1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Complex/IR/Complex.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
25   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
26 
27   LogicalResult
28   matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
29                   ConversionPatternRewriter &rewriter) const override {
30     complex::AbsOp::Adaptor transformed(operands);
31     auto loc = op.getLoc();
32     auto type = op.getType();
33 
34     Value real =
35         rewriter.create<complex::ReOp>(loc, type, transformed.complex());
36     Value imag =
37         rewriter.create<complex::ImOp>(loc, type, transformed.complex());
38     Value realSqr = rewriter.create<MulFOp>(loc, real, real);
39     Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag);
40     Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr);
41 
42     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
43     return success();
44   }
45 };
46 
47 template <typename ComparisonOp, CmpFPredicate p>
48 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
49   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
50   using ResultCombiner =
51       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
52                          AndOp, OrOp>;
53 
54   LogicalResult
55   matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands,
56                   ConversionPatternRewriter &rewriter) const override {
57     typename ComparisonOp::Adaptor transformed(operands);
58     auto loc = op.getLoc();
59     auto type = transformed.lhs()
60                     .getType()
61                     .template cast<ComplexType>()
62                     .getElementType();
63 
64     Value realLhs =
65         rewriter.create<complex::ReOp>(loc, type, transformed.lhs());
66     Value imagLhs =
67         rewriter.create<complex::ImOp>(loc, type, transformed.lhs());
68     Value realRhs =
69         rewriter.create<complex::ReOp>(loc, type, transformed.rhs());
70     Value imagRhs =
71         rewriter.create<complex::ImOp>(loc, type, transformed.rhs());
72     Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs);
73     Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs);
74 
75     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
76                                                 imagComparison);
77     return success();
78   }
79 };
80 } // namespace
81 
82 void mlir::populateComplexToStandardConversionPatterns(
83     RewritePatternSet &patterns) {
84   patterns.add<AbsOpConversion,
85                ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
86                ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>>(
87       patterns.getContext());
88 }
89 
90 namespace {
91 struct ConvertComplexToStandardPass
92     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
93   void runOnFunction() override;
94 };
95 
96 void ConvertComplexToStandardPass::runOnFunction() {
97   auto function = getFunction();
98 
99   // Convert to the Standard dialect using the converter defined above.
100   RewritePatternSet patterns(&getContext());
101   populateComplexToStandardConversionPatterns(patterns);
102 
103   ConversionTarget target(getContext());
104   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
105                          complex::ComplexDialect>();
106   target.addIllegalOp<complex::AbsOp, complex::EqualOp, complex::NotEqualOp>();
107   if (failed(applyPartialConversion(function, target, std::move(patterns))))
108     signalPassFailure();
109 }
110 } // namespace
111 
112 std::unique_ptr<OperationPass<FuncOp>>
113 mlir::createConvertComplexToStandardPass() {
114   return std::make_unique<ConvertComplexToStandardPass>();
115 }
116