1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 
13 #include "../PassDetail.h"
14 #include "mlir/Dialect/Complex/IR/Complex.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
24   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
25 
26   LogicalResult
27   matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
28                   ConversionPatternRewriter &rewriter) const override {
29     complex::AbsOp::Adaptor transformed(operands);
30     auto loc = op.getLoc();
31     auto type = op.getType();
32 
33     Value real =
34         rewriter.create<complex::ReOp>(loc, type, transformed.complex());
35     Value imag =
36         rewriter.create<complex::ImOp>(loc, type, transformed.complex());
37     Value realSqr = rewriter.create<MulFOp>(loc, real, real);
38     Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag);
39     Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr);
40 
41     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
42     return success();
43   }
44 };
45 } // namespace
46 
47 void mlir::populateComplexToStandardConversionPatterns(
48     RewritePatternSet &patterns) {
49   patterns.add<AbsOpConversion>(patterns.getContext());
50 }
51 
52 namespace {
53 struct ConvertComplexToStandardPass
54     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
55   void runOnFunction() override;
56 };
57 
58 void ConvertComplexToStandardPass::runOnFunction() {
59   auto function = getFunction();
60 
61   // Convert to the Standard dialect using the converter defined above.
62   RewritePatternSet patterns(&getContext());
63   populateComplexToStandardConversionPatterns(patterns);
64 
65   ConversionTarget target(getContext());
66   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
67                          complex::ComplexDialect>();
68   target.addIllegalOp<complex::AbsOp>();
69   if (failed(applyPartialConversion(function, target, std::move(patterns))))
70     signalPassFailure();
71 }
72 } // namespace
73 
74 std::unique_ptr<OperationPass<FuncOp>>
75 mlir::createConvertComplexToStandardPass() {
76   return std::make_unique<ConvertComplexToStandardPass>();
77 }
78