1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Complex/IR/Complex.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 18 using namespace mlir; 19 using namespace mlir::LLVM; 20 21 //===----------------------------------------------------------------------===// 22 // ComplexStructBuilder implementation. 23 //===----------------------------------------------------------------------===// 24 25 static constexpr unsigned kRealPosInComplexNumberStruct = 0; 26 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; 27 28 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, 29 Location loc, Type type) { 30 Value val = builder.create<LLVM::UndefOp>(loc, type); 31 return ComplexStructBuilder(val); 32 } 33 34 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, 35 Value real) { 36 setPtr(builder, loc, kRealPosInComplexNumberStruct, real); 37 } 38 39 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { 40 return extractPtr(builder, loc, kRealPosInComplexNumberStruct); 41 } 42 43 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, 44 Value imaginary) { 45 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); 46 } 47 48 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { 49 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); 50 } 51 52 //===----------------------------------------------------------------------===// 53 // Conversion patterns. 54 //===----------------------------------------------------------------------===// 55 56 namespace { 57 58 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { 59 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; 60 61 LogicalResult 62 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 63 ConversionPatternRewriter &rewriter) const override { 64 auto loc = op.getLoc(); 65 66 ComplexStructBuilder complexStruct(adaptor.getComplex()); 67 Value real = complexStruct.real(rewriter, op.getLoc()); 68 Value imag = complexStruct.imaginary(rewriter, op.getLoc()); 69 70 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 71 Value sqNorm = rewriter.create<LLVM::FAddOp>( 72 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), 73 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); 74 75 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); 76 return success(); 77 } 78 }; 79 80 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> { 81 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 82 83 LogicalResult 84 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor, 85 ConversionPatternRewriter &rewriter) const override { 86 return LLVM::detail::oneToOneRewrite( 87 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), 88 *getTypeConverter(), rewriter); 89 } 90 }; 91 92 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { 93 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; 94 95 LogicalResult 96 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor, 97 ConversionPatternRewriter &rewriter) const override { 98 // Pack real and imaginary part in a complex number struct. 99 auto loc = complexOp.getLoc(); 100 auto structType = typeConverter->convertType(complexOp.getType()); 101 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); 102 complexStruct.setReal(rewriter, loc, adaptor.getReal()); 103 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary()); 104 105 rewriter.replaceOp(complexOp, {complexStruct}); 106 return success(); 107 } 108 }; 109 110 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { 111 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; 112 113 LogicalResult 114 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor, 115 ConversionPatternRewriter &rewriter) const override { 116 // Extract real part from the complex number struct. 117 ComplexStructBuilder complexStruct(adaptor.getComplex()); 118 Value real = complexStruct.real(rewriter, op.getLoc()); 119 rewriter.replaceOp(op, real); 120 121 return success(); 122 } 123 }; 124 125 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { 126 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; 127 128 LogicalResult 129 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor, 130 ConversionPatternRewriter &rewriter) const override { 131 // Extract imaginary part from the complex number struct. 132 ComplexStructBuilder complexStruct(adaptor.getComplex()); 133 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); 134 rewriter.replaceOp(op, imaginary); 135 136 return success(); 137 } 138 }; 139 140 struct BinaryComplexOperands { 141 std::complex<Value> lhs; 142 std::complex<Value> rhs; 143 }; 144 145 template <typename OpTy> 146 BinaryComplexOperands 147 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor, 148 ConversionPatternRewriter &rewriter) { 149 auto loc = op.getLoc(); 150 151 // Extract real and imaginary values from operands. 152 BinaryComplexOperands unpacked; 153 ComplexStructBuilder lhs(adaptor.getLhs()); 154 unpacked.lhs.real(lhs.real(rewriter, loc)); 155 unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); 156 ComplexStructBuilder rhs(adaptor.getRhs()); 157 unpacked.rhs.real(rhs.real(rewriter, loc)); 158 unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); 159 160 return unpacked; 161 } 162 163 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { 164 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; 165 166 LogicalResult 167 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor, 168 ConversionPatternRewriter &rewriter) const override { 169 auto loc = op.getLoc(); 170 BinaryComplexOperands arg = 171 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter); 172 173 // Initialize complex number struct for result. 174 auto structType = typeConverter->convertType(op.getType()); 175 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 176 177 // Emit IR to add complex numbers. 178 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 179 Value real = 180 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 181 Value imag = 182 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 183 result.setReal(rewriter, loc, real); 184 result.setImaginary(rewriter, loc, imag); 185 186 rewriter.replaceOp(op, {result}); 187 return success(); 188 } 189 }; 190 191 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { 192 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; 193 194 LogicalResult 195 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 196 ConversionPatternRewriter &rewriter) const override { 197 auto loc = op.getLoc(); 198 BinaryComplexOperands arg = 199 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter); 200 201 // Initialize complex number struct for result. 202 auto structType = typeConverter->convertType(op.getType()); 203 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 204 205 // Emit IR to add complex numbers. 206 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 207 Value rhsRe = arg.rhs.real(); 208 Value rhsIm = arg.rhs.imag(); 209 Value lhsRe = arg.lhs.real(); 210 Value lhsIm = arg.lhs.imag(); 211 212 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>( 213 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf), 214 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf); 215 216 Value resultReal = rewriter.create<LLVM::FAddOp>( 217 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf), 218 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf); 219 220 Value resultImag = rewriter.create<LLVM::FSubOp>( 221 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 222 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 223 224 result.setReal( 225 rewriter, loc, 226 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf)); 227 result.setImaginary( 228 rewriter, loc, 229 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf)); 230 231 rewriter.replaceOp(op, {result}); 232 return success(); 233 } 234 }; 235 236 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { 237 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; 238 239 LogicalResult 240 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 241 ConversionPatternRewriter &rewriter) const override { 242 auto loc = op.getLoc(); 243 BinaryComplexOperands arg = 244 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter); 245 246 // Initialize complex number struct for result. 247 auto structType = typeConverter->convertType(op.getType()); 248 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 249 250 // Emit IR to add complex numbers. 251 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 252 Value rhsRe = arg.rhs.real(); 253 Value rhsIm = arg.rhs.imag(); 254 Value lhsRe = arg.lhs.real(); 255 Value lhsIm = arg.lhs.imag(); 256 257 Value real = rewriter.create<LLVM::FSubOp>( 258 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), 259 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); 260 261 Value imag = rewriter.create<LLVM::FAddOp>( 262 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 263 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 264 265 result.setReal(rewriter, loc, real); 266 result.setImaginary(rewriter, loc, imag); 267 268 rewriter.replaceOp(op, {result}); 269 return success(); 270 } 271 }; 272 273 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { 274 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; 275 276 LogicalResult 277 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor, 278 ConversionPatternRewriter &rewriter) const override { 279 auto loc = op.getLoc(); 280 BinaryComplexOperands arg = 281 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter); 282 283 // Initialize complex number struct for result. 284 auto structType = typeConverter->convertType(op.getType()); 285 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 286 287 // Emit IR to substract complex numbers. 288 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 289 Value real = 290 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 291 Value imag = 292 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 293 result.setReal(rewriter, loc, real); 294 result.setImaginary(rewriter, loc, imag); 295 296 rewriter.replaceOp(op, {result}); 297 return success(); 298 } 299 }; 300 } // namespace 301 302 void mlir::populateComplexToLLVMConversionPatterns( 303 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 304 // clang-format off 305 patterns.add< 306 AbsOpConversion, 307 AddOpConversion, 308 ConstantOpLowering, 309 CreateOpConversion, 310 DivOpConversion, 311 ImOpConversion, 312 MulOpConversion, 313 ReOpConversion, 314 SubOpConversion 315 >(converter); 316 // clang-format on 317 } 318 319 namespace { 320 struct ConvertComplexToLLVMPass 321 : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> { 322 void runOnOperation() override; 323 }; 324 } // namespace 325 326 void ConvertComplexToLLVMPass::runOnOperation() { 327 // Convert to the LLVM IR dialect using the converter defined above. 328 RewritePatternSet patterns(&getContext()); 329 LLVMTypeConverter converter(&getContext()); 330 populateComplexToLLVMConversionPatterns(converter, patterns); 331 332 LLVMConversionTarget target(getContext()); 333 target.addIllegalDialect<complex::ComplexDialect>(); 334 if (failed( 335 applyPartialConversion(getOperation(), target, std::move(patterns)))) 336 signalPassFailure(); 337 } 338 339 std::unique_ptr<Pass> mlir::createConvertComplexToLLVMPass() { 340 return std::make_unique<ConvertComplexToLLVMPass>(); 341 } 342