1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Complex/IR/Complex.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 
18 using namespace mlir;
19 using namespace mlir::LLVM;
20 
21 //===----------------------------------------------------------------------===//
22 // ComplexStructBuilder implementation.
23 //===----------------------------------------------------------------------===//
24 
25 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
26 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
27 
28 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
29                                                  Location loc, Type type) {
30   Value val = builder.create<LLVM::UndefOp>(loc, type);
31   return ComplexStructBuilder(val);
32 }
33 
34 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
35                                    Value real) {
36   setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
37 }
38 
39 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
40   return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
41 }
42 
43 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
44                                         Value imaginary) {
45   setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
46 }
47 
48 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
49   return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // Conversion patterns.
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 
58 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
59   using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
60 
61   LogicalResult
62   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
63                   ConversionPatternRewriter &rewriter) const override {
64     auto loc = op.getLoc();
65 
66     ComplexStructBuilder complexStruct(adaptor.getComplex());
67     Value real = complexStruct.real(rewriter, op.getLoc());
68     Value imag = complexStruct.imaginary(rewriter, op.getLoc());
69 
70     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
71     Value sqNorm = rewriter.create<LLVM::FAddOp>(
72         loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
73         rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
74 
75     rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
76     return success();
77   }
78 };
79 
80 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
81   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
82 
83   LogicalResult
84   matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
85                   ConversionPatternRewriter &rewriter) const override {
86     return LLVM::detail::oneToOneRewrite(
87         op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
88         *getTypeConverter(), rewriter);
89   }
90 };
91 
92 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
93   using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
94 
95   LogicalResult
96   matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
97                   ConversionPatternRewriter &rewriter) const override {
98     // Pack real and imaginary part in a complex number struct.
99     auto loc = complexOp.getLoc();
100     auto structType = typeConverter->convertType(complexOp.getType());
101     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
102     complexStruct.setReal(rewriter, loc, adaptor.getReal());
103     complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
104 
105     rewriter.replaceOp(complexOp, {complexStruct});
106     return success();
107   }
108 };
109 
110 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
111   using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
112 
113   LogicalResult
114   matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
115                   ConversionPatternRewriter &rewriter) const override {
116     // Extract real part from the complex number struct.
117     ComplexStructBuilder complexStruct(adaptor.getComplex());
118     Value real = complexStruct.real(rewriter, op.getLoc());
119     rewriter.replaceOp(op, real);
120 
121     return success();
122   }
123 };
124 
125 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
126   using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
127 
128   LogicalResult
129   matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
130                   ConversionPatternRewriter &rewriter) const override {
131     // Extract imaginary part from the complex number struct.
132     ComplexStructBuilder complexStruct(adaptor.getComplex());
133     Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
134     rewriter.replaceOp(op, imaginary);
135 
136     return success();
137   }
138 };
139 
140 struct BinaryComplexOperands {
141   std::complex<Value> lhs;
142   std::complex<Value> rhs;
143 };
144 
145 template <typename OpTy>
146 BinaryComplexOperands
147 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
148                             ConversionPatternRewriter &rewriter) {
149   auto loc = op.getLoc();
150 
151   // Extract real and imaginary values from operands.
152   BinaryComplexOperands unpacked;
153   ComplexStructBuilder lhs(adaptor.getLhs());
154   unpacked.lhs.real(lhs.real(rewriter, loc));
155   unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
156   ComplexStructBuilder rhs(adaptor.getRhs());
157   unpacked.rhs.real(rhs.real(rewriter, loc));
158   unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
159 
160   return unpacked;
161 }
162 
163 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
164   using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
165 
166   LogicalResult
167   matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
168                   ConversionPatternRewriter &rewriter) const override {
169     auto loc = op.getLoc();
170     BinaryComplexOperands arg =
171         unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
172 
173     // Initialize complex number struct for result.
174     auto structType = typeConverter->convertType(op.getType());
175     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
176 
177     // Emit IR to add complex numbers.
178     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
179     Value real =
180         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
181     Value imag =
182         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
183     result.setReal(rewriter, loc, real);
184     result.setImaginary(rewriter, loc, imag);
185 
186     rewriter.replaceOp(op, {result});
187     return success();
188   }
189 };
190 
191 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
192   using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
193 
194   LogicalResult
195   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
196                   ConversionPatternRewriter &rewriter) const override {
197     auto loc = op.getLoc();
198     BinaryComplexOperands arg =
199         unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
200 
201     // Initialize complex number struct for result.
202     auto structType = typeConverter->convertType(op.getType());
203     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
204 
205     // Emit IR to add complex numbers.
206     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
207     Value rhsRe = arg.rhs.real();
208     Value rhsIm = arg.rhs.imag();
209     Value lhsRe = arg.lhs.real();
210     Value lhsIm = arg.lhs.imag();
211 
212     Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
213         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
214         rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
215 
216     Value resultReal = rewriter.create<LLVM::FAddOp>(
217         loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
218         rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
219 
220     Value resultImag = rewriter.create<LLVM::FSubOp>(
221         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
222         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
223 
224     result.setReal(
225         rewriter, loc,
226         rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
227     result.setImaginary(
228         rewriter, loc,
229         rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
230 
231     rewriter.replaceOp(op, {result});
232     return success();
233   }
234 };
235 
236 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
237   using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
238 
239   LogicalResult
240   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
241                   ConversionPatternRewriter &rewriter) const override {
242     auto loc = op.getLoc();
243     BinaryComplexOperands arg =
244         unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
245 
246     // Initialize complex number struct for result.
247     auto structType = typeConverter->convertType(op.getType());
248     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
249 
250     // Emit IR to add complex numbers.
251     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
252     Value rhsRe = arg.rhs.real();
253     Value rhsIm = arg.rhs.imag();
254     Value lhsRe = arg.lhs.real();
255     Value lhsIm = arg.lhs.imag();
256 
257     Value real = rewriter.create<LLVM::FSubOp>(
258         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
259         rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
260 
261     Value imag = rewriter.create<LLVM::FAddOp>(
262         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
263         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
264 
265     result.setReal(rewriter, loc, real);
266     result.setImaginary(rewriter, loc, imag);
267 
268     rewriter.replaceOp(op, {result});
269     return success();
270   }
271 };
272 
273 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
274   using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
275 
276   LogicalResult
277   matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
278                   ConversionPatternRewriter &rewriter) const override {
279     auto loc = op.getLoc();
280     BinaryComplexOperands arg =
281         unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
282 
283     // Initialize complex number struct for result.
284     auto structType = typeConverter->convertType(op.getType());
285     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
286 
287     // Emit IR to substract complex numbers.
288     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
289     Value real =
290         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
291     Value imag =
292         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
293     result.setReal(rewriter, loc, real);
294     result.setImaginary(rewriter, loc, imag);
295 
296     rewriter.replaceOp(op, {result});
297     return success();
298   }
299 };
300 } // namespace
301 
302 void mlir::populateComplexToLLVMConversionPatterns(
303     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
304   // clang-format off
305   patterns.add<
306       AbsOpConversion,
307       AddOpConversion,
308       ConstantOpLowering,
309       CreateOpConversion,
310       DivOpConversion,
311       ImOpConversion,
312       MulOpConversion,
313       ReOpConversion,
314       SubOpConversion
315     >(converter);
316   // clang-format on
317 }
318 
319 namespace {
320 struct ConvertComplexToLLVMPass
321     : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
322   void runOnOperation() override;
323 };
324 } // namespace
325 
326 void ConvertComplexToLLVMPass::runOnOperation() {
327   // Convert to the LLVM IR dialect using the converter defined above.
328   RewritePatternSet patterns(&getContext());
329   LLVMTypeConverter converter(&getContext());
330   populateComplexToLLVMConversionPatterns(converter, patterns);
331 
332   LLVMConversionTarget target(getContext());
333   target.addIllegalDialect<complex::ComplexDialect>();
334   if (failed(
335           applyPartialConversion(getOperation(), target, std::move(patterns))))
336     signalPassFailure();
337 }
338 
339 std::unique_ptr<Pass> mlir::createConvertComplexToLLVMPass() {
340   return std::make_unique<ConvertComplexToLLVMPass>();
341 }
342