1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 #include <type_traits> 13 14 #include "../PassDetail.h" 15 #include "mlir/Dialect/Complex/IR/Complex.h" 16 #include "mlir/Dialect/Math/IR/Math.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 using namespace mlir; 23 24 namespace { 25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 26 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands, 30 ConversionPatternRewriter &rewriter) const override { 31 complex::AbsOp::Adaptor transformed(operands); 32 auto loc = op.getLoc(); 33 auto type = op.getType(); 34 35 Value real = 36 rewriter.create<complex::ReOp>(loc, type, transformed.complex()); 37 Value imag = 38 rewriter.create<complex::ImOp>(loc, type, transformed.complex()); 39 Value realSqr = rewriter.create<MulFOp>(loc, real, real); 40 Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag); 41 Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr); 42 43 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 44 return success(); 45 } 46 }; 47 48 template <typename ComparisonOp, CmpFPredicate p> 49 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 50 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 51 using ResultCombiner = 52 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 53 AndOp, OrOp>; 54 55 LogicalResult 56 matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands, 57 ConversionPatternRewriter &rewriter) const override { 58 typename ComparisonOp::Adaptor transformed(operands); 59 auto loc = op.getLoc(); 60 auto type = transformed.lhs() 61 .getType() 62 .template cast<ComplexType>() 63 .getElementType(); 64 65 Value realLhs = 66 rewriter.create<complex::ReOp>(loc, type, transformed.lhs()); 67 Value imagLhs = 68 rewriter.create<complex::ImOp>(loc, type, transformed.lhs()); 69 Value realRhs = 70 rewriter.create<complex::ReOp>(loc, type, transformed.rhs()); 71 Value imagRhs = 72 rewriter.create<complex::ImOp>(loc, type, transformed.rhs()); 73 Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs); 74 Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs); 75 76 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 77 imagComparison); 78 return success(); 79 } 80 }; 81 82 // Default conversion which applies the BinaryStandardOp separately on the real 83 // and imaginary parts. Can for example be used for complex::AddOp and 84 // complex::SubOp. 85 template <typename BinaryComplexOp, typename BinaryStandardOp> 86 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 87 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 88 89 LogicalResult 90 matchAndRewrite(BinaryComplexOp op, ArrayRef<Value> operands, 91 ConversionPatternRewriter &rewriter) const override { 92 typename BinaryComplexOp::Adaptor transformed(operands); 93 auto type = transformed.lhs().getType().template cast<ComplexType>(); 94 auto elementType = type.getElementType().template cast<FloatType>(); 95 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 96 97 Value realLhs = b.create<complex::ReOp>(elementType, transformed.lhs()); 98 Value realRhs = b.create<complex::ReOp>(elementType, transformed.rhs()); 99 Value resultReal = 100 b.create<BinaryStandardOp>(elementType, realLhs, realRhs); 101 Value imagLhs = b.create<complex::ImOp>(elementType, transformed.lhs()); 102 Value imagRhs = b.create<complex::ImOp>(elementType, transformed.rhs()); 103 Value resultImag = 104 b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs); 105 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 106 resultImag); 107 return success(); 108 } 109 }; 110 111 struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 112 using OpConversionPattern<complex::DivOp>::OpConversionPattern; 113 114 LogicalResult 115 matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands, 116 ConversionPatternRewriter &rewriter) const override { 117 complex::DivOp::Adaptor transformed(operands); 118 auto loc = op.getLoc(); 119 auto type = transformed.lhs().getType().cast<ComplexType>(); 120 auto elementType = type.getElementType().cast<FloatType>(); 121 122 Value lhsReal = 123 rewriter.create<complex::ReOp>(loc, elementType, transformed.lhs()); 124 Value lhsImag = 125 rewriter.create<complex::ImOp>(loc, elementType, transformed.lhs()); 126 Value rhsReal = 127 rewriter.create<complex::ReOp>(loc, elementType, transformed.rhs()); 128 Value rhsImag = 129 rewriter.create<complex::ImOp>(loc, elementType, transformed.rhs()); 130 131 // Smith's algorithm to divide complex numbers. It is just a bit smarter 132 // way to compute the following formula: 133 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 134 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 135 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 136 // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 137 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 138 // 139 // Depending on whether |rhsReal| < |rhsImag| we compute either 140 // rhsRealImagRatio = rhsReal / rhsImag 141 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 142 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 143 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 144 // 145 // or 146 // 147 // rhsImagRealRatio = rhsImag / rhsReal 148 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 149 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 150 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 151 // 152 // See https://dl.acm.org/citation.cfm?id=368661 for more details. 153 Value rhsRealImagRatio = rewriter.create<DivFOp>(loc, rhsReal, rhsImag); 154 Value rhsRealImagDenom = rewriter.create<AddFOp>( 155 loc, rhsImag, rewriter.create<MulFOp>(loc, rhsRealImagRatio, rhsReal)); 156 Value realNumerator1 = rewriter.create<AddFOp>( 157 loc, rewriter.create<MulFOp>(loc, lhsReal, rhsRealImagRatio), lhsImag); 158 Value resultReal1 = 159 rewriter.create<DivFOp>(loc, realNumerator1, rhsRealImagDenom); 160 Value imagNumerator1 = rewriter.create<SubFOp>( 161 loc, rewriter.create<MulFOp>(loc, lhsImag, rhsRealImagRatio), lhsReal); 162 Value resultImag1 = 163 rewriter.create<DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 164 165 Value rhsImagRealRatio = rewriter.create<DivFOp>(loc, rhsImag, rhsReal); 166 Value rhsImagRealDenom = rewriter.create<AddFOp>( 167 loc, rhsReal, rewriter.create<MulFOp>(loc, rhsImagRealRatio, rhsImag)); 168 Value realNumerator2 = rewriter.create<AddFOp>( 169 loc, lhsReal, rewriter.create<MulFOp>(loc, lhsImag, rhsImagRealRatio)); 170 Value resultReal2 = 171 rewriter.create<DivFOp>(loc, realNumerator2, rhsImagRealDenom); 172 Value imagNumerator2 = rewriter.create<SubFOp>( 173 loc, lhsImag, rewriter.create<MulFOp>(loc, lhsReal, rhsImagRealRatio)); 174 Value resultImag2 = 175 rewriter.create<DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 176 177 // Consider corner cases. 178 // Case 1. Zero denominator, numerator contains at most one NaN value. 179 Value zero = rewriter.create<ConstantOp>(loc, elementType, 180 rewriter.getZeroAttr(elementType)); 181 Value rhsRealAbs = rewriter.create<AbsFOp>(loc, rhsReal); 182 Value rhsRealIsZero = 183 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, zero); 184 Value rhsImagAbs = rewriter.create<AbsFOp>(loc, rhsImag); 185 Value rhsImagIsZero = 186 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, zero); 187 Value lhsRealIsNotNaN = 188 rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsReal, zero); 189 Value lhsImagIsNotNaN = 190 rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsImag, zero); 191 Value lhsContainsNotNaNValue = 192 rewriter.create<OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 193 Value resultIsInfinity = rewriter.create<AndOp>( 194 loc, lhsContainsNotNaNValue, 195 rewriter.create<AndOp>(loc, rhsRealIsZero, rhsImagIsZero)); 196 Value inf = rewriter.create<ConstantOp>( 197 loc, elementType, 198 rewriter.getFloatAttr( 199 elementType, APFloat::getInf(elementType.getFloatSemantics()))); 200 Value infWithSignOfRhsReal = rewriter.create<CopySignOp>(loc, inf, rhsReal); 201 Value infinityResultReal = 202 rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 203 Value infinityResultImag = 204 rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 205 206 // Case 2. Infinite numerator, finite denominator. 207 Value rhsRealFinite = 208 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsRealAbs, inf); 209 Value rhsImagFinite = 210 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsImagAbs, inf); 211 Value rhsFinite = rewriter.create<AndOp>(loc, rhsRealFinite, rhsImagFinite); 212 Value lhsRealAbs = rewriter.create<AbsFOp>(loc, lhsReal); 213 Value lhsRealInfinite = 214 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsRealAbs, inf); 215 Value lhsImagAbs = rewriter.create<AbsFOp>(loc, lhsImag); 216 Value lhsImagInfinite = 217 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsImagAbs, inf); 218 Value lhsInfinite = 219 rewriter.create<OrOp>(loc, lhsRealInfinite, lhsImagInfinite); 220 Value infNumFiniteDenom = 221 rewriter.create<AndOp>(loc, lhsInfinite, rhsFinite); 222 Value one = rewriter.create<ConstantOp>( 223 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 224 Value lhsRealIsInfWithSign = rewriter.create<CopySignOp>( 225 loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero), 226 lhsReal); 227 Value lhsImagIsInfWithSign = rewriter.create<CopySignOp>( 228 loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero), 229 lhsImag); 230 Value lhsRealIsInfWithSignTimesRhsReal = 231 rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 232 Value lhsImagIsInfWithSignTimesRhsImag = 233 rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 234 Value resultReal3 = rewriter.create<MulFOp>( 235 loc, inf, 236 rewriter.create<AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 237 lhsImagIsInfWithSignTimesRhsImag)); 238 Value lhsRealIsInfWithSignTimesRhsImag = 239 rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 240 Value lhsImagIsInfWithSignTimesRhsReal = 241 rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 242 Value resultImag3 = rewriter.create<MulFOp>( 243 loc, inf, 244 rewriter.create<SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 245 lhsRealIsInfWithSignTimesRhsImag)); 246 247 // Case 3: Finite numerator, infinite denominator. 248 Value lhsRealFinite = 249 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsRealAbs, inf); 250 Value lhsImagFinite = 251 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsImagAbs, inf); 252 Value lhsFinite = rewriter.create<AndOp>(loc, lhsRealFinite, lhsImagFinite); 253 Value rhsRealInfinite = 254 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, inf); 255 Value rhsImagInfinite = 256 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, inf); 257 Value rhsInfinite = 258 rewriter.create<OrOp>(loc, rhsRealInfinite, rhsImagInfinite); 259 Value finiteNumInfiniteDenom = 260 rewriter.create<AndOp>(loc, lhsFinite, rhsInfinite); 261 Value rhsRealIsInfWithSign = rewriter.create<CopySignOp>( 262 loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero), 263 rhsReal); 264 Value rhsImagIsInfWithSign = rewriter.create<CopySignOp>( 265 loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero), 266 rhsImag); 267 Value rhsRealIsInfWithSignTimesLhsReal = 268 rewriter.create<MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 269 Value rhsImagIsInfWithSignTimesLhsImag = 270 rewriter.create<MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 271 Value resultReal4 = rewriter.create<MulFOp>( 272 loc, zero, 273 rewriter.create<AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 274 rhsImagIsInfWithSignTimesLhsImag)); 275 Value rhsRealIsInfWithSignTimesLhsImag = 276 rewriter.create<MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 277 Value rhsImagIsInfWithSignTimesLhsReal = 278 rewriter.create<MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 279 Value resultImag4 = rewriter.create<MulFOp>( 280 loc, zero, 281 rewriter.create<SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 282 rhsImagIsInfWithSignTimesLhsReal)); 283 284 Value realAbsSmallerThanImagAbs = rewriter.create<CmpFOp>( 285 loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 286 Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 287 resultReal1, resultReal2); 288 Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 289 resultImag1, resultImag2); 290 Value resultRealSpecialCase3 = rewriter.create<SelectOp>( 291 loc, finiteNumInfiniteDenom, resultReal4, resultReal); 292 Value resultImagSpecialCase3 = rewriter.create<SelectOp>( 293 loc, finiteNumInfiniteDenom, resultImag4, resultImag); 294 Value resultRealSpecialCase2 = rewriter.create<SelectOp>( 295 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 296 Value resultImagSpecialCase2 = rewriter.create<SelectOp>( 297 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 298 Value resultRealSpecialCase1 = rewriter.create<SelectOp>( 299 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 300 Value resultImagSpecialCase1 = rewriter.create<SelectOp>( 301 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 302 303 Value resultRealIsNaN = 304 rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultReal, zero); 305 Value resultImagIsNaN = 306 rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultImag, zero); 307 Value resultIsNaN = 308 rewriter.create<AndOp>(loc, resultRealIsNaN, resultImagIsNaN); 309 Value resultRealWithSpecialCases = rewriter.create<SelectOp>( 310 loc, resultIsNaN, resultRealSpecialCase1, resultReal); 311 Value resultImagWithSpecialCases = rewriter.create<SelectOp>( 312 loc, resultIsNaN, resultImagSpecialCase1, resultImag); 313 314 rewriter.replaceOpWithNewOp<complex::CreateOp>( 315 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 316 return success(); 317 } 318 }; 319 320 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 321 using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 322 323 LogicalResult 324 matchAndRewrite(complex::ExpOp op, ArrayRef<Value> operands, 325 ConversionPatternRewriter &rewriter) const override { 326 complex::ExpOp::Adaptor transformed(operands); 327 auto loc = op.getLoc(); 328 auto type = transformed.complex().getType().cast<ComplexType>(); 329 auto elementType = type.getElementType().cast<FloatType>(); 330 331 Value real = 332 rewriter.create<complex::ReOp>(loc, elementType, transformed.complex()); 333 Value imag = 334 rewriter.create<complex::ImOp>(loc, elementType, transformed.complex()); 335 Value expReal = rewriter.create<math::ExpOp>(loc, real); 336 Value cosImag = rewriter.create<math::CosOp>(loc, imag); 337 Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag); 338 Value sinImag = rewriter.create<math::SinOp>(loc, imag); 339 Value resultImag = rewriter.create<MulFOp>(loc, expReal, sinImag); 340 341 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 342 resultImag); 343 return success(); 344 } 345 }; 346 347 struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 348 using OpConversionPattern<complex::LogOp>::OpConversionPattern; 349 350 LogicalResult 351 matchAndRewrite(complex::LogOp op, ArrayRef<Value> operands, 352 ConversionPatternRewriter &rewriter) const override { 353 complex::LogOp::Adaptor transformed(operands); 354 auto type = transformed.complex().getType().cast<ComplexType>(); 355 auto elementType = type.getElementType().cast<FloatType>(); 356 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 357 358 Value abs = b.create<complex::AbsOp>(elementType, transformed.complex()); 359 Value resultReal = b.create<math::LogOp>(elementType, abs); 360 Value real = b.create<complex::ReOp>(elementType, transformed.complex()); 361 Value imag = b.create<complex::ImOp>(elementType, transformed.complex()); 362 Value resultImag = b.create<math::Atan2Op>(elementType, imag, real); 363 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 364 resultImag); 365 return success(); 366 } 367 }; 368 369 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 370 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 371 372 LogicalResult 373 matchAndRewrite(complex::Log1pOp op, ArrayRef<Value> operands, 374 ConversionPatternRewriter &rewriter) const override { 375 complex::Log1pOp::Adaptor transformed(operands); 376 auto type = transformed.complex().getType().cast<ComplexType>(); 377 auto elementType = type.getElementType().cast<FloatType>(); 378 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 379 380 Value real = b.create<complex::ReOp>(elementType, transformed.complex()); 381 Value imag = b.create<complex::ImOp>(elementType, transformed.complex()); 382 Value one = 383 b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1)); 384 Value realPlusOne = b.create<AddFOp>(real, one); 385 Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag); 386 rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex); 387 return success(); 388 } 389 }; 390 391 struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 392 using OpConversionPattern<complex::MulOp>::OpConversionPattern; 393 394 LogicalResult 395 matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands, 396 ConversionPatternRewriter &rewriter) const override { 397 complex::MulOp::Adaptor transformed(operands); 398 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 399 auto type = transformed.lhs().getType().cast<ComplexType>(); 400 auto elementType = type.getElementType().cast<FloatType>(); 401 402 Value lhsReal = b.create<complex::ReOp>(elementType, transformed.lhs()); 403 Value lhsRealAbs = b.create<AbsFOp>(lhsReal); 404 Value lhsImag = b.create<complex::ImOp>(elementType, transformed.lhs()); 405 Value lhsImagAbs = b.create<AbsFOp>(lhsImag); 406 Value rhsReal = b.create<complex::ReOp>(elementType, transformed.rhs()); 407 Value rhsRealAbs = b.create<AbsFOp>(rhsReal); 408 Value rhsImag = b.create<complex::ImOp>(elementType, transformed.rhs()); 409 Value rhsImagAbs = b.create<AbsFOp>(rhsImag); 410 411 Value lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal); 412 Value lhsRealTimesRhsRealAbs = b.create<AbsFOp>(lhsRealTimesRhsReal); 413 Value lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag); 414 Value lhsImagTimesRhsImagAbs = b.create<AbsFOp>(lhsImagTimesRhsImag); 415 Value real = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 416 417 Value lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal); 418 Value lhsImagTimesRhsRealAbs = b.create<AbsFOp>(lhsImagTimesRhsReal); 419 Value lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag); 420 Value lhsRealTimesRhsImagAbs = b.create<AbsFOp>(lhsRealTimesRhsImag); 421 Value imag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 422 423 // Handle cases where the "naive" calculation results in NaN values. 424 Value realIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, real, real); 425 Value imagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, imag, imag); 426 Value isNan = b.create<AndOp>(realIsNan, imagIsNan); 427 428 Value inf = b.create<ConstantOp>( 429 elementType, 430 b.getFloatAttr(elementType, 431 APFloat::getInf(elementType.getFloatSemantics()))); 432 433 // Case 1. `lhsReal` or `lhsImag` are infinite. 434 Value lhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealAbs, inf); 435 Value lhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagAbs, inf); 436 Value lhsIsInf = b.create<OrOp>(lhsRealIsInf, lhsImagIsInf); 437 Value rhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsReal, rhsReal); 438 Value rhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsImag, rhsImag); 439 Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType)); 440 Value one = 441 b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1)); 442 Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero); 443 lhsReal = b.create<SelectOp>( 444 lhsIsInf, b.create<CopySignOp>(lhsRealIsInfFloat, lhsReal), lhsReal); 445 Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero); 446 lhsImag = b.create<SelectOp>( 447 lhsIsInf, b.create<CopySignOp>(lhsImagIsInfFloat, lhsImag), lhsImag); 448 Value lhsIsInfAndRhsRealIsNan = b.create<AndOp>(lhsIsInf, rhsRealIsNan); 449 rhsReal = b.create<SelectOp>(lhsIsInfAndRhsRealIsNan, 450 b.create<CopySignOp>(zero, rhsReal), rhsReal); 451 Value lhsIsInfAndRhsImagIsNan = b.create<AndOp>(lhsIsInf, rhsImagIsNan); 452 rhsImag = b.create<SelectOp>(lhsIsInfAndRhsImagIsNan, 453 b.create<CopySignOp>(zero, rhsImag), rhsImag); 454 455 // Case 2. `rhsReal` or `rhsImag` are infinite. 456 Value rhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsRealAbs, inf); 457 Value rhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsImagAbs, inf); 458 Value rhsIsInf = b.create<OrOp>(rhsRealIsInf, rhsImagIsInf); 459 Value lhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsReal, lhsReal); 460 Value lhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsImag, lhsImag); 461 Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero); 462 rhsReal = b.create<SelectOp>( 463 rhsIsInf, b.create<CopySignOp>(rhsRealIsInfFloat, rhsReal), rhsReal); 464 Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero); 465 rhsImag = b.create<SelectOp>( 466 rhsIsInf, b.create<CopySignOp>(rhsImagIsInfFloat, rhsImag), rhsImag); 467 Value rhsIsInfAndLhsRealIsNan = b.create<AndOp>(rhsIsInf, lhsRealIsNan); 468 lhsReal = b.create<SelectOp>(rhsIsInfAndLhsRealIsNan, 469 b.create<CopySignOp>(zero, lhsReal), lhsReal); 470 Value rhsIsInfAndLhsImagIsNan = b.create<AndOp>(rhsIsInf, lhsImagIsNan); 471 lhsImag = b.create<SelectOp>(rhsIsInfAndLhsImagIsNan, 472 b.create<CopySignOp>(zero, lhsImag), lhsImag); 473 Value recalc = b.create<OrOp>(lhsIsInf, rhsIsInf); 474 475 // Case 3. One of the pairwise products of left hand side with right hand 476 // side is infinite. 477 Value lhsRealTimesRhsRealIsInf = 478 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); 479 Value lhsImagTimesRhsImagIsInf = 480 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); 481 Value isSpecialCase = 482 b.create<OrOp>(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf); 483 Value lhsRealTimesRhsImagIsInf = 484 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); 485 isSpecialCase = b.create<OrOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); 486 Value lhsImagTimesRhsRealIsInf = 487 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); 488 isSpecialCase = b.create<OrOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); 489 Type i1Type = b.getI1Type(); 490 Value notRecalc = b.create<XOrOp>( 491 recalc, b.create<ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); 492 isSpecialCase = b.create<AndOp>(isSpecialCase, notRecalc); 493 Value isSpecialCaseAndLhsRealIsNan = 494 b.create<AndOp>(isSpecialCase, lhsRealIsNan); 495 lhsReal = b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan, 496 b.create<CopySignOp>(zero, lhsReal), lhsReal); 497 Value isSpecialCaseAndLhsImagIsNan = 498 b.create<AndOp>(isSpecialCase, lhsImagIsNan); 499 lhsImag = b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan, 500 b.create<CopySignOp>(zero, lhsImag), lhsImag); 501 Value isSpecialCaseAndRhsRealIsNan = 502 b.create<AndOp>(isSpecialCase, rhsRealIsNan); 503 rhsReal = b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan, 504 b.create<CopySignOp>(zero, rhsReal), rhsReal); 505 Value isSpecialCaseAndRhsImagIsNan = 506 b.create<AndOp>(isSpecialCase, rhsImagIsNan); 507 rhsImag = b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan, 508 b.create<CopySignOp>(zero, rhsImag), rhsImag); 509 recalc = b.create<OrOp>(recalc, isSpecialCase); 510 recalc = b.create<AndOp>(isNan, recalc); 511 512 // Recalculate real part. 513 lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal); 514 lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag); 515 Value newReal = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 516 real = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newReal), real); 517 518 // Recalculate imag part. 519 lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal); 520 lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag); 521 Value newImag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 522 imag = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newImag), imag); 523 524 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 525 return success(); 526 } 527 }; 528 529 struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 530 using OpConversionPattern<complex::NegOp>::OpConversionPattern; 531 532 LogicalResult 533 matchAndRewrite(complex::NegOp op, ArrayRef<Value> operands, 534 ConversionPatternRewriter &rewriter) const override { 535 complex::NegOp::Adaptor transformed(operands); 536 auto loc = op.getLoc(); 537 auto type = transformed.complex().getType().cast<ComplexType>(); 538 auto elementType = type.getElementType().cast<FloatType>(); 539 540 Value real = 541 rewriter.create<complex::ReOp>(loc, elementType, transformed.complex()); 542 Value imag = 543 rewriter.create<complex::ImOp>(loc, elementType, transformed.complex()); 544 Value negReal = rewriter.create<NegFOp>(loc, real); 545 Value negImag = rewriter.create<NegFOp>(loc, imag); 546 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 547 return success(); 548 } 549 }; 550 551 struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 552 using OpConversionPattern<complex::SignOp>::OpConversionPattern; 553 554 LogicalResult 555 matchAndRewrite(complex::SignOp op, ArrayRef<Value> operands, 556 ConversionPatternRewriter &rewriter) const override { 557 complex::SignOp::Adaptor transformed(operands); 558 auto type = transformed.complex().getType().cast<ComplexType>(); 559 auto elementType = type.getElementType().cast<FloatType>(); 560 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 561 562 Value real = b.create<complex::ReOp>(elementType, transformed.complex()); 563 Value imag = b.create<complex::ImOp>(elementType, transformed.complex()); 564 Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType)); 565 Value realIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, real, zero); 566 Value imagIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, imag, zero); 567 Value isZero = b.create<AndOp>(realIsZero, imagIsZero); 568 auto abs = b.create<complex::AbsOp>(elementType, transformed.complex()); 569 Value realSign = b.create<DivFOp>(real, abs); 570 Value imagSign = b.create<DivFOp>(imag, abs); 571 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 572 rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, transformed.complex(), 573 sign); 574 return success(); 575 } 576 }; 577 } // namespace 578 579 void mlir::populateComplexToStandardConversionPatterns( 580 RewritePatternSet &patterns) { 581 // clang-format off 582 patterns.add< 583 AbsOpConversion, 584 ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>, 585 ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>, 586 BinaryComplexOpConversion<complex::AddOp, AddFOp>, 587 BinaryComplexOpConversion<complex::SubOp, SubFOp>, 588 DivOpConversion, 589 ExpOpConversion, 590 LogOpConversion, 591 Log1pOpConversion, 592 MulOpConversion, 593 NegOpConversion, 594 SignOpConversion>(patterns.getContext()); 595 // clang-format on 596 } 597 598 namespace { 599 struct ConvertComplexToStandardPass 600 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 601 void runOnFunction() override; 602 }; 603 604 void ConvertComplexToStandardPass::runOnFunction() { 605 auto function = getFunction(); 606 607 // Convert to the Standard dialect using the converter defined above. 608 RewritePatternSet patterns(&getContext()); 609 populateComplexToStandardConversionPatterns(patterns); 610 611 ConversionTarget target(getContext()); 612 target.addLegalDialect<StandardOpsDialect, math::MathDialect>(); 613 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 614 if (failed(applyPartialConversion(function, target, std::move(patterns)))) 615 signalPassFailure(); 616 } 617 } // namespace 618 619 std::unique_ptr<OperationPass<FuncOp>> 620 mlir::createConvertComplexToStandardPass() { 621 return std::make_unique<ConvertComplexToStandardPass>(); 622 } 623