1 //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/IR/PatternMatch.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 // Pattern to convert scalar complex operations to calls to libm functions.
20 // Additionally the libm function signatures are declared.
21 template <typename Op>
22 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
23 public:
24   using OpRewritePattern<Op>::OpRewritePattern;
25   ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
26                          StringRef doubleFunc, PatternBenefit benefit)
27       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
28         doubleFunc(doubleFunc){};
29 
30   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
31 
32 private:
33   std::string floatFunc, doubleFunc;
34 };
35 } // namespace
36 
37 template <typename Op>
38 LogicalResult
39 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
40                                         PatternRewriter &rewriter) const {
41   auto module = SymbolTable::getNearestSymbolTable(op);
42   auto type = op.getType().template cast<ComplexType>();
43   Type elementType = type.getElementType();
44   if (!elementType.isa<Float32Type, Float64Type>())
45     return failure();
46 
47   auto name =
48       elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
49   auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
50       SymbolTable::lookupSymbolIn(module, name));
51   // Forward declare function if it hasn't already been
52   if (!opFunc) {
53     OpBuilder::InsertionGuard guard(rewriter);
54     rewriter.setInsertionPointToStart(&module->getRegion(0).front());
55     auto opFunctionTy = FunctionType::get(
56         rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
57     opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
58                                            opFunctionTy);
59     opFunc.setPrivate();
60   }
61   assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
62 
63   rewriter.replaceOpWithNewOp<func::CallOp>(op, name, type, op->getOperands());
64 
65   return success();
66 }
67 
68 void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
69                                                    PatternBenefit benefit) {
70   patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(),
71                                                    "cpowf", "cpow", benefit);
72   patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(),
73                                                     "csqrtf", "csqrt", benefit);
74   patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(),
75                                                     "ctanhf", "ctanh", benefit);
76   patterns.add<ScalarOpToLibmCall<complex::CosOp>>(patterns.getContext(),
77                                                    "ccosf", "ccos", benefit);
78   patterns.add<ScalarOpToLibmCall<complex::SinOp>>(patterns.getContext(),
79                                                    "csinf", "csin", benefit);
80 }
81 
82 namespace {
83 struct ConvertComplexToLibmPass
84     : public ConvertComplexToLibmBase<ConvertComplexToLibmPass> {
85   void runOnOperation() override;
86 };
87 } // namespace
88 
89 void ConvertComplexToLibmPass::runOnOperation() {
90   auto module = getOperation();
91 
92   RewritePatternSet patterns(&getContext());
93   populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1);
94 
95   ConversionTarget target(getContext());
96   target.addLegalDialect<func::FuncDialect>();
97   target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp>();
98   if (failed(applyPartialConversion(module, target, std::move(patterns))))
99     signalPassFailure();
100 }
101 
102 std::unique_ptr<OperationPass<ModuleOp>>
103 mlir::createConvertComplexToLibmPass() {
104   return std::make_unique<ConvertComplexToLibmPass>();
105 }
106