//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include #include #include "../PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; namespace { struct AbsOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = op.getType(); Value real = rewriter.create(loc, type, adaptor.complex()); Value imag = rewriter.create(loc, type, adaptor.complex()); Value realSqr = rewriter.create(loc, real, real); Value imagSqr = rewriter.create(loc, imag, imag); Value sqNorm = rewriter.create(loc, realSqr, imagSqr); rewriter.replaceOpWithNewOp(op, sqNorm); return success(); } }; template struct ComparisonOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; using ResultCombiner = std::conditional_t::value, arith::AndIOp, arith::OrIOp>; LogicalResult matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = adaptor.lhs().getType().template cast().getElementType(); Value realLhs = rewriter.create(loc, type, adaptor.lhs()); Value imagLhs = rewriter.create(loc, type, adaptor.lhs()); Value realRhs = rewriter.create(loc, type, adaptor.rhs()); Value imagRhs = rewriter.create(loc, type, adaptor.rhs()); Value realComparison = rewriter.create(loc, p, realLhs, realRhs); Value imagComparison = rewriter.create(loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); return success(); } }; // Default conversion which applies the BinaryStandardOp separately on the real // and imaginary parts. Can for example be used for complex::AddOp and // complex::SubOp. template struct BinaryComplexOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = adaptor.lhs().getType().template cast(); auto elementType = type.getElementType().template cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value realLhs = b.create(elementType, adaptor.lhs()); Value realRhs = b.create(elementType, adaptor.rhs()); Value resultReal = b.create(elementType, realLhs, realRhs); Value imagLhs = b.create(elementType, adaptor.lhs()); Value imagRhs = b.create(elementType, adaptor.rhs()); Value resultImag = b.create(elementType, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct DivOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = adaptor.lhs().getType().cast(); auto elementType = type.getElementType().cast(); Value lhsReal = rewriter.create(loc, elementType, adaptor.lhs()); Value lhsImag = rewriter.create(loc, elementType, adaptor.lhs()); Value rhsReal = rewriter.create(loc, elementType, adaptor.rhs()); Value rhsImag = rewriter.create(loc, elementType, adaptor.rhs()); // Smith's algorithm to divide complex numbers. It is just a bit smarter // way to compute the following formula: // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) // = ((lhsReal * rhsReal + lhsImag * rhsImag) + // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 // // Depending on whether |rhsReal| < |rhsImag| we compute either // rhsRealImagRatio = rhsReal / rhsImag // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom // // or // // rhsImagRealRatio = rhsImag / rhsReal // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom // // See https://dl.acm.org/citation.cfm?id=368661 for more details. Value rhsRealImagRatio = rewriter.create(loc, rhsReal, rhsImag); Value rhsRealImagDenom = rewriter.create( loc, rhsImag, rewriter.create(loc, rhsRealImagRatio, rhsReal)); Value realNumerator1 = rewriter.create( loc, rewriter.create(loc, lhsReal, rhsRealImagRatio), lhsImag); Value resultReal1 = rewriter.create(loc, realNumerator1, rhsRealImagDenom); Value imagNumerator1 = rewriter.create( loc, rewriter.create(loc, lhsImag, rhsRealImagRatio), lhsReal); Value resultImag1 = rewriter.create(loc, imagNumerator1, rhsRealImagDenom); Value rhsImagRealRatio = rewriter.create(loc, rhsImag, rhsReal); Value rhsImagRealDenom = rewriter.create( loc, rhsReal, rewriter.create(loc, rhsImagRealRatio, rhsImag)); Value realNumerator2 = rewriter.create( loc, lhsReal, rewriter.create(loc, lhsImag, rhsImagRealRatio)); Value resultReal2 = rewriter.create(loc, realNumerator2, rhsImagRealDenom); Value imagNumerator2 = rewriter.create( loc, lhsImag, rewriter.create(loc, lhsReal, rhsImagRealRatio)); Value resultImag2 = rewriter.create(loc, imagNumerator2, rhsImagRealDenom); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. Value zero = rewriter.create( loc, elementType, rewriter.getZeroAttr(elementType)); Value rhsRealAbs = rewriter.create(loc, rhsReal); Value rhsRealIsZero = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); Value rhsImagAbs = rewriter.create(loc, rhsImag); Value rhsImagIsZero = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); Value lhsRealIsNotNaN = rewriter.create( loc, arith::CmpFPredicate::ORD, lhsReal, zero); Value lhsImagIsNotNaN = rewriter.create( loc, arith::CmpFPredicate::ORD, lhsImag, zero); Value lhsContainsNotNaNValue = rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); Value resultIsInfinity = rewriter.create( loc, lhsContainsNotNaNValue, rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); Value inf = rewriter.create( loc, elementType, rewriter.getFloatAttr( elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value infWithSignOfRhsReal = rewriter.create(loc, inf, rhsReal); Value infinityResultReal = rewriter.create(loc, infWithSignOfRhsReal, lhsReal); Value infinityResultImag = rewriter.create(loc, infWithSignOfRhsReal, lhsImag); // Case 2. Infinite numerator, finite denominator. Value rhsRealFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); Value rhsImagFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); Value rhsFinite = rewriter.create(loc, rhsRealFinite, rhsImagFinite); Value lhsRealAbs = rewriter.create(loc, lhsReal); Value lhsRealInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); Value lhsImagAbs = rewriter.create(loc, lhsImag); Value lhsImagInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsInfinite = rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = rewriter.create(loc, lhsInfinite, rhsFinite); Value one = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 1)); Value lhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsRealInfinite, one, zero), lhsReal); Value lhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsImagInfinite, one, zero), lhsImag); Value lhsRealIsInfWithSignTimesRhsReal = rewriter.create(loc, lhsRealIsInfWithSign, rhsReal); Value lhsImagIsInfWithSignTimesRhsImag = rewriter.create(loc, lhsImagIsInfWithSign, rhsImag); Value resultReal3 = rewriter.create( loc, inf, rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, lhsImagIsInfWithSignTimesRhsImag)); Value lhsRealIsInfWithSignTimesRhsImag = rewriter.create(loc, lhsRealIsInfWithSign, rhsImag); Value lhsImagIsInfWithSignTimesRhsReal = rewriter.create(loc, lhsImagIsInfWithSign, rhsReal); Value resultImag3 = rewriter.create( loc, inf, rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, lhsRealIsInfWithSignTimesRhsImag)); // Case 3: Finite numerator, infinite denominator. Value lhsRealFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); Value lhsImagFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); Value lhsFinite = rewriter.create(loc, lhsRealFinite, lhsImagFinite); Value rhsRealInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); Value rhsImagInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsInfinite = rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = rewriter.create(loc, lhsFinite, rhsInfinite); Value rhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsRealInfinite, one, zero), rhsReal); Value rhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsImagInfinite, one, zero), rhsImag); Value rhsRealIsInfWithSignTimesLhsReal = rewriter.create(loc, lhsReal, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsImag = rewriter.create(loc, lhsImag, rhsImagIsInfWithSign); Value resultReal4 = rewriter.create( loc, zero, rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, rhsImagIsInfWithSignTimesLhsImag)); Value rhsRealIsInfWithSignTimesLhsImag = rewriter.create(loc, lhsImag, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsReal = rewriter.create(loc, lhsReal, rhsImagIsInfWithSign); Value resultImag4 = rewriter.create( loc, zero, rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, rhsImagIsInfWithSignTimesLhsReal)); Value realAbsSmallerThanImagAbs = rewriter.create( loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); Value resultReal = rewriter.create(loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); Value resultImag = rewriter.create(loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); Value resultRealSpecialCase3 = rewriter.create( loc, finiteNumInfiniteDenom, resultReal4, resultReal); Value resultImagSpecialCase3 = rewriter.create( loc, finiteNumInfiniteDenom, resultImag4, resultImag); Value resultRealSpecialCase2 = rewriter.create( loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); Value resultImagSpecialCase2 = rewriter.create( loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); Value resultRealSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); Value resultImagSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); Value resultRealIsNaN = rewriter.create( loc, arith::CmpFPredicate::UNO, resultReal, zero); Value resultImagIsNaN = rewriter.create( loc, arith::CmpFPredicate::UNO, resultImag, zero); Value resultIsNaN = rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); Value resultRealWithSpecialCases = rewriter.create( loc, resultIsNaN, resultRealSpecialCase1, resultReal); Value resultImagWithSpecialCases = rewriter.create( loc, resultIsNaN, resultImagSpecialCase1, resultImag); rewriter.replaceOpWithNewOp( op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); return success(); } }; struct ExpOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); Value real = rewriter.create(loc, elementType, adaptor.complex()); Value imag = rewriter.create(loc, elementType, adaptor.complex()); Value expReal = rewriter.create(loc, real); Value cosImag = rewriter.create(loc, imag); Value resultReal = rewriter.create(loc, expReal, cosImag); Value sinImag = rewriter.create(loc, imag); Value resultImag = rewriter.create(loc, expReal, sinImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct LogOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value abs = b.create(elementType, adaptor.complex()); Value resultReal = b.create(elementType, abs); Value real = b.create(elementType, adaptor.complex()); Value imag = b.create(elementType, adaptor.complex()); Value resultImag = b.create(elementType, imag, real); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct Log1pOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = b.create(elementType, adaptor.complex()); Value imag = b.create(elementType, adaptor.complex()); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); Value realPlusOne = b.create(real, one); Value newComplex = b.create(type, realPlusOne, imag); rewriter.replaceOpWithNewOp(op, type, newComplex); return success(); } }; struct MulOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = adaptor.lhs().getType().cast(); auto elementType = type.getElementType().cast(); Value lhsReal = b.create(elementType, adaptor.lhs()); Value lhsRealAbs = b.create(lhsReal); Value lhsImag = b.create(elementType, adaptor.lhs()); Value lhsImagAbs = b.create(lhsImag); Value rhsReal = b.create(elementType, adaptor.rhs()); Value rhsRealAbs = b.create(rhsReal); Value rhsImag = b.create(elementType, adaptor.rhs()); Value rhsImagAbs = b.create(rhsImag); Value lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); Value lhsRealTimesRhsRealAbs = b.create(lhsRealTimesRhsReal); Value lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); Value lhsImagTimesRhsImagAbs = b.create(lhsImagTimesRhsImag); Value real = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); Value lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); Value lhsImagTimesRhsRealAbs = b.create(lhsImagTimesRhsReal); Value lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); Value lhsRealTimesRhsImagAbs = b.create(lhsRealTimesRhsImag); Value imag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); // Handle cases where the "naive" calculation results in NaN values. Value realIsNan = b.create(arith::CmpFPredicate::UNO, real, real); Value imagIsNan = b.create(arith::CmpFPredicate::UNO, imag, imag); Value isNan = b.create(realIsNan, imagIsNan); Value inf = b.create( elementType, b.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); // Case 1. `lhsReal` or `lhsImag` are infinite. Value lhsRealIsInf = b.create(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); Value lhsImagIsInf = b.create(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsIsInf = b.create(lhsRealIsInf, lhsImagIsInf); Value rhsRealIsNan = b.create(arith::CmpFPredicate::UNO, rhsReal, rhsReal); Value rhsImagIsNan = b.create(arith::CmpFPredicate::UNO, rhsImag, rhsImag); Value zero = b.create(elementType, b.getZeroAttr(elementType)); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); Value lhsRealIsInfFloat = b.create(lhsRealIsInf, one, zero); lhsReal = b.create( lhsIsInf, b.create(lhsRealIsInfFloat, lhsReal), lhsReal); Value lhsImagIsInfFloat = b.create(lhsImagIsInf, one, zero); lhsImag = b.create( lhsIsInf, b.create(lhsImagIsInfFloat, lhsImag), lhsImag); Value lhsIsInfAndRhsRealIsNan = b.create(lhsIsInf, rhsRealIsNan); rhsReal = b.create(lhsIsInfAndRhsRealIsNan, b.create(zero, rhsReal), rhsReal); Value lhsIsInfAndRhsImagIsNan = b.create(lhsIsInf, rhsImagIsNan); rhsImag = b.create(lhsIsInfAndRhsImagIsNan, b.create(zero, rhsImag), rhsImag); // Case 2. `rhsReal` or `rhsImag` are infinite. Value rhsRealIsInf = b.create(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); Value rhsImagIsInf = b.create(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsIsInf = b.create(rhsRealIsInf, rhsImagIsInf); Value lhsRealIsNan = b.create(arith::CmpFPredicate::UNO, lhsReal, lhsReal); Value lhsImagIsNan = b.create(arith::CmpFPredicate::UNO, lhsImag, lhsImag); Value rhsRealIsInfFloat = b.create(rhsRealIsInf, one, zero); rhsReal = b.create( rhsIsInf, b.create(rhsRealIsInfFloat, rhsReal), rhsReal); Value rhsImagIsInfFloat = b.create(rhsImagIsInf, one, zero); rhsImag = b.create( rhsIsInf, b.create(rhsImagIsInfFloat, rhsImag), rhsImag); Value rhsIsInfAndLhsRealIsNan = b.create(rhsIsInf, lhsRealIsNan); lhsReal = b.create(rhsIsInfAndLhsRealIsNan, b.create(zero, lhsReal), lhsReal); Value rhsIsInfAndLhsImagIsNan = b.create(rhsIsInf, lhsImagIsNan); lhsImag = b.create(rhsIsInfAndLhsImagIsNan, b.create(zero, lhsImag), lhsImag); Value recalc = b.create(lhsIsInf, rhsIsInf); // Case 3. One of the pairwise products of left hand side with right hand // side is infinite. Value lhsRealTimesRhsRealIsInf = b.create( arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); Value lhsImagTimesRhsImagIsInf = b.create( arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); Value isSpecialCase = b.create(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf); Value lhsRealTimesRhsImagIsInf = b.create( arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); isSpecialCase = b.create(isSpecialCase, lhsRealTimesRhsImagIsInf); Value lhsImagTimesRhsRealIsInf = b.create( arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); isSpecialCase = b.create(isSpecialCase, lhsImagTimesRhsRealIsInf); Type i1Type = b.getI1Type(); Value notRecalc = b.create( recalc, b.create(i1Type, b.getIntegerAttr(i1Type, 1))); isSpecialCase = b.create(isSpecialCase, notRecalc); Value isSpecialCaseAndLhsRealIsNan = b.create(isSpecialCase, lhsRealIsNan); lhsReal = b.create(isSpecialCaseAndLhsRealIsNan, b.create(zero, lhsReal), lhsReal); Value isSpecialCaseAndLhsImagIsNan = b.create(isSpecialCase, lhsImagIsNan); lhsImag = b.create(isSpecialCaseAndLhsImagIsNan, b.create(zero, lhsImag), lhsImag); Value isSpecialCaseAndRhsRealIsNan = b.create(isSpecialCase, rhsRealIsNan); rhsReal = b.create(isSpecialCaseAndRhsRealIsNan, b.create(zero, rhsReal), rhsReal); Value isSpecialCaseAndRhsImagIsNan = b.create(isSpecialCase, rhsImagIsNan); rhsImag = b.create(isSpecialCaseAndRhsImagIsNan, b.create(zero, rhsImag), rhsImag); recalc = b.create(recalc, isSpecialCase); recalc = b.create(isNan, recalc); // Recalculate real part. lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); Value newReal = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); real = b.create(recalc, b.create(inf, newReal), real); // Recalculate imag part. lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); Value newImag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); imag = b.create(recalc, b.create(inf, newImag), imag); rewriter.replaceOpWithNewOp(op, type, real, imag); return success(); } }; struct NegOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); Value real = rewriter.create(loc, elementType, adaptor.complex()); Value imag = rewriter.create(loc, elementType, adaptor.complex()); Value negReal = rewriter.create(loc, real); Value negImag = rewriter.create(loc, imag); rewriter.replaceOpWithNewOp(op, type, negReal, negImag); return success(); } }; struct SignOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = b.create(elementType, adaptor.complex()); Value imag = b.create(elementType, adaptor.complex()); Value zero = b.create(elementType, b.getZeroAttr(elementType)); Value realIsZero = b.create(arith::CmpFPredicate::OEQ, real, zero); Value imagIsZero = b.create(arith::CmpFPredicate::OEQ, imag, zero); Value isZero = b.create(realIsZero, imagIsZero); auto abs = b.create(elementType, adaptor.complex()); Value realSign = b.create(real, abs); Value imagSign = b.create(imag, abs); Value sign = b.create(type, realSign, imagSign); rewriter.replaceOpWithNewOp(op, isZero, adaptor.complex(), sign); return success(); } }; } // namespace void mlir::populateComplexToStandardConversionPatterns( RewritePatternSet &patterns) { // clang-format off patterns.add< AbsOpConversion, ComparisonOpConversion, ComparisonOpConversion, BinaryComplexOpConversion, BinaryComplexOpConversion, DivOpConversion, ExpOpConversion, LogOpConversion, Log1pOpConversion, MulOpConversion, NegOpConversion, SignOpConversion>(patterns.getContext()); // clang-format on } namespace { struct ConvertComplexToStandardPass : public ConvertComplexToStandardBase { void runOnFunction() override; }; void ConvertComplexToStandardPass::runOnFunction() { auto function = getFunction(); // Convert to the Standard dialect using the converter defined above. RewritePatternSet patterns(&getContext()); populateComplexToStandardConversionPatterns(patterns); ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } } // namespace std::unique_ptr> mlir::createConvertComplexToStandardPass() { return std::make_unique(); }