1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Complex/IR/Complex.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17
18 using namespace mlir;
19 using namespace mlir::LLVM;
20
21 //===----------------------------------------------------------------------===//
22 // ComplexStructBuilder implementation.
23 //===----------------------------------------------------------------------===//
24
25 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
26 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
27
undef(OpBuilder & builder,Location loc,Type type)28 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
29 Location loc, Type type) {
30 Value val = builder.create<LLVM::UndefOp>(loc, type);
31 return ComplexStructBuilder(val);
32 }
33
setReal(OpBuilder & builder,Location loc,Value real)34 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
35 Value real) {
36 setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
37 }
38
real(OpBuilder & builder,Location loc)39 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
40 return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
41 }
42
setImaginary(OpBuilder & builder,Location loc,Value imaginary)43 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
44 Value imaginary) {
45 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
46 }
47
imaginary(OpBuilder & builder,Location loc)48 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
49 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
50 }
51
52 //===----------------------------------------------------------------------===//
53 // Conversion patterns.
54 //===----------------------------------------------------------------------===//
55
56 namespace {
57
58 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
59 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
60
61 LogicalResult
matchAndRewrite__anone1f7db290111::AbsOpConversion62 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
63 ConversionPatternRewriter &rewriter) const override {
64 auto loc = op.getLoc();
65
66 ComplexStructBuilder complexStruct(adaptor.getComplex());
67 Value real = complexStruct.real(rewriter, op.getLoc());
68 Value imag = complexStruct.imaginary(rewriter, op.getLoc());
69
70 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
71 Value sqNorm = rewriter.create<LLVM::FAddOp>(
72 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
73 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
74
75 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
76 return success();
77 }
78 };
79
80 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
81 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
82
83 LogicalResult
matchAndRewrite__anone1f7db290111::ConstantOpLowering84 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 return LLVM::detail::oneToOneRewrite(
87 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
88 *getTypeConverter(), rewriter);
89 }
90 };
91
92 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
93 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
94
95 LogicalResult
matchAndRewrite__anone1f7db290111::CreateOpConversion96 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
97 ConversionPatternRewriter &rewriter) const override {
98 // Pack real and imaginary part in a complex number struct.
99 auto loc = complexOp.getLoc();
100 auto structType = typeConverter->convertType(complexOp.getType());
101 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
102 complexStruct.setReal(rewriter, loc, adaptor.getReal());
103 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
104
105 rewriter.replaceOp(complexOp, {complexStruct});
106 return success();
107 }
108 };
109
110 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
111 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
112
113 LogicalResult
matchAndRewrite__anone1f7db290111::ReOpConversion114 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
115 ConversionPatternRewriter &rewriter) const override {
116 // Extract real part from the complex number struct.
117 ComplexStructBuilder complexStruct(adaptor.getComplex());
118 Value real = complexStruct.real(rewriter, op.getLoc());
119 rewriter.replaceOp(op, real);
120
121 return success();
122 }
123 };
124
125 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
126 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
127
128 LogicalResult
matchAndRewrite__anone1f7db290111::ImOpConversion129 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter) const override {
131 // Extract imaginary part from the complex number struct.
132 ComplexStructBuilder complexStruct(adaptor.getComplex());
133 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
134 rewriter.replaceOp(op, imaginary);
135
136 return success();
137 }
138 };
139
140 struct BinaryComplexOperands {
141 std::complex<Value> lhs;
142 std::complex<Value> rhs;
143 };
144
145 template <typename OpTy>
146 BinaryComplexOperands
unpackBinaryComplexOperands(OpTy op,typename OpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter)147 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
148 ConversionPatternRewriter &rewriter) {
149 auto loc = op.getLoc();
150
151 // Extract real and imaginary values from operands.
152 BinaryComplexOperands unpacked;
153 ComplexStructBuilder lhs(adaptor.getLhs());
154 unpacked.lhs.real(lhs.real(rewriter, loc));
155 unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
156 ComplexStructBuilder rhs(adaptor.getRhs());
157 unpacked.rhs.real(rhs.real(rewriter, loc));
158 unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
159
160 return unpacked;
161 }
162
163 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
164 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
165
166 LogicalResult
matchAndRewrite__anone1f7db290111::AddOpConversion167 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
168 ConversionPatternRewriter &rewriter) const override {
169 auto loc = op.getLoc();
170 BinaryComplexOperands arg =
171 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
172
173 // Initialize complex number struct for result.
174 auto structType = typeConverter->convertType(op.getType());
175 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
176
177 // Emit IR to add complex numbers.
178 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
179 Value real =
180 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
181 Value imag =
182 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
183 result.setReal(rewriter, loc, real);
184 result.setImaginary(rewriter, loc, imag);
185
186 rewriter.replaceOp(op, {result});
187 return success();
188 }
189 };
190
191 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
192 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
193
194 LogicalResult
matchAndRewrite__anone1f7db290111::DivOpConversion195 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
196 ConversionPatternRewriter &rewriter) const override {
197 auto loc = op.getLoc();
198 BinaryComplexOperands arg =
199 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
200
201 // Initialize complex number struct for result.
202 auto structType = typeConverter->convertType(op.getType());
203 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
204
205 // Emit IR to add complex numbers.
206 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
207 Value rhsRe = arg.rhs.real();
208 Value rhsIm = arg.rhs.imag();
209 Value lhsRe = arg.lhs.real();
210 Value lhsIm = arg.lhs.imag();
211
212 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
213 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
214 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
215
216 Value resultReal = rewriter.create<LLVM::FAddOp>(
217 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
218 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
219
220 Value resultImag = rewriter.create<LLVM::FSubOp>(
221 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
222 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
223
224 result.setReal(
225 rewriter, loc,
226 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
227 result.setImaginary(
228 rewriter, loc,
229 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
230
231 rewriter.replaceOp(op, {result});
232 return success();
233 }
234 };
235
236 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
237 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
238
239 LogicalResult
matchAndRewrite__anone1f7db290111::MulOpConversion240 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
241 ConversionPatternRewriter &rewriter) const override {
242 auto loc = op.getLoc();
243 BinaryComplexOperands arg =
244 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
245
246 // Initialize complex number struct for result.
247 auto structType = typeConverter->convertType(op.getType());
248 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
249
250 // Emit IR to add complex numbers.
251 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
252 Value rhsRe = arg.rhs.real();
253 Value rhsIm = arg.rhs.imag();
254 Value lhsRe = arg.lhs.real();
255 Value lhsIm = arg.lhs.imag();
256
257 Value real = rewriter.create<LLVM::FSubOp>(
258 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
259 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
260
261 Value imag = rewriter.create<LLVM::FAddOp>(
262 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
263 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
264
265 result.setReal(rewriter, loc, real);
266 result.setImaginary(rewriter, loc, imag);
267
268 rewriter.replaceOp(op, {result});
269 return success();
270 }
271 };
272
273 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
274 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
275
276 LogicalResult
matchAndRewrite__anone1f7db290111::SubOpConversion277 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter) const override {
279 auto loc = op.getLoc();
280 BinaryComplexOperands arg =
281 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
282
283 // Initialize complex number struct for result.
284 auto structType = typeConverter->convertType(op.getType());
285 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
286
287 // Emit IR to substract complex numbers.
288 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
289 Value real =
290 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
291 Value imag =
292 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
293 result.setReal(rewriter, loc, real);
294 result.setImaginary(rewriter, loc, imag);
295
296 rewriter.replaceOp(op, {result});
297 return success();
298 }
299 };
300 } // namespace
301
populateComplexToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)302 void mlir::populateComplexToLLVMConversionPatterns(
303 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
304 // clang-format off
305 patterns.add<
306 AbsOpConversion,
307 AddOpConversion,
308 ConstantOpLowering,
309 CreateOpConversion,
310 DivOpConversion,
311 ImOpConversion,
312 MulOpConversion,
313 ReOpConversion,
314 SubOpConversion
315 >(converter);
316 // clang-format on
317 }
318
319 namespace {
320 struct ConvertComplexToLLVMPass
321 : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
322 void runOnOperation() override;
323 };
324 } // namespace
325
runOnOperation()326 void ConvertComplexToLLVMPass::runOnOperation() {
327 // Convert to the LLVM IR dialect using the converter defined above.
328 RewritePatternSet patterns(&getContext());
329 LLVMTypeConverter converter(&getContext());
330 populateComplexToLLVMConversionPatterns(converter, patterns);
331
332 LLVMConversionTarget target(getContext());
333 target.addIllegalDialect<complex::ComplexDialect>();
334 if (failed(
335 applyPartialConversion(getOperation(), target, std::move(patterns))))
336 signalPassFailure();
337 }
338
createConvertComplexToLLVMPass()339 std::unique_ptr<Pass> mlir::createConvertComplexToLLVMPass() {
340 return std::make_unique<ConvertComplexToLLVMPass>();
341 }
342