1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToLibm/MathToLibm.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Math/IR/Math.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Utils/IndexingUtils.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 18 #include "mlir/IR/BuiltinDialect.h" 19 #include "mlir/IR/PatternMatch.h" 20 21 using namespace mlir; 22 23 namespace { 24 // Pattern to convert vector operations to scalar operations. This is needed as 25 // libm calls require scalars. 26 template <typename Op> 27 struct VecOpToScalarOp : public OpRewritePattern<Op> { 28 public: 29 using OpRewritePattern<Op>::OpRewritePattern; 30 31 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 32 }; 33 // Pattern to convert scalar math operations to calls to libm functions. 34 // Additionally the libm function signatures are declared. 35 template <typename Op> 36 struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 37 public: 38 using OpRewritePattern<Op>::OpRewritePattern; 39 ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc, 40 StringRef doubleFunc, PatternBenefit benefit) 41 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 42 doubleFunc(doubleFunc){}; 43 44 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 45 46 private: 47 std::string floatFunc, doubleFunc; 48 }; 49 } // namespace 50 51 template <typename Op> 52 LogicalResult 53 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 54 auto opType = op.getType(); 55 auto loc = op.getLoc(); 56 auto vecType = opType.template dyn_cast<VectorType>(); 57 58 if (!vecType) 59 return failure(); 60 if (!vecType.hasRank()) 61 return failure(); 62 auto shape = vecType.getShape(); 63 int64_t numElements = vecType.getNumElements(); 64 65 Value result = rewriter.create<arith::ConstantOp>( 66 loc, DenseElementsAttr::get( 67 vecType, FloatAttr::get(vecType.getElementType(), 0.0))); 68 SmallVector<int64_t> ones(shape.size(), 1); 69 SmallVector<int64_t> strides = computeStrides(shape, ones); 70 for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { 71 SmallVector<int64_t> positions = delinearize(strides, linearIndex); 72 SmallVector<Value> operands; 73 for (auto input : op->getOperands()) 74 operands.push_back( 75 rewriter.create<vector::ExtractOp>(loc, input, positions)); 76 Value scalarOp = 77 rewriter.create<Op>(loc, vecType.getElementType(), operands); 78 result = 79 rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); 80 } 81 rewriter.replaceOp(op, {result}); 82 return success(); 83 } 84 85 template <typename Op> 86 LogicalResult 87 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, 88 PatternRewriter &rewriter) const { 89 auto module = SymbolTable::getNearestSymbolTable(op); 90 auto type = op.getType(); 91 // TODO: Support Float16 by upcasting to Float32 92 if (!type.template isa<Float32Type, Float64Type>()) 93 return failure(); 94 95 auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; 96 auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 97 SymbolTable::lookupSymbolIn(module, name)); 98 // Forward declare function if it hasn't already been 99 if (!opFunc) { 100 OpBuilder::InsertionGuard guard(rewriter); 101 rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 102 auto opFunctionTy = FunctionType::get( 103 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 104 opFunc = 105 rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy); 106 opFunc.setPrivate(); 107 } 108 assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); 109 110 rewriter.replaceOpWithNewOp<CallOp>(op, name, op.getType(), 111 op->getOperands()); 112 113 return success(); 114 } 115 116 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, 117 PatternBenefit benefit) { 118 patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>, 119 VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit); 120 patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(), 121 "atan2f", "atan2", benefit); 122 patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff", 123 "erf", benefit); 124 patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(), 125 "expm1f", "expm1", benefit); 126 patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf", 127 "tanh", benefit); 128 } 129 130 namespace { 131 struct ConvertMathToLibmPass 132 : public ConvertMathToLibmBase<ConvertMathToLibmPass> { 133 void runOnOperation() override; 134 }; 135 } // namespace 136 137 void ConvertMathToLibmPass::runOnOperation() { 138 auto module = getOperation(); 139 140 RewritePatternSet patterns(&getContext()); 141 populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); 142 143 ConversionTarget target(getContext()); 144 target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect, 145 StandardOpsDialect, vector::VectorDialect>(); 146 target.addIllegalDialect<math::MathDialect>(); 147 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 148 signalPassFailure(); 149 } 150 151 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() { 152 return std::make_unique<ConvertMathToLibmPass>(); 153 } 154