1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Complex/IR/Complex.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 
19 using namespace mlir;
20 using namespace mlir::LLVM;
21 
22 //===----------------------------------------------------------------------===//
23 // ComplexStructBuilder implementation.
24 //===----------------------------------------------------------------------===//
25 
26 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
27 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
28 
29 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
30                                                  Location loc, Type type) {
31   Value val = builder.create<LLVM::UndefOp>(loc, type);
32   return ComplexStructBuilder(val);
33 }
34 
35 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
36                                    Value real) {
37   setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
38 }
39 
40 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
41   return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
42 }
43 
44 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
45                                         Value imaginary) {
46   setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
47 }
48 
49 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
50   return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // Conversion patterns.
55 //===----------------------------------------------------------------------===//
56 
57 namespace {
58 
59 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
60   using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
61 
62   LogicalResult
63   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
64                   ConversionPatternRewriter &rewriter) const override {
65     auto loc = op.getLoc();
66 
67     ComplexStructBuilder complexStruct(adaptor.getComplex());
68     Value real = complexStruct.real(rewriter, op.getLoc());
69     Value imag = complexStruct.imaginary(rewriter, op.getLoc());
70 
71     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
72     Value sqNorm = rewriter.create<LLVM::FAddOp>(
73         loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
74         rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
75 
76     rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
77     return success();
78   }
79 };
80 
81 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
82   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
83 
84   LogicalResult
85   matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
86                   ConversionPatternRewriter &rewriter) const override {
87     return LLVM::detail::oneToOneRewrite(
88         op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
89         *getTypeConverter(), rewriter);
90   }
91 };
92 
93 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
94   using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
95 
96   LogicalResult
97   matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
98                   ConversionPatternRewriter &rewriter) const override {
99     // Pack real and imaginary part in a complex number struct.
100     auto loc = complexOp.getLoc();
101     auto structType = typeConverter->convertType(complexOp.getType());
102     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
103     complexStruct.setReal(rewriter, loc, adaptor.getReal());
104     complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
105 
106     rewriter.replaceOp(complexOp, {complexStruct});
107     return success();
108   }
109 };
110 
111 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
112   using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
113 
114   LogicalResult
115   matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
116                   ConversionPatternRewriter &rewriter) const override {
117     // Extract real part from the complex number struct.
118     ComplexStructBuilder complexStruct(adaptor.getComplex());
119     Value real = complexStruct.real(rewriter, op.getLoc());
120     rewriter.replaceOp(op, real);
121 
122     return success();
123   }
124 };
125 
126 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
127   using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
128 
129   LogicalResult
130   matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
131                   ConversionPatternRewriter &rewriter) const override {
132     // Extract imaginary part from the complex number struct.
133     ComplexStructBuilder complexStruct(adaptor.getComplex());
134     Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
135     rewriter.replaceOp(op, imaginary);
136 
137     return success();
138   }
139 };
140 
141 struct BinaryComplexOperands {
142   std::complex<Value> lhs;
143   std::complex<Value> rhs;
144 };
145 
146 template <typename OpTy>
147 BinaryComplexOperands
148 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
149                             ConversionPatternRewriter &rewriter) {
150   auto loc = op.getLoc();
151 
152   // Extract real and imaginary values from operands.
153   BinaryComplexOperands unpacked;
154   ComplexStructBuilder lhs(adaptor.getLhs());
155   unpacked.lhs.real(lhs.real(rewriter, loc));
156   unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
157   ComplexStructBuilder rhs(adaptor.getRhs());
158   unpacked.rhs.real(rhs.real(rewriter, loc));
159   unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
160 
161   return unpacked;
162 }
163 
164 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
165   using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
166 
167   LogicalResult
168   matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
169                   ConversionPatternRewriter &rewriter) const override {
170     auto loc = op.getLoc();
171     BinaryComplexOperands arg =
172         unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
173 
174     // Initialize complex number struct for result.
175     auto structType = typeConverter->convertType(op.getType());
176     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
177 
178     // Emit IR to add complex numbers.
179     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
180     Value real =
181         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
182     Value imag =
183         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
184     result.setReal(rewriter, loc, real);
185     result.setImaginary(rewriter, loc, imag);
186 
187     rewriter.replaceOp(op, {result});
188     return success();
189   }
190 };
191 
192 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
193   using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
194 
195   LogicalResult
196   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
197                   ConversionPatternRewriter &rewriter) const override {
198     auto loc = op.getLoc();
199     BinaryComplexOperands arg =
200         unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
201 
202     // Initialize complex number struct for result.
203     auto structType = typeConverter->convertType(op.getType());
204     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
205 
206     // Emit IR to add complex numbers.
207     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
208     Value rhsRe = arg.rhs.real();
209     Value rhsIm = arg.rhs.imag();
210     Value lhsRe = arg.lhs.real();
211     Value lhsIm = arg.lhs.imag();
212 
213     Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
214         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
215         rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
216 
217     Value resultReal = rewriter.create<LLVM::FAddOp>(
218         loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
219         rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
220 
221     Value resultImag = rewriter.create<LLVM::FSubOp>(
222         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
223         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
224 
225     result.setReal(
226         rewriter, loc,
227         rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
228     result.setImaginary(
229         rewriter, loc,
230         rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
231 
232     rewriter.replaceOp(op, {result});
233     return success();
234   }
235 };
236 
237 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
238   using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
239 
240   LogicalResult
241   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
242                   ConversionPatternRewriter &rewriter) const override {
243     auto loc = op.getLoc();
244     BinaryComplexOperands arg =
245         unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
246 
247     // Initialize complex number struct for result.
248     auto structType = typeConverter->convertType(op.getType());
249     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
250 
251     // Emit IR to add complex numbers.
252     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
253     Value rhsRe = arg.rhs.real();
254     Value rhsIm = arg.rhs.imag();
255     Value lhsRe = arg.lhs.real();
256     Value lhsIm = arg.lhs.imag();
257 
258     Value real = rewriter.create<LLVM::FSubOp>(
259         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
260         rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
261 
262     Value imag = rewriter.create<LLVM::FAddOp>(
263         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
264         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
265 
266     result.setReal(rewriter, loc, real);
267     result.setImaginary(rewriter, loc, imag);
268 
269     rewriter.replaceOp(op, {result});
270     return success();
271   }
272 };
273 
274 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
275   using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
276 
277   LogicalResult
278   matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
279                   ConversionPatternRewriter &rewriter) const override {
280     auto loc = op.getLoc();
281     BinaryComplexOperands arg =
282         unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
283 
284     // Initialize complex number struct for result.
285     auto structType = typeConverter->convertType(op.getType());
286     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
287 
288     // Emit IR to substract complex numbers.
289     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
290     Value real =
291         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
292     Value imag =
293         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
294     result.setReal(rewriter, loc, real);
295     result.setImaginary(rewriter, loc, imag);
296 
297     rewriter.replaceOp(op, {result});
298     return success();
299   }
300 };
301 } // namespace
302 
303 void mlir::populateComplexToLLVMConversionPatterns(
304     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
305   // clang-format off
306   patterns.add<
307       AbsOpConversion,
308       AddOpConversion,
309       ConstantOpLowering,
310       CreateOpConversion,
311       DivOpConversion,
312       ImOpConversion,
313       MulOpConversion,
314       ReOpConversion,
315       SubOpConversion
316     >(converter);
317   // clang-format on
318 }
319 
320 namespace {
321 struct ConvertComplexToLLVMPass
322     : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
323   void runOnOperation() override;
324 };
325 } // namespace
326 
327 void ConvertComplexToLLVMPass::runOnOperation() {
328   auto module = getOperation();
329 
330   // Convert to the LLVM IR dialect using the converter defined above.
331   RewritePatternSet patterns(&getContext());
332   LLVMTypeConverter converter(&getContext());
333   populateComplexToLLVMConversionPatterns(converter, patterns);
334 
335   LLVMConversionTarget target(getContext());
336   target.addLegalOp<ModuleOp, FuncOp>();
337   target.addIllegalDialect<complex::ComplexDialect>();
338   if (failed(applyPartialConversion(module, target, std::move(patterns))))
339     signalPassFailure();
340 }
341 
342 std::unique_ptr<OperationPass<ModuleOp>>
343 mlir::createConvertComplexToLLVMPass() {
344   return std::make_unique<ConvertComplexToLLVMPass>();
345 }
346