1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Complex/IR/Complex.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
26   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
30                   ConversionPatternRewriter &rewriter) const override {
31     auto loc = op.getLoc();
32     auto type = op.getType();
33 
34     Value real =
35         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
36     Value imag =
37         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
38     Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
39     Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
40     Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
41 
42     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
43     return success();
44   }
45 };
46 
47 template <typename ComparisonOp, arith::CmpFPredicate p>
48 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
49   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
50   using ResultCombiner =
51       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
52                          arith::AndIOp, arith::OrIOp>;
53 
54   LogicalResult
55   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
56                   ConversionPatternRewriter &rewriter) const override {
57     auto loc = op.getLoc();
58     auto type = adaptor.getLhs()
59                     .getType()
60                     .template cast<ComplexType>()
61                     .getElementType();
62 
63     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
64     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
65     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
66     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
67     Value realComparison =
68         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
69     Value imagComparison =
70         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
71 
72     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
73                                                 imagComparison);
74     return success();
75   }
76 };
77 
78 // Default conversion which applies the BinaryStandardOp separately on the real
79 // and imaginary parts. Can for example be used for complex::AddOp and
80 // complex::SubOp.
81 template <typename BinaryComplexOp, typename BinaryStandardOp>
82 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
83   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
84 
85   LogicalResult
86   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
87                   ConversionPatternRewriter &rewriter) const override {
88     auto type = adaptor.getLhs().getType().template cast<ComplexType>();
89     auto elementType = type.getElementType().template cast<FloatType>();
90     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
91 
92     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
93     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
94     Value resultReal =
95         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
96     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
97     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
98     Value resultImag =
99         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
100     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
101                                                    resultImag);
102     return success();
103   }
104 };
105 
106 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
107   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
108 
109   LogicalResult
110   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
111                   ConversionPatternRewriter &rewriter) const override {
112     auto loc = op.getLoc();
113     auto type = adaptor.getLhs().getType().cast<ComplexType>();
114     auto elementType = type.getElementType().cast<FloatType>();
115 
116     Value lhsReal =
117         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
118     Value lhsImag =
119         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
120     Value rhsReal =
121         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
122     Value rhsImag =
123         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
124 
125     // Smith's algorithm to divide complex numbers. It is just a bit smarter
126     // way to compute the following formula:
127     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
128     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
129     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
130     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
131     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
132     //
133     // Depending on whether |rhsReal| < |rhsImag| we compute either
134     //   rhsRealImagRatio = rhsReal / rhsImag
135     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
136     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
137     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
138     //
139     // or
140     //
141     //   rhsImagRealRatio = rhsImag / rhsReal
142     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
143     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
144     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
145     //
146     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
147     Value rhsRealImagRatio =
148         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
149     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
150         loc, rhsImag,
151         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
152     Value realNumerator1 = rewriter.create<arith::AddFOp>(
153         loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
154         lhsImag);
155     Value resultReal1 =
156         rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
157     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
158         loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
159         lhsReal);
160     Value resultImag1 =
161         rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
162 
163     Value rhsImagRealRatio =
164         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
165     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
166         loc, rhsReal,
167         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
168     Value realNumerator2 = rewriter.create<arith::AddFOp>(
169         loc, lhsReal,
170         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
171     Value resultReal2 =
172         rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
173     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
174         loc, lhsImag,
175         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
176     Value resultImag2 =
177         rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
178 
179     // Consider corner cases.
180     // Case 1. Zero denominator, numerator contains at most one NaN value.
181     Value zero = rewriter.create<arith::ConstantOp>(
182         loc, elementType, rewriter.getZeroAttr(elementType));
183     Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
184     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
185         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
186     Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
187     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
188         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
189     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
190         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
191     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
192         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
193     Value lhsContainsNotNaNValue =
194         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
195     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
196         loc, lhsContainsNotNaNValue,
197         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
198     Value inf = rewriter.create<arith::ConstantOp>(
199         loc, elementType,
200         rewriter.getFloatAttr(
201             elementType, APFloat::getInf(elementType.getFloatSemantics())));
202     Value infWithSignOfRhsReal =
203         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
204     Value infinityResultReal =
205         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
206     Value infinityResultImag =
207         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
208 
209     // Case 2. Infinite numerator, finite denominator.
210     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
211         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
212     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
213         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
214     Value rhsFinite =
215         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
216     Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
217     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
218         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
219     Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
220     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
221         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
222     Value lhsInfinite =
223         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
224     Value infNumFiniteDenom =
225         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
226     Value one = rewriter.create<arith::ConstantOp>(
227         loc, elementType, rewriter.getFloatAttr(elementType, 1));
228     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
229         loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
230         lhsReal);
231     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
232         loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
233         lhsImag);
234     Value lhsRealIsInfWithSignTimesRhsReal =
235         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
236     Value lhsImagIsInfWithSignTimesRhsImag =
237         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
238     Value resultReal3 = rewriter.create<arith::MulFOp>(
239         loc, inf,
240         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
241                                        lhsImagIsInfWithSignTimesRhsImag));
242     Value lhsRealIsInfWithSignTimesRhsImag =
243         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
244     Value lhsImagIsInfWithSignTimesRhsReal =
245         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
246     Value resultImag3 = rewriter.create<arith::MulFOp>(
247         loc, inf,
248         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
249                                        lhsRealIsInfWithSignTimesRhsImag));
250 
251     // Case 3: Finite numerator, infinite denominator.
252     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
253         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
254     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
255         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
256     Value lhsFinite =
257         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
258     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
259         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
260     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
261         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
262     Value rhsInfinite =
263         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
264     Value finiteNumInfiniteDenom =
265         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
266     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
267         loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
268         rhsReal);
269     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
270         loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
271         rhsImag);
272     Value rhsRealIsInfWithSignTimesLhsReal =
273         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
274     Value rhsImagIsInfWithSignTimesLhsImag =
275         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
276     Value resultReal4 = rewriter.create<arith::MulFOp>(
277         loc, zero,
278         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
279                                        rhsImagIsInfWithSignTimesLhsImag));
280     Value rhsRealIsInfWithSignTimesLhsImag =
281         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
282     Value rhsImagIsInfWithSignTimesLhsReal =
283         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
284     Value resultImag4 = rewriter.create<arith::MulFOp>(
285         loc, zero,
286         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
287                                        rhsImagIsInfWithSignTimesLhsReal));
288 
289     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
290         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
291     Value resultReal = rewriter.create<arith::SelectOp>(
292         loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
293     Value resultImag = rewriter.create<arith::SelectOp>(
294         loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
295     Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
296         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
297     Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
298         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
299     Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
300         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
301     Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
302         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
303     Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
304         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
305     Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
306         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
307 
308     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
309         loc, arith::CmpFPredicate::UNO, resultReal, zero);
310     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
311         loc, arith::CmpFPredicate::UNO, resultImag, zero);
312     Value resultIsNaN =
313         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
314     Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
315         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
316     Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
317         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
318 
319     rewriter.replaceOpWithNewOp<complex::CreateOp>(
320         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
321     return success();
322   }
323 };
324 
325 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
326   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
327 
328   LogicalResult
329   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
330                   ConversionPatternRewriter &rewriter) const override {
331     auto loc = op.getLoc();
332     auto type = adaptor.getComplex().getType().cast<ComplexType>();
333     auto elementType = type.getElementType().cast<FloatType>();
334 
335     Value real =
336         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
337     Value imag =
338         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
339     Value expReal = rewriter.create<math::ExpOp>(loc, real);
340     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
341     Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
342     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
343     Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
344 
345     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
346                                                    resultImag);
347     return success();
348   }
349 };
350 
351 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
352   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
353 
354   LogicalResult
355   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
356                   ConversionPatternRewriter &rewriter) const override {
357     auto type = adaptor.getComplex().getType().cast<ComplexType>();
358     auto elementType = type.getElementType().cast<FloatType>();
359     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
360 
361     Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
362     Value resultReal = b.create<math::LogOp>(elementType, abs);
363     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
364     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
365     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
366     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
367                                                    resultImag);
368     return success();
369   }
370 };
371 
372 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
373   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
374 
375   LogicalResult
376   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
377                   ConversionPatternRewriter &rewriter) const override {
378     auto type = adaptor.getComplex().getType().cast<ComplexType>();
379     auto elementType = type.getElementType().cast<FloatType>();
380     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
381 
382     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
383     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
384     Value one = b.create<arith::ConstantOp>(elementType,
385                                             b.getFloatAttr(elementType, 1));
386     Value realPlusOne = b.create<arith::AddFOp>(real, one);
387     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
388     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
389     return success();
390   }
391 };
392 
393 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
394   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
395 
396   LogicalResult
397   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
398                   ConversionPatternRewriter &rewriter) const override {
399     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
400     auto type = adaptor.getLhs().getType().cast<ComplexType>();
401     auto elementType = type.getElementType().cast<FloatType>();
402 
403     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
404     Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
405     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
406     Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
407     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
408     Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
409     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
410     Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
411 
412     Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
413     Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
414     Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
415     Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
416     Value real =
417         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
418 
419     Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
420     Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
421     Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
422     Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
423     Value imag =
424         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
425 
426     // Handle cases where the "naive" calculation results in NaN values.
427     Value realIsNan =
428         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
429     Value imagIsNan =
430         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
431     Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
432 
433     Value inf = b.create<arith::ConstantOp>(
434         elementType,
435         b.getFloatAttr(elementType,
436                        APFloat::getInf(elementType.getFloatSemantics())));
437 
438     // Case 1. `lhsReal` or `lhsImag` are infinite.
439     Value lhsRealIsInf =
440         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
441     Value lhsImagIsInf =
442         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
443     Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
444     Value rhsRealIsNan =
445         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
446     Value rhsImagIsNan =
447         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
448     Value zero =
449         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
450     Value one = b.create<arith::ConstantOp>(elementType,
451                                             b.getFloatAttr(elementType, 1));
452     Value lhsRealIsInfFloat =
453         b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
454     lhsReal = b.create<arith::SelectOp>(
455         lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
456         lhsReal);
457     Value lhsImagIsInfFloat =
458         b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
459     lhsImag = b.create<arith::SelectOp>(
460         lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
461         lhsImag);
462     Value lhsIsInfAndRhsRealIsNan =
463         b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
464     rhsReal = b.create<arith::SelectOp>(
465         lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
466         rhsReal);
467     Value lhsIsInfAndRhsImagIsNan =
468         b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
469     rhsImag = b.create<arith::SelectOp>(
470         lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
471         rhsImag);
472 
473     // Case 2. `rhsReal` or `rhsImag` are infinite.
474     Value rhsRealIsInf =
475         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
476     Value rhsImagIsInf =
477         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
478     Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
479     Value lhsRealIsNan =
480         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
481     Value lhsImagIsNan =
482         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
483     Value rhsRealIsInfFloat =
484         b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
485     rhsReal = b.create<arith::SelectOp>(
486         rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
487         rhsReal);
488     Value rhsImagIsInfFloat =
489         b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
490     rhsImag = b.create<arith::SelectOp>(
491         rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
492         rhsImag);
493     Value rhsIsInfAndLhsRealIsNan =
494         b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
495     lhsReal = b.create<arith::SelectOp>(
496         rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
497         lhsReal);
498     Value rhsIsInfAndLhsImagIsNan =
499         b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
500     lhsImag = b.create<arith::SelectOp>(
501         rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
502         lhsImag);
503     Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
504 
505     // Case 3. One of the pairwise products of left hand side with right hand
506     // side is infinite.
507     Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
508         arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
509     Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
510         arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
511     Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
512                                                  lhsImagTimesRhsImagIsInf);
513     Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
514         arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
515     isSpecialCase =
516         b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
517     Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
518         arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
519     isSpecialCase =
520         b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
521     Type i1Type = b.getI1Type();
522     Value notRecalc = b.create<arith::XOrIOp>(
523         recalc,
524         b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
525     isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
526     Value isSpecialCaseAndLhsRealIsNan =
527         b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
528     lhsReal = b.create<arith::SelectOp>(
529         isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
530         lhsReal);
531     Value isSpecialCaseAndLhsImagIsNan =
532         b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
533     lhsImag = b.create<arith::SelectOp>(
534         isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
535         lhsImag);
536     Value isSpecialCaseAndRhsRealIsNan =
537         b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
538     rhsReal = b.create<arith::SelectOp>(
539         isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
540         rhsReal);
541     Value isSpecialCaseAndRhsImagIsNan =
542         b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
543     rhsImag = b.create<arith::SelectOp>(
544         isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
545         rhsImag);
546     recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
547     recalc = b.create<arith::AndIOp>(isNan, recalc);
548 
549     // Recalculate real part.
550     lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
551     lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
552     Value newReal =
553         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
554     real = b.create<arith::SelectOp>(
555         recalc, b.create<arith::MulFOp>(inf, newReal), real);
556 
557     // Recalculate imag part.
558     lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
559     lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
560     Value newImag =
561         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
562     imag = b.create<arith::SelectOp>(
563         recalc, b.create<arith::MulFOp>(inf, newImag), imag);
564 
565     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
566     return success();
567   }
568 };
569 
570 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
571   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
572 
573   LogicalResult
574   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
575                   ConversionPatternRewriter &rewriter) const override {
576     auto loc = op.getLoc();
577     auto type = adaptor.getComplex().getType().cast<ComplexType>();
578     auto elementType = type.getElementType().cast<FloatType>();
579 
580     Value real =
581         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
582     Value imag =
583         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
584     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
585     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
586     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
587     return success();
588   }
589 };
590 
591 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
592   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
593 
594   LogicalResult
595   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
596                   ConversionPatternRewriter &rewriter) const override {
597     auto type = adaptor.getComplex().getType().cast<ComplexType>();
598     auto elementType = type.getElementType().cast<FloatType>();
599     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
600 
601     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
602     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
603     Value zero =
604         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
605     Value realIsZero =
606         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
607     Value imagIsZero =
608         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
609     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
610     auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
611     Value realSign = b.create<arith::DivFOp>(real, abs);
612     Value imagSign = b.create<arith::DivFOp>(imag, abs);
613     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
614     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
615                                                  adaptor.getComplex(), sign);
616     return success();
617   }
618 };
619 } // namespace
620 
621 void mlir::populateComplexToStandardConversionPatterns(
622     RewritePatternSet &patterns) {
623   // clang-format off
624   patterns.add<
625       AbsOpConversion,
626       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
627       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
628       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
629       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
630       DivOpConversion,
631       ExpOpConversion,
632       LogOpConversion,
633       Log1pOpConversion,
634       MulOpConversion,
635       NegOpConversion,
636       SignOpConversion>(patterns.getContext());
637   // clang-format on
638 }
639 
640 namespace {
641 struct ConvertComplexToStandardPass
642     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
643   void runOnOperation() override;
644 };
645 
646 void ConvertComplexToStandardPass::runOnOperation() {
647   // Convert to the Standard dialect using the converter defined above.
648   RewritePatternSet patterns(&getContext());
649   populateComplexToStandardConversionPatterns(patterns);
650 
651   ConversionTarget target(getContext());
652   target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>();
653   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
654   if (failed(
655           applyPartialConversion(getOperation(), target, std::move(patterns))))
656     signalPassFailure();
657 }
658 } // namespace
659 
660 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
661   return std::make_unique<ConvertComplexToStandardPass>();
662 }
663