1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 #include <type_traits> 13 14 #include "../PassDetail.h" 15 #include "mlir/Dialect/Complex/IR/Complex.h" 16 #include "mlir/Dialect/Math/IR/Math.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 using namespace mlir; 22 23 namespace { 24 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 25 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 26 27 LogicalResult 28 matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands, 29 ConversionPatternRewriter &rewriter) const override { 30 complex::AbsOp::Adaptor transformed(operands); 31 auto loc = op.getLoc(); 32 auto type = op.getType(); 33 34 Value real = 35 rewriter.create<complex::ReOp>(loc, type, transformed.complex()); 36 Value imag = 37 rewriter.create<complex::ImOp>(loc, type, transformed.complex()); 38 Value realSqr = rewriter.create<MulFOp>(loc, real, real); 39 Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag); 40 Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr); 41 42 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 43 return success(); 44 } 45 }; 46 47 template <typename ComparisonOp, CmpFPredicate p> 48 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 49 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 50 using ResultCombiner = 51 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 52 AndOp, OrOp>; 53 54 LogicalResult 55 matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands, 56 ConversionPatternRewriter &rewriter) const override { 57 typename ComparisonOp::Adaptor transformed(operands); 58 auto loc = op.getLoc(); 59 auto type = transformed.lhs() 60 .getType() 61 .template cast<ComplexType>() 62 .getElementType(); 63 64 Value realLhs = 65 rewriter.create<complex::ReOp>(loc, type, transformed.lhs()); 66 Value imagLhs = 67 rewriter.create<complex::ImOp>(loc, type, transformed.lhs()); 68 Value realRhs = 69 rewriter.create<complex::ReOp>(loc, type, transformed.rhs()); 70 Value imagRhs = 71 rewriter.create<complex::ImOp>(loc, type, transformed.rhs()); 72 Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs); 73 Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs); 74 75 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 76 imagComparison); 77 return success(); 78 } 79 }; 80 } // namespace 81 82 void mlir::populateComplexToStandardConversionPatterns( 83 RewritePatternSet &patterns) { 84 patterns.add<AbsOpConversion, 85 ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>, 86 ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>>( 87 patterns.getContext()); 88 } 89 90 namespace { 91 struct ConvertComplexToStandardPass 92 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 93 void runOnFunction() override; 94 }; 95 96 void ConvertComplexToStandardPass::runOnFunction() { 97 auto function = getFunction(); 98 99 // Convert to the Standard dialect using the converter defined above. 100 RewritePatternSet patterns(&getContext()); 101 populateComplexToStandardConversionPatterns(patterns); 102 103 ConversionTarget target(getContext()); 104 target.addLegalDialect<StandardOpsDialect, math::MathDialect, 105 complex::ComplexDialect>(); 106 target.addIllegalOp<complex::AbsOp, complex::EqualOp, complex::NotEqualOp>(); 107 if (failed(applyPartialConversion(function, target, std::move(patterns)))) 108 signalPassFailure(); 109 } 110 } // namespace 111 112 std::unique_ptr<OperationPass<FuncOp>> 113 mlir::createConvertComplexToStandardPass() { 114 return std::make_unique<ConvertComplexToStandardPass>(); 115 } 116