1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Complex/IR/Complex.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
27   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
28 
29   LogicalResult
30   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
31                   ConversionPatternRewriter &rewriter) const override {
32     auto loc = op.getLoc();
33     auto type = op.getType();
34 
35     Value real =
36         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
37     Value imag =
38         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
39     Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
40     Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
41     Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
42 
43     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
44     return success();
45   }
46 };
47 
48 template <typename ComparisonOp, arith::CmpFPredicate p>
49 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
50   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
51   using ResultCombiner =
52       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
53                          arith::AndIOp, arith::OrIOp>;
54 
55   LogicalResult
56   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
57                   ConversionPatternRewriter &rewriter) const override {
58     auto loc = op.getLoc();
59     auto type = adaptor.getLhs()
60                     .getType()
61                     .template cast<ComplexType>()
62                     .getElementType();
63 
64     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
65     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
66     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
67     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
68     Value realComparison =
69         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
70     Value imagComparison =
71         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
72 
73     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
74                                                 imagComparison);
75     return success();
76   }
77 };
78 
79 // Default conversion which applies the BinaryStandardOp separately on the real
80 // and imaginary parts. Can for example be used for complex::AddOp and
81 // complex::SubOp.
82 template <typename BinaryComplexOp, typename BinaryStandardOp>
83 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
84   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
85 
86   LogicalResult
87   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
88                   ConversionPatternRewriter &rewriter) const override {
89     auto type = adaptor.getLhs().getType().template cast<ComplexType>();
90     auto elementType = type.getElementType().template cast<FloatType>();
91     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
92 
93     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
94     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
95     Value resultReal =
96         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
97     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
98     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
99     Value resultImag =
100         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
101     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
102                                                    resultImag);
103     return success();
104   }
105 };
106 
107 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
108   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
109 
110   LogicalResult
111   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
112                   ConversionPatternRewriter &rewriter) const override {
113     auto loc = op.getLoc();
114     auto type = adaptor.getLhs().getType().cast<ComplexType>();
115     auto elementType = type.getElementType().cast<FloatType>();
116 
117     Value lhsReal =
118         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
119     Value lhsImag =
120         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
121     Value rhsReal =
122         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
123     Value rhsImag =
124         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
125 
126     // Smith's algorithm to divide complex numbers. It is just a bit smarter
127     // way to compute the following formula:
128     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
129     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
130     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
131     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
132     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
133     //
134     // Depending on whether |rhsReal| < |rhsImag| we compute either
135     //   rhsRealImagRatio = rhsReal / rhsImag
136     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
137     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
138     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
139     //
140     // or
141     //
142     //   rhsImagRealRatio = rhsImag / rhsReal
143     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
144     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
145     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
146     //
147     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
148     Value rhsRealImagRatio =
149         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
150     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
151         loc, rhsImag,
152         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
153     Value realNumerator1 = rewriter.create<arith::AddFOp>(
154         loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
155         lhsImag);
156     Value resultReal1 =
157         rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
158     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
159         loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
160         lhsReal);
161     Value resultImag1 =
162         rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
163 
164     Value rhsImagRealRatio =
165         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
166     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
167         loc, rhsReal,
168         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
169     Value realNumerator2 = rewriter.create<arith::AddFOp>(
170         loc, lhsReal,
171         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
172     Value resultReal2 =
173         rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
174     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
175         loc, lhsImag,
176         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
177     Value resultImag2 =
178         rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
179 
180     // Consider corner cases.
181     // Case 1. Zero denominator, numerator contains at most one NaN value.
182     Value zero = rewriter.create<arith::ConstantOp>(
183         loc, elementType, rewriter.getZeroAttr(elementType));
184     Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
185     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
186         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
187     Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
188     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
189         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
190     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
191         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
192     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
193         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
194     Value lhsContainsNotNaNValue =
195         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
196     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
197         loc, lhsContainsNotNaNValue,
198         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
199     Value inf = rewriter.create<arith::ConstantOp>(
200         loc, elementType,
201         rewriter.getFloatAttr(
202             elementType, APFloat::getInf(elementType.getFloatSemantics())));
203     Value infWithSignOfRhsReal =
204         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
205     Value infinityResultReal =
206         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
207     Value infinityResultImag =
208         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
209 
210     // Case 2. Infinite numerator, finite denominator.
211     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
212         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
213     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
214         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
215     Value rhsFinite =
216         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
217     Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
218     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
219         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
220     Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
221     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
222         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
223     Value lhsInfinite =
224         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
225     Value infNumFiniteDenom =
226         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
227     Value one = rewriter.create<arith::ConstantOp>(
228         loc, elementType, rewriter.getFloatAttr(elementType, 1));
229     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
230         loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero),
231         lhsReal);
232     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
233         loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero),
234         lhsImag);
235     Value lhsRealIsInfWithSignTimesRhsReal =
236         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
237     Value lhsImagIsInfWithSignTimesRhsImag =
238         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
239     Value resultReal3 = rewriter.create<arith::MulFOp>(
240         loc, inf,
241         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
242                                        lhsImagIsInfWithSignTimesRhsImag));
243     Value lhsRealIsInfWithSignTimesRhsImag =
244         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
245     Value lhsImagIsInfWithSignTimesRhsReal =
246         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
247     Value resultImag3 = rewriter.create<arith::MulFOp>(
248         loc, inf,
249         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
250                                        lhsRealIsInfWithSignTimesRhsImag));
251 
252     // Case 3: Finite numerator, infinite denominator.
253     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
254         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
255     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
256         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
257     Value lhsFinite =
258         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
259     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
260         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
261     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
262         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
263     Value rhsInfinite =
264         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
265     Value finiteNumInfiniteDenom =
266         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
267     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
268         loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero),
269         rhsReal);
270     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
271         loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero),
272         rhsImag);
273     Value rhsRealIsInfWithSignTimesLhsReal =
274         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
275     Value rhsImagIsInfWithSignTimesLhsImag =
276         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
277     Value resultReal4 = rewriter.create<arith::MulFOp>(
278         loc, zero,
279         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
280                                        rhsImagIsInfWithSignTimesLhsImag));
281     Value rhsRealIsInfWithSignTimesLhsImag =
282         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
283     Value rhsImagIsInfWithSignTimesLhsReal =
284         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
285     Value resultImag4 = rewriter.create<arith::MulFOp>(
286         loc, zero,
287         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
288                                        rhsImagIsInfWithSignTimesLhsReal));
289 
290     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
291         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
292     Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
293                                                  resultReal1, resultReal2);
294     Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
295                                                  resultImag1, resultImag2);
296     Value resultRealSpecialCase3 = rewriter.create<SelectOp>(
297         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
298     Value resultImagSpecialCase3 = rewriter.create<SelectOp>(
299         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
300     Value resultRealSpecialCase2 = rewriter.create<SelectOp>(
301         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
302     Value resultImagSpecialCase2 = rewriter.create<SelectOp>(
303         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
304     Value resultRealSpecialCase1 = rewriter.create<SelectOp>(
305         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
306     Value resultImagSpecialCase1 = rewriter.create<SelectOp>(
307         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
308 
309     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
310         loc, arith::CmpFPredicate::UNO, resultReal, zero);
311     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
312         loc, arith::CmpFPredicate::UNO, resultImag, zero);
313     Value resultIsNaN =
314         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
315     Value resultRealWithSpecialCases = rewriter.create<SelectOp>(
316         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
317     Value resultImagWithSpecialCases = rewriter.create<SelectOp>(
318         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
319 
320     rewriter.replaceOpWithNewOp<complex::CreateOp>(
321         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
322     return success();
323   }
324 };
325 
326 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
327   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
328 
329   LogicalResult
330   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
331                   ConversionPatternRewriter &rewriter) const override {
332     auto loc = op.getLoc();
333     auto type = adaptor.getComplex().getType().cast<ComplexType>();
334     auto elementType = type.getElementType().cast<FloatType>();
335 
336     Value real =
337         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
338     Value imag =
339         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
340     Value expReal = rewriter.create<math::ExpOp>(loc, real);
341     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
342     Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
343     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
344     Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
345 
346     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
347                                                    resultImag);
348     return success();
349   }
350 };
351 
352 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
353   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
354 
355   LogicalResult
356   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
357                   ConversionPatternRewriter &rewriter) const override {
358     auto type = adaptor.getComplex().getType().cast<ComplexType>();
359     auto elementType = type.getElementType().cast<FloatType>();
360     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
361 
362     Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
363     Value resultReal = b.create<math::LogOp>(elementType, abs);
364     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
365     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
366     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
367     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
368                                                    resultImag);
369     return success();
370   }
371 };
372 
373 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
374   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
375 
376   LogicalResult
377   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
378                   ConversionPatternRewriter &rewriter) const override {
379     auto type = adaptor.getComplex().getType().cast<ComplexType>();
380     auto elementType = type.getElementType().cast<FloatType>();
381     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
382 
383     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
384     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
385     Value one = b.create<arith::ConstantOp>(elementType,
386                                             b.getFloatAttr(elementType, 1));
387     Value realPlusOne = b.create<arith::AddFOp>(real, one);
388     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
389     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
390     return success();
391   }
392 };
393 
394 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
395   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
396 
397   LogicalResult
398   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
399                   ConversionPatternRewriter &rewriter) const override {
400     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
401     auto type = adaptor.getLhs().getType().cast<ComplexType>();
402     auto elementType = type.getElementType().cast<FloatType>();
403 
404     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
405     Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
406     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
407     Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
408     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
409     Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
410     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
411     Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
412 
413     Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
414     Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
415     Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
416     Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
417     Value real =
418         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
419 
420     Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
421     Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
422     Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
423     Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
424     Value imag =
425         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
426 
427     // Handle cases where the "naive" calculation results in NaN values.
428     Value realIsNan =
429         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
430     Value imagIsNan =
431         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
432     Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
433 
434     Value inf = b.create<arith::ConstantOp>(
435         elementType,
436         b.getFloatAttr(elementType,
437                        APFloat::getInf(elementType.getFloatSemantics())));
438 
439     // Case 1. `lhsReal` or `lhsImag` are infinite.
440     Value lhsRealIsInf =
441         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
442     Value lhsImagIsInf =
443         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
444     Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
445     Value rhsRealIsNan =
446         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
447     Value rhsImagIsNan =
448         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
449     Value zero =
450         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
451     Value one = b.create<arith::ConstantOp>(elementType,
452                                             b.getFloatAttr(elementType, 1));
453     Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero);
454     lhsReal = b.create<SelectOp>(
455         lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
456         lhsReal);
457     Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero);
458     lhsImag = b.create<SelectOp>(
459         lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
460         lhsImag);
461     Value lhsIsInfAndRhsRealIsNan =
462         b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
463     rhsReal =
464         b.create<SelectOp>(lhsIsInfAndRhsRealIsNan,
465                            b.create<math::CopySignOp>(zero, rhsReal), rhsReal);
466     Value lhsIsInfAndRhsImagIsNan =
467         b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
468     rhsImag =
469         b.create<SelectOp>(lhsIsInfAndRhsImagIsNan,
470                            b.create<math::CopySignOp>(zero, rhsImag), rhsImag);
471 
472     // Case 2. `rhsReal` or `rhsImag` are infinite.
473     Value rhsRealIsInf =
474         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
475     Value rhsImagIsInf =
476         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
477     Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
478     Value lhsRealIsNan =
479         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
480     Value lhsImagIsNan =
481         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
482     Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero);
483     rhsReal = b.create<SelectOp>(
484         rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
485         rhsReal);
486     Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero);
487     rhsImag = b.create<SelectOp>(
488         rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
489         rhsImag);
490     Value rhsIsInfAndLhsRealIsNan =
491         b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
492     lhsReal =
493         b.create<SelectOp>(rhsIsInfAndLhsRealIsNan,
494                            b.create<math::CopySignOp>(zero, lhsReal), lhsReal);
495     Value rhsIsInfAndLhsImagIsNan =
496         b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
497     lhsImag =
498         b.create<SelectOp>(rhsIsInfAndLhsImagIsNan,
499                            b.create<math::CopySignOp>(zero, lhsImag), lhsImag);
500     Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
501 
502     // Case 3. One of the pairwise products of left hand side with right hand
503     // side is infinite.
504     Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
505         arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
506     Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
507         arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
508     Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
509                                                  lhsImagTimesRhsImagIsInf);
510     Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
511         arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
512     isSpecialCase =
513         b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
514     Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
515         arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
516     isSpecialCase =
517         b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
518     Type i1Type = b.getI1Type();
519     Value notRecalc = b.create<arith::XOrIOp>(
520         recalc,
521         b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
522     isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
523     Value isSpecialCaseAndLhsRealIsNan =
524         b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
525     lhsReal =
526         b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan,
527                            b.create<math::CopySignOp>(zero, lhsReal), lhsReal);
528     Value isSpecialCaseAndLhsImagIsNan =
529         b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
530     lhsImag =
531         b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan,
532                            b.create<math::CopySignOp>(zero, lhsImag), lhsImag);
533     Value isSpecialCaseAndRhsRealIsNan =
534         b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
535     rhsReal =
536         b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan,
537                            b.create<math::CopySignOp>(zero, rhsReal), rhsReal);
538     Value isSpecialCaseAndRhsImagIsNan =
539         b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
540     rhsImag =
541         b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan,
542                            b.create<math::CopySignOp>(zero, rhsImag), rhsImag);
543     recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
544     recalc = b.create<arith::AndIOp>(isNan, recalc);
545 
546     // Recalculate real part.
547     lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
548     lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
549     Value newReal =
550         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
551     real =
552         b.create<SelectOp>(recalc, b.create<arith::MulFOp>(inf, newReal), real);
553 
554     // Recalculate imag part.
555     lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
556     lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
557     Value newImag =
558         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
559     imag =
560         b.create<SelectOp>(recalc, b.create<arith::MulFOp>(inf, newImag), imag);
561 
562     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
563     return success();
564   }
565 };
566 
567 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
568   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
569 
570   LogicalResult
571   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
572                   ConversionPatternRewriter &rewriter) const override {
573     auto loc = op.getLoc();
574     auto type = adaptor.getComplex().getType().cast<ComplexType>();
575     auto elementType = type.getElementType().cast<FloatType>();
576 
577     Value real =
578         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
579     Value imag =
580         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
581     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
582     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
583     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
584     return success();
585   }
586 };
587 
588 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
589   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
590 
591   LogicalResult
592   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
593                   ConversionPatternRewriter &rewriter) const override {
594     auto type = adaptor.getComplex().getType().cast<ComplexType>();
595     auto elementType = type.getElementType().cast<FloatType>();
596     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
597 
598     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
599     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
600     Value zero =
601         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
602     Value realIsZero =
603         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
604     Value imagIsZero =
605         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
606     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
607     auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
608     Value realSign = b.create<arith::DivFOp>(real, abs);
609     Value imagSign = b.create<arith::DivFOp>(imag, abs);
610     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
611     rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, adaptor.getComplex(),
612                                           sign);
613     return success();
614   }
615 };
616 } // namespace
617 
618 void mlir::populateComplexToStandardConversionPatterns(
619     RewritePatternSet &patterns) {
620   // clang-format off
621   patterns.add<
622       AbsOpConversion,
623       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
624       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
625       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
626       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
627       DivOpConversion,
628       ExpOpConversion,
629       LogOpConversion,
630       Log1pOpConversion,
631       MulOpConversion,
632       NegOpConversion,
633       SignOpConversion>(patterns.getContext());
634   // clang-format on
635 }
636 
637 namespace {
638 struct ConvertComplexToStandardPass
639     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
640   void runOnOperation() override;
641 };
642 
643 void ConvertComplexToStandardPass::runOnOperation() {
644   auto function = getOperation();
645 
646   // Convert to the Standard dialect using the converter defined above.
647   RewritePatternSet patterns(&getContext());
648   populateComplexToStandardConversionPatterns(patterns);
649 
650   ConversionTarget target(getContext());
651   target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect,
652                          math::MathDialect>();
653   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
654   if (failed(applyPartialConversion(function, target, std::move(patterns))))
655     signalPassFailure();
656 }
657 } // namespace
658 
659 std::unique_ptr<OperationPass<FuncOp>>
660 mlir::createConvertComplexToStandardPass() {
661   return std::make_unique<ConvertComplexToStandardPass>();
662 }
663