1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Dialect/Complex/IR/Complex.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 17 using namespace mlir; 18 using namespace mlir::LLVM; 19 20 //===----------------------------------------------------------------------===// 21 // ComplexStructBuilder implementation. 22 //===----------------------------------------------------------------------===// 23 24 static constexpr unsigned kRealPosInComplexNumberStruct = 0; 25 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; 26 27 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, 28 Location loc, Type type) { 29 Value val = builder.create<LLVM::UndefOp>(loc, type); 30 return ComplexStructBuilder(val); 31 } 32 33 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, 34 Value real) { 35 setPtr(builder, loc, kRealPosInComplexNumberStruct, real); 36 } 37 38 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { 39 return extractPtr(builder, loc, kRealPosInComplexNumberStruct); 40 } 41 42 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, 43 Value imaginary) { 44 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); 45 } 46 47 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { 48 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); 49 } 50 51 //===----------------------------------------------------------------------===// 52 // Conversion patterns. 53 //===----------------------------------------------------------------------===// 54 55 namespace { 56 57 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { 58 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; 59 60 LogicalResult 61 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 62 ConversionPatternRewriter &rewriter) const override { 63 auto loc = op.getLoc(); 64 65 ComplexStructBuilder complexStruct(adaptor.complex()); 66 Value real = complexStruct.real(rewriter, op.getLoc()); 67 Value imag = complexStruct.imaginary(rewriter, op.getLoc()); 68 69 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 70 Value sqNorm = rewriter.create<LLVM::FAddOp>( 71 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), 72 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); 73 74 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); 75 return success(); 76 } 77 }; 78 79 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { 80 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; 81 82 LogicalResult 83 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor, 84 ConversionPatternRewriter &rewriter) const override { 85 // Pack real and imaginary part in a complex number struct. 86 auto loc = complexOp.getLoc(); 87 auto structType = typeConverter->convertType(complexOp.getType()); 88 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); 89 complexStruct.setReal(rewriter, loc, adaptor.real()); 90 complexStruct.setImaginary(rewriter, loc, adaptor.imaginary()); 91 92 rewriter.replaceOp(complexOp, {complexStruct}); 93 return success(); 94 } 95 }; 96 97 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { 98 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; 99 100 LogicalResult 101 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor, 102 ConversionPatternRewriter &rewriter) const override { 103 // Extract real part from the complex number struct. 104 ComplexStructBuilder complexStruct(adaptor.complex()); 105 Value real = complexStruct.real(rewriter, op.getLoc()); 106 rewriter.replaceOp(op, real); 107 108 return success(); 109 } 110 }; 111 112 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { 113 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; 114 115 LogicalResult 116 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor, 117 ConversionPatternRewriter &rewriter) const override { 118 // Extract imaginary part from the complex number struct. 119 ComplexStructBuilder complexStruct(adaptor.complex()); 120 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); 121 rewriter.replaceOp(op, imaginary); 122 123 return success(); 124 } 125 }; 126 127 struct BinaryComplexOperands { 128 std::complex<Value> lhs; 129 std::complex<Value> rhs; 130 }; 131 132 template <typename OpTy> 133 BinaryComplexOperands 134 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor, 135 ConversionPatternRewriter &rewriter) { 136 auto loc = op.getLoc(); 137 138 // Extract real and imaginary values from operands. 139 BinaryComplexOperands unpacked; 140 ComplexStructBuilder lhs(adaptor.lhs()); 141 unpacked.lhs.real(lhs.real(rewriter, loc)); 142 unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); 143 ComplexStructBuilder rhs(adaptor.rhs()); 144 unpacked.rhs.real(rhs.real(rewriter, loc)); 145 unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); 146 147 return unpacked; 148 } 149 150 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { 151 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; 152 153 LogicalResult 154 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor, 155 ConversionPatternRewriter &rewriter) const override { 156 auto loc = op.getLoc(); 157 BinaryComplexOperands arg = 158 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter); 159 160 // Initialize complex number struct for result. 161 auto structType = typeConverter->convertType(op.getType()); 162 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 163 164 // Emit IR to add complex numbers. 165 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 166 Value real = 167 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 168 Value imag = 169 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 170 result.setReal(rewriter, loc, real); 171 result.setImaginary(rewriter, loc, imag); 172 173 rewriter.replaceOp(op, {result}); 174 return success(); 175 } 176 }; 177 178 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { 179 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; 180 181 LogicalResult 182 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 183 ConversionPatternRewriter &rewriter) const override { 184 auto loc = op.getLoc(); 185 BinaryComplexOperands arg = 186 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter); 187 188 // Initialize complex number struct for result. 189 auto structType = typeConverter->convertType(op.getType()); 190 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 191 192 // Emit IR to add complex numbers. 193 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 194 Value rhsRe = arg.rhs.real(); 195 Value rhsIm = arg.rhs.imag(); 196 Value lhsRe = arg.lhs.real(); 197 Value lhsIm = arg.lhs.imag(); 198 199 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>( 200 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf), 201 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf); 202 203 Value resultReal = rewriter.create<LLVM::FAddOp>( 204 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf), 205 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf); 206 207 Value resultImag = rewriter.create<LLVM::FSubOp>( 208 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 209 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 210 211 result.setReal( 212 rewriter, loc, 213 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf)); 214 result.setImaginary( 215 rewriter, loc, 216 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf)); 217 218 rewriter.replaceOp(op, {result}); 219 return success(); 220 } 221 }; 222 223 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { 224 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; 225 226 LogicalResult 227 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 228 ConversionPatternRewriter &rewriter) const override { 229 auto loc = op.getLoc(); 230 BinaryComplexOperands arg = 231 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter); 232 233 // Initialize complex number struct for result. 234 auto structType = typeConverter->convertType(op.getType()); 235 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 236 237 // Emit IR to add complex numbers. 238 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 239 Value rhsRe = arg.rhs.real(); 240 Value rhsIm = arg.rhs.imag(); 241 Value lhsRe = arg.lhs.real(); 242 Value lhsIm = arg.lhs.imag(); 243 244 Value real = rewriter.create<LLVM::FSubOp>( 245 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), 246 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); 247 248 Value imag = rewriter.create<LLVM::FAddOp>( 249 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 250 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 251 252 result.setReal(rewriter, loc, real); 253 result.setImaginary(rewriter, loc, imag); 254 255 rewriter.replaceOp(op, {result}); 256 return success(); 257 } 258 }; 259 260 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { 261 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; 262 263 LogicalResult 264 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor, 265 ConversionPatternRewriter &rewriter) const override { 266 auto loc = op.getLoc(); 267 BinaryComplexOperands arg = 268 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter); 269 270 // Initialize complex number struct for result. 271 auto structType = typeConverter->convertType(op.getType()); 272 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 273 274 // Emit IR to substract complex numbers. 275 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 276 Value real = 277 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 278 Value imag = 279 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 280 result.setReal(rewriter, loc, real); 281 result.setImaginary(rewriter, loc, imag); 282 283 rewriter.replaceOp(op, {result}); 284 return success(); 285 } 286 }; 287 } // namespace 288 289 void mlir::populateComplexToLLVMConversionPatterns( 290 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 291 // clang-format off 292 patterns.add< 293 AbsOpConversion, 294 AddOpConversion, 295 CreateOpConversion, 296 DivOpConversion, 297 ImOpConversion, 298 MulOpConversion, 299 ReOpConversion, 300 SubOpConversion 301 >(converter); 302 // clang-format on 303 } 304 305 namespace { 306 struct ConvertComplexToLLVMPass 307 : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> { 308 void runOnOperation() override; 309 }; 310 } // namespace 311 312 void ConvertComplexToLLVMPass::runOnOperation() { 313 auto module = getOperation(); 314 315 // Convert to the LLVM IR dialect using the converter defined above. 316 RewritePatternSet patterns(&getContext()); 317 LLVMTypeConverter converter(&getContext()); 318 populateComplexToLLVMConversionPatterns(converter, patterns); 319 320 LLVMConversionTarget target(getContext()); 321 target.addLegalOp<ModuleOp, FuncOp>(); 322 target.addIllegalDialect<complex::ComplexDialect>(); 323 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 324 signalPassFailure(); 325 } 326 327 std::unique_ptr<OperationPass<ModuleOp>> 328 mlir::createConvertComplexToLLVMPass() { 329 return std::make_unique<ConvertComplexToLLVMPass>(); 330 } 331