1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Dialect/Complex/IR/Complex.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 17 using namespace mlir; 18 using namespace mlir::LLVM; 19 20 //===----------------------------------------------------------------------===// 21 // ComplexStructBuilder implementation. 22 //===----------------------------------------------------------------------===// 23 24 static constexpr unsigned kRealPosInComplexNumberStruct = 0; 25 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; 26 27 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, 28 Location loc, Type type) { 29 Value val = builder.create<LLVM::UndefOp>(loc, type); 30 return ComplexStructBuilder(val); 31 } 32 33 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, 34 Value real) { 35 setPtr(builder, loc, kRealPosInComplexNumberStruct, real); 36 } 37 38 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { 39 return extractPtr(builder, loc, kRealPosInComplexNumberStruct); 40 } 41 42 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, 43 Value imaginary) { 44 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); 45 } 46 47 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { 48 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); 49 } 50 51 //===----------------------------------------------------------------------===// 52 // Conversion patterns. 53 //===----------------------------------------------------------------------===// 54 55 namespace { 56 57 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { 58 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; 59 60 LogicalResult 61 matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands, 62 ConversionPatternRewriter &rewriter) const override { 63 complex::AbsOp::Adaptor transformed(operands); 64 auto loc = op.getLoc(); 65 66 ComplexStructBuilder complexStruct(transformed.complex()); 67 Value real = complexStruct.real(rewriter, op.getLoc()); 68 Value imag = complexStruct.imaginary(rewriter, op.getLoc()); 69 70 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 71 Value sqNorm = rewriter.create<LLVM::FAddOp>( 72 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), 73 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); 74 75 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); 76 return success(); 77 } 78 }; 79 80 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { 81 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; 82 83 LogicalResult 84 matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands, 85 ConversionPatternRewriter &rewriter) const override { 86 complex::CreateOp::Adaptor transformed(operands); 87 88 // Pack real and imaginary part in a complex number struct. 89 auto loc = complexOp.getLoc(); 90 auto structType = typeConverter->convertType(complexOp.getType()); 91 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); 92 complexStruct.setReal(rewriter, loc, transformed.real()); 93 complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); 94 95 rewriter.replaceOp(complexOp, {complexStruct}); 96 return success(); 97 } 98 }; 99 100 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { 101 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; 102 103 LogicalResult 104 matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands, 105 ConversionPatternRewriter &rewriter) const override { 106 complex::ReOp::Adaptor transformed(operands); 107 108 // Extract real part from the complex number struct. 109 ComplexStructBuilder complexStruct(transformed.complex()); 110 Value real = complexStruct.real(rewriter, op.getLoc()); 111 rewriter.replaceOp(op, real); 112 113 return success(); 114 } 115 }; 116 117 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { 118 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; 119 120 LogicalResult 121 matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands, 122 ConversionPatternRewriter &rewriter) const override { 123 complex::ImOp::Adaptor transformed(operands); 124 125 // Extract imaginary part from the complex number struct. 126 ComplexStructBuilder complexStruct(transformed.complex()); 127 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); 128 rewriter.replaceOp(op, imaginary); 129 130 return success(); 131 } 132 }; 133 134 struct BinaryComplexOperands { 135 std::complex<Value> lhs; 136 std::complex<Value> rhs; 137 }; 138 139 template <typename OpTy> 140 BinaryComplexOperands 141 unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands, 142 ConversionPatternRewriter &rewriter) { 143 auto loc = op.getLoc(); 144 typename OpTy::Adaptor transformed(operands); 145 146 // Extract real and imaginary values from operands. 147 BinaryComplexOperands unpacked; 148 ComplexStructBuilder lhs(transformed.lhs()); 149 unpacked.lhs.real(lhs.real(rewriter, loc)); 150 unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); 151 ComplexStructBuilder rhs(transformed.rhs()); 152 unpacked.rhs.real(rhs.real(rewriter, loc)); 153 unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); 154 155 return unpacked; 156 } 157 158 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { 159 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; 160 161 LogicalResult 162 matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands, 163 ConversionPatternRewriter &rewriter) const override { 164 auto loc = op.getLoc(); 165 BinaryComplexOperands arg = 166 unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter); 167 168 // Initialize complex number struct for result. 169 auto structType = typeConverter->convertType(op.getType()); 170 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 171 172 // Emit IR to add complex numbers. 173 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 174 Value real = 175 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 176 Value imag = 177 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 178 result.setReal(rewriter, loc, real); 179 result.setImaginary(rewriter, loc, imag); 180 181 rewriter.replaceOp(op, {result}); 182 return success(); 183 } 184 }; 185 186 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { 187 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; 188 189 LogicalResult 190 matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands, 191 ConversionPatternRewriter &rewriter) const override { 192 auto loc = op.getLoc(); 193 BinaryComplexOperands arg = 194 unpackBinaryComplexOperands<complex::DivOp>(op, operands, rewriter); 195 196 // Initialize complex number struct for result. 197 auto structType = typeConverter->convertType(op.getType()); 198 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 199 200 // Emit IR to add complex numbers. 201 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 202 Value rhsRe = arg.rhs.real(); 203 Value rhsIm = arg.rhs.imag(); 204 Value lhsRe = arg.lhs.real(); 205 Value lhsIm = arg.lhs.imag(); 206 207 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>( 208 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf), 209 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf); 210 211 Value resultReal = rewriter.create<LLVM::FAddOp>( 212 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf), 213 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf); 214 215 Value resultImag = rewriter.create<LLVM::FSubOp>( 216 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 217 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 218 219 result.setReal( 220 rewriter, loc, 221 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf)); 222 result.setImaginary( 223 rewriter, loc, 224 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf)); 225 226 rewriter.replaceOp(op, {result}); 227 return success(); 228 } 229 }; 230 231 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { 232 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; 233 234 LogicalResult 235 matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands, 236 ConversionPatternRewriter &rewriter) const override { 237 auto loc = op.getLoc(); 238 BinaryComplexOperands arg = 239 unpackBinaryComplexOperands<complex::MulOp>(op, operands, rewriter); 240 241 // Initialize complex number struct for result. 242 auto structType = typeConverter->convertType(op.getType()); 243 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 244 245 // Emit IR to add complex numbers. 246 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 247 Value rhsRe = arg.rhs.real(); 248 Value rhsIm = arg.rhs.imag(); 249 Value lhsRe = arg.lhs.real(); 250 Value lhsIm = arg.lhs.imag(); 251 252 Value real = rewriter.create<LLVM::FSubOp>( 253 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), 254 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); 255 256 Value imag = rewriter.create<LLVM::FAddOp>( 257 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 258 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 259 260 result.setReal(rewriter, loc, real); 261 result.setImaginary(rewriter, loc, imag); 262 263 rewriter.replaceOp(op, {result}); 264 return success(); 265 } 266 }; 267 268 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { 269 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; 270 271 LogicalResult 272 matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands, 273 ConversionPatternRewriter &rewriter) const override { 274 auto loc = op.getLoc(); 275 BinaryComplexOperands arg = 276 unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter); 277 278 // Initialize complex number struct for result. 279 auto structType = typeConverter->convertType(op.getType()); 280 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 281 282 // Emit IR to substract complex numbers. 283 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 284 Value real = 285 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 286 Value imag = 287 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 288 result.setReal(rewriter, loc, real); 289 result.setImaginary(rewriter, loc, imag); 290 291 rewriter.replaceOp(op, {result}); 292 return success(); 293 } 294 }; 295 } // namespace 296 297 void mlir::populateComplexToLLVMConversionPatterns( 298 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 299 // clang-format off 300 patterns.add< 301 AbsOpConversion, 302 AddOpConversion, 303 CreateOpConversion, 304 DivOpConversion, 305 ImOpConversion, 306 MulOpConversion, 307 ReOpConversion, 308 SubOpConversion 309 >(converter); 310 // clang-format on 311 } 312 313 namespace { 314 struct ConvertComplexToLLVMPass 315 : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> { 316 void runOnOperation() override; 317 }; 318 } // namespace 319 320 void ConvertComplexToLLVMPass::runOnOperation() { 321 auto module = getOperation(); 322 323 // Convert to the LLVM IR dialect using the converter defined above. 324 RewritePatternSet patterns(&getContext()); 325 LLVMTypeConverter converter(&getContext()); 326 populateComplexToLLVMConversionPatterns(converter, patterns); 327 328 LLVMConversionTarget target(getContext()); 329 target.addLegalOp<ModuleOp, FuncOp>(); 330 target.addIllegalDialect<complex::ComplexDialect>(); 331 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 332 signalPassFailure(); 333 } 334 335 std::unique_ptr<OperationPass<ModuleOp>> 336 mlir::createConvertComplexToLLVMPass() { 337 return std::make_unique<ConvertComplexToLLVMPass>(); 338 } 339