12ea7fb7bSAdrian Kuegel //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 22ea7fb7bSAdrian Kuegel // 32ea7fb7bSAdrian Kuegel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42ea7fb7bSAdrian Kuegel // See https://llvm.org/LICENSE.txt for license information. 52ea7fb7bSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62ea7fb7bSAdrian Kuegel // 72ea7fb7bSAdrian Kuegel //===----------------------------------------------------------------------===// 82ea7fb7bSAdrian Kuegel 92ea7fb7bSAdrian Kuegel #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 102ea7fb7bSAdrian Kuegel 112ea7fb7bSAdrian Kuegel #include <memory> 12fb8b2b86SAdrian Kuegel #include <type_traits> 132ea7fb7bSAdrian Kuegel 142ea7fb7bSAdrian Kuegel #include "../PassDetail.h" 15a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 162ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Complex/IR/Complex.h" 172ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Math/IR/Math.h" 18f112bd61SAdrian Kuegel #include "mlir/IR/ImplicitLocOpBuilder.h" 192ea7fb7bSAdrian Kuegel #include "mlir/IR/PatternMatch.h" 202ea7fb7bSAdrian Kuegel #include "mlir/Transforms/DialectConversion.h" 212ea7fb7bSAdrian Kuegel 222ea7fb7bSAdrian Kuegel using namespace mlir; 232ea7fb7bSAdrian Kuegel 242ea7fb7bSAdrian Kuegel namespace { 252ea7fb7bSAdrian Kuegel struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 262ea7fb7bSAdrian Kuegel using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 272ea7fb7bSAdrian Kuegel 282ea7fb7bSAdrian Kuegel LogicalResult 29b54c724bSRiver Riddle matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 302ea7fb7bSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 312ea7fb7bSAdrian Kuegel auto loc = op.getLoc(); 322ea7fb7bSAdrian Kuegel auto type = op.getType(); 332ea7fb7bSAdrian Kuegel 34c0342a2dSJacques Pienaar Value real = 35c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); 36c0342a2dSJacques Pienaar Value imag = 37c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); 38a54f4eaeSMogball Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real); 39a54f4eaeSMogball Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag); 40a54f4eaeSMogball Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr); 412ea7fb7bSAdrian Kuegel 422ea7fb7bSAdrian Kuegel rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 432ea7fb7bSAdrian Kuegel return success(); 442ea7fb7bSAdrian Kuegel } 452ea7fb7bSAdrian Kuegel }; 46ac00cb0dSAdrian Kuegel 47a54f4eaeSMogball template <typename ComparisonOp, arith::CmpFPredicate p> 48fb8b2b86SAdrian Kuegel struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 49fb8b2b86SAdrian Kuegel using OpConversionPattern<ComparisonOp>::OpConversionPattern; 50fb8b2b86SAdrian Kuegel using ResultCombiner = 51fb8b2b86SAdrian Kuegel std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 52a54f4eaeSMogball arith::AndIOp, arith::OrIOp>; 53ac00cb0dSAdrian Kuegel 54ac00cb0dSAdrian Kuegel LogicalResult 55b54c724bSRiver Riddle matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, 56ac00cb0dSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 57ac00cb0dSAdrian Kuegel auto loc = op.getLoc(); 58c0342a2dSJacques Pienaar auto type = adaptor.getLhs() 59c0342a2dSJacques Pienaar .getType() 60c0342a2dSJacques Pienaar .template cast<ComplexType>() 61c0342a2dSJacques Pienaar .getElementType(); 62ac00cb0dSAdrian Kuegel 63c0342a2dSJacques Pienaar Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); 64c0342a2dSJacques Pienaar Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); 65c0342a2dSJacques Pienaar Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); 66c0342a2dSJacques Pienaar Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); 67a54f4eaeSMogball Value realComparison = 68a54f4eaeSMogball rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); 69a54f4eaeSMogball Value imagComparison = 70a54f4eaeSMogball rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); 71ac00cb0dSAdrian Kuegel 72fb8b2b86SAdrian Kuegel rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 73fb8b2b86SAdrian Kuegel imagComparison); 74ac00cb0dSAdrian Kuegel return success(); 75ac00cb0dSAdrian Kuegel } 76ac00cb0dSAdrian Kuegel }; 77942be7cbSAdrian Kuegel 78fb978f09SAdrian Kuegel // Default conversion which applies the BinaryStandardOp separately on the real 79fb978f09SAdrian Kuegel // and imaginary parts. Can for example be used for complex::AddOp and 80fb978f09SAdrian Kuegel // complex::SubOp. 81fb978f09SAdrian Kuegel template <typename BinaryComplexOp, typename BinaryStandardOp> 82fb978f09SAdrian Kuegel struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 83fb978f09SAdrian Kuegel using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 84fb978f09SAdrian Kuegel 85fb978f09SAdrian Kuegel LogicalResult 86b54c724bSRiver Riddle matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, 87fb978f09SAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 88c0342a2dSJacques Pienaar auto type = adaptor.getLhs().getType().template cast<ComplexType>(); 89fb978f09SAdrian Kuegel auto elementType = type.getElementType().template cast<FloatType>(); 90fb978f09SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 91fb978f09SAdrian Kuegel 92c0342a2dSJacques Pienaar Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 93c0342a2dSJacques Pienaar Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 94fb978f09SAdrian Kuegel Value resultReal = 95fb978f09SAdrian Kuegel b.create<BinaryStandardOp>(elementType, realLhs, realRhs); 96c0342a2dSJacques Pienaar Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 97c0342a2dSJacques Pienaar Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 98fb978f09SAdrian Kuegel Value resultImag = 99fb978f09SAdrian Kuegel b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs); 100fb978f09SAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 101fb978f09SAdrian Kuegel resultImag); 102fb978f09SAdrian Kuegel return success(); 103fb978f09SAdrian Kuegel } 104fb978f09SAdrian Kuegel }; 105fb978f09SAdrian Kuegel 106672b908bSGoran Flegar template <typename TrigonometricOp> 107672b908bSGoran Flegar struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> { 108672b908bSGoran Flegar using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor; 109672b908bSGoran Flegar 110672b908bSGoran Flegar using OpConversionPattern<TrigonometricOp>::OpConversionPattern; 111672b908bSGoran Flegar 112672b908bSGoran Flegar LogicalResult 113672b908bSGoran Flegar matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, 114672b908bSGoran Flegar ConversionPatternRewriter &rewriter) const override { 115672b908bSGoran Flegar auto loc = op.getLoc(); 116672b908bSGoran Flegar auto type = adaptor.getComplex().getType().template cast<ComplexType>(); 117672b908bSGoran Flegar auto elementType = type.getElementType().template cast<FloatType>(); 118672b908bSGoran Flegar 119672b908bSGoran Flegar Value real = 120672b908bSGoran Flegar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 121672b908bSGoran Flegar Value imag = 122672b908bSGoran Flegar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 123672b908bSGoran Flegar 124672b908bSGoran Flegar // Trigonometric ops use a set of common building blocks to convert to real 125672b908bSGoran Flegar // ops. Here we create these building blocks and call into an op-specific 126672b908bSGoran Flegar // implementation in the subclass to combine them. 127672b908bSGoran Flegar Value half = rewriter.create<arith::ConstantOp>( 128672b908bSGoran Flegar loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); 129672b908bSGoran Flegar Value exp = rewriter.create<math::ExpOp>(loc, imag); 130672b908bSGoran Flegar Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp); 131672b908bSGoran Flegar Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp); 132672b908bSGoran Flegar Value sin = rewriter.create<math::SinOp>(loc, real); 133672b908bSGoran Flegar Value cos = rewriter.create<math::CosOp>(loc, real); 134672b908bSGoran Flegar 135672b908bSGoran Flegar auto resultPair = 136672b908bSGoran Flegar combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter); 137672b908bSGoran Flegar 138672b908bSGoran Flegar rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first, 139672b908bSGoran Flegar resultPair.second); 140672b908bSGoran Flegar return success(); 141672b908bSGoran Flegar } 142672b908bSGoran Flegar 143672b908bSGoran Flegar virtual std::pair<Value, Value> 144672b908bSGoran Flegar combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 145672b908bSGoran Flegar Value cos, ConversionPatternRewriter &rewriter) const = 0; 146672b908bSGoran Flegar }; 147672b908bSGoran Flegar 148672b908bSGoran Flegar struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> { 149672b908bSGoran Flegar using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion; 150672b908bSGoran Flegar 151672b908bSGoran Flegar std::pair<Value, Value> 152672b908bSGoran Flegar combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 153672b908bSGoran Flegar Value cos, ConversionPatternRewriter &rewriter) const override { 154672b908bSGoran Flegar // Complex cosine is defined as; 155672b908bSGoran Flegar // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) 156672b908bSGoran Flegar // Plugging in: 157672b908bSGoran Flegar // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 158672b908bSGoran Flegar // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 159672b908bSGoran Flegar // and defining t := exp(y) 160672b908bSGoran Flegar // We get: 161672b908bSGoran Flegar // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x 162672b908bSGoran Flegar // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x 163672b908bSGoran Flegar Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp); 164672b908bSGoran Flegar Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos); 165672b908bSGoran Flegar Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp); 166672b908bSGoran Flegar Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin); 167672b908bSGoran Flegar return {resultReal, resultImag}; 168672b908bSGoran Flegar } 169672b908bSGoran Flegar }; 170672b908bSGoran Flegar 171942be7cbSAdrian Kuegel struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 172942be7cbSAdrian Kuegel using OpConversionPattern<complex::DivOp>::OpConversionPattern; 173942be7cbSAdrian Kuegel 174942be7cbSAdrian Kuegel LogicalResult 175b54c724bSRiver Riddle matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 176942be7cbSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 177942be7cbSAdrian Kuegel auto loc = op.getLoc(); 178c0342a2dSJacques Pienaar auto type = adaptor.getLhs().getType().cast<ComplexType>(); 179942be7cbSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 180942be7cbSAdrian Kuegel 181942be7cbSAdrian Kuegel Value lhsReal = 182c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); 183942be7cbSAdrian Kuegel Value lhsImag = 184c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); 185942be7cbSAdrian Kuegel Value rhsReal = 186c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); 187942be7cbSAdrian Kuegel Value rhsImag = 188c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); 189942be7cbSAdrian Kuegel 190942be7cbSAdrian Kuegel // Smith's algorithm to divide complex numbers. It is just a bit smarter 191942be7cbSAdrian Kuegel // way to compute the following formula: 192942be7cbSAdrian Kuegel // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 193942be7cbSAdrian Kuegel // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 194942be7cbSAdrian Kuegel // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 195942be7cbSAdrian Kuegel // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 196942be7cbSAdrian Kuegel // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 197942be7cbSAdrian Kuegel // 198942be7cbSAdrian Kuegel // Depending on whether |rhsReal| < |rhsImag| we compute either 199942be7cbSAdrian Kuegel // rhsRealImagRatio = rhsReal / rhsImag 200942be7cbSAdrian Kuegel // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 201942be7cbSAdrian Kuegel // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 202942be7cbSAdrian Kuegel // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 203942be7cbSAdrian Kuegel // 204942be7cbSAdrian Kuegel // or 205942be7cbSAdrian Kuegel // 206942be7cbSAdrian Kuegel // rhsImagRealRatio = rhsImag / rhsReal 207942be7cbSAdrian Kuegel // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 208942be7cbSAdrian Kuegel // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 209942be7cbSAdrian Kuegel // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 210942be7cbSAdrian Kuegel // 211942be7cbSAdrian Kuegel // See https://dl.acm.org/citation.cfm?id=368661 for more details. 212a54f4eaeSMogball Value rhsRealImagRatio = 213a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag); 214a54f4eaeSMogball Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( 215a54f4eaeSMogball loc, rhsImag, 216a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal)); 217a54f4eaeSMogball Value realNumerator1 = rewriter.create<arith::AddFOp>( 218a54f4eaeSMogball loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio), 219a54f4eaeSMogball lhsImag); 220942be7cbSAdrian Kuegel Value resultReal1 = 221a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom); 222a54f4eaeSMogball Value imagNumerator1 = rewriter.create<arith::SubFOp>( 223a54f4eaeSMogball loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio), 224a54f4eaeSMogball lhsReal); 225942be7cbSAdrian Kuegel Value resultImag1 = 226a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 227942be7cbSAdrian Kuegel 228a54f4eaeSMogball Value rhsImagRealRatio = 229a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal); 230a54f4eaeSMogball Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( 231a54f4eaeSMogball loc, rhsReal, 232a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag)); 233a54f4eaeSMogball Value realNumerator2 = rewriter.create<arith::AddFOp>( 234a54f4eaeSMogball loc, lhsReal, 235a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio)); 236942be7cbSAdrian Kuegel Value resultReal2 = 237a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom); 238a54f4eaeSMogball Value imagNumerator2 = rewriter.create<arith::SubFOp>( 239a54f4eaeSMogball loc, lhsImag, 240a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio)); 241942be7cbSAdrian Kuegel Value resultImag2 = 242a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 243942be7cbSAdrian Kuegel 244942be7cbSAdrian Kuegel // Consider corner cases. 245942be7cbSAdrian Kuegel // Case 1. Zero denominator, numerator contains at most one NaN value. 246a54f4eaeSMogball Value zero = rewriter.create<arith::ConstantOp>( 247a54f4eaeSMogball loc, elementType, rewriter.getZeroAttr(elementType)); 248a54f4eaeSMogball Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal); 249a54f4eaeSMogball Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( 250a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); 251a54f4eaeSMogball Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag); 252a54f4eaeSMogball Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( 253a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); 254a54f4eaeSMogball Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( 255a54f4eaeSMogball loc, arith::CmpFPredicate::ORD, lhsReal, zero); 256a54f4eaeSMogball Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( 257a54f4eaeSMogball loc, arith::CmpFPredicate::ORD, lhsImag, zero); 258942be7cbSAdrian Kuegel Value lhsContainsNotNaNValue = 259a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 260a54f4eaeSMogball Value resultIsInfinity = rewriter.create<arith::AndIOp>( 261942be7cbSAdrian Kuegel loc, lhsContainsNotNaNValue, 262a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); 263a54f4eaeSMogball Value inf = rewriter.create<arith::ConstantOp>( 264942be7cbSAdrian Kuegel loc, elementType, 265942be7cbSAdrian Kuegel rewriter.getFloatAttr( 266942be7cbSAdrian Kuegel elementType, APFloat::getInf(elementType.getFloatSemantics()))); 267a54f4eaeSMogball Value infWithSignOfRhsReal = 268a54f4eaeSMogball rewriter.create<math::CopySignOp>(loc, inf, rhsReal); 269942be7cbSAdrian Kuegel Value infinityResultReal = 270a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 271942be7cbSAdrian Kuegel Value infinityResultImag = 272a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 273942be7cbSAdrian Kuegel 274942be7cbSAdrian Kuegel // Case 2. Infinite numerator, finite denominator. 275a54f4eaeSMogball Value rhsRealFinite = rewriter.create<arith::CmpFOp>( 276a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); 277a54f4eaeSMogball Value rhsImagFinite = rewriter.create<arith::CmpFOp>( 278a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); 279a54f4eaeSMogball Value rhsFinite = 280a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); 281a54f4eaeSMogball Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal); 282a54f4eaeSMogball Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( 283a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 284a54f4eaeSMogball Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag); 285a54f4eaeSMogball Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( 286a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 287942be7cbSAdrian Kuegel Value lhsInfinite = 288a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); 289942be7cbSAdrian Kuegel Value infNumFiniteDenom = 290a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); 291a54f4eaeSMogball Value one = rewriter.create<arith::ConstantOp>( 292942be7cbSAdrian Kuegel loc, elementType, rewriter.getFloatAttr(elementType, 1)); 293a54f4eaeSMogball Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 294dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero), 295942be7cbSAdrian Kuegel lhsReal); 296a54f4eaeSMogball Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 297dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero), 298942be7cbSAdrian Kuegel lhsImag); 299942be7cbSAdrian Kuegel Value lhsRealIsInfWithSignTimesRhsReal = 300a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 301942be7cbSAdrian Kuegel Value lhsImagIsInfWithSignTimesRhsImag = 302a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 303a54f4eaeSMogball Value resultReal3 = rewriter.create<arith::MulFOp>( 304942be7cbSAdrian Kuegel loc, inf, 305a54f4eaeSMogball rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 306942be7cbSAdrian Kuegel lhsImagIsInfWithSignTimesRhsImag)); 307942be7cbSAdrian Kuegel Value lhsRealIsInfWithSignTimesRhsImag = 308a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 309942be7cbSAdrian Kuegel Value lhsImagIsInfWithSignTimesRhsReal = 310a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 311a54f4eaeSMogball Value resultImag3 = rewriter.create<arith::MulFOp>( 312942be7cbSAdrian Kuegel loc, inf, 313a54f4eaeSMogball rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 314942be7cbSAdrian Kuegel lhsRealIsInfWithSignTimesRhsImag)); 315942be7cbSAdrian Kuegel 316942be7cbSAdrian Kuegel // Case 3: Finite numerator, infinite denominator. 317a54f4eaeSMogball Value lhsRealFinite = rewriter.create<arith::CmpFOp>( 318a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); 319a54f4eaeSMogball Value lhsImagFinite = rewriter.create<arith::CmpFOp>( 320a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); 321a54f4eaeSMogball Value lhsFinite = 322a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); 323a54f4eaeSMogball Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( 324a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 325a54f4eaeSMogball Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( 326a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 327942be7cbSAdrian Kuegel Value rhsInfinite = 328a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); 329942be7cbSAdrian Kuegel Value finiteNumInfiniteDenom = 330a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); 331a54f4eaeSMogball Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 332dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero), 333942be7cbSAdrian Kuegel rhsReal); 334a54f4eaeSMogball Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 335dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero), 336942be7cbSAdrian Kuegel rhsImag); 337942be7cbSAdrian Kuegel Value rhsRealIsInfWithSignTimesLhsReal = 338a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 339942be7cbSAdrian Kuegel Value rhsImagIsInfWithSignTimesLhsImag = 340a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 341a54f4eaeSMogball Value resultReal4 = rewriter.create<arith::MulFOp>( 342942be7cbSAdrian Kuegel loc, zero, 343a54f4eaeSMogball rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 344942be7cbSAdrian Kuegel rhsImagIsInfWithSignTimesLhsImag)); 345942be7cbSAdrian Kuegel Value rhsRealIsInfWithSignTimesLhsImag = 346a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 347942be7cbSAdrian Kuegel Value rhsImagIsInfWithSignTimesLhsReal = 348a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 349a54f4eaeSMogball Value resultImag4 = rewriter.create<arith::MulFOp>( 350942be7cbSAdrian Kuegel loc, zero, 351a54f4eaeSMogball rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 352942be7cbSAdrian Kuegel rhsImagIsInfWithSignTimesLhsReal)); 353942be7cbSAdrian Kuegel 354a54f4eaeSMogball Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( 355a54f4eaeSMogball loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 356dec8af70SRiver Riddle Value resultReal = rewriter.create<arith::SelectOp>( 357dec8af70SRiver Riddle loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); 358dec8af70SRiver Riddle Value resultImag = rewriter.create<arith::SelectOp>( 359dec8af70SRiver Riddle loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); 360dec8af70SRiver Riddle Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>( 361942be7cbSAdrian Kuegel loc, finiteNumInfiniteDenom, resultReal4, resultReal); 362dec8af70SRiver Riddle Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>( 363942be7cbSAdrian Kuegel loc, finiteNumInfiniteDenom, resultImag4, resultImag); 364dec8af70SRiver Riddle Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>( 365942be7cbSAdrian Kuegel loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 366dec8af70SRiver Riddle Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>( 367942be7cbSAdrian Kuegel loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 368dec8af70SRiver Riddle Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>( 369942be7cbSAdrian Kuegel loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 370dec8af70SRiver Riddle Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>( 371942be7cbSAdrian Kuegel loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 372942be7cbSAdrian Kuegel 373a54f4eaeSMogball Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( 374a54f4eaeSMogball loc, arith::CmpFPredicate::UNO, resultReal, zero); 375a54f4eaeSMogball Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( 376a54f4eaeSMogball loc, arith::CmpFPredicate::UNO, resultImag, zero); 377942be7cbSAdrian Kuegel Value resultIsNaN = 378a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); 379dec8af70SRiver Riddle Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>( 380942be7cbSAdrian Kuegel loc, resultIsNaN, resultRealSpecialCase1, resultReal); 381dec8af70SRiver Riddle Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>( 382942be7cbSAdrian Kuegel loc, resultIsNaN, resultImagSpecialCase1, resultImag); 383942be7cbSAdrian Kuegel 384942be7cbSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>( 385942be7cbSAdrian Kuegel op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 386942be7cbSAdrian Kuegel return success(); 387942be7cbSAdrian Kuegel } 388942be7cbSAdrian Kuegel }; 38973cbc91cSAdrian Kuegel 39073cbc91cSAdrian Kuegel struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 39173cbc91cSAdrian Kuegel using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 39273cbc91cSAdrian Kuegel 39373cbc91cSAdrian Kuegel LogicalResult 394b54c724bSRiver Riddle matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, 39573cbc91cSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 39673cbc91cSAdrian Kuegel auto loc = op.getLoc(); 397c0342a2dSJacques Pienaar auto type = adaptor.getComplex().getType().cast<ComplexType>(); 39873cbc91cSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 39973cbc91cSAdrian Kuegel 40073cbc91cSAdrian Kuegel Value real = 401c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 40273cbc91cSAdrian Kuegel Value imag = 403c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 40473cbc91cSAdrian Kuegel Value expReal = rewriter.create<math::ExpOp>(loc, real); 40573cbc91cSAdrian Kuegel Value cosImag = rewriter.create<math::CosOp>(loc, imag); 406a54f4eaeSMogball Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag); 40773cbc91cSAdrian Kuegel Value sinImag = rewriter.create<math::SinOp>(loc, imag); 408a54f4eaeSMogball Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag); 40973cbc91cSAdrian Kuegel 41073cbc91cSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 41173cbc91cSAdrian Kuegel resultImag); 41273cbc91cSAdrian Kuegel return success(); 41373cbc91cSAdrian Kuegel } 41473cbc91cSAdrian Kuegel }; 415662e074dSAdrian Kuegel 416338e76f8Sbixia1 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> { 417338e76f8Sbixia1 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern; 418338e76f8Sbixia1 419338e76f8Sbixia1 LogicalResult 420338e76f8Sbixia1 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, 421338e76f8Sbixia1 ConversionPatternRewriter &rewriter) const override { 422338e76f8Sbixia1 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 423338e76f8Sbixia1 auto elementType = type.getElementType().cast<FloatType>(); 424338e76f8Sbixia1 425338e76f8Sbixia1 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 426338e76f8Sbixia1 Value exp = b.create<complex::ExpOp>(adaptor.getComplex()); 427338e76f8Sbixia1 428338e76f8Sbixia1 Value real = b.create<complex::ReOp>(elementType, exp); 429338e76f8Sbixia1 Value one = b.create<arith::ConstantOp>(elementType, 430338e76f8Sbixia1 b.getFloatAttr(elementType, 1)); 431338e76f8Sbixia1 Value realMinusOne = b.create<arith::SubFOp>(real, one); 432338e76f8Sbixia1 Value imag = b.create<complex::ImOp>(elementType, exp); 433338e76f8Sbixia1 434338e76f8Sbixia1 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne, 435338e76f8Sbixia1 imag); 436338e76f8Sbixia1 return success(); 437338e76f8Sbixia1 } 438338e76f8Sbixia1 }; 439338e76f8Sbixia1 440380fa71fSAdrian Kuegel struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 441380fa71fSAdrian Kuegel using OpConversionPattern<complex::LogOp>::OpConversionPattern; 442380fa71fSAdrian Kuegel 443380fa71fSAdrian Kuegel LogicalResult 444b54c724bSRiver Riddle matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, 445380fa71fSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 446c0342a2dSJacques Pienaar auto type = adaptor.getComplex().getType().cast<ComplexType>(); 447380fa71fSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 448380fa71fSAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 449380fa71fSAdrian Kuegel 450c0342a2dSJacques Pienaar Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 451380fa71fSAdrian Kuegel Value resultReal = b.create<math::LogOp>(elementType, abs); 452c0342a2dSJacques Pienaar Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 453c0342a2dSJacques Pienaar Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 454380fa71fSAdrian Kuegel Value resultImag = b.create<math::Atan2Op>(elementType, imag, real); 455380fa71fSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 456380fa71fSAdrian Kuegel resultImag); 457380fa71fSAdrian Kuegel return success(); 458380fa71fSAdrian Kuegel } 459380fa71fSAdrian Kuegel }; 460380fa71fSAdrian Kuegel 4616e80e3bdSAdrian Kuegel struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 4626e80e3bdSAdrian Kuegel using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 4636e80e3bdSAdrian Kuegel 4646e80e3bdSAdrian Kuegel LogicalResult 465b54c724bSRiver Riddle matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, 4666e80e3bdSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 467c0342a2dSJacques Pienaar auto type = adaptor.getComplex().getType().cast<ComplexType>(); 4686e80e3bdSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 4696e80e3bdSAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 4706e80e3bdSAdrian Kuegel 471c0342a2dSJacques Pienaar Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 472c0342a2dSJacques Pienaar Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 473a54f4eaeSMogball Value one = b.create<arith::ConstantOp>(elementType, 474a54f4eaeSMogball b.getFloatAttr(elementType, 1)); 475a54f4eaeSMogball Value realPlusOne = b.create<arith::AddFOp>(real, one); 4766e80e3bdSAdrian Kuegel Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag); 4776e80e3bdSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex); 4786e80e3bdSAdrian Kuegel return success(); 4796e80e3bdSAdrian Kuegel } 4806e80e3bdSAdrian Kuegel }; 4816e80e3bdSAdrian Kuegel 482bf17ee19SAdrian Kuegel struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 483bf17ee19SAdrian Kuegel using OpConversionPattern<complex::MulOp>::OpConversionPattern; 484bf17ee19SAdrian Kuegel 485bf17ee19SAdrian Kuegel LogicalResult 486b54c724bSRiver Riddle matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 487bf17ee19SAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 488bf17ee19SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 489c0342a2dSJacques Pienaar auto type = adaptor.getLhs().getType().cast<ComplexType>(); 490bf17ee19SAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 491bf17ee19SAdrian Kuegel 492c0342a2dSJacques Pienaar Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 493a54f4eaeSMogball Value lhsRealAbs = b.create<math::AbsOp>(lhsReal); 494c0342a2dSJacques Pienaar Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 495a54f4eaeSMogball Value lhsImagAbs = b.create<math::AbsOp>(lhsImag); 496c0342a2dSJacques Pienaar Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 497a54f4eaeSMogball Value rhsRealAbs = b.create<math::AbsOp>(rhsReal); 498c0342a2dSJacques Pienaar Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 499a54f4eaeSMogball Value rhsImagAbs = b.create<math::AbsOp>(rhsImag); 500bf17ee19SAdrian Kuegel 501a54f4eaeSMogball Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 502a54f4eaeSMogball Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal); 503a54f4eaeSMogball Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 504a54f4eaeSMogball Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag); 505a54f4eaeSMogball Value real = 506a54f4eaeSMogball b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 507bf17ee19SAdrian Kuegel 508a54f4eaeSMogball Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 509a54f4eaeSMogball Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal); 510a54f4eaeSMogball Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 511a54f4eaeSMogball Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag); 512a54f4eaeSMogball Value imag = 513a54f4eaeSMogball b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 514bf17ee19SAdrian Kuegel 515bf17ee19SAdrian Kuegel // Handle cases where the "naive" calculation results in NaN values. 516a54f4eaeSMogball Value realIsNan = 517a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real); 518a54f4eaeSMogball Value imagIsNan = 519a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag); 520a54f4eaeSMogball Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan); 521bf17ee19SAdrian Kuegel 522a54f4eaeSMogball Value inf = b.create<arith::ConstantOp>( 523bf17ee19SAdrian Kuegel elementType, 524bf17ee19SAdrian Kuegel b.getFloatAttr(elementType, 525bf17ee19SAdrian Kuegel APFloat::getInf(elementType.getFloatSemantics()))); 526bf17ee19SAdrian Kuegel 527bf17ee19SAdrian Kuegel // Case 1. `lhsReal` or `lhsImag` are infinite. 528a54f4eaeSMogball Value lhsRealIsInf = 529a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 530a54f4eaeSMogball Value lhsImagIsInf = 531a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 532a54f4eaeSMogball Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf); 533a54f4eaeSMogball Value rhsRealIsNan = 534a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal); 535a54f4eaeSMogball Value rhsImagIsNan = 536a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag); 537a54f4eaeSMogball Value zero = 538a54f4eaeSMogball b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 539a54f4eaeSMogball Value one = b.create<arith::ConstantOp>(elementType, 540a54f4eaeSMogball b.getFloatAttr(elementType, 1)); 541dec8af70SRiver Riddle Value lhsRealIsInfFloat = 542dec8af70SRiver Riddle b.create<arith::SelectOp>(lhsRealIsInf, one, zero); 543dec8af70SRiver Riddle lhsReal = b.create<arith::SelectOp>( 544a54f4eaeSMogball lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal), 545a54f4eaeSMogball lhsReal); 546dec8af70SRiver Riddle Value lhsImagIsInfFloat = 547dec8af70SRiver Riddle b.create<arith::SelectOp>(lhsImagIsInf, one, zero); 548dec8af70SRiver Riddle lhsImag = b.create<arith::SelectOp>( 549a54f4eaeSMogball lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag), 550a54f4eaeSMogball lhsImag); 551a54f4eaeSMogball Value lhsIsInfAndRhsRealIsNan = 552a54f4eaeSMogball b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan); 553dec8af70SRiver Riddle rhsReal = b.create<arith::SelectOp>( 554dec8af70SRiver Riddle lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 555dec8af70SRiver Riddle rhsReal); 556a54f4eaeSMogball Value lhsIsInfAndRhsImagIsNan = 557a54f4eaeSMogball b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan); 558dec8af70SRiver Riddle rhsImag = b.create<arith::SelectOp>( 559dec8af70SRiver Riddle lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 560dec8af70SRiver Riddle rhsImag); 561bf17ee19SAdrian Kuegel 562bf17ee19SAdrian Kuegel // Case 2. `rhsReal` or `rhsImag` are infinite. 563a54f4eaeSMogball Value rhsRealIsInf = 564a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 565a54f4eaeSMogball Value rhsImagIsInf = 566a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 567a54f4eaeSMogball Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf); 568a54f4eaeSMogball Value lhsRealIsNan = 569a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal); 570a54f4eaeSMogball Value lhsImagIsNan = 571a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag); 572dec8af70SRiver Riddle Value rhsRealIsInfFloat = 573dec8af70SRiver Riddle b.create<arith::SelectOp>(rhsRealIsInf, one, zero); 574dec8af70SRiver Riddle rhsReal = b.create<arith::SelectOp>( 575a54f4eaeSMogball rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal), 576a54f4eaeSMogball rhsReal); 577dec8af70SRiver Riddle Value rhsImagIsInfFloat = 578dec8af70SRiver Riddle b.create<arith::SelectOp>(rhsImagIsInf, one, zero); 579dec8af70SRiver Riddle rhsImag = b.create<arith::SelectOp>( 580a54f4eaeSMogball rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag), 581a54f4eaeSMogball rhsImag); 582a54f4eaeSMogball Value rhsIsInfAndLhsRealIsNan = 583a54f4eaeSMogball b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan); 584dec8af70SRiver Riddle lhsReal = b.create<arith::SelectOp>( 585dec8af70SRiver Riddle rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 586dec8af70SRiver Riddle lhsReal); 587a54f4eaeSMogball Value rhsIsInfAndLhsImagIsNan = 588a54f4eaeSMogball b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan); 589dec8af70SRiver Riddle lhsImag = b.create<arith::SelectOp>( 590dec8af70SRiver Riddle rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 591dec8af70SRiver Riddle lhsImag); 592a54f4eaeSMogball Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf); 593bf17ee19SAdrian Kuegel 594bf17ee19SAdrian Kuegel // Case 3. One of the pairwise products of left hand side with right hand 595bf17ee19SAdrian Kuegel // side is infinite. 596a54f4eaeSMogball Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>( 597a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); 598a54f4eaeSMogball Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>( 599a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); 600a54f4eaeSMogball Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf, 601a54f4eaeSMogball lhsImagTimesRhsImagIsInf); 602a54f4eaeSMogball Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>( 603a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); 604a54f4eaeSMogball isSpecialCase = 605a54f4eaeSMogball b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); 606a54f4eaeSMogball Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>( 607a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); 608a54f4eaeSMogball isSpecialCase = 609a54f4eaeSMogball b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); 610bf17ee19SAdrian Kuegel Type i1Type = b.getI1Type(); 611a54f4eaeSMogball Value notRecalc = b.create<arith::XOrIOp>( 612a54f4eaeSMogball recalc, 613a54f4eaeSMogball b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); 614a54f4eaeSMogball isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc); 615bf17ee19SAdrian Kuegel Value isSpecialCaseAndLhsRealIsNan = 616a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan); 617dec8af70SRiver Riddle lhsReal = b.create<arith::SelectOp>( 618dec8af70SRiver Riddle isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 619dec8af70SRiver Riddle lhsReal); 620bf17ee19SAdrian Kuegel Value isSpecialCaseAndLhsImagIsNan = 621a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan); 622dec8af70SRiver Riddle lhsImag = b.create<arith::SelectOp>( 623dec8af70SRiver Riddle isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 624dec8af70SRiver Riddle lhsImag); 625bf17ee19SAdrian Kuegel Value isSpecialCaseAndRhsRealIsNan = 626a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan); 627dec8af70SRiver Riddle rhsReal = b.create<arith::SelectOp>( 628dec8af70SRiver Riddle isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 629dec8af70SRiver Riddle rhsReal); 630bf17ee19SAdrian Kuegel Value isSpecialCaseAndRhsImagIsNan = 631a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan); 632dec8af70SRiver Riddle rhsImag = b.create<arith::SelectOp>( 633dec8af70SRiver Riddle isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 634dec8af70SRiver Riddle rhsImag); 635a54f4eaeSMogball recalc = b.create<arith::OrIOp>(recalc, isSpecialCase); 636a54f4eaeSMogball recalc = b.create<arith::AndIOp>(isNan, recalc); 637bf17ee19SAdrian Kuegel 638bf17ee19SAdrian Kuegel // Recalculate real part. 639a54f4eaeSMogball lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 640a54f4eaeSMogball lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 641a54f4eaeSMogball Value newReal = 642a54f4eaeSMogball b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 643dec8af70SRiver Riddle real = b.create<arith::SelectOp>( 644dec8af70SRiver Riddle recalc, b.create<arith::MulFOp>(inf, newReal), real); 645bf17ee19SAdrian Kuegel 646bf17ee19SAdrian Kuegel // Recalculate imag part. 647a54f4eaeSMogball lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 648a54f4eaeSMogball lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 649a54f4eaeSMogball Value newImag = 650a54f4eaeSMogball b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 651dec8af70SRiver Riddle imag = b.create<arith::SelectOp>( 652dec8af70SRiver Riddle recalc, b.create<arith::MulFOp>(inf, newImag), imag); 653bf17ee19SAdrian Kuegel 654bf17ee19SAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 655bf17ee19SAdrian Kuegel return success(); 656bf17ee19SAdrian Kuegel } 657bf17ee19SAdrian Kuegel }; 658bf17ee19SAdrian Kuegel 659662e074dSAdrian Kuegel struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 660662e074dSAdrian Kuegel using OpConversionPattern<complex::NegOp>::OpConversionPattern; 661662e074dSAdrian Kuegel 662662e074dSAdrian Kuegel LogicalResult 663b54c724bSRiver Riddle matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, 664662e074dSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 665662e074dSAdrian Kuegel auto loc = op.getLoc(); 666c0342a2dSJacques Pienaar auto type = adaptor.getComplex().getType().cast<ComplexType>(); 667662e074dSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 668662e074dSAdrian Kuegel 669662e074dSAdrian Kuegel Value real = 670c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 671662e074dSAdrian Kuegel Value imag = 672c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 673a54f4eaeSMogball Value negReal = rewriter.create<arith::NegFOp>(loc, real); 674a54f4eaeSMogball Value negImag = rewriter.create<arith::NegFOp>(loc, imag); 675662e074dSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 676662e074dSAdrian Kuegel return success(); 677662e074dSAdrian Kuegel } 678662e074dSAdrian Kuegel }; 679f112bd61SAdrian Kuegel 680672b908bSGoran Flegar struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> { 681672b908bSGoran Flegar using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion; 682672b908bSGoran Flegar 683672b908bSGoran Flegar std::pair<Value, Value> 684672b908bSGoran Flegar combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 685672b908bSGoran Flegar Value cos, ConversionPatternRewriter &rewriter) const override { 686672b908bSGoran Flegar // Complex sine is defined as; 687672b908bSGoran Flegar // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) 688672b908bSGoran Flegar // Plugging in: 689672b908bSGoran Flegar // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 690672b908bSGoran Flegar // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 691672b908bSGoran Flegar // and defining t := exp(y) 692672b908bSGoran Flegar // We get: 693672b908bSGoran Flegar // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x 694672b908bSGoran Flegar // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x 695672b908bSGoran Flegar Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp); 696672b908bSGoran Flegar Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin); 697672b908bSGoran Flegar Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp); 698672b908bSGoran Flegar Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos); 699672b908bSGoran Flegar return {resultReal, resultImag}; 700672b908bSGoran Flegar } 701672b908bSGoran Flegar }; 702672b908bSGoran Flegar 703f112bd61SAdrian Kuegel struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 704f112bd61SAdrian Kuegel using OpConversionPattern<complex::SignOp>::OpConversionPattern; 705f112bd61SAdrian Kuegel 706f112bd61SAdrian Kuegel LogicalResult 707b54c724bSRiver Riddle matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, 708f112bd61SAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 709c0342a2dSJacques Pienaar auto type = adaptor.getComplex().getType().cast<ComplexType>(); 710f112bd61SAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 711f112bd61SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 712f112bd61SAdrian Kuegel 713c0342a2dSJacques Pienaar Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 714c0342a2dSJacques Pienaar Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 715a54f4eaeSMogball Value zero = 716a54f4eaeSMogball b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 717a54f4eaeSMogball Value realIsZero = 718a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 719a54f4eaeSMogball Value imagIsZero = 720a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 721a54f4eaeSMogball Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 722c0342a2dSJacques Pienaar auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 723a54f4eaeSMogball Value realSign = b.create<arith::DivFOp>(real, abs); 724a54f4eaeSMogball Value imagSign = b.create<arith::DivFOp>(imag, abs); 725f112bd61SAdrian Kuegel Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 726dec8af70SRiver Riddle rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, 727dec8af70SRiver Riddle adaptor.getComplex(), sign); 728f112bd61SAdrian Kuegel return success(); 729f112bd61SAdrian Kuegel } 730f112bd61SAdrian Kuegel }; 7316d75c897Slewuathe 7326d75c897Slewuathe struct TanOpConversion : public OpConversionPattern<complex::TanOp> { 7336d75c897Slewuathe using OpConversionPattern<complex::TanOp>::OpConversionPattern; 7346d75c897Slewuathe 7356d75c897Slewuathe LogicalResult 7366d75c897Slewuathe matchAndRewrite(complex::TanOp op, OpAdaptor adaptor, 7376d75c897Slewuathe ConversionPatternRewriter &rewriter) const override { 7386d75c897Slewuathe auto loc = op.getLoc(); 7396d75c897Slewuathe Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex()); 7406d75c897Slewuathe Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex()); 7416d75c897Slewuathe rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos); 7426d75c897Slewuathe return success(); 7436d75c897Slewuathe } 7446d75c897Slewuathe }; 745*ffb8eecdSlewuathe 746*ffb8eecdSlewuathe struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> { 747*ffb8eecdSlewuathe using OpConversionPattern<complex::TanhOp>::OpConversionPattern; 748*ffb8eecdSlewuathe 749*ffb8eecdSlewuathe LogicalResult 750*ffb8eecdSlewuathe matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor, 751*ffb8eecdSlewuathe ConversionPatternRewriter &rewriter) const override { 752*ffb8eecdSlewuathe auto loc = op.getLoc(); 753*ffb8eecdSlewuathe auto type = adaptor.getComplex().getType().cast<ComplexType>(); 754*ffb8eecdSlewuathe auto elementType = type.getElementType().cast<FloatType>(); 755*ffb8eecdSlewuathe 756*ffb8eecdSlewuathe // The hyperbolic tangent for complex number can be calculated as follows. 757*ffb8eecdSlewuathe // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y)) 758*ffb8eecdSlewuathe // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number 759*ffb8eecdSlewuathe Value real = 760*ffb8eecdSlewuathe rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 761*ffb8eecdSlewuathe Value imag = 762*ffb8eecdSlewuathe rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 763*ffb8eecdSlewuathe Value tanhA = rewriter.create<math::TanhOp>(loc, real); 764*ffb8eecdSlewuathe Value cosB = rewriter.create<math::CosOp>(loc, imag); 765*ffb8eecdSlewuathe Value sinB = rewriter.create<math::SinOp>(loc, imag); 766*ffb8eecdSlewuathe Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB); 767*ffb8eecdSlewuathe Value numerator = 768*ffb8eecdSlewuathe rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB); 769*ffb8eecdSlewuathe Value one = rewriter.create<arith::ConstantOp>( 770*ffb8eecdSlewuathe loc, elementType, rewriter.getFloatAttr(elementType, 1)); 771*ffb8eecdSlewuathe Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB); 772*ffb8eecdSlewuathe Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul); 773*ffb8eecdSlewuathe rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator); 774*ffb8eecdSlewuathe return success(); 775*ffb8eecdSlewuathe } 776*ffb8eecdSlewuathe }; 777*ffb8eecdSlewuathe 7782ea7fb7bSAdrian Kuegel } // namespace 7792ea7fb7bSAdrian Kuegel 7802ea7fb7bSAdrian Kuegel void mlir::populateComplexToStandardConversionPatterns( 7812ea7fb7bSAdrian Kuegel RewritePatternSet &patterns) { 782f112bd61SAdrian Kuegel // clang-format off 783f112bd61SAdrian Kuegel patterns.add< 784f112bd61SAdrian Kuegel AbsOpConversion, 785a54f4eaeSMogball ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, 786a54f4eaeSMogball ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, 787a54f4eaeSMogball BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, 788a54f4eaeSMogball BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, 789672b908bSGoran Flegar CosOpConversion, 790f112bd61SAdrian Kuegel DivOpConversion, 791f112bd61SAdrian Kuegel ExpOpConversion, 792338e76f8Sbixia1 Expm1OpConversion, 793380fa71fSAdrian Kuegel LogOpConversion, 7946e80e3bdSAdrian Kuegel Log1pOpConversion, 795bf17ee19SAdrian Kuegel MulOpConversion, 796f112bd61SAdrian Kuegel NegOpConversion, 797672b908bSGoran Flegar SignOpConversion, 7986d75c897Slewuathe SinOpConversion, 799*ffb8eecdSlewuathe TanOpConversion, 800*ffb8eecdSlewuathe TanhOpConversion>(patterns.getContext()); 801f112bd61SAdrian Kuegel // clang-format on 8022ea7fb7bSAdrian Kuegel } 8032ea7fb7bSAdrian Kuegel 8042ea7fb7bSAdrian Kuegel namespace { 8052ea7fb7bSAdrian Kuegel struct ConvertComplexToStandardPass 8062ea7fb7bSAdrian Kuegel : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 80741574554SRiver Riddle void runOnOperation() override; 8082ea7fb7bSAdrian Kuegel }; 8092ea7fb7bSAdrian Kuegel 81041574554SRiver Riddle void ConvertComplexToStandardPass::runOnOperation() { 8112ea7fb7bSAdrian Kuegel // Convert to the Standard dialect using the converter defined above. 8122ea7fb7bSAdrian Kuegel RewritePatternSet patterns(&getContext()); 8132ea7fb7bSAdrian Kuegel populateComplexToStandardConversionPatterns(patterns); 8142ea7fb7bSAdrian Kuegel 8152ea7fb7bSAdrian Kuegel ConversionTarget target(getContext()); 8161f971e23SRiver Riddle target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>(); 817fb978f09SAdrian Kuegel target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 81847f175b0SRiver Riddle if (failed( 81947f175b0SRiver Riddle applyPartialConversion(getOperation(), target, std::move(patterns)))) 8202ea7fb7bSAdrian Kuegel signalPassFailure(); 8212ea7fb7bSAdrian Kuegel } 8222ea7fb7bSAdrian Kuegel } // namespace 8232ea7fb7bSAdrian Kuegel 82447f175b0SRiver Riddle std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() { 8252ea7fb7bSAdrian Kuegel return std::make_unique<ConvertComplexToStandardPass>(); 8262ea7fb7bSAdrian Kuegel } 827