12ea7fb7bSAdrian Kuegel //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 22ea7fb7bSAdrian Kuegel // 32ea7fb7bSAdrian Kuegel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42ea7fb7bSAdrian Kuegel // See https://llvm.org/LICENSE.txt for license information. 52ea7fb7bSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62ea7fb7bSAdrian Kuegel // 72ea7fb7bSAdrian Kuegel //===----------------------------------------------------------------------===// 82ea7fb7bSAdrian Kuegel 92ea7fb7bSAdrian Kuegel #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 102ea7fb7bSAdrian Kuegel 112ea7fb7bSAdrian Kuegel #include <memory> 12fb8b2b86SAdrian Kuegel #include <type_traits> 132ea7fb7bSAdrian Kuegel 142ea7fb7bSAdrian Kuegel #include "../PassDetail.h" 15*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 162ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Complex/IR/Complex.h" 172ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Math/IR/Math.h" 182ea7fb7bSAdrian Kuegel #include "mlir/Dialect/StandardOps/IR/Ops.h" 19f112bd61SAdrian Kuegel #include "mlir/IR/ImplicitLocOpBuilder.h" 202ea7fb7bSAdrian Kuegel #include "mlir/IR/PatternMatch.h" 212ea7fb7bSAdrian Kuegel #include "mlir/Transforms/DialectConversion.h" 222ea7fb7bSAdrian Kuegel 232ea7fb7bSAdrian Kuegel using namespace mlir; 242ea7fb7bSAdrian Kuegel 252ea7fb7bSAdrian Kuegel namespace { 262ea7fb7bSAdrian Kuegel struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 272ea7fb7bSAdrian Kuegel using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 282ea7fb7bSAdrian Kuegel 292ea7fb7bSAdrian Kuegel LogicalResult 30b54c724bSRiver Riddle matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 312ea7fb7bSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 322ea7fb7bSAdrian Kuegel auto loc = op.getLoc(); 332ea7fb7bSAdrian Kuegel auto type = op.getType(); 342ea7fb7bSAdrian Kuegel 35b54c724bSRiver Riddle Value real = rewriter.create<complex::ReOp>(loc, type, adaptor.complex()); 36b54c724bSRiver Riddle Value imag = rewriter.create<complex::ImOp>(loc, type, adaptor.complex()); 37*a54f4eaeSMogball Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real); 38*a54f4eaeSMogball Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag); 39*a54f4eaeSMogball Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr); 402ea7fb7bSAdrian Kuegel 412ea7fb7bSAdrian Kuegel rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 422ea7fb7bSAdrian Kuegel return success(); 432ea7fb7bSAdrian Kuegel } 442ea7fb7bSAdrian Kuegel }; 45ac00cb0dSAdrian Kuegel 46*a54f4eaeSMogball template <typename ComparisonOp, arith::CmpFPredicate p> 47fb8b2b86SAdrian Kuegel struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 48fb8b2b86SAdrian Kuegel using OpConversionPattern<ComparisonOp>::OpConversionPattern; 49fb8b2b86SAdrian Kuegel using ResultCombiner = 50fb8b2b86SAdrian Kuegel std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 51*a54f4eaeSMogball arith::AndIOp, arith::OrIOp>; 52ac00cb0dSAdrian Kuegel 53ac00cb0dSAdrian Kuegel LogicalResult 54b54c724bSRiver Riddle matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, 55ac00cb0dSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 56ac00cb0dSAdrian Kuegel auto loc = op.getLoc(); 57b54c724bSRiver Riddle auto type = 58b54c724bSRiver Riddle adaptor.lhs().getType().template cast<ComplexType>().getElementType(); 59ac00cb0dSAdrian Kuegel 60b54c724bSRiver Riddle Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.lhs()); 61b54c724bSRiver Riddle Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.lhs()); 62b54c724bSRiver Riddle Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.rhs()); 63b54c724bSRiver Riddle Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.rhs()); 64*a54f4eaeSMogball Value realComparison = 65*a54f4eaeSMogball rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); 66*a54f4eaeSMogball Value imagComparison = 67*a54f4eaeSMogball rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); 68ac00cb0dSAdrian Kuegel 69fb8b2b86SAdrian Kuegel rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 70fb8b2b86SAdrian Kuegel imagComparison); 71ac00cb0dSAdrian Kuegel return success(); 72ac00cb0dSAdrian Kuegel } 73ac00cb0dSAdrian Kuegel }; 74942be7cbSAdrian Kuegel 75fb978f09SAdrian Kuegel // Default conversion which applies the BinaryStandardOp separately on the real 76fb978f09SAdrian Kuegel // and imaginary parts. Can for example be used for complex::AddOp and 77fb978f09SAdrian Kuegel // complex::SubOp. 78fb978f09SAdrian Kuegel template <typename BinaryComplexOp, typename BinaryStandardOp> 79fb978f09SAdrian Kuegel struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 80fb978f09SAdrian Kuegel using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 81fb978f09SAdrian Kuegel 82fb978f09SAdrian Kuegel LogicalResult 83b54c724bSRiver Riddle matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, 84fb978f09SAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 85b54c724bSRiver Riddle auto type = adaptor.lhs().getType().template cast<ComplexType>(); 86fb978f09SAdrian Kuegel auto elementType = type.getElementType().template cast<FloatType>(); 87fb978f09SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 88fb978f09SAdrian Kuegel 89b54c724bSRiver Riddle Value realLhs = b.create<complex::ReOp>(elementType, adaptor.lhs()); 90b54c724bSRiver Riddle Value realRhs = b.create<complex::ReOp>(elementType, adaptor.rhs()); 91fb978f09SAdrian Kuegel Value resultReal = 92fb978f09SAdrian Kuegel b.create<BinaryStandardOp>(elementType, realLhs, realRhs); 93b54c724bSRiver Riddle Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.lhs()); 94b54c724bSRiver Riddle Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.rhs()); 95fb978f09SAdrian Kuegel Value resultImag = 96fb978f09SAdrian Kuegel b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs); 97fb978f09SAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 98fb978f09SAdrian Kuegel resultImag); 99fb978f09SAdrian Kuegel return success(); 100fb978f09SAdrian Kuegel } 101fb978f09SAdrian Kuegel }; 102fb978f09SAdrian Kuegel 103942be7cbSAdrian Kuegel struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 104942be7cbSAdrian Kuegel using OpConversionPattern<complex::DivOp>::OpConversionPattern; 105942be7cbSAdrian Kuegel 106942be7cbSAdrian Kuegel LogicalResult 107b54c724bSRiver Riddle matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 108942be7cbSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 109942be7cbSAdrian Kuegel auto loc = op.getLoc(); 110b54c724bSRiver Riddle auto type = adaptor.lhs().getType().cast<ComplexType>(); 111942be7cbSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 112942be7cbSAdrian Kuegel 113942be7cbSAdrian Kuegel Value lhsReal = 114b54c724bSRiver Riddle rewriter.create<complex::ReOp>(loc, elementType, adaptor.lhs()); 115942be7cbSAdrian Kuegel Value lhsImag = 116b54c724bSRiver Riddle rewriter.create<complex::ImOp>(loc, elementType, adaptor.lhs()); 117942be7cbSAdrian Kuegel Value rhsReal = 118b54c724bSRiver Riddle rewriter.create<complex::ReOp>(loc, elementType, adaptor.rhs()); 119942be7cbSAdrian Kuegel Value rhsImag = 120b54c724bSRiver Riddle rewriter.create<complex::ImOp>(loc, elementType, adaptor.rhs()); 121942be7cbSAdrian Kuegel 122942be7cbSAdrian Kuegel // Smith's algorithm to divide complex numbers. It is just a bit smarter 123942be7cbSAdrian Kuegel // way to compute the following formula: 124942be7cbSAdrian Kuegel // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 125942be7cbSAdrian Kuegel // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 126942be7cbSAdrian Kuegel // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 127942be7cbSAdrian Kuegel // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 128942be7cbSAdrian Kuegel // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 129942be7cbSAdrian Kuegel // 130942be7cbSAdrian Kuegel // Depending on whether |rhsReal| < |rhsImag| we compute either 131942be7cbSAdrian Kuegel // rhsRealImagRatio = rhsReal / rhsImag 132942be7cbSAdrian Kuegel // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 133942be7cbSAdrian Kuegel // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 134942be7cbSAdrian Kuegel // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 135942be7cbSAdrian Kuegel // 136942be7cbSAdrian Kuegel // or 137942be7cbSAdrian Kuegel // 138942be7cbSAdrian Kuegel // rhsImagRealRatio = rhsImag / rhsReal 139942be7cbSAdrian Kuegel // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 140942be7cbSAdrian Kuegel // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 141942be7cbSAdrian Kuegel // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 142942be7cbSAdrian Kuegel // 143942be7cbSAdrian Kuegel // See https://dl.acm.org/citation.cfm?id=368661 for more details. 144*a54f4eaeSMogball Value rhsRealImagRatio = 145*a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag); 146*a54f4eaeSMogball Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( 147*a54f4eaeSMogball loc, rhsImag, 148*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal)); 149*a54f4eaeSMogball Value realNumerator1 = rewriter.create<arith::AddFOp>( 150*a54f4eaeSMogball loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio), 151*a54f4eaeSMogball lhsImag); 152942be7cbSAdrian Kuegel Value resultReal1 = 153*a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom); 154*a54f4eaeSMogball Value imagNumerator1 = rewriter.create<arith::SubFOp>( 155*a54f4eaeSMogball loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio), 156*a54f4eaeSMogball lhsReal); 157942be7cbSAdrian Kuegel Value resultImag1 = 158*a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 159942be7cbSAdrian Kuegel 160*a54f4eaeSMogball Value rhsImagRealRatio = 161*a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal); 162*a54f4eaeSMogball Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( 163*a54f4eaeSMogball loc, rhsReal, 164*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag)); 165*a54f4eaeSMogball Value realNumerator2 = rewriter.create<arith::AddFOp>( 166*a54f4eaeSMogball loc, lhsReal, 167*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio)); 168942be7cbSAdrian Kuegel Value resultReal2 = 169*a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom); 170*a54f4eaeSMogball Value imagNumerator2 = rewriter.create<arith::SubFOp>( 171*a54f4eaeSMogball loc, lhsImag, 172*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio)); 173942be7cbSAdrian Kuegel Value resultImag2 = 174*a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 175942be7cbSAdrian Kuegel 176942be7cbSAdrian Kuegel // Consider corner cases. 177942be7cbSAdrian Kuegel // Case 1. Zero denominator, numerator contains at most one NaN value. 178*a54f4eaeSMogball Value zero = rewriter.create<arith::ConstantOp>( 179*a54f4eaeSMogball loc, elementType, rewriter.getZeroAttr(elementType)); 180*a54f4eaeSMogball Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal); 181*a54f4eaeSMogball Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( 182*a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); 183*a54f4eaeSMogball Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag); 184*a54f4eaeSMogball Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( 185*a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); 186*a54f4eaeSMogball Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( 187*a54f4eaeSMogball loc, arith::CmpFPredicate::ORD, lhsReal, zero); 188*a54f4eaeSMogball Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( 189*a54f4eaeSMogball loc, arith::CmpFPredicate::ORD, lhsImag, zero); 190942be7cbSAdrian Kuegel Value lhsContainsNotNaNValue = 191*a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 192*a54f4eaeSMogball Value resultIsInfinity = rewriter.create<arith::AndIOp>( 193942be7cbSAdrian Kuegel loc, lhsContainsNotNaNValue, 194*a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); 195*a54f4eaeSMogball Value inf = rewriter.create<arith::ConstantOp>( 196942be7cbSAdrian Kuegel loc, elementType, 197942be7cbSAdrian Kuegel rewriter.getFloatAttr( 198942be7cbSAdrian Kuegel elementType, APFloat::getInf(elementType.getFloatSemantics()))); 199*a54f4eaeSMogball Value infWithSignOfRhsReal = 200*a54f4eaeSMogball rewriter.create<math::CopySignOp>(loc, inf, rhsReal); 201942be7cbSAdrian Kuegel Value infinityResultReal = 202*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 203942be7cbSAdrian Kuegel Value infinityResultImag = 204*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 205942be7cbSAdrian Kuegel 206942be7cbSAdrian Kuegel // Case 2. Infinite numerator, finite denominator. 207*a54f4eaeSMogball Value rhsRealFinite = rewriter.create<arith::CmpFOp>( 208*a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); 209*a54f4eaeSMogball Value rhsImagFinite = rewriter.create<arith::CmpFOp>( 210*a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); 211*a54f4eaeSMogball Value rhsFinite = 212*a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); 213*a54f4eaeSMogball Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal); 214*a54f4eaeSMogball Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( 215*a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 216*a54f4eaeSMogball Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag); 217*a54f4eaeSMogball Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( 218*a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 219942be7cbSAdrian Kuegel Value lhsInfinite = 220*a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); 221942be7cbSAdrian Kuegel Value infNumFiniteDenom = 222*a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); 223*a54f4eaeSMogball Value one = rewriter.create<arith::ConstantOp>( 224942be7cbSAdrian Kuegel loc, elementType, rewriter.getFloatAttr(elementType, 1)); 225*a54f4eaeSMogball Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 226942be7cbSAdrian Kuegel loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero), 227942be7cbSAdrian Kuegel lhsReal); 228*a54f4eaeSMogball Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 229942be7cbSAdrian Kuegel loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero), 230942be7cbSAdrian Kuegel lhsImag); 231942be7cbSAdrian Kuegel Value lhsRealIsInfWithSignTimesRhsReal = 232*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 233942be7cbSAdrian Kuegel Value lhsImagIsInfWithSignTimesRhsImag = 234*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 235*a54f4eaeSMogball Value resultReal3 = rewriter.create<arith::MulFOp>( 236942be7cbSAdrian Kuegel loc, inf, 237*a54f4eaeSMogball rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 238942be7cbSAdrian Kuegel lhsImagIsInfWithSignTimesRhsImag)); 239942be7cbSAdrian Kuegel Value lhsRealIsInfWithSignTimesRhsImag = 240*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 241942be7cbSAdrian Kuegel Value lhsImagIsInfWithSignTimesRhsReal = 242*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 243*a54f4eaeSMogball Value resultImag3 = rewriter.create<arith::MulFOp>( 244942be7cbSAdrian Kuegel loc, inf, 245*a54f4eaeSMogball rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 246942be7cbSAdrian Kuegel lhsRealIsInfWithSignTimesRhsImag)); 247942be7cbSAdrian Kuegel 248942be7cbSAdrian Kuegel // Case 3: Finite numerator, infinite denominator. 249*a54f4eaeSMogball Value lhsRealFinite = rewriter.create<arith::CmpFOp>( 250*a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); 251*a54f4eaeSMogball Value lhsImagFinite = rewriter.create<arith::CmpFOp>( 252*a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); 253*a54f4eaeSMogball Value lhsFinite = 254*a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); 255*a54f4eaeSMogball Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( 256*a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 257*a54f4eaeSMogball Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( 258*a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 259942be7cbSAdrian Kuegel Value rhsInfinite = 260*a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); 261942be7cbSAdrian Kuegel Value finiteNumInfiniteDenom = 262*a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); 263*a54f4eaeSMogball Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 264942be7cbSAdrian Kuegel loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero), 265942be7cbSAdrian Kuegel rhsReal); 266*a54f4eaeSMogball Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 267942be7cbSAdrian Kuegel loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero), 268942be7cbSAdrian Kuegel rhsImag); 269942be7cbSAdrian Kuegel Value rhsRealIsInfWithSignTimesLhsReal = 270*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 271942be7cbSAdrian Kuegel Value rhsImagIsInfWithSignTimesLhsImag = 272*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 273*a54f4eaeSMogball Value resultReal4 = rewriter.create<arith::MulFOp>( 274942be7cbSAdrian Kuegel loc, zero, 275*a54f4eaeSMogball rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 276942be7cbSAdrian Kuegel rhsImagIsInfWithSignTimesLhsImag)); 277942be7cbSAdrian Kuegel Value rhsRealIsInfWithSignTimesLhsImag = 278*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 279942be7cbSAdrian Kuegel Value rhsImagIsInfWithSignTimesLhsReal = 280*a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 281*a54f4eaeSMogball Value resultImag4 = rewriter.create<arith::MulFOp>( 282942be7cbSAdrian Kuegel loc, zero, 283*a54f4eaeSMogball rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 284942be7cbSAdrian Kuegel rhsImagIsInfWithSignTimesLhsReal)); 285942be7cbSAdrian Kuegel 286*a54f4eaeSMogball Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( 287*a54f4eaeSMogball loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 288942be7cbSAdrian Kuegel Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 289942be7cbSAdrian Kuegel resultReal1, resultReal2); 290942be7cbSAdrian Kuegel Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 291942be7cbSAdrian Kuegel resultImag1, resultImag2); 292942be7cbSAdrian Kuegel Value resultRealSpecialCase3 = rewriter.create<SelectOp>( 293942be7cbSAdrian Kuegel loc, finiteNumInfiniteDenom, resultReal4, resultReal); 294942be7cbSAdrian Kuegel Value resultImagSpecialCase3 = rewriter.create<SelectOp>( 295942be7cbSAdrian Kuegel loc, finiteNumInfiniteDenom, resultImag4, resultImag); 296942be7cbSAdrian Kuegel Value resultRealSpecialCase2 = rewriter.create<SelectOp>( 297942be7cbSAdrian Kuegel loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 298942be7cbSAdrian Kuegel Value resultImagSpecialCase2 = rewriter.create<SelectOp>( 299942be7cbSAdrian Kuegel loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 300942be7cbSAdrian Kuegel Value resultRealSpecialCase1 = rewriter.create<SelectOp>( 301942be7cbSAdrian Kuegel loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 302942be7cbSAdrian Kuegel Value resultImagSpecialCase1 = rewriter.create<SelectOp>( 303942be7cbSAdrian Kuegel loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 304942be7cbSAdrian Kuegel 305*a54f4eaeSMogball Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( 306*a54f4eaeSMogball loc, arith::CmpFPredicate::UNO, resultReal, zero); 307*a54f4eaeSMogball Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( 308*a54f4eaeSMogball loc, arith::CmpFPredicate::UNO, resultImag, zero); 309942be7cbSAdrian Kuegel Value resultIsNaN = 310*a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); 311942be7cbSAdrian Kuegel Value resultRealWithSpecialCases = rewriter.create<SelectOp>( 312942be7cbSAdrian Kuegel loc, resultIsNaN, resultRealSpecialCase1, resultReal); 313942be7cbSAdrian Kuegel Value resultImagWithSpecialCases = rewriter.create<SelectOp>( 314942be7cbSAdrian Kuegel loc, resultIsNaN, resultImagSpecialCase1, resultImag); 315942be7cbSAdrian Kuegel 316942be7cbSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>( 317942be7cbSAdrian Kuegel op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 318942be7cbSAdrian Kuegel return success(); 319942be7cbSAdrian Kuegel } 320942be7cbSAdrian Kuegel }; 32173cbc91cSAdrian Kuegel 32273cbc91cSAdrian Kuegel struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 32373cbc91cSAdrian Kuegel using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 32473cbc91cSAdrian Kuegel 32573cbc91cSAdrian Kuegel LogicalResult 326b54c724bSRiver Riddle matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, 32773cbc91cSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 32873cbc91cSAdrian Kuegel auto loc = op.getLoc(); 329b54c724bSRiver Riddle auto type = adaptor.complex().getType().cast<ComplexType>(); 33073cbc91cSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 33173cbc91cSAdrian Kuegel 33273cbc91cSAdrian Kuegel Value real = 333b54c724bSRiver Riddle rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex()); 33473cbc91cSAdrian Kuegel Value imag = 335b54c724bSRiver Riddle rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex()); 33673cbc91cSAdrian Kuegel Value expReal = rewriter.create<math::ExpOp>(loc, real); 33773cbc91cSAdrian Kuegel Value cosImag = rewriter.create<math::CosOp>(loc, imag); 338*a54f4eaeSMogball Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag); 33973cbc91cSAdrian Kuegel Value sinImag = rewriter.create<math::SinOp>(loc, imag); 340*a54f4eaeSMogball Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag); 34173cbc91cSAdrian Kuegel 34273cbc91cSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 34373cbc91cSAdrian Kuegel resultImag); 34473cbc91cSAdrian Kuegel return success(); 34573cbc91cSAdrian Kuegel } 34673cbc91cSAdrian Kuegel }; 347662e074dSAdrian Kuegel 348380fa71fSAdrian Kuegel struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 349380fa71fSAdrian Kuegel using OpConversionPattern<complex::LogOp>::OpConversionPattern; 350380fa71fSAdrian Kuegel 351380fa71fSAdrian Kuegel LogicalResult 352b54c724bSRiver Riddle matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, 353380fa71fSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 354b54c724bSRiver Riddle auto type = adaptor.complex().getType().cast<ComplexType>(); 355380fa71fSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 356380fa71fSAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 357380fa71fSAdrian Kuegel 358b54c724bSRiver Riddle Value abs = b.create<complex::AbsOp>(elementType, adaptor.complex()); 359380fa71fSAdrian Kuegel Value resultReal = b.create<math::LogOp>(elementType, abs); 360b54c724bSRiver Riddle Value real = b.create<complex::ReOp>(elementType, adaptor.complex()); 361b54c724bSRiver Riddle Value imag = b.create<complex::ImOp>(elementType, adaptor.complex()); 362380fa71fSAdrian Kuegel Value resultImag = b.create<math::Atan2Op>(elementType, imag, real); 363380fa71fSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 364380fa71fSAdrian Kuegel resultImag); 365380fa71fSAdrian Kuegel return success(); 366380fa71fSAdrian Kuegel } 367380fa71fSAdrian Kuegel }; 368380fa71fSAdrian Kuegel 3696e80e3bdSAdrian Kuegel struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 3706e80e3bdSAdrian Kuegel using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 3716e80e3bdSAdrian Kuegel 3726e80e3bdSAdrian Kuegel LogicalResult 373b54c724bSRiver Riddle matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, 3746e80e3bdSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 375b54c724bSRiver Riddle auto type = adaptor.complex().getType().cast<ComplexType>(); 3766e80e3bdSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 3776e80e3bdSAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 3786e80e3bdSAdrian Kuegel 379b54c724bSRiver Riddle Value real = b.create<complex::ReOp>(elementType, adaptor.complex()); 380b54c724bSRiver Riddle Value imag = b.create<complex::ImOp>(elementType, adaptor.complex()); 381*a54f4eaeSMogball Value one = b.create<arith::ConstantOp>(elementType, 382*a54f4eaeSMogball b.getFloatAttr(elementType, 1)); 383*a54f4eaeSMogball Value realPlusOne = b.create<arith::AddFOp>(real, one); 3846e80e3bdSAdrian Kuegel Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag); 3856e80e3bdSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex); 3866e80e3bdSAdrian Kuegel return success(); 3876e80e3bdSAdrian Kuegel } 3886e80e3bdSAdrian Kuegel }; 3896e80e3bdSAdrian Kuegel 390bf17ee19SAdrian Kuegel struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 391bf17ee19SAdrian Kuegel using OpConversionPattern<complex::MulOp>::OpConversionPattern; 392bf17ee19SAdrian Kuegel 393bf17ee19SAdrian Kuegel LogicalResult 394b54c724bSRiver Riddle matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 395bf17ee19SAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 396bf17ee19SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 397b54c724bSRiver Riddle auto type = adaptor.lhs().getType().cast<ComplexType>(); 398bf17ee19SAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 399bf17ee19SAdrian Kuegel 400b54c724bSRiver Riddle Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.lhs()); 401*a54f4eaeSMogball Value lhsRealAbs = b.create<math::AbsOp>(lhsReal); 402b54c724bSRiver Riddle Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.lhs()); 403*a54f4eaeSMogball Value lhsImagAbs = b.create<math::AbsOp>(lhsImag); 404b54c724bSRiver Riddle Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.rhs()); 405*a54f4eaeSMogball Value rhsRealAbs = b.create<math::AbsOp>(rhsReal); 406b54c724bSRiver Riddle Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.rhs()); 407*a54f4eaeSMogball Value rhsImagAbs = b.create<math::AbsOp>(rhsImag); 408bf17ee19SAdrian Kuegel 409*a54f4eaeSMogball Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 410*a54f4eaeSMogball Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal); 411*a54f4eaeSMogball Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 412*a54f4eaeSMogball Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag); 413*a54f4eaeSMogball Value real = 414*a54f4eaeSMogball b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 415bf17ee19SAdrian Kuegel 416*a54f4eaeSMogball Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 417*a54f4eaeSMogball Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal); 418*a54f4eaeSMogball Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 419*a54f4eaeSMogball Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag); 420*a54f4eaeSMogball Value imag = 421*a54f4eaeSMogball b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 422bf17ee19SAdrian Kuegel 423bf17ee19SAdrian Kuegel // Handle cases where the "naive" calculation results in NaN values. 424*a54f4eaeSMogball Value realIsNan = 425*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real); 426*a54f4eaeSMogball Value imagIsNan = 427*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag); 428*a54f4eaeSMogball Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan); 429bf17ee19SAdrian Kuegel 430*a54f4eaeSMogball Value inf = b.create<arith::ConstantOp>( 431bf17ee19SAdrian Kuegel elementType, 432bf17ee19SAdrian Kuegel b.getFloatAttr(elementType, 433bf17ee19SAdrian Kuegel APFloat::getInf(elementType.getFloatSemantics()))); 434bf17ee19SAdrian Kuegel 435bf17ee19SAdrian Kuegel // Case 1. `lhsReal` or `lhsImag` are infinite. 436*a54f4eaeSMogball Value lhsRealIsInf = 437*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 438*a54f4eaeSMogball Value lhsImagIsInf = 439*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 440*a54f4eaeSMogball Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf); 441*a54f4eaeSMogball Value rhsRealIsNan = 442*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal); 443*a54f4eaeSMogball Value rhsImagIsNan = 444*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag); 445*a54f4eaeSMogball Value zero = 446*a54f4eaeSMogball b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 447*a54f4eaeSMogball Value one = b.create<arith::ConstantOp>(elementType, 448*a54f4eaeSMogball b.getFloatAttr(elementType, 1)); 449bf17ee19SAdrian Kuegel Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero); 450bf17ee19SAdrian Kuegel lhsReal = b.create<SelectOp>( 451*a54f4eaeSMogball lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal), 452*a54f4eaeSMogball lhsReal); 453bf17ee19SAdrian Kuegel Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero); 454bf17ee19SAdrian Kuegel lhsImag = b.create<SelectOp>( 455*a54f4eaeSMogball lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag), 456*a54f4eaeSMogball lhsImag); 457*a54f4eaeSMogball Value lhsIsInfAndRhsRealIsNan = 458*a54f4eaeSMogball b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan); 459*a54f4eaeSMogball rhsReal = 460*a54f4eaeSMogball b.create<SelectOp>(lhsIsInfAndRhsRealIsNan, 461*a54f4eaeSMogball b.create<math::CopySignOp>(zero, rhsReal), rhsReal); 462*a54f4eaeSMogball Value lhsIsInfAndRhsImagIsNan = 463*a54f4eaeSMogball b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan); 464*a54f4eaeSMogball rhsImag = 465*a54f4eaeSMogball b.create<SelectOp>(lhsIsInfAndRhsImagIsNan, 466*a54f4eaeSMogball b.create<math::CopySignOp>(zero, rhsImag), rhsImag); 467bf17ee19SAdrian Kuegel 468bf17ee19SAdrian Kuegel // Case 2. `rhsReal` or `rhsImag` are infinite. 469*a54f4eaeSMogball Value rhsRealIsInf = 470*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 471*a54f4eaeSMogball Value rhsImagIsInf = 472*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 473*a54f4eaeSMogball Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf); 474*a54f4eaeSMogball Value lhsRealIsNan = 475*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal); 476*a54f4eaeSMogball Value lhsImagIsNan = 477*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag); 478bf17ee19SAdrian Kuegel Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero); 479bf17ee19SAdrian Kuegel rhsReal = b.create<SelectOp>( 480*a54f4eaeSMogball rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal), 481*a54f4eaeSMogball rhsReal); 482bf17ee19SAdrian Kuegel Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero); 483bf17ee19SAdrian Kuegel rhsImag = b.create<SelectOp>( 484*a54f4eaeSMogball rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag), 485*a54f4eaeSMogball rhsImag); 486*a54f4eaeSMogball Value rhsIsInfAndLhsRealIsNan = 487*a54f4eaeSMogball b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan); 488*a54f4eaeSMogball lhsReal = 489*a54f4eaeSMogball b.create<SelectOp>(rhsIsInfAndLhsRealIsNan, 490*a54f4eaeSMogball b.create<math::CopySignOp>(zero, lhsReal), lhsReal); 491*a54f4eaeSMogball Value rhsIsInfAndLhsImagIsNan = 492*a54f4eaeSMogball b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan); 493*a54f4eaeSMogball lhsImag = 494*a54f4eaeSMogball b.create<SelectOp>(rhsIsInfAndLhsImagIsNan, 495*a54f4eaeSMogball b.create<math::CopySignOp>(zero, lhsImag), lhsImag); 496*a54f4eaeSMogball Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf); 497bf17ee19SAdrian Kuegel 498bf17ee19SAdrian Kuegel // Case 3. One of the pairwise products of left hand side with right hand 499bf17ee19SAdrian Kuegel // side is infinite. 500*a54f4eaeSMogball Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>( 501*a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); 502*a54f4eaeSMogball Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>( 503*a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); 504*a54f4eaeSMogball Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf, 505*a54f4eaeSMogball lhsImagTimesRhsImagIsInf); 506*a54f4eaeSMogball Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>( 507*a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); 508*a54f4eaeSMogball isSpecialCase = 509*a54f4eaeSMogball b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); 510*a54f4eaeSMogball Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>( 511*a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); 512*a54f4eaeSMogball isSpecialCase = 513*a54f4eaeSMogball b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); 514bf17ee19SAdrian Kuegel Type i1Type = b.getI1Type(); 515*a54f4eaeSMogball Value notRecalc = b.create<arith::XOrIOp>( 516*a54f4eaeSMogball recalc, 517*a54f4eaeSMogball b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); 518*a54f4eaeSMogball isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc); 519bf17ee19SAdrian Kuegel Value isSpecialCaseAndLhsRealIsNan = 520*a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan); 521*a54f4eaeSMogball lhsReal = 522*a54f4eaeSMogball b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan, 523*a54f4eaeSMogball b.create<math::CopySignOp>(zero, lhsReal), lhsReal); 524bf17ee19SAdrian Kuegel Value isSpecialCaseAndLhsImagIsNan = 525*a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan); 526*a54f4eaeSMogball lhsImag = 527*a54f4eaeSMogball b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan, 528*a54f4eaeSMogball b.create<math::CopySignOp>(zero, lhsImag), lhsImag); 529bf17ee19SAdrian Kuegel Value isSpecialCaseAndRhsRealIsNan = 530*a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan); 531*a54f4eaeSMogball rhsReal = 532*a54f4eaeSMogball b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan, 533*a54f4eaeSMogball b.create<math::CopySignOp>(zero, rhsReal), rhsReal); 534bf17ee19SAdrian Kuegel Value isSpecialCaseAndRhsImagIsNan = 535*a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan); 536*a54f4eaeSMogball rhsImag = 537*a54f4eaeSMogball b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan, 538*a54f4eaeSMogball b.create<math::CopySignOp>(zero, rhsImag), rhsImag); 539*a54f4eaeSMogball recalc = b.create<arith::OrIOp>(recalc, isSpecialCase); 540*a54f4eaeSMogball recalc = b.create<arith::AndIOp>(isNan, recalc); 541bf17ee19SAdrian Kuegel 542bf17ee19SAdrian Kuegel // Recalculate real part. 543*a54f4eaeSMogball lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 544*a54f4eaeSMogball lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 545*a54f4eaeSMogball Value newReal = 546*a54f4eaeSMogball b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 547*a54f4eaeSMogball real = 548*a54f4eaeSMogball b.create<SelectOp>(recalc, b.create<arith::MulFOp>(inf, newReal), real); 549bf17ee19SAdrian Kuegel 550bf17ee19SAdrian Kuegel // Recalculate imag part. 551*a54f4eaeSMogball lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 552*a54f4eaeSMogball lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 553*a54f4eaeSMogball Value newImag = 554*a54f4eaeSMogball b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 555*a54f4eaeSMogball imag = 556*a54f4eaeSMogball b.create<SelectOp>(recalc, b.create<arith::MulFOp>(inf, newImag), imag); 557bf17ee19SAdrian Kuegel 558bf17ee19SAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 559bf17ee19SAdrian Kuegel return success(); 560bf17ee19SAdrian Kuegel } 561bf17ee19SAdrian Kuegel }; 562bf17ee19SAdrian Kuegel 563662e074dSAdrian Kuegel struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 564662e074dSAdrian Kuegel using OpConversionPattern<complex::NegOp>::OpConversionPattern; 565662e074dSAdrian Kuegel 566662e074dSAdrian Kuegel LogicalResult 567b54c724bSRiver Riddle matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, 568662e074dSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 569662e074dSAdrian Kuegel auto loc = op.getLoc(); 570b54c724bSRiver Riddle auto type = adaptor.complex().getType().cast<ComplexType>(); 571662e074dSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 572662e074dSAdrian Kuegel 573662e074dSAdrian Kuegel Value real = 574b54c724bSRiver Riddle rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex()); 575662e074dSAdrian Kuegel Value imag = 576b54c724bSRiver Riddle rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex()); 577*a54f4eaeSMogball Value negReal = rewriter.create<arith::NegFOp>(loc, real); 578*a54f4eaeSMogball Value negImag = rewriter.create<arith::NegFOp>(loc, imag); 579662e074dSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 580662e074dSAdrian Kuegel return success(); 581662e074dSAdrian Kuegel } 582662e074dSAdrian Kuegel }; 583f112bd61SAdrian Kuegel 584f112bd61SAdrian Kuegel struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 585f112bd61SAdrian Kuegel using OpConversionPattern<complex::SignOp>::OpConversionPattern; 586f112bd61SAdrian Kuegel 587f112bd61SAdrian Kuegel LogicalResult 588b54c724bSRiver Riddle matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, 589f112bd61SAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 590b54c724bSRiver Riddle auto type = adaptor.complex().getType().cast<ComplexType>(); 591f112bd61SAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 592f112bd61SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 593f112bd61SAdrian Kuegel 594b54c724bSRiver Riddle Value real = b.create<complex::ReOp>(elementType, adaptor.complex()); 595b54c724bSRiver Riddle Value imag = b.create<complex::ImOp>(elementType, adaptor.complex()); 596*a54f4eaeSMogball Value zero = 597*a54f4eaeSMogball b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 598*a54f4eaeSMogball Value realIsZero = 599*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 600*a54f4eaeSMogball Value imagIsZero = 601*a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 602*a54f4eaeSMogball Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 603b54c724bSRiver Riddle auto abs = b.create<complex::AbsOp>(elementType, adaptor.complex()); 604*a54f4eaeSMogball Value realSign = b.create<arith::DivFOp>(real, abs); 605*a54f4eaeSMogball Value imagSign = b.create<arith::DivFOp>(imag, abs); 606f112bd61SAdrian Kuegel Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 607b54c724bSRiver Riddle rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, adaptor.complex(), sign); 608f112bd61SAdrian Kuegel return success(); 609f112bd61SAdrian Kuegel } 610f112bd61SAdrian Kuegel }; 6112ea7fb7bSAdrian Kuegel } // namespace 6122ea7fb7bSAdrian Kuegel 6132ea7fb7bSAdrian Kuegel void mlir::populateComplexToStandardConversionPatterns( 6142ea7fb7bSAdrian Kuegel RewritePatternSet &patterns) { 615f112bd61SAdrian Kuegel // clang-format off 616f112bd61SAdrian Kuegel patterns.add< 617f112bd61SAdrian Kuegel AbsOpConversion, 618*a54f4eaeSMogball ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, 619*a54f4eaeSMogball ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, 620*a54f4eaeSMogball BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, 621*a54f4eaeSMogball BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, 622f112bd61SAdrian Kuegel DivOpConversion, 623f112bd61SAdrian Kuegel ExpOpConversion, 624380fa71fSAdrian Kuegel LogOpConversion, 6256e80e3bdSAdrian Kuegel Log1pOpConversion, 626bf17ee19SAdrian Kuegel MulOpConversion, 627f112bd61SAdrian Kuegel NegOpConversion, 628f112bd61SAdrian Kuegel SignOpConversion>(patterns.getContext()); 629f112bd61SAdrian Kuegel // clang-format on 6302ea7fb7bSAdrian Kuegel } 6312ea7fb7bSAdrian Kuegel 6322ea7fb7bSAdrian Kuegel namespace { 6332ea7fb7bSAdrian Kuegel struct ConvertComplexToStandardPass 6342ea7fb7bSAdrian Kuegel : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 6352ea7fb7bSAdrian Kuegel void runOnFunction() override; 6362ea7fb7bSAdrian Kuegel }; 6372ea7fb7bSAdrian Kuegel 6382ea7fb7bSAdrian Kuegel void ConvertComplexToStandardPass::runOnFunction() { 6392ea7fb7bSAdrian Kuegel auto function = getFunction(); 6402ea7fb7bSAdrian Kuegel 6412ea7fb7bSAdrian Kuegel // Convert to the Standard dialect using the converter defined above. 6422ea7fb7bSAdrian Kuegel RewritePatternSet patterns(&getContext()); 6432ea7fb7bSAdrian Kuegel populateComplexToStandardConversionPatterns(patterns); 6442ea7fb7bSAdrian Kuegel 6452ea7fb7bSAdrian Kuegel ConversionTarget target(getContext()); 646*a54f4eaeSMogball target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect, 647*a54f4eaeSMogball math::MathDialect>(); 648fb978f09SAdrian Kuegel target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 6492ea7fb7bSAdrian Kuegel if (failed(applyPartialConversion(function, target, std::move(patterns)))) 6502ea7fb7bSAdrian Kuegel signalPassFailure(); 6512ea7fb7bSAdrian Kuegel } 6522ea7fb7bSAdrian Kuegel } // namespace 6532ea7fb7bSAdrian Kuegel 6542ea7fb7bSAdrian Kuegel std::unique_ptr<OperationPass<FuncOp>> 6552ea7fb7bSAdrian Kuegel mlir::createConvertComplexToStandardPass() { 6562ea7fb7bSAdrian Kuegel return std::make_unique<ConvertComplexToStandardPass>(); 6572ea7fb7bSAdrian Kuegel } 658