1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Complex/IR/Complex.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
26   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
30                   ConversionPatternRewriter &rewriter) const override {
31     auto loc = op.getLoc();
32     auto type = op.getType();
33 
34     Value real =
35         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
36     Value imag =
37         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
38     Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
39     Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
40     Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
41 
42     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
43     return success();
44   }
45 };
46 
47 template <typename ComparisonOp, arith::CmpFPredicate p>
48 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
49   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
50   using ResultCombiner =
51       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
52                          arith::AndIOp, arith::OrIOp>;
53 
54   LogicalResult
55   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
56                   ConversionPatternRewriter &rewriter) const override {
57     auto loc = op.getLoc();
58     auto type = adaptor.getLhs()
59                     .getType()
60                     .template cast<ComplexType>()
61                     .getElementType();
62 
63     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
64     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
65     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
66     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
67     Value realComparison =
68         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
69     Value imagComparison =
70         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
71 
72     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
73                                                 imagComparison);
74     return success();
75   }
76 };
77 
78 // Default conversion which applies the BinaryStandardOp separately on the real
79 // and imaginary parts. Can for example be used for complex::AddOp and
80 // complex::SubOp.
81 template <typename BinaryComplexOp, typename BinaryStandardOp>
82 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
83   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
84 
85   LogicalResult
86   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
87                   ConversionPatternRewriter &rewriter) const override {
88     auto type = adaptor.getLhs().getType().template cast<ComplexType>();
89     auto elementType = type.getElementType().template cast<FloatType>();
90     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
91 
92     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
93     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
94     Value resultReal =
95         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
96     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
97     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
98     Value resultImag =
99         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
100     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
101                                                    resultImag);
102     return success();
103   }
104 };
105 
106 template <typename TrigonometricOp>
107 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
108   using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
109 
110   using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
111 
112   LogicalResult
113   matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
114                   ConversionPatternRewriter &rewriter) const override {
115     auto loc = op.getLoc();
116     auto type = adaptor.getComplex().getType().template cast<ComplexType>();
117     auto elementType = type.getElementType().template cast<FloatType>();
118 
119     Value real =
120         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
121     Value imag =
122         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
123 
124     // Trigonometric ops use a set of common building blocks to convert to real
125     // ops. Here we create these building blocks and call into an op-specific
126     // implementation in the subclass to combine them.
127     Value half = rewriter.create<arith::ConstantOp>(
128         loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
129     Value exp = rewriter.create<math::ExpOp>(loc, imag);
130     Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
131     Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
132     Value sin = rewriter.create<math::SinOp>(loc, real);
133     Value cos = rewriter.create<math::CosOp>(loc, real);
134 
135     auto resultPair =
136         combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
137 
138     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
139                                                    resultPair.second);
140     return success();
141   }
142 
143   virtual std::pair<Value, Value>
144   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
145           Value cos, ConversionPatternRewriter &rewriter) const = 0;
146 };
147 
148 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
149   using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
150 
151   std::pair<Value, Value>
152   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
153           Value cos, ConversionPatternRewriter &rewriter) const override {
154     // Complex cosine is defined as;
155     //   cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
156     // Plugging in:
157     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
158     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
159     // and defining t := exp(y)
160     // We get:
161     //   Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
162     //   Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
163     Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
164     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
165     Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
166     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
167     return {resultReal, resultImag};
168   }
169 };
170 
171 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
172   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
173 
174   LogicalResult
175   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
176                   ConversionPatternRewriter &rewriter) const override {
177     auto loc = op.getLoc();
178     auto type = adaptor.getLhs().getType().cast<ComplexType>();
179     auto elementType = type.getElementType().cast<FloatType>();
180 
181     Value lhsReal =
182         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
183     Value lhsImag =
184         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
185     Value rhsReal =
186         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
187     Value rhsImag =
188         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
189 
190     // Smith's algorithm to divide complex numbers. It is just a bit smarter
191     // way to compute the following formula:
192     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
193     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
194     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
195     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
196     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
197     //
198     // Depending on whether |rhsReal| < |rhsImag| we compute either
199     //   rhsRealImagRatio = rhsReal / rhsImag
200     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
201     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
202     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
203     //
204     // or
205     //
206     //   rhsImagRealRatio = rhsImag / rhsReal
207     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
208     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
209     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
210     //
211     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
212     Value rhsRealImagRatio =
213         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
214     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
215         loc, rhsImag,
216         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
217     Value realNumerator1 = rewriter.create<arith::AddFOp>(
218         loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
219         lhsImag);
220     Value resultReal1 =
221         rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
222     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
223         loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
224         lhsReal);
225     Value resultImag1 =
226         rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
227 
228     Value rhsImagRealRatio =
229         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
230     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
231         loc, rhsReal,
232         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
233     Value realNumerator2 = rewriter.create<arith::AddFOp>(
234         loc, lhsReal,
235         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
236     Value resultReal2 =
237         rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
238     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
239         loc, lhsImag,
240         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
241     Value resultImag2 =
242         rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
243 
244     // Consider corner cases.
245     // Case 1. Zero denominator, numerator contains at most one NaN value.
246     Value zero = rewriter.create<arith::ConstantOp>(
247         loc, elementType, rewriter.getZeroAttr(elementType));
248     Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
249     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
250         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
251     Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
252     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
253         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
254     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
255         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
256     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
257         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
258     Value lhsContainsNotNaNValue =
259         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
260     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
261         loc, lhsContainsNotNaNValue,
262         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
263     Value inf = rewriter.create<arith::ConstantOp>(
264         loc, elementType,
265         rewriter.getFloatAttr(
266             elementType, APFloat::getInf(elementType.getFloatSemantics())));
267     Value infWithSignOfRhsReal =
268         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
269     Value infinityResultReal =
270         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
271     Value infinityResultImag =
272         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
273 
274     // Case 2. Infinite numerator, finite denominator.
275     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
276         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
277     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
278         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
279     Value rhsFinite =
280         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
281     Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
282     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
283         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
284     Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
285     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
286         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
287     Value lhsInfinite =
288         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
289     Value infNumFiniteDenom =
290         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
291     Value one = rewriter.create<arith::ConstantOp>(
292         loc, elementType, rewriter.getFloatAttr(elementType, 1));
293     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
294         loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
295         lhsReal);
296     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
297         loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
298         lhsImag);
299     Value lhsRealIsInfWithSignTimesRhsReal =
300         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
301     Value lhsImagIsInfWithSignTimesRhsImag =
302         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
303     Value resultReal3 = rewriter.create<arith::MulFOp>(
304         loc, inf,
305         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
306                                        lhsImagIsInfWithSignTimesRhsImag));
307     Value lhsRealIsInfWithSignTimesRhsImag =
308         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
309     Value lhsImagIsInfWithSignTimesRhsReal =
310         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
311     Value resultImag3 = rewriter.create<arith::MulFOp>(
312         loc, inf,
313         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
314                                        lhsRealIsInfWithSignTimesRhsImag));
315 
316     // Case 3: Finite numerator, infinite denominator.
317     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
318         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
319     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
320         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
321     Value lhsFinite =
322         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
323     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
324         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
325     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
326         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
327     Value rhsInfinite =
328         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
329     Value finiteNumInfiniteDenom =
330         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
331     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
332         loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
333         rhsReal);
334     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
335         loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
336         rhsImag);
337     Value rhsRealIsInfWithSignTimesLhsReal =
338         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
339     Value rhsImagIsInfWithSignTimesLhsImag =
340         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
341     Value resultReal4 = rewriter.create<arith::MulFOp>(
342         loc, zero,
343         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
344                                        rhsImagIsInfWithSignTimesLhsImag));
345     Value rhsRealIsInfWithSignTimesLhsImag =
346         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
347     Value rhsImagIsInfWithSignTimesLhsReal =
348         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
349     Value resultImag4 = rewriter.create<arith::MulFOp>(
350         loc, zero,
351         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
352                                        rhsImagIsInfWithSignTimesLhsReal));
353 
354     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
355         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
356     Value resultReal = rewriter.create<arith::SelectOp>(
357         loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
358     Value resultImag = rewriter.create<arith::SelectOp>(
359         loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
360     Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
361         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
362     Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
363         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
364     Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
365         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
366     Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
367         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
368     Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
369         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
370     Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
371         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
372 
373     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
374         loc, arith::CmpFPredicate::UNO, resultReal, zero);
375     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
376         loc, arith::CmpFPredicate::UNO, resultImag, zero);
377     Value resultIsNaN =
378         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
379     Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
380         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
381     Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
382         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
383 
384     rewriter.replaceOpWithNewOp<complex::CreateOp>(
385         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
386     return success();
387   }
388 };
389 
390 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
391   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
392 
393   LogicalResult
394   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
395                   ConversionPatternRewriter &rewriter) const override {
396     auto loc = op.getLoc();
397     auto type = adaptor.getComplex().getType().cast<ComplexType>();
398     auto elementType = type.getElementType().cast<FloatType>();
399 
400     Value real =
401         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
402     Value imag =
403         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
404     Value expReal = rewriter.create<math::ExpOp>(loc, real);
405     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
406     Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
407     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
408     Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
409 
410     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
411                                                    resultImag);
412     return success();
413   }
414 };
415 
416 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
417   using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
418 
419   LogicalResult
420   matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
421                   ConversionPatternRewriter &rewriter) const override {
422     auto type = adaptor.getComplex().getType().cast<ComplexType>();
423     auto elementType = type.getElementType().cast<FloatType>();
424 
425     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
426     Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
427 
428     Value real = b.create<complex::ReOp>(elementType, exp);
429     Value one = b.create<arith::ConstantOp>(elementType,
430                                             b.getFloatAttr(elementType, 1));
431     Value realMinusOne = b.create<arith::SubFOp>(real, one);
432     Value imag = b.create<complex::ImOp>(elementType, exp);
433 
434     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
435                                                    imag);
436     return success();
437   }
438 };
439 
440 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
441   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
442 
443   LogicalResult
444   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
445                   ConversionPatternRewriter &rewriter) const override {
446     auto type = adaptor.getComplex().getType().cast<ComplexType>();
447     auto elementType = type.getElementType().cast<FloatType>();
448     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
449 
450     Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
451     Value resultReal = b.create<math::LogOp>(elementType, abs);
452     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
453     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
454     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
455     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
456                                                    resultImag);
457     return success();
458   }
459 };
460 
461 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
462   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
463 
464   LogicalResult
465   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
466                   ConversionPatternRewriter &rewriter) const override {
467     auto type = adaptor.getComplex().getType().cast<ComplexType>();
468     auto elementType = type.getElementType().cast<FloatType>();
469     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
470 
471     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
472     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
473     Value one = b.create<arith::ConstantOp>(elementType,
474                                             b.getFloatAttr(elementType, 1));
475     Value realPlusOne = b.create<arith::AddFOp>(real, one);
476     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
477     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
478     return success();
479   }
480 };
481 
482 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
483   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
484 
485   LogicalResult
486   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
487                   ConversionPatternRewriter &rewriter) const override {
488     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
489     auto type = adaptor.getLhs().getType().cast<ComplexType>();
490     auto elementType = type.getElementType().cast<FloatType>();
491 
492     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
493     Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
494     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
495     Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
496     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
497     Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
498     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
499     Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
500 
501     Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
502     Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
503     Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
504     Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
505     Value real =
506         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
507 
508     Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
509     Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
510     Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
511     Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
512     Value imag =
513         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
514 
515     // Handle cases where the "naive" calculation results in NaN values.
516     Value realIsNan =
517         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
518     Value imagIsNan =
519         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
520     Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
521 
522     Value inf = b.create<arith::ConstantOp>(
523         elementType,
524         b.getFloatAttr(elementType,
525                        APFloat::getInf(elementType.getFloatSemantics())));
526 
527     // Case 1. `lhsReal` or `lhsImag` are infinite.
528     Value lhsRealIsInf =
529         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
530     Value lhsImagIsInf =
531         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
532     Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
533     Value rhsRealIsNan =
534         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
535     Value rhsImagIsNan =
536         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
537     Value zero =
538         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
539     Value one = b.create<arith::ConstantOp>(elementType,
540                                             b.getFloatAttr(elementType, 1));
541     Value lhsRealIsInfFloat =
542         b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
543     lhsReal = b.create<arith::SelectOp>(
544         lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
545         lhsReal);
546     Value lhsImagIsInfFloat =
547         b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
548     lhsImag = b.create<arith::SelectOp>(
549         lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
550         lhsImag);
551     Value lhsIsInfAndRhsRealIsNan =
552         b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
553     rhsReal = b.create<arith::SelectOp>(
554         lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
555         rhsReal);
556     Value lhsIsInfAndRhsImagIsNan =
557         b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
558     rhsImag = b.create<arith::SelectOp>(
559         lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
560         rhsImag);
561 
562     // Case 2. `rhsReal` or `rhsImag` are infinite.
563     Value rhsRealIsInf =
564         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
565     Value rhsImagIsInf =
566         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
567     Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
568     Value lhsRealIsNan =
569         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
570     Value lhsImagIsNan =
571         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
572     Value rhsRealIsInfFloat =
573         b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
574     rhsReal = b.create<arith::SelectOp>(
575         rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
576         rhsReal);
577     Value rhsImagIsInfFloat =
578         b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
579     rhsImag = b.create<arith::SelectOp>(
580         rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
581         rhsImag);
582     Value rhsIsInfAndLhsRealIsNan =
583         b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
584     lhsReal = b.create<arith::SelectOp>(
585         rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
586         lhsReal);
587     Value rhsIsInfAndLhsImagIsNan =
588         b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
589     lhsImag = b.create<arith::SelectOp>(
590         rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
591         lhsImag);
592     Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
593 
594     // Case 3. One of the pairwise products of left hand side with right hand
595     // side is infinite.
596     Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
597         arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
598     Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
599         arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
600     Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
601                                                  lhsImagTimesRhsImagIsInf);
602     Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
603         arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
604     isSpecialCase =
605         b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
606     Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
607         arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
608     isSpecialCase =
609         b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
610     Type i1Type = b.getI1Type();
611     Value notRecalc = b.create<arith::XOrIOp>(
612         recalc,
613         b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
614     isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
615     Value isSpecialCaseAndLhsRealIsNan =
616         b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
617     lhsReal = b.create<arith::SelectOp>(
618         isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
619         lhsReal);
620     Value isSpecialCaseAndLhsImagIsNan =
621         b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
622     lhsImag = b.create<arith::SelectOp>(
623         isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
624         lhsImag);
625     Value isSpecialCaseAndRhsRealIsNan =
626         b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
627     rhsReal = b.create<arith::SelectOp>(
628         isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
629         rhsReal);
630     Value isSpecialCaseAndRhsImagIsNan =
631         b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
632     rhsImag = b.create<arith::SelectOp>(
633         isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
634         rhsImag);
635     recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
636     recalc = b.create<arith::AndIOp>(isNan, recalc);
637 
638     // Recalculate real part.
639     lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
640     lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
641     Value newReal =
642         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
643     real = b.create<arith::SelectOp>(
644         recalc, b.create<arith::MulFOp>(inf, newReal), real);
645 
646     // Recalculate imag part.
647     lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
648     lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
649     Value newImag =
650         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
651     imag = b.create<arith::SelectOp>(
652         recalc, b.create<arith::MulFOp>(inf, newImag), imag);
653 
654     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
655     return success();
656   }
657 };
658 
659 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
660   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
661 
662   LogicalResult
663   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
664                   ConversionPatternRewriter &rewriter) const override {
665     auto loc = op.getLoc();
666     auto type = adaptor.getComplex().getType().cast<ComplexType>();
667     auto elementType = type.getElementType().cast<FloatType>();
668 
669     Value real =
670         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
671     Value imag =
672         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
673     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
674     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
675     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
676     return success();
677   }
678 };
679 
680 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
681   using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
682 
683   std::pair<Value, Value>
684   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
685           Value cos, ConversionPatternRewriter &rewriter) const override {
686     // Complex sine is defined as;
687     //   sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
688     // Plugging in:
689     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
690     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
691     // and defining t := exp(y)
692     // We get:
693     //   Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
694     //   Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
695     Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
696     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
697     Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
698     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
699     return {resultReal, resultImag};
700   }
701 };
702 
703 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
704   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
705 
706   LogicalResult
707   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
708                   ConversionPatternRewriter &rewriter) const override {
709     auto type = adaptor.getComplex().getType().cast<ComplexType>();
710     auto elementType = type.getElementType().cast<FloatType>();
711     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
712 
713     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
714     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
715     Value zero =
716         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
717     Value realIsZero =
718         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
719     Value imagIsZero =
720         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
721     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
722     auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
723     Value realSign = b.create<arith::DivFOp>(real, abs);
724     Value imagSign = b.create<arith::DivFOp>(imag, abs);
725     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
726     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
727                                                  adaptor.getComplex(), sign);
728     return success();
729   }
730 };
731 
732 struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
733   using OpConversionPattern<complex::TanOp>::OpConversionPattern;
734 
735   LogicalResult
736   matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
737                   ConversionPatternRewriter &rewriter) const override {
738     auto loc = op.getLoc();
739     Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex());
740     Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex());
741     rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos);
742     return success();
743   }
744 };
745 
746 struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
747   using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
748 
749   LogicalResult
750   matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
751                   ConversionPatternRewriter &rewriter) const override {
752     auto loc = op.getLoc();
753     auto type = adaptor.getComplex().getType().cast<ComplexType>();
754     auto elementType = type.getElementType().cast<FloatType>();
755 
756     // The hyperbolic tangent for complex number can be calculated as follows.
757     // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
758     // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
759     Value real =
760         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
761     Value imag =
762         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
763     Value tanhA = rewriter.create<math::TanhOp>(loc, real);
764     Value cosB = rewriter.create<math::CosOp>(loc, imag);
765     Value sinB = rewriter.create<math::SinOp>(loc, imag);
766     Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
767     Value numerator =
768         rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
769     Value one = rewriter.create<arith::ConstantOp>(
770         loc, elementType, rewriter.getFloatAttr(elementType, 1));
771     Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
772     Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
773     rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
774     return success();
775   }
776 };
777 
778 } // namespace
779 
780 void mlir::populateComplexToStandardConversionPatterns(
781     RewritePatternSet &patterns) {
782   // clang-format off
783   patterns.add<
784       AbsOpConversion,
785       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
786       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
787       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
788       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
789       CosOpConversion,
790       DivOpConversion,
791       ExpOpConversion,
792       Expm1OpConversion,
793       LogOpConversion,
794       Log1pOpConversion,
795       MulOpConversion,
796       NegOpConversion,
797       SignOpConversion,
798       SinOpConversion,
799       TanOpConversion,
800       TanhOpConversion>(patterns.getContext());
801   // clang-format on
802 }
803 
804 namespace {
805 struct ConvertComplexToStandardPass
806     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
807   void runOnOperation() override;
808 };
809 
810 void ConvertComplexToStandardPass::runOnOperation() {
811   // Convert to the Standard dialect using the converter defined above.
812   RewritePatternSet patterns(&getContext());
813   populateComplexToStandardConversionPatterns(patterns);
814 
815   ConversionTarget target(getContext());
816   target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>();
817   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
818   if (failed(
819           applyPartialConversion(getOperation(), target, std::move(patterns))))
820     signalPassFailure();
821 }
822 } // namespace
823 
824 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
825   return std::make_unique<ConvertComplexToStandardPass>();
826 }
827