1 //===- MathOps.cpp - MLIR operations for math implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Math/IR/Math.h"
11 #include "mlir/IR/Builders.h"
12 
13 using namespace mlir;
14 using namespace mlir::math;
15 
16 //===----------------------------------------------------------------------===//
17 // TableGen'd op method definitions
18 //===----------------------------------------------------------------------===//
19 
20 #define GET_OP_CLASSES
21 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
22 
23 //===----------------------------------------------------------------------===//
24 // CeilOp folder
25 //===----------------------------------------------------------------------===//
26 
27 OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
28   auto constOperand = operands.front();
29   if (!constOperand)
30     return {};
31 
32   auto attr = constOperand.dyn_cast<FloatAttr>();
33   if (!attr)
34     return {};
35 
36   APFloat sourceVal = attr.getValue();
37   sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive);
38 
39   return FloatAttr::get(getType(), sourceVal);
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Log2Op folder
44 //===----------------------------------------------------------------------===//
45 
46 OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
47   auto constOperand = operands.front();
48   if (!constOperand)
49     return {};
50 
51   auto attr = constOperand.dyn_cast<FloatAttr>();
52   if (!attr)
53     return {};
54 
55   auto ft = getType().cast<FloatType>();
56 
57   APFloat apf = attr.getValue();
58 
59   if (apf.isNegative())
60     return {};
61 
62   if (ft.getWidth() == 64)
63     return FloatAttr::get(getType(), log2(apf.convertToDouble()));
64 
65   if (ft.getWidth() == 32)
66     return FloatAttr::get(getType(), log2f(apf.convertToDouble()));
67 
68   return {};
69 }
70 
71 /// Materialize an integer or floating point constant.
72 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
73                                                   Attribute value, Type type,
74                                                   Location loc) {
75   return builder.create<arith::ConstantOp>(loc, value, type);
76 }
77