1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Complex/IR/Complex.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 15 using namespace mlir; 16 using namespace mlir::LLVM; 17 18 namespace { 19 20 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { 21 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; 22 23 LogicalResult 24 matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands, 25 ConversionPatternRewriter &rewriter) const override { 26 complex::CreateOp::Adaptor transformed(operands); 27 28 // Pack real and imaginary part in a complex number struct. 29 auto loc = complexOp.getLoc(); 30 auto structType = typeConverter->convertType(complexOp.getType()); 31 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); 32 complexStruct.setReal(rewriter, loc, transformed.real()); 33 complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); 34 35 rewriter.replaceOp(complexOp, {complexStruct}); 36 return success(); 37 } 38 }; 39 40 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { 41 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; 42 43 LogicalResult 44 matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands, 45 ConversionPatternRewriter &rewriter) const override { 46 complex::ReOp::Adaptor transformed(operands); 47 48 // Extract real part from the complex number struct. 49 ComplexStructBuilder complexStruct(transformed.complex()); 50 Value real = complexStruct.real(rewriter, op.getLoc()); 51 rewriter.replaceOp(op, real); 52 53 return success(); 54 } 55 }; 56 57 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { 58 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; 59 60 LogicalResult 61 matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands, 62 ConversionPatternRewriter &rewriter) const override { 63 complex::ImOp::Adaptor transformed(operands); 64 65 // Extract imaginary part from the complex number struct. 66 ComplexStructBuilder complexStruct(transformed.complex()); 67 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); 68 rewriter.replaceOp(op, imaginary); 69 70 return success(); 71 } 72 }; 73 74 struct BinaryComplexOperands { 75 std::complex<Value> lhs; 76 std::complex<Value> rhs; 77 }; 78 79 template <typename OpTy> 80 BinaryComplexOperands 81 unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands, 82 ConversionPatternRewriter &rewriter) { 83 auto loc = op.getLoc(); 84 typename OpTy::Adaptor transformed(operands); 85 86 // Extract real and imaginary values from operands. 87 BinaryComplexOperands unpacked; 88 ComplexStructBuilder lhs(transformed.lhs()); 89 unpacked.lhs.real(lhs.real(rewriter, loc)); 90 unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); 91 ComplexStructBuilder rhs(transformed.rhs()); 92 unpacked.rhs.real(rhs.real(rewriter, loc)); 93 unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); 94 95 return unpacked; 96 } 97 98 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { 99 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; 100 101 LogicalResult 102 matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands, 103 ConversionPatternRewriter &rewriter) const override { 104 auto loc = op.getLoc(); 105 BinaryComplexOperands arg = 106 unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter); 107 108 // Initialize complex number struct for result. 109 auto structType = typeConverter->convertType(op.getType()); 110 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 111 112 // Emit IR to add complex numbers. 113 auto fmf = LLVM::FMFAttr::get({}, op.getContext()); 114 Value real = 115 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 116 Value imag = 117 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 118 result.setReal(rewriter, loc, real); 119 result.setImaginary(rewriter, loc, imag); 120 121 rewriter.replaceOp(op, {result}); 122 return success(); 123 } 124 }; 125 126 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { 127 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; 128 129 LogicalResult 130 matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands, 131 ConversionPatternRewriter &rewriter) const override { 132 auto loc = op.getLoc(); 133 BinaryComplexOperands arg = 134 unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter); 135 136 // Initialize complex number struct for result. 137 auto structType = typeConverter->convertType(op.getType()); 138 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 139 140 // Emit IR to substract complex numbers. 141 auto fmf = LLVM::FMFAttr::get({}, op.getContext()); 142 Value real = 143 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 144 Value imag = 145 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 146 result.setReal(rewriter, loc, real); 147 result.setImaginary(rewriter, loc, imag); 148 149 rewriter.replaceOp(op, {result}); 150 return success(); 151 } 152 }; 153 } // namespace 154 155 void mlir::populateComplexToLLVMConversionPatterns( 156 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 157 // clang-format off 158 patterns.insert< 159 AddOpConversion, 160 CreateOpConversion, 161 ImOpConversion, 162 ReOpConversion, 163 SubOpConversion 164 >(converter); 165 // clang-format on 166 } 167 168 namespace { 169 struct ConvertComplexToLLVMPass 170 : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> { 171 void runOnOperation() override; 172 }; 173 } // namespace 174 175 void ConvertComplexToLLVMPass::runOnOperation() { 176 auto module = getOperation(); 177 178 // Convert to the LLVM IR dialect using the converter defined above. 179 OwningRewritePatternList patterns; 180 LLVMTypeConverter converter(&getContext()); 181 populateStdToLLVMConversionPatterns(converter, patterns); 182 populateComplexToLLVMConversionPatterns(converter, patterns); 183 184 LLVMConversionTarget target(getContext()); 185 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 186 if (failed(applyFullConversion(module, target, std::move(patterns)))) 187 signalPassFailure(); 188 } 189 190 std::unique_ptr<OperationPass<ModuleOp>> 191 mlir::createConvertComplexToLLVMPass() { 192 return std::make_unique<ConvertComplexToLLVMPass>(); 193 } 194