12ea7fb7bSAdrian Kuegel //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 22ea7fb7bSAdrian Kuegel // 32ea7fb7bSAdrian Kuegel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42ea7fb7bSAdrian Kuegel // See https://llvm.org/LICENSE.txt for license information. 52ea7fb7bSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62ea7fb7bSAdrian Kuegel // 72ea7fb7bSAdrian Kuegel //===----------------------------------------------------------------------===// 82ea7fb7bSAdrian Kuegel 92ea7fb7bSAdrian Kuegel #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 102ea7fb7bSAdrian Kuegel 112ea7fb7bSAdrian Kuegel #include <memory> 12fb8b2b86SAdrian Kuegel #include <type_traits> 132ea7fb7bSAdrian Kuegel 142ea7fb7bSAdrian Kuegel #include "../PassDetail.h" 152ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Complex/IR/Complex.h" 162ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Math/IR/Math.h" 172ea7fb7bSAdrian Kuegel #include "mlir/Dialect/StandardOps/IR/Ops.h" 182ea7fb7bSAdrian Kuegel #include "mlir/IR/PatternMatch.h" 192ea7fb7bSAdrian Kuegel #include "mlir/Transforms/DialectConversion.h" 202ea7fb7bSAdrian Kuegel 212ea7fb7bSAdrian Kuegel using namespace mlir; 222ea7fb7bSAdrian Kuegel 232ea7fb7bSAdrian Kuegel namespace { 242ea7fb7bSAdrian Kuegel struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 252ea7fb7bSAdrian Kuegel using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 262ea7fb7bSAdrian Kuegel 272ea7fb7bSAdrian Kuegel LogicalResult 282ea7fb7bSAdrian Kuegel matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands, 292ea7fb7bSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 302ea7fb7bSAdrian Kuegel complex::AbsOp::Adaptor transformed(operands); 312ea7fb7bSAdrian Kuegel auto loc = op.getLoc(); 322ea7fb7bSAdrian Kuegel auto type = op.getType(); 332ea7fb7bSAdrian Kuegel 342ea7fb7bSAdrian Kuegel Value real = 352ea7fb7bSAdrian Kuegel rewriter.create<complex::ReOp>(loc, type, transformed.complex()); 362ea7fb7bSAdrian Kuegel Value imag = 372ea7fb7bSAdrian Kuegel rewriter.create<complex::ImOp>(loc, type, transformed.complex()); 382ea7fb7bSAdrian Kuegel Value realSqr = rewriter.create<MulFOp>(loc, real, real); 392ea7fb7bSAdrian Kuegel Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag); 402ea7fb7bSAdrian Kuegel Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr); 412ea7fb7bSAdrian Kuegel 422ea7fb7bSAdrian Kuegel rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 432ea7fb7bSAdrian Kuegel return success(); 442ea7fb7bSAdrian Kuegel } 452ea7fb7bSAdrian Kuegel }; 46ac00cb0dSAdrian Kuegel 47fb8b2b86SAdrian Kuegel template <typename ComparisonOp, CmpFPredicate p> 48fb8b2b86SAdrian Kuegel struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 49fb8b2b86SAdrian Kuegel using OpConversionPattern<ComparisonOp>::OpConversionPattern; 50fb8b2b86SAdrian Kuegel using ResultCombiner = 51fb8b2b86SAdrian Kuegel std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 52fb8b2b86SAdrian Kuegel AndOp, OrOp>; 53ac00cb0dSAdrian Kuegel 54ac00cb0dSAdrian Kuegel LogicalResult 55fb8b2b86SAdrian Kuegel matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands, 56ac00cb0dSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 57fb8b2b86SAdrian Kuegel typename ComparisonOp::Adaptor transformed(operands); 58ac00cb0dSAdrian Kuegel auto loc = op.getLoc(); 59fb8b2b86SAdrian Kuegel auto type = transformed.lhs() 60fb8b2b86SAdrian Kuegel .getType() 61fb8b2b86SAdrian Kuegel .template cast<ComplexType>() 62fb8b2b86SAdrian Kuegel .getElementType(); 63ac00cb0dSAdrian Kuegel 64ac00cb0dSAdrian Kuegel Value realLhs = 65ac00cb0dSAdrian Kuegel rewriter.create<complex::ReOp>(loc, type, transformed.lhs()); 66ac00cb0dSAdrian Kuegel Value imagLhs = 67ac00cb0dSAdrian Kuegel rewriter.create<complex::ImOp>(loc, type, transformed.lhs()); 68ac00cb0dSAdrian Kuegel Value realRhs = 69ac00cb0dSAdrian Kuegel rewriter.create<complex::ReOp>(loc, type, transformed.rhs()); 70ac00cb0dSAdrian Kuegel Value imagRhs = 71ac00cb0dSAdrian Kuegel rewriter.create<complex::ImOp>(loc, type, transformed.rhs()); 72fb8b2b86SAdrian Kuegel Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs); 73fb8b2b86SAdrian Kuegel Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs); 74ac00cb0dSAdrian Kuegel 75fb8b2b86SAdrian Kuegel rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 76fb8b2b86SAdrian Kuegel imagComparison); 77ac00cb0dSAdrian Kuegel return success(); 78ac00cb0dSAdrian Kuegel } 79ac00cb0dSAdrian Kuegel }; 80942be7cbSAdrian Kuegel 81942be7cbSAdrian Kuegel struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 82942be7cbSAdrian Kuegel using OpConversionPattern<complex::DivOp>::OpConversionPattern; 83942be7cbSAdrian Kuegel 84942be7cbSAdrian Kuegel LogicalResult 85942be7cbSAdrian Kuegel matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands, 86942be7cbSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 87942be7cbSAdrian Kuegel complex::DivOp::Adaptor transformed(operands); 88942be7cbSAdrian Kuegel auto loc = op.getLoc(); 89*73cbc91cSAdrian Kuegel auto type = transformed.lhs().getType().cast<ComplexType>(); 90942be7cbSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 91942be7cbSAdrian Kuegel 92942be7cbSAdrian Kuegel Value lhsReal = 93942be7cbSAdrian Kuegel rewriter.create<complex::ReOp>(loc, elementType, transformed.lhs()); 94942be7cbSAdrian Kuegel Value lhsImag = 95942be7cbSAdrian Kuegel rewriter.create<complex::ImOp>(loc, elementType, transformed.lhs()); 96942be7cbSAdrian Kuegel Value rhsReal = 97942be7cbSAdrian Kuegel rewriter.create<complex::ReOp>(loc, elementType, transformed.rhs()); 98942be7cbSAdrian Kuegel Value rhsImag = 99942be7cbSAdrian Kuegel rewriter.create<complex::ImOp>(loc, elementType, transformed.rhs()); 100942be7cbSAdrian Kuegel 101942be7cbSAdrian Kuegel // Smith's algorithm to divide complex numbers. It is just a bit smarter 102942be7cbSAdrian Kuegel // way to compute the following formula: 103942be7cbSAdrian Kuegel // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 104942be7cbSAdrian Kuegel // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 105942be7cbSAdrian Kuegel // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 106942be7cbSAdrian Kuegel // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 107942be7cbSAdrian Kuegel // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 108942be7cbSAdrian Kuegel // 109942be7cbSAdrian Kuegel // Depending on whether |rhsReal| < |rhsImag| we compute either 110942be7cbSAdrian Kuegel // rhsRealImagRatio = rhsReal / rhsImag 111942be7cbSAdrian Kuegel // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 112942be7cbSAdrian Kuegel // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 113942be7cbSAdrian Kuegel // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 114942be7cbSAdrian Kuegel // 115942be7cbSAdrian Kuegel // or 116942be7cbSAdrian Kuegel // 117942be7cbSAdrian Kuegel // rhsImagRealRatio = rhsImag / rhsReal 118942be7cbSAdrian Kuegel // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 119942be7cbSAdrian Kuegel // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 120942be7cbSAdrian Kuegel // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 121942be7cbSAdrian Kuegel // 122942be7cbSAdrian Kuegel // See https://dl.acm.org/citation.cfm?id=368661 for more details. 123942be7cbSAdrian Kuegel Value rhsRealImagRatio = rewriter.create<DivFOp>(loc, rhsReal, rhsImag); 124942be7cbSAdrian Kuegel Value rhsRealImagDenom = rewriter.create<AddFOp>( 125942be7cbSAdrian Kuegel loc, rhsImag, rewriter.create<MulFOp>(loc, rhsRealImagRatio, rhsReal)); 126942be7cbSAdrian Kuegel Value realNumerator1 = rewriter.create<AddFOp>( 127942be7cbSAdrian Kuegel loc, rewriter.create<MulFOp>(loc, lhsReal, rhsRealImagRatio), lhsImag); 128942be7cbSAdrian Kuegel Value resultReal1 = 129942be7cbSAdrian Kuegel rewriter.create<DivFOp>(loc, realNumerator1, rhsRealImagDenom); 130942be7cbSAdrian Kuegel Value imagNumerator1 = rewriter.create<SubFOp>( 131942be7cbSAdrian Kuegel loc, rewriter.create<MulFOp>(loc, lhsImag, rhsRealImagRatio), lhsReal); 132942be7cbSAdrian Kuegel Value resultImag1 = 133942be7cbSAdrian Kuegel rewriter.create<DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 134942be7cbSAdrian Kuegel 135942be7cbSAdrian Kuegel Value rhsImagRealRatio = rewriter.create<DivFOp>(loc, rhsImag, rhsReal); 136942be7cbSAdrian Kuegel Value rhsImagRealDenom = rewriter.create<AddFOp>( 137942be7cbSAdrian Kuegel loc, rhsReal, rewriter.create<MulFOp>(loc, rhsImagRealRatio, rhsImag)); 138942be7cbSAdrian Kuegel Value realNumerator2 = rewriter.create<AddFOp>( 139942be7cbSAdrian Kuegel loc, lhsReal, rewriter.create<MulFOp>(loc, lhsImag, rhsImagRealRatio)); 140942be7cbSAdrian Kuegel Value resultReal2 = 141942be7cbSAdrian Kuegel rewriter.create<DivFOp>(loc, realNumerator2, rhsImagRealDenom); 142942be7cbSAdrian Kuegel Value imagNumerator2 = rewriter.create<SubFOp>( 143942be7cbSAdrian Kuegel loc, lhsImag, rewriter.create<MulFOp>(loc, lhsReal, rhsImagRealRatio)); 144942be7cbSAdrian Kuegel Value resultImag2 = 145942be7cbSAdrian Kuegel rewriter.create<DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 146942be7cbSAdrian Kuegel 147942be7cbSAdrian Kuegel // Consider corner cases. 148942be7cbSAdrian Kuegel // Case 1. Zero denominator, numerator contains at most one NaN value. 149942be7cbSAdrian Kuegel Value zero = rewriter.create<ConstantOp>(loc, elementType, 150942be7cbSAdrian Kuegel rewriter.getZeroAttr(elementType)); 151942be7cbSAdrian Kuegel Value rhsRealAbs = rewriter.create<AbsFOp>(loc, rhsReal); 152942be7cbSAdrian Kuegel Value rhsRealIsZero = 153942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, zero); 154942be7cbSAdrian Kuegel Value rhsImagAbs = rewriter.create<AbsFOp>(loc, rhsImag); 155942be7cbSAdrian Kuegel Value rhsImagIsZero = 156942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, zero); 157942be7cbSAdrian Kuegel Value lhsRealIsNotNaN = 158942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsReal, zero); 159942be7cbSAdrian Kuegel Value lhsImagIsNotNaN = 160942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsImag, zero); 161942be7cbSAdrian Kuegel Value lhsContainsNotNaNValue = 162942be7cbSAdrian Kuegel rewriter.create<OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 163942be7cbSAdrian Kuegel Value resultIsInfinity = rewriter.create<AndOp>( 164942be7cbSAdrian Kuegel loc, lhsContainsNotNaNValue, 165942be7cbSAdrian Kuegel rewriter.create<AndOp>(loc, rhsRealIsZero, rhsImagIsZero)); 166942be7cbSAdrian Kuegel Value inf = rewriter.create<ConstantOp>( 167942be7cbSAdrian Kuegel loc, elementType, 168942be7cbSAdrian Kuegel rewriter.getFloatAttr( 169942be7cbSAdrian Kuegel elementType, APFloat::getInf(elementType.getFloatSemantics()))); 170942be7cbSAdrian Kuegel Value infWithSignOfRhsReal = rewriter.create<CopySignOp>(loc, inf, rhsReal); 171942be7cbSAdrian Kuegel Value infinityResultReal = 172942be7cbSAdrian Kuegel rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 173942be7cbSAdrian Kuegel Value infinityResultImag = 174942be7cbSAdrian Kuegel rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 175942be7cbSAdrian Kuegel 176942be7cbSAdrian Kuegel // Case 2. Infinite numerator, finite denominator. 177942be7cbSAdrian Kuegel Value rhsRealFinite = 178942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsRealAbs, inf); 179942be7cbSAdrian Kuegel Value rhsImagFinite = 180942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsImagAbs, inf); 181942be7cbSAdrian Kuegel Value rhsFinite = rewriter.create<AndOp>(loc, rhsRealFinite, rhsImagFinite); 182942be7cbSAdrian Kuegel Value lhsRealAbs = rewriter.create<AbsFOp>(loc, lhsReal); 183942be7cbSAdrian Kuegel Value lhsRealInfinite = 184942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsRealAbs, inf); 185942be7cbSAdrian Kuegel Value lhsImagAbs = rewriter.create<AbsFOp>(loc, lhsImag); 186942be7cbSAdrian Kuegel Value lhsImagInfinite = 187942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsImagAbs, inf); 188942be7cbSAdrian Kuegel Value lhsInfinite = 189942be7cbSAdrian Kuegel rewriter.create<OrOp>(loc, lhsRealInfinite, lhsImagInfinite); 190942be7cbSAdrian Kuegel Value infNumFiniteDenom = 191942be7cbSAdrian Kuegel rewriter.create<AndOp>(loc, lhsInfinite, rhsFinite); 192942be7cbSAdrian Kuegel Value one = rewriter.create<ConstantOp>( 193942be7cbSAdrian Kuegel loc, elementType, rewriter.getFloatAttr(elementType, 1)); 194942be7cbSAdrian Kuegel Value lhsRealIsInfWithSign = rewriter.create<CopySignOp>( 195942be7cbSAdrian Kuegel loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero), 196942be7cbSAdrian Kuegel lhsReal); 197942be7cbSAdrian Kuegel Value lhsImagIsInfWithSign = rewriter.create<CopySignOp>( 198942be7cbSAdrian Kuegel loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero), 199942be7cbSAdrian Kuegel lhsImag); 200942be7cbSAdrian Kuegel Value lhsRealIsInfWithSignTimesRhsReal = 201942be7cbSAdrian Kuegel rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 202942be7cbSAdrian Kuegel Value lhsImagIsInfWithSignTimesRhsImag = 203942be7cbSAdrian Kuegel rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 204942be7cbSAdrian Kuegel Value resultReal3 = rewriter.create<MulFOp>( 205942be7cbSAdrian Kuegel loc, inf, 206942be7cbSAdrian Kuegel rewriter.create<AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 207942be7cbSAdrian Kuegel lhsImagIsInfWithSignTimesRhsImag)); 208942be7cbSAdrian Kuegel Value lhsRealIsInfWithSignTimesRhsImag = 209942be7cbSAdrian Kuegel rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 210942be7cbSAdrian Kuegel Value lhsImagIsInfWithSignTimesRhsReal = 211942be7cbSAdrian Kuegel rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 212942be7cbSAdrian Kuegel Value resultImag3 = rewriter.create<MulFOp>( 213942be7cbSAdrian Kuegel loc, inf, 214942be7cbSAdrian Kuegel rewriter.create<SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 215942be7cbSAdrian Kuegel lhsRealIsInfWithSignTimesRhsImag)); 216942be7cbSAdrian Kuegel 217942be7cbSAdrian Kuegel // Case 3: Finite numerator, infinite denominator. 218942be7cbSAdrian Kuegel Value lhsRealFinite = 219942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsRealAbs, inf); 220942be7cbSAdrian Kuegel Value lhsImagFinite = 221942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsImagAbs, inf); 222942be7cbSAdrian Kuegel Value lhsFinite = rewriter.create<AndOp>(loc, lhsRealFinite, lhsImagFinite); 223942be7cbSAdrian Kuegel Value rhsRealInfinite = 224942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, inf); 225942be7cbSAdrian Kuegel Value rhsImagInfinite = 226942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, inf); 227942be7cbSAdrian Kuegel Value rhsInfinite = 228942be7cbSAdrian Kuegel rewriter.create<OrOp>(loc, rhsRealInfinite, rhsImagInfinite); 229942be7cbSAdrian Kuegel Value finiteNumInfiniteDenom = 230942be7cbSAdrian Kuegel rewriter.create<AndOp>(loc, lhsFinite, rhsInfinite); 231942be7cbSAdrian Kuegel Value rhsRealIsInfWithSign = rewriter.create<CopySignOp>( 232942be7cbSAdrian Kuegel loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero), 233942be7cbSAdrian Kuegel rhsReal); 234942be7cbSAdrian Kuegel Value rhsImagIsInfWithSign = rewriter.create<CopySignOp>( 235942be7cbSAdrian Kuegel loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero), 236942be7cbSAdrian Kuegel rhsImag); 237942be7cbSAdrian Kuegel Value rhsRealIsInfWithSignTimesLhsReal = 238942be7cbSAdrian Kuegel rewriter.create<MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 239942be7cbSAdrian Kuegel Value rhsImagIsInfWithSignTimesLhsImag = 240942be7cbSAdrian Kuegel rewriter.create<MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 241942be7cbSAdrian Kuegel Value resultReal4 = rewriter.create<MulFOp>( 242942be7cbSAdrian Kuegel loc, zero, 243942be7cbSAdrian Kuegel rewriter.create<AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 244942be7cbSAdrian Kuegel rhsImagIsInfWithSignTimesLhsImag)); 245942be7cbSAdrian Kuegel Value rhsRealIsInfWithSignTimesLhsImag = 246942be7cbSAdrian Kuegel rewriter.create<MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 247942be7cbSAdrian Kuegel Value rhsImagIsInfWithSignTimesLhsReal = 248942be7cbSAdrian Kuegel rewriter.create<MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 249942be7cbSAdrian Kuegel Value resultImag4 = rewriter.create<MulFOp>( 250942be7cbSAdrian Kuegel loc, zero, 251942be7cbSAdrian Kuegel rewriter.create<SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 252942be7cbSAdrian Kuegel rhsImagIsInfWithSignTimesLhsReal)); 253942be7cbSAdrian Kuegel 254942be7cbSAdrian Kuegel Value realAbsSmallerThanImagAbs = rewriter.create<CmpFOp>( 255942be7cbSAdrian Kuegel loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 256942be7cbSAdrian Kuegel Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 257942be7cbSAdrian Kuegel resultReal1, resultReal2); 258942be7cbSAdrian Kuegel Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 259942be7cbSAdrian Kuegel resultImag1, resultImag2); 260942be7cbSAdrian Kuegel Value resultRealSpecialCase3 = rewriter.create<SelectOp>( 261942be7cbSAdrian Kuegel loc, finiteNumInfiniteDenom, resultReal4, resultReal); 262942be7cbSAdrian Kuegel Value resultImagSpecialCase3 = rewriter.create<SelectOp>( 263942be7cbSAdrian Kuegel loc, finiteNumInfiniteDenom, resultImag4, resultImag); 264942be7cbSAdrian Kuegel Value resultRealSpecialCase2 = rewriter.create<SelectOp>( 265942be7cbSAdrian Kuegel loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 266942be7cbSAdrian Kuegel Value resultImagSpecialCase2 = rewriter.create<SelectOp>( 267942be7cbSAdrian Kuegel loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 268942be7cbSAdrian Kuegel Value resultRealSpecialCase1 = rewriter.create<SelectOp>( 269942be7cbSAdrian Kuegel loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 270942be7cbSAdrian Kuegel Value resultImagSpecialCase1 = rewriter.create<SelectOp>( 271942be7cbSAdrian Kuegel loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 272942be7cbSAdrian Kuegel 273942be7cbSAdrian Kuegel Value resultRealIsNaN = 274942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultReal, zero); 275942be7cbSAdrian Kuegel Value resultImagIsNaN = 276942be7cbSAdrian Kuegel rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultImag, zero); 277942be7cbSAdrian Kuegel Value resultIsNaN = 278942be7cbSAdrian Kuegel rewriter.create<AndOp>(loc, resultRealIsNaN, resultImagIsNaN); 279942be7cbSAdrian Kuegel Value resultRealWithSpecialCases = rewriter.create<SelectOp>( 280942be7cbSAdrian Kuegel loc, resultIsNaN, resultRealSpecialCase1, resultReal); 281942be7cbSAdrian Kuegel Value resultImagWithSpecialCases = rewriter.create<SelectOp>( 282942be7cbSAdrian Kuegel loc, resultIsNaN, resultImagSpecialCase1, resultImag); 283942be7cbSAdrian Kuegel 284942be7cbSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>( 285942be7cbSAdrian Kuegel op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 286942be7cbSAdrian Kuegel return success(); 287942be7cbSAdrian Kuegel } 288942be7cbSAdrian Kuegel }; 289*73cbc91cSAdrian Kuegel 290*73cbc91cSAdrian Kuegel struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 291*73cbc91cSAdrian Kuegel using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 292*73cbc91cSAdrian Kuegel 293*73cbc91cSAdrian Kuegel LogicalResult 294*73cbc91cSAdrian Kuegel matchAndRewrite(complex::ExpOp op, ArrayRef<Value> operands, 295*73cbc91cSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 296*73cbc91cSAdrian Kuegel complex::ExpOp::Adaptor transformed(operands); 297*73cbc91cSAdrian Kuegel auto loc = op.getLoc(); 298*73cbc91cSAdrian Kuegel auto type = transformed.complex().getType().cast<ComplexType>(); 299*73cbc91cSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>(); 300*73cbc91cSAdrian Kuegel 301*73cbc91cSAdrian Kuegel Value real = 302*73cbc91cSAdrian Kuegel rewriter.create<complex::ReOp>(loc, elementType, transformed.complex()); 303*73cbc91cSAdrian Kuegel Value imag = 304*73cbc91cSAdrian Kuegel rewriter.create<complex::ImOp>(loc, elementType, transformed.complex()); 305*73cbc91cSAdrian Kuegel Value expReal = rewriter.create<math::ExpOp>(loc, real); 306*73cbc91cSAdrian Kuegel Value cosImag = rewriter.create<math::CosOp>(loc, imag); 307*73cbc91cSAdrian Kuegel Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag); 308*73cbc91cSAdrian Kuegel Value sinImag = rewriter.create<math::SinOp>(loc, imag); 309*73cbc91cSAdrian Kuegel Value resultImag = rewriter.create<MulFOp>(loc, expReal, sinImag); 310*73cbc91cSAdrian Kuegel 311*73cbc91cSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 312*73cbc91cSAdrian Kuegel resultImag); 313*73cbc91cSAdrian Kuegel return success(); 314*73cbc91cSAdrian Kuegel } 315*73cbc91cSAdrian Kuegel }; 3162ea7fb7bSAdrian Kuegel } // namespace 3172ea7fb7bSAdrian Kuegel 3182ea7fb7bSAdrian Kuegel void mlir::populateComplexToStandardConversionPatterns( 3192ea7fb7bSAdrian Kuegel RewritePatternSet &patterns) { 320fb8b2b86SAdrian Kuegel patterns.add<AbsOpConversion, 321fb8b2b86SAdrian Kuegel ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>, 322942be7cbSAdrian Kuegel ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>, 323*73cbc91cSAdrian Kuegel DivOpConversion, ExpOpConversion>(patterns.getContext()); 3242ea7fb7bSAdrian Kuegel } 3252ea7fb7bSAdrian Kuegel 3262ea7fb7bSAdrian Kuegel namespace { 3272ea7fb7bSAdrian Kuegel struct ConvertComplexToStandardPass 3282ea7fb7bSAdrian Kuegel : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 3292ea7fb7bSAdrian Kuegel void runOnFunction() override; 3302ea7fb7bSAdrian Kuegel }; 3312ea7fb7bSAdrian Kuegel 3322ea7fb7bSAdrian Kuegel void ConvertComplexToStandardPass::runOnFunction() { 3332ea7fb7bSAdrian Kuegel auto function = getFunction(); 3342ea7fb7bSAdrian Kuegel 3352ea7fb7bSAdrian Kuegel // Convert to the Standard dialect using the converter defined above. 3362ea7fb7bSAdrian Kuegel RewritePatternSet patterns(&getContext()); 3372ea7fb7bSAdrian Kuegel populateComplexToStandardConversionPatterns(patterns); 3382ea7fb7bSAdrian Kuegel 3392ea7fb7bSAdrian Kuegel ConversionTarget target(getContext()); 3402ea7fb7bSAdrian Kuegel target.addLegalDialect<StandardOpsDialect, math::MathDialect, 3412ea7fb7bSAdrian Kuegel complex::ComplexDialect>(); 342942be7cbSAdrian Kuegel target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp, 343*73cbc91cSAdrian Kuegel complex::ExpOp, complex::NotEqualOp>(); 3442ea7fb7bSAdrian Kuegel if (failed(applyPartialConversion(function, target, std::move(patterns)))) 3452ea7fb7bSAdrian Kuegel signalPassFailure(); 3462ea7fb7bSAdrian Kuegel } 3472ea7fb7bSAdrian Kuegel } // namespace 3482ea7fb7bSAdrian Kuegel 3492ea7fb7bSAdrian Kuegel std::unique_ptr<OperationPass<FuncOp>> 3502ea7fb7bSAdrian Kuegel mlir::createConvertComplexToStandardPass() { 3512ea7fb7bSAdrian Kuegel return std::make_unique<ConvertComplexToStandardPass>(); 3522ea7fb7bSAdrian Kuegel } 353