1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Complex/IR/Complex.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
27   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
28 
29   LogicalResult
30   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
31                   ConversionPatternRewriter &rewriter) const override {
32     auto loc = op.getLoc();
33     auto type = op.getType();
34 
35     Value real = rewriter.create<complex::ReOp>(loc, type, adaptor.complex());
36     Value imag = rewriter.create<complex::ImOp>(loc, type, adaptor.complex());
37     Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
38     Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
39     Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
40 
41     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
42     return success();
43   }
44 };
45 
46 template <typename ComparisonOp, arith::CmpFPredicate p>
47 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
48   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
49   using ResultCombiner =
50       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
51                          arith::AndIOp, arith::OrIOp>;
52 
53   LogicalResult
54   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
55                   ConversionPatternRewriter &rewriter) const override {
56     auto loc = op.getLoc();
57     auto type =
58         adaptor.lhs().getType().template cast<ComplexType>().getElementType();
59 
60     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.lhs());
61     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.lhs());
62     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.rhs());
63     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.rhs());
64     Value realComparison =
65         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
66     Value imagComparison =
67         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
68 
69     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
70                                                 imagComparison);
71     return success();
72   }
73 };
74 
75 // Default conversion which applies the BinaryStandardOp separately on the real
76 // and imaginary parts. Can for example be used for complex::AddOp and
77 // complex::SubOp.
78 template <typename BinaryComplexOp, typename BinaryStandardOp>
79 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
80   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
81 
82   LogicalResult
83   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
84                   ConversionPatternRewriter &rewriter) const override {
85     auto type = adaptor.lhs().getType().template cast<ComplexType>();
86     auto elementType = type.getElementType().template cast<FloatType>();
87     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
88 
89     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.lhs());
90     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.rhs());
91     Value resultReal =
92         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
93     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.lhs());
94     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.rhs());
95     Value resultImag =
96         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
97     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
98                                                    resultImag);
99     return success();
100   }
101 };
102 
103 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
104   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
105 
106   LogicalResult
107   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const override {
109     auto loc = op.getLoc();
110     auto type = adaptor.lhs().getType().cast<ComplexType>();
111     auto elementType = type.getElementType().cast<FloatType>();
112 
113     Value lhsReal =
114         rewriter.create<complex::ReOp>(loc, elementType, adaptor.lhs());
115     Value lhsImag =
116         rewriter.create<complex::ImOp>(loc, elementType, adaptor.lhs());
117     Value rhsReal =
118         rewriter.create<complex::ReOp>(loc, elementType, adaptor.rhs());
119     Value rhsImag =
120         rewriter.create<complex::ImOp>(loc, elementType, adaptor.rhs());
121 
122     // Smith's algorithm to divide complex numbers. It is just a bit smarter
123     // way to compute the following formula:
124     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
125     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
126     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
127     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
128     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
129     //
130     // Depending on whether |rhsReal| < |rhsImag| we compute either
131     //   rhsRealImagRatio = rhsReal / rhsImag
132     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
133     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
134     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
135     //
136     // or
137     //
138     //   rhsImagRealRatio = rhsImag / rhsReal
139     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
140     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
141     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
142     //
143     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
144     Value rhsRealImagRatio =
145         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
146     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
147         loc, rhsImag,
148         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
149     Value realNumerator1 = rewriter.create<arith::AddFOp>(
150         loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
151         lhsImag);
152     Value resultReal1 =
153         rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
154     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
155         loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
156         lhsReal);
157     Value resultImag1 =
158         rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
159 
160     Value rhsImagRealRatio =
161         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
162     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
163         loc, rhsReal,
164         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
165     Value realNumerator2 = rewriter.create<arith::AddFOp>(
166         loc, lhsReal,
167         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
168     Value resultReal2 =
169         rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
170     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
171         loc, lhsImag,
172         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
173     Value resultImag2 =
174         rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
175 
176     // Consider corner cases.
177     // Case 1. Zero denominator, numerator contains at most one NaN value.
178     Value zero = rewriter.create<arith::ConstantOp>(
179         loc, elementType, rewriter.getZeroAttr(elementType));
180     Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
181     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
182         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
183     Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
184     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
185         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
186     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
187         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
188     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
189         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
190     Value lhsContainsNotNaNValue =
191         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
192     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
193         loc, lhsContainsNotNaNValue,
194         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
195     Value inf = rewriter.create<arith::ConstantOp>(
196         loc, elementType,
197         rewriter.getFloatAttr(
198             elementType, APFloat::getInf(elementType.getFloatSemantics())));
199     Value infWithSignOfRhsReal =
200         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
201     Value infinityResultReal =
202         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
203     Value infinityResultImag =
204         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
205 
206     // Case 2. Infinite numerator, finite denominator.
207     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
208         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
209     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
210         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
211     Value rhsFinite =
212         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
213     Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
214     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
215         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
216     Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
217     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
218         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
219     Value lhsInfinite =
220         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
221     Value infNumFiniteDenom =
222         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
223     Value one = rewriter.create<arith::ConstantOp>(
224         loc, elementType, rewriter.getFloatAttr(elementType, 1));
225     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
226         loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero),
227         lhsReal);
228     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
229         loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero),
230         lhsImag);
231     Value lhsRealIsInfWithSignTimesRhsReal =
232         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
233     Value lhsImagIsInfWithSignTimesRhsImag =
234         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
235     Value resultReal3 = rewriter.create<arith::MulFOp>(
236         loc, inf,
237         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
238                                        lhsImagIsInfWithSignTimesRhsImag));
239     Value lhsRealIsInfWithSignTimesRhsImag =
240         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
241     Value lhsImagIsInfWithSignTimesRhsReal =
242         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
243     Value resultImag3 = rewriter.create<arith::MulFOp>(
244         loc, inf,
245         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
246                                        lhsRealIsInfWithSignTimesRhsImag));
247 
248     // Case 3: Finite numerator, infinite denominator.
249     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
250         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
251     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
252         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
253     Value lhsFinite =
254         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
255     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
256         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
257     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
258         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
259     Value rhsInfinite =
260         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
261     Value finiteNumInfiniteDenom =
262         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
263     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
264         loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero),
265         rhsReal);
266     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
267         loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero),
268         rhsImag);
269     Value rhsRealIsInfWithSignTimesLhsReal =
270         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
271     Value rhsImagIsInfWithSignTimesLhsImag =
272         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
273     Value resultReal4 = rewriter.create<arith::MulFOp>(
274         loc, zero,
275         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
276                                        rhsImagIsInfWithSignTimesLhsImag));
277     Value rhsRealIsInfWithSignTimesLhsImag =
278         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
279     Value rhsImagIsInfWithSignTimesLhsReal =
280         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
281     Value resultImag4 = rewriter.create<arith::MulFOp>(
282         loc, zero,
283         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
284                                        rhsImagIsInfWithSignTimesLhsReal));
285 
286     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
287         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
288     Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
289                                                  resultReal1, resultReal2);
290     Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
291                                                  resultImag1, resultImag2);
292     Value resultRealSpecialCase3 = rewriter.create<SelectOp>(
293         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
294     Value resultImagSpecialCase3 = rewriter.create<SelectOp>(
295         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
296     Value resultRealSpecialCase2 = rewriter.create<SelectOp>(
297         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
298     Value resultImagSpecialCase2 = rewriter.create<SelectOp>(
299         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
300     Value resultRealSpecialCase1 = rewriter.create<SelectOp>(
301         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
302     Value resultImagSpecialCase1 = rewriter.create<SelectOp>(
303         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
304 
305     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
306         loc, arith::CmpFPredicate::UNO, resultReal, zero);
307     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
308         loc, arith::CmpFPredicate::UNO, resultImag, zero);
309     Value resultIsNaN =
310         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
311     Value resultRealWithSpecialCases = rewriter.create<SelectOp>(
312         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
313     Value resultImagWithSpecialCases = rewriter.create<SelectOp>(
314         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
315 
316     rewriter.replaceOpWithNewOp<complex::CreateOp>(
317         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
318     return success();
319   }
320 };
321 
322 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
323   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
324 
325   LogicalResult
326   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
327                   ConversionPatternRewriter &rewriter) const override {
328     auto loc = op.getLoc();
329     auto type = adaptor.complex().getType().cast<ComplexType>();
330     auto elementType = type.getElementType().cast<FloatType>();
331 
332     Value real =
333         rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex());
334     Value imag =
335         rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex());
336     Value expReal = rewriter.create<math::ExpOp>(loc, real);
337     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
338     Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
339     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
340     Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
341 
342     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
343                                                    resultImag);
344     return success();
345   }
346 };
347 
348 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
349   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
350 
351   LogicalResult
352   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
353                   ConversionPatternRewriter &rewriter) const override {
354     auto type = adaptor.complex().getType().cast<ComplexType>();
355     auto elementType = type.getElementType().cast<FloatType>();
356     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
357 
358     Value abs = b.create<complex::AbsOp>(elementType, adaptor.complex());
359     Value resultReal = b.create<math::LogOp>(elementType, abs);
360     Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
361     Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
362     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
363     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
364                                                    resultImag);
365     return success();
366   }
367 };
368 
369 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
370   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
371 
372   LogicalResult
373   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
374                   ConversionPatternRewriter &rewriter) const override {
375     auto type = adaptor.complex().getType().cast<ComplexType>();
376     auto elementType = type.getElementType().cast<FloatType>();
377     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
378 
379     Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
380     Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
381     Value one = b.create<arith::ConstantOp>(elementType,
382                                             b.getFloatAttr(elementType, 1));
383     Value realPlusOne = b.create<arith::AddFOp>(real, one);
384     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
385     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
386     return success();
387   }
388 };
389 
390 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
391   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
392 
393   LogicalResult
394   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
395                   ConversionPatternRewriter &rewriter) const override {
396     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
397     auto type = adaptor.lhs().getType().cast<ComplexType>();
398     auto elementType = type.getElementType().cast<FloatType>();
399 
400     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.lhs());
401     Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
402     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.lhs());
403     Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
404     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.rhs());
405     Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
406     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.rhs());
407     Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
408 
409     Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
410     Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
411     Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
412     Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
413     Value real =
414         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
415 
416     Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
417     Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
418     Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
419     Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
420     Value imag =
421         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
422 
423     // Handle cases where the "naive" calculation results in NaN values.
424     Value realIsNan =
425         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
426     Value imagIsNan =
427         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
428     Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
429 
430     Value inf = b.create<arith::ConstantOp>(
431         elementType,
432         b.getFloatAttr(elementType,
433                        APFloat::getInf(elementType.getFloatSemantics())));
434 
435     // Case 1. `lhsReal` or `lhsImag` are infinite.
436     Value lhsRealIsInf =
437         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
438     Value lhsImagIsInf =
439         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
440     Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
441     Value rhsRealIsNan =
442         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
443     Value rhsImagIsNan =
444         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
445     Value zero =
446         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
447     Value one = b.create<arith::ConstantOp>(elementType,
448                                             b.getFloatAttr(elementType, 1));
449     Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero);
450     lhsReal = b.create<SelectOp>(
451         lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
452         lhsReal);
453     Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero);
454     lhsImag = b.create<SelectOp>(
455         lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
456         lhsImag);
457     Value lhsIsInfAndRhsRealIsNan =
458         b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
459     rhsReal =
460         b.create<SelectOp>(lhsIsInfAndRhsRealIsNan,
461                            b.create<math::CopySignOp>(zero, rhsReal), rhsReal);
462     Value lhsIsInfAndRhsImagIsNan =
463         b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
464     rhsImag =
465         b.create<SelectOp>(lhsIsInfAndRhsImagIsNan,
466                            b.create<math::CopySignOp>(zero, rhsImag), rhsImag);
467 
468     // Case 2. `rhsReal` or `rhsImag` are infinite.
469     Value rhsRealIsInf =
470         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
471     Value rhsImagIsInf =
472         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
473     Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
474     Value lhsRealIsNan =
475         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
476     Value lhsImagIsNan =
477         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
478     Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero);
479     rhsReal = b.create<SelectOp>(
480         rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
481         rhsReal);
482     Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero);
483     rhsImag = b.create<SelectOp>(
484         rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
485         rhsImag);
486     Value rhsIsInfAndLhsRealIsNan =
487         b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
488     lhsReal =
489         b.create<SelectOp>(rhsIsInfAndLhsRealIsNan,
490                            b.create<math::CopySignOp>(zero, lhsReal), lhsReal);
491     Value rhsIsInfAndLhsImagIsNan =
492         b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
493     lhsImag =
494         b.create<SelectOp>(rhsIsInfAndLhsImagIsNan,
495                            b.create<math::CopySignOp>(zero, lhsImag), lhsImag);
496     Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
497 
498     // Case 3. One of the pairwise products of left hand side with right hand
499     // side is infinite.
500     Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
501         arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
502     Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
503         arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
504     Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
505                                                  lhsImagTimesRhsImagIsInf);
506     Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
507         arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
508     isSpecialCase =
509         b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
510     Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
511         arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
512     isSpecialCase =
513         b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
514     Type i1Type = b.getI1Type();
515     Value notRecalc = b.create<arith::XOrIOp>(
516         recalc,
517         b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
518     isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
519     Value isSpecialCaseAndLhsRealIsNan =
520         b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
521     lhsReal =
522         b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan,
523                            b.create<math::CopySignOp>(zero, lhsReal), lhsReal);
524     Value isSpecialCaseAndLhsImagIsNan =
525         b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
526     lhsImag =
527         b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan,
528                            b.create<math::CopySignOp>(zero, lhsImag), lhsImag);
529     Value isSpecialCaseAndRhsRealIsNan =
530         b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
531     rhsReal =
532         b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan,
533                            b.create<math::CopySignOp>(zero, rhsReal), rhsReal);
534     Value isSpecialCaseAndRhsImagIsNan =
535         b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
536     rhsImag =
537         b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan,
538                            b.create<math::CopySignOp>(zero, rhsImag), rhsImag);
539     recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
540     recalc = b.create<arith::AndIOp>(isNan, recalc);
541 
542     // Recalculate real part.
543     lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
544     lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
545     Value newReal =
546         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
547     real =
548         b.create<SelectOp>(recalc, b.create<arith::MulFOp>(inf, newReal), real);
549 
550     // Recalculate imag part.
551     lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
552     lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
553     Value newImag =
554         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
555     imag =
556         b.create<SelectOp>(recalc, b.create<arith::MulFOp>(inf, newImag), imag);
557 
558     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
559     return success();
560   }
561 };
562 
563 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
564   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
565 
566   LogicalResult
567   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
568                   ConversionPatternRewriter &rewriter) const override {
569     auto loc = op.getLoc();
570     auto type = adaptor.complex().getType().cast<ComplexType>();
571     auto elementType = type.getElementType().cast<FloatType>();
572 
573     Value real =
574         rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex());
575     Value imag =
576         rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex());
577     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
578     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
579     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
580     return success();
581   }
582 };
583 
584 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
585   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
586 
587   LogicalResult
588   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
589                   ConversionPatternRewriter &rewriter) const override {
590     auto type = adaptor.complex().getType().cast<ComplexType>();
591     auto elementType = type.getElementType().cast<FloatType>();
592     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
593 
594     Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
595     Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
596     Value zero =
597         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
598     Value realIsZero =
599         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
600     Value imagIsZero =
601         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
602     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
603     auto abs = b.create<complex::AbsOp>(elementType, adaptor.complex());
604     Value realSign = b.create<arith::DivFOp>(real, abs);
605     Value imagSign = b.create<arith::DivFOp>(imag, abs);
606     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
607     rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, adaptor.complex(), sign);
608     return success();
609   }
610 };
611 } // namespace
612 
613 void mlir::populateComplexToStandardConversionPatterns(
614     RewritePatternSet &patterns) {
615   // clang-format off
616   patterns.add<
617       AbsOpConversion,
618       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
619       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
620       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
621       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
622       DivOpConversion,
623       ExpOpConversion,
624       LogOpConversion,
625       Log1pOpConversion,
626       MulOpConversion,
627       NegOpConversion,
628       SignOpConversion>(patterns.getContext());
629   // clang-format on
630 }
631 
632 namespace {
633 struct ConvertComplexToStandardPass
634     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
635   void runOnFunction() override;
636 };
637 
638 void ConvertComplexToStandardPass::runOnFunction() {
639   auto function = getFunction();
640 
641   // Convert to the Standard dialect using the converter defined above.
642   RewritePatternSet patterns(&getContext());
643   populateComplexToStandardConversionPatterns(patterns);
644 
645   ConversionTarget target(getContext());
646   target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect,
647                          math::MathDialect>();
648   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
649   if (failed(applyPartialConversion(function, target, std::move(patterns))))
650     signalPassFailure();
651 }
652 } // namespace
653 
654 std::unique_ptr<OperationPass<FuncOp>>
655 mlir::createConvertComplexToStandardPass() {
656   return std::make_unique<ConvertComplexToStandardPass>();
657 }
658