1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/Complex/IR/Complex.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 
16 using namespace mlir;
17 using namespace mlir::LLVM;
18 
19 //===----------------------------------------------------------------------===//
20 // ComplexStructBuilder implementation.
21 //===----------------------------------------------------------------------===//
22 
23 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
24 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
25 
26 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
27                                                  Location loc, Type type) {
28   Value val = builder.create<LLVM::UndefOp>(loc, type);
29   return ComplexStructBuilder(val);
30 }
31 
32 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
33                                    Value real) {
34   setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
35 }
36 
37 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
38   return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
39 }
40 
41 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
42                                         Value imaginary) {
43   setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
44 }
45 
46 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
47   return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
48 }
49 
50 //===----------------------------------------------------------------------===//
51 // Conversion patterns.
52 //===----------------------------------------------------------------------===//
53 
54 namespace {
55 
56 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
57   using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
58 
59   LogicalResult
60   matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
61                   ConversionPatternRewriter &rewriter) const override {
62     complex::AbsOp::Adaptor transformed(operands);
63     auto loc = op.getLoc();
64 
65     ComplexStructBuilder complexStruct(transformed.complex());
66     Value real = complexStruct.real(rewriter, op.getLoc());
67     Value imag = complexStruct.imaginary(rewriter, op.getLoc());
68 
69     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
70     Value sqNorm = rewriter.create<LLVM::FAddOp>(
71         loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
72         rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
73 
74     rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
75     return success();
76   }
77 };
78 
79 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
80   using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
81 
82   LogicalResult
83   matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands,
84                   ConversionPatternRewriter &rewriter) const override {
85     complex::CreateOp::Adaptor transformed(operands);
86 
87     // Pack real and imaginary part in a complex number struct.
88     auto loc = complexOp.getLoc();
89     auto structType = typeConverter->convertType(complexOp.getType());
90     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
91     complexStruct.setReal(rewriter, loc, transformed.real());
92     complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
93 
94     rewriter.replaceOp(complexOp, {complexStruct});
95     return success();
96   }
97 };
98 
99 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
100   using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
101 
102   LogicalResult
103   matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands,
104                   ConversionPatternRewriter &rewriter) const override {
105     complex::ReOp::Adaptor transformed(operands);
106 
107     // Extract real part from the complex number struct.
108     ComplexStructBuilder complexStruct(transformed.complex());
109     Value real = complexStruct.real(rewriter, op.getLoc());
110     rewriter.replaceOp(op, real);
111 
112     return success();
113   }
114 };
115 
116 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
117   using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
118 
119   LogicalResult
120   matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands,
121                   ConversionPatternRewriter &rewriter) const override {
122     complex::ImOp::Adaptor transformed(operands);
123 
124     // Extract imaginary part from the complex number struct.
125     ComplexStructBuilder complexStruct(transformed.complex());
126     Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
127     rewriter.replaceOp(op, imaginary);
128 
129     return success();
130   }
131 };
132 
133 struct BinaryComplexOperands {
134   std::complex<Value> lhs;
135   std::complex<Value> rhs;
136 };
137 
138 template <typename OpTy>
139 BinaryComplexOperands
140 unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
141                             ConversionPatternRewriter &rewriter) {
142   auto loc = op.getLoc();
143   typename OpTy::Adaptor transformed(operands);
144 
145   // Extract real and imaginary values from operands.
146   BinaryComplexOperands unpacked;
147   ComplexStructBuilder lhs(transformed.lhs());
148   unpacked.lhs.real(lhs.real(rewriter, loc));
149   unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
150   ComplexStructBuilder rhs(transformed.rhs());
151   unpacked.rhs.real(rhs.real(rewriter, loc));
152   unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
153 
154   return unpacked;
155 }
156 
157 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
158   using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
159 
160   LogicalResult
161   matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands,
162                   ConversionPatternRewriter &rewriter) const override {
163     auto loc = op.getLoc();
164     BinaryComplexOperands arg =
165         unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter);
166 
167     // Initialize complex number struct for result.
168     auto structType = typeConverter->convertType(op.getType());
169     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
170 
171     // Emit IR to add complex numbers.
172     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
173     Value real =
174         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
175     Value imag =
176         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
177     result.setReal(rewriter, loc, real);
178     result.setImaginary(rewriter, loc, imag);
179 
180     rewriter.replaceOp(op, {result});
181     return success();
182   }
183 };
184 
185 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
186   using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
187 
188   LogicalResult
189   matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
190                   ConversionPatternRewriter &rewriter) const override {
191     auto loc = op.getLoc();
192     BinaryComplexOperands arg =
193         unpackBinaryComplexOperands<complex::DivOp>(op, operands, rewriter);
194 
195     // Initialize complex number struct for result.
196     auto structType = typeConverter->convertType(op.getType());
197     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
198 
199     // Emit IR to add complex numbers.
200     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
201     Value rhsRe = arg.rhs.real();
202     Value rhsIm = arg.rhs.imag();
203     Value lhsRe = arg.lhs.real();
204     Value lhsIm = arg.lhs.imag();
205 
206     Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
207         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
208         rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
209 
210     Value resultReal = rewriter.create<LLVM::FAddOp>(
211         loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
212         rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
213 
214     Value resultImag = rewriter.create<LLVM::FSubOp>(
215         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
216         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
217 
218     result.setReal(
219         rewriter, loc,
220         rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
221     result.setImaginary(
222         rewriter, loc,
223         rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
224 
225     rewriter.replaceOp(op, {result});
226     return success();
227   }
228 };
229 
230 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
231   using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
232 
233   LogicalResult
234   matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
235                   ConversionPatternRewriter &rewriter) const override {
236     auto loc = op.getLoc();
237     BinaryComplexOperands arg =
238         unpackBinaryComplexOperands<complex::MulOp>(op, operands, rewriter);
239 
240     // Initialize complex number struct for result.
241     auto structType = typeConverter->convertType(op.getType());
242     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
243 
244     // Emit IR to add complex numbers.
245     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
246     Value rhsRe = arg.rhs.real();
247     Value rhsIm = arg.rhs.imag();
248     Value lhsRe = arg.lhs.real();
249     Value lhsIm = arg.lhs.imag();
250 
251     Value real = rewriter.create<LLVM::FSubOp>(
252         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
253         rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
254 
255     Value imag = rewriter.create<LLVM::FAddOp>(
256         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
257         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
258 
259     result.setReal(rewriter, loc, real);
260     result.setImaginary(rewriter, loc, imag);
261 
262     rewriter.replaceOp(op, {result});
263     return success();
264   }
265 };
266 
267 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
268   using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
269 
270   LogicalResult
271   matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands,
272                   ConversionPatternRewriter &rewriter) const override {
273     auto loc = op.getLoc();
274     BinaryComplexOperands arg =
275         unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter);
276 
277     // Initialize complex number struct for result.
278     auto structType = typeConverter->convertType(op.getType());
279     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
280 
281     // Emit IR to substract complex numbers.
282     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
283     Value real =
284         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
285     Value imag =
286         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
287     result.setReal(rewriter, loc, real);
288     result.setImaginary(rewriter, loc, imag);
289 
290     rewriter.replaceOp(op, {result});
291     return success();
292   }
293 };
294 } // namespace
295 
296 void mlir::populateComplexToLLVMConversionPatterns(
297     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
298   // clang-format off
299   patterns.add<
300       AbsOpConversion,
301       AddOpConversion,
302       CreateOpConversion,
303       DivOpConversion,
304       ImOpConversion,
305       MulOpConversion,
306       ReOpConversion,
307       SubOpConversion
308     >(converter);
309   // clang-format on
310 }
311 
312 namespace {
313 struct ConvertComplexToLLVMPass
314     : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
315   void runOnOperation() override;
316 };
317 } // namespace
318 
319 void ConvertComplexToLLVMPass::runOnOperation() {
320   auto module = getOperation();
321 
322   // Convert to the LLVM IR dialect using the converter defined above.
323   RewritePatternSet patterns(&getContext());
324   LLVMTypeConverter converter(&getContext());
325   populateComplexToLLVMConversionPatterns(converter, patterns);
326 
327   LLVMConversionTarget target(getContext());
328   target.addLegalOp<ModuleOp, FuncOp>();
329   target.addLegalOp<LLVM::DialectCastOp>();
330   target.addIllegalDialect<complex::ComplexDialect>();
331   if (failed(applyPartialConversion(module, target, std::move(patterns))))
332     signalPassFailure();
333 }
334 
335 std::unique_ptr<OperationPass<ModuleOp>>
336 mlir::createConvertComplexToLLVMPass() {
337   return std::make_unique<ConvertComplexToLLVMPass>();
338 }
339