1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 13 #include "../PassDetail.h" 14 #include "mlir/Dialect/Complex/IR/Complex.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 22 namespace { 23 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 24 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 25 26 LogicalResult 27 matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands, 28 ConversionPatternRewriter &rewriter) const override { 29 complex::AbsOp::Adaptor transformed(operands); 30 auto loc = op.getLoc(); 31 auto type = op.getType(); 32 33 Value real = 34 rewriter.create<complex::ReOp>(loc, type, transformed.complex()); 35 Value imag = 36 rewriter.create<complex::ImOp>(loc, type, transformed.complex()); 37 Value realSqr = rewriter.create<MulFOp>(loc, real, real); 38 Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag); 39 Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr); 40 41 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 42 return success(); 43 } 44 }; 45 } // namespace 46 47 void mlir::populateComplexToStandardConversionPatterns( 48 RewritePatternSet &patterns) { 49 patterns.add<AbsOpConversion>(patterns.getContext()); 50 } 51 52 namespace { 53 struct ConvertComplexToStandardPass 54 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 55 void runOnFunction() override; 56 }; 57 58 void ConvertComplexToStandardPass::runOnFunction() { 59 auto function = getFunction(); 60 61 // Convert to the Standard dialect using the converter defined above. 62 RewritePatternSet patterns(&getContext()); 63 populateComplexToStandardConversionPatterns(patterns); 64 65 ConversionTarget target(getContext()); 66 target.addLegalDialect<StandardOpsDialect, math::MathDialect, 67 complex::ComplexDialect>(); 68 target.addIllegalOp<complex::AbsOp>(); 69 if (failed(applyPartialConversion(function, target, std::move(patterns)))) 70 signalPassFailure(); 71 } 72 } // namespace 73 74 std::unique_ptr<OperationPass<FuncOp>> 75 mlir::createConvertComplexToStandardPass() { 76 return std::make_unique<ConvertComplexToStandardPass>(); 77 } 78