12ea7fb7bSAdrian Kuegel //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
22ea7fb7bSAdrian Kuegel //
32ea7fb7bSAdrian Kuegel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42ea7fb7bSAdrian Kuegel // See https://llvm.org/LICENSE.txt for license information.
52ea7fb7bSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62ea7fb7bSAdrian Kuegel //
72ea7fb7bSAdrian Kuegel //===----------------------------------------------------------------------===//
82ea7fb7bSAdrian Kuegel 
92ea7fb7bSAdrian Kuegel #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
102ea7fb7bSAdrian Kuegel 
112ea7fb7bSAdrian Kuegel #include <memory>
12fb8b2b86SAdrian Kuegel #include <type_traits>
132ea7fb7bSAdrian Kuegel 
142ea7fb7bSAdrian Kuegel #include "../PassDetail.h"
15a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
162ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Complex/IR/Complex.h"
172ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Math/IR/Math.h"
18f112bd61SAdrian Kuegel #include "mlir/IR/ImplicitLocOpBuilder.h"
192ea7fb7bSAdrian Kuegel #include "mlir/IR/PatternMatch.h"
202ea7fb7bSAdrian Kuegel #include "mlir/Transforms/DialectConversion.h"
212ea7fb7bSAdrian Kuegel 
222ea7fb7bSAdrian Kuegel using namespace mlir;
232ea7fb7bSAdrian Kuegel 
242ea7fb7bSAdrian Kuegel namespace {
252ea7fb7bSAdrian Kuegel struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
262ea7fb7bSAdrian Kuegel   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
272ea7fb7bSAdrian Kuegel 
282ea7fb7bSAdrian Kuegel   LogicalResult
matchAndRewrite__anon597693150111::AbsOpConversion29b54c724bSRiver Riddle   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
302ea7fb7bSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
312ea7fb7bSAdrian Kuegel     auto loc = op.getLoc();
322ea7fb7bSAdrian Kuegel     auto type = op.getType();
332ea7fb7bSAdrian Kuegel 
34c0342a2dSJacques Pienaar     Value real =
35c0342a2dSJacques Pienaar         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
36c0342a2dSJacques Pienaar     Value imag =
37c0342a2dSJacques Pienaar         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
38a54f4eaeSMogball     Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
39a54f4eaeSMogball     Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
40a54f4eaeSMogball     Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
412ea7fb7bSAdrian Kuegel 
422ea7fb7bSAdrian Kuegel     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
432ea7fb7bSAdrian Kuegel     return success();
442ea7fb7bSAdrian Kuegel   }
452ea7fb7bSAdrian Kuegel };
46ac00cb0dSAdrian Kuegel 
47f711785eSAlexander Belyaev // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
48f711785eSAlexander Belyaev struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
49f711785eSAlexander Belyaev   using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
50f711785eSAlexander Belyaev 
51f711785eSAlexander Belyaev   LogicalResult
matchAndRewrite__anon597693150111::Atan2OpConversion52f711785eSAlexander Belyaev   matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
53f711785eSAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
54f711785eSAlexander Belyaev     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
55f711785eSAlexander Belyaev 
56f711785eSAlexander Belyaev     auto type = op.getType().cast<ComplexType>();
57f711785eSAlexander Belyaev     Type elementType = type.getElementType();
58f711785eSAlexander Belyaev 
59f711785eSAlexander Belyaev     Value lhs = adaptor.getLhs();
60f711785eSAlexander Belyaev     Value rhs = adaptor.getRhs();
61f711785eSAlexander Belyaev 
62f711785eSAlexander Belyaev     Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
63f711785eSAlexander Belyaev     Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
64f711785eSAlexander Belyaev     Value rhsSquaredPlusLhsSquared =
65f711785eSAlexander Belyaev         b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
66f711785eSAlexander Belyaev     Value sqrtOfRhsSquaredPlusLhsSquared =
67f711785eSAlexander Belyaev         b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
68f711785eSAlexander Belyaev 
69f711785eSAlexander Belyaev     Value zero =
70f711785eSAlexander Belyaev         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
71f711785eSAlexander Belyaev     Value one = b.create<arith::ConstantOp>(elementType,
72f711785eSAlexander Belyaev                                             b.getFloatAttr(elementType, 1));
73f711785eSAlexander Belyaev     Value i = b.create<complex::CreateOp>(type, zero, one);
74f711785eSAlexander Belyaev     Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
75f711785eSAlexander Belyaev     Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
76f711785eSAlexander Belyaev 
77f711785eSAlexander Belyaev     Value divResult =
78f711785eSAlexander Belyaev         b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
79f711785eSAlexander Belyaev     Value logResult = b.create<complex::LogOp>(divResult);
80f711785eSAlexander Belyaev 
81f711785eSAlexander Belyaev     Value negativeOne = b.create<arith::ConstantOp>(
82f711785eSAlexander Belyaev         elementType, b.getFloatAttr(elementType, -1));
83f711785eSAlexander Belyaev     Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
84f711785eSAlexander Belyaev 
85f711785eSAlexander Belyaev     rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
86f711785eSAlexander Belyaev     return success();
87f711785eSAlexander Belyaev   }
88f711785eSAlexander Belyaev };
89f711785eSAlexander Belyaev 
90a54f4eaeSMogball template <typename ComparisonOp, arith::CmpFPredicate p>
91fb8b2b86SAdrian Kuegel struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
92fb8b2b86SAdrian Kuegel   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
93fb8b2b86SAdrian Kuegel   using ResultCombiner =
94fb8b2b86SAdrian Kuegel       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
95a54f4eaeSMogball                          arith::AndIOp, arith::OrIOp>;
96ac00cb0dSAdrian Kuegel 
97ac00cb0dSAdrian Kuegel   LogicalResult
matchAndRewrite__anon597693150111::ComparisonOpConversion98b54c724bSRiver Riddle   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
99ac00cb0dSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
100ac00cb0dSAdrian Kuegel     auto loc = op.getLoc();
101c0342a2dSJacques Pienaar     auto type = adaptor.getLhs()
102c0342a2dSJacques Pienaar                     .getType()
103c0342a2dSJacques Pienaar                     .template cast<ComplexType>()
104c0342a2dSJacques Pienaar                     .getElementType();
105ac00cb0dSAdrian Kuegel 
106c0342a2dSJacques Pienaar     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
107c0342a2dSJacques Pienaar     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
108c0342a2dSJacques Pienaar     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
109c0342a2dSJacques Pienaar     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
110a54f4eaeSMogball     Value realComparison =
111a54f4eaeSMogball         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
112a54f4eaeSMogball     Value imagComparison =
113a54f4eaeSMogball         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
114ac00cb0dSAdrian Kuegel 
115fb8b2b86SAdrian Kuegel     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
116fb8b2b86SAdrian Kuegel                                                 imagComparison);
117ac00cb0dSAdrian Kuegel     return success();
118ac00cb0dSAdrian Kuegel   }
119ac00cb0dSAdrian Kuegel };
120942be7cbSAdrian Kuegel 
121fb978f09SAdrian Kuegel // Default conversion which applies the BinaryStandardOp separately on the real
122fb978f09SAdrian Kuegel // and imaginary parts. Can for example be used for complex::AddOp and
123fb978f09SAdrian Kuegel // complex::SubOp.
124fb978f09SAdrian Kuegel template <typename BinaryComplexOp, typename BinaryStandardOp>
125fb978f09SAdrian Kuegel struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
126fb978f09SAdrian Kuegel   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
127fb978f09SAdrian Kuegel 
128fb978f09SAdrian Kuegel   LogicalResult
matchAndRewrite__anon597693150111::BinaryComplexOpConversion129b54c724bSRiver Riddle   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
130fb978f09SAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
131c0342a2dSJacques Pienaar     auto type = adaptor.getLhs().getType().template cast<ComplexType>();
132fb978f09SAdrian Kuegel     auto elementType = type.getElementType().template cast<FloatType>();
133fb978f09SAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
134fb978f09SAdrian Kuegel 
135c0342a2dSJacques Pienaar     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
136c0342a2dSJacques Pienaar     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
137fb978f09SAdrian Kuegel     Value resultReal =
138fb978f09SAdrian Kuegel         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
139c0342a2dSJacques Pienaar     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
140c0342a2dSJacques Pienaar     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
141fb978f09SAdrian Kuegel     Value resultImag =
142fb978f09SAdrian Kuegel         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
143fb978f09SAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
144fb978f09SAdrian Kuegel                                                    resultImag);
145fb978f09SAdrian Kuegel     return success();
146fb978f09SAdrian Kuegel   }
147fb978f09SAdrian Kuegel };
148fb978f09SAdrian Kuegel 
149672b908bSGoran Flegar template <typename TrigonometricOp>
150672b908bSGoran Flegar struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
151672b908bSGoran Flegar   using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
152672b908bSGoran Flegar 
153672b908bSGoran Flegar   using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
154672b908bSGoran Flegar 
155672b908bSGoran Flegar   LogicalResult
matchAndRewrite__anon597693150111::TrigonometricOpConversion156672b908bSGoran Flegar   matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
157672b908bSGoran Flegar                   ConversionPatternRewriter &rewriter) const override {
158672b908bSGoran Flegar     auto loc = op.getLoc();
159672b908bSGoran Flegar     auto type = adaptor.getComplex().getType().template cast<ComplexType>();
160672b908bSGoran Flegar     auto elementType = type.getElementType().template cast<FloatType>();
161672b908bSGoran Flegar 
162672b908bSGoran Flegar     Value real =
163672b908bSGoran Flegar         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
164672b908bSGoran Flegar     Value imag =
165672b908bSGoran Flegar         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
166672b908bSGoran Flegar 
167672b908bSGoran Flegar     // Trigonometric ops use a set of common building blocks to convert to real
168672b908bSGoran Flegar     // ops. Here we create these building blocks and call into an op-specific
169672b908bSGoran Flegar     // implementation in the subclass to combine them.
170672b908bSGoran Flegar     Value half = rewriter.create<arith::ConstantOp>(
171672b908bSGoran Flegar         loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
172672b908bSGoran Flegar     Value exp = rewriter.create<math::ExpOp>(loc, imag);
173672b908bSGoran Flegar     Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
174672b908bSGoran Flegar     Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
175672b908bSGoran Flegar     Value sin = rewriter.create<math::SinOp>(loc, real);
176672b908bSGoran Flegar     Value cos = rewriter.create<math::CosOp>(loc, real);
177672b908bSGoran Flegar 
178672b908bSGoran Flegar     auto resultPair =
179672b908bSGoran Flegar         combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
180672b908bSGoran Flegar 
181672b908bSGoran Flegar     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
182672b908bSGoran Flegar                                                    resultPair.second);
183672b908bSGoran Flegar     return success();
184672b908bSGoran Flegar   }
185672b908bSGoran Flegar 
186672b908bSGoran Flegar   virtual std::pair<Value, Value>
187672b908bSGoran Flegar   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
188672b908bSGoran Flegar           Value cos, ConversionPatternRewriter &rewriter) const = 0;
189672b908bSGoran Flegar };
190672b908bSGoran Flegar 
191672b908bSGoran Flegar struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
192672b908bSGoran Flegar   using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
193672b908bSGoran Flegar 
194672b908bSGoran Flegar   std::pair<Value, Value>
combine__anon597693150111::CosOpConversion195672b908bSGoran Flegar   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
196672b908bSGoran Flegar           Value cos, ConversionPatternRewriter &rewriter) const override {
197672b908bSGoran Flegar     // Complex cosine is defined as;
198672b908bSGoran Flegar     //   cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
199672b908bSGoran Flegar     // Plugging in:
200672b908bSGoran Flegar     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
201672b908bSGoran Flegar     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
202672b908bSGoran Flegar     // and defining t := exp(y)
203672b908bSGoran Flegar     // We get:
204672b908bSGoran Flegar     //   Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
205672b908bSGoran Flegar     //   Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
206672b908bSGoran Flegar     Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
207672b908bSGoran Flegar     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
208672b908bSGoran Flegar     Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
209672b908bSGoran Flegar     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
210672b908bSGoran Flegar     return {resultReal, resultImag};
211672b908bSGoran Flegar   }
212672b908bSGoran Flegar };
213672b908bSGoran Flegar 
214942be7cbSAdrian Kuegel struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
215942be7cbSAdrian Kuegel   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
216942be7cbSAdrian Kuegel 
217942be7cbSAdrian Kuegel   LogicalResult
matchAndRewrite__anon597693150111::DivOpConversion218b54c724bSRiver Riddle   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
219942be7cbSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
220942be7cbSAdrian Kuegel     auto loc = op.getLoc();
221c0342a2dSJacques Pienaar     auto type = adaptor.getLhs().getType().cast<ComplexType>();
222942be7cbSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
223942be7cbSAdrian Kuegel 
224942be7cbSAdrian Kuegel     Value lhsReal =
225c0342a2dSJacques Pienaar         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
226942be7cbSAdrian Kuegel     Value lhsImag =
227c0342a2dSJacques Pienaar         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
228942be7cbSAdrian Kuegel     Value rhsReal =
229c0342a2dSJacques Pienaar         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
230942be7cbSAdrian Kuegel     Value rhsImag =
231c0342a2dSJacques Pienaar         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
232942be7cbSAdrian Kuegel 
233942be7cbSAdrian Kuegel     // Smith's algorithm to divide complex numbers. It is just a bit smarter
234942be7cbSAdrian Kuegel     // way to compute the following formula:
235942be7cbSAdrian Kuegel     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
236942be7cbSAdrian Kuegel     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
237942be7cbSAdrian Kuegel     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
238942be7cbSAdrian Kuegel     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
239942be7cbSAdrian Kuegel     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
240942be7cbSAdrian Kuegel     //
241942be7cbSAdrian Kuegel     // Depending on whether |rhsReal| < |rhsImag| we compute either
242942be7cbSAdrian Kuegel     //   rhsRealImagRatio = rhsReal / rhsImag
243942be7cbSAdrian Kuegel     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
244942be7cbSAdrian Kuegel     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
245942be7cbSAdrian Kuegel     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
246942be7cbSAdrian Kuegel     //
247942be7cbSAdrian Kuegel     // or
248942be7cbSAdrian Kuegel     //
249942be7cbSAdrian Kuegel     //   rhsImagRealRatio = rhsImag / rhsReal
250942be7cbSAdrian Kuegel     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
251942be7cbSAdrian Kuegel     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
252942be7cbSAdrian Kuegel     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
253942be7cbSAdrian Kuegel     //
254942be7cbSAdrian Kuegel     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
255a54f4eaeSMogball     Value rhsRealImagRatio =
256a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
257a54f4eaeSMogball     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
258a54f4eaeSMogball         loc, rhsImag,
259a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
260a54f4eaeSMogball     Value realNumerator1 = rewriter.create<arith::AddFOp>(
261a54f4eaeSMogball         loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
262a54f4eaeSMogball         lhsImag);
263942be7cbSAdrian Kuegel     Value resultReal1 =
264a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
265a54f4eaeSMogball     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
266a54f4eaeSMogball         loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
267a54f4eaeSMogball         lhsReal);
268942be7cbSAdrian Kuegel     Value resultImag1 =
269a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
270942be7cbSAdrian Kuegel 
271a54f4eaeSMogball     Value rhsImagRealRatio =
272a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
273a54f4eaeSMogball     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
274a54f4eaeSMogball         loc, rhsReal,
275a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
276a54f4eaeSMogball     Value realNumerator2 = rewriter.create<arith::AddFOp>(
277a54f4eaeSMogball         loc, lhsReal,
278a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
279942be7cbSAdrian Kuegel     Value resultReal2 =
280a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
281a54f4eaeSMogball     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
282a54f4eaeSMogball         loc, lhsImag,
283a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
284942be7cbSAdrian Kuegel     Value resultImag2 =
285a54f4eaeSMogball         rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
286942be7cbSAdrian Kuegel 
287942be7cbSAdrian Kuegel     // Consider corner cases.
288942be7cbSAdrian Kuegel     // Case 1. Zero denominator, numerator contains at most one NaN value.
289a54f4eaeSMogball     Value zero = rewriter.create<arith::ConstantOp>(
290a54f4eaeSMogball         loc, elementType, rewriter.getZeroAttr(elementType));
291a54f4eaeSMogball     Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
292a54f4eaeSMogball     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
293a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
294a54f4eaeSMogball     Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
295a54f4eaeSMogball     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
296a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
297a54f4eaeSMogball     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
298a54f4eaeSMogball         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
299a54f4eaeSMogball     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
300a54f4eaeSMogball         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
301942be7cbSAdrian Kuegel     Value lhsContainsNotNaNValue =
302a54f4eaeSMogball         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
303a54f4eaeSMogball     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
304942be7cbSAdrian Kuegel         loc, lhsContainsNotNaNValue,
305a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
306a54f4eaeSMogball     Value inf = rewriter.create<arith::ConstantOp>(
307942be7cbSAdrian Kuegel         loc, elementType,
308942be7cbSAdrian Kuegel         rewriter.getFloatAttr(
309942be7cbSAdrian Kuegel             elementType, APFloat::getInf(elementType.getFloatSemantics())));
310a54f4eaeSMogball     Value infWithSignOfRhsReal =
311a54f4eaeSMogball         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
312942be7cbSAdrian Kuegel     Value infinityResultReal =
313a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
314942be7cbSAdrian Kuegel     Value infinityResultImag =
315a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
316942be7cbSAdrian Kuegel 
317942be7cbSAdrian Kuegel     // Case 2. Infinite numerator, finite denominator.
318a54f4eaeSMogball     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
319a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
320a54f4eaeSMogball     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
321a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
322a54f4eaeSMogball     Value rhsFinite =
323a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
324a54f4eaeSMogball     Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
325a54f4eaeSMogball     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
326a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
327a54f4eaeSMogball     Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
328a54f4eaeSMogball     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
329a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
330942be7cbSAdrian Kuegel     Value lhsInfinite =
331a54f4eaeSMogball         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
332942be7cbSAdrian Kuegel     Value infNumFiniteDenom =
333a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
334a54f4eaeSMogball     Value one = rewriter.create<arith::ConstantOp>(
335942be7cbSAdrian Kuegel         loc, elementType, rewriter.getFloatAttr(elementType, 1));
336a54f4eaeSMogball     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
337dec8af70SRiver Riddle         loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
338942be7cbSAdrian Kuegel         lhsReal);
339a54f4eaeSMogball     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
340dec8af70SRiver Riddle         loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
341942be7cbSAdrian Kuegel         lhsImag);
342942be7cbSAdrian Kuegel     Value lhsRealIsInfWithSignTimesRhsReal =
343a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
344942be7cbSAdrian Kuegel     Value lhsImagIsInfWithSignTimesRhsImag =
345a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
346a54f4eaeSMogball     Value resultReal3 = rewriter.create<arith::MulFOp>(
347942be7cbSAdrian Kuegel         loc, inf,
348a54f4eaeSMogball         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
349942be7cbSAdrian Kuegel                                        lhsImagIsInfWithSignTimesRhsImag));
350942be7cbSAdrian Kuegel     Value lhsRealIsInfWithSignTimesRhsImag =
351a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
352942be7cbSAdrian Kuegel     Value lhsImagIsInfWithSignTimesRhsReal =
353a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
354a54f4eaeSMogball     Value resultImag3 = rewriter.create<arith::MulFOp>(
355942be7cbSAdrian Kuegel         loc, inf,
356a54f4eaeSMogball         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
357942be7cbSAdrian Kuegel                                        lhsRealIsInfWithSignTimesRhsImag));
358942be7cbSAdrian Kuegel 
359942be7cbSAdrian Kuegel     // Case 3: Finite numerator, infinite denominator.
360a54f4eaeSMogball     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
361a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
362a54f4eaeSMogball     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
363a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
364a54f4eaeSMogball     Value lhsFinite =
365a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
366a54f4eaeSMogball     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
367a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
368a54f4eaeSMogball     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
369a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
370942be7cbSAdrian Kuegel     Value rhsInfinite =
371a54f4eaeSMogball         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
372942be7cbSAdrian Kuegel     Value finiteNumInfiniteDenom =
373a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
374a54f4eaeSMogball     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
375dec8af70SRiver Riddle         loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
376942be7cbSAdrian Kuegel         rhsReal);
377a54f4eaeSMogball     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
378dec8af70SRiver Riddle         loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
379942be7cbSAdrian Kuegel         rhsImag);
380942be7cbSAdrian Kuegel     Value rhsRealIsInfWithSignTimesLhsReal =
381a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
382942be7cbSAdrian Kuegel     Value rhsImagIsInfWithSignTimesLhsImag =
383a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
384a54f4eaeSMogball     Value resultReal4 = rewriter.create<arith::MulFOp>(
385942be7cbSAdrian Kuegel         loc, zero,
386a54f4eaeSMogball         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
387942be7cbSAdrian Kuegel                                        rhsImagIsInfWithSignTimesLhsImag));
388942be7cbSAdrian Kuegel     Value rhsRealIsInfWithSignTimesLhsImag =
389a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
390942be7cbSAdrian Kuegel     Value rhsImagIsInfWithSignTimesLhsReal =
391a54f4eaeSMogball         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
392a54f4eaeSMogball     Value resultImag4 = rewriter.create<arith::MulFOp>(
393942be7cbSAdrian Kuegel         loc, zero,
394a54f4eaeSMogball         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
395942be7cbSAdrian Kuegel                                        rhsImagIsInfWithSignTimesLhsReal));
396942be7cbSAdrian Kuegel 
397a54f4eaeSMogball     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
398a54f4eaeSMogball         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
399dec8af70SRiver Riddle     Value resultReal = rewriter.create<arith::SelectOp>(
400dec8af70SRiver Riddle         loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
401dec8af70SRiver Riddle     Value resultImag = rewriter.create<arith::SelectOp>(
402dec8af70SRiver Riddle         loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
403dec8af70SRiver Riddle     Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
404942be7cbSAdrian Kuegel         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
405dec8af70SRiver Riddle     Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
406942be7cbSAdrian Kuegel         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
407dec8af70SRiver Riddle     Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
408942be7cbSAdrian Kuegel         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
409dec8af70SRiver Riddle     Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
410942be7cbSAdrian Kuegel         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
411dec8af70SRiver Riddle     Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
412942be7cbSAdrian Kuegel         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
413dec8af70SRiver Riddle     Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
414942be7cbSAdrian Kuegel         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
415942be7cbSAdrian Kuegel 
416a54f4eaeSMogball     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
417a54f4eaeSMogball         loc, arith::CmpFPredicate::UNO, resultReal, zero);
418a54f4eaeSMogball     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
419a54f4eaeSMogball         loc, arith::CmpFPredicate::UNO, resultImag, zero);
420942be7cbSAdrian Kuegel     Value resultIsNaN =
421a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
422dec8af70SRiver Riddle     Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
423942be7cbSAdrian Kuegel         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
424dec8af70SRiver Riddle     Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
425942be7cbSAdrian Kuegel         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
426942be7cbSAdrian Kuegel 
427942be7cbSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(
428942be7cbSAdrian Kuegel         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
429942be7cbSAdrian Kuegel     return success();
430942be7cbSAdrian Kuegel   }
431942be7cbSAdrian Kuegel };
43273cbc91cSAdrian Kuegel 
43373cbc91cSAdrian Kuegel struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
43473cbc91cSAdrian Kuegel   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
43573cbc91cSAdrian Kuegel 
43673cbc91cSAdrian Kuegel   LogicalResult
matchAndRewrite__anon597693150111::ExpOpConversion437b54c724bSRiver Riddle   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
43873cbc91cSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
43973cbc91cSAdrian Kuegel     auto loc = op.getLoc();
440c0342a2dSJacques Pienaar     auto type = adaptor.getComplex().getType().cast<ComplexType>();
44173cbc91cSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
44273cbc91cSAdrian Kuegel 
44373cbc91cSAdrian Kuegel     Value real =
444c0342a2dSJacques Pienaar         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
44573cbc91cSAdrian Kuegel     Value imag =
446c0342a2dSJacques Pienaar         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
44773cbc91cSAdrian Kuegel     Value expReal = rewriter.create<math::ExpOp>(loc, real);
44873cbc91cSAdrian Kuegel     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
449a54f4eaeSMogball     Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
45073cbc91cSAdrian Kuegel     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
451a54f4eaeSMogball     Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
45273cbc91cSAdrian Kuegel 
45373cbc91cSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
45473cbc91cSAdrian Kuegel                                                    resultImag);
45573cbc91cSAdrian Kuegel     return success();
45673cbc91cSAdrian Kuegel   }
45773cbc91cSAdrian Kuegel };
458662e074dSAdrian Kuegel 
459338e76f8Sbixia1 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
460338e76f8Sbixia1   using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
461338e76f8Sbixia1 
462338e76f8Sbixia1   LogicalResult
matchAndRewrite__anon597693150111::Expm1OpConversion463338e76f8Sbixia1   matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
464338e76f8Sbixia1                   ConversionPatternRewriter &rewriter) const override {
465338e76f8Sbixia1     auto type = adaptor.getComplex().getType().cast<ComplexType>();
466338e76f8Sbixia1     auto elementType = type.getElementType().cast<FloatType>();
467338e76f8Sbixia1 
468338e76f8Sbixia1     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
469338e76f8Sbixia1     Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
470338e76f8Sbixia1 
471338e76f8Sbixia1     Value real = b.create<complex::ReOp>(elementType, exp);
472338e76f8Sbixia1     Value one = b.create<arith::ConstantOp>(elementType,
473338e76f8Sbixia1                                             b.getFloatAttr(elementType, 1));
474338e76f8Sbixia1     Value realMinusOne = b.create<arith::SubFOp>(real, one);
475338e76f8Sbixia1     Value imag = b.create<complex::ImOp>(elementType, exp);
476338e76f8Sbixia1 
477338e76f8Sbixia1     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
478338e76f8Sbixia1                                                    imag);
479338e76f8Sbixia1     return success();
480338e76f8Sbixia1   }
481338e76f8Sbixia1 };
482338e76f8Sbixia1 
483380fa71fSAdrian Kuegel struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
484380fa71fSAdrian Kuegel   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
485380fa71fSAdrian Kuegel 
486380fa71fSAdrian Kuegel   LogicalResult
matchAndRewrite__anon597693150111::LogOpConversion487b54c724bSRiver Riddle   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
488380fa71fSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
489c0342a2dSJacques Pienaar     auto type = adaptor.getComplex().getType().cast<ComplexType>();
490380fa71fSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
491380fa71fSAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
492380fa71fSAdrian Kuegel 
493c0342a2dSJacques Pienaar     Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
494380fa71fSAdrian Kuegel     Value resultReal = b.create<math::LogOp>(elementType, abs);
495c0342a2dSJacques Pienaar     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
496c0342a2dSJacques Pienaar     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
497380fa71fSAdrian Kuegel     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
498380fa71fSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
499380fa71fSAdrian Kuegel                                                    resultImag);
500380fa71fSAdrian Kuegel     return success();
501380fa71fSAdrian Kuegel   }
502380fa71fSAdrian Kuegel };
503380fa71fSAdrian Kuegel 
5046e80e3bdSAdrian Kuegel struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
5056e80e3bdSAdrian Kuegel   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
5066e80e3bdSAdrian Kuegel 
5076e80e3bdSAdrian Kuegel   LogicalResult
matchAndRewrite__anon597693150111::Log1pOpConversion508b54c724bSRiver Riddle   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
5096e80e3bdSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
510c0342a2dSJacques Pienaar     auto type = adaptor.getComplex().getType().cast<ComplexType>();
5116e80e3bdSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
5126e80e3bdSAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
5136e80e3bdSAdrian Kuegel 
514c0342a2dSJacques Pienaar     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
515c0342a2dSJacques Pienaar     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
516a54f4eaeSMogball     Value one = b.create<arith::ConstantOp>(elementType,
517a54f4eaeSMogball                                             b.getFloatAttr(elementType, 1));
518a54f4eaeSMogball     Value realPlusOne = b.create<arith::AddFOp>(real, one);
5196e80e3bdSAdrian Kuegel     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
5206e80e3bdSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
5216e80e3bdSAdrian Kuegel     return success();
5226e80e3bdSAdrian Kuegel   }
5236e80e3bdSAdrian Kuegel };
5246e80e3bdSAdrian Kuegel 
525bf17ee19SAdrian Kuegel struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
526bf17ee19SAdrian Kuegel   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
527bf17ee19SAdrian Kuegel 
528bf17ee19SAdrian Kuegel   LogicalResult
matchAndRewrite__anon597693150111::MulOpConversion529b54c724bSRiver Riddle   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
530bf17ee19SAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
531bf17ee19SAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
532c0342a2dSJacques Pienaar     auto type = adaptor.getLhs().getType().cast<ComplexType>();
533bf17ee19SAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
534bf17ee19SAdrian Kuegel 
535c0342a2dSJacques Pienaar     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
536a54f4eaeSMogball     Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
537c0342a2dSJacques Pienaar     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
538a54f4eaeSMogball     Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
539c0342a2dSJacques Pienaar     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
540a54f4eaeSMogball     Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
541c0342a2dSJacques Pienaar     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
542a54f4eaeSMogball     Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
543bf17ee19SAdrian Kuegel 
544a54f4eaeSMogball     Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
545a54f4eaeSMogball     Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
546a54f4eaeSMogball     Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
547a54f4eaeSMogball     Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
548a54f4eaeSMogball     Value real =
549a54f4eaeSMogball         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
550bf17ee19SAdrian Kuegel 
551a54f4eaeSMogball     Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
552a54f4eaeSMogball     Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
553a54f4eaeSMogball     Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
554a54f4eaeSMogball     Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
555a54f4eaeSMogball     Value imag =
556a54f4eaeSMogball         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
557bf17ee19SAdrian Kuegel 
558bf17ee19SAdrian Kuegel     // Handle cases where the "naive" calculation results in NaN values.
559a54f4eaeSMogball     Value realIsNan =
560a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
561a54f4eaeSMogball     Value imagIsNan =
562a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
563a54f4eaeSMogball     Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
564bf17ee19SAdrian Kuegel 
565a54f4eaeSMogball     Value inf = b.create<arith::ConstantOp>(
566bf17ee19SAdrian Kuegel         elementType,
567bf17ee19SAdrian Kuegel         b.getFloatAttr(elementType,
568bf17ee19SAdrian Kuegel                        APFloat::getInf(elementType.getFloatSemantics())));
569bf17ee19SAdrian Kuegel 
570bf17ee19SAdrian Kuegel     // Case 1. `lhsReal` or `lhsImag` are infinite.
571a54f4eaeSMogball     Value lhsRealIsInf =
572a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
573a54f4eaeSMogball     Value lhsImagIsInf =
574a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
575a54f4eaeSMogball     Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
576a54f4eaeSMogball     Value rhsRealIsNan =
577a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
578a54f4eaeSMogball     Value rhsImagIsNan =
579a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
580a54f4eaeSMogball     Value zero =
581a54f4eaeSMogball         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
582a54f4eaeSMogball     Value one = b.create<arith::ConstantOp>(elementType,
583a54f4eaeSMogball                                             b.getFloatAttr(elementType, 1));
584dec8af70SRiver Riddle     Value lhsRealIsInfFloat =
585dec8af70SRiver Riddle         b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
586dec8af70SRiver Riddle     lhsReal = b.create<arith::SelectOp>(
587a54f4eaeSMogball         lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
588a54f4eaeSMogball         lhsReal);
589dec8af70SRiver Riddle     Value lhsImagIsInfFloat =
590dec8af70SRiver Riddle         b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
591dec8af70SRiver Riddle     lhsImag = b.create<arith::SelectOp>(
592a54f4eaeSMogball         lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
593a54f4eaeSMogball         lhsImag);
594a54f4eaeSMogball     Value lhsIsInfAndRhsRealIsNan =
595a54f4eaeSMogball         b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
596dec8af70SRiver Riddle     rhsReal = b.create<arith::SelectOp>(
597dec8af70SRiver Riddle         lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
598dec8af70SRiver Riddle         rhsReal);
599a54f4eaeSMogball     Value lhsIsInfAndRhsImagIsNan =
600a54f4eaeSMogball         b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
601dec8af70SRiver Riddle     rhsImag = b.create<arith::SelectOp>(
602dec8af70SRiver Riddle         lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
603dec8af70SRiver Riddle         rhsImag);
604bf17ee19SAdrian Kuegel 
605bf17ee19SAdrian Kuegel     // Case 2. `rhsReal` or `rhsImag` are infinite.
606a54f4eaeSMogball     Value rhsRealIsInf =
607a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
608a54f4eaeSMogball     Value rhsImagIsInf =
609a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
610a54f4eaeSMogball     Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
611a54f4eaeSMogball     Value lhsRealIsNan =
612a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
613a54f4eaeSMogball     Value lhsImagIsNan =
614a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
615dec8af70SRiver Riddle     Value rhsRealIsInfFloat =
616dec8af70SRiver Riddle         b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
617dec8af70SRiver Riddle     rhsReal = b.create<arith::SelectOp>(
618a54f4eaeSMogball         rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
619a54f4eaeSMogball         rhsReal);
620dec8af70SRiver Riddle     Value rhsImagIsInfFloat =
621dec8af70SRiver Riddle         b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
622dec8af70SRiver Riddle     rhsImag = b.create<arith::SelectOp>(
623a54f4eaeSMogball         rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
624a54f4eaeSMogball         rhsImag);
625a54f4eaeSMogball     Value rhsIsInfAndLhsRealIsNan =
626a54f4eaeSMogball         b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
627dec8af70SRiver Riddle     lhsReal = b.create<arith::SelectOp>(
628dec8af70SRiver Riddle         rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
629dec8af70SRiver Riddle         lhsReal);
630a54f4eaeSMogball     Value rhsIsInfAndLhsImagIsNan =
631a54f4eaeSMogball         b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
632dec8af70SRiver Riddle     lhsImag = b.create<arith::SelectOp>(
633dec8af70SRiver Riddle         rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
634dec8af70SRiver Riddle         lhsImag);
635a54f4eaeSMogball     Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
636bf17ee19SAdrian Kuegel 
637bf17ee19SAdrian Kuegel     // Case 3. One of the pairwise products of left hand side with right hand
638bf17ee19SAdrian Kuegel     // side is infinite.
639a54f4eaeSMogball     Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
640a54f4eaeSMogball         arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
641a54f4eaeSMogball     Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
642a54f4eaeSMogball         arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
643a54f4eaeSMogball     Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
644a54f4eaeSMogball                                                  lhsImagTimesRhsImagIsInf);
645a54f4eaeSMogball     Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
646a54f4eaeSMogball         arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
647a54f4eaeSMogball     isSpecialCase =
648a54f4eaeSMogball         b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
649a54f4eaeSMogball     Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
650a54f4eaeSMogball         arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
651a54f4eaeSMogball     isSpecialCase =
652a54f4eaeSMogball         b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
653bf17ee19SAdrian Kuegel     Type i1Type = b.getI1Type();
654a54f4eaeSMogball     Value notRecalc = b.create<arith::XOrIOp>(
655a54f4eaeSMogball         recalc,
656a54f4eaeSMogball         b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
657a54f4eaeSMogball     isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
658bf17ee19SAdrian Kuegel     Value isSpecialCaseAndLhsRealIsNan =
659a54f4eaeSMogball         b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
660dec8af70SRiver Riddle     lhsReal = b.create<arith::SelectOp>(
661dec8af70SRiver Riddle         isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
662dec8af70SRiver Riddle         lhsReal);
663bf17ee19SAdrian Kuegel     Value isSpecialCaseAndLhsImagIsNan =
664a54f4eaeSMogball         b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
665dec8af70SRiver Riddle     lhsImag = b.create<arith::SelectOp>(
666dec8af70SRiver Riddle         isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
667dec8af70SRiver Riddle         lhsImag);
668bf17ee19SAdrian Kuegel     Value isSpecialCaseAndRhsRealIsNan =
669a54f4eaeSMogball         b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
670dec8af70SRiver Riddle     rhsReal = b.create<arith::SelectOp>(
671dec8af70SRiver Riddle         isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
672dec8af70SRiver Riddle         rhsReal);
673bf17ee19SAdrian Kuegel     Value isSpecialCaseAndRhsImagIsNan =
674a54f4eaeSMogball         b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
675dec8af70SRiver Riddle     rhsImag = b.create<arith::SelectOp>(
676dec8af70SRiver Riddle         isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
677dec8af70SRiver Riddle         rhsImag);
678a54f4eaeSMogball     recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
679a54f4eaeSMogball     recalc = b.create<arith::AndIOp>(isNan, recalc);
680bf17ee19SAdrian Kuegel 
681bf17ee19SAdrian Kuegel     // Recalculate real part.
682a54f4eaeSMogball     lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
683a54f4eaeSMogball     lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
684a54f4eaeSMogball     Value newReal =
685a54f4eaeSMogball         b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
686dec8af70SRiver Riddle     real = b.create<arith::SelectOp>(
687dec8af70SRiver Riddle         recalc, b.create<arith::MulFOp>(inf, newReal), real);
688bf17ee19SAdrian Kuegel 
689bf17ee19SAdrian Kuegel     // Recalculate imag part.
690a54f4eaeSMogball     lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
691a54f4eaeSMogball     lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
692a54f4eaeSMogball     Value newImag =
693a54f4eaeSMogball         b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
694dec8af70SRiver Riddle     imag = b.create<arith::SelectOp>(
695dec8af70SRiver Riddle         recalc, b.create<arith::MulFOp>(inf, newImag), imag);
696bf17ee19SAdrian Kuegel 
697bf17ee19SAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
698bf17ee19SAdrian Kuegel     return success();
699bf17ee19SAdrian Kuegel   }
700bf17ee19SAdrian Kuegel };
701bf17ee19SAdrian Kuegel 
702662e074dSAdrian Kuegel struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
703662e074dSAdrian Kuegel   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
704662e074dSAdrian Kuegel 
705662e074dSAdrian Kuegel   LogicalResult
matchAndRewrite__anon597693150111::NegOpConversion706b54c724bSRiver Riddle   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
707662e074dSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
708662e074dSAdrian Kuegel     auto loc = op.getLoc();
709c0342a2dSJacques Pienaar     auto type = adaptor.getComplex().getType().cast<ComplexType>();
710662e074dSAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
711662e074dSAdrian Kuegel 
712662e074dSAdrian Kuegel     Value real =
713c0342a2dSJacques Pienaar         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
714662e074dSAdrian Kuegel     Value imag =
715c0342a2dSJacques Pienaar         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
716a54f4eaeSMogball     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
717a54f4eaeSMogball     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
718662e074dSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
719662e074dSAdrian Kuegel     return success();
720662e074dSAdrian Kuegel   }
721662e074dSAdrian Kuegel };
722f112bd61SAdrian Kuegel 
723672b908bSGoran Flegar struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
724672b908bSGoran Flegar   using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
725672b908bSGoran Flegar 
726672b908bSGoran Flegar   std::pair<Value, Value>
combine__anon597693150111::SinOpConversion727672b908bSGoran Flegar   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
728672b908bSGoran Flegar           Value cos, ConversionPatternRewriter &rewriter) const override {
729672b908bSGoran Flegar     // Complex sine is defined as;
730672b908bSGoran Flegar     //   sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
731672b908bSGoran Flegar     // Plugging in:
732672b908bSGoran Flegar     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
733672b908bSGoran Flegar     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
734672b908bSGoran Flegar     // and defining t := exp(y)
735672b908bSGoran Flegar     // We get:
736672b908bSGoran Flegar     //   Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
737672b908bSGoran Flegar     //   Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
738672b908bSGoran Flegar     Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
739672b908bSGoran Flegar     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
740672b908bSGoran Flegar     Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
741672b908bSGoran Flegar     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
742672b908bSGoran Flegar     return {resultReal, resultImag};
743672b908bSGoran Flegar   }
744672b908bSGoran Flegar };
745672b908bSGoran Flegar 
746f711785eSAlexander Belyaev // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
747f711785eSAlexander Belyaev struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
748f711785eSAlexander Belyaev   using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
749f711785eSAlexander Belyaev 
750f711785eSAlexander Belyaev   LogicalResult
matchAndRewrite__anon597693150111::SqrtOpConversion751f711785eSAlexander Belyaev   matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
752f711785eSAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
753f711785eSAlexander Belyaev     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
754f711785eSAlexander Belyaev 
755f711785eSAlexander Belyaev     auto type = op.getType().cast<ComplexType>();
756f711785eSAlexander Belyaev     Type elementType = type.getElementType();
757f711785eSAlexander Belyaev     Value arg = adaptor.getComplex();
758f711785eSAlexander Belyaev 
759f711785eSAlexander Belyaev     Value zero =
760f711785eSAlexander Belyaev         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
761f711785eSAlexander Belyaev 
762f711785eSAlexander Belyaev     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
763f711785eSAlexander Belyaev     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
764f711785eSAlexander Belyaev 
765f711785eSAlexander Belyaev     Value absLhs = b.create<math::AbsOp>(real);
766f711785eSAlexander Belyaev     Value absArg = b.create<complex::AbsOp>(elementType, arg);
767f711785eSAlexander Belyaev     Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
768f711785eSAlexander Belyaev 
769*b7f93c28SJeff Niu     Value half = b.create<arith::ConstantOp>(elementType,
770*b7f93c28SJeff Niu                                              b.getFloatAttr(elementType, 0.5));
771f711785eSAlexander Belyaev     Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
772f711785eSAlexander Belyaev     Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
773f711785eSAlexander Belyaev 
774f711785eSAlexander Belyaev     Value realIsNegative =
775f711785eSAlexander Belyaev         b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
776f711785eSAlexander Belyaev     Value imagIsNegative =
777f711785eSAlexander Belyaev         b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
778f711785eSAlexander Belyaev 
779f711785eSAlexander Belyaev     Value resultReal = sqrtAddAbs;
780f711785eSAlexander Belyaev 
781f711785eSAlexander Belyaev     Value imagDivTwoResultReal = b.create<arith::DivFOp>(
782f711785eSAlexander Belyaev         imag, b.create<arith::AddFOp>(resultReal, resultReal));
783f711785eSAlexander Belyaev 
784f711785eSAlexander Belyaev     Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
785f711785eSAlexander Belyaev 
786f711785eSAlexander Belyaev     Value resultImag = b.create<arith::SelectOp>(
787f711785eSAlexander Belyaev         realIsNegative,
788f711785eSAlexander Belyaev         b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
789f711785eSAlexander Belyaev                                   resultReal),
790f711785eSAlexander Belyaev         imagDivTwoResultReal);
791f711785eSAlexander Belyaev 
792f711785eSAlexander Belyaev     resultReal = b.create<arith::SelectOp>(
793f711785eSAlexander Belyaev         realIsNegative,
794f711785eSAlexander Belyaev         b.create<arith::DivFOp>(
795f711785eSAlexander Belyaev             imag, b.create<arith::AddFOp>(resultImag, resultImag)),
796f711785eSAlexander Belyaev         resultReal);
797f711785eSAlexander Belyaev 
798f711785eSAlexander Belyaev     Value realIsZero =
799f711785eSAlexander Belyaev         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
800f711785eSAlexander Belyaev     Value imagIsZero =
801f711785eSAlexander Belyaev         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
802f711785eSAlexander Belyaev     Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
803f711785eSAlexander Belyaev 
804f711785eSAlexander Belyaev     resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
805f711785eSAlexander Belyaev     resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
806f711785eSAlexander Belyaev 
807f711785eSAlexander Belyaev     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
808f711785eSAlexander Belyaev                                                    resultImag);
809f711785eSAlexander Belyaev     return success();
810f711785eSAlexander Belyaev   }
811f711785eSAlexander Belyaev };
812f711785eSAlexander Belyaev 
813f112bd61SAdrian Kuegel struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
814f112bd61SAdrian Kuegel   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
815f112bd61SAdrian Kuegel 
816f112bd61SAdrian Kuegel   LogicalResult
matchAndRewrite__anon597693150111::SignOpConversion817b54c724bSRiver Riddle   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
818f112bd61SAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
819c0342a2dSJacques Pienaar     auto type = adaptor.getComplex().getType().cast<ComplexType>();
820f112bd61SAdrian Kuegel     auto elementType = type.getElementType().cast<FloatType>();
821f112bd61SAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
822f112bd61SAdrian Kuegel 
823c0342a2dSJacques Pienaar     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
824c0342a2dSJacques Pienaar     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
825a54f4eaeSMogball     Value zero =
826a54f4eaeSMogball         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
827a54f4eaeSMogball     Value realIsZero =
828a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
829a54f4eaeSMogball     Value imagIsZero =
830a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
831a54f4eaeSMogball     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
832c0342a2dSJacques Pienaar     auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
833a54f4eaeSMogball     Value realSign = b.create<arith::DivFOp>(real, abs);
834a54f4eaeSMogball     Value imagSign = b.create<arith::DivFOp>(imag, abs);
835f112bd61SAdrian Kuegel     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
836dec8af70SRiver Riddle     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
837dec8af70SRiver Riddle                                                  adaptor.getComplex(), sign);
838f112bd61SAdrian Kuegel     return success();
839f112bd61SAdrian Kuegel   }
840f112bd61SAdrian Kuegel };
8416d75c897Slewuathe 
8426d75c897Slewuathe struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
8436d75c897Slewuathe   using OpConversionPattern<complex::TanOp>::OpConversionPattern;
8446d75c897Slewuathe 
8456d75c897Slewuathe   LogicalResult
matchAndRewrite__anon597693150111::TanOpConversion8466d75c897Slewuathe   matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
8476d75c897Slewuathe                   ConversionPatternRewriter &rewriter) const override {
8486d75c897Slewuathe     auto loc = op.getLoc();
8496d75c897Slewuathe     Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex());
8506d75c897Slewuathe     Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex());
8516d75c897Slewuathe     rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos);
8526d75c897Slewuathe     return success();
8536d75c897Slewuathe   }
8546d75c897Slewuathe };
855ffb8eecdSlewuathe 
856ffb8eecdSlewuathe struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
857ffb8eecdSlewuathe   using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
858ffb8eecdSlewuathe 
859ffb8eecdSlewuathe   LogicalResult
matchAndRewrite__anon597693150111::TanhOpConversion860ffb8eecdSlewuathe   matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
861ffb8eecdSlewuathe                   ConversionPatternRewriter &rewriter) const override {
862ffb8eecdSlewuathe     auto loc = op.getLoc();
863ffb8eecdSlewuathe     auto type = adaptor.getComplex().getType().cast<ComplexType>();
864ffb8eecdSlewuathe     auto elementType = type.getElementType().cast<FloatType>();
865ffb8eecdSlewuathe 
866ffb8eecdSlewuathe     // The hyperbolic tangent for complex number can be calculated as follows.
867ffb8eecdSlewuathe     // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
868ffb8eecdSlewuathe     // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
869ffb8eecdSlewuathe     Value real =
870ffb8eecdSlewuathe         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
871ffb8eecdSlewuathe     Value imag =
872ffb8eecdSlewuathe         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
873ffb8eecdSlewuathe     Value tanhA = rewriter.create<math::TanhOp>(loc, real);
874ffb8eecdSlewuathe     Value cosB = rewriter.create<math::CosOp>(loc, imag);
875ffb8eecdSlewuathe     Value sinB = rewriter.create<math::SinOp>(loc, imag);
876ffb8eecdSlewuathe     Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
877ffb8eecdSlewuathe     Value numerator =
878ffb8eecdSlewuathe         rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
879ffb8eecdSlewuathe     Value one = rewriter.create<arith::ConstantOp>(
880ffb8eecdSlewuathe         loc, elementType, rewriter.getFloatAttr(elementType, 1));
881ffb8eecdSlewuathe     Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
882ffb8eecdSlewuathe     Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
883ffb8eecdSlewuathe     rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
884ffb8eecdSlewuathe     return success();
885ffb8eecdSlewuathe   }
886ffb8eecdSlewuathe };
887ffb8eecdSlewuathe 
88862a34f6aSlewuathe struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
88962a34f6aSlewuathe   using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
89062a34f6aSlewuathe 
89162a34f6aSlewuathe   LogicalResult
matchAndRewrite__anon597693150111::ConjOpConversion89262a34f6aSlewuathe   matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
89362a34f6aSlewuathe                   ConversionPatternRewriter &rewriter) const override {
89462a34f6aSlewuathe     auto loc = op.getLoc();
89562a34f6aSlewuathe     auto type = adaptor.getComplex().getType().cast<ComplexType>();
89662a34f6aSlewuathe     auto elementType = type.getElementType().cast<FloatType>();
89762a34f6aSlewuathe     Value real =
89862a34f6aSlewuathe         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
89962a34f6aSlewuathe     Value imag =
90062a34f6aSlewuathe         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
90162a34f6aSlewuathe     Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
90262a34f6aSlewuathe 
90362a34f6aSlewuathe     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
90462a34f6aSlewuathe 
90562a34f6aSlewuathe     return success();
90662a34f6aSlewuathe   }
90762a34f6aSlewuathe };
90862a34f6aSlewuathe 
9096c6eddb6Sbixia1 /// Coverts x^y = (a+bi)^(c+di) to
9106c6eddb6Sbixia1 ///    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
9116c6eddb6Sbixia1 ///    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
powOpConversionImpl(mlir::ImplicitLocOpBuilder & builder,ComplexType type,Value a,Value b,Value c,Value d)9126c6eddb6Sbixia1 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
9136c6eddb6Sbixia1                                  ComplexType type, Value a, Value b, Value c,
9146c6eddb6Sbixia1                                  Value d) {
9156c6eddb6Sbixia1   auto elementType = type.getElementType().cast<FloatType>();
9166c6eddb6Sbixia1 
9176c6eddb6Sbixia1   // Compute (a*a+b*b)^(0.5c).
9186c6eddb6Sbixia1   Value aaPbb = builder.create<arith::AddFOp>(
9196c6eddb6Sbixia1       builder.create<arith::MulFOp>(a, a), builder.create<arith::MulFOp>(b, b));
9206c6eddb6Sbixia1   Value half = builder.create<arith::ConstantOp>(
9216c6eddb6Sbixia1       elementType, builder.getFloatAttr(elementType, 0.5));
9226c6eddb6Sbixia1   Value halfC = builder.create<arith::MulFOp>(half, c);
9236c6eddb6Sbixia1   Value aaPbbTohalfC = builder.create<math::PowFOp>(aaPbb, halfC);
9246c6eddb6Sbixia1 
9256c6eddb6Sbixia1   // Compute exp(-d*atan2(b,a)).
9266c6eddb6Sbixia1   Value negD = builder.create<arith::NegFOp>(d);
9276c6eddb6Sbixia1   Value argX = builder.create<math::Atan2Op>(b, a);
9286c6eddb6Sbixia1   Value negDArgX = builder.create<arith::MulFOp>(negD, argX);
9296c6eddb6Sbixia1   Value eToNegDArgX = builder.create<math::ExpOp>(negDArgX);
9306c6eddb6Sbixia1 
9316c6eddb6Sbixia1   // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)).
9326c6eddb6Sbixia1   Value coeff = builder.create<arith::MulFOp>(aaPbbTohalfC, eToNegDArgX);
9336c6eddb6Sbixia1 
9346c6eddb6Sbixia1   // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b).
9356c6eddb6Sbixia1   Value lnAaPbb = builder.create<math::LogOp>(aaPbb);
9366c6eddb6Sbixia1   Value halfD = builder.create<arith::MulFOp>(half, d);
9376c6eddb6Sbixia1   Value q = builder.create<arith::AddFOp>(
9386c6eddb6Sbixia1       builder.create<arith::MulFOp>(c, argX),
9396c6eddb6Sbixia1       builder.create<arith::MulFOp>(halfD, lnAaPbb));
9406c6eddb6Sbixia1 
9416c6eddb6Sbixia1   Value cosQ = builder.create<math::CosOp>(q);
9426c6eddb6Sbixia1   Value sinQ = builder.create<math::SinOp>(q);
9436c6eddb6Sbixia1   Value zero = builder.create<arith::ConstantOp>(
9446c6eddb6Sbixia1       elementType, builder.getFloatAttr(elementType, 0));
9456c6eddb6Sbixia1   Value one = builder.create<arith::ConstantOp>(
9466c6eddb6Sbixia1       elementType, builder.getFloatAttr(elementType, 1));
9476c6eddb6Sbixia1 
9486c6eddb6Sbixia1   Value xEqZero =
9496c6eddb6Sbixia1       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, aaPbb, zero);
9506c6eddb6Sbixia1   Value yGeZero = builder.create<arith::AndIOp>(
9516c6eddb6Sbixia1       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, c, zero),
9526c6eddb6Sbixia1       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero));
9536c6eddb6Sbixia1   Value cEqZero =
9546c6eddb6Sbixia1       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero);
9556c6eddb6Sbixia1   Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
9566c6eddb6Sbixia1   Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
9576c6eddb6Sbixia1   Value complexOther = builder.create<complex::CreateOp>(
9586c6eddb6Sbixia1       type, builder.create<arith::MulFOp>(coeff, cosQ),
9596c6eddb6Sbixia1       builder.create<arith::MulFOp>(coeff, sinQ));
9606c6eddb6Sbixia1 
9616c6eddb6Sbixia1   // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see
9626c6eddb6Sbixia1   // Branch Cuts for Complex Elementary Functions or Much Ado About
9636c6eddb6Sbixia1   // Nothing's Sign Bit, W. Kahan, Section 10.
9646c6eddb6Sbixia1   return builder.create<arith::SelectOp>(
9656c6eddb6Sbixia1       builder.create<arith::AndIOp>(xEqZero, yGeZero),
9666c6eddb6Sbixia1       builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero),
9676c6eddb6Sbixia1       complexOther);
9686c6eddb6Sbixia1 }
9696c6eddb6Sbixia1 
9706c6eddb6Sbixia1 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
9716c6eddb6Sbixia1   using OpConversionPattern<complex::PowOp>::OpConversionPattern;
9726c6eddb6Sbixia1 
9736c6eddb6Sbixia1   LogicalResult
matchAndRewrite__anon597693150111::PowOpConversion9746c6eddb6Sbixia1   matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
9756c6eddb6Sbixia1                   ConversionPatternRewriter &rewriter) const override {
9766c6eddb6Sbixia1     mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
9776c6eddb6Sbixia1     auto type = adaptor.getLhs().getType().cast<ComplexType>();
9786c6eddb6Sbixia1     auto elementType = type.getElementType().cast<FloatType>();
9796c6eddb6Sbixia1 
9806c6eddb6Sbixia1     Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs());
9816c6eddb6Sbixia1     Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs());
9826c6eddb6Sbixia1     Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
9836c6eddb6Sbixia1     Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
9846c6eddb6Sbixia1 
9856c6eddb6Sbixia1     rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
9866c6eddb6Sbixia1     return success();
9876c6eddb6Sbixia1   }
9886c6eddb6Sbixia1 };
9896c6eddb6Sbixia1 
9906c6eddb6Sbixia1 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
9916c6eddb6Sbixia1   using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
9926c6eddb6Sbixia1 
9936c6eddb6Sbixia1   LogicalResult
matchAndRewrite__anon597693150111::RsqrtOpConversion9946c6eddb6Sbixia1   matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
9956c6eddb6Sbixia1                   ConversionPatternRewriter &rewriter) const override {
9966c6eddb6Sbixia1     mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
9976c6eddb6Sbixia1     auto type = adaptor.getComplex().getType().cast<ComplexType>();
9986c6eddb6Sbixia1     auto elementType = type.getElementType().cast<FloatType>();
9996c6eddb6Sbixia1 
10006c6eddb6Sbixia1     Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex());
10016c6eddb6Sbixia1     Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex());
10026c6eddb6Sbixia1     Value c = builder.create<arith::ConstantOp>(
10036c6eddb6Sbixia1         elementType, builder.getFloatAttr(elementType, -0.5));
10046c6eddb6Sbixia1     Value d = builder.create<arith::ConstantOp>(
10056c6eddb6Sbixia1         elementType, builder.getFloatAttr(elementType, 0));
10066c6eddb6Sbixia1 
10076c6eddb6Sbixia1     rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
10086c6eddb6Sbixia1     return success();
10096c6eddb6Sbixia1   }
10106c6eddb6Sbixia1 };
10116c6eddb6Sbixia1 
10128fa2e679SLewuathe struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
10138fa2e679SLewuathe   using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
10148fa2e679SLewuathe 
10158fa2e679SLewuathe   LogicalResult
matchAndRewrite__anon597693150111::AngleOpConversion10168fa2e679SLewuathe   matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
10178fa2e679SLewuathe                   ConversionPatternRewriter &rewriter) const override {
10188fa2e679SLewuathe     auto loc = op.getLoc();
10198fa2e679SLewuathe     auto type = op.getType();
10208fa2e679SLewuathe 
10218fa2e679SLewuathe     Value real =
10228fa2e679SLewuathe         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
10238fa2e679SLewuathe     Value imag =
10248fa2e679SLewuathe         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
10258fa2e679SLewuathe 
10268fa2e679SLewuathe     rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
10278fa2e679SLewuathe 
10288fa2e679SLewuathe     return success();
10298fa2e679SLewuathe   }
10308fa2e679SLewuathe };
10318fa2e679SLewuathe 
10322ea7fb7bSAdrian Kuegel } // namespace
10332ea7fb7bSAdrian Kuegel 
populateComplexToStandardConversionPatterns(RewritePatternSet & patterns)10342ea7fb7bSAdrian Kuegel void mlir::populateComplexToStandardConversionPatterns(
10352ea7fb7bSAdrian Kuegel     RewritePatternSet &patterns) {
1036f112bd61SAdrian Kuegel   // clang-format off
1037f112bd61SAdrian Kuegel   patterns.add<
1038f112bd61SAdrian Kuegel       AbsOpConversion,
10398fa2e679SLewuathe       AngleOpConversion,
1040f711785eSAlexander Belyaev       Atan2OpConversion,
1041a54f4eaeSMogball       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1042a54f4eaeSMogball       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
104362a34f6aSlewuathe       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
104462a34f6aSlewuathe       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
104562a34f6aSlewuathe       ConjOpConversion,
1046672b908bSGoran Flegar       CosOpConversion,
1047f112bd61SAdrian Kuegel       DivOpConversion,
1048f112bd61SAdrian Kuegel       ExpOpConversion,
1049338e76f8Sbixia1       Expm1OpConversion,
10506e80e3bdSAdrian Kuegel       Log1pOpConversion,
105162a34f6aSlewuathe       LogOpConversion,
1052bf17ee19SAdrian Kuegel       MulOpConversion,
1053f112bd61SAdrian Kuegel       NegOpConversion,
1054672b908bSGoran Flegar       SignOpConversion,
10556d75c897Slewuathe       SinOpConversion,
1056f711785eSAlexander Belyaev       SqrtOpConversion,
1057ffb8eecdSlewuathe       TanOpConversion,
10586c6eddb6Sbixia1       TanhOpConversion,
10596c6eddb6Sbixia1       PowOpConversion,
10606c6eddb6Sbixia1       RsqrtOpConversion
106162a34f6aSlewuathe   >(patterns.getContext());
1062f112bd61SAdrian Kuegel   // clang-format on
10632ea7fb7bSAdrian Kuegel }
10642ea7fb7bSAdrian Kuegel 
10652ea7fb7bSAdrian Kuegel namespace {
10662ea7fb7bSAdrian Kuegel struct ConvertComplexToStandardPass
10672ea7fb7bSAdrian Kuegel     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
106841574554SRiver Riddle   void runOnOperation() override;
10692ea7fb7bSAdrian Kuegel };
10702ea7fb7bSAdrian Kuegel 
runOnOperation()107141574554SRiver Riddle void ConvertComplexToStandardPass::runOnOperation() {
10722ea7fb7bSAdrian Kuegel   // Convert to the Standard dialect using the converter defined above.
10732ea7fb7bSAdrian Kuegel   RewritePatternSet patterns(&getContext());
10742ea7fb7bSAdrian Kuegel   populateComplexToStandardConversionPatterns(patterns);
10752ea7fb7bSAdrian Kuegel 
10762ea7fb7bSAdrian Kuegel   ConversionTarget target(getContext());
10771f971e23SRiver Riddle   target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>();
1078fb978f09SAdrian Kuegel   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
107947f175b0SRiver Riddle   if (failed(
108047f175b0SRiver Riddle           applyPartialConversion(getOperation(), target, std::move(patterns))))
10812ea7fb7bSAdrian Kuegel     signalPassFailure();
10822ea7fb7bSAdrian Kuegel }
10832ea7fb7bSAdrian Kuegel } // namespace
10842ea7fb7bSAdrian Kuegel 
createConvertComplexToStandardPass()108547f175b0SRiver Riddle std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
10862ea7fb7bSAdrian Kuegel   return std::make_unique<ConvertComplexToStandardPass>();
10872ea7fb7bSAdrian Kuegel }
1088