1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 
15 using namespace mlir;
16 using namespace mlir::LLVM;
17 
18 namespace {
19 
20 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
21   using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
22 
23   LogicalResult
24   matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
25                   ConversionPatternRewriter &rewriter) const override {
26     complex::AbsOp::Adaptor transformed(operands);
27     auto loc = op.getLoc();
28 
29     ComplexStructBuilder complexStruct(transformed.complex());
30     Value real = complexStruct.real(rewriter, op.getLoc());
31     Value imag = complexStruct.imaginary(rewriter, op.getLoc());
32 
33     auto fmf = LLVM::FMFAttr::get({}, op.getContext());
34     Value sqNorm = rewriter.create<LLVM::FAddOp>(
35         loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
36         rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
37 
38     rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
39     return success();
40   }
41 };
42 
43 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
44   using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
45 
46   LogicalResult
47   matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands,
48                   ConversionPatternRewriter &rewriter) const override {
49     complex::CreateOp::Adaptor transformed(operands);
50 
51     // Pack real and imaginary part in a complex number struct.
52     auto loc = complexOp.getLoc();
53     auto structType = typeConverter->convertType(complexOp.getType());
54     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
55     complexStruct.setReal(rewriter, loc, transformed.real());
56     complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
57 
58     rewriter.replaceOp(complexOp, {complexStruct});
59     return success();
60   }
61 };
62 
63 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
64   using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
65 
66   LogicalResult
67   matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands,
68                   ConversionPatternRewriter &rewriter) const override {
69     complex::ReOp::Adaptor transformed(operands);
70 
71     // Extract real part from the complex number struct.
72     ComplexStructBuilder complexStruct(transformed.complex());
73     Value real = complexStruct.real(rewriter, op.getLoc());
74     rewriter.replaceOp(op, real);
75 
76     return success();
77   }
78 };
79 
80 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
81   using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
82 
83   LogicalResult
84   matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands,
85                   ConversionPatternRewriter &rewriter) const override {
86     complex::ImOp::Adaptor transformed(operands);
87 
88     // Extract imaginary part from the complex number struct.
89     ComplexStructBuilder complexStruct(transformed.complex());
90     Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
91     rewriter.replaceOp(op, imaginary);
92 
93     return success();
94   }
95 };
96 
97 struct BinaryComplexOperands {
98   std::complex<Value> lhs;
99   std::complex<Value> rhs;
100 };
101 
102 template <typename OpTy>
103 BinaryComplexOperands
104 unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
105                             ConversionPatternRewriter &rewriter) {
106   auto loc = op.getLoc();
107   typename OpTy::Adaptor transformed(operands);
108 
109   // Extract real and imaginary values from operands.
110   BinaryComplexOperands unpacked;
111   ComplexStructBuilder lhs(transformed.lhs());
112   unpacked.lhs.real(lhs.real(rewriter, loc));
113   unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
114   ComplexStructBuilder rhs(transformed.rhs());
115   unpacked.rhs.real(rhs.real(rewriter, loc));
116   unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
117 
118   return unpacked;
119 }
120 
121 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
122   using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
123 
124   LogicalResult
125   matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands,
126                   ConversionPatternRewriter &rewriter) const override {
127     auto loc = op.getLoc();
128     BinaryComplexOperands arg =
129         unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter);
130 
131     // Initialize complex number struct for result.
132     auto structType = typeConverter->convertType(op.getType());
133     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
134 
135     // Emit IR to add complex numbers.
136     auto fmf = LLVM::FMFAttr::get({}, op.getContext());
137     Value real =
138         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
139     Value imag =
140         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
141     result.setReal(rewriter, loc, real);
142     result.setImaginary(rewriter, loc, imag);
143 
144     rewriter.replaceOp(op, {result});
145     return success();
146   }
147 };
148 
149 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
150   using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
151 
152   LogicalResult
153   matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
154                   ConversionPatternRewriter &rewriter) const override {
155     auto loc = op.getLoc();
156     BinaryComplexOperands arg =
157         unpackBinaryComplexOperands<complex::DivOp>(op, operands, rewriter);
158 
159     // Initialize complex number struct for result.
160     auto structType = typeConverter->convertType(op.getType());
161     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
162 
163     // Emit IR to add complex numbers.
164     auto fmf = LLVM::FMFAttr::get({}, op.getContext());
165     Value rhsRe = arg.rhs.real();
166     Value rhsIm = arg.rhs.imag();
167     Value lhsRe = arg.lhs.real();
168     Value lhsIm = arg.lhs.imag();
169 
170     Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
171         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
172         rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
173 
174     Value resultReal = rewriter.create<LLVM::FAddOp>(
175         loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
176         rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
177 
178     Value resultImag = rewriter.create<LLVM::FSubOp>(
179         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
180         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
181 
182     result.setReal(
183         rewriter, loc,
184         rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
185     result.setImaginary(
186         rewriter, loc,
187         rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
188 
189     rewriter.replaceOp(op, {result});
190     return success();
191   }
192 };
193 
194 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
195   using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
196 
197   LogicalResult
198   matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
199                   ConversionPatternRewriter &rewriter) const override {
200     auto loc = op.getLoc();
201     BinaryComplexOperands arg =
202         unpackBinaryComplexOperands<complex::MulOp>(op, operands, rewriter);
203 
204     // Initialize complex number struct for result.
205     auto structType = typeConverter->convertType(op.getType());
206     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
207 
208     // Emit IR to add complex numbers.
209     auto fmf = LLVM::FMFAttr::get({}, op.getContext());
210     Value rhsRe = arg.rhs.real();
211     Value rhsIm = arg.rhs.imag();
212     Value lhsRe = arg.lhs.real();
213     Value lhsIm = arg.lhs.imag();
214 
215     Value real = rewriter.create<LLVM::FSubOp>(
216         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
217         rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
218 
219     Value imag = rewriter.create<LLVM::FAddOp>(
220         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
221         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
222 
223     result.setReal(rewriter, loc, real);
224     result.setImaginary(rewriter, loc, imag);
225 
226     rewriter.replaceOp(op, {result});
227     return success();
228   }
229 };
230 
231 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
232   using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
233 
234   LogicalResult
235   matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands,
236                   ConversionPatternRewriter &rewriter) const override {
237     auto loc = op.getLoc();
238     BinaryComplexOperands arg =
239         unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter);
240 
241     // Initialize complex number struct for result.
242     auto structType = typeConverter->convertType(op.getType());
243     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
244 
245     // Emit IR to substract complex numbers.
246     auto fmf = LLVM::FMFAttr::get({}, op.getContext());
247     Value real =
248         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
249     Value imag =
250         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
251     result.setReal(rewriter, loc, real);
252     result.setImaginary(rewriter, loc, imag);
253 
254     rewriter.replaceOp(op, {result});
255     return success();
256   }
257 };
258 } // namespace
259 
260 void mlir::populateComplexToLLVMConversionPatterns(
261     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
262   // clang-format off
263   patterns.insert<
264       AbsOpConversion,
265       AddOpConversion,
266       CreateOpConversion,
267       DivOpConversion,
268       ImOpConversion,
269       MulOpConversion,
270       ReOpConversion,
271       SubOpConversion
272     >(converter);
273   // clang-format on
274 }
275 
276 namespace {
277 struct ConvertComplexToLLVMPass
278     : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
279   void runOnOperation() override;
280 };
281 } // namespace
282 
283 void ConvertComplexToLLVMPass::runOnOperation() {
284   auto module = getOperation();
285 
286   // Convert to the LLVM IR dialect using the converter defined above.
287   OwningRewritePatternList patterns;
288   LLVMTypeConverter converter(&getContext());
289   populateStdToLLVMConversionPatterns(converter, patterns);
290   populateComplexToLLVMConversionPatterns(converter, patterns);
291 
292   LLVMConversionTarget target(getContext());
293   target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
294   if (failed(applyFullConversion(module, target, std::move(patterns))))
295     signalPassFailure();
296 }
297 
298 std::unique_ptr<OperationPass<ModuleOp>>
299 mlir::createConvertComplexToLLVMPass() {
300   return std::make_unique<ConvertComplexToLLVMPass>();
301 }
302