1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Complex/IR/Complex.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
25   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
26 
27   LogicalResult
28   matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
29                   ConversionPatternRewriter &rewriter) const override {
30     complex::AbsOp::Adaptor transformed(operands);
31     auto loc = op.getLoc();
32     auto type = op.getType();
33 
34     Value real =
35         rewriter.create<complex::ReOp>(loc, type, transformed.complex());
36     Value imag =
37         rewriter.create<complex::ImOp>(loc, type, transformed.complex());
38     Value realSqr = rewriter.create<MulFOp>(loc, real, real);
39     Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag);
40     Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr);
41 
42     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
43     return success();
44   }
45 };
46 
47 template <typename ComparisonOp, CmpFPredicate p>
48 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
49   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
50   using ResultCombiner =
51       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
52                          AndOp, OrOp>;
53 
54   LogicalResult
55   matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands,
56                   ConversionPatternRewriter &rewriter) const override {
57     typename ComparisonOp::Adaptor transformed(operands);
58     auto loc = op.getLoc();
59     auto type = transformed.lhs()
60                     .getType()
61                     .template cast<ComplexType>()
62                     .getElementType();
63 
64     Value realLhs =
65         rewriter.create<complex::ReOp>(loc, type, transformed.lhs());
66     Value imagLhs =
67         rewriter.create<complex::ImOp>(loc, type, transformed.lhs());
68     Value realRhs =
69         rewriter.create<complex::ReOp>(loc, type, transformed.rhs());
70     Value imagRhs =
71         rewriter.create<complex::ImOp>(loc, type, transformed.rhs());
72     Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs);
73     Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs);
74 
75     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
76                                                 imagComparison);
77     return success();
78   }
79 };
80 
81 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
82   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
83 
84   LogicalResult
85   matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
86                   ConversionPatternRewriter &rewriter) const override {
87     complex::DivOp::Adaptor transformed(operands);
88     auto loc = op.getLoc();
89     auto type = transformed.lhs().getType().template cast<ComplexType>();
90     auto elementType = type.getElementType().cast<FloatType>();
91 
92     Value lhsReal =
93         rewriter.create<complex::ReOp>(loc, elementType, transformed.lhs());
94     Value lhsImag =
95         rewriter.create<complex::ImOp>(loc, elementType, transformed.lhs());
96     Value rhsReal =
97         rewriter.create<complex::ReOp>(loc, elementType, transformed.rhs());
98     Value rhsImag =
99         rewriter.create<complex::ImOp>(loc, elementType, transformed.rhs());
100 
101     // Smith's algorithm to divide complex numbers. It is just a bit smarter
102     // way to compute the following formula:
103     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
104     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
105     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
106     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
107     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
108     //
109     // Depending on whether |rhsReal| < |rhsImag| we compute either
110     //   rhsRealImagRatio = rhsReal / rhsImag
111     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
112     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
113     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
114     //
115     // or
116     //
117     //   rhsImagRealRatio = rhsImag / rhsReal
118     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
119     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
120     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
121     //
122     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
123     Value rhsRealImagRatio = rewriter.create<DivFOp>(loc, rhsReal, rhsImag);
124     Value rhsRealImagDenom = rewriter.create<AddFOp>(
125         loc, rhsImag, rewriter.create<MulFOp>(loc, rhsRealImagRatio, rhsReal));
126     Value realNumerator1 = rewriter.create<AddFOp>(
127         loc, rewriter.create<MulFOp>(loc, lhsReal, rhsRealImagRatio), lhsImag);
128     Value resultReal1 =
129         rewriter.create<DivFOp>(loc, realNumerator1, rhsRealImagDenom);
130     Value imagNumerator1 = rewriter.create<SubFOp>(
131         loc, rewriter.create<MulFOp>(loc, lhsImag, rhsRealImagRatio), lhsReal);
132     Value resultImag1 =
133         rewriter.create<DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
134 
135     Value rhsImagRealRatio = rewriter.create<DivFOp>(loc, rhsImag, rhsReal);
136     Value rhsImagRealDenom = rewriter.create<AddFOp>(
137         loc, rhsReal, rewriter.create<MulFOp>(loc, rhsImagRealRatio, rhsImag));
138     Value realNumerator2 = rewriter.create<AddFOp>(
139         loc, lhsReal, rewriter.create<MulFOp>(loc, lhsImag, rhsImagRealRatio));
140     Value resultReal2 =
141         rewriter.create<DivFOp>(loc, realNumerator2, rhsImagRealDenom);
142     Value imagNumerator2 = rewriter.create<SubFOp>(
143         loc, lhsImag, rewriter.create<MulFOp>(loc, lhsReal, rhsImagRealRatio));
144     Value resultImag2 =
145         rewriter.create<DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
146 
147     // Consider corner cases.
148     // Case 1. Zero denominator, numerator contains at most one NaN value.
149     Value zero = rewriter.create<ConstantOp>(loc, elementType,
150                                              rewriter.getZeroAttr(elementType));
151     Value rhsRealAbs = rewriter.create<AbsFOp>(loc, rhsReal);
152     Value rhsRealIsZero =
153         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, zero);
154     Value rhsImagAbs = rewriter.create<AbsFOp>(loc, rhsImag);
155     Value rhsImagIsZero =
156         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, zero);
157     Value lhsRealIsNotNaN =
158         rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsReal, zero);
159     Value lhsImagIsNotNaN =
160         rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsImag, zero);
161     Value lhsContainsNotNaNValue =
162         rewriter.create<OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
163     Value resultIsInfinity = rewriter.create<AndOp>(
164         loc, lhsContainsNotNaNValue,
165         rewriter.create<AndOp>(loc, rhsRealIsZero, rhsImagIsZero));
166     Value inf = rewriter.create<ConstantOp>(
167         loc, elementType,
168         rewriter.getFloatAttr(
169             elementType, APFloat::getInf(elementType.getFloatSemantics())));
170     Value infWithSignOfRhsReal = rewriter.create<CopySignOp>(loc, inf, rhsReal);
171     Value infinityResultReal =
172         rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
173     Value infinityResultImag =
174         rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
175 
176     // Case 2. Infinite numerator, finite denominator.
177     Value rhsRealFinite =
178         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsRealAbs, inf);
179     Value rhsImagFinite =
180         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsImagAbs, inf);
181     Value rhsFinite = rewriter.create<AndOp>(loc, rhsRealFinite, rhsImagFinite);
182     Value lhsRealAbs = rewriter.create<AbsFOp>(loc, lhsReal);
183     Value lhsRealInfinite =
184         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsRealAbs, inf);
185     Value lhsImagAbs = rewriter.create<AbsFOp>(loc, lhsImag);
186     Value lhsImagInfinite =
187         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsImagAbs, inf);
188     Value lhsInfinite =
189         rewriter.create<OrOp>(loc, lhsRealInfinite, lhsImagInfinite);
190     Value infNumFiniteDenom =
191         rewriter.create<AndOp>(loc, lhsInfinite, rhsFinite);
192     Value one = rewriter.create<ConstantOp>(
193         loc, elementType, rewriter.getFloatAttr(elementType, 1));
194     Value lhsRealIsInfWithSign = rewriter.create<CopySignOp>(
195         loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero),
196         lhsReal);
197     Value lhsImagIsInfWithSign = rewriter.create<CopySignOp>(
198         loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero),
199         lhsImag);
200     Value lhsRealIsInfWithSignTimesRhsReal =
201         rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
202     Value lhsImagIsInfWithSignTimesRhsImag =
203         rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
204     Value resultReal3 = rewriter.create<MulFOp>(
205         loc, inf,
206         rewriter.create<AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
207                                 lhsImagIsInfWithSignTimesRhsImag));
208     Value lhsRealIsInfWithSignTimesRhsImag =
209         rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
210     Value lhsImagIsInfWithSignTimesRhsReal =
211         rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
212     Value resultImag3 = rewriter.create<MulFOp>(
213         loc, inf,
214         rewriter.create<SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
215                                 lhsRealIsInfWithSignTimesRhsImag));
216 
217     // Case 3: Finite numerator, infinite denominator.
218     Value lhsRealFinite =
219         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsRealAbs, inf);
220     Value lhsImagFinite =
221         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsImagAbs, inf);
222     Value lhsFinite = rewriter.create<AndOp>(loc, lhsRealFinite, lhsImagFinite);
223     Value rhsRealInfinite =
224         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, inf);
225     Value rhsImagInfinite =
226         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, inf);
227     Value rhsInfinite =
228         rewriter.create<OrOp>(loc, rhsRealInfinite, rhsImagInfinite);
229     Value finiteNumInfiniteDenom =
230         rewriter.create<AndOp>(loc, lhsFinite, rhsInfinite);
231     Value rhsRealIsInfWithSign = rewriter.create<CopySignOp>(
232         loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero),
233         rhsReal);
234     Value rhsImagIsInfWithSign = rewriter.create<CopySignOp>(
235         loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero),
236         rhsImag);
237     Value rhsRealIsInfWithSignTimesLhsReal =
238         rewriter.create<MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
239     Value rhsImagIsInfWithSignTimesLhsImag =
240         rewriter.create<MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
241     Value resultReal4 = rewriter.create<MulFOp>(
242         loc, zero,
243         rewriter.create<AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
244                                 rhsImagIsInfWithSignTimesLhsImag));
245     Value rhsRealIsInfWithSignTimesLhsImag =
246         rewriter.create<MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
247     Value rhsImagIsInfWithSignTimesLhsReal =
248         rewriter.create<MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
249     Value resultImag4 = rewriter.create<MulFOp>(
250         loc, zero,
251         rewriter.create<SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
252                                 rhsImagIsInfWithSignTimesLhsReal));
253 
254     Value realAbsSmallerThanImagAbs = rewriter.create<CmpFOp>(
255         loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
256     Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
257                                                  resultReal1, resultReal2);
258     Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
259                                                  resultImag1, resultImag2);
260     Value resultRealSpecialCase3 = rewriter.create<SelectOp>(
261         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
262     Value resultImagSpecialCase3 = rewriter.create<SelectOp>(
263         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
264     Value resultRealSpecialCase2 = rewriter.create<SelectOp>(
265         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
266     Value resultImagSpecialCase2 = rewriter.create<SelectOp>(
267         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
268     Value resultRealSpecialCase1 = rewriter.create<SelectOp>(
269         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
270     Value resultImagSpecialCase1 = rewriter.create<SelectOp>(
271         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
272 
273     Value resultRealIsNaN =
274         rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultReal, zero);
275     Value resultImagIsNaN =
276         rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultImag, zero);
277     Value resultIsNaN =
278         rewriter.create<AndOp>(loc, resultRealIsNaN, resultImagIsNaN);
279     Value resultRealWithSpecialCases = rewriter.create<SelectOp>(
280         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
281     Value resultImagWithSpecialCases = rewriter.create<SelectOp>(
282         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
283 
284     rewriter.replaceOpWithNewOp<complex::CreateOp>(
285         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
286     return success();
287   }
288 };
289 } // namespace
290 
291 void mlir::populateComplexToStandardConversionPatterns(
292     RewritePatternSet &patterns) {
293   patterns.add<AbsOpConversion,
294                ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
295                ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
296                DivOpConversion>(patterns.getContext());
297 }
298 
299 namespace {
300 struct ConvertComplexToStandardPass
301     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
302   void runOnFunction() override;
303 };
304 
305 void ConvertComplexToStandardPass::runOnFunction() {
306   auto function = getFunction();
307 
308   // Convert to the Standard dialect using the converter defined above.
309   RewritePatternSet patterns(&getContext());
310   populateComplexToStandardConversionPatterns(patterns);
311 
312   ConversionTarget target(getContext());
313   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
314                          complex::ComplexDialect>();
315   target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
316                       complex::NotEqualOp>();
317   if (failed(applyPartialConversion(function, target, std::move(patterns))))
318     signalPassFailure();
319 }
320 } // namespace
321 
322 std::unique_ptr<OperationPass<FuncOp>>
323 mlir::createConvertComplexToStandardPass() {
324   return std::make_unique<ConvertComplexToStandardPass>();
325 }
326