1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToLibm/MathToLibm.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/Utils/IndexingUtils.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 18 #include "mlir/IR/BuiltinDialect.h" 19 #include "mlir/IR/PatternMatch.h" 20 21 using namespace mlir; 22 23 namespace { 24 // Pattern to convert vector operations to scalar operations. This is needed as 25 // libm calls require scalars. 26 template <typename Op> 27 struct VecOpToScalarOp : public OpRewritePattern<Op> { 28 public: 29 using OpRewritePattern<Op>::OpRewritePattern; 30 31 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 32 }; 33 // Pattern to promote an op of a smaller floating point type to F32. 34 template <typename Op> 35 struct PromoteOpToF32 : public OpRewritePattern<Op> { 36 public: 37 using OpRewritePattern<Op>::OpRewritePattern; 38 39 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 40 }; 41 // Pattern to convert scalar math operations to calls to libm functions. 42 // Additionally the libm function signatures are declared. 43 template <typename Op> 44 struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 45 public: 46 using OpRewritePattern<Op>::OpRewritePattern; 47 ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc, 48 StringRef doubleFunc, PatternBenefit benefit) 49 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 50 doubleFunc(doubleFunc){}; 51 52 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 53 54 private: 55 std::string floatFunc, doubleFunc; 56 }; 57 } // namespace 58 59 template <typename Op> 60 LogicalResult 61 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 62 auto opType = op.getType(); 63 auto loc = op.getLoc(); 64 auto vecType = opType.template dyn_cast<VectorType>(); 65 66 if (!vecType) 67 return failure(); 68 if (!vecType.hasRank()) 69 return failure(); 70 auto shape = vecType.getShape(); 71 int64_t numElements = vecType.getNumElements(); 72 73 Value result = rewriter.create<arith::ConstantOp>( 74 loc, DenseElementsAttr::get( 75 vecType, FloatAttr::get(vecType.getElementType(), 0.0))); 76 SmallVector<int64_t> ones(shape.size(), 1); 77 SmallVector<int64_t> strides = computeStrides(shape, ones); 78 for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { 79 SmallVector<int64_t> positions = delinearize(strides, linearIndex); 80 SmallVector<Value> operands; 81 for (auto input : op->getOperands()) 82 operands.push_back( 83 rewriter.create<vector::ExtractOp>(loc, input, positions)); 84 Value scalarOp = 85 rewriter.create<Op>(loc, vecType.getElementType(), operands); 86 result = 87 rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); 88 } 89 rewriter.replaceOp(op, {result}); 90 return success(); 91 } 92 93 template <typename Op> 94 LogicalResult 95 PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 96 auto opType = op.getType(); 97 if (!opType.template isa<Float16Type, BFloat16Type>()) 98 return failure(); 99 100 auto loc = op.getLoc(); 101 auto f32 = rewriter.getF32Type(); 102 auto extendedOperands = llvm::to_vector( 103 llvm::map_range(op->getOperands(), [&](Value operand) -> Value { 104 return rewriter.create<arith::ExtFOp>(loc, f32, operand); 105 })); 106 auto newOp = rewriter.create<Op>(loc, f32, extendedOperands); 107 rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp); 108 return success(); 109 } 110 111 template <typename Op> 112 LogicalResult 113 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, 114 PatternRewriter &rewriter) const { 115 auto module = SymbolTable::getNearestSymbolTable(op); 116 auto type = op.getType(); 117 if (!type.template isa<Float32Type, Float64Type>()) 118 return failure(); 119 120 auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; 121 auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 122 SymbolTable::lookupSymbolIn(module, name)); 123 // Forward declare function if it hasn't already been 124 if (!opFunc) { 125 OpBuilder::InsertionGuard guard(rewriter); 126 rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 127 auto opFunctionTy = FunctionType::get( 128 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 129 opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, 130 opFunctionTy); 131 opFunc.setPrivate(); 132 } 133 assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); 134 135 rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(), 136 op->getOperands()); 137 138 return success(); 139 } 140 141 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, 142 PatternBenefit benefit) { 143 patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>, 144 VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>, 145 VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>, 146 VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>, 147 VecOpToScalarOp<math::TanOp>>(patterns.getContext(), benefit); 148 patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>, 149 PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>, 150 PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>, 151 PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>, 152 PromoteOpToF32<math::TanOp>>(patterns.getContext(), benefit); 153 patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf", 154 "atan", benefit); 155 patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(), 156 "atan2f", "atan2", benefit); 157 patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff", 158 "erf", benefit); 159 patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(), 160 "expm1f", "expm1", benefit); 161 patterns.add<ScalarOpToLibmCall<math::TanOp>>(patterns.getContext(), "tanf", 162 "tan", benefit); 163 patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf", 164 "tanh", benefit); 165 patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(), 166 "roundf", "round", benefit); 167 patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf", 168 "cos", benefit); 169 patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf", 170 "sin", benefit); 171 } 172 173 namespace { 174 struct ConvertMathToLibmPass 175 : public ConvertMathToLibmBase<ConvertMathToLibmPass> { 176 void runOnOperation() override; 177 }; 178 } // namespace 179 180 void ConvertMathToLibmPass::runOnOperation() { 181 auto module = getOperation(); 182 183 RewritePatternSet patterns(&getContext()); 184 populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); 185 186 ConversionTarget target(getContext()); 187 target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect, 188 func::FuncDialect, vector::VectorDialect>(); 189 target.addIllegalDialect<math::MathDialect>(); 190 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 191 signalPassFailure(); 192 } 193 194 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() { 195 return std::make_unique<ConvertMathToLibmPass>(); 196 } 197