1*2ea7fb7bSAdrian Kuegel //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2*2ea7fb7bSAdrian Kuegel // 3*2ea7fb7bSAdrian Kuegel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*2ea7fb7bSAdrian Kuegel // See https://llvm.org/LICENSE.txt for license information. 5*2ea7fb7bSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*2ea7fb7bSAdrian Kuegel // 7*2ea7fb7bSAdrian Kuegel //===----------------------------------------------------------------------===// 8*2ea7fb7bSAdrian Kuegel 9*2ea7fb7bSAdrian Kuegel #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10*2ea7fb7bSAdrian Kuegel 11*2ea7fb7bSAdrian Kuegel #include <memory> 12*2ea7fb7bSAdrian Kuegel 13*2ea7fb7bSAdrian Kuegel #include "../PassDetail.h" 14*2ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Complex/IR/Complex.h" 15*2ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Math/IR/Math.h" 16*2ea7fb7bSAdrian Kuegel #include "mlir/Dialect/StandardOps/IR/Ops.h" 17*2ea7fb7bSAdrian Kuegel #include "mlir/IR/PatternMatch.h" 18*2ea7fb7bSAdrian Kuegel #include "mlir/Transforms/DialectConversion.h" 19*2ea7fb7bSAdrian Kuegel 20*2ea7fb7bSAdrian Kuegel using namespace mlir; 21*2ea7fb7bSAdrian Kuegel 22*2ea7fb7bSAdrian Kuegel namespace { 23*2ea7fb7bSAdrian Kuegel struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 24*2ea7fb7bSAdrian Kuegel using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 25*2ea7fb7bSAdrian Kuegel 26*2ea7fb7bSAdrian Kuegel LogicalResult 27*2ea7fb7bSAdrian Kuegel matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands, 28*2ea7fb7bSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 29*2ea7fb7bSAdrian Kuegel complex::AbsOp::Adaptor transformed(operands); 30*2ea7fb7bSAdrian Kuegel auto loc = op.getLoc(); 31*2ea7fb7bSAdrian Kuegel auto type = op.getType(); 32*2ea7fb7bSAdrian Kuegel 33*2ea7fb7bSAdrian Kuegel Value real = 34*2ea7fb7bSAdrian Kuegel rewriter.create<complex::ReOp>(loc, type, transformed.complex()); 35*2ea7fb7bSAdrian Kuegel Value imag = 36*2ea7fb7bSAdrian Kuegel rewriter.create<complex::ImOp>(loc, type, transformed.complex()); 37*2ea7fb7bSAdrian Kuegel Value realSqr = rewriter.create<MulFOp>(loc, real, real); 38*2ea7fb7bSAdrian Kuegel Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag); 39*2ea7fb7bSAdrian Kuegel Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr); 40*2ea7fb7bSAdrian Kuegel 41*2ea7fb7bSAdrian Kuegel rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 42*2ea7fb7bSAdrian Kuegel return success(); 43*2ea7fb7bSAdrian Kuegel } 44*2ea7fb7bSAdrian Kuegel }; 45*2ea7fb7bSAdrian Kuegel } // namespace 46*2ea7fb7bSAdrian Kuegel 47*2ea7fb7bSAdrian Kuegel void mlir::populateComplexToStandardConversionPatterns( 48*2ea7fb7bSAdrian Kuegel RewritePatternSet &patterns) { 49*2ea7fb7bSAdrian Kuegel patterns.add<AbsOpConversion>(patterns.getContext()); 50*2ea7fb7bSAdrian Kuegel } 51*2ea7fb7bSAdrian Kuegel 52*2ea7fb7bSAdrian Kuegel namespace { 53*2ea7fb7bSAdrian Kuegel struct ConvertComplexToStandardPass 54*2ea7fb7bSAdrian Kuegel : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 55*2ea7fb7bSAdrian Kuegel void runOnFunction() override; 56*2ea7fb7bSAdrian Kuegel }; 57*2ea7fb7bSAdrian Kuegel 58*2ea7fb7bSAdrian Kuegel void ConvertComplexToStandardPass::runOnFunction() { 59*2ea7fb7bSAdrian Kuegel auto function = getFunction(); 60*2ea7fb7bSAdrian Kuegel 61*2ea7fb7bSAdrian Kuegel // Convert to the Standard dialect using the converter defined above. 62*2ea7fb7bSAdrian Kuegel RewritePatternSet patterns(&getContext()); 63*2ea7fb7bSAdrian Kuegel populateComplexToStandardConversionPatterns(patterns); 64*2ea7fb7bSAdrian Kuegel 65*2ea7fb7bSAdrian Kuegel ConversionTarget target(getContext()); 66*2ea7fb7bSAdrian Kuegel target.addLegalDialect<StandardOpsDialect, math::MathDialect, 67*2ea7fb7bSAdrian Kuegel complex::ComplexDialect>(); 68*2ea7fb7bSAdrian Kuegel target.addIllegalOp<complex::AbsOp>(); 69*2ea7fb7bSAdrian Kuegel if (failed(applyPartialConversion(function, target, std::move(patterns)))) 70*2ea7fb7bSAdrian Kuegel signalPassFailure(); 71*2ea7fb7bSAdrian Kuegel } 72*2ea7fb7bSAdrian Kuegel } // namespace 73*2ea7fb7bSAdrian Kuegel 74*2ea7fb7bSAdrian Kuegel std::unique_ptr<OperationPass<FuncOp>> 75*2ea7fb7bSAdrian Kuegel mlir::createConvertComplexToStandardPass() { 76*2ea7fb7bSAdrian Kuegel return std::make_unique<ConvertComplexToStandardPass>(); 77*2ea7fb7bSAdrian Kuegel } 78