1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 #include <type_traits> 13 14 #include "../PassDetail.h" 15 #include "mlir/Dialect/Complex/IR/Complex.h" 16 #include "mlir/Dialect/Math/IR/Math.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 using namespace mlir; 23 24 namespace { 25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 26 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 30 ConversionPatternRewriter &rewriter) const override { 31 auto loc = op.getLoc(); 32 auto type = op.getType(); 33 34 Value real = rewriter.create<complex::ReOp>(loc, type, adaptor.complex()); 35 Value imag = rewriter.create<complex::ImOp>(loc, type, adaptor.complex()); 36 Value realSqr = rewriter.create<MulFOp>(loc, real, real); 37 Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag); 38 Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr); 39 40 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 41 return success(); 42 } 43 }; 44 45 template <typename ComparisonOp, CmpFPredicate p> 46 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 47 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 48 using ResultCombiner = 49 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 50 AndOp, OrOp>; 51 52 LogicalResult 53 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, 54 ConversionPatternRewriter &rewriter) const override { 55 auto loc = op.getLoc(); 56 auto type = 57 adaptor.lhs().getType().template cast<ComplexType>().getElementType(); 58 59 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.lhs()); 60 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.lhs()); 61 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.rhs()); 62 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.rhs()); 63 Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs); 64 Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs); 65 66 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 67 imagComparison); 68 return success(); 69 } 70 }; 71 72 // Default conversion which applies the BinaryStandardOp separately on the real 73 // and imaginary parts. Can for example be used for complex::AddOp and 74 // complex::SubOp. 75 template <typename BinaryComplexOp, typename BinaryStandardOp> 76 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 77 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 78 79 LogicalResult 80 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, 81 ConversionPatternRewriter &rewriter) const override { 82 auto type = adaptor.lhs().getType().template cast<ComplexType>(); 83 auto elementType = type.getElementType().template cast<FloatType>(); 84 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 85 86 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.lhs()); 87 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.rhs()); 88 Value resultReal = 89 b.create<BinaryStandardOp>(elementType, realLhs, realRhs); 90 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.lhs()); 91 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.rhs()); 92 Value resultImag = 93 b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs); 94 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 95 resultImag); 96 return success(); 97 } 98 }; 99 100 struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 101 using OpConversionPattern<complex::DivOp>::OpConversionPattern; 102 103 LogicalResult 104 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 105 ConversionPatternRewriter &rewriter) const override { 106 auto loc = op.getLoc(); 107 auto type = adaptor.lhs().getType().cast<ComplexType>(); 108 auto elementType = type.getElementType().cast<FloatType>(); 109 110 Value lhsReal = 111 rewriter.create<complex::ReOp>(loc, elementType, adaptor.lhs()); 112 Value lhsImag = 113 rewriter.create<complex::ImOp>(loc, elementType, adaptor.lhs()); 114 Value rhsReal = 115 rewriter.create<complex::ReOp>(loc, elementType, adaptor.rhs()); 116 Value rhsImag = 117 rewriter.create<complex::ImOp>(loc, elementType, adaptor.rhs()); 118 119 // Smith's algorithm to divide complex numbers. It is just a bit smarter 120 // way to compute the following formula: 121 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 122 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 123 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 124 // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 125 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 126 // 127 // Depending on whether |rhsReal| < |rhsImag| we compute either 128 // rhsRealImagRatio = rhsReal / rhsImag 129 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 130 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 131 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 132 // 133 // or 134 // 135 // rhsImagRealRatio = rhsImag / rhsReal 136 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 137 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 138 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 139 // 140 // See https://dl.acm.org/citation.cfm?id=368661 for more details. 141 Value rhsRealImagRatio = rewriter.create<DivFOp>(loc, rhsReal, rhsImag); 142 Value rhsRealImagDenom = rewriter.create<AddFOp>( 143 loc, rhsImag, rewriter.create<MulFOp>(loc, rhsRealImagRatio, rhsReal)); 144 Value realNumerator1 = rewriter.create<AddFOp>( 145 loc, rewriter.create<MulFOp>(loc, lhsReal, rhsRealImagRatio), lhsImag); 146 Value resultReal1 = 147 rewriter.create<DivFOp>(loc, realNumerator1, rhsRealImagDenom); 148 Value imagNumerator1 = rewriter.create<SubFOp>( 149 loc, rewriter.create<MulFOp>(loc, lhsImag, rhsRealImagRatio), lhsReal); 150 Value resultImag1 = 151 rewriter.create<DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 152 153 Value rhsImagRealRatio = rewriter.create<DivFOp>(loc, rhsImag, rhsReal); 154 Value rhsImagRealDenom = rewriter.create<AddFOp>( 155 loc, rhsReal, rewriter.create<MulFOp>(loc, rhsImagRealRatio, rhsImag)); 156 Value realNumerator2 = rewriter.create<AddFOp>( 157 loc, lhsReal, rewriter.create<MulFOp>(loc, lhsImag, rhsImagRealRatio)); 158 Value resultReal2 = 159 rewriter.create<DivFOp>(loc, realNumerator2, rhsImagRealDenom); 160 Value imagNumerator2 = rewriter.create<SubFOp>( 161 loc, lhsImag, rewriter.create<MulFOp>(loc, lhsReal, rhsImagRealRatio)); 162 Value resultImag2 = 163 rewriter.create<DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 164 165 // Consider corner cases. 166 // Case 1. Zero denominator, numerator contains at most one NaN value. 167 Value zero = rewriter.create<ConstantOp>(loc, elementType, 168 rewriter.getZeroAttr(elementType)); 169 Value rhsRealAbs = rewriter.create<AbsFOp>(loc, rhsReal); 170 Value rhsRealIsZero = 171 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, zero); 172 Value rhsImagAbs = rewriter.create<AbsFOp>(loc, rhsImag); 173 Value rhsImagIsZero = 174 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, zero); 175 Value lhsRealIsNotNaN = 176 rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsReal, zero); 177 Value lhsImagIsNotNaN = 178 rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsImag, zero); 179 Value lhsContainsNotNaNValue = 180 rewriter.create<OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 181 Value resultIsInfinity = rewriter.create<AndOp>( 182 loc, lhsContainsNotNaNValue, 183 rewriter.create<AndOp>(loc, rhsRealIsZero, rhsImagIsZero)); 184 Value inf = rewriter.create<ConstantOp>( 185 loc, elementType, 186 rewriter.getFloatAttr( 187 elementType, APFloat::getInf(elementType.getFloatSemantics()))); 188 Value infWithSignOfRhsReal = rewriter.create<CopySignOp>(loc, inf, rhsReal); 189 Value infinityResultReal = 190 rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 191 Value infinityResultImag = 192 rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 193 194 // Case 2. Infinite numerator, finite denominator. 195 Value rhsRealFinite = 196 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsRealAbs, inf); 197 Value rhsImagFinite = 198 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsImagAbs, inf); 199 Value rhsFinite = rewriter.create<AndOp>(loc, rhsRealFinite, rhsImagFinite); 200 Value lhsRealAbs = rewriter.create<AbsFOp>(loc, lhsReal); 201 Value lhsRealInfinite = 202 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsRealAbs, inf); 203 Value lhsImagAbs = rewriter.create<AbsFOp>(loc, lhsImag); 204 Value lhsImagInfinite = 205 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsImagAbs, inf); 206 Value lhsInfinite = 207 rewriter.create<OrOp>(loc, lhsRealInfinite, lhsImagInfinite); 208 Value infNumFiniteDenom = 209 rewriter.create<AndOp>(loc, lhsInfinite, rhsFinite); 210 Value one = rewriter.create<ConstantOp>( 211 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 212 Value lhsRealIsInfWithSign = rewriter.create<CopySignOp>( 213 loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero), 214 lhsReal); 215 Value lhsImagIsInfWithSign = rewriter.create<CopySignOp>( 216 loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero), 217 lhsImag); 218 Value lhsRealIsInfWithSignTimesRhsReal = 219 rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 220 Value lhsImagIsInfWithSignTimesRhsImag = 221 rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 222 Value resultReal3 = rewriter.create<MulFOp>( 223 loc, inf, 224 rewriter.create<AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 225 lhsImagIsInfWithSignTimesRhsImag)); 226 Value lhsRealIsInfWithSignTimesRhsImag = 227 rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 228 Value lhsImagIsInfWithSignTimesRhsReal = 229 rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 230 Value resultImag3 = rewriter.create<MulFOp>( 231 loc, inf, 232 rewriter.create<SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 233 lhsRealIsInfWithSignTimesRhsImag)); 234 235 // Case 3: Finite numerator, infinite denominator. 236 Value lhsRealFinite = 237 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsRealAbs, inf); 238 Value lhsImagFinite = 239 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsImagAbs, inf); 240 Value lhsFinite = rewriter.create<AndOp>(loc, lhsRealFinite, lhsImagFinite); 241 Value rhsRealInfinite = 242 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, inf); 243 Value rhsImagInfinite = 244 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, inf); 245 Value rhsInfinite = 246 rewriter.create<OrOp>(loc, rhsRealInfinite, rhsImagInfinite); 247 Value finiteNumInfiniteDenom = 248 rewriter.create<AndOp>(loc, lhsFinite, rhsInfinite); 249 Value rhsRealIsInfWithSign = rewriter.create<CopySignOp>( 250 loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero), 251 rhsReal); 252 Value rhsImagIsInfWithSign = rewriter.create<CopySignOp>( 253 loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero), 254 rhsImag); 255 Value rhsRealIsInfWithSignTimesLhsReal = 256 rewriter.create<MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 257 Value rhsImagIsInfWithSignTimesLhsImag = 258 rewriter.create<MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 259 Value resultReal4 = rewriter.create<MulFOp>( 260 loc, zero, 261 rewriter.create<AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 262 rhsImagIsInfWithSignTimesLhsImag)); 263 Value rhsRealIsInfWithSignTimesLhsImag = 264 rewriter.create<MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 265 Value rhsImagIsInfWithSignTimesLhsReal = 266 rewriter.create<MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 267 Value resultImag4 = rewriter.create<MulFOp>( 268 loc, zero, 269 rewriter.create<SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 270 rhsImagIsInfWithSignTimesLhsReal)); 271 272 Value realAbsSmallerThanImagAbs = rewriter.create<CmpFOp>( 273 loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 274 Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 275 resultReal1, resultReal2); 276 Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 277 resultImag1, resultImag2); 278 Value resultRealSpecialCase3 = rewriter.create<SelectOp>( 279 loc, finiteNumInfiniteDenom, resultReal4, resultReal); 280 Value resultImagSpecialCase3 = rewriter.create<SelectOp>( 281 loc, finiteNumInfiniteDenom, resultImag4, resultImag); 282 Value resultRealSpecialCase2 = rewriter.create<SelectOp>( 283 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 284 Value resultImagSpecialCase2 = rewriter.create<SelectOp>( 285 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 286 Value resultRealSpecialCase1 = rewriter.create<SelectOp>( 287 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 288 Value resultImagSpecialCase1 = rewriter.create<SelectOp>( 289 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 290 291 Value resultRealIsNaN = 292 rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultReal, zero); 293 Value resultImagIsNaN = 294 rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultImag, zero); 295 Value resultIsNaN = 296 rewriter.create<AndOp>(loc, resultRealIsNaN, resultImagIsNaN); 297 Value resultRealWithSpecialCases = rewriter.create<SelectOp>( 298 loc, resultIsNaN, resultRealSpecialCase1, resultReal); 299 Value resultImagWithSpecialCases = rewriter.create<SelectOp>( 300 loc, resultIsNaN, resultImagSpecialCase1, resultImag); 301 302 rewriter.replaceOpWithNewOp<complex::CreateOp>( 303 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 304 return success(); 305 } 306 }; 307 308 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 309 using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 310 311 LogicalResult 312 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, 313 ConversionPatternRewriter &rewriter) const override { 314 auto loc = op.getLoc(); 315 auto type = adaptor.complex().getType().cast<ComplexType>(); 316 auto elementType = type.getElementType().cast<FloatType>(); 317 318 Value real = 319 rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex()); 320 Value imag = 321 rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex()); 322 Value expReal = rewriter.create<math::ExpOp>(loc, real); 323 Value cosImag = rewriter.create<math::CosOp>(loc, imag); 324 Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag); 325 Value sinImag = rewriter.create<math::SinOp>(loc, imag); 326 Value resultImag = rewriter.create<MulFOp>(loc, expReal, sinImag); 327 328 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 329 resultImag); 330 return success(); 331 } 332 }; 333 334 struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 335 using OpConversionPattern<complex::LogOp>::OpConversionPattern; 336 337 LogicalResult 338 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, 339 ConversionPatternRewriter &rewriter) const override { 340 auto type = adaptor.complex().getType().cast<ComplexType>(); 341 auto elementType = type.getElementType().cast<FloatType>(); 342 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 343 344 Value abs = b.create<complex::AbsOp>(elementType, adaptor.complex()); 345 Value resultReal = b.create<math::LogOp>(elementType, abs); 346 Value real = b.create<complex::ReOp>(elementType, adaptor.complex()); 347 Value imag = b.create<complex::ImOp>(elementType, adaptor.complex()); 348 Value resultImag = b.create<math::Atan2Op>(elementType, imag, real); 349 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 350 resultImag); 351 return success(); 352 } 353 }; 354 355 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 356 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 357 358 LogicalResult 359 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, 360 ConversionPatternRewriter &rewriter) const override { 361 auto type = adaptor.complex().getType().cast<ComplexType>(); 362 auto elementType = type.getElementType().cast<FloatType>(); 363 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 364 365 Value real = b.create<complex::ReOp>(elementType, adaptor.complex()); 366 Value imag = b.create<complex::ImOp>(elementType, adaptor.complex()); 367 Value one = 368 b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1)); 369 Value realPlusOne = b.create<AddFOp>(real, one); 370 Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag); 371 rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex); 372 return success(); 373 } 374 }; 375 376 struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 377 using OpConversionPattern<complex::MulOp>::OpConversionPattern; 378 379 LogicalResult 380 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 381 ConversionPatternRewriter &rewriter) const override { 382 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 383 auto type = adaptor.lhs().getType().cast<ComplexType>(); 384 auto elementType = type.getElementType().cast<FloatType>(); 385 386 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.lhs()); 387 Value lhsRealAbs = b.create<AbsFOp>(lhsReal); 388 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.lhs()); 389 Value lhsImagAbs = b.create<AbsFOp>(lhsImag); 390 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.rhs()); 391 Value rhsRealAbs = b.create<AbsFOp>(rhsReal); 392 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.rhs()); 393 Value rhsImagAbs = b.create<AbsFOp>(rhsImag); 394 395 Value lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal); 396 Value lhsRealTimesRhsRealAbs = b.create<AbsFOp>(lhsRealTimesRhsReal); 397 Value lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag); 398 Value lhsImagTimesRhsImagAbs = b.create<AbsFOp>(lhsImagTimesRhsImag); 399 Value real = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 400 401 Value lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal); 402 Value lhsImagTimesRhsRealAbs = b.create<AbsFOp>(lhsImagTimesRhsReal); 403 Value lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag); 404 Value lhsRealTimesRhsImagAbs = b.create<AbsFOp>(lhsRealTimesRhsImag); 405 Value imag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 406 407 // Handle cases where the "naive" calculation results in NaN values. 408 Value realIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, real, real); 409 Value imagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, imag, imag); 410 Value isNan = b.create<AndOp>(realIsNan, imagIsNan); 411 412 Value inf = b.create<ConstantOp>( 413 elementType, 414 b.getFloatAttr(elementType, 415 APFloat::getInf(elementType.getFloatSemantics()))); 416 417 // Case 1. `lhsReal` or `lhsImag` are infinite. 418 Value lhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealAbs, inf); 419 Value lhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagAbs, inf); 420 Value lhsIsInf = b.create<OrOp>(lhsRealIsInf, lhsImagIsInf); 421 Value rhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsReal, rhsReal); 422 Value rhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsImag, rhsImag); 423 Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType)); 424 Value one = 425 b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1)); 426 Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero); 427 lhsReal = b.create<SelectOp>( 428 lhsIsInf, b.create<CopySignOp>(lhsRealIsInfFloat, lhsReal), lhsReal); 429 Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero); 430 lhsImag = b.create<SelectOp>( 431 lhsIsInf, b.create<CopySignOp>(lhsImagIsInfFloat, lhsImag), lhsImag); 432 Value lhsIsInfAndRhsRealIsNan = b.create<AndOp>(lhsIsInf, rhsRealIsNan); 433 rhsReal = b.create<SelectOp>(lhsIsInfAndRhsRealIsNan, 434 b.create<CopySignOp>(zero, rhsReal), rhsReal); 435 Value lhsIsInfAndRhsImagIsNan = b.create<AndOp>(lhsIsInf, rhsImagIsNan); 436 rhsImag = b.create<SelectOp>(lhsIsInfAndRhsImagIsNan, 437 b.create<CopySignOp>(zero, rhsImag), rhsImag); 438 439 // Case 2. `rhsReal` or `rhsImag` are infinite. 440 Value rhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsRealAbs, inf); 441 Value rhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsImagAbs, inf); 442 Value rhsIsInf = b.create<OrOp>(rhsRealIsInf, rhsImagIsInf); 443 Value lhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsReal, lhsReal); 444 Value lhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsImag, lhsImag); 445 Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero); 446 rhsReal = b.create<SelectOp>( 447 rhsIsInf, b.create<CopySignOp>(rhsRealIsInfFloat, rhsReal), rhsReal); 448 Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero); 449 rhsImag = b.create<SelectOp>( 450 rhsIsInf, b.create<CopySignOp>(rhsImagIsInfFloat, rhsImag), rhsImag); 451 Value rhsIsInfAndLhsRealIsNan = b.create<AndOp>(rhsIsInf, lhsRealIsNan); 452 lhsReal = b.create<SelectOp>(rhsIsInfAndLhsRealIsNan, 453 b.create<CopySignOp>(zero, lhsReal), lhsReal); 454 Value rhsIsInfAndLhsImagIsNan = b.create<AndOp>(rhsIsInf, lhsImagIsNan); 455 lhsImag = b.create<SelectOp>(rhsIsInfAndLhsImagIsNan, 456 b.create<CopySignOp>(zero, lhsImag), lhsImag); 457 Value recalc = b.create<OrOp>(lhsIsInf, rhsIsInf); 458 459 // Case 3. One of the pairwise products of left hand side with right hand 460 // side is infinite. 461 Value lhsRealTimesRhsRealIsInf = 462 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); 463 Value lhsImagTimesRhsImagIsInf = 464 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); 465 Value isSpecialCase = 466 b.create<OrOp>(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf); 467 Value lhsRealTimesRhsImagIsInf = 468 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); 469 isSpecialCase = b.create<OrOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); 470 Value lhsImagTimesRhsRealIsInf = 471 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); 472 isSpecialCase = b.create<OrOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); 473 Type i1Type = b.getI1Type(); 474 Value notRecalc = b.create<XOrOp>( 475 recalc, b.create<ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); 476 isSpecialCase = b.create<AndOp>(isSpecialCase, notRecalc); 477 Value isSpecialCaseAndLhsRealIsNan = 478 b.create<AndOp>(isSpecialCase, lhsRealIsNan); 479 lhsReal = b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan, 480 b.create<CopySignOp>(zero, lhsReal), lhsReal); 481 Value isSpecialCaseAndLhsImagIsNan = 482 b.create<AndOp>(isSpecialCase, lhsImagIsNan); 483 lhsImag = b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan, 484 b.create<CopySignOp>(zero, lhsImag), lhsImag); 485 Value isSpecialCaseAndRhsRealIsNan = 486 b.create<AndOp>(isSpecialCase, rhsRealIsNan); 487 rhsReal = b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan, 488 b.create<CopySignOp>(zero, rhsReal), rhsReal); 489 Value isSpecialCaseAndRhsImagIsNan = 490 b.create<AndOp>(isSpecialCase, rhsImagIsNan); 491 rhsImag = b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan, 492 b.create<CopySignOp>(zero, rhsImag), rhsImag); 493 recalc = b.create<OrOp>(recalc, isSpecialCase); 494 recalc = b.create<AndOp>(isNan, recalc); 495 496 // Recalculate real part. 497 lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal); 498 lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag); 499 Value newReal = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 500 real = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newReal), real); 501 502 // Recalculate imag part. 503 lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal); 504 lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag); 505 Value newImag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 506 imag = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newImag), imag); 507 508 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 509 return success(); 510 } 511 }; 512 513 struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 514 using OpConversionPattern<complex::NegOp>::OpConversionPattern; 515 516 LogicalResult 517 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, 518 ConversionPatternRewriter &rewriter) const override { 519 auto loc = op.getLoc(); 520 auto type = adaptor.complex().getType().cast<ComplexType>(); 521 auto elementType = type.getElementType().cast<FloatType>(); 522 523 Value real = 524 rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex()); 525 Value imag = 526 rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex()); 527 Value negReal = rewriter.create<NegFOp>(loc, real); 528 Value negImag = rewriter.create<NegFOp>(loc, imag); 529 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 530 return success(); 531 } 532 }; 533 534 struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 535 using OpConversionPattern<complex::SignOp>::OpConversionPattern; 536 537 LogicalResult 538 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, 539 ConversionPatternRewriter &rewriter) const override { 540 auto type = adaptor.complex().getType().cast<ComplexType>(); 541 auto elementType = type.getElementType().cast<FloatType>(); 542 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 543 544 Value real = b.create<complex::ReOp>(elementType, adaptor.complex()); 545 Value imag = b.create<complex::ImOp>(elementType, adaptor.complex()); 546 Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType)); 547 Value realIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, real, zero); 548 Value imagIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, imag, zero); 549 Value isZero = b.create<AndOp>(realIsZero, imagIsZero); 550 auto abs = b.create<complex::AbsOp>(elementType, adaptor.complex()); 551 Value realSign = b.create<DivFOp>(real, abs); 552 Value imagSign = b.create<DivFOp>(imag, abs); 553 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 554 rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, adaptor.complex(), sign); 555 return success(); 556 } 557 }; 558 } // namespace 559 560 void mlir::populateComplexToStandardConversionPatterns( 561 RewritePatternSet &patterns) { 562 // clang-format off 563 patterns.add< 564 AbsOpConversion, 565 ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>, 566 ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>, 567 BinaryComplexOpConversion<complex::AddOp, AddFOp>, 568 BinaryComplexOpConversion<complex::SubOp, SubFOp>, 569 DivOpConversion, 570 ExpOpConversion, 571 LogOpConversion, 572 Log1pOpConversion, 573 MulOpConversion, 574 NegOpConversion, 575 SignOpConversion>(patterns.getContext()); 576 // clang-format on 577 } 578 579 namespace { 580 struct ConvertComplexToStandardPass 581 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 582 void runOnFunction() override; 583 }; 584 585 void ConvertComplexToStandardPass::runOnFunction() { 586 auto function = getFunction(); 587 588 // Convert to the Standard dialect using the converter defined above. 589 RewritePatternSet patterns(&getContext()); 590 populateComplexToStandardConversionPatterns(patterns); 591 592 ConversionTarget target(getContext()); 593 target.addLegalDialect<StandardOpsDialect, math::MathDialect>(); 594 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 595 if (failed(applyPartialConversion(function, target, std::move(patterns)))) 596 signalPassFailure(); 597 } 598 } // namespace 599 600 std::unique_ptr<OperationPass<FuncOp>> 601 mlir::createConvertComplexToStandardPass() { 602 return std::make_unique<ConvertComplexToStandardPass>(); 603 } 604