1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Complex/IR/Complex.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 19 using namespace mlir; 20 using namespace mlir::LLVM; 21 22 //===----------------------------------------------------------------------===// 23 // ComplexStructBuilder implementation. 24 //===----------------------------------------------------------------------===// 25 26 static constexpr unsigned kRealPosInComplexNumberStruct = 0; 27 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; 28 29 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, 30 Location loc, Type type) { 31 Value val = builder.create<LLVM::UndefOp>(loc, type); 32 return ComplexStructBuilder(val); 33 } 34 35 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, 36 Value real) { 37 setPtr(builder, loc, kRealPosInComplexNumberStruct, real); 38 } 39 40 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { 41 return extractPtr(builder, loc, kRealPosInComplexNumberStruct); 42 } 43 44 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, 45 Value imaginary) { 46 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); 47 } 48 49 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { 50 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); 51 } 52 53 //===----------------------------------------------------------------------===// 54 // Conversion patterns. 55 //===----------------------------------------------------------------------===// 56 57 namespace { 58 59 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { 60 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; 61 62 LogicalResult 63 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 64 ConversionPatternRewriter &rewriter) const override { 65 auto loc = op.getLoc(); 66 67 ComplexStructBuilder complexStruct(adaptor.getComplex()); 68 Value real = complexStruct.real(rewriter, op.getLoc()); 69 Value imag = complexStruct.imaginary(rewriter, op.getLoc()); 70 71 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 72 Value sqNorm = rewriter.create<LLVM::FAddOp>( 73 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), 74 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); 75 76 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); 77 return success(); 78 } 79 }; 80 81 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> { 82 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 83 84 LogicalResult 85 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor, 86 ConversionPatternRewriter &rewriter) const override { 87 return LLVM::detail::oneToOneRewrite( 88 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), 89 *getTypeConverter(), rewriter); 90 } 91 }; 92 93 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { 94 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; 95 96 LogicalResult 97 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor, 98 ConversionPatternRewriter &rewriter) const override { 99 // Pack real and imaginary part in a complex number struct. 100 auto loc = complexOp.getLoc(); 101 auto structType = typeConverter->convertType(complexOp.getType()); 102 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); 103 complexStruct.setReal(rewriter, loc, adaptor.getReal()); 104 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary()); 105 106 rewriter.replaceOp(complexOp, {complexStruct}); 107 return success(); 108 } 109 }; 110 111 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { 112 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; 113 114 LogicalResult 115 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor, 116 ConversionPatternRewriter &rewriter) const override { 117 // Extract real part from the complex number struct. 118 ComplexStructBuilder complexStruct(adaptor.getComplex()); 119 Value real = complexStruct.real(rewriter, op.getLoc()); 120 rewriter.replaceOp(op, real); 121 122 return success(); 123 } 124 }; 125 126 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { 127 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; 128 129 LogicalResult 130 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor, 131 ConversionPatternRewriter &rewriter) const override { 132 // Extract imaginary part from the complex number struct. 133 ComplexStructBuilder complexStruct(adaptor.getComplex()); 134 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); 135 rewriter.replaceOp(op, imaginary); 136 137 return success(); 138 } 139 }; 140 141 struct BinaryComplexOperands { 142 std::complex<Value> lhs; 143 std::complex<Value> rhs; 144 }; 145 146 template <typename OpTy> 147 BinaryComplexOperands 148 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor, 149 ConversionPatternRewriter &rewriter) { 150 auto loc = op.getLoc(); 151 152 // Extract real and imaginary values from operands. 153 BinaryComplexOperands unpacked; 154 ComplexStructBuilder lhs(adaptor.getLhs()); 155 unpacked.lhs.real(lhs.real(rewriter, loc)); 156 unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); 157 ComplexStructBuilder rhs(adaptor.getRhs()); 158 unpacked.rhs.real(rhs.real(rewriter, loc)); 159 unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); 160 161 return unpacked; 162 } 163 164 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { 165 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; 166 167 LogicalResult 168 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor, 169 ConversionPatternRewriter &rewriter) const override { 170 auto loc = op.getLoc(); 171 BinaryComplexOperands arg = 172 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter); 173 174 // Initialize complex number struct for result. 175 auto structType = typeConverter->convertType(op.getType()); 176 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 177 178 // Emit IR to add complex numbers. 179 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 180 Value real = 181 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 182 Value imag = 183 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 184 result.setReal(rewriter, loc, real); 185 result.setImaginary(rewriter, loc, imag); 186 187 rewriter.replaceOp(op, {result}); 188 return success(); 189 } 190 }; 191 192 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { 193 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; 194 195 LogicalResult 196 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 197 ConversionPatternRewriter &rewriter) const override { 198 auto loc = op.getLoc(); 199 BinaryComplexOperands arg = 200 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter); 201 202 // Initialize complex number struct for result. 203 auto structType = typeConverter->convertType(op.getType()); 204 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 205 206 // Emit IR to add complex numbers. 207 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 208 Value rhsRe = arg.rhs.real(); 209 Value rhsIm = arg.rhs.imag(); 210 Value lhsRe = arg.lhs.real(); 211 Value lhsIm = arg.lhs.imag(); 212 213 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>( 214 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf), 215 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf); 216 217 Value resultReal = rewriter.create<LLVM::FAddOp>( 218 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf), 219 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf); 220 221 Value resultImag = rewriter.create<LLVM::FSubOp>( 222 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 223 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 224 225 result.setReal( 226 rewriter, loc, 227 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf)); 228 result.setImaginary( 229 rewriter, loc, 230 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf)); 231 232 rewriter.replaceOp(op, {result}); 233 return success(); 234 } 235 }; 236 237 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { 238 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; 239 240 LogicalResult 241 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 242 ConversionPatternRewriter &rewriter) const override { 243 auto loc = op.getLoc(); 244 BinaryComplexOperands arg = 245 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter); 246 247 // Initialize complex number struct for result. 248 auto structType = typeConverter->convertType(op.getType()); 249 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 250 251 // Emit IR to add complex numbers. 252 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 253 Value rhsRe = arg.rhs.real(); 254 Value rhsIm = arg.rhs.imag(); 255 Value lhsRe = arg.lhs.real(); 256 Value lhsIm = arg.lhs.imag(); 257 258 Value real = rewriter.create<LLVM::FSubOp>( 259 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), 260 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); 261 262 Value imag = rewriter.create<LLVM::FAddOp>( 263 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 264 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 265 266 result.setReal(rewriter, loc, real); 267 result.setImaginary(rewriter, loc, imag); 268 269 rewriter.replaceOp(op, {result}); 270 return success(); 271 } 272 }; 273 274 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { 275 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; 276 277 LogicalResult 278 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor, 279 ConversionPatternRewriter &rewriter) const override { 280 auto loc = op.getLoc(); 281 BinaryComplexOperands arg = 282 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter); 283 284 // Initialize complex number struct for result. 285 auto structType = typeConverter->convertType(op.getType()); 286 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 287 288 // Emit IR to substract complex numbers. 289 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 290 Value real = 291 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 292 Value imag = 293 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 294 result.setReal(rewriter, loc, real); 295 result.setImaginary(rewriter, loc, imag); 296 297 rewriter.replaceOp(op, {result}); 298 return success(); 299 } 300 }; 301 } // namespace 302 303 void mlir::populateComplexToLLVMConversionPatterns( 304 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 305 // clang-format off 306 patterns.add< 307 AbsOpConversion, 308 AddOpConversion, 309 ConstantOpLowering, 310 CreateOpConversion, 311 DivOpConversion, 312 ImOpConversion, 313 MulOpConversion, 314 ReOpConversion, 315 SubOpConversion 316 >(converter); 317 // clang-format on 318 } 319 320 namespace { 321 struct ConvertComplexToLLVMPass 322 : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> { 323 void runOnOperation() override; 324 }; 325 } // namespace 326 327 void ConvertComplexToLLVMPass::runOnOperation() { 328 auto module = getOperation(); 329 330 // Convert to the LLVM IR dialect using the converter defined above. 331 RewritePatternSet patterns(&getContext()); 332 LLVMTypeConverter converter(&getContext()); 333 populateComplexToLLVMConversionPatterns(converter, patterns); 334 335 LLVMConversionTarget target(getContext()); 336 target.addLegalOp<ModuleOp, FuncOp>(); 337 target.addIllegalDialect<complex::ComplexDialect>(); 338 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 339 signalPassFailure(); 340 } 341 342 std::unique_ptr<OperationPass<ModuleOp>> 343 mlir::createConvertComplexToLLVMPass() { 344 return std::make_unique<ConvertComplexToLLVMPass>(); 345 } 346