1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Complex/IR/Complex.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
26   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
30                   ConversionPatternRewriter &rewriter) const override {
31     auto loc = op.getLoc();
32     auto type = op.getType();
33 
34     Value real = rewriter.create<complex::ReOp>(loc, type, adaptor.complex());
35     Value imag = rewriter.create<complex::ImOp>(loc, type, adaptor.complex());
36     Value realSqr = rewriter.create<MulFOp>(loc, real, real);
37     Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag);
38     Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr);
39 
40     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
41     return success();
42   }
43 };
44 
45 template <typename ComparisonOp, CmpFPredicate p>
46 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
47   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
48   using ResultCombiner =
49       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
50                          AndOp, OrOp>;
51 
52   LogicalResult
53   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
54                   ConversionPatternRewriter &rewriter) const override {
55     auto loc = op.getLoc();
56     auto type =
57         adaptor.lhs().getType().template cast<ComplexType>().getElementType();
58 
59     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.lhs());
60     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.lhs());
61     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.rhs());
62     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.rhs());
63     Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs);
64     Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs);
65 
66     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
67                                                 imagComparison);
68     return success();
69   }
70 };
71 
72 // Default conversion which applies the BinaryStandardOp separately on the real
73 // and imaginary parts. Can for example be used for complex::AddOp and
74 // complex::SubOp.
75 template <typename BinaryComplexOp, typename BinaryStandardOp>
76 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
77   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
78 
79   LogicalResult
80   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
81                   ConversionPatternRewriter &rewriter) const override {
82     auto type = adaptor.lhs().getType().template cast<ComplexType>();
83     auto elementType = type.getElementType().template cast<FloatType>();
84     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
85 
86     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.lhs());
87     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.rhs());
88     Value resultReal =
89         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
90     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.lhs());
91     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.rhs());
92     Value resultImag =
93         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
94     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
95                                                    resultImag);
96     return success();
97   }
98 };
99 
100 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
101   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
102 
103   LogicalResult
104   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
105                   ConversionPatternRewriter &rewriter) const override {
106     auto loc = op.getLoc();
107     auto type = adaptor.lhs().getType().cast<ComplexType>();
108     auto elementType = type.getElementType().cast<FloatType>();
109 
110     Value lhsReal =
111         rewriter.create<complex::ReOp>(loc, elementType, adaptor.lhs());
112     Value lhsImag =
113         rewriter.create<complex::ImOp>(loc, elementType, adaptor.lhs());
114     Value rhsReal =
115         rewriter.create<complex::ReOp>(loc, elementType, adaptor.rhs());
116     Value rhsImag =
117         rewriter.create<complex::ImOp>(loc, elementType, adaptor.rhs());
118 
119     // Smith's algorithm to divide complex numbers. It is just a bit smarter
120     // way to compute the following formula:
121     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
122     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
123     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
124     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
125     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
126     //
127     // Depending on whether |rhsReal| < |rhsImag| we compute either
128     //   rhsRealImagRatio = rhsReal / rhsImag
129     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
130     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
131     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
132     //
133     // or
134     //
135     //   rhsImagRealRatio = rhsImag / rhsReal
136     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
137     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
138     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
139     //
140     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
141     Value rhsRealImagRatio = rewriter.create<DivFOp>(loc, rhsReal, rhsImag);
142     Value rhsRealImagDenom = rewriter.create<AddFOp>(
143         loc, rhsImag, rewriter.create<MulFOp>(loc, rhsRealImagRatio, rhsReal));
144     Value realNumerator1 = rewriter.create<AddFOp>(
145         loc, rewriter.create<MulFOp>(loc, lhsReal, rhsRealImagRatio), lhsImag);
146     Value resultReal1 =
147         rewriter.create<DivFOp>(loc, realNumerator1, rhsRealImagDenom);
148     Value imagNumerator1 = rewriter.create<SubFOp>(
149         loc, rewriter.create<MulFOp>(loc, lhsImag, rhsRealImagRatio), lhsReal);
150     Value resultImag1 =
151         rewriter.create<DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
152 
153     Value rhsImagRealRatio = rewriter.create<DivFOp>(loc, rhsImag, rhsReal);
154     Value rhsImagRealDenom = rewriter.create<AddFOp>(
155         loc, rhsReal, rewriter.create<MulFOp>(loc, rhsImagRealRatio, rhsImag));
156     Value realNumerator2 = rewriter.create<AddFOp>(
157         loc, lhsReal, rewriter.create<MulFOp>(loc, lhsImag, rhsImagRealRatio));
158     Value resultReal2 =
159         rewriter.create<DivFOp>(loc, realNumerator2, rhsImagRealDenom);
160     Value imagNumerator2 = rewriter.create<SubFOp>(
161         loc, lhsImag, rewriter.create<MulFOp>(loc, lhsReal, rhsImagRealRatio));
162     Value resultImag2 =
163         rewriter.create<DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
164 
165     // Consider corner cases.
166     // Case 1. Zero denominator, numerator contains at most one NaN value.
167     Value zero = rewriter.create<ConstantOp>(loc, elementType,
168                                              rewriter.getZeroAttr(elementType));
169     Value rhsRealAbs = rewriter.create<AbsFOp>(loc, rhsReal);
170     Value rhsRealIsZero =
171         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, zero);
172     Value rhsImagAbs = rewriter.create<AbsFOp>(loc, rhsImag);
173     Value rhsImagIsZero =
174         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, zero);
175     Value lhsRealIsNotNaN =
176         rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsReal, zero);
177     Value lhsImagIsNotNaN =
178         rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsImag, zero);
179     Value lhsContainsNotNaNValue =
180         rewriter.create<OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
181     Value resultIsInfinity = rewriter.create<AndOp>(
182         loc, lhsContainsNotNaNValue,
183         rewriter.create<AndOp>(loc, rhsRealIsZero, rhsImagIsZero));
184     Value inf = rewriter.create<ConstantOp>(
185         loc, elementType,
186         rewriter.getFloatAttr(
187             elementType, APFloat::getInf(elementType.getFloatSemantics())));
188     Value infWithSignOfRhsReal = rewriter.create<CopySignOp>(loc, inf, rhsReal);
189     Value infinityResultReal =
190         rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
191     Value infinityResultImag =
192         rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
193 
194     // Case 2. Infinite numerator, finite denominator.
195     Value rhsRealFinite =
196         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsRealAbs, inf);
197     Value rhsImagFinite =
198         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsImagAbs, inf);
199     Value rhsFinite = rewriter.create<AndOp>(loc, rhsRealFinite, rhsImagFinite);
200     Value lhsRealAbs = rewriter.create<AbsFOp>(loc, lhsReal);
201     Value lhsRealInfinite =
202         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsRealAbs, inf);
203     Value lhsImagAbs = rewriter.create<AbsFOp>(loc, lhsImag);
204     Value lhsImagInfinite =
205         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsImagAbs, inf);
206     Value lhsInfinite =
207         rewriter.create<OrOp>(loc, lhsRealInfinite, lhsImagInfinite);
208     Value infNumFiniteDenom =
209         rewriter.create<AndOp>(loc, lhsInfinite, rhsFinite);
210     Value one = rewriter.create<ConstantOp>(
211         loc, elementType, rewriter.getFloatAttr(elementType, 1));
212     Value lhsRealIsInfWithSign = rewriter.create<CopySignOp>(
213         loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero),
214         lhsReal);
215     Value lhsImagIsInfWithSign = rewriter.create<CopySignOp>(
216         loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero),
217         lhsImag);
218     Value lhsRealIsInfWithSignTimesRhsReal =
219         rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
220     Value lhsImagIsInfWithSignTimesRhsImag =
221         rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
222     Value resultReal3 = rewriter.create<MulFOp>(
223         loc, inf,
224         rewriter.create<AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
225                                 lhsImagIsInfWithSignTimesRhsImag));
226     Value lhsRealIsInfWithSignTimesRhsImag =
227         rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
228     Value lhsImagIsInfWithSignTimesRhsReal =
229         rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
230     Value resultImag3 = rewriter.create<MulFOp>(
231         loc, inf,
232         rewriter.create<SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
233                                 lhsRealIsInfWithSignTimesRhsImag));
234 
235     // Case 3: Finite numerator, infinite denominator.
236     Value lhsRealFinite =
237         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsRealAbs, inf);
238     Value lhsImagFinite =
239         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsImagAbs, inf);
240     Value lhsFinite = rewriter.create<AndOp>(loc, lhsRealFinite, lhsImagFinite);
241     Value rhsRealInfinite =
242         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, inf);
243     Value rhsImagInfinite =
244         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, inf);
245     Value rhsInfinite =
246         rewriter.create<OrOp>(loc, rhsRealInfinite, rhsImagInfinite);
247     Value finiteNumInfiniteDenom =
248         rewriter.create<AndOp>(loc, lhsFinite, rhsInfinite);
249     Value rhsRealIsInfWithSign = rewriter.create<CopySignOp>(
250         loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero),
251         rhsReal);
252     Value rhsImagIsInfWithSign = rewriter.create<CopySignOp>(
253         loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero),
254         rhsImag);
255     Value rhsRealIsInfWithSignTimesLhsReal =
256         rewriter.create<MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
257     Value rhsImagIsInfWithSignTimesLhsImag =
258         rewriter.create<MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
259     Value resultReal4 = rewriter.create<MulFOp>(
260         loc, zero,
261         rewriter.create<AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
262                                 rhsImagIsInfWithSignTimesLhsImag));
263     Value rhsRealIsInfWithSignTimesLhsImag =
264         rewriter.create<MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
265     Value rhsImagIsInfWithSignTimesLhsReal =
266         rewriter.create<MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
267     Value resultImag4 = rewriter.create<MulFOp>(
268         loc, zero,
269         rewriter.create<SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
270                                 rhsImagIsInfWithSignTimesLhsReal));
271 
272     Value realAbsSmallerThanImagAbs = rewriter.create<CmpFOp>(
273         loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
274     Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
275                                                  resultReal1, resultReal2);
276     Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
277                                                  resultImag1, resultImag2);
278     Value resultRealSpecialCase3 = rewriter.create<SelectOp>(
279         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
280     Value resultImagSpecialCase3 = rewriter.create<SelectOp>(
281         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
282     Value resultRealSpecialCase2 = rewriter.create<SelectOp>(
283         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
284     Value resultImagSpecialCase2 = rewriter.create<SelectOp>(
285         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
286     Value resultRealSpecialCase1 = rewriter.create<SelectOp>(
287         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
288     Value resultImagSpecialCase1 = rewriter.create<SelectOp>(
289         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
290 
291     Value resultRealIsNaN =
292         rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultReal, zero);
293     Value resultImagIsNaN =
294         rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultImag, zero);
295     Value resultIsNaN =
296         rewriter.create<AndOp>(loc, resultRealIsNaN, resultImagIsNaN);
297     Value resultRealWithSpecialCases = rewriter.create<SelectOp>(
298         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
299     Value resultImagWithSpecialCases = rewriter.create<SelectOp>(
300         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
301 
302     rewriter.replaceOpWithNewOp<complex::CreateOp>(
303         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
304     return success();
305   }
306 };
307 
308 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
309   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
310 
311   LogicalResult
312   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
313                   ConversionPatternRewriter &rewriter) const override {
314     auto loc = op.getLoc();
315     auto type = adaptor.complex().getType().cast<ComplexType>();
316     auto elementType = type.getElementType().cast<FloatType>();
317 
318     Value real =
319         rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex());
320     Value imag =
321         rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex());
322     Value expReal = rewriter.create<math::ExpOp>(loc, real);
323     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
324     Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag);
325     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
326     Value resultImag = rewriter.create<MulFOp>(loc, expReal, sinImag);
327 
328     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
329                                                    resultImag);
330     return success();
331   }
332 };
333 
334 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
335   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
336 
337   LogicalResult
338   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
339                   ConversionPatternRewriter &rewriter) const override {
340     auto type = adaptor.complex().getType().cast<ComplexType>();
341     auto elementType = type.getElementType().cast<FloatType>();
342     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
343 
344     Value abs = b.create<complex::AbsOp>(elementType, adaptor.complex());
345     Value resultReal = b.create<math::LogOp>(elementType, abs);
346     Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
347     Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
348     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
349     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
350                                                    resultImag);
351     return success();
352   }
353 };
354 
355 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
356   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
357 
358   LogicalResult
359   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
360                   ConversionPatternRewriter &rewriter) const override {
361     auto type = adaptor.complex().getType().cast<ComplexType>();
362     auto elementType = type.getElementType().cast<FloatType>();
363     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
364 
365     Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
366     Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
367     Value one =
368         b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
369     Value realPlusOne = b.create<AddFOp>(real, one);
370     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
371     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
372     return success();
373   }
374 };
375 
376 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
377   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
378 
379   LogicalResult
380   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
381                   ConversionPatternRewriter &rewriter) const override {
382     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
383     auto type = adaptor.lhs().getType().cast<ComplexType>();
384     auto elementType = type.getElementType().cast<FloatType>();
385 
386     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.lhs());
387     Value lhsRealAbs = b.create<AbsFOp>(lhsReal);
388     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.lhs());
389     Value lhsImagAbs = b.create<AbsFOp>(lhsImag);
390     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.rhs());
391     Value rhsRealAbs = b.create<AbsFOp>(rhsReal);
392     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.rhs());
393     Value rhsImagAbs = b.create<AbsFOp>(rhsImag);
394 
395     Value lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
396     Value lhsRealTimesRhsRealAbs = b.create<AbsFOp>(lhsRealTimesRhsReal);
397     Value lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag);
398     Value lhsImagTimesRhsImagAbs = b.create<AbsFOp>(lhsImagTimesRhsImag);
399     Value real = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
400 
401     Value lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal);
402     Value lhsImagTimesRhsRealAbs = b.create<AbsFOp>(lhsImagTimesRhsReal);
403     Value lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag);
404     Value lhsRealTimesRhsImagAbs = b.create<AbsFOp>(lhsRealTimesRhsImag);
405     Value imag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
406 
407     // Handle cases where the "naive" calculation results in NaN values.
408     Value realIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, real, real);
409     Value imagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, imag, imag);
410     Value isNan = b.create<AndOp>(realIsNan, imagIsNan);
411 
412     Value inf = b.create<ConstantOp>(
413         elementType,
414         b.getFloatAttr(elementType,
415                        APFloat::getInf(elementType.getFloatSemantics())));
416 
417     // Case 1. `lhsReal` or `lhsImag` are infinite.
418     Value lhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealAbs, inf);
419     Value lhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagAbs, inf);
420     Value lhsIsInf = b.create<OrOp>(lhsRealIsInf, lhsImagIsInf);
421     Value rhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsReal, rhsReal);
422     Value rhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsImag, rhsImag);
423     Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
424     Value one =
425         b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
426     Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero);
427     lhsReal = b.create<SelectOp>(
428         lhsIsInf, b.create<CopySignOp>(lhsRealIsInfFloat, lhsReal), lhsReal);
429     Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero);
430     lhsImag = b.create<SelectOp>(
431         lhsIsInf, b.create<CopySignOp>(lhsImagIsInfFloat, lhsImag), lhsImag);
432     Value lhsIsInfAndRhsRealIsNan = b.create<AndOp>(lhsIsInf, rhsRealIsNan);
433     rhsReal = b.create<SelectOp>(lhsIsInfAndRhsRealIsNan,
434                                  b.create<CopySignOp>(zero, rhsReal), rhsReal);
435     Value lhsIsInfAndRhsImagIsNan = b.create<AndOp>(lhsIsInf, rhsImagIsNan);
436     rhsImag = b.create<SelectOp>(lhsIsInfAndRhsImagIsNan,
437                                  b.create<CopySignOp>(zero, rhsImag), rhsImag);
438 
439     // Case 2. `rhsReal` or `rhsImag` are infinite.
440     Value rhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsRealAbs, inf);
441     Value rhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsImagAbs, inf);
442     Value rhsIsInf = b.create<OrOp>(rhsRealIsInf, rhsImagIsInf);
443     Value lhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsReal, lhsReal);
444     Value lhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsImag, lhsImag);
445     Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero);
446     rhsReal = b.create<SelectOp>(
447         rhsIsInf, b.create<CopySignOp>(rhsRealIsInfFloat, rhsReal), rhsReal);
448     Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero);
449     rhsImag = b.create<SelectOp>(
450         rhsIsInf, b.create<CopySignOp>(rhsImagIsInfFloat, rhsImag), rhsImag);
451     Value rhsIsInfAndLhsRealIsNan = b.create<AndOp>(rhsIsInf, lhsRealIsNan);
452     lhsReal = b.create<SelectOp>(rhsIsInfAndLhsRealIsNan,
453                                  b.create<CopySignOp>(zero, lhsReal), lhsReal);
454     Value rhsIsInfAndLhsImagIsNan = b.create<AndOp>(rhsIsInf, lhsImagIsNan);
455     lhsImag = b.create<SelectOp>(rhsIsInfAndLhsImagIsNan,
456                                  b.create<CopySignOp>(zero, lhsImag), lhsImag);
457     Value recalc = b.create<OrOp>(lhsIsInf, rhsIsInf);
458 
459     // Case 3. One of the pairwise products of left hand side with right hand
460     // side is infinite.
461     Value lhsRealTimesRhsRealIsInf =
462         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
463     Value lhsImagTimesRhsImagIsInf =
464         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
465     Value isSpecialCase =
466         b.create<OrOp>(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf);
467     Value lhsRealTimesRhsImagIsInf =
468         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
469     isSpecialCase = b.create<OrOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
470     Value lhsImagTimesRhsRealIsInf =
471         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
472     isSpecialCase = b.create<OrOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
473     Type i1Type = b.getI1Type();
474     Value notRecalc = b.create<XOrOp>(
475         recalc, b.create<ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
476     isSpecialCase = b.create<AndOp>(isSpecialCase, notRecalc);
477     Value isSpecialCaseAndLhsRealIsNan =
478         b.create<AndOp>(isSpecialCase, lhsRealIsNan);
479     lhsReal = b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan,
480                                  b.create<CopySignOp>(zero, lhsReal), lhsReal);
481     Value isSpecialCaseAndLhsImagIsNan =
482         b.create<AndOp>(isSpecialCase, lhsImagIsNan);
483     lhsImag = b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan,
484                                  b.create<CopySignOp>(zero, lhsImag), lhsImag);
485     Value isSpecialCaseAndRhsRealIsNan =
486         b.create<AndOp>(isSpecialCase, rhsRealIsNan);
487     rhsReal = b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan,
488                                  b.create<CopySignOp>(zero, rhsReal), rhsReal);
489     Value isSpecialCaseAndRhsImagIsNan =
490         b.create<AndOp>(isSpecialCase, rhsImagIsNan);
491     rhsImag = b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan,
492                                  b.create<CopySignOp>(zero, rhsImag), rhsImag);
493     recalc = b.create<OrOp>(recalc, isSpecialCase);
494     recalc = b.create<AndOp>(isNan, recalc);
495 
496     // Recalculate real part.
497     lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
498     lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag);
499     Value newReal = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
500     real = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newReal), real);
501 
502     // Recalculate imag part.
503     lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal);
504     lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag);
505     Value newImag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
506     imag = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newImag), imag);
507 
508     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
509     return success();
510   }
511 };
512 
513 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
514   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
515 
516   LogicalResult
517   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
518                   ConversionPatternRewriter &rewriter) const override {
519     auto loc = op.getLoc();
520     auto type = adaptor.complex().getType().cast<ComplexType>();
521     auto elementType = type.getElementType().cast<FloatType>();
522 
523     Value real =
524         rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex());
525     Value imag =
526         rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex());
527     Value negReal = rewriter.create<NegFOp>(loc, real);
528     Value negImag = rewriter.create<NegFOp>(loc, imag);
529     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
530     return success();
531   }
532 };
533 
534 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
535   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
536 
537   LogicalResult
538   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
539                   ConversionPatternRewriter &rewriter) const override {
540     auto type = adaptor.complex().getType().cast<ComplexType>();
541     auto elementType = type.getElementType().cast<FloatType>();
542     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
543 
544     Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
545     Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
546     Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
547     Value realIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, real, zero);
548     Value imagIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, imag, zero);
549     Value isZero = b.create<AndOp>(realIsZero, imagIsZero);
550     auto abs = b.create<complex::AbsOp>(elementType, adaptor.complex());
551     Value realSign = b.create<DivFOp>(real, abs);
552     Value imagSign = b.create<DivFOp>(imag, abs);
553     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
554     rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, adaptor.complex(), sign);
555     return success();
556   }
557 };
558 } // namespace
559 
560 void mlir::populateComplexToStandardConversionPatterns(
561     RewritePatternSet &patterns) {
562   // clang-format off
563   patterns.add<
564       AbsOpConversion,
565       ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
566       ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
567       BinaryComplexOpConversion<complex::AddOp, AddFOp>,
568       BinaryComplexOpConversion<complex::SubOp, SubFOp>,
569       DivOpConversion,
570       ExpOpConversion,
571       LogOpConversion,
572       Log1pOpConversion,
573       MulOpConversion,
574       NegOpConversion,
575       SignOpConversion>(patterns.getContext());
576   // clang-format on
577 }
578 
579 namespace {
580 struct ConvertComplexToStandardPass
581     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
582   void runOnFunction() override;
583 };
584 
585 void ConvertComplexToStandardPass::runOnFunction() {
586   auto function = getFunction();
587 
588   // Convert to the Standard dialect using the converter defined above.
589   RewritePatternSet patterns(&getContext());
590   populateComplexToStandardConversionPatterns(patterns);
591 
592   ConversionTarget target(getContext());
593   target.addLegalDialect<StandardOpsDialect, math::MathDialect>();
594   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
595   if (failed(applyPartialConversion(function, target, std::move(patterns))))
596     signalPassFailure();
597 }
598 } // namespace
599 
600 std::unique_ptr<OperationPass<FuncOp>>
601 mlir::createConvertComplexToStandardPass() {
602   return std::make_unique<ConvertComplexToStandardPass>();
603 }
604