1e4978713SBenjamin Kramer //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===// 2e4978713SBenjamin Kramer // 3e4978713SBenjamin Kramer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e4978713SBenjamin Kramer // See https://llvm.org/LICENSE.txt for license information. 5e4978713SBenjamin Kramer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e4978713SBenjamin Kramer // 7e4978713SBenjamin Kramer //===----------------------------------------------------------------------===// 8e4978713SBenjamin Kramer 9e4978713SBenjamin Kramer #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" 10e4978713SBenjamin Kramer 11e4978713SBenjamin Kramer #include "../PassDetail.h" 12e4978713SBenjamin Kramer #include "mlir/Dialect/Complex/IR/Complex.h" 13e4978713SBenjamin Kramer #include "mlir/Dialect/Func/IR/FuncOps.h" 14e4978713SBenjamin Kramer #include "mlir/IR/PatternMatch.h" 15e4978713SBenjamin Kramer 16e4978713SBenjamin Kramer using namespace mlir; 17e4978713SBenjamin Kramer 18e4978713SBenjamin Kramer namespace { 19eaba6e0bSlewuathe // Functor to resolve the function name corresponding to the given complex 20eaba6e0bSlewuathe // result type. 21eaba6e0bSlewuathe struct ComplexTypeResolver { 22eaba6e0bSlewuathe llvm::Optional<bool> operator()(Type type) const { 23eaba6e0bSlewuathe auto complexType = type.cast<ComplexType>(); 24eaba6e0bSlewuathe auto elementType = complexType.getElementType(); 25eaba6e0bSlewuathe if (!elementType.isa<Float32Type, Float64Type>()) 26eaba6e0bSlewuathe return {}; 27eaba6e0bSlewuathe 28eaba6e0bSlewuathe return elementType.getIntOrFloatBitWidth() == 64; 29eaba6e0bSlewuathe } 30eaba6e0bSlewuathe }; 31eaba6e0bSlewuathe 32eaba6e0bSlewuathe // Functor to resolve the function name corresponding to the given float result 33eaba6e0bSlewuathe // type. 34eaba6e0bSlewuathe struct FloatTypeResolver { 35eaba6e0bSlewuathe llvm::Optional<bool> operator()(Type type) const { 36eaba6e0bSlewuathe auto elementType = type.cast<FloatType>(); 37eaba6e0bSlewuathe if (!elementType.isa<Float32Type, Float64Type>()) 38eaba6e0bSlewuathe return {}; 39eaba6e0bSlewuathe 40eaba6e0bSlewuathe return elementType.getIntOrFloatBitWidth() == 64; 41eaba6e0bSlewuathe } 42eaba6e0bSlewuathe }; 43eaba6e0bSlewuathe 44e4978713SBenjamin Kramer // Pattern to convert scalar complex operations to calls to libm functions. 45e4978713SBenjamin Kramer // Additionally the libm function signatures are declared. 46eaba6e0bSlewuathe // TypeResolver is a functor returning the libm function name according to the 47eaba6e0bSlewuathe // expected type double or float. 48eaba6e0bSlewuathe template <typename Op, typename TypeResolver = ComplexTypeResolver> 49e4978713SBenjamin Kramer struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 50e4978713SBenjamin Kramer public: 51e4978713SBenjamin Kramer using OpRewritePattern<Op>::OpRewritePattern; 52eaba6e0bSlewuathe ScalarOpToLibmCall<Op, TypeResolver>(MLIRContext *context, 53eaba6e0bSlewuathe StringRef floatFunc, 54eaba6e0bSlewuathe StringRef doubleFunc, 55eaba6e0bSlewuathe PatternBenefit benefit) 56e4978713SBenjamin Kramer : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 57e4978713SBenjamin Kramer doubleFunc(doubleFunc){}; 58e4978713SBenjamin Kramer 59e4978713SBenjamin Kramer LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 60e4978713SBenjamin Kramer 61e4978713SBenjamin Kramer private: 62e4978713SBenjamin Kramer std::string floatFunc, doubleFunc; 63e4978713SBenjamin Kramer }; 64e4978713SBenjamin Kramer } // namespace 65e4978713SBenjamin Kramer 66eaba6e0bSlewuathe template <typename Op, typename TypeResolver> 67eaba6e0bSlewuathe LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite( 68eaba6e0bSlewuathe Op op, PatternRewriter &rewriter) const { 69e4978713SBenjamin Kramer auto module = SymbolTable::getNearestSymbolTable(op); 70eaba6e0bSlewuathe auto isDouble = TypeResolver()(op.getType()); 71eaba6e0bSlewuathe if (!isDouble.hasValue()) 72e4978713SBenjamin Kramer return failure(); 73e4978713SBenjamin Kramer 74eaba6e0bSlewuathe auto name = isDouble.value() ? doubleFunc : floatFunc; 75eaba6e0bSlewuathe 76e4978713SBenjamin Kramer auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 77e4978713SBenjamin Kramer SymbolTable::lookupSymbolIn(module, name)); 78e4978713SBenjamin Kramer // Forward declare function if it hasn't already been 79e4978713SBenjamin Kramer if (!opFunc) { 80e4978713SBenjamin Kramer OpBuilder::InsertionGuard guard(rewriter); 81e4978713SBenjamin Kramer rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 82e4978713SBenjamin Kramer auto opFunctionTy = FunctionType::get( 83e4978713SBenjamin Kramer rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 84e4978713SBenjamin Kramer opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, 85e4978713SBenjamin Kramer opFunctionTy); 86e4978713SBenjamin Kramer opFunc.setPrivate(); 87e4978713SBenjamin Kramer } 88e4978713SBenjamin Kramer assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); 89e4978713SBenjamin Kramer 90eaba6e0bSlewuathe rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(), 91eaba6e0bSlewuathe op->getOperands()); 92e4978713SBenjamin Kramer 93e4978713SBenjamin Kramer return success(); 94e4978713SBenjamin Kramer } 95e4978713SBenjamin Kramer 96e4978713SBenjamin Kramer void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, 97e4978713SBenjamin Kramer PatternBenefit benefit) { 98e4978713SBenjamin Kramer patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(), 99e4978713SBenjamin Kramer "cpowf", "cpow", benefit); 100e4978713SBenjamin Kramer patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(), 101e4978713SBenjamin Kramer "csqrtf", "csqrt", benefit); 102e4978713SBenjamin Kramer patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(), 103e4978713SBenjamin Kramer "ctanhf", "ctanh", benefit); 1049f0869a6Slewuathe patterns.add<ScalarOpToLibmCall<complex::CosOp>>(patterns.getContext(), 1059f0869a6Slewuathe "ccosf", "ccos", benefit); 1069f0869a6Slewuathe patterns.add<ScalarOpToLibmCall<complex::SinOp>>(patterns.getContext(), 1079f0869a6Slewuathe "csinf", "csin", benefit); 10872ee11a8Slewuathe patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(), 10972ee11a8Slewuathe "conjf", "conj", benefit); 110eaba6e0bSlewuathe patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>( 111eaba6e0bSlewuathe patterns.getContext(), "cabsf", "cabs", benefit); 112*f27deeeeSlewuathe patterns.add<ScalarOpToLibmCall<complex::AngleOp, FloatTypeResolver>>( 113*f27deeeeSlewuathe patterns.getContext(), "cargf", "carg", benefit); 114e4978713SBenjamin Kramer } 115e4978713SBenjamin Kramer 116e4978713SBenjamin Kramer namespace { 117e4978713SBenjamin Kramer struct ConvertComplexToLibmPass 118e4978713SBenjamin Kramer : public ConvertComplexToLibmBase<ConvertComplexToLibmPass> { 119e4978713SBenjamin Kramer void runOnOperation() override; 120e4978713SBenjamin Kramer }; 121e4978713SBenjamin Kramer } // namespace 122e4978713SBenjamin Kramer 123e4978713SBenjamin Kramer void ConvertComplexToLibmPass::runOnOperation() { 124e4978713SBenjamin Kramer auto module = getOperation(); 125e4978713SBenjamin Kramer 126e4978713SBenjamin Kramer RewritePatternSet patterns(&getContext()); 127e4978713SBenjamin Kramer populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1); 128e4978713SBenjamin Kramer 129e4978713SBenjamin Kramer ConversionTarget target(getContext()); 130e4978713SBenjamin Kramer target.addLegalDialect<func::FuncDialect>(); 131eaba6e0bSlewuathe target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp, 132*f27deeeeSlewuathe complex::AbsOp, complex::AngleOp>(); 133e4978713SBenjamin Kramer if (failed(applyPartialConversion(module, target, std::move(patterns)))) 134e4978713SBenjamin Kramer signalPassFailure(); 135e4978713SBenjamin Kramer } 136e4978713SBenjamin Kramer 137e4978713SBenjamin Kramer std::unique_ptr<OperationPass<ModuleOp>> 138e4978713SBenjamin Kramer mlir::createConvertComplexToLibmPass() { 139e4978713SBenjamin Kramer return std::make_unique<ConvertComplexToLibmPass>(); 140e4978713SBenjamin Kramer } 141