12ea7fb7bSAdrian Kuegel //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
22ea7fb7bSAdrian Kuegel //
32ea7fb7bSAdrian Kuegel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42ea7fb7bSAdrian Kuegel // See https://llvm.org/LICENSE.txt for license information.
52ea7fb7bSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62ea7fb7bSAdrian Kuegel //
72ea7fb7bSAdrian Kuegel //===----------------------------------------------------------------------===//
82ea7fb7bSAdrian Kuegel 
92ea7fb7bSAdrian Kuegel #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
102ea7fb7bSAdrian Kuegel 
112ea7fb7bSAdrian Kuegel #include <memory>
12fb8b2b86SAdrian Kuegel #include <type_traits>
132ea7fb7bSAdrian Kuegel 
142ea7fb7bSAdrian Kuegel #include "../PassDetail.h"
152ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Complex/IR/Complex.h"
162ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Math/IR/Math.h"
172ea7fb7bSAdrian Kuegel #include "mlir/Dialect/StandardOps/IR/Ops.h"
182ea7fb7bSAdrian Kuegel #include "mlir/IR/PatternMatch.h"
192ea7fb7bSAdrian Kuegel #include "mlir/Transforms/DialectConversion.h"
202ea7fb7bSAdrian Kuegel 
212ea7fb7bSAdrian Kuegel using namespace mlir;
222ea7fb7bSAdrian Kuegel 
232ea7fb7bSAdrian Kuegel namespace {
242ea7fb7bSAdrian Kuegel struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
252ea7fb7bSAdrian Kuegel   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
262ea7fb7bSAdrian Kuegel 
272ea7fb7bSAdrian Kuegel   LogicalResult
282ea7fb7bSAdrian Kuegel   matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
292ea7fb7bSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
302ea7fb7bSAdrian Kuegel     complex::AbsOp::Adaptor transformed(operands);
312ea7fb7bSAdrian Kuegel     auto loc = op.getLoc();
322ea7fb7bSAdrian Kuegel     auto type = op.getType();
332ea7fb7bSAdrian Kuegel 
342ea7fb7bSAdrian Kuegel     Value real =
352ea7fb7bSAdrian Kuegel         rewriter.create<complex::ReOp>(loc, type, transformed.complex());
362ea7fb7bSAdrian Kuegel     Value imag =
372ea7fb7bSAdrian Kuegel         rewriter.create<complex::ImOp>(loc, type, transformed.complex());
382ea7fb7bSAdrian Kuegel     Value realSqr = rewriter.create<MulFOp>(loc, real, real);
392ea7fb7bSAdrian Kuegel     Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag);
402ea7fb7bSAdrian Kuegel     Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr);
412ea7fb7bSAdrian Kuegel 
422ea7fb7bSAdrian Kuegel     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
432ea7fb7bSAdrian Kuegel     return success();
442ea7fb7bSAdrian Kuegel   }
452ea7fb7bSAdrian Kuegel };
46ac00cb0dSAdrian Kuegel 
47fb8b2b86SAdrian Kuegel template <typename ComparisonOp, CmpFPredicate p>
48fb8b2b86SAdrian Kuegel struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
49fb8b2b86SAdrian Kuegel   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
50fb8b2b86SAdrian Kuegel   using ResultCombiner =
51fb8b2b86SAdrian Kuegel       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
52fb8b2b86SAdrian Kuegel                          AndOp, OrOp>;
53ac00cb0dSAdrian Kuegel 
54ac00cb0dSAdrian Kuegel   LogicalResult
55fb8b2b86SAdrian Kuegel   matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands,
56ac00cb0dSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
57fb8b2b86SAdrian Kuegel     typename ComparisonOp::Adaptor transformed(operands);
58ac00cb0dSAdrian Kuegel     auto loc = op.getLoc();
59fb8b2b86SAdrian Kuegel     auto type = transformed.lhs()
60fb8b2b86SAdrian Kuegel                     .getType()
61fb8b2b86SAdrian Kuegel                     .template cast<ComplexType>()
62fb8b2b86SAdrian Kuegel                     .getElementType();
63ac00cb0dSAdrian Kuegel 
64ac00cb0dSAdrian Kuegel     Value realLhs =
65ac00cb0dSAdrian Kuegel         rewriter.create<complex::ReOp>(loc, type, transformed.lhs());
66ac00cb0dSAdrian Kuegel     Value imagLhs =
67ac00cb0dSAdrian Kuegel         rewriter.create<complex::ImOp>(loc, type, transformed.lhs());
68ac00cb0dSAdrian Kuegel     Value realRhs =
69ac00cb0dSAdrian Kuegel         rewriter.create<complex::ReOp>(loc, type, transformed.rhs());
70ac00cb0dSAdrian Kuegel     Value imagRhs =
71ac00cb0dSAdrian Kuegel         rewriter.create<complex::ImOp>(loc, type, transformed.rhs());
72fb8b2b86SAdrian Kuegel     Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs);
73fb8b2b86SAdrian Kuegel     Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs);
74ac00cb0dSAdrian Kuegel 
75fb8b2b86SAdrian Kuegel     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
76fb8b2b86SAdrian Kuegel                                                 imagComparison);
77ac00cb0dSAdrian Kuegel     return success();
78ac00cb0dSAdrian Kuegel   }
79ac00cb0dSAdrian Kuegel };
80942be7cbSAdrian Kuegel 
81942be7cbSAdrian Kuegel struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
82942be7cbSAdrian Kuegel   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
83942be7cbSAdrian Kuegel 
84942be7cbSAdrian Kuegel   LogicalResult
85942be7cbSAdrian Kuegel   matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
86942be7cbSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
87942be7cbSAdrian Kuegel     complex::DivOp::Adaptor transformed(operands);
88942be7cbSAdrian Kuegel     auto loc = op.getLoc();
89*73cbc91cSAdrian Kuegel     auto type = transformed.lhs().getType().cast<ComplexType>();
90942be7cbSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
91942be7cbSAdrian Kuegel 
92942be7cbSAdrian Kuegel     Value lhsReal =
93942be7cbSAdrian Kuegel         rewriter.create<complex::ReOp>(loc, elementType, transformed.lhs());
94942be7cbSAdrian Kuegel     Value lhsImag =
95942be7cbSAdrian Kuegel         rewriter.create<complex::ImOp>(loc, elementType, transformed.lhs());
96942be7cbSAdrian Kuegel     Value rhsReal =
97942be7cbSAdrian Kuegel         rewriter.create<complex::ReOp>(loc, elementType, transformed.rhs());
98942be7cbSAdrian Kuegel     Value rhsImag =
99942be7cbSAdrian Kuegel         rewriter.create<complex::ImOp>(loc, elementType, transformed.rhs());
100942be7cbSAdrian Kuegel 
101942be7cbSAdrian Kuegel     // Smith's algorithm to divide complex numbers. It is just a bit smarter
102942be7cbSAdrian Kuegel     // way to compute the following formula:
103942be7cbSAdrian Kuegel     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
104942be7cbSAdrian Kuegel     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
105942be7cbSAdrian Kuegel     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
106942be7cbSAdrian Kuegel     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
107942be7cbSAdrian Kuegel     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
108942be7cbSAdrian Kuegel     //
109942be7cbSAdrian Kuegel     // Depending on whether |rhsReal| < |rhsImag| we compute either
110942be7cbSAdrian Kuegel     //   rhsRealImagRatio = rhsReal / rhsImag
111942be7cbSAdrian Kuegel     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
112942be7cbSAdrian Kuegel     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
113942be7cbSAdrian Kuegel     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
114942be7cbSAdrian Kuegel     //
115942be7cbSAdrian Kuegel     // or
116942be7cbSAdrian Kuegel     //
117942be7cbSAdrian Kuegel     //   rhsImagRealRatio = rhsImag / rhsReal
118942be7cbSAdrian Kuegel     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
119942be7cbSAdrian Kuegel     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
120942be7cbSAdrian Kuegel     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
121942be7cbSAdrian Kuegel     //
122942be7cbSAdrian Kuegel     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
123942be7cbSAdrian Kuegel     Value rhsRealImagRatio = rewriter.create<DivFOp>(loc, rhsReal, rhsImag);
124942be7cbSAdrian Kuegel     Value rhsRealImagDenom = rewriter.create<AddFOp>(
125942be7cbSAdrian Kuegel         loc, rhsImag, rewriter.create<MulFOp>(loc, rhsRealImagRatio, rhsReal));
126942be7cbSAdrian Kuegel     Value realNumerator1 = rewriter.create<AddFOp>(
127942be7cbSAdrian Kuegel         loc, rewriter.create<MulFOp>(loc, lhsReal, rhsRealImagRatio), lhsImag);
128942be7cbSAdrian Kuegel     Value resultReal1 =
129942be7cbSAdrian Kuegel         rewriter.create<DivFOp>(loc, realNumerator1, rhsRealImagDenom);
130942be7cbSAdrian Kuegel     Value imagNumerator1 = rewriter.create<SubFOp>(
131942be7cbSAdrian Kuegel         loc, rewriter.create<MulFOp>(loc, lhsImag, rhsRealImagRatio), lhsReal);
132942be7cbSAdrian Kuegel     Value resultImag1 =
133942be7cbSAdrian Kuegel         rewriter.create<DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
134942be7cbSAdrian Kuegel 
135942be7cbSAdrian Kuegel     Value rhsImagRealRatio = rewriter.create<DivFOp>(loc, rhsImag, rhsReal);
136942be7cbSAdrian Kuegel     Value rhsImagRealDenom = rewriter.create<AddFOp>(
137942be7cbSAdrian Kuegel         loc, rhsReal, rewriter.create<MulFOp>(loc, rhsImagRealRatio, rhsImag));
138942be7cbSAdrian Kuegel     Value realNumerator2 = rewriter.create<AddFOp>(
139942be7cbSAdrian Kuegel         loc, lhsReal, rewriter.create<MulFOp>(loc, lhsImag, rhsImagRealRatio));
140942be7cbSAdrian Kuegel     Value resultReal2 =
141942be7cbSAdrian Kuegel         rewriter.create<DivFOp>(loc, realNumerator2, rhsImagRealDenom);
142942be7cbSAdrian Kuegel     Value imagNumerator2 = rewriter.create<SubFOp>(
143942be7cbSAdrian Kuegel         loc, lhsImag, rewriter.create<MulFOp>(loc, lhsReal, rhsImagRealRatio));
144942be7cbSAdrian Kuegel     Value resultImag2 =
145942be7cbSAdrian Kuegel         rewriter.create<DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
146942be7cbSAdrian Kuegel 
147942be7cbSAdrian Kuegel     // Consider corner cases.
148942be7cbSAdrian Kuegel     // Case 1. Zero denominator, numerator contains at most one NaN value.
149942be7cbSAdrian Kuegel     Value zero = rewriter.create<ConstantOp>(loc, elementType,
150942be7cbSAdrian Kuegel                                              rewriter.getZeroAttr(elementType));
151942be7cbSAdrian Kuegel     Value rhsRealAbs = rewriter.create<AbsFOp>(loc, rhsReal);
152942be7cbSAdrian Kuegel     Value rhsRealIsZero =
153942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, zero);
154942be7cbSAdrian Kuegel     Value rhsImagAbs = rewriter.create<AbsFOp>(loc, rhsImag);
155942be7cbSAdrian Kuegel     Value rhsImagIsZero =
156942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, zero);
157942be7cbSAdrian Kuegel     Value lhsRealIsNotNaN =
158942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsReal, zero);
159942be7cbSAdrian Kuegel     Value lhsImagIsNotNaN =
160942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsImag, zero);
161942be7cbSAdrian Kuegel     Value lhsContainsNotNaNValue =
162942be7cbSAdrian Kuegel         rewriter.create<OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
163942be7cbSAdrian Kuegel     Value resultIsInfinity = rewriter.create<AndOp>(
164942be7cbSAdrian Kuegel         loc, lhsContainsNotNaNValue,
165942be7cbSAdrian Kuegel         rewriter.create<AndOp>(loc, rhsRealIsZero, rhsImagIsZero));
166942be7cbSAdrian Kuegel     Value inf = rewriter.create<ConstantOp>(
167942be7cbSAdrian Kuegel         loc, elementType,
168942be7cbSAdrian Kuegel         rewriter.getFloatAttr(
169942be7cbSAdrian Kuegel             elementType, APFloat::getInf(elementType.getFloatSemantics())));
170942be7cbSAdrian Kuegel     Value infWithSignOfRhsReal = rewriter.create<CopySignOp>(loc, inf, rhsReal);
171942be7cbSAdrian Kuegel     Value infinityResultReal =
172942be7cbSAdrian Kuegel         rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
173942be7cbSAdrian Kuegel     Value infinityResultImag =
174942be7cbSAdrian Kuegel         rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
175942be7cbSAdrian Kuegel 
176942be7cbSAdrian Kuegel     // Case 2. Infinite numerator, finite denominator.
177942be7cbSAdrian Kuegel     Value rhsRealFinite =
178942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsRealAbs, inf);
179942be7cbSAdrian Kuegel     Value rhsImagFinite =
180942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsImagAbs, inf);
181942be7cbSAdrian Kuegel     Value rhsFinite = rewriter.create<AndOp>(loc, rhsRealFinite, rhsImagFinite);
182942be7cbSAdrian Kuegel     Value lhsRealAbs = rewriter.create<AbsFOp>(loc, lhsReal);
183942be7cbSAdrian Kuegel     Value lhsRealInfinite =
184942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsRealAbs, inf);
185942be7cbSAdrian Kuegel     Value lhsImagAbs = rewriter.create<AbsFOp>(loc, lhsImag);
186942be7cbSAdrian Kuegel     Value lhsImagInfinite =
187942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsImagAbs, inf);
188942be7cbSAdrian Kuegel     Value lhsInfinite =
189942be7cbSAdrian Kuegel         rewriter.create<OrOp>(loc, lhsRealInfinite, lhsImagInfinite);
190942be7cbSAdrian Kuegel     Value infNumFiniteDenom =
191942be7cbSAdrian Kuegel         rewriter.create<AndOp>(loc, lhsInfinite, rhsFinite);
192942be7cbSAdrian Kuegel     Value one = rewriter.create<ConstantOp>(
193942be7cbSAdrian Kuegel         loc, elementType, rewriter.getFloatAttr(elementType, 1));
194942be7cbSAdrian Kuegel     Value lhsRealIsInfWithSign = rewriter.create<CopySignOp>(
195942be7cbSAdrian Kuegel         loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero),
196942be7cbSAdrian Kuegel         lhsReal);
197942be7cbSAdrian Kuegel     Value lhsImagIsInfWithSign = rewriter.create<CopySignOp>(
198942be7cbSAdrian Kuegel         loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero),
199942be7cbSAdrian Kuegel         lhsImag);
200942be7cbSAdrian Kuegel     Value lhsRealIsInfWithSignTimesRhsReal =
201942be7cbSAdrian Kuegel         rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
202942be7cbSAdrian Kuegel     Value lhsImagIsInfWithSignTimesRhsImag =
203942be7cbSAdrian Kuegel         rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
204942be7cbSAdrian Kuegel     Value resultReal3 = rewriter.create<MulFOp>(
205942be7cbSAdrian Kuegel         loc, inf,
206942be7cbSAdrian Kuegel         rewriter.create<AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
207942be7cbSAdrian Kuegel                                 lhsImagIsInfWithSignTimesRhsImag));
208942be7cbSAdrian Kuegel     Value lhsRealIsInfWithSignTimesRhsImag =
209942be7cbSAdrian Kuegel         rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
210942be7cbSAdrian Kuegel     Value lhsImagIsInfWithSignTimesRhsReal =
211942be7cbSAdrian Kuegel         rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
212942be7cbSAdrian Kuegel     Value resultImag3 = rewriter.create<MulFOp>(
213942be7cbSAdrian Kuegel         loc, inf,
214942be7cbSAdrian Kuegel         rewriter.create<SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
215942be7cbSAdrian Kuegel                                 lhsRealIsInfWithSignTimesRhsImag));
216942be7cbSAdrian Kuegel 
217942be7cbSAdrian Kuegel     // Case 3: Finite numerator, infinite denominator.
218942be7cbSAdrian Kuegel     Value lhsRealFinite =
219942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsRealAbs, inf);
220942be7cbSAdrian Kuegel     Value lhsImagFinite =
221942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsImagAbs, inf);
222942be7cbSAdrian Kuegel     Value lhsFinite = rewriter.create<AndOp>(loc, lhsRealFinite, lhsImagFinite);
223942be7cbSAdrian Kuegel     Value rhsRealInfinite =
224942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, inf);
225942be7cbSAdrian Kuegel     Value rhsImagInfinite =
226942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, inf);
227942be7cbSAdrian Kuegel     Value rhsInfinite =
228942be7cbSAdrian Kuegel         rewriter.create<OrOp>(loc, rhsRealInfinite, rhsImagInfinite);
229942be7cbSAdrian Kuegel     Value finiteNumInfiniteDenom =
230942be7cbSAdrian Kuegel         rewriter.create<AndOp>(loc, lhsFinite, rhsInfinite);
231942be7cbSAdrian Kuegel     Value rhsRealIsInfWithSign = rewriter.create<CopySignOp>(
232942be7cbSAdrian Kuegel         loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero),
233942be7cbSAdrian Kuegel         rhsReal);
234942be7cbSAdrian Kuegel     Value rhsImagIsInfWithSign = rewriter.create<CopySignOp>(
235942be7cbSAdrian Kuegel         loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero),
236942be7cbSAdrian Kuegel         rhsImag);
237942be7cbSAdrian Kuegel     Value rhsRealIsInfWithSignTimesLhsReal =
238942be7cbSAdrian Kuegel         rewriter.create<MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
239942be7cbSAdrian Kuegel     Value rhsImagIsInfWithSignTimesLhsImag =
240942be7cbSAdrian Kuegel         rewriter.create<MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
241942be7cbSAdrian Kuegel     Value resultReal4 = rewriter.create<MulFOp>(
242942be7cbSAdrian Kuegel         loc, zero,
243942be7cbSAdrian Kuegel         rewriter.create<AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
244942be7cbSAdrian Kuegel                                 rhsImagIsInfWithSignTimesLhsImag));
245942be7cbSAdrian Kuegel     Value rhsRealIsInfWithSignTimesLhsImag =
246942be7cbSAdrian Kuegel         rewriter.create<MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
247942be7cbSAdrian Kuegel     Value rhsImagIsInfWithSignTimesLhsReal =
248942be7cbSAdrian Kuegel         rewriter.create<MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
249942be7cbSAdrian Kuegel     Value resultImag4 = rewriter.create<MulFOp>(
250942be7cbSAdrian Kuegel         loc, zero,
251942be7cbSAdrian Kuegel         rewriter.create<SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
252942be7cbSAdrian Kuegel                                 rhsImagIsInfWithSignTimesLhsReal));
253942be7cbSAdrian Kuegel 
254942be7cbSAdrian Kuegel     Value realAbsSmallerThanImagAbs = rewriter.create<CmpFOp>(
255942be7cbSAdrian Kuegel         loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
256942be7cbSAdrian Kuegel     Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
257942be7cbSAdrian Kuegel                                                  resultReal1, resultReal2);
258942be7cbSAdrian Kuegel     Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
259942be7cbSAdrian Kuegel                                                  resultImag1, resultImag2);
260942be7cbSAdrian Kuegel     Value resultRealSpecialCase3 = rewriter.create<SelectOp>(
261942be7cbSAdrian Kuegel         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
262942be7cbSAdrian Kuegel     Value resultImagSpecialCase3 = rewriter.create<SelectOp>(
263942be7cbSAdrian Kuegel         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
264942be7cbSAdrian Kuegel     Value resultRealSpecialCase2 = rewriter.create<SelectOp>(
265942be7cbSAdrian Kuegel         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
266942be7cbSAdrian Kuegel     Value resultImagSpecialCase2 = rewriter.create<SelectOp>(
267942be7cbSAdrian Kuegel         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
268942be7cbSAdrian Kuegel     Value resultRealSpecialCase1 = rewriter.create<SelectOp>(
269942be7cbSAdrian Kuegel         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
270942be7cbSAdrian Kuegel     Value resultImagSpecialCase1 = rewriter.create<SelectOp>(
271942be7cbSAdrian Kuegel         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
272942be7cbSAdrian Kuegel 
273942be7cbSAdrian Kuegel     Value resultRealIsNaN =
274942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultReal, zero);
275942be7cbSAdrian Kuegel     Value resultImagIsNaN =
276942be7cbSAdrian Kuegel         rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultImag, zero);
277942be7cbSAdrian Kuegel     Value resultIsNaN =
278942be7cbSAdrian Kuegel         rewriter.create<AndOp>(loc, resultRealIsNaN, resultImagIsNaN);
279942be7cbSAdrian Kuegel     Value resultRealWithSpecialCases = rewriter.create<SelectOp>(
280942be7cbSAdrian Kuegel         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
281942be7cbSAdrian Kuegel     Value resultImagWithSpecialCases = rewriter.create<SelectOp>(
282942be7cbSAdrian Kuegel         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
283942be7cbSAdrian Kuegel 
284942be7cbSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(
285942be7cbSAdrian Kuegel         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
286942be7cbSAdrian Kuegel     return success();
287942be7cbSAdrian Kuegel   }
288942be7cbSAdrian Kuegel };
289*73cbc91cSAdrian Kuegel 
290*73cbc91cSAdrian Kuegel struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
291*73cbc91cSAdrian Kuegel   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
292*73cbc91cSAdrian Kuegel 
293*73cbc91cSAdrian Kuegel   LogicalResult
294*73cbc91cSAdrian Kuegel   matchAndRewrite(complex::ExpOp op, ArrayRef<Value> operands,
295*73cbc91cSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
296*73cbc91cSAdrian Kuegel     complex::ExpOp::Adaptor transformed(operands);
297*73cbc91cSAdrian Kuegel     auto loc = op.getLoc();
298*73cbc91cSAdrian Kuegel     auto type = transformed.complex().getType().cast<ComplexType>();
299*73cbc91cSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
300*73cbc91cSAdrian Kuegel 
301*73cbc91cSAdrian Kuegel     Value real =
302*73cbc91cSAdrian Kuegel         rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
303*73cbc91cSAdrian Kuegel     Value imag =
304*73cbc91cSAdrian Kuegel         rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
305*73cbc91cSAdrian Kuegel     Value expReal = rewriter.create<math::ExpOp>(loc, real);
306*73cbc91cSAdrian Kuegel     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
307*73cbc91cSAdrian Kuegel     Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag);
308*73cbc91cSAdrian Kuegel     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
309*73cbc91cSAdrian Kuegel     Value resultImag = rewriter.create<MulFOp>(loc, expReal, sinImag);
310*73cbc91cSAdrian Kuegel 
311*73cbc91cSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
312*73cbc91cSAdrian Kuegel                                                    resultImag);
313*73cbc91cSAdrian Kuegel     return success();
314*73cbc91cSAdrian Kuegel   }
315*73cbc91cSAdrian Kuegel };
3162ea7fb7bSAdrian Kuegel } // namespace
3172ea7fb7bSAdrian Kuegel 
3182ea7fb7bSAdrian Kuegel void mlir::populateComplexToStandardConversionPatterns(
3192ea7fb7bSAdrian Kuegel     RewritePatternSet &patterns) {
320fb8b2b86SAdrian Kuegel   patterns.add<AbsOpConversion,
321fb8b2b86SAdrian Kuegel                ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
322942be7cbSAdrian Kuegel                ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
323*73cbc91cSAdrian Kuegel                DivOpConversion, ExpOpConversion>(patterns.getContext());
3242ea7fb7bSAdrian Kuegel }
3252ea7fb7bSAdrian Kuegel 
3262ea7fb7bSAdrian Kuegel namespace {
3272ea7fb7bSAdrian Kuegel struct ConvertComplexToStandardPass
3282ea7fb7bSAdrian Kuegel     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
3292ea7fb7bSAdrian Kuegel   void runOnFunction() override;
3302ea7fb7bSAdrian Kuegel };
3312ea7fb7bSAdrian Kuegel 
3322ea7fb7bSAdrian Kuegel void ConvertComplexToStandardPass::runOnFunction() {
3332ea7fb7bSAdrian Kuegel   auto function = getFunction();
3342ea7fb7bSAdrian Kuegel 
3352ea7fb7bSAdrian Kuegel   // Convert to the Standard dialect using the converter defined above.
3362ea7fb7bSAdrian Kuegel   RewritePatternSet patterns(&getContext());
3372ea7fb7bSAdrian Kuegel   populateComplexToStandardConversionPatterns(patterns);
3382ea7fb7bSAdrian Kuegel 
3392ea7fb7bSAdrian Kuegel   ConversionTarget target(getContext());
3402ea7fb7bSAdrian Kuegel   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
3412ea7fb7bSAdrian Kuegel                          complex::ComplexDialect>();
342942be7cbSAdrian Kuegel   target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
343*73cbc91cSAdrian Kuegel                       complex::ExpOp, complex::NotEqualOp>();
3442ea7fb7bSAdrian Kuegel   if (failed(applyPartialConversion(function, target, std::move(patterns))))
3452ea7fb7bSAdrian Kuegel     signalPassFailure();
3462ea7fb7bSAdrian Kuegel }
3472ea7fb7bSAdrian Kuegel } // namespace
3482ea7fb7bSAdrian Kuegel 
3492ea7fb7bSAdrian Kuegel std::unique_ptr<OperationPass<FuncOp>>
3502ea7fb7bSAdrian Kuegel mlir::createConvertComplexToStandardPass() {
3512ea7fb7bSAdrian Kuegel   return std::make_unique<ConvertComplexToStandardPass>();
3522ea7fb7bSAdrian Kuegel }
353