1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Complex/IR/Complex.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
26   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
30                   ConversionPatternRewriter &rewriter) const override {
31     complex::AbsOp::Adaptor transformed(operands);
32     auto loc = op.getLoc();
33     auto type = op.getType();
34 
35     Value real =
36         rewriter.create<complex::ReOp>(loc, type, transformed.complex());
37     Value imag =
38         rewriter.create<complex::ImOp>(loc, type, transformed.complex());
39     Value realSqr = rewriter.create<MulFOp>(loc, real, real);
40     Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag);
41     Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr);
42 
43     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
44     return success();
45   }
46 };
47 
48 template <typename ComparisonOp, CmpFPredicate p>
49 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
50   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
51   using ResultCombiner =
52       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
53                          AndOp, OrOp>;
54 
55   LogicalResult
56   matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands,
57                   ConversionPatternRewriter &rewriter) const override {
58     typename ComparisonOp::Adaptor transformed(operands);
59     auto loc = op.getLoc();
60     auto type = transformed.lhs()
61                     .getType()
62                     .template cast<ComplexType>()
63                     .getElementType();
64 
65     Value realLhs =
66         rewriter.create<complex::ReOp>(loc, type, transformed.lhs());
67     Value imagLhs =
68         rewriter.create<complex::ImOp>(loc, type, transformed.lhs());
69     Value realRhs =
70         rewriter.create<complex::ReOp>(loc, type, transformed.rhs());
71     Value imagRhs =
72         rewriter.create<complex::ImOp>(loc, type, transformed.rhs());
73     Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs);
74     Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs);
75 
76     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
77                                                 imagComparison);
78     return success();
79   }
80 };
81 
82 // Default conversion which applies the BinaryStandardOp separately on the real
83 // and imaginary parts. Can for example be used for complex::AddOp and
84 // complex::SubOp.
85 template <typename BinaryComplexOp, typename BinaryStandardOp>
86 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
87   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
88 
89   LogicalResult
90   matchAndRewrite(BinaryComplexOp op, ArrayRef<Value> operands,
91                   ConversionPatternRewriter &rewriter) const override {
92     typename BinaryComplexOp::Adaptor transformed(operands);
93     auto type = transformed.lhs().getType().template cast<ComplexType>();
94     auto elementType = type.getElementType().template cast<FloatType>();
95     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
96 
97     Value realLhs = b.create<complex::ReOp>(elementType, transformed.lhs());
98     Value realRhs = b.create<complex::ReOp>(elementType, transformed.rhs());
99     Value resultReal =
100         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
101     Value imagLhs = b.create<complex::ImOp>(elementType, transformed.lhs());
102     Value imagRhs = b.create<complex::ImOp>(elementType, transformed.rhs());
103     Value resultImag =
104         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
105     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
106                                                    resultImag);
107     return success();
108   }
109 };
110 
111 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
112   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
113 
114   LogicalResult
115   matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
116                   ConversionPatternRewriter &rewriter) const override {
117     complex::DivOp::Adaptor transformed(operands);
118     auto loc = op.getLoc();
119     auto type = transformed.lhs().getType().cast<ComplexType>();
120     auto elementType = type.getElementType().cast<FloatType>();
121 
122     Value lhsReal =
123         rewriter.create<complex::ReOp>(loc, elementType, transformed.lhs());
124     Value lhsImag =
125         rewriter.create<complex::ImOp>(loc, elementType, transformed.lhs());
126     Value rhsReal =
127         rewriter.create<complex::ReOp>(loc, elementType, transformed.rhs());
128     Value rhsImag =
129         rewriter.create<complex::ImOp>(loc, elementType, transformed.rhs());
130 
131     // Smith's algorithm to divide complex numbers. It is just a bit smarter
132     // way to compute the following formula:
133     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
134     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
135     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
136     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
137     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
138     //
139     // Depending on whether |rhsReal| < |rhsImag| we compute either
140     //   rhsRealImagRatio = rhsReal / rhsImag
141     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
142     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
143     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
144     //
145     // or
146     //
147     //   rhsImagRealRatio = rhsImag / rhsReal
148     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
149     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
150     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
151     //
152     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
153     Value rhsRealImagRatio = rewriter.create<DivFOp>(loc, rhsReal, rhsImag);
154     Value rhsRealImagDenom = rewriter.create<AddFOp>(
155         loc, rhsImag, rewriter.create<MulFOp>(loc, rhsRealImagRatio, rhsReal));
156     Value realNumerator1 = rewriter.create<AddFOp>(
157         loc, rewriter.create<MulFOp>(loc, lhsReal, rhsRealImagRatio), lhsImag);
158     Value resultReal1 =
159         rewriter.create<DivFOp>(loc, realNumerator1, rhsRealImagDenom);
160     Value imagNumerator1 = rewriter.create<SubFOp>(
161         loc, rewriter.create<MulFOp>(loc, lhsImag, rhsRealImagRatio), lhsReal);
162     Value resultImag1 =
163         rewriter.create<DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
164 
165     Value rhsImagRealRatio = rewriter.create<DivFOp>(loc, rhsImag, rhsReal);
166     Value rhsImagRealDenom = rewriter.create<AddFOp>(
167         loc, rhsReal, rewriter.create<MulFOp>(loc, rhsImagRealRatio, rhsImag));
168     Value realNumerator2 = rewriter.create<AddFOp>(
169         loc, lhsReal, rewriter.create<MulFOp>(loc, lhsImag, rhsImagRealRatio));
170     Value resultReal2 =
171         rewriter.create<DivFOp>(loc, realNumerator2, rhsImagRealDenom);
172     Value imagNumerator2 = rewriter.create<SubFOp>(
173         loc, lhsImag, rewriter.create<MulFOp>(loc, lhsReal, rhsImagRealRatio));
174     Value resultImag2 =
175         rewriter.create<DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
176 
177     // Consider corner cases.
178     // Case 1. Zero denominator, numerator contains at most one NaN value.
179     Value zero = rewriter.create<ConstantOp>(loc, elementType,
180                                              rewriter.getZeroAttr(elementType));
181     Value rhsRealAbs = rewriter.create<AbsFOp>(loc, rhsReal);
182     Value rhsRealIsZero =
183         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, zero);
184     Value rhsImagAbs = rewriter.create<AbsFOp>(loc, rhsImag);
185     Value rhsImagIsZero =
186         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, zero);
187     Value lhsRealIsNotNaN =
188         rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsReal, zero);
189     Value lhsImagIsNotNaN =
190         rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsImag, zero);
191     Value lhsContainsNotNaNValue =
192         rewriter.create<OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
193     Value resultIsInfinity = rewriter.create<AndOp>(
194         loc, lhsContainsNotNaNValue,
195         rewriter.create<AndOp>(loc, rhsRealIsZero, rhsImagIsZero));
196     Value inf = rewriter.create<ConstantOp>(
197         loc, elementType,
198         rewriter.getFloatAttr(
199             elementType, APFloat::getInf(elementType.getFloatSemantics())));
200     Value infWithSignOfRhsReal = rewriter.create<CopySignOp>(loc, inf, rhsReal);
201     Value infinityResultReal =
202         rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
203     Value infinityResultImag =
204         rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
205 
206     // Case 2. Infinite numerator, finite denominator.
207     Value rhsRealFinite =
208         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsRealAbs, inf);
209     Value rhsImagFinite =
210         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsImagAbs, inf);
211     Value rhsFinite = rewriter.create<AndOp>(loc, rhsRealFinite, rhsImagFinite);
212     Value lhsRealAbs = rewriter.create<AbsFOp>(loc, lhsReal);
213     Value lhsRealInfinite =
214         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsRealAbs, inf);
215     Value lhsImagAbs = rewriter.create<AbsFOp>(loc, lhsImag);
216     Value lhsImagInfinite =
217         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsImagAbs, inf);
218     Value lhsInfinite =
219         rewriter.create<OrOp>(loc, lhsRealInfinite, lhsImagInfinite);
220     Value infNumFiniteDenom =
221         rewriter.create<AndOp>(loc, lhsInfinite, rhsFinite);
222     Value one = rewriter.create<ConstantOp>(
223         loc, elementType, rewriter.getFloatAttr(elementType, 1));
224     Value lhsRealIsInfWithSign = rewriter.create<CopySignOp>(
225         loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero),
226         lhsReal);
227     Value lhsImagIsInfWithSign = rewriter.create<CopySignOp>(
228         loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero),
229         lhsImag);
230     Value lhsRealIsInfWithSignTimesRhsReal =
231         rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
232     Value lhsImagIsInfWithSignTimesRhsImag =
233         rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
234     Value resultReal3 = rewriter.create<MulFOp>(
235         loc, inf,
236         rewriter.create<AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
237                                 lhsImagIsInfWithSignTimesRhsImag));
238     Value lhsRealIsInfWithSignTimesRhsImag =
239         rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
240     Value lhsImagIsInfWithSignTimesRhsReal =
241         rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
242     Value resultImag3 = rewriter.create<MulFOp>(
243         loc, inf,
244         rewriter.create<SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
245                                 lhsRealIsInfWithSignTimesRhsImag));
246 
247     // Case 3: Finite numerator, infinite denominator.
248     Value lhsRealFinite =
249         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsRealAbs, inf);
250     Value lhsImagFinite =
251         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsImagAbs, inf);
252     Value lhsFinite = rewriter.create<AndOp>(loc, lhsRealFinite, lhsImagFinite);
253     Value rhsRealInfinite =
254         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, inf);
255     Value rhsImagInfinite =
256         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, inf);
257     Value rhsInfinite =
258         rewriter.create<OrOp>(loc, rhsRealInfinite, rhsImagInfinite);
259     Value finiteNumInfiniteDenom =
260         rewriter.create<AndOp>(loc, lhsFinite, rhsInfinite);
261     Value rhsRealIsInfWithSign = rewriter.create<CopySignOp>(
262         loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero),
263         rhsReal);
264     Value rhsImagIsInfWithSign = rewriter.create<CopySignOp>(
265         loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero),
266         rhsImag);
267     Value rhsRealIsInfWithSignTimesLhsReal =
268         rewriter.create<MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
269     Value rhsImagIsInfWithSignTimesLhsImag =
270         rewriter.create<MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
271     Value resultReal4 = rewriter.create<MulFOp>(
272         loc, zero,
273         rewriter.create<AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
274                                 rhsImagIsInfWithSignTimesLhsImag));
275     Value rhsRealIsInfWithSignTimesLhsImag =
276         rewriter.create<MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
277     Value rhsImagIsInfWithSignTimesLhsReal =
278         rewriter.create<MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
279     Value resultImag4 = rewriter.create<MulFOp>(
280         loc, zero,
281         rewriter.create<SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
282                                 rhsImagIsInfWithSignTimesLhsReal));
283 
284     Value realAbsSmallerThanImagAbs = rewriter.create<CmpFOp>(
285         loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
286     Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
287                                                  resultReal1, resultReal2);
288     Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
289                                                  resultImag1, resultImag2);
290     Value resultRealSpecialCase3 = rewriter.create<SelectOp>(
291         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
292     Value resultImagSpecialCase3 = rewriter.create<SelectOp>(
293         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
294     Value resultRealSpecialCase2 = rewriter.create<SelectOp>(
295         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
296     Value resultImagSpecialCase2 = rewriter.create<SelectOp>(
297         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
298     Value resultRealSpecialCase1 = rewriter.create<SelectOp>(
299         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
300     Value resultImagSpecialCase1 = rewriter.create<SelectOp>(
301         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
302 
303     Value resultRealIsNaN =
304         rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultReal, zero);
305     Value resultImagIsNaN =
306         rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultImag, zero);
307     Value resultIsNaN =
308         rewriter.create<AndOp>(loc, resultRealIsNaN, resultImagIsNaN);
309     Value resultRealWithSpecialCases = rewriter.create<SelectOp>(
310         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
311     Value resultImagWithSpecialCases = rewriter.create<SelectOp>(
312         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
313 
314     rewriter.replaceOpWithNewOp<complex::CreateOp>(
315         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
316     return success();
317   }
318 };
319 
320 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
321   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
322 
323   LogicalResult
324   matchAndRewrite(complex::ExpOp op, ArrayRef<Value> operands,
325                   ConversionPatternRewriter &rewriter) const override {
326     complex::ExpOp::Adaptor transformed(operands);
327     auto loc = op.getLoc();
328     auto type = transformed.complex().getType().cast<ComplexType>();
329     auto elementType = type.getElementType().cast<FloatType>();
330 
331     Value real =
332         rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
333     Value imag =
334         rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
335     Value expReal = rewriter.create<math::ExpOp>(loc, real);
336     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
337     Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag);
338     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
339     Value resultImag = rewriter.create<MulFOp>(loc, expReal, sinImag);
340 
341     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
342                                                    resultImag);
343     return success();
344   }
345 };
346 
347 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
348   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
349 
350   LogicalResult
351   matchAndRewrite(complex::LogOp op, ArrayRef<Value> operands,
352                   ConversionPatternRewriter &rewriter) const override {
353     complex::LogOp::Adaptor transformed(operands);
354     auto type = transformed.complex().getType().cast<ComplexType>();
355     auto elementType = type.getElementType().cast<FloatType>();
356     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
357 
358     Value abs = b.create<complex::AbsOp>(elementType, transformed.complex());
359     Value resultReal = b.create<math::LogOp>(elementType, abs);
360     Value real = b.create<complex::ReOp>(elementType, transformed.complex());
361     Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
362     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
363     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
364                                                    resultImag);
365     return success();
366   }
367 };
368 
369 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
370   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
371 
372   LogicalResult
373   matchAndRewrite(complex::Log1pOp op, ArrayRef<Value> operands,
374                   ConversionPatternRewriter &rewriter) const override {
375     complex::Log1pOp::Adaptor transformed(operands);
376     auto type = transformed.complex().getType().cast<ComplexType>();
377     auto elementType = type.getElementType().cast<FloatType>();
378     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
379 
380     Value real = b.create<complex::ReOp>(elementType, transformed.complex());
381     Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
382     Value one =
383         b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
384     Value realPlusOne = b.create<AddFOp>(real, one);
385     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
386     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
387     return success();
388   }
389 };
390 
391 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
392   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
393 
394   LogicalResult
395   matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
396                   ConversionPatternRewriter &rewriter) const override {
397     complex::MulOp::Adaptor transformed(operands);
398     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
399     auto type = transformed.lhs().getType().cast<ComplexType>();
400     auto elementType = type.getElementType().cast<FloatType>();
401 
402     Value lhsReal = b.create<complex::ReOp>(elementType, transformed.lhs());
403     Value lhsRealAbs = b.create<AbsFOp>(lhsReal);
404     Value lhsImag = b.create<complex::ImOp>(elementType, transformed.lhs());
405     Value lhsImagAbs = b.create<AbsFOp>(lhsImag);
406     Value rhsReal = b.create<complex::ReOp>(elementType, transformed.rhs());
407     Value rhsRealAbs = b.create<AbsFOp>(rhsReal);
408     Value rhsImag = b.create<complex::ImOp>(elementType, transformed.rhs());
409     Value rhsImagAbs = b.create<AbsFOp>(rhsImag);
410 
411     Value lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
412     Value lhsRealTimesRhsRealAbs = b.create<AbsFOp>(lhsRealTimesRhsReal);
413     Value lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag);
414     Value lhsImagTimesRhsImagAbs = b.create<AbsFOp>(lhsImagTimesRhsImag);
415     Value real = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
416 
417     Value lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal);
418     Value lhsImagTimesRhsRealAbs = b.create<AbsFOp>(lhsImagTimesRhsReal);
419     Value lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag);
420     Value lhsRealTimesRhsImagAbs = b.create<AbsFOp>(lhsRealTimesRhsImag);
421     Value imag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
422 
423     // Handle cases where the "naive" calculation results in NaN values.
424     Value realIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, real, real);
425     Value imagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, imag, imag);
426     Value isNan = b.create<AndOp>(realIsNan, imagIsNan);
427 
428     Value inf = b.create<ConstantOp>(
429         elementType,
430         b.getFloatAttr(elementType,
431                        APFloat::getInf(elementType.getFloatSemantics())));
432 
433     // Case 1. `lhsReal` or `lhsImag` are infinite.
434     Value lhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealAbs, inf);
435     Value lhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagAbs, inf);
436     Value lhsIsInf = b.create<OrOp>(lhsRealIsInf, lhsImagIsInf);
437     Value rhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsReal, rhsReal);
438     Value rhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsImag, rhsImag);
439     Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
440     Value one =
441         b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
442     Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero);
443     lhsReal = b.create<SelectOp>(
444         lhsIsInf, b.create<CopySignOp>(lhsRealIsInfFloat, lhsReal), lhsReal);
445     Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero);
446     lhsImag = b.create<SelectOp>(
447         lhsIsInf, b.create<CopySignOp>(lhsImagIsInfFloat, lhsImag), lhsImag);
448     Value lhsIsInfAndRhsRealIsNan = b.create<AndOp>(lhsIsInf, rhsRealIsNan);
449     rhsReal = b.create<SelectOp>(lhsIsInfAndRhsRealIsNan,
450                                  b.create<CopySignOp>(zero, rhsReal), rhsReal);
451     Value lhsIsInfAndRhsImagIsNan = b.create<AndOp>(lhsIsInf, rhsImagIsNan);
452     rhsImag = b.create<SelectOp>(lhsIsInfAndRhsImagIsNan,
453                                  b.create<CopySignOp>(zero, rhsImag), rhsImag);
454 
455     // Case 2. `rhsReal` or `rhsImag` are infinite.
456     Value rhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsRealAbs, inf);
457     Value rhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsImagAbs, inf);
458     Value rhsIsInf = b.create<OrOp>(rhsRealIsInf, rhsImagIsInf);
459     Value lhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsReal, lhsReal);
460     Value lhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsImag, lhsImag);
461     Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero);
462     rhsReal = b.create<SelectOp>(
463         rhsIsInf, b.create<CopySignOp>(rhsRealIsInfFloat, rhsReal), rhsReal);
464     Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero);
465     rhsImag = b.create<SelectOp>(
466         rhsIsInf, b.create<CopySignOp>(rhsImagIsInfFloat, rhsImag), rhsImag);
467     Value rhsIsInfAndLhsRealIsNan = b.create<AndOp>(rhsIsInf, lhsRealIsNan);
468     lhsReal = b.create<SelectOp>(rhsIsInfAndLhsRealIsNan,
469                                  b.create<CopySignOp>(zero, lhsReal), lhsReal);
470     Value rhsIsInfAndLhsImagIsNan = b.create<AndOp>(rhsIsInf, lhsImagIsNan);
471     lhsImag = b.create<SelectOp>(rhsIsInfAndLhsImagIsNan,
472                                  b.create<CopySignOp>(zero, lhsImag), lhsImag);
473     Value recalc = b.create<OrOp>(lhsIsInf, rhsIsInf);
474 
475     // Case 3. One of the pairwise products of left hand side with right hand
476     // side is infinite.
477     Value lhsRealTimesRhsRealIsInf =
478         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
479     Value lhsImagTimesRhsImagIsInf =
480         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
481     Value isSpecialCase =
482         b.create<OrOp>(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf);
483     Value lhsRealTimesRhsImagIsInf =
484         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
485     isSpecialCase = b.create<OrOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
486     Value lhsImagTimesRhsRealIsInf =
487         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
488     isSpecialCase = b.create<OrOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
489     Type i1Type = b.getI1Type();
490     Value notRecalc = b.create<XOrOp>(
491         recalc, b.create<ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
492     isSpecialCase = b.create<AndOp>(isSpecialCase, notRecalc);
493     Value isSpecialCaseAndLhsRealIsNan =
494         b.create<AndOp>(isSpecialCase, lhsRealIsNan);
495     lhsReal = b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan,
496                                  b.create<CopySignOp>(zero, lhsReal), lhsReal);
497     Value isSpecialCaseAndLhsImagIsNan =
498         b.create<AndOp>(isSpecialCase, lhsImagIsNan);
499     lhsImag = b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan,
500                                  b.create<CopySignOp>(zero, lhsImag), lhsImag);
501     Value isSpecialCaseAndRhsRealIsNan =
502         b.create<AndOp>(isSpecialCase, rhsRealIsNan);
503     rhsReal = b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan,
504                                  b.create<CopySignOp>(zero, rhsReal), rhsReal);
505     Value isSpecialCaseAndRhsImagIsNan =
506         b.create<AndOp>(isSpecialCase, rhsImagIsNan);
507     rhsImag = b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan,
508                                  b.create<CopySignOp>(zero, rhsImag), rhsImag);
509     recalc = b.create<OrOp>(recalc, isSpecialCase);
510     recalc = b.create<AndOp>(isNan, recalc);
511 
512     // Recalculate real part.
513     lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
514     lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag);
515     Value newReal = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
516     real = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newReal), real);
517 
518     // Recalculate imag part.
519     lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal);
520     lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag);
521     Value newImag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
522     imag = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newImag), imag);
523 
524     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
525     return success();
526   }
527 };
528 
529 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
530   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
531 
532   LogicalResult
533   matchAndRewrite(complex::NegOp op, ArrayRef<Value> operands,
534                   ConversionPatternRewriter &rewriter) const override {
535     complex::NegOp::Adaptor transformed(operands);
536     auto loc = op.getLoc();
537     auto type = transformed.complex().getType().cast<ComplexType>();
538     auto elementType = type.getElementType().cast<FloatType>();
539 
540     Value real =
541         rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
542     Value imag =
543         rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
544     Value negReal = rewriter.create<NegFOp>(loc, real);
545     Value negImag = rewriter.create<NegFOp>(loc, imag);
546     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
547     return success();
548   }
549 };
550 
551 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
552   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
553 
554   LogicalResult
555   matchAndRewrite(complex::SignOp op, ArrayRef<Value> operands,
556                   ConversionPatternRewriter &rewriter) const override {
557     complex::SignOp::Adaptor transformed(operands);
558     auto type = transformed.complex().getType().cast<ComplexType>();
559     auto elementType = type.getElementType().cast<FloatType>();
560     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
561 
562     Value real = b.create<complex::ReOp>(elementType, transformed.complex());
563     Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
564     Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
565     Value realIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, real, zero);
566     Value imagIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, imag, zero);
567     Value isZero = b.create<AndOp>(realIsZero, imagIsZero);
568     auto abs = b.create<complex::AbsOp>(elementType, transformed.complex());
569     Value realSign = b.create<DivFOp>(real, abs);
570     Value imagSign = b.create<DivFOp>(imag, abs);
571     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
572     rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, transformed.complex(),
573                                           sign);
574     return success();
575   }
576 };
577 } // namespace
578 
579 void mlir::populateComplexToStandardConversionPatterns(
580     RewritePatternSet &patterns) {
581   // clang-format off
582   patterns.add<
583       AbsOpConversion,
584       ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
585       ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
586       BinaryComplexOpConversion<complex::AddOp, AddFOp>,
587       BinaryComplexOpConversion<complex::SubOp, SubFOp>,
588       DivOpConversion,
589       ExpOpConversion,
590       LogOpConversion,
591       Log1pOpConversion,
592       MulOpConversion,
593       NegOpConversion,
594       SignOpConversion>(patterns.getContext());
595   // clang-format on
596 }
597 
598 namespace {
599 struct ConvertComplexToStandardPass
600     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
601   void runOnFunction() override;
602 };
603 
604 void ConvertComplexToStandardPass::runOnFunction() {
605   auto function = getFunction();
606 
607   // Convert to the Standard dialect using the converter defined above.
608   RewritePatternSet patterns(&getContext());
609   populateComplexToStandardConversionPatterns(patterns);
610 
611   ConversionTarget target(getContext());
612   target.addLegalDialect<StandardOpsDialect, math::MathDialect>();
613   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
614   if (failed(applyPartialConversion(function, target, std::move(patterns))))
615     signalPassFailure();
616 }
617 } // namespace
618 
619 std::unique_ptr<OperationPass<FuncOp>>
620 mlir::createConvertComplexToStandardPass() {
621   return std::make_unique<ConvertComplexToStandardPass>();
622 }
623