1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Complex/IR/Complex.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
26   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
30                   ConversionPatternRewriter &rewriter) const override {
31     auto loc = op.getLoc();
32     auto type = op.getType();
33 
34     Value real =
35         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
36     Value imag =
37         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
38     Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
39     Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
40     Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
41 
42     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
43     return success();
44   }
45 };
46 
47 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
48 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
49   using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
50 
51   LogicalResult
52   matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
53                   ConversionPatternRewriter &rewriter) const override {
54     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
55 
56     auto type = op.getType().cast<ComplexType>();
57     Type elementType = type.getElementType();
58 
59     Value lhs = adaptor.getLhs();
60     Value rhs = adaptor.getRhs();
61 
62     Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
63     Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
64     Value rhsSquaredPlusLhsSquared =
65         b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
66     Value sqrtOfRhsSquaredPlusLhsSquared =
67         b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
68 
69     Value zero =
70         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
71     Value one = b.create<arith::ConstantOp>(elementType,
72                                             b.getFloatAttr(elementType, 1));
73     Value i = b.create<complex::CreateOp>(type, zero, one);
74     Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
75     Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
76 
77     Value divResult =
78         b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
79     Value logResult = b.create<complex::LogOp>(divResult);
80 
81     Value negativeOne = b.create<arith::ConstantOp>(
82         elementType, b.getFloatAttr(elementType, -1));
83     Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
84 
85     rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
86     return success();
87   }
88 };
89 
90 template <typename ComparisonOp, arith::CmpFPredicate p>
91 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
92   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
93   using ResultCombiner =
94       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
95                          arith::AndIOp, arith::OrIOp>;
96 
97   LogicalResult
98   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
99                   ConversionPatternRewriter &rewriter) const override {
100     auto loc = op.getLoc();
101     auto type = adaptor.getLhs()
102                     .getType()
103                     .template cast<ComplexType>()
104                     .getElementType();
105 
106     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
107     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
108     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
109     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
110     Value realComparison =
111         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
112     Value imagComparison =
113         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
114 
115     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
116                                                 imagComparison);
117     return success();
118   }
119 };
120 
121 // Default conversion which applies the BinaryStandardOp separately on the real
122 // and imaginary parts. Can for example be used for complex::AddOp and
123 // complex::SubOp.
124 template <typename BinaryComplexOp, typename BinaryStandardOp>
125 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
126   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
127 
128   LogicalResult
129   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
130                   ConversionPatternRewriter &rewriter) const override {
131     auto type = adaptor.getLhs().getType().template cast<ComplexType>();
132     auto elementType = type.getElementType().template cast<FloatType>();
133     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
134 
135     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
136     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
137     Value resultReal =
138         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
139     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
140     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
141     Value resultImag =
142         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
143     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
144                                                    resultImag);
145     return success();
146   }
147 };
148 
149 template <typename TrigonometricOp>
150 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
151   using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
152 
153   using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
154 
155   LogicalResult
156   matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
157                   ConversionPatternRewriter &rewriter) const override {
158     auto loc = op.getLoc();
159     auto type = adaptor.getComplex().getType().template cast<ComplexType>();
160     auto elementType = type.getElementType().template cast<FloatType>();
161 
162     Value real =
163         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
164     Value imag =
165         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
166 
167     // Trigonometric ops use a set of common building blocks to convert to real
168     // ops. Here we create these building blocks and call into an op-specific
169     // implementation in the subclass to combine them.
170     Value half = rewriter.create<arith::ConstantOp>(
171         loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
172     Value exp = rewriter.create<math::ExpOp>(loc, imag);
173     Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
174     Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
175     Value sin = rewriter.create<math::SinOp>(loc, real);
176     Value cos = rewriter.create<math::CosOp>(loc, real);
177 
178     auto resultPair =
179         combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
180 
181     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
182                                                    resultPair.second);
183     return success();
184   }
185 
186   virtual std::pair<Value, Value>
187   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
188           Value cos, ConversionPatternRewriter &rewriter) const = 0;
189 };
190 
191 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
192   using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
193 
194   std::pair<Value, Value>
195   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
196           Value cos, ConversionPatternRewriter &rewriter) const override {
197     // Complex cosine is defined as;
198     //   cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
199     // Plugging in:
200     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
201     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
202     // and defining t := exp(y)
203     // We get:
204     //   Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
205     //   Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
206     Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
207     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
208     Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
209     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
210     return {resultReal, resultImag};
211   }
212 };
213 
214 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
215   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
216 
217   LogicalResult
218   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
219                   ConversionPatternRewriter &rewriter) const override {
220     auto loc = op.getLoc();
221     auto type = adaptor.getLhs().getType().cast<ComplexType>();
222     auto elementType = type.getElementType().cast<FloatType>();
223 
224     Value lhsReal =
225         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
226     Value lhsImag =
227         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
228     Value rhsReal =
229         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
230     Value rhsImag =
231         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
232 
233     // Smith's algorithm to divide complex numbers. It is just a bit smarter
234     // way to compute the following formula:
235     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
236     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
237     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
238     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
239     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
240     //
241     // Depending on whether |rhsReal| < |rhsImag| we compute either
242     //   rhsRealImagRatio = rhsReal / rhsImag
243     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
244     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
245     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
246     //
247     // or
248     //
249     //   rhsImagRealRatio = rhsImag / rhsReal
250     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
251     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
252     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
253     //
254     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
255     Value rhsRealImagRatio =
256         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
257     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
258         loc, rhsImag,
259         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
260     Value realNumerator1 = rewriter.create<arith::AddFOp>(
261         loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
262         lhsImag);
263     Value resultReal1 =
264         rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
265     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
266         loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
267         lhsReal);
268     Value resultImag1 =
269         rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
270 
271     Value rhsImagRealRatio =
272         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
273     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
274         loc, rhsReal,
275         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
276     Value realNumerator2 = rewriter.create<arith::AddFOp>(
277         loc, lhsReal,
278         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
279     Value resultReal2 =
280         rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
281     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
282         loc, lhsImag,
283         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
284     Value resultImag2 =
285         rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
286 
287     // Consider corner cases.
288     // Case 1. Zero denominator, numerator contains at most one NaN value.
289     Value zero = rewriter.create<arith::ConstantOp>(
290         loc, elementType, rewriter.getZeroAttr(elementType));
291     Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
292     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
293         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
294     Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
295     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
296         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
297     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
298         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
299     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
300         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
301     Value lhsContainsNotNaNValue =
302         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
303     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
304         loc, lhsContainsNotNaNValue,
305         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
306     Value inf = rewriter.create<arith::ConstantOp>(
307         loc, elementType,
308         rewriter.getFloatAttr(
309             elementType, APFloat::getInf(elementType.getFloatSemantics())));
310     Value infWithSignOfRhsReal =
311         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
312     Value infinityResultReal =
313         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
314     Value infinityResultImag =
315         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
316 
317     // Case 2. Infinite numerator, finite denominator.
318     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
319         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
320     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
321         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
322     Value rhsFinite =
323         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
324     Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
325     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
326         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
327     Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
328     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
329         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
330     Value lhsInfinite =
331         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
332     Value infNumFiniteDenom =
333         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
334     Value one = rewriter.create<arith::ConstantOp>(
335         loc, elementType, rewriter.getFloatAttr(elementType, 1));
336     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
337         loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
338         lhsReal);
339     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
340         loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
341         lhsImag);
342     Value lhsRealIsInfWithSignTimesRhsReal =
343         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
344     Value lhsImagIsInfWithSignTimesRhsImag =
345         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
346     Value resultReal3 = rewriter.create<arith::MulFOp>(
347         loc, inf,
348         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
349                                        lhsImagIsInfWithSignTimesRhsImag));
350     Value lhsRealIsInfWithSignTimesRhsImag =
351         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
352     Value lhsImagIsInfWithSignTimesRhsReal =
353         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
354     Value resultImag3 = rewriter.create<arith::MulFOp>(
355         loc, inf,
356         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
357                                        lhsRealIsInfWithSignTimesRhsImag));
358 
359     // Case 3: Finite numerator, infinite denominator.
360     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
361         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
362     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
363         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
364     Value lhsFinite =
365         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
366     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
367         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
368     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
369         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
370     Value rhsInfinite =
371         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
372     Value finiteNumInfiniteDenom =
373         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
374     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
375         loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
376         rhsReal);
377     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
378         loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
379         rhsImag);
380     Value rhsRealIsInfWithSignTimesLhsReal =
381         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
382     Value rhsImagIsInfWithSignTimesLhsImag =
383         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
384     Value resultReal4 = rewriter.create<arith::MulFOp>(
385         loc, zero,
386         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
387                                        rhsImagIsInfWithSignTimesLhsImag));
388     Value rhsRealIsInfWithSignTimesLhsImag =
389         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
390     Value rhsImagIsInfWithSignTimesLhsReal =
391         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
392     Value resultImag4 = rewriter.create<arith::MulFOp>(
393         loc, zero,
394         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
395                                        rhsImagIsInfWithSignTimesLhsReal));
396 
397     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
398         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
399     Value resultReal = rewriter.create<arith::SelectOp>(
400         loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
401     Value resultImag = rewriter.create<arith::SelectOp>(
402         loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
403     Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
404         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
405     Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
406         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
407     Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
408         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
409     Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
410         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
411     Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
412         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
413     Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
414         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
415 
416     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
417         loc, arith::CmpFPredicate::UNO, resultReal, zero);
418     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
419         loc, arith::CmpFPredicate::UNO, resultImag, zero);
420     Value resultIsNaN =
421         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
422     Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
423         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
424     Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
425         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
426 
427     rewriter.replaceOpWithNewOp<complex::CreateOp>(
428         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
429     return success();
430   }
431 };
432 
433 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
434   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
435 
436   LogicalResult
437   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
438                   ConversionPatternRewriter &rewriter) const override {
439     auto loc = op.getLoc();
440     auto type = adaptor.getComplex().getType().cast<ComplexType>();
441     auto elementType = type.getElementType().cast<FloatType>();
442 
443     Value real =
444         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
445     Value imag =
446         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
447     Value expReal = rewriter.create<math::ExpOp>(loc, real);
448     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
449     Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
450     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
451     Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
452 
453     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
454                                                    resultImag);
455     return success();
456   }
457 };
458 
459 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
460   using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
461 
462   LogicalResult
463   matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
464                   ConversionPatternRewriter &rewriter) const override {
465     auto type = adaptor.getComplex().getType().cast<ComplexType>();
466     auto elementType = type.getElementType().cast<FloatType>();
467 
468     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
469     Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
470 
471     Value real = b.create<complex::ReOp>(elementType, exp);
472     Value one = b.create<arith::ConstantOp>(elementType,
473                                             b.getFloatAttr(elementType, 1));
474     Value realMinusOne = b.create<arith::SubFOp>(real, one);
475     Value imag = b.create<complex::ImOp>(elementType, exp);
476 
477     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
478                                                    imag);
479     return success();
480   }
481 };
482 
483 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
484   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
485 
486   LogicalResult
487   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
488                   ConversionPatternRewriter &rewriter) const override {
489     auto type = adaptor.getComplex().getType().cast<ComplexType>();
490     auto elementType = type.getElementType().cast<FloatType>();
491     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
492 
493     Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
494     Value resultReal = b.create<math::LogOp>(elementType, abs);
495     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
496     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
497     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
498     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
499                                                    resultImag);
500     return success();
501   }
502 };
503 
504 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
505   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
506 
507   LogicalResult
508   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
509                   ConversionPatternRewriter &rewriter) const override {
510     auto type = adaptor.getComplex().getType().cast<ComplexType>();
511     auto elementType = type.getElementType().cast<FloatType>();
512     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
513 
514     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
515     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
516     Value one = b.create<arith::ConstantOp>(elementType,
517                                             b.getFloatAttr(elementType, 1));
518     Value realPlusOne = b.create<arith::AddFOp>(real, one);
519     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
520     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
521     return success();
522   }
523 };
524 
525 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
526   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
527 
528   LogicalResult
529   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
530                   ConversionPatternRewriter &rewriter) const override {
531     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
532     auto type = adaptor.getLhs().getType().cast<ComplexType>();
533     auto elementType = type.getElementType().cast<FloatType>();
534 
535     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
536     Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
537     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
538     Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
539     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
540     Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
541     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
542     Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
543 
544     Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
545     Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
546     Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
547     Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
548     Value real =
549         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
550 
551     Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
552     Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
553     Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
554     Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
555     Value imag =
556         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
557 
558     // Handle cases where the "naive" calculation results in NaN values.
559     Value realIsNan =
560         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
561     Value imagIsNan =
562         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
563     Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
564 
565     Value inf = b.create<arith::ConstantOp>(
566         elementType,
567         b.getFloatAttr(elementType,
568                        APFloat::getInf(elementType.getFloatSemantics())));
569 
570     // Case 1. `lhsReal` or `lhsImag` are infinite.
571     Value lhsRealIsInf =
572         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
573     Value lhsImagIsInf =
574         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
575     Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
576     Value rhsRealIsNan =
577         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
578     Value rhsImagIsNan =
579         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
580     Value zero =
581         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
582     Value one = b.create<arith::ConstantOp>(elementType,
583                                             b.getFloatAttr(elementType, 1));
584     Value lhsRealIsInfFloat =
585         b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
586     lhsReal = b.create<arith::SelectOp>(
587         lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
588         lhsReal);
589     Value lhsImagIsInfFloat =
590         b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
591     lhsImag = b.create<arith::SelectOp>(
592         lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
593         lhsImag);
594     Value lhsIsInfAndRhsRealIsNan =
595         b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
596     rhsReal = b.create<arith::SelectOp>(
597         lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
598         rhsReal);
599     Value lhsIsInfAndRhsImagIsNan =
600         b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
601     rhsImag = b.create<arith::SelectOp>(
602         lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
603         rhsImag);
604 
605     // Case 2. `rhsReal` or `rhsImag` are infinite.
606     Value rhsRealIsInf =
607         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
608     Value rhsImagIsInf =
609         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
610     Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
611     Value lhsRealIsNan =
612         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
613     Value lhsImagIsNan =
614         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
615     Value rhsRealIsInfFloat =
616         b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
617     rhsReal = b.create<arith::SelectOp>(
618         rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
619         rhsReal);
620     Value rhsImagIsInfFloat =
621         b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
622     rhsImag = b.create<arith::SelectOp>(
623         rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
624         rhsImag);
625     Value rhsIsInfAndLhsRealIsNan =
626         b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
627     lhsReal = b.create<arith::SelectOp>(
628         rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
629         lhsReal);
630     Value rhsIsInfAndLhsImagIsNan =
631         b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
632     lhsImag = b.create<arith::SelectOp>(
633         rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
634         lhsImag);
635     Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
636 
637     // Case 3. One of the pairwise products of left hand side with right hand
638     // side is infinite.
639     Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
640         arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
641     Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
642         arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
643     Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
644                                                  lhsImagTimesRhsImagIsInf);
645     Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
646         arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
647     isSpecialCase =
648         b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
649     Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
650         arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
651     isSpecialCase =
652         b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
653     Type i1Type = b.getI1Type();
654     Value notRecalc = b.create<arith::XOrIOp>(
655         recalc,
656         b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
657     isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
658     Value isSpecialCaseAndLhsRealIsNan =
659         b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
660     lhsReal = b.create<arith::SelectOp>(
661         isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
662         lhsReal);
663     Value isSpecialCaseAndLhsImagIsNan =
664         b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
665     lhsImag = b.create<arith::SelectOp>(
666         isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
667         lhsImag);
668     Value isSpecialCaseAndRhsRealIsNan =
669         b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
670     rhsReal = b.create<arith::SelectOp>(
671         isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
672         rhsReal);
673     Value isSpecialCaseAndRhsImagIsNan =
674         b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
675     rhsImag = b.create<arith::SelectOp>(
676         isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
677         rhsImag);
678     recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
679     recalc = b.create<arith::AndIOp>(isNan, recalc);
680 
681     // Recalculate real part.
682     lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
683     lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
684     Value newReal =
685         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
686     real = b.create<arith::SelectOp>(
687         recalc, b.create<arith::MulFOp>(inf, newReal), real);
688 
689     // Recalculate imag part.
690     lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
691     lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
692     Value newImag =
693         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
694     imag = b.create<arith::SelectOp>(
695         recalc, b.create<arith::MulFOp>(inf, newImag), imag);
696 
697     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
698     return success();
699   }
700 };
701 
702 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
703   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
704 
705   LogicalResult
706   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
707                   ConversionPatternRewriter &rewriter) const override {
708     auto loc = op.getLoc();
709     auto type = adaptor.getComplex().getType().cast<ComplexType>();
710     auto elementType = type.getElementType().cast<FloatType>();
711 
712     Value real =
713         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
714     Value imag =
715         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
716     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
717     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
718     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
719     return success();
720   }
721 };
722 
723 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
724   using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
725 
726   std::pair<Value, Value>
727   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
728           Value cos, ConversionPatternRewriter &rewriter) const override {
729     // Complex sine is defined as;
730     //   sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
731     // Plugging in:
732     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
733     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
734     // and defining t := exp(y)
735     // We get:
736     //   Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
737     //   Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
738     Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
739     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
740     Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
741     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
742     return {resultReal, resultImag};
743   }
744 };
745 
746 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
747 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
748   using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
749 
750   LogicalResult
751   matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
752                   ConversionPatternRewriter &rewriter) const override {
753     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
754 
755     auto type = op.getType().cast<ComplexType>();
756     Type elementType = type.getElementType();
757     Value arg = adaptor.getComplex();
758 
759     Value zero =
760         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
761 
762     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
763     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
764 
765     Value absLhs = b.create<math::AbsOp>(real);
766     Value absArg = b.create<complex::AbsOp>(elementType, arg);
767     Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
768 
769     Value half = b.create<arith::ConstantOp>(
770                         elementType, b.getFloatAttr(elementType, 0.5));
771     Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
772     Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
773 
774     Value realIsNegative =
775         b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
776     Value imagIsNegative =
777         b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
778 
779     Value resultReal = sqrtAddAbs;
780 
781     Value imagDivTwoResultReal = b.create<arith::DivFOp>(
782         imag, b.create<arith::AddFOp>(resultReal, resultReal));
783 
784     Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
785 
786     Value resultImag = b.create<arith::SelectOp>(
787         realIsNegative,
788         b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
789                                   resultReal),
790         imagDivTwoResultReal);
791 
792     resultReal = b.create<arith::SelectOp>(
793         realIsNegative,
794         b.create<arith::DivFOp>(
795             imag, b.create<arith::AddFOp>(resultImag, resultImag)),
796         resultReal);
797 
798     Value realIsZero =
799         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
800     Value imagIsZero =
801         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
802     Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
803 
804     resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
805     resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
806 
807     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
808                                                    resultImag);
809     return success();
810   }
811 };
812 
813 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
814   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
815 
816   LogicalResult
817   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
818                   ConversionPatternRewriter &rewriter) const override {
819     auto type = adaptor.getComplex().getType().cast<ComplexType>();
820     auto elementType = type.getElementType().cast<FloatType>();
821     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
822 
823     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
824     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
825     Value zero =
826         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
827     Value realIsZero =
828         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
829     Value imagIsZero =
830         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
831     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
832     auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
833     Value realSign = b.create<arith::DivFOp>(real, abs);
834     Value imagSign = b.create<arith::DivFOp>(imag, abs);
835     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
836     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
837                                                  adaptor.getComplex(), sign);
838     return success();
839   }
840 };
841 
842 struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
843   using OpConversionPattern<complex::TanOp>::OpConversionPattern;
844 
845   LogicalResult
846   matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
847                   ConversionPatternRewriter &rewriter) const override {
848     auto loc = op.getLoc();
849     Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex());
850     Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex());
851     rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos);
852     return success();
853   }
854 };
855 
856 struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
857   using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
858 
859   LogicalResult
860   matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
861                   ConversionPatternRewriter &rewriter) const override {
862     auto loc = op.getLoc();
863     auto type = adaptor.getComplex().getType().cast<ComplexType>();
864     auto elementType = type.getElementType().cast<FloatType>();
865 
866     // The hyperbolic tangent for complex number can be calculated as follows.
867     // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
868     // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
869     Value real =
870         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
871     Value imag =
872         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
873     Value tanhA = rewriter.create<math::TanhOp>(loc, real);
874     Value cosB = rewriter.create<math::CosOp>(loc, imag);
875     Value sinB = rewriter.create<math::SinOp>(loc, imag);
876     Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
877     Value numerator =
878         rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
879     Value one = rewriter.create<arith::ConstantOp>(
880         loc, elementType, rewriter.getFloatAttr(elementType, 1));
881     Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
882     Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
883     rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
884     return success();
885   }
886 };
887 
888 } // namespace
889 
890 void mlir::populateComplexToStandardConversionPatterns(
891     RewritePatternSet &patterns) {
892   // clang-format off
893   patterns.add<
894       AbsOpConversion,
895       Atan2OpConversion,
896       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
897       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
898       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
899       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
900       CosOpConversion,
901       DivOpConversion,
902       ExpOpConversion,
903       Expm1OpConversion,
904       LogOpConversion,
905       Log1pOpConversion,
906       MulOpConversion,
907       NegOpConversion,
908       SignOpConversion,
909       SinOpConversion,
910       SqrtOpConversion,
911       TanOpConversion,
912       TanhOpConversion>(patterns.getContext());
913   // clang-format on
914 }
915 
916 namespace {
917 struct ConvertComplexToStandardPass
918     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
919   void runOnOperation() override;
920 };
921 
922 void ConvertComplexToStandardPass::runOnOperation() {
923   // Convert to the Standard dialect using the converter defined above.
924   RewritePatternSet patterns(&getContext());
925   populateComplexToStandardConversionPatterns(patterns);
926 
927   ConversionTarget target(getContext());
928   target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>();
929   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
930   if (failed(
931           applyPartialConversion(getOperation(), target, std::move(patterns))))
932     signalPassFailure();
933 }
934 } // namespace
935 
936 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
937   return std::make_unique<ConvertComplexToStandardPass>();
938 }
939