1 //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Complex/IR/Complex.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/IR/PatternMatch.h" 15 16 using namespace mlir; 17 18 namespace { 19 // Pattern to convert scalar complex operations to calls to libm functions. 20 // Additionally the libm function signatures are declared. 21 template <typename Op> 22 struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 23 public: 24 using OpRewritePattern<Op>::OpRewritePattern; 25 ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc, 26 StringRef doubleFunc, PatternBenefit benefit) 27 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 28 doubleFunc(doubleFunc){}; 29 30 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 31 32 private: 33 std::string floatFunc, doubleFunc; 34 }; 35 } // namespace 36 37 template <typename Op> 38 LogicalResult 39 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, 40 PatternRewriter &rewriter) const { 41 auto module = SymbolTable::getNearestSymbolTable(op); 42 auto type = op.getType().template cast<ComplexType>(); 43 Type elementType = type.getElementType(); 44 if (!elementType.isa<Float32Type, Float64Type>()) 45 return failure(); 46 47 auto name = 48 elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; 49 auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 50 SymbolTable::lookupSymbolIn(module, name)); 51 // Forward declare function if it hasn't already been 52 if (!opFunc) { 53 OpBuilder::InsertionGuard guard(rewriter); 54 rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 55 auto opFunctionTy = FunctionType::get( 56 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 57 opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, 58 opFunctionTy); 59 opFunc.setPrivate(); 60 } 61 assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); 62 63 rewriter.replaceOpWithNewOp<func::CallOp>(op, name, type, op->getOperands()); 64 65 return success(); 66 } 67 68 void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, 69 PatternBenefit benefit) { 70 patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(), 71 "cpowf", "cpow", benefit); 72 patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(), 73 "csqrtf", "csqrt", benefit); 74 patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(), 75 "ctanhf", "ctanh", benefit); 76 patterns.add<ScalarOpToLibmCall<complex::CosOp>>(patterns.getContext(), 77 "ccosf", "ccos", benefit); 78 patterns.add<ScalarOpToLibmCall<complex::SinOp>>(patterns.getContext(), 79 "csinf", "csin", benefit); 80 patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(), 81 "conjf", "conj", benefit); 82 } 83 84 namespace { 85 struct ConvertComplexToLibmPass 86 : public ConvertComplexToLibmBase<ConvertComplexToLibmPass> { 87 void runOnOperation() override; 88 }; 89 } // namespace 90 91 void ConvertComplexToLibmPass::runOnOperation() { 92 auto module = getOperation(); 93 94 RewritePatternSet patterns(&getContext()); 95 populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1); 96 97 ConversionTarget target(getContext()); 98 target.addLegalDialect<func::FuncDialect>(); 99 target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp>(); 100 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 101 signalPassFailure(); 102 } 103 104 std::unique_ptr<OperationPass<ModuleOp>> 105 mlir::createConvertComplexToLibmPass() { 106 return std::make_unique<ConvertComplexToLibmPass>(); 107 } 108