//===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" #include "../PassDetail.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; namespace { // Pattern to convert scalar complex operations to calls to libm functions. // Additionally the libm function signatures are declared. template struct ScalarOpToLibmCall : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, StringRef doubleFunc, PatternBenefit benefit) : OpRewritePattern(context, benefit), floatFunc(floatFunc), doubleFunc(doubleFunc){}; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; private: std::string floatFunc, doubleFunc; }; } // namespace template LogicalResult ScalarOpToLibmCall::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto module = SymbolTable::getNearestSymbolTable(op); auto type = op.getType().template cast(); Type elementType = type.getElementType(); if (!elementType.isa()) return failure(); auto name = elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; auto opFunc = dyn_cast_or_null( SymbolTable::lookupSymbolIn(module, name)); // Forward declare function if it hasn't already been if (!opFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&module->getRegion(0).front()); auto opFunctionTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); opFunc = rewriter.create(rewriter.getUnknownLoc(), name, opFunctionTy); opFunc.setPrivate(); } assert(isa(SymbolTable::lookupSymbolIn(module, name))); rewriter.replaceOpWithNewOp(op, name, type, op->getOperands()); return success(); } void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add>(patterns.getContext(), "cpowf", "cpow", benefit); patterns.add>(patterns.getContext(), "csqrtf", "csqrt", benefit); patterns.add>(patterns.getContext(), "ctanhf", "ctanh", benefit); patterns.add>(patterns.getContext(), "ccosf", "ccos", benefit); patterns.add>(patterns.getContext(), "csinf", "csin", benefit); patterns.add>(patterns.getContext(), "conjf", "conj", benefit); } namespace { struct ConvertComplexToLibmPass : public ConvertComplexToLibmBase { void runOnOperation() override; }; } // namespace void ConvertComplexToLibmPass::runOnOperation() { auto module = getOperation(); RewritePatternSet patterns(&getContext()); populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1); ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } std::unique_ptr> mlir::createConvertComplexToLibmPass() { return std::make_unique(); }