1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Complex/IR/Complex.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 
19 using namespace mlir;
20 using namespace mlir::LLVM;
21 
22 //===----------------------------------------------------------------------===//
23 // ComplexStructBuilder implementation.
24 //===----------------------------------------------------------------------===//
25 
26 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
27 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
28 
29 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
30                                                  Location loc, Type type) {
31   Value val = builder.create<LLVM::UndefOp>(loc, type);
32   return ComplexStructBuilder(val);
33 }
34 
35 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
36                                    Value real) {
37   setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
38 }
39 
40 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
41   return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
42 }
43 
44 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
45                                         Value imaginary) {
46   setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
47 }
48 
49 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
50   return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // Conversion patterns.
55 //===----------------------------------------------------------------------===//
56 
57 namespace {
58 
59 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
60   using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
61 
62   LogicalResult
63   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
64                   ConversionPatternRewriter &rewriter) const override {
65     auto loc = op.getLoc();
66 
67     ComplexStructBuilder complexStruct(adaptor.complex());
68     Value real = complexStruct.real(rewriter, op.getLoc());
69     Value imag = complexStruct.imaginary(rewriter, op.getLoc());
70 
71     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
72     Value sqNorm = rewriter.create<LLVM::FAddOp>(
73         loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
74         rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
75 
76     rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
77     return success();
78   }
79 };
80 
81 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
82   using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
83 
84   LogicalResult
85   matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
86                   ConversionPatternRewriter &rewriter) const override {
87     // Pack real and imaginary part in a complex number struct.
88     auto loc = complexOp.getLoc();
89     auto structType = typeConverter->convertType(complexOp.getType());
90     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
91     complexStruct.setReal(rewriter, loc, adaptor.real());
92     complexStruct.setImaginary(rewriter, loc, adaptor.imaginary());
93 
94     rewriter.replaceOp(complexOp, {complexStruct});
95     return success();
96   }
97 };
98 
99 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
100   using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
101 
102   LogicalResult
103   matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
104                   ConversionPatternRewriter &rewriter) const override {
105     // Extract real part from the complex number struct.
106     ComplexStructBuilder complexStruct(adaptor.complex());
107     Value real = complexStruct.real(rewriter, op.getLoc());
108     rewriter.replaceOp(op, real);
109 
110     return success();
111   }
112 };
113 
114 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
115   using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
116 
117   LogicalResult
118   matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
119                   ConversionPatternRewriter &rewriter) const override {
120     // Extract imaginary part from the complex number struct.
121     ComplexStructBuilder complexStruct(adaptor.complex());
122     Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
123     rewriter.replaceOp(op, imaginary);
124 
125     return success();
126   }
127 };
128 
129 struct BinaryComplexOperands {
130   std::complex<Value> lhs;
131   std::complex<Value> rhs;
132 };
133 
134 template <typename OpTy>
135 BinaryComplexOperands
136 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
137                             ConversionPatternRewriter &rewriter) {
138   auto loc = op.getLoc();
139 
140   // Extract real and imaginary values from operands.
141   BinaryComplexOperands unpacked;
142   ComplexStructBuilder lhs(adaptor.lhs());
143   unpacked.lhs.real(lhs.real(rewriter, loc));
144   unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
145   ComplexStructBuilder rhs(adaptor.rhs());
146   unpacked.rhs.real(rhs.real(rewriter, loc));
147   unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
148 
149   return unpacked;
150 }
151 
152 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
153   using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
154 
155   LogicalResult
156   matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
157                   ConversionPatternRewriter &rewriter) const override {
158     auto loc = op.getLoc();
159     BinaryComplexOperands arg =
160         unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
161 
162     // Initialize complex number struct for result.
163     auto structType = typeConverter->convertType(op.getType());
164     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
165 
166     // Emit IR to add complex numbers.
167     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
168     Value real =
169         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
170     Value imag =
171         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
172     result.setReal(rewriter, loc, real);
173     result.setImaginary(rewriter, loc, imag);
174 
175     rewriter.replaceOp(op, {result});
176     return success();
177   }
178 };
179 
180 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
181   using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
182 
183   LogicalResult
184   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
185                   ConversionPatternRewriter &rewriter) const override {
186     auto loc = op.getLoc();
187     BinaryComplexOperands arg =
188         unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
189 
190     // Initialize complex number struct for result.
191     auto structType = typeConverter->convertType(op.getType());
192     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
193 
194     // Emit IR to add complex numbers.
195     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
196     Value rhsRe = arg.rhs.real();
197     Value rhsIm = arg.rhs.imag();
198     Value lhsRe = arg.lhs.real();
199     Value lhsIm = arg.lhs.imag();
200 
201     Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
202         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
203         rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
204 
205     Value resultReal = rewriter.create<LLVM::FAddOp>(
206         loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
207         rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
208 
209     Value resultImag = rewriter.create<LLVM::FSubOp>(
210         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
211         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
212 
213     result.setReal(
214         rewriter, loc,
215         rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
216     result.setImaginary(
217         rewriter, loc,
218         rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
219 
220     rewriter.replaceOp(op, {result});
221     return success();
222   }
223 };
224 
225 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
226   using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
227 
228   LogicalResult
229   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
230                   ConversionPatternRewriter &rewriter) const override {
231     auto loc = op.getLoc();
232     BinaryComplexOperands arg =
233         unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
234 
235     // Initialize complex number struct for result.
236     auto structType = typeConverter->convertType(op.getType());
237     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
238 
239     // Emit IR to add complex numbers.
240     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
241     Value rhsRe = arg.rhs.real();
242     Value rhsIm = arg.rhs.imag();
243     Value lhsRe = arg.lhs.real();
244     Value lhsIm = arg.lhs.imag();
245 
246     Value real = rewriter.create<LLVM::FSubOp>(
247         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
248         rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
249 
250     Value imag = rewriter.create<LLVM::FAddOp>(
251         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
252         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
253 
254     result.setReal(rewriter, loc, real);
255     result.setImaginary(rewriter, loc, imag);
256 
257     rewriter.replaceOp(op, {result});
258     return success();
259   }
260 };
261 
262 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
263   using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
264 
265   LogicalResult
266   matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
267                   ConversionPatternRewriter &rewriter) const override {
268     auto loc = op.getLoc();
269     BinaryComplexOperands arg =
270         unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
271 
272     // Initialize complex number struct for result.
273     auto structType = typeConverter->convertType(op.getType());
274     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
275 
276     // Emit IR to substract complex numbers.
277     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
278     Value real =
279         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
280     Value imag =
281         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
282     result.setReal(rewriter, loc, real);
283     result.setImaginary(rewriter, loc, imag);
284 
285     rewriter.replaceOp(op, {result});
286     return success();
287   }
288 };
289 } // namespace
290 
291 void mlir::populateComplexToLLVMConversionPatterns(
292     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
293   // clang-format off
294   patterns.add<
295       AbsOpConversion,
296       AddOpConversion,
297       CreateOpConversion,
298       DivOpConversion,
299       ImOpConversion,
300       MulOpConversion,
301       ReOpConversion,
302       SubOpConversion
303     >(converter);
304   // clang-format on
305 }
306 
307 namespace {
308 struct ConvertComplexToLLVMPass
309     : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
310   void runOnOperation() override;
311 };
312 } // namespace
313 
314 void ConvertComplexToLLVMPass::runOnOperation() {
315   auto module = getOperation();
316 
317   // Convert to the LLVM IR dialect using the converter defined above.
318   RewritePatternSet patterns(&getContext());
319   LLVMTypeConverter converter(&getContext());
320   populateComplexToLLVMConversionPatterns(converter, patterns);
321 
322   LLVMConversionTarget target(getContext());
323   target.addLegalOp<ModuleOp, FuncOp>();
324   target.addIllegalDialect<complex::ComplexDialect>();
325   if (failed(applyPartialConversion(module, target, std::move(patterns))))
326     signalPassFailure();
327 }
328 
329 std::unique_ptr<OperationPass<ModuleOp>>
330 mlir::createConvertComplexToLLVMPass() {
331   return std::make_unique<ConvertComplexToLLVMPass>();
332 }
333