1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Complex/IR/Complex.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
27   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
28 
29   LogicalResult
30   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
31                   ConversionPatternRewriter &rewriter) const override {
32     auto loc = op.getLoc();
33     auto type = op.getType();
34 
35     Value real =
36         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
37     Value imag =
38         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
39     Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
40     Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
41     Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
42 
43     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
44     return success();
45   }
46 };
47 
48 template <typename ComparisonOp, arith::CmpFPredicate p>
49 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
50   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
51   using ResultCombiner =
52       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
53                          arith::AndIOp, arith::OrIOp>;
54 
55   LogicalResult
56   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
57                   ConversionPatternRewriter &rewriter) const override {
58     auto loc = op.getLoc();
59     auto type = adaptor.getLhs()
60                     .getType()
61                     .template cast<ComplexType>()
62                     .getElementType();
63 
64     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
65     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
66     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
67     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
68     Value realComparison =
69         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
70     Value imagComparison =
71         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
72 
73     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
74                                                 imagComparison);
75     return success();
76   }
77 };
78 
79 // Default conversion which applies the BinaryStandardOp separately on the real
80 // and imaginary parts. Can for example be used for complex::AddOp and
81 // complex::SubOp.
82 template <typename BinaryComplexOp, typename BinaryStandardOp>
83 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
84   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
85 
86   LogicalResult
87   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
88                   ConversionPatternRewriter &rewriter) const override {
89     auto type = adaptor.getLhs().getType().template cast<ComplexType>();
90     auto elementType = type.getElementType().template cast<FloatType>();
91     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
92 
93     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
94     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
95     Value resultReal =
96         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
97     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
98     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
99     Value resultImag =
100         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
101     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
102                                                    resultImag);
103     return success();
104   }
105 };
106 
107 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
108   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
109 
110   LogicalResult
111   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
112                   ConversionPatternRewriter &rewriter) const override {
113     auto loc = op.getLoc();
114     auto type = adaptor.getLhs().getType().cast<ComplexType>();
115     auto elementType = type.getElementType().cast<FloatType>();
116 
117     Value lhsReal =
118         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
119     Value lhsImag =
120         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
121     Value rhsReal =
122         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
123     Value rhsImag =
124         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
125 
126     // Smith's algorithm to divide complex numbers. It is just a bit smarter
127     // way to compute the following formula:
128     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
129     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
130     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
131     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
132     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
133     //
134     // Depending on whether |rhsReal| < |rhsImag| we compute either
135     //   rhsRealImagRatio = rhsReal / rhsImag
136     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
137     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
138     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
139     //
140     // or
141     //
142     //   rhsImagRealRatio = rhsImag / rhsReal
143     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
144     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
145     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
146     //
147     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
148     Value rhsRealImagRatio =
149         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
150     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
151         loc, rhsImag,
152         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
153     Value realNumerator1 = rewriter.create<arith::AddFOp>(
154         loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
155         lhsImag);
156     Value resultReal1 =
157         rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
158     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
159         loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
160         lhsReal);
161     Value resultImag1 =
162         rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
163 
164     Value rhsImagRealRatio =
165         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
166     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
167         loc, rhsReal,
168         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
169     Value realNumerator2 = rewriter.create<arith::AddFOp>(
170         loc, lhsReal,
171         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
172     Value resultReal2 =
173         rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
174     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
175         loc, lhsImag,
176         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
177     Value resultImag2 =
178         rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
179 
180     // Consider corner cases.
181     // Case 1. Zero denominator, numerator contains at most one NaN value.
182     Value zero = rewriter.create<arith::ConstantOp>(
183         loc, elementType, rewriter.getZeroAttr(elementType));
184     Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
185     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
186         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
187     Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
188     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
189         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
190     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
191         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
192     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
193         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
194     Value lhsContainsNotNaNValue =
195         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
196     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
197         loc, lhsContainsNotNaNValue,
198         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
199     Value inf = rewriter.create<arith::ConstantOp>(
200         loc, elementType,
201         rewriter.getFloatAttr(
202             elementType, APFloat::getInf(elementType.getFloatSemantics())));
203     Value infWithSignOfRhsReal =
204         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
205     Value infinityResultReal =
206         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
207     Value infinityResultImag =
208         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
209 
210     // Case 2. Infinite numerator, finite denominator.
211     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
212         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
213     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
214         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
215     Value rhsFinite =
216         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
217     Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
218     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
219         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
220     Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
221     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
222         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
223     Value lhsInfinite =
224         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
225     Value infNumFiniteDenom =
226         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
227     Value one = rewriter.create<arith::ConstantOp>(
228         loc, elementType, rewriter.getFloatAttr(elementType, 1));
229     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
230         loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
231         lhsReal);
232     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
233         loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
234         lhsImag);
235     Value lhsRealIsInfWithSignTimesRhsReal =
236         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
237     Value lhsImagIsInfWithSignTimesRhsImag =
238         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
239     Value resultReal3 = rewriter.create<arith::MulFOp>(
240         loc, inf,
241         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
242                                        lhsImagIsInfWithSignTimesRhsImag));
243     Value lhsRealIsInfWithSignTimesRhsImag =
244         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
245     Value lhsImagIsInfWithSignTimesRhsReal =
246         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
247     Value resultImag3 = rewriter.create<arith::MulFOp>(
248         loc, inf,
249         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
250                                        lhsRealIsInfWithSignTimesRhsImag));
251 
252     // Case 3: Finite numerator, infinite denominator.
253     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
254         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
255     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
256         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
257     Value lhsFinite =
258         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
259     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
260         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
261     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
262         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
263     Value rhsInfinite =
264         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
265     Value finiteNumInfiniteDenom =
266         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
267     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
268         loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
269         rhsReal);
270     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
271         loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
272         rhsImag);
273     Value rhsRealIsInfWithSignTimesLhsReal =
274         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
275     Value rhsImagIsInfWithSignTimesLhsImag =
276         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
277     Value resultReal4 = rewriter.create<arith::MulFOp>(
278         loc, zero,
279         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
280                                        rhsImagIsInfWithSignTimesLhsImag));
281     Value rhsRealIsInfWithSignTimesLhsImag =
282         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
283     Value rhsImagIsInfWithSignTimesLhsReal =
284         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
285     Value resultImag4 = rewriter.create<arith::MulFOp>(
286         loc, zero,
287         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
288                                        rhsImagIsInfWithSignTimesLhsReal));
289 
290     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
291         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
292     Value resultReal = rewriter.create<arith::SelectOp>(
293         loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
294     Value resultImag = rewriter.create<arith::SelectOp>(
295         loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
296     Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
297         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
298     Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
299         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
300     Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
301         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
302     Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
303         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
304     Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
305         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
306     Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
307         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
308 
309     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
310         loc, arith::CmpFPredicate::UNO, resultReal, zero);
311     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
312         loc, arith::CmpFPredicate::UNO, resultImag, zero);
313     Value resultIsNaN =
314         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
315     Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
316         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
317     Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
318         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
319 
320     rewriter.replaceOpWithNewOp<complex::CreateOp>(
321         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
322     return success();
323   }
324 };
325 
326 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
327   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
328 
329   LogicalResult
330   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
331                   ConversionPatternRewriter &rewriter) const override {
332     auto loc = op.getLoc();
333     auto type = adaptor.getComplex().getType().cast<ComplexType>();
334     auto elementType = type.getElementType().cast<FloatType>();
335 
336     Value real =
337         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
338     Value imag =
339         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
340     Value expReal = rewriter.create<math::ExpOp>(loc, real);
341     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
342     Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
343     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
344     Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
345 
346     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
347                                                    resultImag);
348     return success();
349   }
350 };
351 
352 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
353   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
354 
355   LogicalResult
356   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
357                   ConversionPatternRewriter &rewriter) const override {
358     auto type = adaptor.getComplex().getType().cast<ComplexType>();
359     auto elementType = type.getElementType().cast<FloatType>();
360     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
361 
362     Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
363     Value resultReal = b.create<math::LogOp>(elementType, abs);
364     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
365     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
366     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
367     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
368                                                    resultImag);
369     return success();
370   }
371 };
372 
373 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
374   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
375 
376   LogicalResult
377   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
378                   ConversionPatternRewriter &rewriter) const override {
379     auto type = adaptor.getComplex().getType().cast<ComplexType>();
380     auto elementType = type.getElementType().cast<FloatType>();
381     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
382 
383     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
384     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
385     Value one = b.create<arith::ConstantOp>(elementType,
386                                             b.getFloatAttr(elementType, 1));
387     Value realPlusOne = b.create<arith::AddFOp>(real, one);
388     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
389     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
390     return success();
391   }
392 };
393 
394 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
395   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
396 
397   LogicalResult
398   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
399                   ConversionPatternRewriter &rewriter) const override {
400     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
401     auto type = adaptor.getLhs().getType().cast<ComplexType>();
402     auto elementType = type.getElementType().cast<FloatType>();
403 
404     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
405     Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
406     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
407     Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
408     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
409     Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
410     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
411     Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
412 
413     Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
414     Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
415     Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
416     Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
417     Value real =
418         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
419 
420     Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
421     Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
422     Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
423     Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
424     Value imag =
425         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
426 
427     // Handle cases where the "naive" calculation results in NaN values.
428     Value realIsNan =
429         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
430     Value imagIsNan =
431         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
432     Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
433 
434     Value inf = b.create<arith::ConstantOp>(
435         elementType,
436         b.getFloatAttr(elementType,
437                        APFloat::getInf(elementType.getFloatSemantics())));
438 
439     // Case 1. `lhsReal` or `lhsImag` are infinite.
440     Value lhsRealIsInf =
441         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
442     Value lhsImagIsInf =
443         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
444     Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
445     Value rhsRealIsNan =
446         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
447     Value rhsImagIsNan =
448         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
449     Value zero =
450         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
451     Value one = b.create<arith::ConstantOp>(elementType,
452                                             b.getFloatAttr(elementType, 1));
453     Value lhsRealIsInfFloat =
454         b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
455     lhsReal = b.create<arith::SelectOp>(
456         lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
457         lhsReal);
458     Value lhsImagIsInfFloat =
459         b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
460     lhsImag = b.create<arith::SelectOp>(
461         lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
462         lhsImag);
463     Value lhsIsInfAndRhsRealIsNan =
464         b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
465     rhsReal = b.create<arith::SelectOp>(
466         lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
467         rhsReal);
468     Value lhsIsInfAndRhsImagIsNan =
469         b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
470     rhsImag = b.create<arith::SelectOp>(
471         lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
472         rhsImag);
473 
474     // Case 2. `rhsReal` or `rhsImag` are infinite.
475     Value rhsRealIsInf =
476         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
477     Value rhsImagIsInf =
478         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
479     Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
480     Value lhsRealIsNan =
481         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
482     Value lhsImagIsNan =
483         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
484     Value rhsRealIsInfFloat =
485         b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
486     rhsReal = b.create<arith::SelectOp>(
487         rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
488         rhsReal);
489     Value rhsImagIsInfFloat =
490         b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
491     rhsImag = b.create<arith::SelectOp>(
492         rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
493         rhsImag);
494     Value rhsIsInfAndLhsRealIsNan =
495         b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
496     lhsReal = b.create<arith::SelectOp>(
497         rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
498         lhsReal);
499     Value rhsIsInfAndLhsImagIsNan =
500         b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
501     lhsImag = b.create<arith::SelectOp>(
502         rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
503         lhsImag);
504     Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
505 
506     // Case 3. One of the pairwise products of left hand side with right hand
507     // side is infinite.
508     Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
509         arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
510     Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
511         arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
512     Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
513                                                  lhsImagTimesRhsImagIsInf);
514     Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
515         arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
516     isSpecialCase =
517         b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
518     Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
519         arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
520     isSpecialCase =
521         b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
522     Type i1Type = b.getI1Type();
523     Value notRecalc = b.create<arith::XOrIOp>(
524         recalc,
525         b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
526     isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
527     Value isSpecialCaseAndLhsRealIsNan =
528         b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
529     lhsReal = b.create<arith::SelectOp>(
530         isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
531         lhsReal);
532     Value isSpecialCaseAndLhsImagIsNan =
533         b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
534     lhsImag = b.create<arith::SelectOp>(
535         isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
536         lhsImag);
537     Value isSpecialCaseAndRhsRealIsNan =
538         b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
539     rhsReal = b.create<arith::SelectOp>(
540         isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
541         rhsReal);
542     Value isSpecialCaseAndRhsImagIsNan =
543         b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
544     rhsImag = b.create<arith::SelectOp>(
545         isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
546         rhsImag);
547     recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
548     recalc = b.create<arith::AndIOp>(isNan, recalc);
549 
550     // Recalculate real part.
551     lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
552     lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
553     Value newReal =
554         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
555     real = b.create<arith::SelectOp>(
556         recalc, b.create<arith::MulFOp>(inf, newReal), real);
557 
558     // Recalculate imag part.
559     lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
560     lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
561     Value newImag =
562         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
563     imag = b.create<arith::SelectOp>(
564         recalc, b.create<arith::MulFOp>(inf, newImag), imag);
565 
566     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
567     return success();
568   }
569 };
570 
571 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
572   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
573 
574   LogicalResult
575   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
576                   ConversionPatternRewriter &rewriter) const override {
577     auto loc = op.getLoc();
578     auto type = adaptor.getComplex().getType().cast<ComplexType>();
579     auto elementType = type.getElementType().cast<FloatType>();
580 
581     Value real =
582         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
583     Value imag =
584         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
585     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
586     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
587     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
588     return success();
589   }
590 };
591 
592 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
593   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
594 
595   LogicalResult
596   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
597                   ConversionPatternRewriter &rewriter) const override {
598     auto type = adaptor.getComplex().getType().cast<ComplexType>();
599     auto elementType = type.getElementType().cast<FloatType>();
600     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
601 
602     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
603     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
604     Value zero =
605         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
606     Value realIsZero =
607         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
608     Value imagIsZero =
609         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
610     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
611     auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
612     Value realSign = b.create<arith::DivFOp>(real, abs);
613     Value imagSign = b.create<arith::DivFOp>(imag, abs);
614     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
615     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
616                                                  adaptor.getComplex(), sign);
617     return success();
618   }
619 };
620 } // namespace
621 
622 void mlir::populateComplexToStandardConversionPatterns(
623     RewritePatternSet &patterns) {
624   // clang-format off
625   patterns.add<
626       AbsOpConversion,
627       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
628       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
629       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
630       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
631       DivOpConversion,
632       ExpOpConversion,
633       LogOpConversion,
634       Log1pOpConversion,
635       MulOpConversion,
636       NegOpConversion,
637       SignOpConversion>(patterns.getContext());
638   // clang-format on
639 }
640 
641 namespace {
642 struct ConvertComplexToStandardPass
643     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
644   void runOnOperation() override;
645 };
646 
647 void ConvertComplexToStandardPass::runOnOperation() {
648   auto function = getOperation();
649 
650   // Convert to the Standard dialect using the converter defined above.
651   RewritePatternSet patterns(&getContext());
652   populateComplexToStandardConversionPatterns(patterns);
653 
654   ConversionTarget target(getContext());
655   target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect,
656                          math::MathDialect>();
657   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
658   if (failed(applyPartialConversion(function, target, std::move(patterns))))
659     signalPassFailure();
660 }
661 } // namespace
662 
663 std::unique_ptr<OperationPass<FuncOp>>
664 mlir::createConvertComplexToStandardPass() {
665   return std::make_unique<ConvertComplexToStandardPass>();
666 }
667