1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Complex/IR/Complex.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 19 using namespace mlir; 20 using namespace mlir::LLVM; 21 22 //===----------------------------------------------------------------------===// 23 // ComplexStructBuilder implementation. 24 //===----------------------------------------------------------------------===// 25 26 static constexpr unsigned kRealPosInComplexNumberStruct = 0; 27 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; 28 29 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, 30 Location loc, Type type) { 31 Value val = builder.create<LLVM::UndefOp>(loc, type); 32 return ComplexStructBuilder(val); 33 } 34 35 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, 36 Value real) { 37 setPtr(builder, loc, kRealPosInComplexNumberStruct, real); 38 } 39 40 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { 41 return extractPtr(builder, loc, kRealPosInComplexNumberStruct); 42 } 43 44 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, 45 Value imaginary) { 46 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); 47 } 48 49 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { 50 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); 51 } 52 53 //===----------------------------------------------------------------------===// 54 // Conversion patterns. 55 //===----------------------------------------------------------------------===// 56 57 namespace { 58 59 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { 60 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; 61 62 LogicalResult 63 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 64 ConversionPatternRewriter &rewriter) const override { 65 auto loc = op.getLoc(); 66 67 ComplexStructBuilder complexStruct(adaptor.getComplex()); 68 Value real = complexStruct.real(rewriter, op.getLoc()); 69 Value imag = complexStruct.imaginary(rewriter, op.getLoc()); 70 71 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 72 Value sqNorm = rewriter.create<LLVM::FAddOp>( 73 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), 74 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); 75 76 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); 77 return success(); 78 } 79 }; 80 81 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { 82 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; 83 84 LogicalResult 85 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor, 86 ConversionPatternRewriter &rewriter) const override { 87 // Pack real and imaginary part in a complex number struct. 88 auto loc = complexOp.getLoc(); 89 auto structType = typeConverter->convertType(complexOp.getType()); 90 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); 91 complexStruct.setReal(rewriter, loc, adaptor.getReal()); 92 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary()); 93 94 rewriter.replaceOp(complexOp, {complexStruct}); 95 return success(); 96 } 97 }; 98 99 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { 100 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; 101 102 LogicalResult 103 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor, 104 ConversionPatternRewriter &rewriter) const override { 105 // Extract real part from the complex number struct. 106 ComplexStructBuilder complexStruct(adaptor.getComplex()); 107 Value real = complexStruct.real(rewriter, op.getLoc()); 108 rewriter.replaceOp(op, real); 109 110 return success(); 111 } 112 }; 113 114 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { 115 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; 116 117 LogicalResult 118 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor, 119 ConversionPatternRewriter &rewriter) const override { 120 // Extract imaginary part from the complex number struct. 121 ComplexStructBuilder complexStruct(adaptor.getComplex()); 122 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); 123 rewriter.replaceOp(op, imaginary); 124 125 return success(); 126 } 127 }; 128 129 struct BinaryComplexOperands { 130 std::complex<Value> lhs; 131 std::complex<Value> rhs; 132 }; 133 134 template <typename OpTy> 135 BinaryComplexOperands 136 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor, 137 ConversionPatternRewriter &rewriter) { 138 auto loc = op.getLoc(); 139 140 // Extract real and imaginary values from operands. 141 BinaryComplexOperands unpacked; 142 ComplexStructBuilder lhs(adaptor.getLhs()); 143 unpacked.lhs.real(lhs.real(rewriter, loc)); 144 unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); 145 ComplexStructBuilder rhs(adaptor.getRhs()); 146 unpacked.rhs.real(rhs.real(rewriter, loc)); 147 unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); 148 149 return unpacked; 150 } 151 152 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { 153 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; 154 155 LogicalResult 156 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor, 157 ConversionPatternRewriter &rewriter) const override { 158 auto loc = op.getLoc(); 159 BinaryComplexOperands arg = 160 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter); 161 162 // Initialize complex number struct for result. 163 auto structType = typeConverter->convertType(op.getType()); 164 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 165 166 // Emit IR to add complex numbers. 167 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 168 Value real = 169 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 170 Value imag = 171 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 172 result.setReal(rewriter, loc, real); 173 result.setImaginary(rewriter, loc, imag); 174 175 rewriter.replaceOp(op, {result}); 176 return success(); 177 } 178 }; 179 180 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { 181 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; 182 183 LogicalResult 184 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 185 ConversionPatternRewriter &rewriter) const override { 186 auto loc = op.getLoc(); 187 BinaryComplexOperands arg = 188 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter); 189 190 // Initialize complex number struct for result. 191 auto structType = typeConverter->convertType(op.getType()); 192 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 193 194 // Emit IR to add complex numbers. 195 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 196 Value rhsRe = arg.rhs.real(); 197 Value rhsIm = arg.rhs.imag(); 198 Value lhsRe = arg.lhs.real(); 199 Value lhsIm = arg.lhs.imag(); 200 201 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>( 202 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf), 203 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf); 204 205 Value resultReal = rewriter.create<LLVM::FAddOp>( 206 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf), 207 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf); 208 209 Value resultImag = rewriter.create<LLVM::FSubOp>( 210 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 211 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 212 213 result.setReal( 214 rewriter, loc, 215 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf)); 216 result.setImaginary( 217 rewriter, loc, 218 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf)); 219 220 rewriter.replaceOp(op, {result}); 221 return success(); 222 } 223 }; 224 225 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { 226 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; 227 228 LogicalResult 229 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 230 ConversionPatternRewriter &rewriter) const override { 231 auto loc = op.getLoc(); 232 BinaryComplexOperands arg = 233 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter); 234 235 // Initialize complex number struct for result. 236 auto structType = typeConverter->convertType(op.getType()); 237 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 238 239 // Emit IR to add complex numbers. 240 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 241 Value rhsRe = arg.rhs.real(); 242 Value rhsIm = arg.rhs.imag(); 243 Value lhsRe = arg.lhs.real(); 244 Value lhsIm = arg.lhs.imag(); 245 246 Value real = rewriter.create<LLVM::FSubOp>( 247 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), 248 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); 249 250 Value imag = rewriter.create<LLVM::FAddOp>( 251 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 252 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 253 254 result.setReal(rewriter, loc, real); 255 result.setImaginary(rewriter, loc, imag); 256 257 rewriter.replaceOp(op, {result}); 258 return success(); 259 } 260 }; 261 262 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { 263 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; 264 265 LogicalResult 266 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor, 267 ConversionPatternRewriter &rewriter) const override { 268 auto loc = op.getLoc(); 269 BinaryComplexOperands arg = 270 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter); 271 272 // Initialize complex number struct for result. 273 auto structType = typeConverter->convertType(op.getType()); 274 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 275 276 // Emit IR to substract complex numbers. 277 auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); 278 Value real = 279 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 280 Value imag = 281 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 282 result.setReal(rewriter, loc, real); 283 result.setImaginary(rewriter, loc, imag); 284 285 rewriter.replaceOp(op, {result}); 286 return success(); 287 } 288 }; 289 } // namespace 290 291 void mlir::populateComplexToLLVMConversionPatterns( 292 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 293 // clang-format off 294 patterns.add< 295 AbsOpConversion, 296 AddOpConversion, 297 CreateOpConversion, 298 DivOpConversion, 299 ImOpConversion, 300 MulOpConversion, 301 ReOpConversion, 302 SubOpConversion 303 >(converter); 304 // clang-format on 305 } 306 307 namespace { 308 struct ConvertComplexToLLVMPass 309 : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> { 310 void runOnOperation() override; 311 }; 312 } // namespace 313 314 void ConvertComplexToLLVMPass::runOnOperation() { 315 auto module = getOperation(); 316 317 // Convert to the LLVM IR dialect using the converter defined above. 318 RewritePatternSet patterns(&getContext()); 319 LLVMTypeConverter converter(&getContext()); 320 populateComplexToLLVMConversionPatterns(converter, patterns); 321 322 LLVMConversionTarget target(getContext()); 323 target.addLegalOp<ModuleOp, FuncOp>(); 324 target.addIllegalDialect<complex::ComplexDialect>(); 325 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 326 signalPassFailure(); 327 } 328 329 std::unique_ptr<OperationPass<ModuleOp>> 330 mlir::createConvertComplexToLLVMPass() { 331 return std::make_unique<ConvertComplexToLLVMPass>(); 332 } 333