1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 
15 using namespace mlir;
16 using namespace mlir::LLVM;
17 
18 namespace {
19 
20 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
21   using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
22 
23   LogicalResult
24   matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands,
25                   ConversionPatternRewriter &rewriter) const override {
26     complex::CreateOp::Adaptor transformed(operands);
27 
28     // Pack real and imaginary part in a complex number struct.
29     auto loc = complexOp.getLoc();
30     auto structType = typeConverter->convertType(complexOp.getType());
31     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
32     complexStruct.setReal(rewriter, loc, transformed.real());
33     complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
34 
35     rewriter.replaceOp(complexOp, {complexStruct});
36     return success();
37   }
38 };
39 
40 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
41   using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
42 
43   LogicalResult
44   matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands,
45                   ConversionPatternRewriter &rewriter) const override {
46     complex::ReOp::Adaptor transformed(operands);
47 
48     // Extract real part from the complex number struct.
49     ComplexStructBuilder complexStruct(transformed.complex());
50     Value real = complexStruct.real(rewriter, op.getLoc());
51     rewriter.replaceOp(op, real);
52 
53     return success();
54   }
55 };
56 
57 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
58   using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
59 
60   LogicalResult
61   matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands,
62                   ConversionPatternRewriter &rewriter) const override {
63     complex::ImOp::Adaptor transformed(operands);
64 
65     // Extract imaginary part from the complex number struct.
66     ComplexStructBuilder complexStruct(transformed.complex());
67     Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
68     rewriter.replaceOp(op, imaginary);
69 
70     return success();
71   }
72 };
73 
74 struct BinaryComplexOperands {
75   std::complex<Value> lhs;
76   std::complex<Value> rhs;
77 };
78 
79 template <typename OpTy>
80 BinaryComplexOperands
81 unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
82                             ConversionPatternRewriter &rewriter) {
83   auto loc = op.getLoc();
84   typename OpTy::Adaptor transformed(operands);
85 
86   // Extract real and imaginary values from operands.
87   BinaryComplexOperands unpacked;
88   ComplexStructBuilder lhs(transformed.lhs());
89   unpacked.lhs.real(lhs.real(rewriter, loc));
90   unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
91   ComplexStructBuilder rhs(transformed.rhs());
92   unpacked.rhs.real(rhs.real(rewriter, loc));
93   unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
94 
95   return unpacked;
96 }
97 
98 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
99   using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
100 
101   LogicalResult
102   matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands,
103                   ConversionPatternRewriter &rewriter) const override {
104     auto loc = op.getLoc();
105     BinaryComplexOperands arg =
106         unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter);
107 
108     // Initialize complex number struct for result.
109     auto structType = typeConverter->convertType(op.getType());
110     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
111 
112     // Emit IR to add complex numbers.
113     auto fmf = LLVM::FMFAttr::get({}, op.getContext());
114     Value real =
115         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
116     Value imag =
117         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
118     result.setReal(rewriter, loc, real);
119     result.setImaginary(rewriter, loc, imag);
120 
121     rewriter.replaceOp(op, {result});
122     return success();
123   }
124 };
125 
126 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
127   using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
128 
129   LogicalResult
130   matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands,
131                   ConversionPatternRewriter &rewriter) const override {
132     auto loc = op.getLoc();
133     BinaryComplexOperands arg =
134         unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter);
135 
136     // Initialize complex number struct for result.
137     auto structType = typeConverter->convertType(op.getType());
138     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
139 
140     // Emit IR to substract complex numbers.
141     auto fmf = LLVM::FMFAttr::get({}, op.getContext());
142     Value real =
143         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
144     Value imag =
145         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
146     result.setReal(rewriter, loc, real);
147     result.setImaginary(rewriter, loc, imag);
148 
149     rewriter.replaceOp(op, {result});
150     return success();
151   }
152 };
153 } // namespace
154 
155 void mlir::populateComplexToLLVMConversionPatterns(
156     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
157   // clang-format off
158   patterns.insert<
159       AddOpConversion,
160       CreateOpConversion,
161       ImOpConversion,
162       ReOpConversion,
163       SubOpConversion
164     >(converter);
165   // clang-format on
166 }
167 
168 namespace {
169 struct ConvertComplexToLLVMPass
170     : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
171   void runOnOperation() override;
172 };
173 } // namespace
174 
175 void ConvertComplexToLLVMPass::runOnOperation() {
176   auto module = getOperation();
177 
178   // Convert to the LLVM IR dialect using the converter defined above.
179   OwningRewritePatternList patterns;
180   LLVMTypeConverter converter(&getContext());
181   populateStdToLLVMConversionPatterns(converter, patterns);
182   populateComplexToLLVMConversionPatterns(converter, patterns);
183 
184   LLVMConversionTarget target(getContext());
185   target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
186   if (failed(applyFullConversion(module, target, std::move(patterns))))
187     signalPassFailure();
188 }
189 
190 std::unique_ptr<OperationPass<ModuleOp>>
191 mlir::createConvertComplexToLLVMPass() {
192   return std::make_unique<ConvertComplexToLLVMPass>();
193 }
194