1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 #include <type_traits> 13 14 #include "../PassDetail.h" 15 #include "mlir/Dialect/Complex/IR/Complex.h" 16 #include "mlir/Dialect/Math/IR/Math.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 using namespace mlir; 22 23 namespace { 24 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 25 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 26 27 LogicalResult 28 matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands, 29 ConversionPatternRewriter &rewriter) const override { 30 complex::AbsOp::Adaptor transformed(operands); 31 auto loc = op.getLoc(); 32 auto type = op.getType(); 33 34 Value real = 35 rewriter.create<complex::ReOp>(loc, type, transformed.complex()); 36 Value imag = 37 rewriter.create<complex::ImOp>(loc, type, transformed.complex()); 38 Value realSqr = rewriter.create<MulFOp>(loc, real, real); 39 Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag); 40 Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr); 41 42 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 43 return success(); 44 } 45 }; 46 47 template <typename ComparisonOp, CmpFPredicate p> 48 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 49 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 50 using ResultCombiner = 51 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 52 AndOp, OrOp>; 53 54 LogicalResult 55 matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands, 56 ConversionPatternRewriter &rewriter) const override { 57 typename ComparisonOp::Adaptor transformed(operands); 58 auto loc = op.getLoc(); 59 auto type = transformed.lhs() 60 .getType() 61 .template cast<ComplexType>() 62 .getElementType(); 63 64 Value realLhs = 65 rewriter.create<complex::ReOp>(loc, type, transformed.lhs()); 66 Value imagLhs = 67 rewriter.create<complex::ImOp>(loc, type, transformed.lhs()); 68 Value realRhs = 69 rewriter.create<complex::ReOp>(loc, type, transformed.rhs()); 70 Value imagRhs = 71 rewriter.create<complex::ImOp>(loc, type, transformed.rhs()); 72 Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs); 73 Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs); 74 75 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 76 imagComparison); 77 return success(); 78 } 79 }; 80 81 struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 82 using OpConversionPattern<complex::DivOp>::OpConversionPattern; 83 84 LogicalResult 85 matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands, 86 ConversionPatternRewriter &rewriter) const override { 87 complex::DivOp::Adaptor transformed(operands); 88 auto loc = op.getLoc(); 89 auto type = transformed.lhs().getType().template cast<ComplexType>(); 90 auto elementType = type.getElementType().cast<FloatType>(); 91 92 Value lhsReal = 93 rewriter.create<complex::ReOp>(loc, elementType, transformed.lhs()); 94 Value lhsImag = 95 rewriter.create<complex::ImOp>(loc, elementType, transformed.lhs()); 96 Value rhsReal = 97 rewriter.create<complex::ReOp>(loc, elementType, transformed.rhs()); 98 Value rhsImag = 99 rewriter.create<complex::ImOp>(loc, elementType, transformed.rhs()); 100 101 // Smith's algorithm to divide complex numbers. It is just a bit smarter 102 // way to compute the following formula: 103 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 104 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 105 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 106 // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 107 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 108 // 109 // Depending on whether |rhsReal| < |rhsImag| we compute either 110 // rhsRealImagRatio = rhsReal / rhsImag 111 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 112 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 113 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 114 // 115 // or 116 // 117 // rhsImagRealRatio = rhsImag / rhsReal 118 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 119 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 120 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 121 // 122 // See https://dl.acm.org/citation.cfm?id=368661 for more details. 123 Value rhsRealImagRatio = rewriter.create<DivFOp>(loc, rhsReal, rhsImag); 124 Value rhsRealImagDenom = rewriter.create<AddFOp>( 125 loc, rhsImag, rewriter.create<MulFOp>(loc, rhsRealImagRatio, rhsReal)); 126 Value realNumerator1 = rewriter.create<AddFOp>( 127 loc, rewriter.create<MulFOp>(loc, lhsReal, rhsRealImagRatio), lhsImag); 128 Value resultReal1 = 129 rewriter.create<DivFOp>(loc, realNumerator1, rhsRealImagDenom); 130 Value imagNumerator1 = rewriter.create<SubFOp>( 131 loc, rewriter.create<MulFOp>(loc, lhsImag, rhsRealImagRatio), lhsReal); 132 Value resultImag1 = 133 rewriter.create<DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 134 135 Value rhsImagRealRatio = rewriter.create<DivFOp>(loc, rhsImag, rhsReal); 136 Value rhsImagRealDenom = rewriter.create<AddFOp>( 137 loc, rhsReal, rewriter.create<MulFOp>(loc, rhsImagRealRatio, rhsImag)); 138 Value realNumerator2 = rewriter.create<AddFOp>( 139 loc, lhsReal, rewriter.create<MulFOp>(loc, lhsImag, rhsImagRealRatio)); 140 Value resultReal2 = 141 rewriter.create<DivFOp>(loc, realNumerator2, rhsImagRealDenom); 142 Value imagNumerator2 = rewriter.create<SubFOp>( 143 loc, lhsImag, rewriter.create<MulFOp>(loc, lhsReal, rhsImagRealRatio)); 144 Value resultImag2 = 145 rewriter.create<DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 146 147 // Consider corner cases. 148 // Case 1. Zero denominator, numerator contains at most one NaN value. 149 Value zero = rewriter.create<ConstantOp>(loc, elementType, 150 rewriter.getZeroAttr(elementType)); 151 Value rhsRealAbs = rewriter.create<AbsFOp>(loc, rhsReal); 152 Value rhsRealIsZero = 153 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, zero); 154 Value rhsImagAbs = rewriter.create<AbsFOp>(loc, rhsImag); 155 Value rhsImagIsZero = 156 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, zero); 157 Value lhsRealIsNotNaN = 158 rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsReal, zero); 159 Value lhsImagIsNotNaN = 160 rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsImag, zero); 161 Value lhsContainsNotNaNValue = 162 rewriter.create<OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 163 Value resultIsInfinity = rewriter.create<AndOp>( 164 loc, lhsContainsNotNaNValue, 165 rewriter.create<AndOp>(loc, rhsRealIsZero, rhsImagIsZero)); 166 Value inf = rewriter.create<ConstantOp>( 167 loc, elementType, 168 rewriter.getFloatAttr( 169 elementType, APFloat::getInf(elementType.getFloatSemantics()))); 170 Value infWithSignOfRhsReal = rewriter.create<CopySignOp>(loc, inf, rhsReal); 171 Value infinityResultReal = 172 rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 173 Value infinityResultImag = 174 rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 175 176 // Case 2. Infinite numerator, finite denominator. 177 Value rhsRealFinite = 178 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsRealAbs, inf); 179 Value rhsImagFinite = 180 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsImagAbs, inf); 181 Value rhsFinite = rewriter.create<AndOp>(loc, rhsRealFinite, rhsImagFinite); 182 Value lhsRealAbs = rewriter.create<AbsFOp>(loc, lhsReal); 183 Value lhsRealInfinite = 184 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsRealAbs, inf); 185 Value lhsImagAbs = rewriter.create<AbsFOp>(loc, lhsImag); 186 Value lhsImagInfinite = 187 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsImagAbs, inf); 188 Value lhsInfinite = 189 rewriter.create<OrOp>(loc, lhsRealInfinite, lhsImagInfinite); 190 Value infNumFiniteDenom = 191 rewriter.create<AndOp>(loc, lhsInfinite, rhsFinite); 192 Value one = rewriter.create<ConstantOp>( 193 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 194 Value lhsRealIsInfWithSign = rewriter.create<CopySignOp>( 195 loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero), 196 lhsReal); 197 Value lhsImagIsInfWithSign = rewriter.create<CopySignOp>( 198 loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero), 199 lhsImag); 200 Value lhsRealIsInfWithSignTimesRhsReal = 201 rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 202 Value lhsImagIsInfWithSignTimesRhsImag = 203 rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 204 Value resultReal3 = rewriter.create<MulFOp>( 205 loc, inf, 206 rewriter.create<AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 207 lhsImagIsInfWithSignTimesRhsImag)); 208 Value lhsRealIsInfWithSignTimesRhsImag = 209 rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 210 Value lhsImagIsInfWithSignTimesRhsReal = 211 rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 212 Value resultImag3 = rewriter.create<MulFOp>( 213 loc, inf, 214 rewriter.create<SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 215 lhsRealIsInfWithSignTimesRhsImag)); 216 217 // Case 3: Finite numerator, infinite denominator. 218 Value lhsRealFinite = 219 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsRealAbs, inf); 220 Value lhsImagFinite = 221 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsImagAbs, inf); 222 Value lhsFinite = rewriter.create<AndOp>(loc, lhsRealFinite, lhsImagFinite); 223 Value rhsRealInfinite = 224 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, inf); 225 Value rhsImagInfinite = 226 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, inf); 227 Value rhsInfinite = 228 rewriter.create<OrOp>(loc, rhsRealInfinite, rhsImagInfinite); 229 Value finiteNumInfiniteDenom = 230 rewriter.create<AndOp>(loc, lhsFinite, rhsInfinite); 231 Value rhsRealIsInfWithSign = rewriter.create<CopySignOp>( 232 loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero), 233 rhsReal); 234 Value rhsImagIsInfWithSign = rewriter.create<CopySignOp>( 235 loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero), 236 rhsImag); 237 Value rhsRealIsInfWithSignTimesLhsReal = 238 rewriter.create<MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 239 Value rhsImagIsInfWithSignTimesLhsImag = 240 rewriter.create<MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 241 Value resultReal4 = rewriter.create<MulFOp>( 242 loc, zero, 243 rewriter.create<AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 244 rhsImagIsInfWithSignTimesLhsImag)); 245 Value rhsRealIsInfWithSignTimesLhsImag = 246 rewriter.create<MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 247 Value rhsImagIsInfWithSignTimesLhsReal = 248 rewriter.create<MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 249 Value resultImag4 = rewriter.create<MulFOp>( 250 loc, zero, 251 rewriter.create<SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 252 rhsImagIsInfWithSignTimesLhsReal)); 253 254 Value realAbsSmallerThanImagAbs = rewriter.create<CmpFOp>( 255 loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 256 Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 257 resultReal1, resultReal2); 258 Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 259 resultImag1, resultImag2); 260 Value resultRealSpecialCase3 = rewriter.create<SelectOp>( 261 loc, finiteNumInfiniteDenom, resultReal4, resultReal); 262 Value resultImagSpecialCase3 = rewriter.create<SelectOp>( 263 loc, finiteNumInfiniteDenom, resultImag4, resultImag); 264 Value resultRealSpecialCase2 = rewriter.create<SelectOp>( 265 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 266 Value resultImagSpecialCase2 = rewriter.create<SelectOp>( 267 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 268 Value resultRealSpecialCase1 = rewriter.create<SelectOp>( 269 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 270 Value resultImagSpecialCase1 = rewriter.create<SelectOp>( 271 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 272 273 Value resultRealIsNaN = 274 rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultReal, zero); 275 Value resultImagIsNaN = 276 rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultImag, zero); 277 Value resultIsNaN = 278 rewriter.create<AndOp>(loc, resultRealIsNaN, resultImagIsNaN); 279 Value resultRealWithSpecialCases = rewriter.create<SelectOp>( 280 loc, resultIsNaN, resultRealSpecialCase1, resultReal); 281 Value resultImagWithSpecialCases = rewriter.create<SelectOp>( 282 loc, resultIsNaN, resultImagSpecialCase1, resultImag); 283 284 rewriter.replaceOpWithNewOp<complex::CreateOp>( 285 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 286 return success(); 287 } 288 }; 289 } // namespace 290 291 void mlir::populateComplexToStandardConversionPatterns( 292 RewritePatternSet &patterns) { 293 patterns.add<AbsOpConversion, 294 ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>, 295 ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>, 296 DivOpConversion>(patterns.getContext()); 297 } 298 299 namespace { 300 struct ConvertComplexToStandardPass 301 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 302 void runOnFunction() override; 303 }; 304 305 void ConvertComplexToStandardPass::runOnFunction() { 306 auto function = getFunction(); 307 308 // Convert to the Standard dialect using the converter defined above. 309 RewritePatternSet patterns(&getContext()); 310 populateComplexToStandardConversionPatterns(patterns); 311 312 ConversionTarget target(getContext()); 313 target.addLegalDialect<StandardOpsDialect, math::MathDialect, 314 complex::ComplexDialect>(); 315 target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp, 316 complex::NotEqualOp>(); 317 if (failed(applyPartialConversion(function, target, std::move(patterns)))) 318 signalPassFailure(); 319 } 320 } // namespace 321 322 std::unique_ptr<OperationPass<FuncOp>> 323 mlir::createConvertComplexToStandardPass() { 324 return std::make_unique<ConvertComplexToStandardPass>(); 325 } 326