1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Complex/IR/Complex.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
26   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
30                   ConversionPatternRewriter &rewriter) const override {
31     auto loc = op.getLoc();
32     auto type = op.getType();
33 
34     Value real =
35         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
36     Value imag =
37         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
38     Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
39     Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
40     Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
41 
42     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
43     return success();
44   }
45 };
46 
47 template <typename ComparisonOp, arith::CmpFPredicate p>
48 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
49   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
50   using ResultCombiner =
51       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
52                          arith::AndIOp, arith::OrIOp>;
53 
54   LogicalResult
55   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
56                   ConversionPatternRewriter &rewriter) const override {
57     auto loc = op.getLoc();
58     auto type = adaptor.getLhs()
59                     .getType()
60                     .template cast<ComplexType>()
61                     .getElementType();
62 
63     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
64     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
65     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
66     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
67     Value realComparison =
68         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
69     Value imagComparison =
70         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
71 
72     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
73                                                 imagComparison);
74     return success();
75   }
76 };
77 
78 // Default conversion which applies the BinaryStandardOp separately on the real
79 // and imaginary parts. Can for example be used for complex::AddOp and
80 // complex::SubOp.
81 template <typename BinaryComplexOp, typename BinaryStandardOp>
82 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
83   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
84 
85   LogicalResult
86   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
87                   ConversionPatternRewriter &rewriter) const override {
88     auto type = adaptor.getLhs().getType().template cast<ComplexType>();
89     auto elementType = type.getElementType().template cast<FloatType>();
90     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
91 
92     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
93     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
94     Value resultReal =
95         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
96     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
97     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
98     Value resultImag =
99         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
100     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
101                                                    resultImag);
102     return success();
103   }
104 };
105 
106 template <typename TrigonometricOp>
107 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
108   using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
109 
110   using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
111 
112   LogicalResult
113   matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
114                   ConversionPatternRewriter &rewriter) const override {
115     auto loc = op.getLoc();
116     auto type = adaptor.getComplex().getType().template cast<ComplexType>();
117     auto elementType = type.getElementType().template cast<FloatType>();
118 
119     Value real =
120         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
121     Value imag =
122         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
123 
124     // Trigonometric ops use a set of common building blocks to convert to real
125     // ops. Here we create these building blocks and call into an op-specific
126     // implementation in the subclass to combine them.
127     Value half = rewriter.create<arith::ConstantOp>(
128         loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
129     Value exp = rewriter.create<math::ExpOp>(loc, imag);
130     Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
131     Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
132     Value sin = rewriter.create<math::SinOp>(loc, real);
133     Value cos = rewriter.create<math::CosOp>(loc, real);
134 
135     auto resultPair =
136         combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
137 
138     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
139                                                    resultPair.second);
140     return success();
141   }
142 
143   virtual std::pair<Value, Value>
144   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
145           Value cos, ConversionPatternRewriter &rewriter) const = 0;
146 };
147 
148 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
149   using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
150 
151   std::pair<Value, Value>
152   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
153           Value cos, ConversionPatternRewriter &rewriter) const override {
154     // Complex cosine is defined as;
155     //   cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
156     // Plugging in:
157     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
158     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
159     // and defining t := exp(y)
160     // We get:
161     //   Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
162     //   Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
163     Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
164     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
165     Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
166     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
167     return {resultReal, resultImag};
168   }
169 };
170 
171 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
172   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
173 
174   LogicalResult
175   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
176                   ConversionPatternRewriter &rewriter) const override {
177     auto loc = op.getLoc();
178     auto type = adaptor.getLhs().getType().cast<ComplexType>();
179     auto elementType = type.getElementType().cast<FloatType>();
180 
181     Value lhsReal =
182         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
183     Value lhsImag =
184         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
185     Value rhsReal =
186         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
187     Value rhsImag =
188         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
189 
190     // Smith's algorithm to divide complex numbers. It is just a bit smarter
191     // way to compute the following formula:
192     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
193     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
194     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
195     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
196     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
197     //
198     // Depending on whether |rhsReal| < |rhsImag| we compute either
199     //   rhsRealImagRatio = rhsReal / rhsImag
200     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
201     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
202     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
203     //
204     // or
205     //
206     //   rhsImagRealRatio = rhsImag / rhsReal
207     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
208     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
209     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
210     //
211     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
212     Value rhsRealImagRatio =
213         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
214     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
215         loc, rhsImag,
216         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
217     Value realNumerator1 = rewriter.create<arith::AddFOp>(
218         loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
219         lhsImag);
220     Value resultReal1 =
221         rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
222     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
223         loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
224         lhsReal);
225     Value resultImag1 =
226         rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
227 
228     Value rhsImagRealRatio =
229         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
230     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
231         loc, rhsReal,
232         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
233     Value realNumerator2 = rewriter.create<arith::AddFOp>(
234         loc, lhsReal,
235         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
236     Value resultReal2 =
237         rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
238     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
239         loc, lhsImag,
240         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
241     Value resultImag2 =
242         rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
243 
244     // Consider corner cases.
245     // Case 1. Zero denominator, numerator contains at most one NaN value.
246     Value zero = rewriter.create<arith::ConstantOp>(
247         loc, elementType, rewriter.getZeroAttr(elementType));
248     Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
249     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
250         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
251     Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
252     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
253         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
254     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
255         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
256     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
257         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
258     Value lhsContainsNotNaNValue =
259         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
260     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
261         loc, lhsContainsNotNaNValue,
262         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
263     Value inf = rewriter.create<arith::ConstantOp>(
264         loc, elementType,
265         rewriter.getFloatAttr(
266             elementType, APFloat::getInf(elementType.getFloatSemantics())));
267     Value infWithSignOfRhsReal =
268         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
269     Value infinityResultReal =
270         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
271     Value infinityResultImag =
272         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
273 
274     // Case 2. Infinite numerator, finite denominator.
275     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
276         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
277     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
278         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
279     Value rhsFinite =
280         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
281     Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
282     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
283         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
284     Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
285     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
286         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
287     Value lhsInfinite =
288         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
289     Value infNumFiniteDenom =
290         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
291     Value one = rewriter.create<arith::ConstantOp>(
292         loc, elementType, rewriter.getFloatAttr(elementType, 1));
293     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
294         loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
295         lhsReal);
296     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
297         loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
298         lhsImag);
299     Value lhsRealIsInfWithSignTimesRhsReal =
300         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
301     Value lhsImagIsInfWithSignTimesRhsImag =
302         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
303     Value resultReal3 = rewriter.create<arith::MulFOp>(
304         loc, inf,
305         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
306                                        lhsImagIsInfWithSignTimesRhsImag));
307     Value lhsRealIsInfWithSignTimesRhsImag =
308         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
309     Value lhsImagIsInfWithSignTimesRhsReal =
310         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
311     Value resultImag3 = rewriter.create<arith::MulFOp>(
312         loc, inf,
313         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
314                                        lhsRealIsInfWithSignTimesRhsImag));
315 
316     // Case 3: Finite numerator, infinite denominator.
317     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
318         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
319     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
320         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
321     Value lhsFinite =
322         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
323     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
324         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
325     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
326         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
327     Value rhsInfinite =
328         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
329     Value finiteNumInfiniteDenom =
330         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
331     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
332         loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
333         rhsReal);
334     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
335         loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
336         rhsImag);
337     Value rhsRealIsInfWithSignTimesLhsReal =
338         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
339     Value rhsImagIsInfWithSignTimesLhsImag =
340         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
341     Value resultReal4 = rewriter.create<arith::MulFOp>(
342         loc, zero,
343         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
344                                        rhsImagIsInfWithSignTimesLhsImag));
345     Value rhsRealIsInfWithSignTimesLhsImag =
346         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
347     Value rhsImagIsInfWithSignTimesLhsReal =
348         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
349     Value resultImag4 = rewriter.create<arith::MulFOp>(
350         loc, zero,
351         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
352                                        rhsImagIsInfWithSignTimesLhsReal));
353 
354     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
355         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
356     Value resultReal = rewriter.create<arith::SelectOp>(
357         loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
358     Value resultImag = rewriter.create<arith::SelectOp>(
359         loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
360     Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
361         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
362     Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
363         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
364     Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
365         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
366     Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
367         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
368     Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
369         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
370     Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
371         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
372 
373     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
374         loc, arith::CmpFPredicate::UNO, resultReal, zero);
375     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
376         loc, arith::CmpFPredicate::UNO, resultImag, zero);
377     Value resultIsNaN =
378         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
379     Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
380         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
381     Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
382         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
383 
384     rewriter.replaceOpWithNewOp<complex::CreateOp>(
385         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
386     return success();
387   }
388 };
389 
390 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
391   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
392 
393   LogicalResult
394   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
395                   ConversionPatternRewriter &rewriter) const override {
396     auto loc = op.getLoc();
397     auto type = adaptor.getComplex().getType().cast<ComplexType>();
398     auto elementType = type.getElementType().cast<FloatType>();
399 
400     Value real =
401         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
402     Value imag =
403         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
404     Value expReal = rewriter.create<math::ExpOp>(loc, real);
405     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
406     Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
407     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
408     Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
409 
410     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
411                                                    resultImag);
412     return success();
413   }
414 };
415 
416 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
417   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
418 
419   LogicalResult
420   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
421                   ConversionPatternRewriter &rewriter) const override {
422     auto type = adaptor.getComplex().getType().cast<ComplexType>();
423     auto elementType = type.getElementType().cast<FloatType>();
424     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
425 
426     Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
427     Value resultReal = b.create<math::LogOp>(elementType, abs);
428     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
429     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
430     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
431     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
432                                                    resultImag);
433     return success();
434   }
435 };
436 
437 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
438   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
439 
440   LogicalResult
441   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
442                   ConversionPatternRewriter &rewriter) const override {
443     auto type = adaptor.getComplex().getType().cast<ComplexType>();
444     auto elementType = type.getElementType().cast<FloatType>();
445     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
446 
447     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
448     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
449     Value one = b.create<arith::ConstantOp>(elementType,
450                                             b.getFloatAttr(elementType, 1));
451     Value realPlusOne = b.create<arith::AddFOp>(real, one);
452     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
453     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
454     return success();
455   }
456 };
457 
458 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
459   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
460 
461   LogicalResult
462   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
463                   ConversionPatternRewriter &rewriter) const override {
464     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
465     auto type = adaptor.getLhs().getType().cast<ComplexType>();
466     auto elementType = type.getElementType().cast<FloatType>();
467 
468     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
469     Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
470     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
471     Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
472     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
473     Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
474     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
475     Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
476 
477     Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
478     Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
479     Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
480     Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
481     Value real =
482         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
483 
484     Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
485     Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
486     Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
487     Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
488     Value imag =
489         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
490 
491     // Handle cases where the "naive" calculation results in NaN values.
492     Value realIsNan =
493         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
494     Value imagIsNan =
495         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
496     Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
497 
498     Value inf = b.create<arith::ConstantOp>(
499         elementType,
500         b.getFloatAttr(elementType,
501                        APFloat::getInf(elementType.getFloatSemantics())));
502 
503     // Case 1. `lhsReal` or `lhsImag` are infinite.
504     Value lhsRealIsInf =
505         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
506     Value lhsImagIsInf =
507         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
508     Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
509     Value rhsRealIsNan =
510         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
511     Value rhsImagIsNan =
512         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
513     Value zero =
514         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
515     Value one = b.create<arith::ConstantOp>(elementType,
516                                             b.getFloatAttr(elementType, 1));
517     Value lhsRealIsInfFloat =
518         b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
519     lhsReal = b.create<arith::SelectOp>(
520         lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
521         lhsReal);
522     Value lhsImagIsInfFloat =
523         b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
524     lhsImag = b.create<arith::SelectOp>(
525         lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
526         lhsImag);
527     Value lhsIsInfAndRhsRealIsNan =
528         b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
529     rhsReal = b.create<arith::SelectOp>(
530         lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
531         rhsReal);
532     Value lhsIsInfAndRhsImagIsNan =
533         b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
534     rhsImag = b.create<arith::SelectOp>(
535         lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
536         rhsImag);
537 
538     // Case 2. `rhsReal` or `rhsImag` are infinite.
539     Value rhsRealIsInf =
540         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
541     Value rhsImagIsInf =
542         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
543     Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
544     Value lhsRealIsNan =
545         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
546     Value lhsImagIsNan =
547         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
548     Value rhsRealIsInfFloat =
549         b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
550     rhsReal = b.create<arith::SelectOp>(
551         rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
552         rhsReal);
553     Value rhsImagIsInfFloat =
554         b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
555     rhsImag = b.create<arith::SelectOp>(
556         rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
557         rhsImag);
558     Value rhsIsInfAndLhsRealIsNan =
559         b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
560     lhsReal = b.create<arith::SelectOp>(
561         rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
562         lhsReal);
563     Value rhsIsInfAndLhsImagIsNan =
564         b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
565     lhsImag = b.create<arith::SelectOp>(
566         rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
567         lhsImag);
568     Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
569 
570     // Case 3. One of the pairwise products of left hand side with right hand
571     // side is infinite.
572     Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
573         arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
574     Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
575         arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
576     Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
577                                                  lhsImagTimesRhsImagIsInf);
578     Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
579         arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
580     isSpecialCase =
581         b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
582     Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
583         arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
584     isSpecialCase =
585         b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
586     Type i1Type = b.getI1Type();
587     Value notRecalc = b.create<arith::XOrIOp>(
588         recalc,
589         b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
590     isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
591     Value isSpecialCaseAndLhsRealIsNan =
592         b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
593     lhsReal = b.create<arith::SelectOp>(
594         isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
595         lhsReal);
596     Value isSpecialCaseAndLhsImagIsNan =
597         b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
598     lhsImag = b.create<arith::SelectOp>(
599         isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
600         lhsImag);
601     Value isSpecialCaseAndRhsRealIsNan =
602         b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
603     rhsReal = b.create<arith::SelectOp>(
604         isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
605         rhsReal);
606     Value isSpecialCaseAndRhsImagIsNan =
607         b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
608     rhsImag = b.create<arith::SelectOp>(
609         isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
610         rhsImag);
611     recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
612     recalc = b.create<arith::AndIOp>(isNan, recalc);
613 
614     // Recalculate real part.
615     lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
616     lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
617     Value newReal =
618         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
619     real = b.create<arith::SelectOp>(
620         recalc, b.create<arith::MulFOp>(inf, newReal), real);
621 
622     // Recalculate imag part.
623     lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
624     lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
625     Value newImag =
626         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
627     imag = b.create<arith::SelectOp>(
628         recalc, b.create<arith::MulFOp>(inf, newImag), imag);
629 
630     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
631     return success();
632   }
633 };
634 
635 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
636   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
637 
638   LogicalResult
639   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
640                   ConversionPatternRewriter &rewriter) const override {
641     auto loc = op.getLoc();
642     auto type = adaptor.getComplex().getType().cast<ComplexType>();
643     auto elementType = type.getElementType().cast<FloatType>();
644 
645     Value real =
646         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
647     Value imag =
648         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
649     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
650     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
651     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
652     return success();
653   }
654 };
655 
656 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
657   using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
658 
659   std::pair<Value, Value>
660   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
661           Value cos, ConversionPatternRewriter &rewriter) const override {
662     // Complex sine is defined as;
663     //   sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
664     // Plugging in:
665     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
666     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
667     // and defining t := exp(y)
668     // We get:
669     //   Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
670     //   Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
671     Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
672     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
673     Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
674     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
675     return {resultReal, resultImag};
676   }
677 };
678 
679 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
680   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
681 
682   LogicalResult
683   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
684                   ConversionPatternRewriter &rewriter) const override {
685     auto type = adaptor.getComplex().getType().cast<ComplexType>();
686     auto elementType = type.getElementType().cast<FloatType>();
687     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
688 
689     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
690     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
691     Value zero =
692         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
693     Value realIsZero =
694         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
695     Value imagIsZero =
696         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
697     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
698     auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
699     Value realSign = b.create<arith::DivFOp>(real, abs);
700     Value imagSign = b.create<arith::DivFOp>(imag, abs);
701     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
702     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
703                                                  adaptor.getComplex(), sign);
704     return success();
705   }
706 };
707 } // namespace
708 
709 void mlir::populateComplexToStandardConversionPatterns(
710     RewritePatternSet &patterns) {
711   // clang-format off
712   patterns.add<
713       AbsOpConversion,
714       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
715       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
716       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
717       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
718       CosOpConversion,
719       DivOpConversion,
720       ExpOpConversion,
721       LogOpConversion,
722       Log1pOpConversion,
723       MulOpConversion,
724       NegOpConversion,
725       SignOpConversion,
726       SinOpConversion>(patterns.getContext());
727   // clang-format on
728 }
729 
730 namespace {
731 struct ConvertComplexToStandardPass
732     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
733   void runOnOperation() override;
734 };
735 
736 void ConvertComplexToStandardPass::runOnOperation() {
737   // Convert to the Standard dialect using the converter defined above.
738   RewritePatternSet patterns(&getContext());
739   populateComplexToStandardConversionPatterns(patterns);
740 
741   ConversionTarget target(getContext());
742   target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>();
743   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
744   if (failed(
745           applyPartialConversion(getOperation(), target, std::move(patterns))))
746     signalPassFailure();
747 }
748 } // namespace
749 
750 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
751   return std::make_unique<ConvertComplexToStandardPass>();
752 }
753