1 //===- MathOps.cpp - MLIR operations for math implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Math/IR/Math.h"
11 #include "mlir/IR/Builders.h"
12 
13 using namespace mlir;
14 using namespace mlir::math;
15 
16 //===----------------------------------------------------------------------===//
17 // TableGen'd op method definitions
18 //===----------------------------------------------------------------------===//
19 
20 #define GET_OP_CLASSES
21 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
22 
23 //===----------------------------------------------------------------------===//
24 // CeilOp folder
25 //===----------------------------------------------------------------------===//
26 
27 OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
28   auto constOperand = operands.front();
29   if (!constOperand)
30     return {};
31 
32   auto attr = constOperand.dyn_cast<FloatAttr>();
33   if (!attr)
34     return {};
35 
36   APFloat sourceVal = attr.getValue();
37   sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive);
38 
39   return FloatAttr::get(getType(), sourceVal);
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Log2Op folder
44 //===----------------------------------------------------------------------===//
45 
46 OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
47   auto constOperand = operands.front();
48   if (!constOperand)
49     return {};
50 
51   auto attr = constOperand.dyn_cast<FloatAttr>();
52   if (!attr)
53     return {};
54 
55   auto ft = getType().cast<FloatType>();
56 
57   APFloat apf = attr.getValue();
58 
59   if (apf.isNegative())
60     return {};
61 
62   if (ft.getWidth() == 64)
63     return FloatAttr::get(getType(), log2(apf.convertToDouble()));
64 
65   if (ft.getWidth() == 32)
66     return FloatAttr::get(getType(), log2f(apf.convertToFloat()));
67 
68   return {};
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // PowFOp folder
73 //===----------------------------------------------------------------------===//
74 
75 OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
76   auto ft = getType().dyn_cast<FloatType>();
77   if (!ft)
78     return {};
79 
80   APFloat vals[2]{APFloat(ft.getFloatSemantics()),
81                   APFloat(ft.getFloatSemantics())};
82   for (int i = 0; i < 2; ++i) {
83     if (!operands[i])
84       return {};
85 
86     auto attr = operands[i].dyn_cast<FloatAttr>();
87     if (!attr)
88       return {};
89 
90     vals[i] = attr.getValue();
91   }
92 
93   if (ft.getWidth() == 64)
94     return FloatAttr::get(
95         getType(), pow(vals[0].convertToDouble(), vals[1].convertToDouble()));
96 
97   if (ft.getWidth() == 32)
98     return FloatAttr::get(
99         getType(), powf(vals[0].convertToFloat(), vals[1].convertToFloat()));
100 
101   return {};
102 }
103 
104 OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
105   auto constOperand = operands.front();
106   if (!constOperand)
107     return {};
108 
109   auto attr = constOperand.dyn_cast<FloatAttr>();
110   if (!attr)
111     return {};
112 
113   auto ft = getType().cast<FloatType>();
114 
115   APFloat apf = attr.getValue();
116 
117   if (apf.isNegative())
118     return {};
119 
120   if (ft.getWidth() == 64)
121     return FloatAttr::get(getType(), sqrt(apf.convertToDouble()));
122 
123   if (ft.getWidth() == 32)
124     return FloatAttr::get(getType(), sqrtf(apf.convertToFloat()));
125 
126   return {};
127 }
128 
129 /// Materialize an integer or floating point constant.
130 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
131                                                   Attribute value, Type type,
132                                                   Location loc) {
133   return builder.create<arith::ConstantOp>(loc, value, type);
134 }
135