1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10
11 #include <memory>
12 #include <type_traits>
13
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Complex/IR/Complex.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21
22 using namespace mlir;
23
24 namespace {
25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
26 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
27
28 LogicalResult
matchAndRewrite__anon597693150111::AbsOpConversion29 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
30 ConversionPatternRewriter &rewriter) const override {
31 auto loc = op.getLoc();
32 auto type = op.getType();
33
34 Value real =
35 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
36 Value imag =
37 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
38 Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
39 Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
40 Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
41
42 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
43 return success();
44 }
45 };
46
47 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
48 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
49 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
50
51 LogicalResult
matchAndRewrite__anon597693150111::Atan2OpConversion52 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
55
56 auto type = op.getType().cast<ComplexType>();
57 Type elementType = type.getElementType();
58
59 Value lhs = adaptor.getLhs();
60 Value rhs = adaptor.getRhs();
61
62 Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
63 Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
64 Value rhsSquaredPlusLhsSquared =
65 b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
66 Value sqrtOfRhsSquaredPlusLhsSquared =
67 b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
68
69 Value zero =
70 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
71 Value one = b.create<arith::ConstantOp>(elementType,
72 b.getFloatAttr(elementType, 1));
73 Value i = b.create<complex::CreateOp>(type, zero, one);
74 Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
75 Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
76
77 Value divResult =
78 b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
79 Value logResult = b.create<complex::LogOp>(divResult);
80
81 Value negativeOne = b.create<arith::ConstantOp>(
82 elementType, b.getFloatAttr(elementType, -1));
83 Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
84
85 rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
86 return success();
87 }
88 };
89
90 template <typename ComparisonOp, arith::CmpFPredicate p>
91 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
92 using OpConversionPattern<ComparisonOp>::OpConversionPattern;
93 using ResultCombiner =
94 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
95 arith::AndIOp, arith::OrIOp>;
96
97 LogicalResult
matchAndRewrite__anon597693150111::ComparisonOpConversion98 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
99 ConversionPatternRewriter &rewriter) const override {
100 auto loc = op.getLoc();
101 auto type = adaptor.getLhs()
102 .getType()
103 .template cast<ComplexType>()
104 .getElementType();
105
106 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
107 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
108 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
109 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
110 Value realComparison =
111 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
112 Value imagComparison =
113 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
114
115 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
116 imagComparison);
117 return success();
118 }
119 };
120
121 // Default conversion which applies the BinaryStandardOp separately on the real
122 // and imaginary parts. Can for example be used for complex::AddOp and
123 // complex::SubOp.
124 template <typename BinaryComplexOp, typename BinaryStandardOp>
125 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
126 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
127
128 LogicalResult
matchAndRewrite__anon597693150111::BinaryComplexOpConversion129 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
130 ConversionPatternRewriter &rewriter) const override {
131 auto type = adaptor.getLhs().getType().template cast<ComplexType>();
132 auto elementType = type.getElementType().template cast<FloatType>();
133 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
134
135 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
136 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
137 Value resultReal =
138 b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
139 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
140 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
141 Value resultImag =
142 b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
143 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
144 resultImag);
145 return success();
146 }
147 };
148
149 template <typename TrigonometricOp>
150 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
151 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
152
153 using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
154
155 LogicalResult
matchAndRewrite__anon597693150111::TrigonometricOpConversion156 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
157 ConversionPatternRewriter &rewriter) const override {
158 auto loc = op.getLoc();
159 auto type = adaptor.getComplex().getType().template cast<ComplexType>();
160 auto elementType = type.getElementType().template cast<FloatType>();
161
162 Value real =
163 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
164 Value imag =
165 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
166
167 // Trigonometric ops use a set of common building blocks to convert to real
168 // ops. Here we create these building blocks and call into an op-specific
169 // implementation in the subclass to combine them.
170 Value half = rewriter.create<arith::ConstantOp>(
171 loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
172 Value exp = rewriter.create<math::ExpOp>(loc, imag);
173 Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
174 Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
175 Value sin = rewriter.create<math::SinOp>(loc, real);
176 Value cos = rewriter.create<math::CosOp>(loc, real);
177
178 auto resultPair =
179 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
180
181 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
182 resultPair.second);
183 return success();
184 }
185
186 virtual std::pair<Value, Value>
187 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
188 Value cos, ConversionPatternRewriter &rewriter) const = 0;
189 };
190
191 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
192 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
193
194 std::pair<Value, Value>
combine__anon597693150111::CosOpConversion195 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
196 Value cos, ConversionPatternRewriter &rewriter) const override {
197 // Complex cosine is defined as;
198 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
199 // Plugging in:
200 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
201 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
202 // and defining t := exp(y)
203 // We get:
204 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
205 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
206 Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
207 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
208 Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
209 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
210 return {resultReal, resultImag};
211 }
212 };
213
214 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
215 using OpConversionPattern<complex::DivOp>::OpConversionPattern;
216
217 LogicalResult
matchAndRewrite__anon597693150111::DivOpConversion218 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter) const override {
220 auto loc = op.getLoc();
221 auto type = adaptor.getLhs().getType().cast<ComplexType>();
222 auto elementType = type.getElementType().cast<FloatType>();
223
224 Value lhsReal =
225 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
226 Value lhsImag =
227 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
228 Value rhsReal =
229 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
230 Value rhsImag =
231 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
232
233 // Smith's algorithm to divide complex numbers. It is just a bit smarter
234 // way to compute the following formula:
235 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
236 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
237 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
238 // = ((lhsReal * rhsReal + lhsImag * rhsImag) +
239 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
240 //
241 // Depending on whether |rhsReal| < |rhsImag| we compute either
242 // rhsRealImagRatio = rhsReal / rhsImag
243 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
244 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
245 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
246 //
247 // or
248 //
249 // rhsImagRealRatio = rhsImag / rhsReal
250 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
251 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
252 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
253 //
254 // See https://dl.acm.org/citation.cfm?id=368661 for more details.
255 Value rhsRealImagRatio =
256 rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
257 Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
258 loc, rhsImag,
259 rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
260 Value realNumerator1 = rewriter.create<arith::AddFOp>(
261 loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
262 lhsImag);
263 Value resultReal1 =
264 rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
265 Value imagNumerator1 = rewriter.create<arith::SubFOp>(
266 loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
267 lhsReal);
268 Value resultImag1 =
269 rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
270
271 Value rhsImagRealRatio =
272 rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
273 Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
274 loc, rhsReal,
275 rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
276 Value realNumerator2 = rewriter.create<arith::AddFOp>(
277 loc, lhsReal,
278 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
279 Value resultReal2 =
280 rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
281 Value imagNumerator2 = rewriter.create<arith::SubFOp>(
282 loc, lhsImag,
283 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
284 Value resultImag2 =
285 rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
286
287 // Consider corner cases.
288 // Case 1. Zero denominator, numerator contains at most one NaN value.
289 Value zero = rewriter.create<arith::ConstantOp>(
290 loc, elementType, rewriter.getZeroAttr(elementType));
291 Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal);
292 Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
293 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
294 Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag);
295 Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
296 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
297 Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
298 loc, arith::CmpFPredicate::ORD, lhsReal, zero);
299 Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
300 loc, arith::CmpFPredicate::ORD, lhsImag, zero);
301 Value lhsContainsNotNaNValue =
302 rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
303 Value resultIsInfinity = rewriter.create<arith::AndIOp>(
304 loc, lhsContainsNotNaNValue,
305 rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
306 Value inf = rewriter.create<arith::ConstantOp>(
307 loc, elementType,
308 rewriter.getFloatAttr(
309 elementType, APFloat::getInf(elementType.getFloatSemantics())));
310 Value infWithSignOfRhsReal =
311 rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
312 Value infinityResultReal =
313 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
314 Value infinityResultImag =
315 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
316
317 // Case 2. Infinite numerator, finite denominator.
318 Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
319 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
320 Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
321 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
322 Value rhsFinite =
323 rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
324 Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal);
325 Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
326 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
327 Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag);
328 Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
329 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
330 Value lhsInfinite =
331 rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
332 Value infNumFiniteDenom =
333 rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
334 Value one = rewriter.create<arith::ConstantOp>(
335 loc, elementType, rewriter.getFloatAttr(elementType, 1));
336 Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
337 loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
338 lhsReal);
339 Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
340 loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
341 lhsImag);
342 Value lhsRealIsInfWithSignTimesRhsReal =
343 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
344 Value lhsImagIsInfWithSignTimesRhsImag =
345 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
346 Value resultReal3 = rewriter.create<arith::MulFOp>(
347 loc, inf,
348 rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
349 lhsImagIsInfWithSignTimesRhsImag));
350 Value lhsRealIsInfWithSignTimesRhsImag =
351 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
352 Value lhsImagIsInfWithSignTimesRhsReal =
353 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
354 Value resultImag3 = rewriter.create<arith::MulFOp>(
355 loc, inf,
356 rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
357 lhsRealIsInfWithSignTimesRhsImag));
358
359 // Case 3: Finite numerator, infinite denominator.
360 Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
361 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
362 Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
363 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
364 Value lhsFinite =
365 rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
366 Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
367 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
368 Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
369 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
370 Value rhsInfinite =
371 rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
372 Value finiteNumInfiniteDenom =
373 rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
374 Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
375 loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
376 rhsReal);
377 Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
378 loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
379 rhsImag);
380 Value rhsRealIsInfWithSignTimesLhsReal =
381 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
382 Value rhsImagIsInfWithSignTimesLhsImag =
383 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
384 Value resultReal4 = rewriter.create<arith::MulFOp>(
385 loc, zero,
386 rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
387 rhsImagIsInfWithSignTimesLhsImag));
388 Value rhsRealIsInfWithSignTimesLhsImag =
389 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
390 Value rhsImagIsInfWithSignTimesLhsReal =
391 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
392 Value resultImag4 = rewriter.create<arith::MulFOp>(
393 loc, zero,
394 rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
395 rhsImagIsInfWithSignTimesLhsReal));
396
397 Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
398 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
399 Value resultReal = rewriter.create<arith::SelectOp>(
400 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
401 Value resultImag = rewriter.create<arith::SelectOp>(
402 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
403 Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
404 loc, finiteNumInfiniteDenom, resultReal4, resultReal);
405 Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
406 loc, finiteNumInfiniteDenom, resultImag4, resultImag);
407 Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
408 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
409 Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
410 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
411 Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
412 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
413 Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
414 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
415
416 Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
417 loc, arith::CmpFPredicate::UNO, resultReal, zero);
418 Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
419 loc, arith::CmpFPredicate::UNO, resultImag, zero);
420 Value resultIsNaN =
421 rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
422 Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
423 loc, resultIsNaN, resultRealSpecialCase1, resultReal);
424 Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
425 loc, resultIsNaN, resultImagSpecialCase1, resultImag);
426
427 rewriter.replaceOpWithNewOp<complex::CreateOp>(
428 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
429 return success();
430 }
431 };
432
433 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
434 using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
435
436 LogicalResult
matchAndRewrite__anon597693150111::ExpOpConversion437 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
438 ConversionPatternRewriter &rewriter) const override {
439 auto loc = op.getLoc();
440 auto type = adaptor.getComplex().getType().cast<ComplexType>();
441 auto elementType = type.getElementType().cast<FloatType>();
442
443 Value real =
444 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
445 Value imag =
446 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
447 Value expReal = rewriter.create<math::ExpOp>(loc, real);
448 Value cosImag = rewriter.create<math::CosOp>(loc, imag);
449 Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
450 Value sinImag = rewriter.create<math::SinOp>(loc, imag);
451 Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
452
453 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
454 resultImag);
455 return success();
456 }
457 };
458
459 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
460 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
461
462 LogicalResult
matchAndRewrite__anon597693150111::Expm1OpConversion463 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
464 ConversionPatternRewriter &rewriter) const override {
465 auto type = adaptor.getComplex().getType().cast<ComplexType>();
466 auto elementType = type.getElementType().cast<FloatType>();
467
468 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
469 Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
470
471 Value real = b.create<complex::ReOp>(elementType, exp);
472 Value one = b.create<arith::ConstantOp>(elementType,
473 b.getFloatAttr(elementType, 1));
474 Value realMinusOne = b.create<arith::SubFOp>(real, one);
475 Value imag = b.create<complex::ImOp>(elementType, exp);
476
477 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
478 imag);
479 return success();
480 }
481 };
482
483 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
484 using OpConversionPattern<complex::LogOp>::OpConversionPattern;
485
486 LogicalResult
matchAndRewrite__anon597693150111::LogOpConversion487 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
488 ConversionPatternRewriter &rewriter) const override {
489 auto type = adaptor.getComplex().getType().cast<ComplexType>();
490 auto elementType = type.getElementType().cast<FloatType>();
491 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
492
493 Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
494 Value resultReal = b.create<math::LogOp>(elementType, abs);
495 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
496 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
497 Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
498 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
499 resultImag);
500 return success();
501 }
502 };
503
504 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
505 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
506
507 LogicalResult
matchAndRewrite__anon597693150111::Log1pOpConversion508 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
509 ConversionPatternRewriter &rewriter) const override {
510 auto type = adaptor.getComplex().getType().cast<ComplexType>();
511 auto elementType = type.getElementType().cast<FloatType>();
512 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
513
514 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
515 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
516 Value one = b.create<arith::ConstantOp>(elementType,
517 b.getFloatAttr(elementType, 1));
518 Value realPlusOne = b.create<arith::AddFOp>(real, one);
519 Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
520 rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
521 return success();
522 }
523 };
524
525 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
526 using OpConversionPattern<complex::MulOp>::OpConversionPattern;
527
528 LogicalResult
matchAndRewrite__anon597693150111::MulOpConversion529 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
530 ConversionPatternRewriter &rewriter) const override {
531 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
532 auto type = adaptor.getLhs().getType().cast<ComplexType>();
533 auto elementType = type.getElementType().cast<FloatType>();
534
535 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
536 Value lhsRealAbs = b.create<math::AbsOp>(lhsReal);
537 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
538 Value lhsImagAbs = b.create<math::AbsOp>(lhsImag);
539 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
540 Value rhsRealAbs = b.create<math::AbsOp>(rhsReal);
541 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
542 Value rhsImagAbs = b.create<math::AbsOp>(rhsImag);
543
544 Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
545 Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal);
546 Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
547 Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag);
548 Value real =
549 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
550
551 Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
552 Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal);
553 Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
554 Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag);
555 Value imag =
556 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
557
558 // Handle cases where the "naive" calculation results in NaN values.
559 Value realIsNan =
560 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
561 Value imagIsNan =
562 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
563 Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
564
565 Value inf = b.create<arith::ConstantOp>(
566 elementType,
567 b.getFloatAttr(elementType,
568 APFloat::getInf(elementType.getFloatSemantics())));
569
570 // Case 1. `lhsReal` or `lhsImag` are infinite.
571 Value lhsRealIsInf =
572 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
573 Value lhsImagIsInf =
574 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
575 Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
576 Value rhsRealIsNan =
577 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
578 Value rhsImagIsNan =
579 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
580 Value zero =
581 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
582 Value one = b.create<arith::ConstantOp>(elementType,
583 b.getFloatAttr(elementType, 1));
584 Value lhsRealIsInfFloat =
585 b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
586 lhsReal = b.create<arith::SelectOp>(
587 lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
588 lhsReal);
589 Value lhsImagIsInfFloat =
590 b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
591 lhsImag = b.create<arith::SelectOp>(
592 lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
593 lhsImag);
594 Value lhsIsInfAndRhsRealIsNan =
595 b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
596 rhsReal = b.create<arith::SelectOp>(
597 lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
598 rhsReal);
599 Value lhsIsInfAndRhsImagIsNan =
600 b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
601 rhsImag = b.create<arith::SelectOp>(
602 lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
603 rhsImag);
604
605 // Case 2. `rhsReal` or `rhsImag` are infinite.
606 Value rhsRealIsInf =
607 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
608 Value rhsImagIsInf =
609 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
610 Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
611 Value lhsRealIsNan =
612 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
613 Value lhsImagIsNan =
614 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
615 Value rhsRealIsInfFloat =
616 b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
617 rhsReal = b.create<arith::SelectOp>(
618 rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
619 rhsReal);
620 Value rhsImagIsInfFloat =
621 b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
622 rhsImag = b.create<arith::SelectOp>(
623 rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
624 rhsImag);
625 Value rhsIsInfAndLhsRealIsNan =
626 b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
627 lhsReal = b.create<arith::SelectOp>(
628 rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
629 lhsReal);
630 Value rhsIsInfAndLhsImagIsNan =
631 b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
632 lhsImag = b.create<arith::SelectOp>(
633 rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
634 lhsImag);
635 Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
636
637 // Case 3. One of the pairwise products of left hand side with right hand
638 // side is infinite.
639 Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
640 arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
641 Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
642 arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
643 Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
644 lhsImagTimesRhsImagIsInf);
645 Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
646 arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
647 isSpecialCase =
648 b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
649 Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
650 arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
651 isSpecialCase =
652 b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
653 Type i1Type = b.getI1Type();
654 Value notRecalc = b.create<arith::XOrIOp>(
655 recalc,
656 b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
657 isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
658 Value isSpecialCaseAndLhsRealIsNan =
659 b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
660 lhsReal = b.create<arith::SelectOp>(
661 isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
662 lhsReal);
663 Value isSpecialCaseAndLhsImagIsNan =
664 b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
665 lhsImag = b.create<arith::SelectOp>(
666 isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
667 lhsImag);
668 Value isSpecialCaseAndRhsRealIsNan =
669 b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
670 rhsReal = b.create<arith::SelectOp>(
671 isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
672 rhsReal);
673 Value isSpecialCaseAndRhsImagIsNan =
674 b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
675 rhsImag = b.create<arith::SelectOp>(
676 isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
677 rhsImag);
678 recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
679 recalc = b.create<arith::AndIOp>(isNan, recalc);
680
681 // Recalculate real part.
682 lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
683 lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
684 Value newReal =
685 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
686 real = b.create<arith::SelectOp>(
687 recalc, b.create<arith::MulFOp>(inf, newReal), real);
688
689 // Recalculate imag part.
690 lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
691 lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
692 Value newImag =
693 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
694 imag = b.create<arith::SelectOp>(
695 recalc, b.create<arith::MulFOp>(inf, newImag), imag);
696
697 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
698 return success();
699 }
700 };
701
702 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
703 using OpConversionPattern<complex::NegOp>::OpConversionPattern;
704
705 LogicalResult
matchAndRewrite__anon597693150111::NegOpConversion706 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
707 ConversionPatternRewriter &rewriter) const override {
708 auto loc = op.getLoc();
709 auto type = adaptor.getComplex().getType().cast<ComplexType>();
710 auto elementType = type.getElementType().cast<FloatType>();
711
712 Value real =
713 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
714 Value imag =
715 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
716 Value negReal = rewriter.create<arith::NegFOp>(loc, real);
717 Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
718 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
719 return success();
720 }
721 };
722
723 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
724 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
725
726 std::pair<Value, Value>
combine__anon597693150111::SinOpConversion727 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
728 Value cos, ConversionPatternRewriter &rewriter) const override {
729 // Complex sine is defined as;
730 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
731 // Plugging in:
732 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
733 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
734 // and defining t := exp(y)
735 // We get:
736 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
737 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
738 Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
739 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
740 Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
741 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
742 return {resultReal, resultImag};
743 }
744 };
745
746 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
747 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
748 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
749
750 LogicalResult
matchAndRewrite__anon597693150111::SqrtOpConversion751 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
752 ConversionPatternRewriter &rewriter) const override {
753 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
754
755 auto type = op.getType().cast<ComplexType>();
756 Type elementType = type.getElementType();
757 Value arg = adaptor.getComplex();
758
759 Value zero =
760 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
761
762 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
763 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
764
765 Value absLhs = b.create<math::AbsOp>(real);
766 Value absArg = b.create<complex::AbsOp>(elementType, arg);
767 Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
768
769 Value half = b.create<arith::ConstantOp>(elementType,
770 b.getFloatAttr(elementType, 0.5));
771 Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
772 Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
773
774 Value realIsNegative =
775 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
776 Value imagIsNegative =
777 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
778
779 Value resultReal = sqrtAddAbs;
780
781 Value imagDivTwoResultReal = b.create<arith::DivFOp>(
782 imag, b.create<arith::AddFOp>(resultReal, resultReal));
783
784 Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
785
786 Value resultImag = b.create<arith::SelectOp>(
787 realIsNegative,
788 b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
789 resultReal),
790 imagDivTwoResultReal);
791
792 resultReal = b.create<arith::SelectOp>(
793 realIsNegative,
794 b.create<arith::DivFOp>(
795 imag, b.create<arith::AddFOp>(resultImag, resultImag)),
796 resultReal);
797
798 Value realIsZero =
799 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
800 Value imagIsZero =
801 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
802 Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
803
804 resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
805 resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
806
807 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
808 resultImag);
809 return success();
810 }
811 };
812
813 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
814 using OpConversionPattern<complex::SignOp>::OpConversionPattern;
815
816 LogicalResult
matchAndRewrite__anon597693150111::SignOpConversion817 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter) const override {
819 auto type = adaptor.getComplex().getType().cast<ComplexType>();
820 auto elementType = type.getElementType().cast<FloatType>();
821 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
822
823 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
824 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
825 Value zero =
826 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
827 Value realIsZero =
828 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
829 Value imagIsZero =
830 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
831 Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
832 auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
833 Value realSign = b.create<arith::DivFOp>(real, abs);
834 Value imagSign = b.create<arith::DivFOp>(imag, abs);
835 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
836 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
837 adaptor.getComplex(), sign);
838 return success();
839 }
840 };
841
842 struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
843 using OpConversionPattern<complex::TanOp>::OpConversionPattern;
844
845 LogicalResult
matchAndRewrite__anon597693150111::TanOpConversion846 matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
847 ConversionPatternRewriter &rewriter) const override {
848 auto loc = op.getLoc();
849 Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex());
850 Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex());
851 rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos);
852 return success();
853 }
854 };
855
856 struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
857 using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
858
859 LogicalResult
matchAndRewrite__anon597693150111::TanhOpConversion860 matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
861 ConversionPatternRewriter &rewriter) const override {
862 auto loc = op.getLoc();
863 auto type = adaptor.getComplex().getType().cast<ComplexType>();
864 auto elementType = type.getElementType().cast<FloatType>();
865
866 // The hyperbolic tangent for complex number can be calculated as follows.
867 // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
868 // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
869 Value real =
870 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
871 Value imag =
872 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
873 Value tanhA = rewriter.create<math::TanhOp>(loc, real);
874 Value cosB = rewriter.create<math::CosOp>(loc, imag);
875 Value sinB = rewriter.create<math::SinOp>(loc, imag);
876 Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
877 Value numerator =
878 rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
879 Value one = rewriter.create<arith::ConstantOp>(
880 loc, elementType, rewriter.getFloatAttr(elementType, 1));
881 Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
882 Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
883 rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
884 return success();
885 }
886 };
887
888 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
889 using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
890
891 LogicalResult
matchAndRewrite__anon597693150111::ConjOpConversion892 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
893 ConversionPatternRewriter &rewriter) const override {
894 auto loc = op.getLoc();
895 auto type = adaptor.getComplex().getType().cast<ComplexType>();
896 auto elementType = type.getElementType().cast<FloatType>();
897 Value real =
898 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
899 Value imag =
900 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
901 Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
902
903 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
904
905 return success();
906 }
907 };
908
909 /// Coverts x^y = (a+bi)^(c+di) to
910 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
911 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
powOpConversionImpl(mlir::ImplicitLocOpBuilder & builder,ComplexType type,Value a,Value b,Value c,Value d)912 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
913 ComplexType type, Value a, Value b, Value c,
914 Value d) {
915 auto elementType = type.getElementType().cast<FloatType>();
916
917 // Compute (a*a+b*b)^(0.5c).
918 Value aaPbb = builder.create<arith::AddFOp>(
919 builder.create<arith::MulFOp>(a, a), builder.create<arith::MulFOp>(b, b));
920 Value half = builder.create<arith::ConstantOp>(
921 elementType, builder.getFloatAttr(elementType, 0.5));
922 Value halfC = builder.create<arith::MulFOp>(half, c);
923 Value aaPbbTohalfC = builder.create<math::PowFOp>(aaPbb, halfC);
924
925 // Compute exp(-d*atan2(b,a)).
926 Value negD = builder.create<arith::NegFOp>(d);
927 Value argX = builder.create<math::Atan2Op>(b, a);
928 Value negDArgX = builder.create<arith::MulFOp>(negD, argX);
929 Value eToNegDArgX = builder.create<math::ExpOp>(negDArgX);
930
931 // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)).
932 Value coeff = builder.create<arith::MulFOp>(aaPbbTohalfC, eToNegDArgX);
933
934 // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b).
935 Value lnAaPbb = builder.create<math::LogOp>(aaPbb);
936 Value halfD = builder.create<arith::MulFOp>(half, d);
937 Value q = builder.create<arith::AddFOp>(
938 builder.create<arith::MulFOp>(c, argX),
939 builder.create<arith::MulFOp>(halfD, lnAaPbb));
940
941 Value cosQ = builder.create<math::CosOp>(q);
942 Value sinQ = builder.create<math::SinOp>(q);
943 Value zero = builder.create<arith::ConstantOp>(
944 elementType, builder.getFloatAttr(elementType, 0));
945 Value one = builder.create<arith::ConstantOp>(
946 elementType, builder.getFloatAttr(elementType, 1));
947
948 Value xEqZero =
949 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, aaPbb, zero);
950 Value yGeZero = builder.create<arith::AndIOp>(
951 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, c, zero),
952 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero));
953 Value cEqZero =
954 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero);
955 Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
956 Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
957 Value complexOther = builder.create<complex::CreateOp>(
958 type, builder.create<arith::MulFOp>(coeff, cosQ),
959 builder.create<arith::MulFOp>(coeff, sinQ));
960
961 // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see
962 // Branch Cuts for Complex Elementary Functions or Much Ado About
963 // Nothing's Sign Bit, W. Kahan, Section 10.
964 return builder.create<arith::SelectOp>(
965 builder.create<arith::AndIOp>(xEqZero, yGeZero),
966 builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero),
967 complexOther);
968 }
969
970 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
971 using OpConversionPattern<complex::PowOp>::OpConversionPattern;
972
973 LogicalResult
matchAndRewrite__anon597693150111::PowOpConversion974 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
975 ConversionPatternRewriter &rewriter) const override {
976 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
977 auto type = adaptor.getLhs().getType().cast<ComplexType>();
978 auto elementType = type.getElementType().cast<FloatType>();
979
980 Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs());
981 Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs());
982 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
983 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
984
985 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
986 return success();
987 }
988 };
989
990 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
991 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
992
993 LogicalResult
matchAndRewrite__anon597693150111::RsqrtOpConversion994 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
995 ConversionPatternRewriter &rewriter) const override {
996 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
997 auto type = adaptor.getComplex().getType().cast<ComplexType>();
998 auto elementType = type.getElementType().cast<FloatType>();
999
1000 Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex());
1001 Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex());
1002 Value c = builder.create<arith::ConstantOp>(
1003 elementType, builder.getFloatAttr(elementType, -0.5));
1004 Value d = builder.create<arith::ConstantOp>(
1005 elementType, builder.getFloatAttr(elementType, 0));
1006
1007 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
1008 return success();
1009 }
1010 };
1011
1012 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1013 using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
1014
1015 LogicalResult
matchAndRewrite__anon597693150111::AngleOpConversion1016 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1017 ConversionPatternRewriter &rewriter) const override {
1018 auto loc = op.getLoc();
1019 auto type = op.getType();
1020
1021 Value real =
1022 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
1023 Value imag =
1024 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
1025
1026 rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
1027
1028 return success();
1029 }
1030 };
1031
1032 } // namespace
1033
populateComplexToStandardConversionPatterns(RewritePatternSet & patterns)1034 void mlir::populateComplexToStandardConversionPatterns(
1035 RewritePatternSet &patterns) {
1036 // clang-format off
1037 patterns.add<
1038 AbsOpConversion,
1039 AngleOpConversion,
1040 Atan2OpConversion,
1041 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1042 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1043 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1044 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1045 ConjOpConversion,
1046 CosOpConversion,
1047 DivOpConversion,
1048 ExpOpConversion,
1049 Expm1OpConversion,
1050 Log1pOpConversion,
1051 LogOpConversion,
1052 MulOpConversion,
1053 NegOpConversion,
1054 SignOpConversion,
1055 SinOpConversion,
1056 SqrtOpConversion,
1057 TanOpConversion,
1058 TanhOpConversion,
1059 PowOpConversion,
1060 RsqrtOpConversion
1061 >(patterns.getContext());
1062 // clang-format on
1063 }
1064
1065 namespace {
1066 struct ConvertComplexToStandardPass
1067 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1068 void runOnOperation() override;
1069 };
1070
runOnOperation()1071 void ConvertComplexToStandardPass::runOnOperation() {
1072 // Convert to the Standard dialect using the converter defined above.
1073 RewritePatternSet patterns(&getContext());
1074 populateComplexToStandardConversionPatterns(patterns);
1075
1076 ConversionTarget target(getContext());
1077 target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>();
1078 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1079 if (failed(
1080 applyPartialConversion(getOperation(), target, std::move(patterns))))
1081 signalPassFailure();
1082 }
1083 } // namespace
1084
createConvertComplexToStandardPass()1085 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
1086 return std::make_unique<ConvertComplexToStandardPass>();
1087 }
1088