1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Dialect/Complex/IR/Complex.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 16 using namespace mlir; 17 using namespace mlir::LLVM; 18 19 //===----------------------------------------------------------------------===// 20 // ComplexStructBuilder implementation. 21 //===----------------------------------------------------------------------===// 22 23 static constexpr unsigned kRealPosInComplexNumberStruct = 0; 24 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; 25 26 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, 27 Location loc, Type type) { 28 Value val = builder.create<LLVM::UndefOp>(loc, type); 29 return ComplexStructBuilder(val); 30 } 31 32 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, 33 Value real) { 34 setPtr(builder, loc, kRealPosInComplexNumberStruct, real); 35 } 36 37 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { 38 return extractPtr(builder, loc, kRealPosInComplexNumberStruct); 39 } 40 41 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, 42 Value imaginary) { 43 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); 44 } 45 46 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { 47 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); 48 } 49 50 //===----------------------------------------------------------------------===// 51 // Conversion patterns. 52 //===----------------------------------------------------------------------===// 53 54 namespace { 55 56 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { 57 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; 58 59 LogicalResult 60 matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands, 61 ConversionPatternRewriter &rewriter) const override { 62 complex::AbsOp::Adaptor transformed(operands); 63 auto loc = op.getLoc(); 64 65 ComplexStructBuilder complexStruct(transformed.complex()); 66 Value real = complexStruct.real(rewriter, op.getLoc()); 67 Value imag = complexStruct.imaginary(rewriter, op.getLoc()); 68 69 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 70 Value sqNorm = rewriter.create<LLVM::FAddOp>( 71 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), 72 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); 73 74 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); 75 return success(); 76 } 77 }; 78 79 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { 80 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; 81 82 LogicalResult 83 matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands, 84 ConversionPatternRewriter &rewriter) const override { 85 complex::CreateOp::Adaptor transformed(operands); 86 87 // Pack real and imaginary part in a complex number struct. 88 auto loc = complexOp.getLoc(); 89 auto structType = typeConverter->convertType(complexOp.getType()); 90 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); 91 complexStruct.setReal(rewriter, loc, transformed.real()); 92 complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); 93 94 rewriter.replaceOp(complexOp, {complexStruct}); 95 return success(); 96 } 97 }; 98 99 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { 100 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; 101 102 LogicalResult 103 matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands, 104 ConversionPatternRewriter &rewriter) const override { 105 complex::ReOp::Adaptor transformed(operands); 106 107 // Extract real part from the complex number struct. 108 ComplexStructBuilder complexStruct(transformed.complex()); 109 Value real = complexStruct.real(rewriter, op.getLoc()); 110 rewriter.replaceOp(op, real); 111 112 return success(); 113 } 114 }; 115 116 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { 117 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; 118 119 LogicalResult 120 matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands, 121 ConversionPatternRewriter &rewriter) const override { 122 complex::ImOp::Adaptor transformed(operands); 123 124 // Extract imaginary part from the complex number struct. 125 ComplexStructBuilder complexStruct(transformed.complex()); 126 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); 127 rewriter.replaceOp(op, imaginary); 128 129 return success(); 130 } 131 }; 132 133 struct BinaryComplexOperands { 134 std::complex<Value> lhs; 135 std::complex<Value> rhs; 136 }; 137 138 template <typename OpTy> 139 BinaryComplexOperands 140 unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands, 141 ConversionPatternRewriter &rewriter) { 142 auto loc = op.getLoc(); 143 typename OpTy::Adaptor transformed(operands); 144 145 // Extract real and imaginary values from operands. 146 BinaryComplexOperands unpacked; 147 ComplexStructBuilder lhs(transformed.lhs()); 148 unpacked.lhs.real(lhs.real(rewriter, loc)); 149 unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); 150 ComplexStructBuilder rhs(transformed.rhs()); 151 unpacked.rhs.real(rhs.real(rewriter, loc)); 152 unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); 153 154 return unpacked; 155 } 156 157 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { 158 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; 159 160 LogicalResult 161 matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands, 162 ConversionPatternRewriter &rewriter) const override { 163 auto loc = op.getLoc(); 164 BinaryComplexOperands arg = 165 unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter); 166 167 // Initialize complex number struct for result. 168 auto structType = typeConverter->convertType(op.getType()); 169 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 170 171 // Emit IR to add complex numbers. 172 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 173 Value real = 174 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 175 Value imag = 176 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 177 result.setReal(rewriter, loc, real); 178 result.setImaginary(rewriter, loc, imag); 179 180 rewriter.replaceOp(op, {result}); 181 return success(); 182 } 183 }; 184 185 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { 186 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; 187 188 LogicalResult 189 matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands, 190 ConversionPatternRewriter &rewriter) const override { 191 auto loc = op.getLoc(); 192 BinaryComplexOperands arg = 193 unpackBinaryComplexOperands<complex::DivOp>(op, operands, rewriter); 194 195 // Initialize complex number struct for result. 196 auto structType = typeConverter->convertType(op.getType()); 197 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 198 199 // Emit IR to add complex numbers. 200 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 201 Value rhsRe = arg.rhs.real(); 202 Value rhsIm = arg.rhs.imag(); 203 Value lhsRe = arg.lhs.real(); 204 Value lhsIm = arg.lhs.imag(); 205 206 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>( 207 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf), 208 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf); 209 210 Value resultReal = rewriter.create<LLVM::FAddOp>( 211 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf), 212 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf); 213 214 Value resultImag = rewriter.create<LLVM::FSubOp>( 215 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 216 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 217 218 result.setReal( 219 rewriter, loc, 220 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf)); 221 result.setImaginary( 222 rewriter, loc, 223 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf)); 224 225 rewriter.replaceOp(op, {result}); 226 return success(); 227 } 228 }; 229 230 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { 231 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; 232 233 LogicalResult 234 matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands, 235 ConversionPatternRewriter &rewriter) const override { 236 auto loc = op.getLoc(); 237 BinaryComplexOperands arg = 238 unpackBinaryComplexOperands<complex::MulOp>(op, operands, rewriter); 239 240 // Initialize complex number struct for result. 241 auto structType = typeConverter->convertType(op.getType()); 242 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 243 244 // Emit IR to add complex numbers. 245 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 246 Value rhsRe = arg.rhs.real(); 247 Value rhsIm = arg.rhs.imag(); 248 Value lhsRe = arg.lhs.real(); 249 Value lhsIm = arg.lhs.imag(); 250 251 Value real = rewriter.create<LLVM::FSubOp>( 252 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), 253 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); 254 255 Value imag = rewriter.create<LLVM::FAddOp>( 256 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 257 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 258 259 result.setReal(rewriter, loc, real); 260 result.setImaginary(rewriter, loc, imag); 261 262 rewriter.replaceOp(op, {result}); 263 return success(); 264 } 265 }; 266 267 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { 268 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; 269 270 LogicalResult 271 matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands, 272 ConversionPatternRewriter &rewriter) const override { 273 auto loc = op.getLoc(); 274 BinaryComplexOperands arg = 275 unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter); 276 277 // Initialize complex number struct for result. 278 auto structType = typeConverter->convertType(op.getType()); 279 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 280 281 // Emit IR to substract complex numbers. 282 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 283 Value real = 284 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 285 Value imag = 286 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 287 result.setReal(rewriter, loc, real); 288 result.setImaginary(rewriter, loc, imag); 289 290 rewriter.replaceOp(op, {result}); 291 return success(); 292 } 293 }; 294 } // namespace 295 296 void mlir::populateComplexToLLVMConversionPatterns( 297 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 298 // clang-format off 299 patterns.add< 300 AbsOpConversion, 301 AddOpConversion, 302 CreateOpConversion, 303 DivOpConversion, 304 ImOpConversion, 305 MulOpConversion, 306 ReOpConversion, 307 SubOpConversion 308 >(converter); 309 // clang-format on 310 } 311 312 namespace { 313 struct ConvertComplexToLLVMPass 314 : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> { 315 void runOnOperation() override; 316 }; 317 } // namespace 318 319 void ConvertComplexToLLVMPass::runOnOperation() { 320 auto module = getOperation(); 321 322 // Convert to the LLVM IR dialect using the converter defined above. 323 RewritePatternSet patterns(&getContext()); 324 LLVMTypeConverter converter(&getContext()); 325 populateComplexToLLVMConversionPatterns(converter, patterns); 326 327 LLVMConversionTarget target(getContext()); 328 target.addLegalOp<ModuleOp, FuncOp>(); 329 target.addLegalOp<LLVM::DialectCastOp>(); 330 target.addIllegalDialect<complex::ComplexDialect>(); 331 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 332 signalPassFailure(); 333 } 334 335 std::unique_ptr<OperationPass<ModuleOp>> 336 mlir::createConvertComplexToLLVMPass() { 337 return std::make_unique<ConvertComplexToLLVMPass>(); 338 } 339