1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 #include <type_traits> 13 14 #include "../PassDetail.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Complex/IR/Complex.h" 17 #include "mlir/Dialect/Math/IR/Math.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 25 namespace { 26 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 27 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 28 29 LogicalResult 30 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 31 ConversionPatternRewriter &rewriter) const override { 32 auto loc = op.getLoc(); 33 auto type = op.getType(); 34 35 Value real = rewriter.create<complex::ReOp>(loc, type, adaptor.complex()); 36 Value imag = rewriter.create<complex::ImOp>(loc, type, adaptor.complex()); 37 Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real); 38 Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag); 39 Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr); 40 41 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 42 return success(); 43 } 44 }; 45 46 template <typename ComparisonOp, arith::CmpFPredicate p> 47 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 48 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 49 using ResultCombiner = 50 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 51 arith::AndIOp, arith::OrIOp>; 52 53 LogicalResult 54 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, 55 ConversionPatternRewriter &rewriter) const override { 56 auto loc = op.getLoc(); 57 auto type = 58 adaptor.lhs().getType().template cast<ComplexType>().getElementType(); 59 60 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.lhs()); 61 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.lhs()); 62 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.rhs()); 63 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.rhs()); 64 Value realComparison = 65 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); 66 Value imagComparison = 67 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); 68 69 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 70 imagComparison); 71 return success(); 72 } 73 }; 74 75 // Default conversion which applies the BinaryStandardOp separately on the real 76 // and imaginary parts. Can for example be used for complex::AddOp and 77 // complex::SubOp. 78 template <typename BinaryComplexOp, typename BinaryStandardOp> 79 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 80 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 81 82 LogicalResult 83 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, 84 ConversionPatternRewriter &rewriter) const override { 85 auto type = adaptor.lhs().getType().template cast<ComplexType>(); 86 auto elementType = type.getElementType().template cast<FloatType>(); 87 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 88 89 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.lhs()); 90 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.rhs()); 91 Value resultReal = 92 b.create<BinaryStandardOp>(elementType, realLhs, realRhs); 93 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.lhs()); 94 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.rhs()); 95 Value resultImag = 96 b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs); 97 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 98 resultImag); 99 return success(); 100 } 101 }; 102 103 struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 104 using OpConversionPattern<complex::DivOp>::OpConversionPattern; 105 106 LogicalResult 107 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const override { 109 auto loc = op.getLoc(); 110 auto type = adaptor.lhs().getType().cast<ComplexType>(); 111 auto elementType = type.getElementType().cast<FloatType>(); 112 113 Value lhsReal = 114 rewriter.create<complex::ReOp>(loc, elementType, adaptor.lhs()); 115 Value lhsImag = 116 rewriter.create<complex::ImOp>(loc, elementType, adaptor.lhs()); 117 Value rhsReal = 118 rewriter.create<complex::ReOp>(loc, elementType, adaptor.rhs()); 119 Value rhsImag = 120 rewriter.create<complex::ImOp>(loc, elementType, adaptor.rhs()); 121 122 // Smith's algorithm to divide complex numbers. It is just a bit smarter 123 // way to compute the following formula: 124 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 125 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 126 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 127 // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 128 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 129 // 130 // Depending on whether |rhsReal| < |rhsImag| we compute either 131 // rhsRealImagRatio = rhsReal / rhsImag 132 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 133 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 134 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 135 // 136 // or 137 // 138 // rhsImagRealRatio = rhsImag / rhsReal 139 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 140 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 141 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 142 // 143 // See https://dl.acm.org/citation.cfm?id=368661 for more details. 144 Value rhsRealImagRatio = 145 rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag); 146 Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( 147 loc, rhsImag, 148 rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal)); 149 Value realNumerator1 = rewriter.create<arith::AddFOp>( 150 loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio), 151 lhsImag); 152 Value resultReal1 = 153 rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom); 154 Value imagNumerator1 = rewriter.create<arith::SubFOp>( 155 loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio), 156 lhsReal); 157 Value resultImag1 = 158 rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 159 160 Value rhsImagRealRatio = 161 rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal); 162 Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( 163 loc, rhsReal, 164 rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag)); 165 Value realNumerator2 = rewriter.create<arith::AddFOp>( 166 loc, lhsReal, 167 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio)); 168 Value resultReal2 = 169 rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom); 170 Value imagNumerator2 = rewriter.create<arith::SubFOp>( 171 loc, lhsImag, 172 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio)); 173 Value resultImag2 = 174 rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 175 176 // Consider corner cases. 177 // Case 1. Zero denominator, numerator contains at most one NaN value. 178 Value zero = rewriter.create<arith::ConstantOp>( 179 loc, elementType, rewriter.getZeroAttr(elementType)); 180 Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal); 181 Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( 182 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); 183 Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag); 184 Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( 185 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); 186 Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( 187 loc, arith::CmpFPredicate::ORD, lhsReal, zero); 188 Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( 189 loc, arith::CmpFPredicate::ORD, lhsImag, zero); 190 Value lhsContainsNotNaNValue = 191 rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 192 Value resultIsInfinity = rewriter.create<arith::AndIOp>( 193 loc, lhsContainsNotNaNValue, 194 rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); 195 Value inf = rewriter.create<arith::ConstantOp>( 196 loc, elementType, 197 rewriter.getFloatAttr( 198 elementType, APFloat::getInf(elementType.getFloatSemantics()))); 199 Value infWithSignOfRhsReal = 200 rewriter.create<math::CopySignOp>(loc, inf, rhsReal); 201 Value infinityResultReal = 202 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 203 Value infinityResultImag = 204 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 205 206 // Case 2. Infinite numerator, finite denominator. 207 Value rhsRealFinite = rewriter.create<arith::CmpFOp>( 208 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); 209 Value rhsImagFinite = rewriter.create<arith::CmpFOp>( 210 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); 211 Value rhsFinite = 212 rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); 213 Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal); 214 Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( 215 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 216 Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag); 217 Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( 218 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 219 Value lhsInfinite = 220 rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); 221 Value infNumFiniteDenom = 222 rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); 223 Value one = rewriter.create<arith::ConstantOp>( 224 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 225 Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 226 loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero), 227 lhsReal); 228 Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 229 loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero), 230 lhsImag); 231 Value lhsRealIsInfWithSignTimesRhsReal = 232 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 233 Value lhsImagIsInfWithSignTimesRhsImag = 234 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 235 Value resultReal3 = rewriter.create<arith::MulFOp>( 236 loc, inf, 237 rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 238 lhsImagIsInfWithSignTimesRhsImag)); 239 Value lhsRealIsInfWithSignTimesRhsImag = 240 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 241 Value lhsImagIsInfWithSignTimesRhsReal = 242 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 243 Value resultImag3 = rewriter.create<arith::MulFOp>( 244 loc, inf, 245 rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 246 lhsRealIsInfWithSignTimesRhsImag)); 247 248 // Case 3: Finite numerator, infinite denominator. 249 Value lhsRealFinite = rewriter.create<arith::CmpFOp>( 250 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); 251 Value lhsImagFinite = rewriter.create<arith::CmpFOp>( 252 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); 253 Value lhsFinite = 254 rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); 255 Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( 256 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 257 Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( 258 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 259 Value rhsInfinite = 260 rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); 261 Value finiteNumInfiniteDenom = 262 rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); 263 Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 264 loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero), 265 rhsReal); 266 Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 267 loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero), 268 rhsImag); 269 Value rhsRealIsInfWithSignTimesLhsReal = 270 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 271 Value rhsImagIsInfWithSignTimesLhsImag = 272 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 273 Value resultReal4 = rewriter.create<arith::MulFOp>( 274 loc, zero, 275 rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 276 rhsImagIsInfWithSignTimesLhsImag)); 277 Value rhsRealIsInfWithSignTimesLhsImag = 278 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 279 Value rhsImagIsInfWithSignTimesLhsReal = 280 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 281 Value resultImag4 = rewriter.create<arith::MulFOp>( 282 loc, zero, 283 rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 284 rhsImagIsInfWithSignTimesLhsReal)); 285 286 Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( 287 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 288 Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 289 resultReal1, resultReal2); 290 Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 291 resultImag1, resultImag2); 292 Value resultRealSpecialCase3 = rewriter.create<SelectOp>( 293 loc, finiteNumInfiniteDenom, resultReal4, resultReal); 294 Value resultImagSpecialCase3 = rewriter.create<SelectOp>( 295 loc, finiteNumInfiniteDenom, resultImag4, resultImag); 296 Value resultRealSpecialCase2 = rewriter.create<SelectOp>( 297 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 298 Value resultImagSpecialCase2 = rewriter.create<SelectOp>( 299 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 300 Value resultRealSpecialCase1 = rewriter.create<SelectOp>( 301 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 302 Value resultImagSpecialCase1 = rewriter.create<SelectOp>( 303 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 304 305 Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( 306 loc, arith::CmpFPredicate::UNO, resultReal, zero); 307 Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( 308 loc, arith::CmpFPredicate::UNO, resultImag, zero); 309 Value resultIsNaN = 310 rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); 311 Value resultRealWithSpecialCases = rewriter.create<SelectOp>( 312 loc, resultIsNaN, resultRealSpecialCase1, resultReal); 313 Value resultImagWithSpecialCases = rewriter.create<SelectOp>( 314 loc, resultIsNaN, resultImagSpecialCase1, resultImag); 315 316 rewriter.replaceOpWithNewOp<complex::CreateOp>( 317 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 318 return success(); 319 } 320 }; 321 322 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 323 using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 324 325 LogicalResult 326 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, 327 ConversionPatternRewriter &rewriter) const override { 328 auto loc = op.getLoc(); 329 auto type = adaptor.complex().getType().cast<ComplexType>(); 330 auto elementType = type.getElementType().cast<FloatType>(); 331 332 Value real = 333 rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex()); 334 Value imag = 335 rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex()); 336 Value expReal = rewriter.create<math::ExpOp>(loc, real); 337 Value cosImag = rewriter.create<math::CosOp>(loc, imag); 338 Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag); 339 Value sinImag = rewriter.create<math::SinOp>(loc, imag); 340 Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag); 341 342 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 343 resultImag); 344 return success(); 345 } 346 }; 347 348 struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 349 using OpConversionPattern<complex::LogOp>::OpConversionPattern; 350 351 LogicalResult 352 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, 353 ConversionPatternRewriter &rewriter) const override { 354 auto type = adaptor.complex().getType().cast<ComplexType>(); 355 auto elementType = type.getElementType().cast<FloatType>(); 356 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 357 358 Value abs = b.create<complex::AbsOp>(elementType, adaptor.complex()); 359 Value resultReal = b.create<math::LogOp>(elementType, abs); 360 Value real = b.create<complex::ReOp>(elementType, adaptor.complex()); 361 Value imag = b.create<complex::ImOp>(elementType, adaptor.complex()); 362 Value resultImag = b.create<math::Atan2Op>(elementType, imag, real); 363 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 364 resultImag); 365 return success(); 366 } 367 }; 368 369 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 370 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 371 372 LogicalResult 373 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, 374 ConversionPatternRewriter &rewriter) const override { 375 auto type = adaptor.complex().getType().cast<ComplexType>(); 376 auto elementType = type.getElementType().cast<FloatType>(); 377 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 378 379 Value real = b.create<complex::ReOp>(elementType, adaptor.complex()); 380 Value imag = b.create<complex::ImOp>(elementType, adaptor.complex()); 381 Value one = b.create<arith::ConstantOp>(elementType, 382 b.getFloatAttr(elementType, 1)); 383 Value realPlusOne = b.create<arith::AddFOp>(real, one); 384 Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag); 385 rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex); 386 return success(); 387 } 388 }; 389 390 struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 391 using OpConversionPattern<complex::MulOp>::OpConversionPattern; 392 393 LogicalResult 394 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 395 ConversionPatternRewriter &rewriter) const override { 396 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 397 auto type = adaptor.lhs().getType().cast<ComplexType>(); 398 auto elementType = type.getElementType().cast<FloatType>(); 399 400 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.lhs()); 401 Value lhsRealAbs = b.create<math::AbsOp>(lhsReal); 402 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.lhs()); 403 Value lhsImagAbs = b.create<math::AbsOp>(lhsImag); 404 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.rhs()); 405 Value rhsRealAbs = b.create<math::AbsOp>(rhsReal); 406 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.rhs()); 407 Value rhsImagAbs = b.create<math::AbsOp>(rhsImag); 408 409 Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 410 Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal); 411 Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 412 Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag); 413 Value real = 414 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 415 416 Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 417 Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal); 418 Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 419 Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag); 420 Value imag = 421 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 422 423 // Handle cases where the "naive" calculation results in NaN values. 424 Value realIsNan = 425 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real); 426 Value imagIsNan = 427 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag); 428 Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan); 429 430 Value inf = b.create<arith::ConstantOp>( 431 elementType, 432 b.getFloatAttr(elementType, 433 APFloat::getInf(elementType.getFloatSemantics()))); 434 435 // Case 1. `lhsReal` or `lhsImag` are infinite. 436 Value lhsRealIsInf = 437 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 438 Value lhsImagIsInf = 439 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 440 Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf); 441 Value rhsRealIsNan = 442 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal); 443 Value rhsImagIsNan = 444 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag); 445 Value zero = 446 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 447 Value one = b.create<arith::ConstantOp>(elementType, 448 b.getFloatAttr(elementType, 1)); 449 Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero); 450 lhsReal = b.create<SelectOp>( 451 lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal), 452 lhsReal); 453 Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero); 454 lhsImag = b.create<SelectOp>( 455 lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag), 456 lhsImag); 457 Value lhsIsInfAndRhsRealIsNan = 458 b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan); 459 rhsReal = 460 b.create<SelectOp>(lhsIsInfAndRhsRealIsNan, 461 b.create<math::CopySignOp>(zero, rhsReal), rhsReal); 462 Value lhsIsInfAndRhsImagIsNan = 463 b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan); 464 rhsImag = 465 b.create<SelectOp>(lhsIsInfAndRhsImagIsNan, 466 b.create<math::CopySignOp>(zero, rhsImag), rhsImag); 467 468 // Case 2. `rhsReal` or `rhsImag` are infinite. 469 Value rhsRealIsInf = 470 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 471 Value rhsImagIsInf = 472 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 473 Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf); 474 Value lhsRealIsNan = 475 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal); 476 Value lhsImagIsNan = 477 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag); 478 Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero); 479 rhsReal = b.create<SelectOp>( 480 rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal), 481 rhsReal); 482 Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero); 483 rhsImag = b.create<SelectOp>( 484 rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag), 485 rhsImag); 486 Value rhsIsInfAndLhsRealIsNan = 487 b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan); 488 lhsReal = 489 b.create<SelectOp>(rhsIsInfAndLhsRealIsNan, 490 b.create<math::CopySignOp>(zero, lhsReal), lhsReal); 491 Value rhsIsInfAndLhsImagIsNan = 492 b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan); 493 lhsImag = 494 b.create<SelectOp>(rhsIsInfAndLhsImagIsNan, 495 b.create<math::CopySignOp>(zero, lhsImag), lhsImag); 496 Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf); 497 498 // Case 3. One of the pairwise products of left hand side with right hand 499 // side is infinite. 500 Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>( 501 arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); 502 Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>( 503 arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); 504 Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf, 505 lhsImagTimesRhsImagIsInf); 506 Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>( 507 arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); 508 isSpecialCase = 509 b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); 510 Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>( 511 arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); 512 isSpecialCase = 513 b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); 514 Type i1Type = b.getI1Type(); 515 Value notRecalc = b.create<arith::XOrIOp>( 516 recalc, 517 b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); 518 isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc); 519 Value isSpecialCaseAndLhsRealIsNan = 520 b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan); 521 lhsReal = 522 b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan, 523 b.create<math::CopySignOp>(zero, lhsReal), lhsReal); 524 Value isSpecialCaseAndLhsImagIsNan = 525 b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan); 526 lhsImag = 527 b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan, 528 b.create<math::CopySignOp>(zero, lhsImag), lhsImag); 529 Value isSpecialCaseAndRhsRealIsNan = 530 b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan); 531 rhsReal = 532 b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan, 533 b.create<math::CopySignOp>(zero, rhsReal), rhsReal); 534 Value isSpecialCaseAndRhsImagIsNan = 535 b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan); 536 rhsImag = 537 b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan, 538 b.create<math::CopySignOp>(zero, rhsImag), rhsImag); 539 recalc = b.create<arith::OrIOp>(recalc, isSpecialCase); 540 recalc = b.create<arith::AndIOp>(isNan, recalc); 541 542 // Recalculate real part. 543 lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 544 lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 545 Value newReal = 546 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 547 real = 548 b.create<SelectOp>(recalc, b.create<arith::MulFOp>(inf, newReal), real); 549 550 // Recalculate imag part. 551 lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 552 lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 553 Value newImag = 554 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 555 imag = 556 b.create<SelectOp>(recalc, b.create<arith::MulFOp>(inf, newImag), imag); 557 558 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 559 return success(); 560 } 561 }; 562 563 struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 564 using OpConversionPattern<complex::NegOp>::OpConversionPattern; 565 566 LogicalResult 567 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, 568 ConversionPatternRewriter &rewriter) const override { 569 auto loc = op.getLoc(); 570 auto type = adaptor.complex().getType().cast<ComplexType>(); 571 auto elementType = type.getElementType().cast<FloatType>(); 572 573 Value real = 574 rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex()); 575 Value imag = 576 rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex()); 577 Value negReal = rewriter.create<arith::NegFOp>(loc, real); 578 Value negImag = rewriter.create<arith::NegFOp>(loc, imag); 579 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 580 return success(); 581 } 582 }; 583 584 struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 585 using OpConversionPattern<complex::SignOp>::OpConversionPattern; 586 587 LogicalResult 588 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, 589 ConversionPatternRewriter &rewriter) const override { 590 auto type = adaptor.complex().getType().cast<ComplexType>(); 591 auto elementType = type.getElementType().cast<FloatType>(); 592 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 593 594 Value real = b.create<complex::ReOp>(elementType, adaptor.complex()); 595 Value imag = b.create<complex::ImOp>(elementType, adaptor.complex()); 596 Value zero = 597 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 598 Value realIsZero = 599 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 600 Value imagIsZero = 601 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 602 Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 603 auto abs = b.create<complex::AbsOp>(elementType, adaptor.complex()); 604 Value realSign = b.create<arith::DivFOp>(real, abs); 605 Value imagSign = b.create<arith::DivFOp>(imag, abs); 606 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 607 rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, adaptor.complex(), sign); 608 return success(); 609 } 610 }; 611 } // namespace 612 613 void mlir::populateComplexToStandardConversionPatterns( 614 RewritePatternSet &patterns) { 615 // clang-format off 616 patterns.add< 617 AbsOpConversion, 618 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, 619 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, 620 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, 621 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, 622 DivOpConversion, 623 ExpOpConversion, 624 LogOpConversion, 625 Log1pOpConversion, 626 MulOpConversion, 627 NegOpConversion, 628 SignOpConversion>(patterns.getContext()); 629 // clang-format on 630 } 631 632 namespace { 633 struct ConvertComplexToStandardPass 634 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 635 void runOnFunction() override; 636 }; 637 638 void ConvertComplexToStandardPass::runOnFunction() { 639 auto function = getFunction(); 640 641 // Convert to the Standard dialect using the converter defined above. 642 RewritePatternSet patterns(&getContext()); 643 populateComplexToStandardConversionPatterns(patterns); 644 645 ConversionTarget target(getContext()); 646 target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect, 647 math::MathDialect>(); 648 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 649 if (failed(applyPartialConversion(function, target, std::move(patterns)))) 650 signalPassFailure(); 651 } 652 } // namespace 653 654 std::unique_ptr<OperationPass<FuncOp>> 655 mlir::createConvertComplexToStandardPass() { 656 return std::make_unique<ConvertComplexToStandardPass>(); 657 } 658