12ea7fb7bSAdrian Kuegel //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
22ea7fb7bSAdrian Kuegel //
32ea7fb7bSAdrian Kuegel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42ea7fb7bSAdrian Kuegel // See https://llvm.org/LICENSE.txt for license information.
52ea7fb7bSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62ea7fb7bSAdrian Kuegel //
72ea7fb7bSAdrian Kuegel //===----------------------------------------------------------------------===//
82ea7fb7bSAdrian Kuegel 
92ea7fb7bSAdrian Kuegel #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
102ea7fb7bSAdrian Kuegel 
112ea7fb7bSAdrian Kuegel #include <memory>
12fb8b2b86SAdrian Kuegel #include <type_traits>
132ea7fb7bSAdrian Kuegel 
142ea7fb7bSAdrian Kuegel #include "../PassDetail.h"
15*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
162ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Complex/IR/Complex.h"
172ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Math/IR/Math.h"
182ea7fb7bSAdrian Kuegel #include "mlir/Dialect/StandardOps/IR/Ops.h"
19f112bd61SAdrian Kuegel #include "mlir/IR/ImplicitLocOpBuilder.h"
202ea7fb7bSAdrian Kuegel #include "mlir/IR/PatternMatch.h"
212ea7fb7bSAdrian Kuegel #include "mlir/Transforms/DialectConversion.h"
222ea7fb7bSAdrian Kuegel 
232ea7fb7bSAdrian Kuegel using namespace mlir;
242ea7fb7bSAdrian Kuegel 
252ea7fb7bSAdrian Kuegel namespace {
262ea7fb7bSAdrian Kuegel struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
272ea7fb7bSAdrian Kuegel   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
282ea7fb7bSAdrian Kuegel 
292ea7fb7bSAdrian Kuegel   LogicalResult
30b54c724bSRiver Riddle   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
312ea7fb7bSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
322ea7fb7bSAdrian Kuegel     auto loc = op.getLoc();
332ea7fb7bSAdrian Kuegel     auto type = op.getType();
342ea7fb7bSAdrian Kuegel 
35b54c724bSRiver Riddle     Value real = rewriter.create<complex::ReOp>(loc, type, adaptor.complex());
36b54c724bSRiver Riddle     Value imag = rewriter.create<complex::ImOp>(loc, type, adaptor.complex());
37*a54f4eaeSMogball     Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
38*a54f4eaeSMogball     Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
39*a54f4eaeSMogball     Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
402ea7fb7bSAdrian Kuegel 
412ea7fb7bSAdrian Kuegel     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
422ea7fb7bSAdrian Kuegel     return success();
432ea7fb7bSAdrian Kuegel   }
442ea7fb7bSAdrian Kuegel };
45ac00cb0dSAdrian Kuegel 
46*a54f4eaeSMogball template <typename ComparisonOp, arith::CmpFPredicate p>
47fb8b2b86SAdrian Kuegel struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
48fb8b2b86SAdrian Kuegel   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
49fb8b2b86SAdrian Kuegel   using ResultCombiner =
50fb8b2b86SAdrian Kuegel       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
51*a54f4eaeSMogball                          arith::AndIOp, arith::OrIOp>;
52ac00cb0dSAdrian Kuegel 
53ac00cb0dSAdrian Kuegel   LogicalResult
54b54c724bSRiver Riddle   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
55ac00cb0dSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
56ac00cb0dSAdrian Kuegel     auto loc = op.getLoc();
57b54c724bSRiver Riddle     auto type =
58b54c724bSRiver Riddle         adaptor.lhs().getType().template cast<ComplexType>().getElementType();
59ac00cb0dSAdrian Kuegel 
60b54c724bSRiver Riddle     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.lhs());
61b54c724bSRiver Riddle     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.lhs());
62b54c724bSRiver Riddle     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.rhs());
63b54c724bSRiver Riddle     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.rhs());
64*a54f4eaeSMogball     Value realComparison =
65*a54f4eaeSMogball         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
66*a54f4eaeSMogball     Value imagComparison =
67*a54f4eaeSMogball         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
68ac00cb0dSAdrian Kuegel 
69fb8b2b86SAdrian Kuegel     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
70fb8b2b86SAdrian Kuegel                                                 imagComparison);
71ac00cb0dSAdrian Kuegel     return success();
72ac00cb0dSAdrian Kuegel   }
73ac00cb0dSAdrian Kuegel };
74942be7cbSAdrian Kuegel 
75fb978f09SAdrian Kuegel // Default conversion which applies the BinaryStandardOp separately on the real
76fb978f09SAdrian Kuegel // and imaginary parts. Can for example be used for complex::AddOp and
77fb978f09SAdrian Kuegel // complex::SubOp.
78fb978f09SAdrian Kuegel template <typename BinaryComplexOp, typename BinaryStandardOp>
79fb978f09SAdrian Kuegel struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
80fb978f09SAdrian Kuegel   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
81fb978f09SAdrian Kuegel 
82fb978f09SAdrian Kuegel   LogicalResult
83b54c724bSRiver Riddle   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
84fb978f09SAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
85b54c724bSRiver Riddle     auto type = adaptor.lhs().getType().template cast<ComplexType>();
86fb978f09SAdrian Kuegel     auto elementType = type.getElementType().template cast<FloatType>();
87fb978f09SAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
88fb978f09SAdrian Kuegel 
89b54c724bSRiver Riddle     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.lhs());
90b54c724bSRiver Riddle     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.rhs());
91fb978f09SAdrian Kuegel     Value resultReal =
92fb978f09SAdrian Kuegel         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
93b54c724bSRiver Riddle     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.lhs());
94b54c724bSRiver Riddle     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.rhs());
95fb978f09SAdrian Kuegel     Value resultImag =
96fb978f09SAdrian Kuegel         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
97fb978f09SAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
98fb978f09SAdrian Kuegel                                                    resultImag);
99fb978f09SAdrian Kuegel     return success();
100fb978f09SAdrian Kuegel   }
101fb978f09SAdrian Kuegel };
102fb978f09SAdrian Kuegel 
103942be7cbSAdrian Kuegel struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
104942be7cbSAdrian Kuegel   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
105942be7cbSAdrian Kuegel 
106942be7cbSAdrian Kuegel   LogicalResult
107b54c724bSRiver Riddle   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
108942be7cbSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
109942be7cbSAdrian Kuegel     auto loc = op.getLoc();
110b54c724bSRiver Riddle     auto type = adaptor.lhs().getType().cast<ComplexType>();
111942be7cbSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
112942be7cbSAdrian Kuegel 
113942be7cbSAdrian Kuegel     Value lhsReal =
114b54c724bSRiver Riddle         rewriter.create<complex::ReOp>(loc, elementType, adaptor.lhs());
115942be7cbSAdrian Kuegel     Value lhsImag =
116b54c724bSRiver Riddle         rewriter.create<complex::ImOp>(loc, elementType, adaptor.lhs());
117942be7cbSAdrian Kuegel     Value rhsReal =
118b54c724bSRiver Riddle         rewriter.create<complex::ReOp>(loc, elementType, adaptor.rhs());
119942be7cbSAdrian Kuegel     Value rhsImag =
120b54c724bSRiver Riddle         rewriter.create<complex::ImOp>(loc, elementType, adaptor.rhs());
121942be7cbSAdrian Kuegel 
122942be7cbSAdrian Kuegel     // Smith's algorithm to divide complex numbers. It is just a bit smarter
123942be7cbSAdrian Kuegel     // way to compute the following formula:
124942be7cbSAdrian Kuegel     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
125942be7cbSAdrian Kuegel     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
126942be7cbSAdrian Kuegel     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
127942be7cbSAdrian Kuegel     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
128942be7cbSAdrian Kuegel     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
129942be7cbSAdrian Kuegel     //
130942be7cbSAdrian Kuegel     // Depending on whether |rhsReal| < |rhsImag| we compute either
131942be7cbSAdrian Kuegel     //   rhsRealImagRatio = rhsReal / rhsImag
132942be7cbSAdrian Kuegel     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
133942be7cbSAdrian Kuegel     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
134942be7cbSAdrian Kuegel     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
135942be7cbSAdrian Kuegel     //
136942be7cbSAdrian Kuegel     // or
137942be7cbSAdrian Kuegel     //
138942be7cbSAdrian Kuegel     //   rhsImagRealRatio = rhsImag / rhsReal
139942be7cbSAdrian Kuegel     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
140942be7cbSAdrian Kuegel     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
141942be7cbSAdrian Kuegel     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
142942be7cbSAdrian Kuegel     //
143942be7cbSAdrian Kuegel     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
144*a54f4eaeSMogball     Value rhsRealImagRatio =
145*a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
146*a54f4eaeSMogball     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
147*a54f4eaeSMogball         loc, rhsImag,
148*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
149*a54f4eaeSMogball     Value realNumerator1 = rewriter.create<arith::AddFOp>(
150*a54f4eaeSMogball         loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
151*a54f4eaeSMogball         lhsImag);
152942be7cbSAdrian Kuegel     Value resultReal1 =
153*a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
154*a54f4eaeSMogball     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
155*a54f4eaeSMogball         loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
156*a54f4eaeSMogball         lhsReal);
157942be7cbSAdrian Kuegel     Value resultImag1 =
158*a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
159942be7cbSAdrian Kuegel 
160*a54f4eaeSMogball     Value rhsImagRealRatio =
161*a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
162*a54f4eaeSMogball     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
163*a54f4eaeSMogball         loc, rhsReal,
164*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
165*a54f4eaeSMogball     Value realNumerator2 = rewriter.create<arith::AddFOp>(
166*a54f4eaeSMogball         loc, lhsReal,
167*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
168942be7cbSAdrian Kuegel     Value resultReal2 =
169*a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
170*a54f4eaeSMogball     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
171*a54f4eaeSMogball         loc, lhsImag,
172*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
173942be7cbSAdrian Kuegel     Value resultImag2 =
174*a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
175942be7cbSAdrian Kuegel 
176942be7cbSAdrian Kuegel     // Consider corner cases.
177942be7cbSAdrian Kuegel     // Case 1. Zero denominator, numerator contains at most one NaN value.
178*a54f4eaeSMogball     Value zero = rewriter.create<arith::ConstantOp>(
179*a54f4eaeSMogball         loc, elementType, rewriter.getZeroAttr(elementType));
180*a54f4eaeSMogball     Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
181*a54f4eaeSMogball     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
182*a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
183*a54f4eaeSMogball     Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
184*a54f4eaeSMogball     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
185*a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
186*a54f4eaeSMogball     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
187*a54f4eaeSMogball         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
188*a54f4eaeSMogball     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
189*a54f4eaeSMogball         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
190942be7cbSAdrian Kuegel     Value lhsContainsNotNaNValue =
191*a54f4eaeSMogball         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
192*a54f4eaeSMogball     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
193942be7cbSAdrian Kuegel         loc, lhsContainsNotNaNValue,
194*a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
195*a54f4eaeSMogball     Value inf = rewriter.create<arith::ConstantOp>(
196942be7cbSAdrian Kuegel         loc, elementType,
197942be7cbSAdrian Kuegel         rewriter.getFloatAttr(
198942be7cbSAdrian Kuegel             elementType, APFloat::getInf(elementType.getFloatSemantics())));
199*a54f4eaeSMogball     Value infWithSignOfRhsReal =
200*a54f4eaeSMogball         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
201942be7cbSAdrian Kuegel     Value infinityResultReal =
202*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
203942be7cbSAdrian Kuegel     Value infinityResultImag =
204*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
205942be7cbSAdrian Kuegel 
206942be7cbSAdrian Kuegel     // Case 2. Infinite numerator, finite denominator.
207*a54f4eaeSMogball     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
208*a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
209*a54f4eaeSMogball     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
210*a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
211*a54f4eaeSMogball     Value rhsFinite =
212*a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
213*a54f4eaeSMogball     Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
214*a54f4eaeSMogball     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
215*a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
216*a54f4eaeSMogball     Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
217*a54f4eaeSMogball     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
218*a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
219942be7cbSAdrian Kuegel     Value lhsInfinite =
220*a54f4eaeSMogball         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
221942be7cbSAdrian Kuegel     Value infNumFiniteDenom =
222*a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
223*a54f4eaeSMogball     Value one = rewriter.create<arith::ConstantOp>(
224942be7cbSAdrian Kuegel         loc, elementType, rewriter.getFloatAttr(elementType, 1));
225*a54f4eaeSMogball     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
226942be7cbSAdrian Kuegel         loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero),
227942be7cbSAdrian Kuegel         lhsReal);
228*a54f4eaeSMogball     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
229942be7cbSAdrian Kuegel         loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero),
230942be7cbSAdrian Kuegel         lhsImag);
231942be7cbSAdrian Kuegel     Value lhsRealIsInfWithSignTimesRhsReal =
232*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
233942be7cbSAdrian Kuegel     Value lhsImagIsInfWithSignTimesRhsImag =
234*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
235*a54f4eaeSMogball     Value resultReal3 = rewriter.create<arith::MulFOp>(
236942be7cbSAdrian Kuegel         loc, inf,
237*a54f4eaeSMogball         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
238942be7cbSAdrian Kuegel                                        lhsImagIsInfWithSignTimesRhsImag));
239942be7cbSAdrian Kuegel     Value lhsRealIsInfWithSignTimesRhsImag =
240*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
241942be7cbSAdrian Kuegel     Value lhsImagIsInfWithSignTimesRhsReal =
242*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
243*a54f4eaeSMogball     Value resultImag3 = rewriter.create<arith::MulFOp>(
244942be7cbSAdrian Kuegel         loc, inf,
245*a54f4eaeSMogball         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
246942be7cbSAdrian Kuegel                                        lhsRealIsInfWithSignTimesRhsImag));
247942be7cbSAdrian Kuegel 
248942be7cbSAdrian Kuegel     // Case 3: Finite numerator, infinite denominator.
249*a54f4eaeSMogball     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
250*a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
251*a54f4eaeSMogball     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
252*a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
253*a54f4eaeSMogball     Value lhsFinite =
254*a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
255*a54f4eaeSMogball     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
256*a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
257*a54f4eaeSMogball     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
258*a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
259942be7cbSAdrian Kuegel     Value rhsInfinite =
260*a54f4eaeSMogball         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
261942be7cbSAdrian Kuegel     Value finiteNumInfiniteDenom =
262*a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
263*a54f4eaeSMogball     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
264942be7cbSAdrian Kuegel         loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero),
265942be7cbSAdrian Kuegel         rhsReal);
266*a54f4eaeSMogball     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
267942be7cbSAdrian Kuegel         loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero),
268942be7cbSAdrian Kuegel         rhsImag);
269942be7cbSAdrian Kuegel     Value rhsRealIsInfWithSignTimesLhsReal =
270*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
271942be7cbSAdrian Kuegel     Value rhsImagIsInfWithSignTimesLhsImag =
272*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
273*a54f4eaeSMogball     Value resultReal4 = rewriter.create<arith::MulFOp>(
274942be7cbSAdrian Kuegel         loc, zero,
275*a54f4eaeSMogball         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
276942be7cbSAdrian Kuegel                                        rhsImagIsInfWithSignTimesLhsImag));
277942be7cbSAdrian Kuegel     Value rhsRealIsInfWithSignTimesLhsImag =
278*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
279942be7cbSAdrian Kuegel     Value rhsImagIsInfWithSignTimesLhsReal =
280*a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
281*a54f4eaeSMogball     Value resultImag4 = rewriter.create<arith::MulFOp>(
282942be7cbSAdrian Kuegel         loc, zero,
283*a54f4eaeSMogball         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
284942be7cbSAdrian Kuegel                                        rhsImagIsInfWithSignTimesLhsReal));
285942be7cbSAdrian Kuegel 
286*a54f4eaeSMogball     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
287*a54f4eaeSMogball         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
288942be7cbSAdrian Kuegel     Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
289942be7cbSAdrian Kuegel                                                  resultReal1, resultReal2);
290942be7cbSAdrian Kuegel     Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
291942be7cbSAdrian Kuegel                                                  resultImag1, resultImag2);
292942be7cbSAdrian Kuegel     Value resultRealSpecialCase3 = rewriter.create<SelectOp>(
293942be7cbSAdrian Kuegel         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
294942be7cbSAdrian Kuegel     Value resultImagSpecialCase3 = rewriter.create<SelectOp>(
295942be7cbSAdrian Kuegel         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
296942be7cbSAdrian Kuegel     Value resultRealSpecialCase2 = rewriter.create<SelectOp>(
297942be7cbSAdrian Kuegel         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
298942be7cbSAdrian Kuegel     Value resultImagSpecialCase2 = rewriter.create<SelectOp>(
299942be7cbSAdrian Kuegel         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
300942be7cbSAdrian Kuegel     Value resultRealSpecialCase1 = rewriter.create<SelectOp>(
301942be7cbSAdrian Kuegel         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
302942be7cbSAdrian Kuegel     Value resultImagSpecialCase1 = rewriter.create<SelectOp>(
303942be7cbSAdrian Kuegel         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
304942be7cbSAdrian Kuegel 
305*a54f4eaeSMogball     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
306*a54f4eaeSMogball         loc, arith::CmpFPredicate::UNO, resultReal, zero);
307*a54f4eaeSMogball     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
308*a54f4eaeSMogball         loc, arith::CmpFPredicate::UNO, resultImag, zero);
309942be7cbSAdrian Kuegel     Value resultIsNaN =
310*a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
311942be7cbSAdrian Kuegel     Value resultRealWithSpecialCases = rewriter.create<SelectOp>(
312942be7cbSAdrian Kuegel         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
313942be7cbSAdrian Kuegel     Value resultImagWithSpecialCases = rewriter.create<SelectOp>(
314942be7cbSAdrian Kuegel         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
315942be7cbSAdrian Kuegel 
316942be7cbSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(
317942be7cbSAdrian Kuegel         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
318942be7cbSAdrian Kuegel     return success();
319942be7cbSAdrian Kuegel   }
320942be7cbSAdrian Kuegel };
32173cbc91cSAdrian Kuegel 
32273cbc91cSAdrian Kuegel struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
32373cbc91cSAdrian Kuegel   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
32473cbc91cSAdrian Kuegel 
32573cbc91cSAdrian Kuegel   LogicalResult
326b54c724bSRiver Riddle   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
32773cbc91cSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
32873cbc91cSAdrian Kuegel     auto loc = op.getLoc();
329b54c724bSRiver Riddle     auto type = adaptor.complex().getType().cast<ComplexType>();
33073cbc91cSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
33173cbc91cSAdrian Kuegel 
33273cbc91cSAdrian Kuegel     Value real =
333b54c724bSRiver Riddle         rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex());
33473cbc91cSAdrian Kuegel     Value imag =
335b54c724bSRiver Riddle         rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex());
33673cbc91cSAdrian Kuegel     Value expReal = rewriter.create<math::ExpOp>(loc, real);
33773cbc91cSAdrian Kuegel     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
338*a54f4eaeSMogball     Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
33973cbc91cSAdrian Kuegel     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
340*a54f4eaeSMogball     Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
34173cbc91cSAdrian Kuegel 
34273cbc91cSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
34373cbc91cSAdrian Kuegel                                                    resultImag);
34473cbc91cSAdrian Kuegel     return success();
34573cbc91cSAdrian Kuegel   }
34673cbc91cSAdrian Kuegel };
347662e074dSAdrian Kuegel 
348380fa71fSAdrian Kuegel struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
349380fa71fSAdrian Kuegel   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
350380fa71fSAdrian Kuegel 
351380fa71fSAdrian Kuegel   LogicalResult
352b54c724bSRiver Riddle   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
353380fa71fSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
354b54c724bSRiver Riddle     auto type = adaptor.complex().getType().cast<ComplexType>();
355380fa71fSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
356380fa71fSAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
357380fa71fSAdrian Kuegel 
358b54c724bSRiver Riddle     Value abs = b.create<complex::AbsOp>(elementType, adaptor.complex());
359380fa71fSAdrian Kuegel     Value resultReal = b.create<math::LogOp>(elementType, abs);
360b54c724bSRiver Riddle     Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
361b54c724bSRiver Riddle     Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
362380fa71fSAdrian Kuegel     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
363380fa71fSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
364380fa71fSAdrian Kuegel                                                    resultImag);
365380fa71fSAdrian Kuegel     return success();
366380fa71fSAdrian Kuegel   }
367380fa71fSAdrian Kuegel };
368380fa71fSAdrian Kuegel 
3696e80e3bdSAdrian Kuegel struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
3706e80e3bdSAdrian Kuegel   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
3716e80e3bdSAdrian Kuegel 
3726e80e3bdSAdrian Kuegel   LogicalResult
373b54c724bSRiver Riddle   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
3746e80e3bdSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
375b54c724bSRiver Riddle     auto type = adaptor.complex().getType().cast<ComplexType>();
3766e80e3bdSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
3776e80e3bdSAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
3786e80e3bdSAdrian Kuegel 
379b54c724bSRiver Riddle     Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
380b54c724bSRiver Riddle     Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
381*a54f4eaeSMogball     Value one = b.create<arith::ConstantOp>(elementType,
382*a54f4eaeSMogball                                             b.getFloatAttr(elementType, 1));
383*a54f4eaeSMogball     Value realPlusOne = b.create<arith::AddFOp>(real, one);
3846e80e3bdSAdrian Kuegel     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
3856e80e3bdSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
3866e80e3bdSAdrian Kuegel     return success();
3876e80e3bdSAdrian Kuegel   }
3886e80e3bdSAdrian Kuegel };
3896e80e3bdSAdrian Kuegel 
390bf17ee19SAdrian Kuegel struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
391bf17ee19SAdrian Kuegel   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
392bf17ee19SAdrian Kuegel 
393bf17ee19SAdrian Kuegel   LogicalResult
394b54c724bSRiver Riddle   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
395bf17ee19SAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
396bf17ee19SAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
397b54c724bSRiver Riddle     auto type = adaptor.lhs().getType().cast<ComplexType>();
398bf17ee19SAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
399bf17ee19SAdrian Kuegel 
400b54c724bSRiver Riddle     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.lhs());
401*a54f4eaeSMogball     Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
402b54c724bSRiver Riddle     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.lhs());
403*a54f4eaeSMogball     Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
404b54c724bSRiver Riddle     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.rhs());
405*a54f4eaeSMogball     Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
406b54c724bSRiver Riddle     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.rhs());
407*a54f4eaeSMogball     Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
408bf17ee19SAdrian Kuegel 
409*a54f4eaeSMogball     Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
410*a54f4eaeSMogball     Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
411*a54f4eaeSMogball     Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
412*a54f4eaeSMogball     Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
413*a54f4eaeSMogball     Value real =
414*a54f4eaeSMogball         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
415bf17ee19SAdrian Kuegel 
416*a54f4eaeSMogball     Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
417*a54f4eaeSMogball     Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
418*a54f4eaeSMogball     Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
419*a54f4eaeSMogball     Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
420*a54f4eaeSMogball     Value imag =
421*a54f4eaeSMogball         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
422bf17ee19SAdrian Kuegel 
423bf17ee19SAdrian Kuegel     // Handle cases where the "naive" calculation results in NaN values.
424*a54f4eaeSMogball     Value realIsNan =
425*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
426*a54f4eaeSMogball     Value imagIsNan =
427*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
428*a54f4eaeSMogball     Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
429bf17ee19SAdrian Kuegel 
430*a54f4eaeSMogball     Value inf = b.create<arith::ConstantOp>(
431bf17ee19SAdrian Kuegel         elementType,
432bf17ee19SAdrian Kuegel         b.getFloatAttr(elementType,
433bf17ee19SAdrian Kuegel                        APFloat::getInf(elementType.getFloatSemantics())));
434bf17ee19SAdrian Kuegel 
435bf17ee19SAdrian Kuegel     // Case 1. `lhsReal` or `lhsImag` are infinite.
436*a54f4eaeSMogball     Value lhsRealIsInf =
437*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
438*a54f4eaeSMogball     Value lhsImagIsInf =
439*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
440*a54f4eaeSMogball     Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
441*a54f4eaeSMogball     Value rhsRealIsNan =
442*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
443*a54f4eaeSMogball     Value rhsImagIsNan =
444*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
445*a54f4eaeSMogball     Value zero =
446*a54f4eaeSMogball         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
447*a54f4eaeSMogball     Value one = b.create<arith::ConstantOp>(elementType,
448*a54f4eaeSMogball                                             b.getFloatAttr(elementType, 1));
449bf17ee19SAdrian Kuegel     Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero);
450bf17ee19SAdrian Kuegel     lhsReal = b.create<SelectOp>(
451*a54f4eaeSMogball         lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
452*a54f4eaeSMogball         lhsReal);
453bf17ee19SAdrian Kuegel     Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero);
454bf17ee19SAdrian Kuegel     lhsImag = b.create<SelectOp>(
455*a54f4eaeSMogball         lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
456*a54f4eaeSMogball         lhsImag);
457*a54f4eaeSMogball     Value lhsIsInfAndRhsRealIsNan =
458*a54f4eaeSMogball         b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
459*a54f4eaeSMogball     rhsReal =
460*a54f4eaeSMogball         b.create<SelectOp>(lhsIsInfAndRhsRealIsNan,
461*a54f4eaeSMogball                            b.create<math::CopySignOp>(zero, rhsReal), rhsReal);
462*a54f4eaeSMogball     Value lhsIsInfAndRhsImagIsNan =
463*a54f4eaeSMogball         b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
464*a54f4eaeSMogball     rhsImag =
465*a54f4eaeSMogball         b.create<SelectOp>(lhsIsInfAndRhsImagIsNan,
466*a54f4eaeSMogball                            b.create<math::CopySignOp>(zero, rhsImag), rhsImag);
467bf17ee19SAdrian Kuegel 
468bf17ee19SAdrian Kuegel     // Case 2. `rhsReal` or `rhsImag` are infinite.
469*a54f4eaeSMogball     Value rhsRealIsInf =
470*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
471*a54f4eaeSMogball     Value rhsImagIsInf =
472*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
473*a54f4eaeSMogball     Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
474*a54f4eaeSMogball     Value lhsRealIsNan =
475*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
476*a54f4eaeSMogball     Value lhsImagIsNan =
477*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
478bf17ee19SAdrian Kuegel     Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero);
479bf17ee19SAdrian Kuegel     rhsReal = b.create<SelectOp>(
480*a54f4eaeSMogball         rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
481*a54f4eaeSMogball         rhsReal);
482bf17ee19SAdrian Kuegel     Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero);
483bf17ee19SAdrian Kuegel     rhsImag = b.create<SelectOp>(
484*a54f4eaeSMogball         rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
485*a54f4eaeSMogball         rhsImag);
486*a54f4eaeSMogball     Value rhsIsInfAndLhsRealIsNan =
487*a54f4eaeSMogball         b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
488*a54f4eaeSMogball     lhsReal =
489*a54f4eaeSMogball         b.create<SelectOp>(rhsIsInfAndLhsRealIsNan,
490*a54f4eaeSMogball                            b.create<math::CopySignOp>(zero, lhsReal), lhsReal);
491*a54f4eaeSMogball     Value rhsIsInfAndLhsImagIsNan =
492*a54f4eaeSMogball         b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
493*a54f4eaeSMogball     lhsImag =
494*a54f4eaeSMogball         b.create<SelectOp>(rhsIsInfAndLhsImagIsNan,
495*a54f4eaeSMogball                            b.create<math::CopySignOp>(zero, lhsImag), lhsImag);
496*a54f4eaeSMogball     Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
497bf17ee19SAdrian Kuegel 
498bf17ee19SAdrian Kuegel     // Case 3. One of the pairwise products of left hand side with right hand
499bf17ee19SAdrian Kuegel     // side is infinite.
500*a54f4eaeSMogball     Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
501*a54f4eaeSMogball         arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
502*a54f4eaeSMogball     Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
503*a54f4eaeSMogball         arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
504*a54f4eaeSMogball     Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
505*a54f4eaeSMogball                                                  lhsImagTimesRhsImagIsInf);
506*a54f4eaeSMogball     Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
507*a54f4eaeSMogball         arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
508*a54f4eaeSMogball     isSpecialCase =
509*a54f4eaeSMogball         b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
510*a54f4eaeSMogball     Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
511*a54f4eaeSMogball         arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
512*a54f4eaeSMogball     isSpecialCase =
513*a54f4eaeSMogball         b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
514bf17ee19SAdrian Kuegel     Type i1Type = b.getI1Type();
515*a54f4eaeSMogball     Value notRecalc = b.create<arith::XOrIOp>(
516*a54f4eaeSMogball         recalc,
517*a54f4eaeSMogball         b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
518*a54f4eaeSMogball     isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
519bf17ee19SAdrian Kuegel     Value isSpecialCaseAndLhsRealIsNan =
520*a54f4eaeSMogball         b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
521*a54f4eaeSMogball     lhsReal =
522*a54f4eaeSMogball         b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan,
523*a54f4eaeSMogball                            b.create<math::CopySignOp>(zero, lhsReal), lhsReal);
524bf17ee19SAdrian Kuegel     Value isSpecialCaseAndLhsImagIsNan =
525*a54f4eaeSMogball         b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
526*a54f4eaeSMogball     lhsImag =
527*a54f4eaeSMogball         b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan,
528*a54f4eaeSMogball                            b.create<math::CopySignOp>(zero, lhsImag), lhsImag);
529bf17ee19SAdrian Kuegel     Value isSpecialCaseAndRhsRealIsNan =
530*a54f4eaeSMogball         b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
531*a54f4eaeSMogball     rhsReal =
532*a54f4eaeSMogball         b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan,
533*a54f4eaeSMogball                            b.create<math::CopySignOp>(zero, rhsReal), rhsReal);
534bf17ee19SAdrian Kuegel     Value isSpecialCaseAndRhsImagIsNan =
535*a54f4eaeSMogball         b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
536*a54f4eaeSMogball     rhsImag =
537*a54f4eaeSMogball         b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan,
538*a54f4eaeSMogball                            b.create<math::CopySignOp>(zero, rhsImag), rhsImag);
539*a54f4eaeSMogball     recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
540*a54f4eaeSMogball     recalc = b.create<arith::AndIOp>(isNan, recalc);
541bf17ee19SAdrian Kuegel 
542bf17ee19SAdrian Kuegel     // Recalculate real part.
543*a54f4eaeSMogball     lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
544*a54f4eaeSMogball     lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
545*a54f4eaeSMogball     Value newReal =
546*a54f4eaeSMogball         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
547*a54f4eaeSMogball     real =
548*a54f4eaeSMogball         b.create<SelectOp>(recalc, b.create<arith::MulFOp>(inf, newReal), real);
549bf17ee19SAdrian Kuegel 
550bf17ee19SAdrian Kuegel     // Recalculate imag part.
551*a54f4eaeSMogball     lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
552*a54f4eaeSMogball     lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
553*a54f4eaeSMogball     Value newImag =
554*a54f4eaeSMogball         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
555*a54f4eaeSMogball     imag =
556*a54f4eaeSMogball         b.create<SelectOp>(recalc, b.create<arith::MulFOp>(inf, newImag), imag);
557bf17ee19SAdrian Kuegel 
558bf17ee19SAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
559bf17ee19SAdrian Kuegel     return success();
560bf17ee19SAdrian Kuegel   }
561bf17ee19SAdrian Kuegel };
562bf17ee19SAdrian Kuegel 
563662e074dSAdrian Kuegel struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
564662e074dSAdrian Kuegel   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
565662e074dSAdrian Kuegel 
566662e074dSAdrian Kuegel   LogicalResult
567b54c724bSRiver Riddle   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
568662e074dSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
569662e074dSAdrian Kuegel     auto loc = op.getLoc();
570b54c724bSRiver Riddle     auto type = adaptor.complex().getType().cast<ComplexType>();
571662e074dSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
572662e074dSAdrian Kuegel 
573662e074dSAdrian Kuegel     Value real =
574b54c724bSRiver Riddle         rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex());
575662e074dSAdrian Kuegel     Value imag =
576b54c724bSRiver Riddle         rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex());
577*a54f4eaeSMogball     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
578*a54f4eaeSMogball     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
579662e074dSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
580662e074dSAdrian Kuegel     return success();
581662e074dSAdrian Kuegel   }
582662e074dSAdrian Kuegel };
583f112bd61SAdrian Kuegel 
584f112bd61SAdrian Kuegel struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
585f112bd61SAdrian Kuegel   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
586f112bd61SAdrian Kuegel 
587f112bd61SAdrian Kuegel   LogicalResult
588b54c724bSRiver Riddle   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
589f112bd61SAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
590b54c724bSRiver Riddle     auto type = adaptor.complex().getType().cast<ComplexType>();
591f112bd61SAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
592f112bd61SAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
593f112bd61SAdrian Kuegel 
594b54c724bSRiver Riddle     Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
595b54c724bSRiver Riddle     Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
596*a54f4eaeSMogball     Value zero =
597*a54f4eaeSMogball         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
598*a54f4eaeSMogball     Value realIsZero =
599*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
600*a54f4eaeSMogball     Value imagIsZero =
601*a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
602*a54f4eaeSMogball     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
603b54c724bSRiver Riddle     auto abs = b.create<complex::AbsOp>(elementType, adaptor.complex());
604*a54f4eaeSMogball     Value realSign = b.create<arith::DivFOp>(real, abs);
605*a54f4eaeSMogball     Value imagSign = b.create<arith::DivFOp>(imag, abs);
606f112bd61SAdrian Kuegel     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
607b54c724bSRiver Riddle     rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, adaptor.complex(), sign);
608f112bd61SAdrian Kuegel     return success();
609f112bd61SAdrian Kuegel   }
610f112bd61SAdrian Kuegel };
6112ea7fb7bSAdrian Kuegel } // namespace
6122ea7fb7bSAdrian Kuegel 
6132ea7fb7bSAdrian Kuegel void mlir::populateComplexToStandardConversionPatterns(
6142ea7fb7bSAdrian Kuegel     RewritePatternSet &patterns) {
615f112bd61SAdrian Kuegel   // clang-format off
616f112bd61SAdrian Kuegel   patterns.add<
617f112bd61SAdrian Kuegel       AbsOpConversion,
618*a54f4eaeSMogball       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
619*a54f4eaeSMogball       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
620*a54f4eaeSMogball       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
621*a54f4eaeSMogball       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
622f112bd61SAdrian Kuegel       DivOpConversion,
623f112bd61SAdrian Kuegel       ExpOpConversion,
624380fa71fSAdrian Kuegel       LogOpConversion,
6256e80e3bdSAdrian Kuegel       Log1pOpConversion,
626bf17ee19SAdrian Kuegel       MulOpConversion,
627f112bd61SAdrian Kuegel       NegOpConversion,
628f112bd61SAdrian Kuegel       SignOpConversion>(patterns.getContext());
629f112bd61SAdrian Kuegel   // clang-format on
6302ea7fb7bSAdrian Kuegel }
6312ea7fb7bSAdrian Kuegel 
6322ea7fb7bSAdrian Kuegel namespace {
6332ea7fb7bSAdrian Kuegel struct ConvertComplexToStandardPass
6342ea7fb7bSAdrian Kuegel     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
6352ea7fb7bSAdrian Kuegel   void runOnFunction() override;
6362ea7fb7bSAdrian Kuegel };
6372ea7fb7bSAdrian Kuegel 
6382ea7fb7bSAdrian Kuegel void ConvertComplexToStandardPass::runOnFunction() {
6392ea7fb7bSAdrian Kuegel   auto function = getFunction();
6402ea7fb7bSAdrian Kuegel 
6412ea7fb7bSAdrian Kuegel   // Convert to the Standard dialect using the converter defined above.
6422ea7fb7bSAdrian Kuegel   RewritePatternSet patterns(&getContext());
6432ea7fb7bSAdrian Kuegel   populateComplexToStandardConversionPatterns(patterns);
6442ea7fb7bSAdrian Kuegel 
6452ea7fb7bSAdrian Kuegel   ConversionTarget target(getContext());
646*a54f4eaeSMogball   target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect,
647*a54f4eaeSMogball                          math::MathDialect>();
648fb978f09SAdrian Kuegel   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
6492ea7fb7bSAdrian Kuegel   if (failed(applyPartialConversion(function, target, std::move(patterns))))
6502ea7fb7bSAdrian Kuegel     signalPassFailure();
6512ea7fb7bSAdrian Kuegel }
6522ea7fb7bSAdrian Kuegel } // namespace
6532ea7fb7bSAdrian Kuegel 
6542ea7fb7bSAdrian Kuegel std::unique_ptr<OperationPass<FuncOp>>
6552ea7fb7bSAdrian Kuegel mlir::createConvertComplexToStandardPass() {
6562ea7fb7bSAdrian Kuegel   return std::make_unique<ConvertComplexToStandardPass>();
6572ea7fb7bSAdrian Kuegel }
658