1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToLibm/MathToLibm.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Math/IR/Math.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Vector/VectorOps.h" 16 #include "mlir/Dialect/Vector/VectorUtils.h" 17 #include "mlir/IR/BuiltinDialect.h" 18 #include "mlir/IR/PatternMatch.h" 19 20 using namespace mlir; 21 22 namespace { 23 // Pattern to convert vector operations to scalar operations. This is needed as 24 // libm calls require scalars. 25 template <typename Op> 26 struct VecOpToScalarOp : public OpRewritePattern<Op> { 27 public: 28 using OpRewritePattern<Op>::OpRewritePattern; 29 30 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 31 }; 32 // Pattern to convert scalar math operations to calls to libm functions. 33 // Additionally the libm function signatures are declared. 34 template <typename Op> 35 struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 36 public: 37 using OpRewritePattern<Op>::OpRewritePattern; 38 ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc, 39 StringRef doubleFunc, PatternBenefit benefit) 40 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 41 doubleFunc(doubleFunc){}; 42 43 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 44 45 private: 46 std::string floatFunc, doubleFunc; 47 }; 48 } // namespace 49 50 template <typename Op> 51 LogicalResult 52 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 53 auto opType = op.getType(); 54 auto loc = op.getLoc(); 55 auto vecType = opType.template dyn_cast<VectorType>(); 56 57 if (!vecType) 58 return failure(); 59 if (!vecType.hasRank()) 60 return failure(); 61 auto shape = vecType.getShape(); 62 int64_t numElements = vecType.getNumElements(); 63 64 Value result = rewriter.create<arith::ConstantOp>( 65 loc, DenseElementsAttr::get( 66 vecType, FloatAttr::get(vecType.getElementType(), 0.0))); 67 SmallVector<int64_t> ones(shape.size(), 1); 68 SmallVector<int64_t> strides = computeStrides(shape, ones); 69 for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { 70 SmallVector<int64_t> positions = delinearize(strides, linearIndex); 71 SmallVector<Value> operands; 72 for (auto input : op->getOperands()) 73 operands.push_back( 74 rewriter.create<vector::ExtractOp>(loc, input, positions)); 75 Value scalarOp = 76 rewriter.create<Op>(loc, vecType.getElementType(), operands); 77 result = 78 rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); 79 } 80 rewriter.replaceOp(op, {result}); 81 return success(); 82 } 83 84 template <typename Op> 85 LogicalResult 86 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, 87 PatternRewriter &rewriter) const { 88 auto module = SymbolTable::getNearestSymbolTable(op); 89 auto type = op.getType(); 90 // TODO: Support Float16 by upcasting to Float32 91 if (!type.template isa<Float32Type, Float64Type>()) 92 return failure(); 93 94 auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; 95 auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 96 SymbolTable::lookupSymbolIn(module, name)); 97 // Forward declare function if it hasn't already been 98 if (!opFunc) { 99 OpBuilder::InsertionGuard guard(rewriter); 100 rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 101 auto opFunctionTy = FunctionType::get( 102 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 103 opFunc = 104 rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy); 105 opFunc.setPrivate(); 106 } 107 assert(SymbolTable::lookupSymbolIn(module, name) 108 ->template hasTrait<mlir::OpTrait::FunctionLike>()); 109 110 rewriter.replaceOpWithNewOp<CallOp>(op, name, op.getType(), 111 op->getOperands()); 112 113 return success(); 114 } 115 116 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, 117 PatternBenefit benefit) { 118 patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>, 119 VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit); 120 patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(), 121 "atan2f", "atan2", benefit); 122 patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff", 123 "erf", benefit); 124 patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(), 125 "expm1f", "expm1", benefit); 126 patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf", 127 "tanh", benefit); 128 } 129 130 namespace { 131 struct ConvertMathToLibmPass 132 : public ConvertMathToLibmBase<ConvertMathToLibmPass> { 133 void runOnOperation() override; 134 }; 135 } // namespace 136 137 void ConvertMathToLibmPass::runOnOperation() { 138 auto module = getOperation(); 139 140 RewritePatternSet patterns(&getContext()); 141 populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); 142 143 ConversionTarget target(getContext()); 144 target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect, 145 StandardOpsDialect, vector::VectorDialect>(); 146 target.addIllegalDialect<math::MathDialect>(); 147 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 148 signalPassFailure(); 149 } 150 151 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() { 152 return std::make_unique<ConvertMathToLibmPass>(); 153 } 154