1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Complex/IR/Complex.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 15 using namespace mlir; 16 using namespace mlir::LLVM; 17 18 namespace { 19 20 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { 21 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; 22 23 LogicalResult 24 matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands, 25 ConversionPatternRewriter &rewriter) const override { 26 complex::AbsOp::Adaptor transformed(operands); 27 auto loc = op.getLoc(); 28 29 ComplexStructBuilder complexStruct(transformed.complex()); 30 Value real = complexStruct.real(rewriter, op.getLoc()); 31 Value imag = complexStruct.imaginary(rewriter, op.getLoc()); 32 33 auto fmf = LLVM::FMFAttr::get({}, op.getContext()); 34 Value sqNorm = rewriter.create<LLVM::FAddOp>( 35 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), 36 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); 37 38 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); 39 return success(); 40 } 41 }; 42 43 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { 44 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; 45 46 LogicalResult 47 matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands, 48 ConversionPatternRewriter &rewriter) const override { 49 complex::CreateOp::Adaptor transformed(operands); 50 51 // Pack real and imaginary part in a complex number struct. 52 auto loc = complexOp.getLoc(); 53 auto structType = typeConverter->convertType(complexOp.getType()); 54 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); 55 complexStruct.setReal(rewriter, loc, transformed.real()); 56 complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); 57 58 rewriter.replaceOp(complexOp, {complexStruct}); 59 return success(); 60 } 61 }; 62 63 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { 64 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; 65 66 LogicalResult 67 matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands, 68 ConversionPatternRewriter &rewriter) const override { 69 complex::ReOp::Adaptor transformed(operands); 70 71 // Extract real part from the complex number struct. 72 ComplexStructBuilder complexStruct(transformed.complex()); 73 Value real = complexStruct.real(rewriter, op.getLoc()); 74 rewriter.replaceOp(op, real); 75 76 return success(); 77 } 78 }; 79 80 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { 81 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; 82 83 LogicalResult 84 matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands, 85 ConversionPatternRewriter &rewriter) const override { 86 complex::ImOp::Adaptor transformed(operands); 87 88 // Extract imaginary part from the complex number struct. 89 ComplexStructBuilder complexStruct(transformed.complex()); 90 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); 91 rewriter.replaceOp(op, imaginary); 92 93 return success(); 94 } 95 }; 96 97 struct BinaryComplexOperands { 98 std::complex<Value> lhs; 99 std::complex<Value> rhs; 100 }; 101 102 template <typename OpTy> 103 BinaryComplexOperands 104 unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands, 105 ConversionPatternRewriter &rewriter) { 106 auto loc = op.getLoc(); 107 typename OpTy::Adaptor transformed(operands); 108 109 // Extract real and imaginary values from operands. 110 BinaryComplexOperands unpacked; 111 ComplexStructBuilder lhs(transformed.lhs()); 112 unpacked.lhs.real(lhs.real(rewriter, loc)); 113 unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); 114 ComplexStructBuilder rhs(transformed.rhs()); 115 unpacked.rhs.real(rhs.real(rewriter, loc)); 116 unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); 117 118 return unpacked; 119 } 120 121 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { 122 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; 123 124 LogicalResult 125 matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands, 126 ConversionPatternRewriter &rewriter) const override { 127 auto loc = op.getLoc(); 128 BinaryComplexOperands arg = 129 unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter); 130 131 // Initialize complex number struct for result. 132 auto structType = typeConverter->convertType(op.getType()); 133 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 134 135 // Emit IR to add complex numbers. 136 auto fmf = LLVM::FMFAttr::get({}, op.getContext()); 137 Value real = 138 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 139 Value imag = 140 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 141 result.setReal(rewriter, loc, real); 142 result.setImaginary(rewriter, loc, imag); 143 144 rewriter.replaceOp(op, {result}); 145 return success(); 146 } 147 }; 148 149 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { 150 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; 151 152 LogicalResult 153 matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands, 154 ConversionPatternRewriter &rewriter) const override { 155 auto loc = op.getLoc(); 156 BinaryComplexOperands arg = 157 unpackBinaryComplexOperands<complex::DivOp>(op, operands, rewriter); 158 159 // Initialize complex number struct for result. 160 auto structType = typeConverter->convertType(op.getType()); 161 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 162 163 // Emit IR to add complex numbers. 164 auto fmf = LLVM::FMFAttr::get({}, op.getContext()); 165 Value rhsRe = arg.rhs.real(); 166 Value rhsIm = arg.rhs.imag(); 167 Value lhsRe = arg.lhs.real(); 168 Value lhsIm = arg.lhs.imag(); 169 170 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>( 171 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf), 172 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf); 173 174 Value resultReal = rewriter.create<LLVM::FAddOp>( 175 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf), 176 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf); 177 178 Value resultImag = rewriter.create<LLVM::FSubOp>( 179 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 180 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 181 182 result.setReal( 183 rewriter, loc, 184 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf)); 185 result.setImaginary( 186 rewriter, loc, 187 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf)); 188 189 rewriter.replaceOp(op, {result}); 190 return success(); 191 } 192 }; 193 194 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { 195 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; 196 197 LogicalResult 198 matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands, 199 ConversionPatternRewriter &rewriter) const override { 200 auto loc = op.getLoc(); 201 BinaryComplexOperands arg = 202 unpackBinaryComplexOperands<complex::MulOp>(op, operands, rewriter); 203 204 // Initialize complex number struct for result. 205 auto structType = typeConverter->convertType(op.getType()); 206 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 207 208 // Emit IR to add complex numbers. 209 auto fmf = LLVM::FMFAttr::get({}, op.getContext()); 210 Value rhsRe = arg.rhs.real(); 211 Value rhsIm = arg.rhs.imag(); 212 Value lhsRe = arg.lhs.real(); 213 Value lhsIm = arg.lhs.imag(); 214 215 Value real = rewriter.create<LLVM::FSubOp>( 216 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), 217 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); 218 219 Value imag = rewriter.create<LLVM::FAddOp>( 220 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 221 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 222 223 result.setReal(rewriter, loc, real); 224 result.setImaginary(rewriter, loc, imag); 225 226 rewriter.replaceOp(op, {result}); 227 return success(); 228 } 229 }; 230 231 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { 232 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; 233 234 LogicalResult 235 matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands, 236 ConversionPatternRewriter &rewriter) const override { 237 auto loc = op.getLoc(); 238 BinaryComplexOperands arg = 239 unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter); 240 241 // Initialize complex number struct for result. 242 auto structType = typeConverter->convertType(op.getType()); 243 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 244 245 // Emit IR to substract complex numbers. 246 auto fmf = LLVM::FMFAttr::get({}, op.getContext()); 247 Value real = 248 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 249 Value imag = 250 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 251 result.setReal(rewriter, loc, real); 252 result.setImaginary(rewriter, loc, imag); 253 254 rewriter.replaceOp(op, {result}); 255 return success(); 256 } 257 }; 258 } // namespace 259 260 void mlir::populateComplexToLLVMConversionPatterns( 261 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 262 // clang-format off 263 patterns.insert< 264 AbsOpConversion, 265 AddOpConversion, 266 CreateOpConversion, 267 DivOpConversion, 268 ImOpConversion, 269 MulOpConversion, 270 ReOpConversion, 271 SubOpConversion 272 >(converter); 273 // clang-format on 274 } 275 276 namespace { 277 struct ConvertComplexToLLVMPass 278 : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> { 279 void runOnOperation() override; 280 }; 281 } // namespace 282 283 void ConvertComplexToLLVMPass::runOnOperation() { 284 auto module = getOperation(); 285 286 // Convert to the LLVM IR dialect using the converter defined above. 287 OwningRewritePatternList patterns; 288 LLVMTypeConverter converter(&getContext()); 289 populateStdToLLVMConversionPatterns(converter, patterns); 290 populateComplexToLLVMConversionPatterns(converter, patterns); 291 292 LLVMConversionTarget target(getContext()); 293 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 294 if (failed(applyFullConversion(module, target, std::move(patterns)))) 295 signalPassFailure(); 296 } 297 298 std::unique_ptr<OperationPass<ModuleOp>> 299 mlir::createConvertComplexToLLVMPass() { 300 return std::make_unique<ConvertComplexToLLVMPass>(); 301 } 302