12ea7fb7bSAdrian Kuegel //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
22ea7fb7bSAdrian Kuegel //
32ea7fb7bSAdrian Kuegel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42ea7fb7bSAdrian Kuegel // See https://llvm.org/LICENSE.txt for license information.
52ea7fb7bSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62ea7fb7bSAdrian Kuegel //
72ea7fb7bSAdrian Kuegel //===----------------------------------------------------------------------===//
82ea7fb7bSAdrian Kuegel
92ea7fb7bSAdrian Kuegel #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
102ea7fb7bSAdrian Kuegel
112ea7fb7bSAdrian Kuegel #include <memory>
12fb8b2b86SAdrian Kuegel #include <type_traits>
132ea7fb7bSAdrian Kuegel
142ea7fb7bSAdrian Kuegel #include "../PassDetail.h"
15a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
162ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Complex/IR/Complex.h"
172ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Math/IR/Math.h"
18f112bd61SAdrian Kuegel #include "mlir/IR/ImplicitLocOpBuilder.h"
192ea7fb7bSAdrian Kuegel #include "mlir/IR/PatternMatch.h"
202ea7fb7bSAdrian Kuegel #include "mlir/Transforms/DialectConversion.h"
212ea7fb7bSAdrian Kuegel
222ea7fb7bSAdrian Kuegel using namespace mlir;
232ea7fb7bSAdrian Kuegel
242ea7fb7bSAdrian Kuegel namespace {
252ea7fb7bSAdrian Kuegel struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
262ea7fb7bSAdrian Kuegel using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
272ea7fb7bSAdrian Kuegel
282ea7fb7bSAdrian Kuegel LogicalResult
matchAndRewrite__anon597693150111::AbsOpConversion29b54c724bSRiver Riddle matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
302ea7fb7bSAdrian Kuegel ConversionPatternRewriter &rewriter) const override {
312ea7fb7bSAdrian Kuegel auto loc = op.getLoc();
322ea7fb7bSAdrian Kuegel auto type = op.getType();
332ea7fb7bSAdrian Kuegel
34c0342a2dSJacques Pienaar Value real =
35c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
36c0342a2dSJacques Pienaar Value imag =
37c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
38a54f4eaeSMogball Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
39a54f4eaeSMogball Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
40a54f4eaeSMogball Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
412ea7fb7bSAdrian Kuegel
422ea7fb7bSAdrian Kuegel rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
432ea7fb7bSAdrian Kuegel return success();
442ea7fb7bSAdrian Kuegel }
452ea7fb7bSAdrian Kuegel };
46ac00cb0dSAdrian Kuegel
47f711785eSAlexander Belyaev // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
48f711785eSAlexander Belyaev struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
49f711785eSAlexander Belyaev using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
50f711785eSAlexander Belyaev
51f711785eSAlexander Belyaev LogicalResult
matchAndRewrite__anon597693150111::Atan2OpConversion52f711785eSAlexander Belyaev matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
53f711785eSAlexander Belyaev ConversionPatternRewriter &rewriter) const override {
54f711785eSAlexander Belyaev mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
55f711785eSAlexander Belyaev
56f711785eSAlexander Belyaev auto type = op.getType().cast<ComplexType>();
57f711785eSAlexander Belyaev Type elementType = type.getElementType();
58f711785eSAlexander Belyaev
59f711785eSAlexander Belyaev Value lhs = adaptor.getLhs();
60f711785eSAlexander Belyaev Value rhs = adaptor.getRhs();
61f711785eSAlexander Belyaev
62f711785eSAlexander Belyaev Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
63f711785eSAlexander Belyaev Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
64f711785eSAlexander Belyaev Value rhsSquaredPlusLhsSquared =
65f711785eSAlexander Belyaev b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
66f711785eSAlexander Belyaev Value sqrtOfRhsSquaredPlusLhsSquared =
67f711785eSAlexander Belyaev b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
68f711785eSAlexander Belyaev
69f711785eSAlexander Belyaev Value zero =
70f711785eSAlexander Belyaev b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
71f711785eSAlexander Belyaev Value one = b.create<arith::ConstantOp>(elementType,
72f711785eSAlexander Belyaev b.getFloatAttr(elementType, 1));
73f711785eSAlexander Belyaev Value i = b.create<complex::CreateOp>(type, zero, one);
74f711785eSAlexander Belyaev Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
75f711785eSAlexander Belyaev Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
76f711785eSAlexander Belyaev
77f711785eSAlexander Belyaev Value divResult =
78f711785eSAlexander Belyaev b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
79f711785eSAlexander Belyaev Value logResult = b.create<complex::LogOp>(divResult);
80f711785eSAlexander Belyaev
81f711785eSAlexander Belyaev Value negativeOne = b.create<arith::ConstantOp>(
82f711785eSAlexander Belyaev elementType, b.getFloatAttr(elementType, -1));
83f711785eSAlexander Belyaev Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
84f711785eSAlexander Belyaev
85f711785eSAlexander Belyaev rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
86f711785eSAlexander Belyaev return success();
87f711785eSAlexander Belyaev }
88f711785eSAlexander Belyaev };
89f711785eSAlexander Belyaev
90a54f4eaeSMogball template <typename ComparisonOp, arith::CmpFPredicate p>
91fb8b2b86SAdrian Kuegel struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
92fb8b2b86SAdrian Kuegel using OpConversionPattern<ComparisonOp>::OpConversionPattern;
93fb8b2b86SAdrian Kuegel using ResultCombiner =
94fb8b2b86SAdrian Kuegel std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
95a54f4eaeSMogball arith::AndIOp, arith::OrIOp>;
96ac00cb0dSAdrian Kuegel
97ac00cb0dSAdrian Kuegel LogicalResult
matchAndRewrite__anon597693150111::ComparisonOpConversion98b54c724bSRiver Riddle matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
99ac00cb0dSAdrian Kuegel ConversionPatternRewriter &rewriter) const override {
100ac00cb0dSAdrian Kuegel auto loc = op.getLoc();
101c0342a2dSJacques Pienaar auto type = adaptor.getLhs()
102c0342a2dSJacques Pienaar .getType()
103c0342a2dSJacques Pienaar .template cast<ComplexType>()
104c0342a2dSJacques Pienaar .getElementType();
105ac00cb0dSAdrian Kuegel
106c0342a2dSJacques Pienaar Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
107c0342a2dSJacques Pienaar Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
108c0342a2dSJacques Pienaar Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
109c0342a2dSJacques Pienaar Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
110a54f4eaeSMogball Value realComparison =
111a54f4eaeSMogball rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
112a54f4eaeSMogball Value imagComparison =
113a54f4eaeSMogball rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
114ac00cb0dSAdrian Kuegel
115fb8b2b86SAdrian Kuegel rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
116fb8b2b86SAdrian Kuegel imagComparison);
117ac00cb0dSAdrian Kuegel return success();
118ac00cb0dSAdrian Kuegel }
119ac00cb0dSAdrian Kuegel };
120942be7cbSAdrian Kuegel
121fb978f09SAdrian Kuegel // Default conversion which applies the BinaryStandardOp separately on the real
122fb978f09SAdrian Kuegel // and imaginary parts. Can for example be used for complex::AddOp and
123fb978f09SAdrian Kuegel // complex::SubOp.
124fb978f09SAdrian Kuegel template <typename BinaryComplexOp, typename BinaryStandardOp>
125fb978f09SAdrian Kuegel struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
126fb978f09SAdrian Kuegel using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
127fb978f09SAdrian Kuegel
128fb978f09SAdrian Kuegel LogicalResult
matchAndRewrite__anon597693150111::BinaryComplexOpConversion129b54c724bSRiver Riddle matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
130fb978f09SAdrian Kuegel ConversionPatternRewriter &rewriter) const override {
131c0342a2dSJacques Pienaar auto type = adaptor.getLhs().getType().template cast<ComplexType>();
132fb978f09SAdrian Kuegel auto elementType = type.getElementType().template cast<FloatType>();
133fb978f09SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
134fb978f09SAdrian Kuegel
135c0342a2dSJacques Pienaar Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
136c0342a2dSJacques Pienaar Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
137fb978f09SAdrian Kuegel Value resultReal =
138fb978f09SAdrian Kuegel b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
139c0342a2dSJacques Pienaar Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
140c0342a2dSJacques Pienaar Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
141fb978f09SAdrian Kuegel Value resultImag =
142fb978f09SAdrian Kuegel b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
143fb978f09SAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
144fb978f09SAdrian Kuegel resultImag);
145fb978f09SAdrian Kuegel return success();
146fb978f09SAdrian Kuegel }
147fb978f09SAdrian Kuegel };
148fb978f09SAdrian Kuegel
149672b908bSGoran Flegar template <typename TrigonometricOp>
150672b908bSGoran Flegar struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
151672b908bSGoran Flegar using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
152672b908bSGoran Flegar
153672b908bSGoran Flegar using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
154672b908bSGoran Flegar
155672b908bSGoran Flegar LogicalResult
matchAndRewrite__anon597693150111::TrigonometricOpConversion156672b908bSGoran Flegar matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
157672b908bSGoran Flegar ConversionPatternRewriter &rewriter) const override {
158672b908bSGoran Flegar auto loc = op.getLoc();
159672b908bSGoran Flegar auto type = adaptor.getComplex().getType().template cast<ComplexType>();
160672b908bSGoran Flegar auto elementType = type.getElementType().template cast<FloatType>();
161672b908bSGoran Flegar
162672b908bSGoran Flegar Value real =
163672b908bSGoran Flegar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
164672b908bSGoran Flegar Value imag =
165672b908bSGoran Flegar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
166672b908bSGoran Flegar
167672b908bSGoran Flegar // Trigonometric ops use a set of common building blocks to convert to real
168672b908bSGoran Flegar // ops. Here we create these building blocks and call into an op-specific
169672b908bSGoran Flegar // implementation in the subclass to combine them.
170672b908bSGoran Flegar Value half = rewriter.create<arith::ConstantOp>(
171672b908bSGoran Flegar loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
172672b908bSGoran Flegar Value exp = rewriter.create<math::ExpOp>(loc, imag);
173672b908bSGoran Flegar Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
174672b908bSGoran Flegar Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
175672b908bSGoran Flegar Value sin = rewriter.create<math::SinOp>(loc, real);
176672b908bSGoran Flegar Value cos = rewriter.create<math::CosOp>(loc, real);
177672b908bSGoran Flegar
178672b908bSGoran Flegar auto resultPair =
179672b908bSGoran Flegar combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
180672b908bSGoran Flegar
181672b908bSGoran Flegar rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
182672b908bSGoran Flegar resultPair.second);
183672b908bSGoran Flegar return success();
184672b908bSGoran Flegar }
185672b908bSGoran Flegar
186672b908bSGoran Flegar virtual std::pair<Value, Value>
187672b908bSGoran Flegar combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
188672b908bSGoran Flegar Value cos, ConversionPatternRewriter &rewriter) const = 0;
189672b908bSGoran Flegar };
190672b908bSGoran Flegar
191672b908bSGoran Flegar struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
192672b908bSGoran Flegar using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
193672b908bSGoran Flegar
194672b908bSGoran Flegar std::pair<Value, Value>
combine__anon597693150111::CosOpConversion195672b908bSGoran Flegar combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
196672b908bSGoran Flegar Value cos, ConversionPatternRewriter &rewriter) const override {
197672b908bSGoran Flegar // Complex cosine is defined as;
198672b908bSGoran Flegar // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
199672b908bSGoran Flegar // Plugging in:
200672b908bSGoran Flegar // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
201672b908bSGoran Flegar // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
202672b908bSGoran Flegar // and defining t := exp(y)
203672b908bSGoran Flegar // We get:
204672b908bSGoran Flegar // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
205672b908bSGoran Flegar // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
206672b908bSGoran Flegar Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
207672b908bSGoran Flegar Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
208672b908bSGoran Flegar Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
209672b908bSGoran Flegar Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
210672b908bSGoran Flegar return {resultReal, resultImag};
211672b908bSGoran Flegar }
212672b908bSGoran Flegar };
213672b908bSGoran Flegar
214942be7cbSAdrian Kuegel struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
215942be7cbSAdrian Kuegel using OpConversionPattern<complex::DivOp>::OpConversionPattern;
216942be7cbSAdrian Kuegel
217942be7cbSAdrian Kuegel LogicalResult
matchAndRewrite__anon597693150111::DivOpConversion218b54c724bSRiver Riddle matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
219942be7cbSAdrian Kuegel ConversionPatternRewriter &rewriter) const override {
220942be7cbSAdrian Kuegel auto loc = op.getLoc();
221c0342a2dSJacques Pienaar auto type = adaptor.getLhs().getType().cast<ComplexType>();
222942be7cbSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>();
223942be7cbSAdrian Kuegel
224942be7cbSAdrian Kuegel Value lhsReal =
225c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
226942be7cbSAdrian Kuegel Value lhsImag =
227c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
228942be7cbSAdrian Kuegel Value rhsReal =
229c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
230942be7cbSAdrian Kuegel Value rhsImag =
231c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
232942be7cbSAdrian Kuegel
233942be7cbSAdrian Kuegel // Smith's algorithm to divide complex numbers. It is just a bit smarter
234942be7cbSAdrian Kuegel // way to compute the following formula:
235942be7cbSAdrian Kuegel // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
236942be7cbSAdrian Kuegel // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
237942be7cbSAdrian Kuegel // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
238942be7cbSAdrian Kuegel // = ((lhsReal * rhsReal + lhsImag * rhsImag) +
239942be7cbSAdrian Kuegel // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
240942be7cbSAdrian Kuegel //
241942be7cbSAdrian Kuegel // Depending on whether |rhsReal| < |rhsImag| we compute either
242942be7cbSAdrian Kuegel // rhsRealImagRatio = rhsReal / rhsImag
243942be7cbSAdrian Kuegel // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
244942be7cbSAdrian Kuegel // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
245942be7cbSAdrian Kuegel // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
246942be7cbSAdrian Kuegel //
247942be7cbSAdrian Kuegel // or
248942be7cbSAdrian Kuegel //
249942be7cbSAdrian Kuegel // rhsImagRealRatio = rhsImag / rhsReal
250942be7cbSAdrian Kuegel // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
251942be7cbSAdrian Kuegel // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
252942be7cbSAdrian Kuegel // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
253942be7cbSAdrian Kuegel //
254942be7cbSAdrian Kuegel // See https://dl.acm.org/citation.cfm?id=368661 for more details.
255a54f4eaeSMogball Value rhsRealImagRatio =
256a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
257a54f4eaeSMogball Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
258a54f4eaeSMogball loc, rhsImag,
259a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
260a54f4eaeSMogball Value realNumerator1 = rewriter.create<arith::AddFOp>(
261a54f4eaeSMogball loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
262a54f4eaeSMogball lhsImag);
263942be7cbSAdrian Kuegel Value resultReal1 =
264a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
265a54f4eaeSMogball Value imagNumerator1 = rewriter.create<arith::SubFOp>(
266a54f4eaeSMogball loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
267a54f4eaeSMogball lhsReal);
268942be7cbSAdrian Kuegel Value resultImag1 =
269a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
270942be7cbSAdrian Kuegel
271a54f4eaeSMogball Value rhsImagRealRatio =
272a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
273a54f4eaeSMogball Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
274a54f4eaeSMogball loc, rhsReal,
275a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
276a54f4eaeSMogball Value realNumerator2 = rewriter.create<arith::AddFOp>(
277a54f4eaeSMogball loc, lhsReal,
278a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
279942be7cbSAdrian Kuegel Value resultReal2 =
280a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
281a54f4eaeSMogball Value imagNumerator2 = rewriter.create<arith::SubFOp>(
282a54f4eaeSMogball loc, lhsImag,
283a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
284942be7cbSAdrian Kuegel Value resultImag2 =
285a54f4eaeSMogball rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
286942be7cbSAdrian Kuegel
287942be7cbSAdrian Kuegel // Consider corner cases.
288942be7cbSAdrian Kuegel // Case 1. Zero denominator, numerator contains at most one NaN value.
289a54f4eaeSMogball Value zero = rewriter.create<arith::ConstantOp>(
290a54f4eaeSMogball loc, elementType, rewriter.getZeroAttr(elementType));
291a54f4eaeSMogball Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
292a54f4eaeSMogball Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
293a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
294a54f4eaeSMogball Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
295a54f4eaeSMogball Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
296a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
297a54f4eaeSMogball Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
298a54f4eaeSMogball loc, arith::CmpFPredicate::ORD, lhsReal, zero);
299a54f4eaeSMogball Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
300a54f4eaeSMogball loc, arith::CmpFPredicate::ORD, lhsImag, zero);
301942be7cbSAdrian Kuegel Value lhsContainsNotNaNValue =
302a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
303a54f4eaeSMogball Value resultIsInfinity = rewriter.create<arith::AndIOp>(
304942be7cbSAdrian Kuegel loc, lhsContainsNotNaNValue,
305a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
306a54f4eaeSMogball Value inf = rewriter.create<arith::ConstantOp>(
307942be7cbSAdrian Kuegel loc, elementType,
308942be7cbSAdrian Kuegel rewriter.getFloatAttr(
309942be7cbSAdrian Kuegel elementType, APFloat::getInf(elementType.getFloatSemantics())));
310a54f4eaeSMogball Value infWithSignOfRhsReal =
311a54f4eaeSMogball rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
312942be7cbSAdrian Kuegel Value infinityResultReal =
313a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
314942be7cbSAdrian Kuegel Value infinityResultImag =
315a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
316942be7cbSAdrian Kuegel
317942be7cbSAdrian Kuegel // Case 2. Infinite numerator, finite denominator.
318a54f4eaeSMogball Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
319a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
320a54f4eaeSMogball Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
321a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
322a54f4eaeSMogball Value rhsFinite =
323a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
324a54f4eaeSMogball Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
325a54f4eaeSMogball Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
326a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
327a54f4eaeSMogball Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
328a54f4eaeSMogball Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
329a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
330942be7cbSAdrian Kuegel Value lhsInfinite =
331a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
332942be7cbSAdrian Kuegel Value infNumFiniteDenom =
333a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
334a54f4eaeSMogball Value one = rewriter.create<arith::ConstantOp>(
335942be7cbSAdrian Kuegel loc, elementType, rewriter.getFloatAttr(elementType, 1));
336a54f4eaeSMogball Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
337dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
338942be7cbSAdrian Kuegel lhsReal);
339a54f4eaeSMogball Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
340dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
341942be7cbSAdrian Kuegel lhsImag);
342942be7cbSAdrian Kuegel Value lhsRealIsInfWithSignTimesRhsReal =
343a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
344942be7cbSAdrian Kuegel Value lhsImagIsInfWithSignTimesRhsImag =
345a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
346a54f4eaeSMogball Value resultReal3 = rewriter.create<arith::MulFOp>(
347942be7cbSAdrian Kuegel loc, inf,
348a54f4eaeSMogball rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
349942be7cbSAdrian Kuegel lhsImagIsInfWithSignTimesRhsImag));
350942be7cbSAdrian Kuegel Value lhsRealIsInfWithSignTimesRhsImag =
351a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
352942be7cbSAdrian Kuegel Value lhsImagIsInfWithSignTimesRhsReal =
353a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
354a54f4eaeSMogball Value resultImag3 = rewriter.create<arith::MulFOp>(
355942be7cbSAdrian Kuegel loc, inf,
356a54f4eaeSMogball rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
357942be7cbSAdrian Kuegel lhsRealIsInfWithSignTimesRhsImag));
358942be7cbSAdrian Kuegel
359942be7cbSAdrian Kuegel // Case 3: Finite numerator, infinite denominator.
360a54f4eaeSMogball Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
361a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
362a54f4eaeSMogball Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
363a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
364a54f4eaeSMogball Value lhsFinite =
365a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
366a54f4eaeSMogball Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
367a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
368a54f4eaeSMogball Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
369a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
370942be7cbSAdrian Kuegel Value rhsInfinite =
371a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
372942be7cbSAdrian Kuegel Value finiteNumInfiniteDenom =
373a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
374a54f4eaeSMogball Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
375dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
376942be7cbSAdrian Kuegel rhsReal);
377a54f4eaeSMogball Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
378dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
379942be7cbSAdrian Kuegel rhsImag);
380942be7cbSAdrian Kuegel Value rhsRealIsInfWithSignTimesLhsReal =
381a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
382942be7cbSAdrian Kuegel Value rhsImagIsInfWithSignTimesLhsImag =
383a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
384a54f4eaeSMogball Value resultReal4 = rewriter.create<arith::MulFOp>(
385942be7cbSAdrian Kuegel loc, zero,
386a54f4eaeSMogball rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
387942be7cbSAdrian Kuegel rhsImagIsInfWithSignTimesLhsImag));
388942be7cbSAdrian Kuegel Value rhsRealIsInfWithSignTimesLhsImag =
389a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
390942be7cbSAdrian Kuegel Value rhsImagIsInfWithSignTimesLhsReal =
391a54f4eaeSMogball rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
392a54f4eaeSMogball Value resultImag4 = rewriter.create<arith::MulFOp>(
393942be7cbSAdrian Kuegel loc, zero,
394a54f4eaeSMogball rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
395942be7cbSAdrian Kuegel rhsImagIsInfWithSignTimesLhsReal));
396942be7cbSAdrian Kuegel
397a54f4eaeSMogball Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
398a54f4eaeSMogball loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
399dec8af70SRiver Riddle Value resultReal = rewriter.create<arith::SelectOp>(
400dec8af70SRiver Riddle loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
401dec8af70SRiver Riddle Value resultImag = rewriter.create<arith::SelectOp>(
402dec8af70SRiver Riddle loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
403dec8af70SRiver Riddle Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
404942be7cbSAdrian Kuegel loc, finiteNumInfiniteDenom, resultReal4, resultReal);
405dec8af70SRiver Riddle Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
406942be7cbSAdrian Kuegel loc, finiteNumInfiniteDenom, resultImag4, resultImag);
407dec8af70SRiver Riddle Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
408942be7cbSAdrian Kuegel loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
409dec8af70SRiver Riddle Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
410942be7cbSAdrian Kuegel loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
411dec8af70SRiver Riddle Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
412942be7cbSAdrian Kuegel loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
413dec8af70SRiver Riddle Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
414942be7cbSAdrian Kuegel loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
415942be7cbSAdrian Kuegel
416a54f4eaeSMogball Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
417a54f4eaeSMogball loc, arith::CmpFPredicate::UNO, resultReal, zero);
418a54f4eaeSMogball Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
419a54f4eaeSMogball loc, arith::CmpFPredicate::UNO, resultImag, zero);
420942be7cbSAdrian Kuegel Value resultIsNaN =
421a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
422dec8af70SRiver Riddle Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
423942be7cbSAdrian Kuegel loc, resultIsNaN, resultRealSpecialCase1, resultReal);
424dec8af70SRiver Riddle Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
425942be7cbSAdrian Kuegel loc, resultIsNaN, resultImagSpecialCase1, resultImag);
426942be7cbSAdrian Kuegel
427942be7cbSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(
428942be7cbSAdrian Kuegel op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
429942be7cbSAdrian Kuegel return success();
430942be7cbSAdrian Kuegel }
431942be7cbSAdrian Kuegel };
43273cbc91cSAdrian Kuegel
43373cbc91cSAdrian Kuegel struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
43473cbc91cSAdrian Kuegel using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
43573cbc91cSAdrian Kuegel
43673cbc91cSAdrian Kuegel LogicalResult
matchAndRewrite__anon597693150111::ExpOpConversion437b54c724bSRiver Riddle matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
43873cbc91cSAdrian Kuegel ConversionPatternRewriter &rewriter) const override {
43973cbc91cSAdrian Kuegel auto loc = op.getLoc();
440c0342a2dSJacques Pienaar auto type = adaptor.getComplex().getType().cast<ComplexType>();
44173cbc91cSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>();
44273cbc91cSAdrian Kuegel
44373cbc91cSAdrian Kuegel Value real =
444c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
44573cbc91cSAdrian Kuegel Value imag =
446c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
44773cbc91cSAdrian Kuegel Value expReal = rewriter.create<math::ExpOp>(loc, real);
44873cbc91cSAdrian Kuegel Value cosImag = rewriter.create<math::CosOp>(loc, imag);
449a54f4eaeSMogball Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
45073cbc91cSAdrian Kuegel Value sinImag = rewriter.create<math::SinOp>(loc, imag);
451a54f4eaeSMogball Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
45273cbc91cSAdrian Kuegel
45373cbc91cSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
45473cbc91cSAdrian Kuegel resultImag);
45573cbc91cSAdrian Kuegel return success();
45673cbc91cSAdrian Kuegel }
45773cbc91cSAdrian Kuegel };
458662e074dSAdrian Kuegel
459338e76f8Sbixia1 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
460338e76f8Sbixia1 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
461338e76f8Sbixia1
462338e76f8Sbixia1 LogicalResult
matchAndRewrite__anon597693150111::Expm1OpConversion463338e76f8Sbixia1 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
464338e76f8Sbixia1 ConversionPatternRewriter &rewriter) const override {
465338e76f8Sbixia1 auto type = adaptor.getComplex().getType().cast<ComplexType>();
466338e76f8Sbixia1 auto elementType = type.getElementType().cast<FloatType>();
467338e76f8Sbixia1
468338e76f8Sbixia1 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
469338e76f8Sbixia1 Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
470338e76f8Sbixia1
471338e76f8Sbixia1 Value real = b.create<complex::ReOp>(elementType, exp);
472338e76f8Sbixia1 Value one = b.create<arith::ConstantOp>(elementType,
473338e76f8Sbixia1 b.getFloatAttr(elementType, 1));
474338e76f8Sbixia1 Value realMinusOne = b.create<arith::SubFOp>(real, one);
475338e76f8Sbixia1 Value imag = b.create<complex::ImOp>(elementType, exp);
476338e76f8Sbixia1
477338e76f8Sbixia1 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
478338e76f8Sbixia1 imag);
479338e76f8Sbixia1 return success();
480338e76f8Sbixia1 }
481338e76f8Sbixia1 };
482338e76f8Sbixia1
483380fa71fSAdrian Kuegel struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
484380fa71fSAdrian Kuegel using OpConversionPattern<complex::LogOp>::OpConversionPattern;
485380fa71fSAdrian Kuegel
486380fa71fSAdrian Kuegel LogicalResult
matchAndRewrite__anon597693150111::LogOpConversion487b54c724bSRiver Riddle matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
488380fa71fSAdrian Kuegel ConversionPatternRewriter &rewriter) const override {
489c0342a2dSJacques Pienaar auto type = adaptor.getComplex().getType().cast<ComplexType>();
490380fa71fSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>();
491380fa71fSAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
492380fa71fSAdrian Kuegel
493c0342a2dSJacques Pienaar Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
494380fa71fSAdrian Kuegel Value resultReal = b.create<math::LogOp>(elementType, abs);
495c0342a2dSJacques Pienaar Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
496c0342a2dSJacques Pienaar Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
497380fa71fSAdrian Kuegel Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
498380fa71fSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
499380fa71fSAdrian Kuegel resultImag);
500380fa71fSAdrian Kuegel return success();
501380fa71fSAdrian Kuegel }
502380fa71fSAdrian Kuegel };
503380fa71fSAdrian Kuegel
5046e80e3bdSAdrian Kuegel struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
5056e80e3bdSAdrian Kuegel using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
5066e80e3bdSAdrian Kuegel
5076e80e3bdSAdrian Kuegel LogicalResult
matchAndRewrite__anon597693150111::Log1pOpConversion508b54c724bSRiver Riddle matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
5096e80e3bdSAdrian Kuegel ConversionPatternRewriter &rewriter) const override {
510c0342a2dSJacques Pienaar auto type = adaptor.getComplex().getType().cast<ComplexType>();
5116e80e3bdSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>();
5126e80e3bdSAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
5136e80e3bdSAdrian Kuegel
514c0342a2dSJacques Pienaar Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
515c0342a2dSJacques Pienaar Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
516a54f4eaeSMogball Value one = b.create<arith::ConstantOp>(elementType,
517a54f4eaeSMogball b.getFloatAttr(elementType, 1));
518a54f4eaeSMogball Value realPlusOne = b.create<arith::AddFOp>(real, one);
5196e80e3bdSAdrian Kuegel Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
5206e80e3bdSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
5216e80e3bdSAdrian Kuegel return success();
5226e80e3bdSAdrian Kuegel }
5236e80e3bdSAdrian Kuegel };
5246e80e3bdSAdrian Kuegel
525bf17ee19SAdrian Kuegel struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
526bf17ee19SAdrian Kuegel using OpConversionPattern<complex::MulOp>::OpConversionPattern;
527bf17ee19SAdrian Kuegel
528bf17ee19SAdrian Kuegel LogicalResult
matchAndRewrite__anon597693150111::MulOpConversion529b54c724bSRiver Riddle matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
530bf17ee19SAdrian Kuegel ConversionPatternRewriter &rewriter) const override {
531bf17ee19SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
532c0342a2dSJacques Pienaar auto type = adaptor.getLhs().getType().cast<ComplexType>();
533bf17ee19SAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>();
534bf17ee19SAdrian Kuegel
535c0342a2dSJacques Pienaar Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
536a54f4eaeSMogball Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
537c0342a2dSJacques Pienaar Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
538a54f4eaeSMogball Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
539c0342a2dSJacques Pienaar Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
540a54f4eaeSMogball Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
541c0342a2dSJacques Pienaar Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
542a54f4eaeSMogball Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
543bf17ee19SAdrian Kuegel
544a54f4eaeSMogball Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
545a54f4eaeSMogball Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
546a54f4eaeSMogball Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
547a54f4eaeSMogball Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
548a54f4eaeSMogball Value real =
549a54f4eaeSMogball b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
550bf17ee19SAdrian Kuegel
551a54f4eaeSMogball Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
552a54f4eaeSMogball Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
553a54f4eaeSMogball Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
554a54f4eaeSMogball Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
555a54f4eaeSMogball Value imag =
556a54f4eaeSMogball b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
557bf17ee19SAdrian Kuegel
558bf17ee19SAdrian Kuegel // Handle cases where the "naive" calculation results in NaN values.
559a54f4eaeSMogball Value realIsNan =
560a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
561a54f4eaeSMogball Value imagIsNan =
562a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
563a54f4eaeSMogball Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
564bf17ee19SAdrian Kuegel
565a54f4eaeSMogball Value inf = b.create<arith::ConstantOp>(
566bf17ee19SAdrian Kuegel elementType,
567bf17ee19SAdrian Kuegel b.getFloatAttr(elementType,
568bf17ee19SAdrian Kuegel APFloat::getInf(elementType.getFloatSemantics())));
569bf17ee19SAdrian Kuegel
570bf17ee19SAdrian Kuegel // Case 1. `lhsReal` or `lhsImag` are infinite.
571a54f4eaeSMogball Value lhsRealIsInf =
572a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
573a54f4eaeSMogball Value lhsImagIsInf =
574a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
575a54f4eaeSMogball Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
576a54f4eaeSMogball Value rhsRealIsNan =
577a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
578a54f4eaeSMogball Value rhsImagIsNan =
579a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
580a54f4eaeSMogball Value zero =
581a54f4eaeSMogball b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
582a54f4eaeSMogball Value one = b.create<arith::ConstantOp>(elementType,
583a54f4eaeSMogball b.getFloatAttr(elementType, 1));
584dec8af70SRiver Riddle Value lhsRealIsInfFloat =
585dec8af70SRiver Riddle b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
586dec8af70SRiver Riddle lhsReal = b.create<arith::SelectOp>(
587a54f4eaeSMogball lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
588a54f4eaeSMogball lhsReal);
589dec8af70SRiver Riddle Value lhsImagIsInfFloat =
590dec8af70SRiver Riddle b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
591dec8af70SRiver Riddle lhsImag = b.create<arith::SelectOp>(
592a54f4eaeSMogball lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
593a54f4eaeSMogball lhsImag);
594a54f4eaeSMogball Value lhsIsInfAndRhsRealIsNan =
595a54f4eaeSMogball b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
596dec8af70SRiver Riddle rhsReal = b.create<arith::SelectOp>(
597dec8af70SRiver Riddle lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
598dec8af70SRiver Riddle rhsReal);
599a54f4eaeSMogball Value lhsIsInfAndRhsImagIsNan =
600a54f4eaeSMogball b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
601dec8af70SRiver Riddle rhsImag = b.create<arith::SelectOp>(
602dec8af70SRiver Riddle lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
603dec8af70SRiver Riddle rhsImag);
604bf17ee19SAdrian Kuegel
605bf17ee19SAdrian Kuegel // Case 2. `rhsReal` or `rhsImag` are infinite.
606a54f4eaeSMogball Value rhsRealIsInf =
607a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
608a54f4eaeSMogball Value rhsImagIsInf =
609a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
610a54f4eaeSMogball Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
611a54f4eaeSMogball Value lhsRealIsNan =
612a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
613a54f4eaeSMogball Value lhsImagIsNan =
614a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
615dec8af70SRiver Riddle Value rhsRealIsInfFloat =
616dec8af70SRiver Riddle b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
617dec8af70SRiver Riddle rhsReal = b.create<arith::SelectOp>(
618a54f4eaeSMogball rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
619a54f4eaeSMogball rhsReal);
620dec8af70SRiver Riddle Value rhsImagIsInfFloat =
621dec8af70SRiver Riddle b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
622dec8af70SRiver Riddle rhsImag = b.create<arith::SelectOp>(
623a54f4eaeSMogball rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
624a54f4eaeSMogball rhsImag);
625a54f4eaeSMogball Value rhsIsInfAndLhsRealIsNan =
626a54f4eaeSMogball b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
627dec8af70SRiver Riddle lhsReal = b.create<arith::SelectOp>(
628dec8af70SRiver Riddle rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
629dec8af70SRiver Riddle lhsReal);
630a54f4eaeSMogball Value rhsIsInfAndLhsImagIsNan =
631a54f4eaeSMogball b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
632dec8af70SRiver Riddle lhsImag = b.create<arith::SelectOp>(
633dec8af70SRiver Riddle rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
634dec8af70SRiver Riddle lhsImag);
635a54f4eaeSMogball Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
636bf17ee19SAdrian Kuegel
637bf17ee19SAdrian Kuegel // Case 3. One of the pairwise products of left hand side with right hand
638bf17ee19SAdrian Kuegel // side is infinite.
639a54f4eaeSMogball Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
640a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
641a54f4eaeSMogball Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
642a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
643a54f4eaeSMogball Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
644a54f4eaeSMogball lhsImagTimesRhsImagIsInf);
645a54f4eaeSMogball Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
646a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
647a54f4eaeSMogball isSpecialCase =
648a54f4eaeSMogball b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
649a54f4eaeSMogball Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
650a54f4eaeSMogball arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
651a54f4eaeSMogball isSpecialCase =
652a54f4eaeSMogball b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
653bf17ee19SAdrian Kuegel Type i1Type = b.getI1Type();
654a54f4eaeSMogball Value notRecalc = b.create<arith::XOrIOp>(
655a54f4eaeSMogball recalc,
656a54f4eaeSMogball b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
657a54f4eaeSMogball isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
658bf17ee19SAdrian Kuegel Value isSpecialCaseAndLhsRealIsNan =
659a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
660dec8af70SRiver Riddle lhsReal = b.create<arith::SelectOp>(
661dec8af70SRiver Riddle isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
662dec8af70SRiver Riddle lhsReal);
663bf17ee19SAdrian Kuegel Value isSpecialCaseAndLhsImagIsNan =
664a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
665dec8af70SRiver Riddle lhsImag = b.create<arith::SelectOp>(
666dec8af70SRiver Riddle isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
667dec8af70SRiver Riddle lhsImag);
668bf17ee19SAdrian Kuegel Value isSpecialCaseAndRhsRealIsNan =
669a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
670dec8af70SRiver Riddle rhsReal = b.create<arith::SelectOp>(
671dec8af70SRiver Riddle isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
672dec8af70SRiver Riddle rhsReal);
673bf17ee19SAdrian Kuegel Value isSpecialCaseAndRhsImagIsNan =
674a54f4eaeSMogball b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
675dec8af70SRiver Riddle rhsImag = b.create<arith::SelectOp>(
676dec8af70SRiver Riddle isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
677dec8af70SRiver Riddle rhsImag);
678a54f4eaeSMogball recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
679a54f4eaeSMogball recalc = b.create<arith::AndIOp>(isNan, recalc);
680bf17ee19SAdrian Kuegel
681bf17ee19SAdrian Kuegel // Recalculate real part.
682a54f4eaeSMogball lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
683a54f4eaeSMogball lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
684a54f4eaeSMogball Value newReal =
685a54f4eaeSMogball b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
686dec8af70SRiver Riddle real = b.create<arith::SelectOp>(
687dec8af70SRiver Riddle recalc, b.create<arith::MulFOp>(inf, newReal), real);
688bf17ee19SAdrian Kuegel
689bf17ee19SAdrian Kuegel // Recalculate imag part.
690a54f4eaeSMogball lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
691a54f4eaeSMogball lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
692a54f4eaeSMogball Value newImag =
693a54f4eaeSMogball b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
694dec8af70SRiver Riddle imag = b.create<arith::SelectOp>(
695dec8af70SRiver Riddle recalc, b.create<arith::MulFOp>(inf, newImag), imag);
696bf17ee19SAdrian Kuegel
697bf17ee19SAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
698bf17ee19SAdrian Kuegel return success();
699bf17ee19SAdrian Kuegel }
700bf17ee19SAdrian Kuegel };
701bf17ee19SAdrian Kuegel
702662e074dSAdrian Kuegel struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
703662e074dSAdrian Kuegel using OpConversionPattern<complex::NegOp>::OpConversionPattern;
704662e074dSAdrian Kuegel
705662e074dSAdrian Kuegel LogicalResult
matchAndRewrite__anon597693150111::NegOpConversion706b54c724bSRiver Riddle matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
707662e074dSAdrian Kuegel ConversionPatternRewriter &rewriter) const override {
708662e074dSAdrian Kuegel auto loc = op.getLoc();
709c0342a2dSJacques Pienaar auto type = adaptor.getComplex().getType().cast<ComplexType>();
710662e074dSAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>();
711662e074dSAdrian Kuegel
712662e074dSAdrian Kuegel Value real =
713c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
714662e074dSAdrian Kuegel Value imag =
715c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
716a54f4eaeSMogball Value negReal = rewriter.create<arith::NegFOp>(loc, real);
717a54f4eaeSMogball Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
718662e074dSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
719662e074dSAdrian Kuegel return success();
720662e074dSAdrian Kuegel }
721662e074dSAdrian Kuegel };
722f112bd61SAdrian Kuegel
723672b908bSGoran Flegar struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
724672b908bSGoran Flegar using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
725672b908bSGoran Flegar
726672b908bSGoran Flegar std::pair<Value, Value>
combine__anon597693150111::SinOpConversion727672b908bSGoran Flegar combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
728672b908bSGoran Flegar Value cos, ConversionPatternRewriter &rewriter) const override {
729672b908bSGoran Flegar // Complex sine is defined as;
730672b908bSGoran Flegar // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
731672b908bSGoran Flegar // Plugging in:
732672b908bSGoran Flegar // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
733672b908bSGoran Flegar // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
734672b908bSGoran Flegar // and defining t := exp(y)
735672b908bSGoran Flegar // We get:
736672b908bSGoran Flegar // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
737672b908bSGoran Flegar // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
738672b908bSGoran Flegar Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
739672b908bSGoran Flegar Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
740672b908bSGoran Flegar Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
741672b908bSGoran Flegar Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
742672b908bSGoran Flegar return {resultReal, resultImag};
743672b908bSGoran Flegar }
744672b908bSGoran Flegar };
745672b908bSGoran Flegar
746f711785eSAlexander Belyaev // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
747f711785eSAlexander Belyaev struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
748f711785eSAlexander Belyaev using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
749f711785eSAlexander Belyaev
750f711785eSAlexander Belyaev LogicalResult
matchAndRewrite__anon597693150111::SqrtOpConversion751f711785eSAlexander Belyaev matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
752f711785eSAlexander Belyaev ConversionPatternRewriter &rewriter) const override {
753f711785eSAlexander Belyaev mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
754f711785eSAlexander Belyaev
755f711785eSAlexander Belyaev auto type = op.getType().cast<ComplexType>();
756f711785eSAlexander Belyaev Type elementType = type.getElementType();
757f711785eSAlexander Belyaev Value arg = adaptor.getComplex();
758f711785eSAlexander Belyaev
759f711785eSAlexander Belyaev Value zero =
760f711785eSAlexander Belyaev b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
761f711785eSAlexander Belyaev
762f711785eSAlexander Belyaev Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
763f711785eSAlexander Belyaev Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
764f711785eSAlexander Belyaev
765f711785eSAlexander Belyaev Value absLhs = b.create<math::AbsOp>(real);
766f711785eSAlexander Belyaev Value absArg = b.create<complex::AbsOp>(elementType, arg);
767f711785eSAlexander Belyaev Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
768f711785eSAlexander Belyaev
769*b7f93c28SJeff Niu Value half = b.create<arith::ConstantOp>(elementType,
770*b7f93c28SJeff Niu b.getFloatAttr(elementType, 0.5));
771f711785eSAlexander Belyaev Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
772f711785eSAlexander Belyaev Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
773f711785eSAlexander Belyaev
774f711785eSAlexander Belyaev Value realIsNegative =
775f711785eSAlexander Belyaev b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
776f711785eSAlexander Belyaev Value imagIsNegative =
777f711785eSAlexander Belyaev b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
778f711785eSAlexander Belyaev
779f711785eSAlexander Belyaev Value resultReal = sqrtAddAbs;
780f711785eSAlexander Belyaev
781f711785eSAlexander Belyaev Value imagDivTwoResultReal = b.create<arith::DivFOp>(
782f711785eSAlexander Belyaev imag, b.create<arith::AddFOp>(resultReal, resultReal));
783f711785eSAlexander Belyaev
784f711785eSAlexander Belyaev Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
785f711785eSAlexander Belyaev
786f711785eSAlexander Belyaev Value resultImag = b.create<arith::SelectOp>(
787f711785eSAlexander Belyaev realIsNegative,
788f711785eSAlexander Belyaev b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
789f711785eSAlexander Belyaev resultReal),
790f711785eSAlexander Belyaev imagDivTwoResultReal);
791f711785eSAlexander Belyaev
792f711785eSAlexander Belyaev resultReal = b.create<arith::SelectOp>(
793f711785eSAlexander Belyaev realIsNegative,
794f711785eSAlexander Belyaev b.create<arith::DivFOp>(
795f711785eSAlexander Belyaev imag, b.create<arith::AddFOp>(resultImag, resultImag)),
796f711785eSAlexander Belyaev resultReal);
797f711785eSAlexander Belyaev
798f711785eSAlexander Belyaev Value realIsZero =
799f711785eSAlexander Belyaev b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
800f711785eSAlexander Belyaev Value imagIsZero =
801f711785eSAlexander Belyaev b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
802f711785eSAlexander Belyaev Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
803f711785eSAlexander Belyaev
804f711785eSAlexander Belyaev resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
805f711785eSAlexander Belyaev resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
806f711785eSAlexander Belyaev
807f711785eSAlexander Belyaev rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
808f711785eSAlexander Belyaev resultImag);
809f711785eSAlexander Belyaev return success();
810f711785eSAlexander Belyaev }
811f711785eSAlexander Belyaev };
812f711785eSAlexander Belyaev
813f112bd61SAdrian Kuegel struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
814f112bd61SAdrian Kuegel using OpConversionPattern<complex::SignOp>::OpConversionPattern;
815f112bd61SAdrian Kuegel
816f112bd61SAdrian Kuegel LogicalResult
matchAndRewrite__anon597693150111::SignOpConversion817b54c724bSRiver Riddle matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
818f112bd61SAdrian Kuegel ConversionPatternRewriter &rewriter) const override {
819c0342a2dSJacques Pienaar auto type = adaptor.getComplex().getType().cast<ComplexType>();
820f112bd61SAdrian Kuegel auto elementType = type.getElementType().cast<FloatType>();
821f112bd61SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
822f112bd61SAdrian Kuegel
823c0342a2dSJacques Pienaar Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
824c0342a2dSJacques Pienaar Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
825a54f4eaeSMogball Value zero =
826a54f4eaeSMogball b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
827a54f4eaeSMogball Value realIsZero =
828a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
829a54f4eaeSMogball Value imagIsZero =
830a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
831a54f4eaeSMogball Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
832c0342a2dSJacques Pienaar auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
833a54f4eaeSMogball Value realSign = b.create<arith::DivFOp>(real, abs);
834a54f4eaeSMogball Value imagSign = b.create<arith::DivFOp>(imag, abs);
835f112bd61SAdrian Kuegel Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
836dec8af70SRiver Riddle rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
837dec8af70SRiver Riddle adaptor.getComplex(), sign);
838f112bd61SAdrian Kuegel return success();
839f112bd61SAdrian Kuegel }
840f112bd61SAdrian Kuegel };
8416d75c897Slewuathe
8426d75c897Slewuathe struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
8436d75c897Slewuathe using OpConversionPattern<complex::TanOp>::OpConversionPattern;
8446d75c897Slewuathe
8456d75c897Slewuathe LogicalResult
matchAndRewrite__anon597693150111::TanOpConversion8466d75c897Slewuathe matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
8476d75c897Slewuathe ConversionPatternRewriter &rewriter) const override {
8486d75c897Slewuathe auto loc = op.getLoc();
8496d75c897Slewuathe Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex());
8506d75c897Slewuathe Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex());
8516d75c897Slewuathe rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos);
8526d75c897Slewuathe return success();
8536d75c897Slewuathe }
8546d75c897Slewuathe };
855ffb8eecdSlewuathe
856ffb8eecdSlewuathe struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
857ffb8eecdSlewuathe using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
858ffb8eecdSlewuathe
859ffb8eecdSlewuathe LogicalResult
matchAndRewrite__anon597693150111::TanhOpConversion860ffb8eecdSlewuathe matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
861ffb8eecdSlewuathe ConversionPatternRewriter &rewriter) const override {
862ffb8eecdSlewuathe auto loc = op.getLoc();
863ffb8eecdSlewuathe auto type = adaptor.getComplex().getType().cast<ComplexType>();
864ffb8eecdSlewuathe auto elementType = type.getElementType().cast<FloatType>();
865ffb8eecdSlewuathe
866ffb8eecdSlewuathe // The hyperbolic tangent for complex number can be calculated as follows.
867ffb8eecdSlewuathe // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
868ffb8eecdSlewuathe // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
869ffb8eecdSlewuathe Value real =
870ffb8eecdSlewuathe rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
871ffb8eecdSlewuathe Value imag =
872ffb8eecdSlewuathe rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
873ffb8eecdSlewuathe Value tanhA = rewriter.create<math::TanhOp>(loc, real);
874ffb8eecdSlewuathe Value cosB = rewriter.create<math::CosOp>(loc, imag);
875ffb8eecdSlewuathe Value sinB = rewriter.create<math::SinOp>(loc, imag);
876ffb8eecdSlewuathe Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
877ffb8eecdSlewuathe Value numerator =
878ffb8eecdSlewuathe rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
879ffb8eecdSlewuathe Value one = rewriter.create<arith::ConstantOp>(
880ffb8eecdSlewuathe loc, elementType, rewriter.getFloatAttr(elementType, 1));
881ffb8eecdSlewuathe Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
882ffb8eecdSlewuathe Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
883ffb8eecdSlewuathe rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
884ffb8eecdSlewuathe return success();
885ffb8eecdSlewuathe }
886ffb8eecdSlewuathe };
887ffb8eecdSlewuathe
88862a34f6aSlewuathe struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
88962a34f6aSlewuathe using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
89062a34f6aSlewuathe
89162a34f6aSlewuathe LogicalResult
matchAndRewrite__anon597693150111::ConjOpConversion89262a34f6aSlewuathe matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
89362a34f6aSlewuathe ConversionPatternRewriter &rewriter) const override {
89462a34f6aSlewuathe auto loc = op.getLoc();
89562a34f6aSlewuathe auto type = adaptor.getComplex().getType().cast<ComplexType>();
89662a34f6aSlewuathe auto elementType = type.getElementType().cast<FloatType>();
89762a34f6aSlewuathe Value real =
89862a34f6aSlewuathe rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
89962a34f6aSlewuathe Value imag =
90062a34f6aSlewuathe rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
90162a34f6aSlewuathe Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
90262a34f6aSlewuathe
90362a34f6aSlewuathe rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
90462a34f6aSlewuathe
90562a34f6aSlewuathe return success();
90662a34f6aSlewuathe }
90762a34f6aSlewuathe };
90862a34f6aSlewuathe
9096c6eddb6Sbixia1 /// Coverts x^y = (a+bi)^(c+di) to
9106c6eddb6Sbixia1 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
9116c6eddb6Sbixia1 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
powOpConversionImpl(mlir::ImplicitLocOpBuilder & builder,ComplexType type,Value a,Value b,Value c,Value d)9126c6eddb6Sbixia1 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
9136c6eddb6Sbixia1 ComplexType type, Value a, Value b, Value c,
9146c6eddb6Sbixia1 Value d) {
9156c6eddb6Sbixia1 auto elementType = type.getElementType().cast<FloatType>();
9166c6eddb6Sbixia1
9176c6eddb6Sbixia1 // Compute (a*a+b*b)^(0.5c).
9186c6eddb6Sbixia1 Value aaPbb = builder.create<arith::AddFOp>(
9196c6eddb6Sbixia1 builder.create<arith::MulFOp>(a, a), builder.create<arith::MulFOp>(b, b));
9206c6eddb6Sbixia1 Value half = builder.create<arith::ConstantOp>(
9216c6eddb6Sbixia1 elementType, builder.getFloatAttr(elementType, 0.5));
9226c6eddb6Sbixia1 Value halfC = builder.create<arith::MulFOp>(half, c);
9236c6eddb6Sbixia1 Value aaPbbTohalfC = builder.create<math::PowFOp>(aaPbb, halfC);
9246c6eddb6Sbixia1
9256c6eddb6Sbixia1 // Compute exp(-d*atan2(b,a)).
9266c6eddb6Sbixia1 Value negD = builder.create<arith::NegFOp>(d);
9276c6eddb6Sbixia1 Value argX = builder.create<math::Atan2Op>(b, a);
9286c6eddb6Sbixia1 Value negDArgX = builder.create<arith::MulFOp>(negD, argX);
9296c6eddb6Sbixia1 Value eToNegDArgX = builder.create<math::ExpOp>(negDArgX);
9306c6eddb6Sbixia1
9316c6eddb6Sbixia1 // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)).
9326c6eddb6Sbixia1 Value coeff = builder.create<arith::MulFOp>(aaPbbTohalfC, eToNegDArgX);
9336c6eddb6Sbixia1
9346c6eddb6Sbixia1 // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b).
9356c6eddb6Sbixia1 Value lnAaPbb = builder.create<math::LogOp>(aaPbb);
9366c6eddb6Sbixia1 Value halfD = builder.create<arith::MulFOp>(half, d);
9376c6eddb6Sbixia1 Value q = builder.create<arith::AddFOp>(
9386c6eddb6Sbixia1 builder.create<arith::MulFOp>(c, argX),
9396c6eddb6Sbixia1 builder.create<arith::MulFOp>(halfD, lnAaPbb));
9406c6eddb6Sbixia1
9416c6eddb6Sbixia1 Value cosQ = builder.create<math::CosOp>(q);
9426c6eddb6Sbixia1 Value sinQ = builder.create<math::SinOp>(q);
9436c6eddb6Sbixia1 Value zero = builder.create<arith::ConstantOp>(
9446c6eddb6Sbixia1 elementType, builder.getFloatAttr(elementType, 0));
9456c6eddb6Sbixia1 Value one = builder.create<arith::ConstantOp>(
9466c6eddb6Sbixia1 elementType, builder.getFloatAttr(elementType, 1));
9476c6eddb6Sbixia1
9486c6eddb6Sbixia1 Value xEqZero =
9496c6eddb6Sbixia1 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, aaPbb, zero);
9506c6eddb6Sbixia1 Value yGeZero = builder.create<arith::AndIOp>(
9516c6eddb6Sbixia1 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, c, zero),
9526c6eddb6Sbixia1 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero));
9536c6eddb6Sbixia1 Value cEqZero =
9546c6eddb6Sbixia1 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero);
9556c6eddb6Sbixia1 Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
9566c6eddb6Sbixia1 Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
9576c6eddb6Sbixia1 Value complexOther = builder.create<complex::CreateOp>(
9586c6eddb6Sbixia1 type, builder.create<arith::MulFOp>(coeff, cosQ),
9596c6eddb6Sbixia1 builder.create<arith::MulFOp>(coeff, sinQ));
9606c6eddb6Sbixia1
9616c6eddb6Sbixia1 // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see
9626c6eddb6Sbixia1 // Branch Cuts for Complex Elementary Functions or Much Ado About
9636c6eddb6Sbixia1 // Nothing's Sign Bit, W. Kahan, Section 10.
9646c6eddb6Sbixia1 return builder.create<arith::SelectOp>(
9656c6eddb6Sbixia1 builder.create<arith::AndIOp>(xEqZero, yGeZero),
9666c6eddb6Sbixia1 builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero),
9676c6eddb6Sbixia1 complexOther);
9686c6eddb6Sbixia1 }
9696c6eddb6Sbixia1
9706c6eddb6Sbixia1 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
9716c6eddb6Sbixia1 using OpConversionPattern<complex::PowOp>::OpConversionPattern;
9726c6eddb6Sbixia1
9736c6eddb6Sbixia1 LogicalResult
matchAndRewrite__anon597693150111::PowOpConversion9746c6eddb6Sbixia1 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
9756c6eddb6Sbixia1 ConversionPatternRewriter &rewriter) const override {
9766c6eddb6Sbixia1 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
9776c6eddb6Sbixia1 auto type = adaptor.getLhs().getType().cast<ComplexType>();
9786c6eddb6Sbixia1 auto elementType = type.getElementType().cast<FloatType>();
9796c6eddb6Sbixia1
9806c6eddb6Sbixia1 Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs());
9816c6eddb6Sbixia1 Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs());
9826c6eddb6Sbixia1 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
9836c6eddb6Sbixia1 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
9846c6eddb6Sbixia1
9856c6eddb6Sbixia1 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
9866c6eddb6Sbixia1 return success();
9876c6eddb6Sbixia1 }
9886c6eddb6Sbixia1 };
9896c6eddb6Sbixia1
9906c6eddb6Sbixia1 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
9916c6eddb6Sbixia1 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
9926c6eddb6Sbixia1
9936c6eddb6Sbixia1 LogicalResult
matchAndRewrite__anon597693150111::RsqrtOpConversion9946c6eddb6Sbixia1 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
9956c6eddb6Sbixia1 ConversionPatternRewriter &rewriter) const override {
9966c6eddb6Sbixia1 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
9976c6eddb6Sbixia1 auto type = adaptor.getComplex().getType().cast<ComplexType>();
9986c6eddb6Sbixia1 auto elementType = type.getElementType().cast<FloatType>();
9996c6eddb6Sbixia1
10006c6eddb6Sbixia1 Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex());
10016c6eddb6Sbixia1 Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex());
10026c6eddb6Sbixia1 Value c = builder.create<arith::ConstantOp>(
10036c6eddb6Sbixia1 elementType, builder.getFloatAttr(elementType, -0.5));
10046c6eddb6Sbixia1 Value d = builder.create<arith::ConstantOp>(
10056c6eddb6Sbixia1 elementType, builder.getFloatAttr(elementType, 0));
10066c6eddb6Sbixia1
10076c6eddb6Sbixia1 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
10086c6eddb6Sbixia1 return success();
10096c6eddb6Sbixia1 }
10106c6eddb6Sbixia1 };
10116c6eddb6Sbixia1
10128fa2e679SLewuathe struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
10138fa2e679SLewuathe using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
10148fa2e679SLewuathe
10158fa2e679SLewuathe LogicalResult
matchAndRewrite__anon597693150111::AngleOpConversion10168fa2e679SLewuathe matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
10178fa2e679SLewuathe ConversionPatternRewriter &rewriter) const override {
10188fa2e679SLewuathe auto loc = op.getLoc();
10198fa2e679SLewuathe auto type = op.getType();
10208fa2e679SLewuathe
10218fa2e679SLewuathe Value real =
10228fa2e679SLewuathe rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
10238fa2e679SLewuathe Value imag =
10248fa2e679SLewuathe rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
10258fa2e679SLewuathe
10268fa2e679SLewuathe rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
10278fa2e679SLewuathe
10288fa2e679SLewuathe return success();
10298fa2e679SLewuathe }
10308fa2e679SLewuathe };
10318fa2e679SLewuathe
10322ea7fb7bSAdrian Kuegel } // namespace
10332ea7fb7bSAdrian Kuegel
populateComplexToStandardConversionPatterns(RewritePatternSet & patterns)10342ea7fb7bSAdrian Kuegel void mlir::populateComplexToStandardConversionPatterns(
10352ea7fb7bSAdrian Kuegel RewritePatternSet &patterns) {
1036f112bd61SAdrian Kuegel // clang-format off
1037f112bd61SAdrian Kuegel patterns.add<
1038f112bd61SAdrian Kuegel AbsOpConversion,
10398fa2e679SLewuathe AngleOpConversion,
1040f711785eSAlexander Belyaev Atan2OpConversion,
1041a54f4eaeSMogball BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1042a54f4eaeSMogball BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
104362a34f6aSlewuathe ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
104462a34f6aSlewuathe ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
104562a34f6aSlewuathe ConjOpConversion,
1046672b908bSGoran Flegar CosOpConversion,
1047f112bd61SAdrian Kuegel DivOpConversion,
1048f112bd61SAdrian Kuegel ExpOpConversion,
1049338e76f8Sbixia1 Expm1OpConversion,
10506e80e3bdSAdrian Kuegel Log1pOpConversion,
105162a34f6aSlewuathe LogOpConversion,
1052bf17ee19SAdrian Kuegel MulOpConversion,
1053f112bd61SAdrian Kuegel NegOpConversion,
1054672b908bSGoran Flegar SignOpConversion,
10556d75c897Slewuathe SinOpConversion,
1056f711785eSAlexander Belyaev SqrtOpConversion,
1057ffb8eecdSlewuathe TanOpConversion,
10586c6eddb6Sbixia1 TanhOpConversion,
10596c6eddb6Sbixia1 PowOpConversion,
10606c6eddb6Sbixia1 RsqrtOpConversion
106162a34f6aSlewuathe >(patterns.getContext());
1062f112bd61SAdrian Kuegel // clang-format on
10632ea7fb7bSAdrian Kuegel }
10642ea7fb7bSAdrian Kuegel
10652ea7fb7bSAdrian Kuegel namespace {
10662ea7fb7bSAdrian Kuegel struct ConvertComplexToStandardPass
10672ea7fb7bSAdrian Kuegel : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
106841574554SRiver Riddle void runOnOperation() override;
10692ea7fb7bSAdrian Kuegel };
10702ea7fb7bSAdrian Kuegel
runOnOperation()107141574554SRiver Riddle void ConvertComplexToStandardPass::runOnOperation() {
10722ea7fb7bSAdrian Kuegel // Convert to the Standard dialect using the converter defined above.
10732ea7fb7bSAdrian Kuegel RewritePatternSet patterns(&getContext());
10742ea7fb7bSAdrian Kuegel populateComplexToStandardConversionPatterns(patterns);
10752ea7fb7bSAdrian Kuegel
10762ea7fb7bSAdrian Kuegel ConversionTarget target(getContext());
10771f971e23SRiver Riddle target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>();
1078fb978f09SAdrian Kuegel target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
107947f175b0SRiver Riddle if (failed(
108047f175b0SRiver Riddle applyPartialConversion(getOperation(), target, std::move(patterns))))
10812ea7fb7bSAdrian Kuegel signalPassFailure();
10822ea7fb7bSAdrian Kuegel }
10832ea7fb7bSAdrian Kuegel } // namespace
10842ea7fb7bSAdrian Kuegel
createConvertComplexToStandardPass()108547f175b0SRiver Riddle std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
10862ea7fb7bSAdrian Kuegel return std::make_unique<ConvertComplexToStandardPass>();
10872ea7fb7bSAdrian Kuegel }
1088