1*2ea7fb7bSAdrian Kuegel //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2*2ea7fb7bSAdrian Kuegel //
3*2ea7fb7bSAdrian Kuegel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*2ea7fb7bSAdrian Kuegel // See https://llvm.org/LICENSE.txt for license information.
5*2ea7fb7bSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*2ea7fb7bSAdrian Kuegel //
7*2ea7fb7bSAdrian Kuegel //===----------------------------------------------------------------------===//
8*2ea7fb7bSAdrian Kuegel 
9*2ea7fb7bSAdrian Kuegel #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10*2ea7fb7bSAdrian Kuegel 
11*2ea7fb7bSAdrian Kuegel #include <memory>
12*2ea7fb7bSAdrian Kuegel 
13*2ea7fb7bSAdrian Kuegel #include "../PassDetail.h"
14*2ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Complex/IR/Complex.h"
15*2ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Math/IR/Math.h"
16*2ea7fb7bSAdrian Kuegel #include "mlir/Dialect/StandardOps/IR/Ops.h"
17*2ea7fb7bSAdrian Kuegel #include "mlir/IR/PatternMatch.h"
18*2ea7fb7bSAdrian Kuegel #include "mlir/Transforms/DialectConversion.h"
19*2ea7fb7bSAdrian Kuegel 
20*2ea7fb7bSAdrian Kuegel using namespace mlir;
21*2ea7fb7bSAdrian Kuegel 
22*2ea7fb7bSAdrian Kuegel namespace {
23*2ea7fb7bSAdrian Kuegel struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
24*2ea7fb7bSAdrian Kuegel   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
25*2ea7fb7bSAdrian Kuegel 
26*2ea7fb7bSAdrian Kuegel   LogicalResult
27*2ea7fb7bSAdrian Kuegel   matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
28*2ea7fb7bSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
29*2ea7fb7bSAdrian Kuegel     complex::AbsOp::Adaptor transformed(operands);
30*2ea7fb7bSAdrian Kuegel     auto loc = op.getLoc();
31*2ea7fb7bSAdrian Kuegel     auto type = op.getType();
32*2ea7fb7bSAdrian Kuegel 
33*2ea7fb7bSAdrian Kuegel     Value real =
34*2ea7fb7bSAdrian Kuegel         rewriter.create<complex::ReOp>(loc, type, transformed.complex());
35*2ea7fb7bSAdrian Kuegel     Value imag =
36*2ea7fb7bSAdrian Kuegel         rewriter.create<complex::ImOp>(loc, type, transformed.complex());
37*2ea7fb7bSAdrian Kuegel     Value realSqr = rewriter.create<MulFOp>(loc, real, real);
38*2ea7fb7bSAdrian Kuegel     Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag);
39*2ea7fb7bSAdrian Kuegel     Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr);
40*2ea7fb7bSAdrian Kuegel 
41*2ea7fb7bSAdrian Kuegel     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
42*2ea7fb7bSAdrian Kuegel     return success();
43*2ea7fb7bSAdrian Kuegel   }
44*2ea7fb7bSAdrian Kuegel };
45*2ea7fb7bSAdrian Kuegel } // namespace
46*2ea7fb7bSAdrian Kuegel 
47*2ea7fb7bSAdrian Kuegel void mlir::populateComplexToStandardConversionPatterns(
48*2ea7fb7bSAdrian Kuegel     RewritePatternSet &patterns) {
49*2ea7fb7bSAdrian Kuegel   patterns.add<AbsOpConversion>(patterns.getContext());
50*2ea7fb7bSAdrian Kuegel }
51*2ea7fb7bSAdrian Kuegel 
52*2ea7fb7bSAdrian Kuegel namespace {
53*2ea7fb7bSAdrian Kuegel struct ConvertComplexToStandardPass
54*2ea7fb7bSAdrian Kuegel     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
55*2ea7fb7bSAdrian Kuegel   void runOnFunction() override;
56*2ea7fb7bSAdrian Kuegel };
57*2ea7fb7bSAdrian Kuegel 
58*2ea7fb7bSAdrian Kuegel void ConvertComplexToStandardPass::runOnFunction() {
59*2ea7fb7bSAdrian Kuegel   auto function = getFunction();
60*2ea7fb7bSAdrian Kuegel 
61*2ea7fb7bSAdrian Kuegel   // Convert to the Standard dialect using the converter defined above.
62*2ea7fb7bSAdrian Kuegel   RewritePatternSet patterns(&getContext());
63*2ea7fb7bSAdrian Kuegel   populateComplexToStandardConversionPatterns(patterns);
64*2ea7fb7bSAdrian Kuegel 
65*2ea7fb7bSAdrian Kuegel   ConversionTarget target(getContext());
66*2ea7fb7bSAdrian Kuegel   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
67*2ea7fb7bSAdrian Kuegel                          complex::ComplexDialect>();
68*2ea7fb7bSAdrian Kuegel   target.addIllegalOp<complex::AbsOp>();
69*2ea7fb7bSAdrian Kuegel   if (failed(applyPartialConversion(function, target, std::move(patterns))))
70*2ea7fb7bSAdrian Kuegel     signalPassFailure();
71*2ea7fb7bSAdrian Kuegel }
72*2ea7fb7bSAdrian Kuegel } // namespace
73*2ea7fb7bSAdrian Kuegel 
74*2ea7fb7bSAdrian Kuegel std::unique_ptr<OperationPass<FuncOp>>
75*2ea7fb7bSAdrian Kuegel mlir::createConvertComplexToStandardPass() {
76*2ea7fb7bSAdrian Kuegel   return std::make_unique<ConvertComplexToStandardPass>();
77*2ea7fb7bSAdrian Kuegel }
78