//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include #include #include "../PassDetail.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; namespace { struct AbsOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::AbsOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { complex::AbsOp::Adaptor transformed(operands); auto loc = op.getLoc(); auto type = op.getType(); Value real = rewriter.create(loc, type, transformed.complex()); Value imag = rewriter.create(loc, type, transformed.complex()); Value realSqr = rewriter.create(loc, real, real); Value imagSqr = rewriter.create(loc, imag, imag); Value sqNorm = rewriter.create(loc, realSqr, imagSqr); rewriter.replaceOpWithNewOp(op, sqNorm); return success(); } }; template struct ComparisonOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; using ResultCombiner = std::conditional_t::value, AndOp, OrOp>; LogicalResult matchAndRewrite(ComparisonOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { typename ComparisonOp::Adaptor transformed(operands); auto loc = op.getLoc(); auto type = transformed.lhs() .getType() .template cast() .getElementType(); Value realLhs = rewriter.create(loc, type, transformed.lhs()); Value imagLhs = rewriter.create(loc, type, transformed.lhs()); Value realRhs = rewriter.create(loc, type, transformed.rhs()); Value imagRhs = rewriter.create(loc, type, transformed.rhs()); Value realComparison = rewriter.create(loc, p, realLhs, realRhs); Value imagComparison = rewriter.create(loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); return success(); } }; struct DivOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::DivOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { complex::DivOp::Adaptor transformed(operands); auto loc = op.getLoc(); auto type = transformed.lhs().getType().template cast(); auto elementType = type.getElementType().cast(); Value lhsReal = rewriter.create(loc, elementType, transformed.lhs()); Value lhsImag = rewriter.create(loc, elementType, transformed.lhs()); Value rhsReal = rewriter.create(loc, elementType, transformed.rhs()); Value rhsImag = rewriter.create(loc, elementType, transformed.rhs()); // Smith's algorithm to divide complex numbers. It is just a bit smarter // way to compute the following formula: // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) // = ((lhsReal * rhsReal + lhsImag * rhsImag) + // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 // // Depending on whether |rhsReal| < |rhsImag| we compute either // rhsRealImagRatio = rhsReal / rhsImag // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom // // or // // rhsImagRealRatio = rhsImag / rhsReal // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom // // See https://dl.acm.org/citation.cfm?id=368661 for more details. Value rhsRealImagRatio = rewriter.create(loc, rhsReal, rhsImag); Value rhsRealImagDenom = rewriter.create( loc, rhsImag, rewriter.create(loc, rhsRealImagRatio, rhsReal)); Value realNumerator1 = rewriter.create( loc, rewriter.create(loc, lhsReal, rhsRealImagRatio), lhsImag); Value resultReal1 = rewriter.create(loc, realNumerator1, rhsRealImagDenom); Value imagNumerator1 = rewriter.create( loc, rewriter.create(loc, lhsImag, rhsRealImagRatio), lhsReal); Value resultImag1 = rewriter.create(loc, imagNumerator1, rhsRealImagDenom); Value rhsImagRealRatio = rewriter.create(loc, rhsImag, rhsReal); Value rhsImagRealDenom = rewriter.create( loc, rhsReal, rewriter.create(loc, rhsImagRealRatio, rhsImag)); Value realNumerator2 = rewriter.create( loc, lhsReal, rewriter.create(loc, lhsImag, rhsImagRealRatio)); Value resultReal2 = rewriter.create(loc, realNumerator2, rhsImagRealDenom); Value imagNumerator2 = rewriter.create( loc, lhsImag, rewriter.create(loc, lhsReal, rhsImagRealRatio)); Value resultImag2 = rewriter.create(loc, imagNumerator2, rhsImagRealDenom); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. Value zero = rewriter.create(loc, elementType, rewriter.getZeroAttr(elementType)); Value rhsRealAbs = rewriter.create(loc, rhsReal); Value rhsRealIsZero = rewriter.create(loc, CmpFPredicate::OEQ, rhsRealAbs, zero); Value rhsImagAbs = rewriter.create(loc, rhsImag); Value rhsImagIsZero = rewriter.create(loc, CmpFPredicate::OEQ, rhsImagAbs, zero); Value lhsRealIsNotNaN = rewriter.create(loc, CmpFPredicate::ORD, lhsReal, zero); Value lhsImagIsNotNaN = rewriter.create(loc, CmpFPredicate::ORD, lhsImag, zero); Value lhsContainsNotNaNValue = rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); Value resultIsInfinity = rewriter.create( loc, lhsContainsNotNaNValue, rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); Value inf = rewriter.create( loc, elementType, rewriter.getFloatAttr( elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value infWithSignOfRhsReal = rewriter.create(loc, inf, rhsReal); Value infinityResultReal = rewriter.create(loc, infWithSignOfRhsReal, lhsReal); Value infinityResultImag = rewriter.create(loc, infWithSignOfRhsReal, lhsImag); // Case 2. Infinite numerator, finite denominator. Value rhsRealFinite = rewriter.create(loc, CmpFPredicate::ONE, rhsRealAbs, inf); Value rhsImagFinite = rewriter.create(loc, CmpFPredicate::ONE, rhsImagAbs, inf); Value rhsFinite = rewriter.create(loc, rhsRealFinite, rhsImagFinite); Value lhsRealAbs = rewriter.create(loc, lhsReal); Value lhsRealInfinite = rewriter.create(loc, CmpFPredicate::OEQ, lhsRealAbs, inf); Value lhsImagAbs = rewriter.create(loc, lhsImag); Value lhsImagInfinite = rewriter.create(loc, CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsInfinite = rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = rewriter.create(loc, lhsInfinite, rhsFinite); Value one = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 1)); Value lhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsRealInfinite, one, zero), lhsReal); Value lhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsImagInfinite, one, zero), lhsImag); Value lhsRealIsInfWithSignTimesRhsReal = rewriter.create(loc, lhsRealIsInfWithSign, rhsReal); Value lhsImagIsInfWithSignTimesRhsImag = rewriter.create(loc, lhsImagIsInfWithSign, rhsImag); Value resultReal3 = rewriter.create( loc, inf, rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, lhsImagIsInfWithSignTimesRhsImag)); Value lhsRealIsInfWithSignTimesRhsImag = rewriter.create(loc, lhsRealIsInfWithSign, rhsImag); Value lhsImagIsInfWithSignTimesRhsReal = rewriter.create(loc, lhsImagIsInfWithSign, rhsReal); Value resultImag3 = rewriter.create( loc, inf, rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, lhsRealIsInfWithSignTimesRhsImag)); // Case 3: Finite numerator, infinite denominator. Value lhsRealFinite = rewriter.create(loc, CmpFPredicate::ONE, lhsRealAbs, inf); Value lhsImagFinite = rewriter.create(loc, CmpFPredicate::ONE, lhsImagAbs, inf); Value lhsFinite = rewriter.create(loc, lhsRealFinite, lhsImagFinite); Value rhsRealInfinite = rewriter.create(loc, CmpFPredicate::OEQ, rhsRealAbs, inf); Value rhsImagInfinite = rewriter.create(loc, CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsInfinite = rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = rewriter.create(loc, lhsFinite, rhsInfinite); Value rhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsRealInfinite, one, zero), rhsReal); Value rhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsImagInfinite, one, zero), rhsImag); Value rhsRealIsInfWithSignTimesLhsReal = rewriter.create(loc, lhsReal, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsImag = rewriter.create(loc, lhsImag, rhsImagIsInfWithSign); Value resultReal4 = rewriter.create( loc, zero, rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, rhsImagIsInfWithSignTimesLhsImag)); Value rhsRealIsInfWithSignTimesLhsImag = rewriter.create(loc, lhsImag, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsReal = rewriter.create(loc, lhsReal, rhsImagIsInfWithSign); Value resultImag4 = rewriter.create( loc, zero, rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, rhsImagIsInfWithSignTimesLhsReal)); Value realAbsSmallerThanImagAbs = rewriter.create( loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); Value resultReal = rewriter.create(loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); Value resultImag = rewriter.create(loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); Value resultRealSpecialCase3 = rewriter.create( loc, finiteNumInfiniteDenom, resultReal4, resultReal); Value resultImagSpecialCase3 = rewriter.create( loc, finiteNumInfiniteDenom, resultImag4, resultImag); Value resultRealSpecialCase2 = rewriter.create( loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); Value resultImagSpecialCase2 = rewriter.create( loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); Value resultRealSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); Value resultImagSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); Value resultRealIsNaN = rewriter.create(loc, CmpFPredicate::UNO, resultReal, zero); Value resultImagIsNaN = rewriter.create(loc, CmpFPredicate::UNO, resultImag, zero); Value resultIsNaN = rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); Value resultRealWithSpecialCases = rewriter.create( loc, resultIsNaN, resultRealSpecialCase1, resultReal); Value resultImagWithSpecialCases = rewriter.create( loc, resultIsNaN, resultImagSpecialCase1, resultImag); rewriter.replaceOpWithNewOp( op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); return success(); } }; } // namespace void mlir::populateComplexToStandardConversionPatterns( RewritePatternSet &patterns) { patterns.add, ComparisonOpConversion, DivOpConversion>(patterns.getContext()); } namespace { struct ConvertComplexToStandardPass : public ConvertComplexToStandardBase { void runOnFunction() override; }; void ConvertComplexToStandardPass::runOnFunction() { auto function = getFunction(); // Convert to the Standard dialect using the converter defined above. RewritePatternSet patterns(&getContext()); populateComplexToStandardConversionPatterns(patterns); ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalOp(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } } // namespace std::unique_ptr> mlir::createConvertComplexToStandardPass() { return std::make_unique(); }