1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Complex/IR/Complex.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
26   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
30                   ConversionPatternRewriter &rewriter) const override {
31     auto loc = op.getLoc();
32     auto type = op.getType();
33 
34     Value real =
35         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
36     Value imag =
37         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
38     Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
39     Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
40     Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
41 
42     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
43     return success();
44   }
45 };
46 
47 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
48 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
49   using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
50 
51   LogicalResult
52   matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
53                   ConversionPatternRewriter &rewriter) const override {
54     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
55 
56     auto type = op.getType().cast<ComplexType>();
57     Type elementType = type.getElementType();
58 
59     Value lhs = adaptor.getLhs();
60     Value rhs = adaptor.getRhs();
61 
62     Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
63     Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
64     Value rhsSquaredPlusLhsSquared =
65         b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
66     Value sqrtOfRhsSquaredPlusLhsSquared =
67         b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
68 
69     Value zero =
70         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
71     Value one = b.create<arith::ConstantOp>(elementType,
72                                             b.getFloatAttr(elementType, 1));
73     Value i = b.create<complex::CreateOp>(type, zero, one);
74     Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
75     Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
76 
77     Value divResult =
78         b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
79     Value logResult = b.create<complex::LogOp>(divResult);
80 
81     Value negativeOne = b.create<arith::ConstantOp>(
82         elementType, b.getFloatAttr(elementType, -1));
83     Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
84 
85     rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
86     return success();
87   }
88 };
89 
90 template <typename ComparisonOp, arith::CmpFPredicate p>
91 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
92   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
93   using ResultCombiner =
94       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
95                          arith::AndIOp, arith::OrIOp>;
96 
97   LogicalResult
98   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
99                   ConversionPatternRewriter &rewriter) const override {
100     auto loc = op.getLoc();
101     auto type = adaptor.getLhs()
102                     .getType()
103                     .template cast<ComplexType>()
104                     .getElementType();
105 
106     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
107     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
108     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
109     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
110     Value realComparison =
111         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
112     Value imagComparison =
113         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
114 
115     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
116                                                 imagComparison);
117     return success();
118   }
119 };
120 
121 // Default conversion which applies the BinaryStandardOp separately on the real
122 // and imaginary parts. Can for example be used for complex::AddOp and
123 // complex::SubOp.
124 template <typename BinaryComplexOp, typename BinaryStandardOp>
125 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
126   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
127 
128   LogicalResult
129   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
130                   ConversionPatternRewriter &rewriter) const override {
131     auto type = adaptor.getLhs().getType().template cast<ComplexType>();
132     auto elementType = type.getElementType().template cast<FloatType>();
133     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
134 
135     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
136     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
137     Value resultReal =
138         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
139     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
140     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
141     Value resultImag =
142         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
143     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
144                                                    resultImag);
145     return success();
146   }
147 };
148 
149 template <typename TrigonometricOp>
150 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
151   using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
152 
153   using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
154 
155   LogicalResult
156   matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
157                   ConversionPatternRewriter &rewriter) const override {
158     auto loc = op.getLoc();
159     auto type = adaptor.getComplex().getType().template cast<ComplexType>();
160     auto elementType = type.getElementType().template cast<FloatType>();
161 
162     Value real =
163         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
164     Value imag =
165         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
166 
167     // Trigonometric ops use a set of common building blocks to convert to real
168     // ops. Here we create these building blocks and call into an op-specific
169     // implementation in the subclass to combine them.
170     Value half = rewriter.create<arith::ConstantOp>(
171         loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
172     Value exp = rewriter.create<math::ExpOp>(loc, imag);
173     Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
174     Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
175     Value sin = rewriter.create<math::SinOp>(loc, real);
176     Value cos = rewriter.create<math::CosOp>(loc, real);
177 
178     auto resultPair =
179         combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
180 
181     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
182                                                    resultPair.second);
183     return success();
184   }
185 
186   virtual std::pair<Value, Value>
187   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
188           Value cos, ConversionPatternRewriter &rewriter) const = 0;
189 };
190 
191 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
192   using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
193 
194   std::pair<Value, Value>
195   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
196           Value cos, ConversionPatternRewriter &rewriter) const override {
197     // Complex cosine is defined as;
198     //   cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
199     // Plugging in:
200     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
201     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
202     // and defining t := exp(y)
203     // We get:
204     //   Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
205     //   Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
206     Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
207     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
208     Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
209     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
210     return {resultReal, resultImag};
211   }
212 };
213 
214 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
215   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
216 
217   LogicalResult
218   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
219                   ConversionPatternRewriter &rewriter) const override {
220     auto loc = op.getLoc();
221     auto type = adaptor.getLhs().getType().cast<ComplexType>();
222     auto elementType = type.getElementType().cast<FloatType>();
223 
224     Value lhsReal =
225         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
226     Value lhsImag =
227         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
228     Value rhsReal =
229         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
230     Value rhsImag =
231         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
232 
233     // Smith's algorithm to divide complex numbers. It is just a bit smarter
234     // way to compute the following formula:
235     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
236     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
237     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
238     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
239     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
240     //
241     // Depending on whether |rhsReal| < |rhsImag| we compute either
242     //   rhsRealImagRatio = rhsReal / rhsImag
243     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
244     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
245     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
246     //
247     // or
248     //
249     //   rhsImagRealRatio = rhsImag / rhsReal
250     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
251     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
252     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
253     //
254     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
255     Value rhsRealImagRatio =
256         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
257     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
258         loc, rhsImag,
259         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
260     Value realNumerator1 = rewriter.create<arith::AddFOp>(
261         loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
262         lhsImag);
263     Value resultReal1 =
264         rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
265     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
266         loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
267         lhsReal);
268     Value resultImag1 =
269         rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
270 
271     Value rhsImagRealRatio =
272         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
273     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
274         loc, rhsReal,
275         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
276     Value realNumerator2 = rewriter.create<arith::AddFOp>(
277         loc, lhsReal,
278         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
279     Value resultReal2 =
280         rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
281     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
282         loc, lhsImag,
283         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
284     Value resultImag2 =
285         rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
286 
287     // Consider corner cases.
288     // Case 1. Zero denominator, numerator contains at most one NaN value.
289     Value zero = rewriter.create<arith::ConstantOp>(
290         loc, elementType, rewriter.getZeroAttr(elementType));
291     Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
292     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
293         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
294     Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
295     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
296         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
297     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
298         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
299     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
300         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
301     Value lhsContainsNotNaNValue =
302         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
303     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
304         loc, lhsContainsNotNaNValue,
305         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
306     Value inf = rewriter.create<arith::ConstantOp>(
307         loc, elementType,
308         rewriter.getFloatAttr(
309             elementType, APFloat::getInf(elementType.getFloatSemantics())));
310     Value infWithSignOfRhsReal =
311         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
312     Value infinityResultReal =
313         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
314     Value infinityResultImag =
315         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
316 
317     // Case 2. Infinite numerator, finite denominator.
318     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
319         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
320     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
321         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
322     Value rhsFinite =
323         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
324     Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
325     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
326         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
327     Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
328     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
329         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
330     Value lhsInfinite =
331         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
332     Value infNumFiniteDenom =
333         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
334     Value one = rewriter.create<arith::ConstantOp>(
335         loc, elementType, rewriter.getFloatAttr(elementType, 1));
336     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
337         loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
338         lhsReal);
339     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
340         loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
341         lhsImag);
342     Value lhsRealIsInfWithSignTimesRhsReal =
343         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
344     Value lhsImagIsInfWithSignTimesRhsImag =
345         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
346     Value resultReal3 = rewriter.create<arith::MulFOp>(
347         loc, inf,
348         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
349                                        lhsImagIsInfWithSignTimesRhsImag));
350     Value lhsRealIsInfWithSignTimesRhsImag =
351         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
352     Value lhsImagIsInfWithSignTimesRhsReal =
353         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
354     Value resultImag3 = rewriter.create<arith::MulFOp>(
355         loc, inf,
356         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
357                                        lhsRealIsInfWithSignTimesRhsImag));
358 
359     // Case 3: Finite numerator, infinite denominator.
360     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
361         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
362     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
363         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
364     Value lhsFinite =
365         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
366     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
367         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
368     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
369         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
370     Value rhsInfinite =
371         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
372     Value finiteNumInfiniteDenom =
373         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
374     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
375         loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
376         rhsReal);
377     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
378         loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
379         rhsImag);
380     Value rhsRealIsInfWithSignTimesLhsReal =
381         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
382     Value rhsImagIsInfWithSignTimesLhsImag =
383         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
384     Value resultReal4 = rewriter.create<arith::MulFOp>(
385         loc, zero,
386         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
387                                        rhsImagIsInfWithSignTimesLhsImag));
388     Value rhsRealIsInfWithSignTimesLhsImag =
389         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
390     Value rhsImagIsInfWithSignTimesLhsReal =
391         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
392     Value resultImag4 = rewriter.create<arith::MulFOp>(
393         loc, zero,
394         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
395                                        rhsImagIsInfWithSignTimesLhsReal));
396 
397     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
398         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
399     Value resultReal = rewriter.create<arith::SelectOp>(
400         loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
401     Value resultImag = rewriter.create<arith::SelectOp>(
402         loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
403     Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
404         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
405     Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
406         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
407     Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
408         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
409     Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
410         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
411     Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
412         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
413     Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
414         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
415 
416     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
417         loc, arith::CmpFPredicate::UNO, resultReal, zero);
418     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
419         loc, arith::CmpFPredicate::UNO, resultImag, zero);
420     Value resultIsNaN =
421         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
422     Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
423         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
424     Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
425         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
426 
427     rewriter.replaceOpWithNewOp<complex::CreateOp>(
428         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
429     return success();
430   }
431 };
432 
433 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
434   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
435 
436   LogicalResult
437   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
438                   ConversionPatternRewriter &rewriter) const override {
439     auto loc = op.getLoc();
440     auto type = adaptor.getComplex().getType().cast<ComplexType>();
441     auto elementType = type.getElementType().cast<FloatType>();
442 
443     Value real =
444         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
445     Value imag =
446         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
447     Value expReal = rewriter.create<math::ExpOp>(loc, real);
448     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
449     Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
450     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
451     Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
452 
453     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
454                                                    resultImag);
455     return success();
456   }
457 };
458 
459 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
460   using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
461 
462   LogicalResult
463   matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
464                   ConversionPatternRewriter &rewriter) const override {
465     auto type = adaptor.getComplex().getType().cast<ComplexType>();
466     auto elementType = type.getElementType().cast<FloatType>();
467 
468     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
469     Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
470 
471     Value real = b.create<complex::ReOp>(elementType, exp);
472     Value one = b.create<arith::ConstantOp>(elementType,
473                                             b.getFloatAttr(elementType, 1));
474     Value realMinusOne = b.create<arith::SubFOp>(real, one);
475     Value imag = b.create<complex::ImOp>(elementType, exp);
476 
477     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
478                                                    imag);
479     return success();
480   }
481 };
482 
483 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
484   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
485 
486   LogicalResult
487   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
488                   ConversionPatternRewriter &rewriter) const override {
489     auto type = adaptor.getComplex().getType().cast<ComplexType>();
490     auto elementType = type.getElementType().cast<FloatType>();
491     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
492 
493     Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
494     Value resultReal = b.create<math::LogOp>(elementType, abs);
495     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
496     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
497     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
498     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
499                                                    resultImag);
500     return success();
501   }
502 };
503 
504 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
505   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
506 
507   LogicalResult
508   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
509                   ConversionPatternRewriter &rewriter) const override {
510     auto type = adaptor.getComplex().getType().cast<ComplexType>();
511     auto elementType = type.getElementType().cast<FloatType>();
512     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
513 
514     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
515     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
516     Value one = b.create<arith::ConstantOp>(elementType,
517                                             b.getFloatAttr(elementType, 1));
518     Value realPlusOne = b.create<arith::AddFOp>(real, one);
519     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
520     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
521     return success();
522   }
523 };
524 
525 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
526   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
527 
528   LogicalResult
529   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
530                   ConversionPatternRewriter &rewriter) const override {
531     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
532     auto type = adaptor.getLhs().getType().cast<ComplexType>();
533     auto elementType = type.getElementType().cast<FloatType>();
534 
535     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
536     Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
537     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
538     Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
539     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
540     Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
541     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
542     Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
543 
544     Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
545     Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
546     Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
547     Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
548     Value real =
549         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
550 
551     Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
552     Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
553     Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
554     Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
555     Value imag =
556         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
557 
558     // Handle cases where the "naive" calculation results in NaN values.
559     Value realIsNan =
560         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
561     Value imagIsNan =
562         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
563     Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
564 
565     Value inf = b.create<arith::ConstantOp>(
566         elementType,
567         b.getFloatAttr(elementType,
568                        APFloat::getInf(elementType.getFloatSemantics())));
569 
570     // Case 1. `lhsReal` or `lhsImag` are infinite.
571     Value lhsRealIsInf =
572         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
573     Value lhsImagIsInf =
574         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
575     Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
576     Value rhsRealIsNan =
577         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
578     Value rhsImagIsNan =
579         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
580     Value zero =
581         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
582     Value one = b.create<arith::ConstantOp>(elementType,
583                                             b.getFloatAttr(elementType, 1));
584     Value lhsRealIsInfFloat =
585         b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
586     lhsReal = b.create<arith::SelectOp>(
587         lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
588         lhsReal);
589     Value lhsImagIsInfFloat =
590         b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
591     lhsImag = b.create<arith::SelectOp>(
592         lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
593         lhsImag);
594     Value lhsIsInfAndRhsRealIsNan =
595         b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
596     rhsReal = b.create<arith::SelectOp>(
597         lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
598         rhsReal);
599     Value lhsIsInfAndRhsImagIsNan =
600         b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
601     rhsImag = b.create<arith::SelectOp>(
602         lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
603         rhsImag);
604 
605     // Case 2. `rhsReal` or `rhsImag` are infinite.
606     Value rhsRealIsInf =
607         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
608     Value rhsImagIsInf =
609         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
610     Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
611     Value lhsRealIsNan =
612         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
613     Value lhsImagIsNan =
614         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
615     Value rhsRealIsInfFloat =
616         b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
617     rhsReal = b.create<arith::SelectOp>(
618         rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
619         rhsReal);
620     Value rhsImagIsInfFloat =
621         b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
622     rhsImag = b.create<arith::SelectOp>(
623         rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
624         rhsImag);
625     Value rhsIsInfAndLhsRealIsNan =
626         b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
627     lhsReal = b.create<arith::SelectOp>(
628         rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
629         lhsReal);
630     Value rhsIsInfAndLhsImagIsNan =
631         b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
632     lhsImag = b.create<arith::SelectOp>(
633         rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
634         lhsImag);
635     Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
636 
637     // Case 3. One of the pairwise products of left hand side with right hand
638     // side is infinite.
639     Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
640         arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
641     Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
642         arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
643     Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
644                                                  lhsImagTimesRhsImagIsInf);
645     Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
646         arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
647     isSpecialCase =
648         b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
649     Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
650         arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
651     isSpecialCase =
652         b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
653     Type i1Type = b.getI1Type();
654     Value notRecalc = b.create<arith::XOrIOp>(
655         recalc,
656         b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
657     isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
658     Value isSpecialCaseAndLhsRealIsNan =
659         b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
660     lhsReal = b.create<arith::SelectOp>(
661         isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
662         lhsReal);
663     Value isSpecialCaseAndLhsImagIsNan =
664         b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
665     lhsImag = b.create<arith::SelectOp>(
666         isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
667         lhsImag);
668     Value isSpecialCaseAndRhsRealIsNan =
669         b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
670     rhsReal = b.create<arith::SelectOp>(
671         isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
672         rhsReal);
673     Value isSpecialCaseAndRhsImagIsNan =
674         b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
675     rhsImag = b.create<arith::SelectOp>(
676         isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
677         rhsImag);
678     recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
679     recalc = b.create<arith::AndIOp>(isNan, recalc);
680 
681     // Recalculate real part.
682     lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
683     lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
684     Value newReal =
685         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
686     real = b.create<arith::SelectOp>(
687         recalc, b.create<arith::MulFOp>(inf, newReal), real);
688 
689     // Recalculate imag part.
690     lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
691     lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
692     Value newImag =
693         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
694     imag = b.create<arith::SelectOp>(
695         recalc, b.create<arith::MulFOp>(inf, newImag), imag);
696 
697     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
698     return success();
699   }
700 };
701 
702 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
703   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
704 
705   LogicalResult
706   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
707                   ConversionPatternRewriter &rewriter) const override {
708     auto loc = op.getLoc();
709     auto type = adaptor.getComplex().getType().cast<ComplexType>();
710     auto elementType = type.getElementType().cast<FloatType>();
711 
712     Value real =
713         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
714     Value imag =
715         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
716     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
717     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
718     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
719     return success();
720   }
721 };
722 
723 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
724   using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
725 
726   std::pair<Value, Value>
727   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
728           Value cos, ConversionPatternRewriter &rewriter) const override {
729     // Complex sine is defined as;
730     //   sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
731     // Plugging in:
732     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
733     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
734     // and defining t := exp(y)
735     // We get:
736     //   Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
737     //   Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
738     Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
739     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
740     Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
741     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
742     return {resultReal, resultImag};
743   }
744 };
745 
746 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
747 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
748   using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
749 
750   LogicalResult
751   matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
752                   ConversionPatternRewriter &rewriter) const override {
753     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
754 
755     auto type = op.getType().cast<ComplexType>();
756     Type elementType = type.getElementType();
757     Value arg = adaptor.getComplex();
758 
759     Value zero =
760         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
761 
762     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
763     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
764 
765     Value absLhs = b.create<math::AbsOp>(real);
766     Value absArg = b.create<complex::AbsOp>(elementType, arg);
767     Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
768 
769     Value half = b.create<arith::ConstantOp>(elementType,
770                                              b.getFloatAttr(elementType, 0.5));
771     Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
772     Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
773 
774     Value realIsNegative =
775         b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
776     Value imagIsNegative =
777         b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
778 
779     Value resultReal = sqrtAddAbs;
780 
781     Value imagDivTwoResultReal = b.create<arith::DivFOp>(
782         imag, b.create<arith::AddFOp>(resultReal, resultReal));
783 
784     Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
785 
786     Value resultImag = b.create<arith::SelectOp>(
787         realIsNegative,
788         b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
789                                   resultReal),
790         imagDivTwoResultReal);
791 
792     resultReal = b.create<arith::SelectOp>(
793         realIsNegative,
794         b.create<arith::DivFOp>(
795             imag, b.create<arith::AddFOp>(resultImag, resultImag)),
796         resultReal);
797 
798     Value realIsZero =
799         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
800     Value imagIsZero =
801         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
802     Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
803 
804     resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
805     resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
806 
807     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
808                                                    resultImag);
809     return success();
810   }
811 };
812 
813 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
814   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
815 
816   LogicalResult
817   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
818                   ConversionPatternRewriter &rewriter) const override {
819     auto type = adaptor.getComplex().getType().cast<ComplexType>();
820     auto elementType = type.getElementType().cast<FloatType>();
821     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
822 
823     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
824     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
825     Value zero =
826         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
827     Value realIsZero =
828         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
829     Value imagIsZero =
830         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
831     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
832     auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
833     Value realSign = b.create<arith::DivFOp>(real, abs);
834     Value imagSign = b.create<arith::DivFOp>(imag, abs);
835     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
836     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
837                                                  adaptor.getComplex(), sign);
838     return success();
839   }
840 };
841 
842 struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
843   using OpConversionPattern<complex::TanOp>::OpConversionPattern;
844 
845   LogicalResult
846   matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
847                   ConversionPatternRewriter &rewriter) const override {
848     auto loc = op.getLoc();
849     Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex());
850     Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex());
851     rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos);
852     return success();
853   }
854 };
855 
856 struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
857   using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
858 
859   LogicalResult
860   matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
861                   ConversionPatternRewriter &rewriter) const override {
862     auto loc = op.getLoc();
863     auto type = adaptor.getComplex().getType().cast<ComplexType>();
864     auto elementType = type.getElementType().cast<FloatType>();
865 
866     // The hyperbolic tangent for complex number can be calculated as follows.
867     // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
868     // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
869     Value real =
870         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
871     Value imag =
872         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
873     Value tanhA = rewriter.create<math::TanhOp>(loc, real);
874     Value cosB = rewriter.create<math::CosOp>(loc, imag);
875     Value sinB = rewriter.create<math::SinOp>(loc, imag);
876     Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
877     Value numerator =
878         rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
879     Value one = rewriter.create<arith::ConstantOp>(
880         loc, elementType, rewriter.getFloatAttr(elementType, 1));
881     Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
882     Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
883     rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
884     return success();
885   }
886 };
887 
888 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
889   using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
890 
891   LogicalResult
892   matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
893                   ConversionPatternRewriter &rewriter) const override {
894     auto loc = op.getLoc();
895     auto type = adaptor.getComplex().getType().cast<ComplexType>();
896     auto elementType = type.getElementType().cast<FloatType>();
897     Value real =
898         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
899     Value imag =
900         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
901     Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
902 
903     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
904 
905     return success();
906   }
907 };
908 
909 /// Coverts x^y = (a+bi)^(c+di) to
910 ///    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
911 ///    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
912 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
913                                  ComplexType type, Value a, Value b, Value c,
914                                  Value d) {
915   auto elementType = type.getElementType().cast<FloatType>();
916 
917   // Compute (a*a+b*b)^(0.5c).
918   Value aaPbb = builder.create<arith::AddFOp>(
919       builder.create<arith::MulFOp>(a, a), builder.create<arith::MulFOp>(b, b));
920   Value half = builder.create<arith::ConstantOp>(
921       elementType, builder.getFloatAttr(elementType, 0.5));
922   Value halfC = builder.create<arith::MulFOp>(half, c);
923   Value aaPbbTohalfC = builder.create<math::PowFOp>(aaPbb, halfC);
924 
925   // Compute exp(-d*atan2(b,a)).
926   Value negD = builder.create<arith::NegFOp>(d);
927   Value argX = builder.create<math::Atan2Op>(b, a);
928   Value negDArgX = builder.create<arith::MulFOp>(negD, argX);
929   Value eToNegDArgX = builder.create<math::ExpOp>(negDArgX);
930 
931   // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)).
932   Value coeff = builder.create<arith::MulFOp>(aaPbbTohalfC, eToNegDArgX);
933 
934   // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b).
935   Value lnAaPbb = builder.create<math::LogOp>(aaPbb);
936   Value halfD = builder.create<arith::MulFOp>(half, d);
937   Value q = builder.create<arith::AddFOp>(
938       builder.create<arith::MulFOp>(c, argX),
939       builder.create<arith::MulFOp>(halfD, lnAaPbb));
940 
941   Value cosQ = builder.create<math::CosOp>(q);
942   Value sinQ = builder.create<math::SinOp>(q);
943   Value zero = builder.create<arith::ConstantOp>(
944       elementType, builder.getFloatAttr(elementType, 0));
945   Value one = builder.create<arith::ConstantOp>(
946       elementType, builder.getFloatAttr(elementType, 1));
947 
948   Value xEqZero =
949       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, aaPbb, zero);
950   Value yGeZero = builder.create<arith::AndIOp>(
951       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, c, zero),
952       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero));
953   Value cEqZero =
954       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero);
955   Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
956   Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
957   Value complexOther = builder.create<complex::CreateOp>(
958       type, builder.create<arith::MulFOp>(coeff, cosQ),
959       builder.create<arith::MulFOp>(coeff, sinQ));
960 
961   // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see
962   // Branch Cuts for Complex Elementary Functions or Much Ado About
963   // Nothing's Sign Bit, W. Kahan, Section 10.
964   return builder.create<arith::SelectOp>(
965       builder.create<arith::AndIOp>(xEqZero, yGeZero),
966       builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero),
967       complexOther);
968 }
969 
970 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
971   using OpConversionPattern<complex::PowOp>::OpConversionPattern;
972 
973   LogicalResult
974   matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
975                   ConversionPatternRewriter &rewriter) const override {
976     mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
977     auto type = adaptor.getLhs().getType().cast<ComplexType>();
978     auto elementType = type.getElementType().cast<FloatType>();
979 
980     Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs());
981     Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs());
982     Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
983     Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
984 
985     rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
986     return success();
987   }
988 };
989 
990 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
991   using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
992 
993   LogicalResult
994   matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
995                   ConversionPatternRewriter &rewriter) const override {
996     mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
997     auto type = adaptor.getComplex().getType().cast<ComplexType>();
998     auto elementType = type.getElementType().cast<FloatType>();
999 
1000     Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex());
1001     Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex());
1002     Value c = builder.create<arith::ConstantOp>(
1003         elementType, builder.getFloatAttr(elementType, -0.5));
1004     Value d = builder.create<arith::ConstantOp>(
1005         elementType, builder.getFloatAttr(elementType, 0));
1006 
1007     rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
1008     return success();
1009   }
1010 };
1011 
1012 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1013   using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
1014 
1015   LogicalResult
1016   matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1017                   ConversionPatternRewriter &rewriter) const override {
1018     auto loc = op.getLoc();
1019     auto type = op.getType();
1020 
1021     Value real =
1022         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
1023     Value imag =
1024         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
1025 
1026     rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
1027 
1028     return success();
1029   }
1030 };
1031 
1032 } // namespace
1033 
1034 void mlir::populateComplexToStandardConversionPatterns(
1035     RewritePatternSet &patterns) {
1036   // clang-format off
1037   patterns.add<
1038       AbsOpConversion,
1039       AngleOpConversion,
1040       Atan2OpConversion,
1041       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1042       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1043       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1044       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1045       ConjOpConversion,
1046       CosOpConversion,
1047       DivOpConversion,
1048       ExpOpConversion,
1049       Expm1OpConversion,
1050       Log1pOpConversion,
1051       LogOpConversion,
1052       MulOpConversion,
1053       NegOpConversion,
1054       SignOpConversion,
1055       SinOpConversion,
1056       SqrtOpConversion,
1057       TanOpConversion,
1058       TanhOpConversion,
1059       PowOpConversion,
1060       RsqrtOpConversion
1061   >(patterns.getContext());
1062   // clang-format on
1063 }
1064 
1065 namespace {
1066 struct ConvertComplexToStandardPass
1067     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1068   void runOnOperation() override;
1069 };
1070 
1071 void ConvertComplexToStandardPass::runOnOperation() {
1072   // Convert to the Standard dialect using the converter defined above.
1073   RewritePatternSet patterns(&getContext());
1074   populateComplexToStandardConversionPatterns(patterns);
1075 
1076   ConversionTarget target(getContext());
1077   target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>();
1078   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1079   if (failed(
1080           applyPartialConversion(getOperation(), target, std::move(patterns))))
1081     signalPassFailure();
1082 }
1083 } // namespace
1084 
1085 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
1086   return std::make_unique<ConvertComplexToStandardPass>();
1087 }
1088