1 //===- MathOps.cpp - MLIR operations for math implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Math/IR/Math.h"
11 #include "mlir/IR/Builders.h"
12 
13 using namespace mlir;
14 using namespace mlir::math;
15 
16 //===----------------------------------------------------------------------===//
17 // TableGen'd op method definitions
18 //===----------------------------------------------------------------------===//
19 
20 #define GET_OP_CLASSES
21 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
22 
23 //===----------------------------------------------------------------------===//
24 // AbsOp folder
25 //===----------------------------------------------------------------------===//
26 
27 OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
28   auto constOperand = operands.front();
29   if (!constOperand)
30     return {};
31 
32   auto attr = constOperand.dyn_cast<FloatAttr>();
33   if (!attr)
34     return {};
35 
36   auto ft = getType().cast<FloatType>();
37 
38   APFloat apf = attr.getValue();
39 
40   if (ft.getWidth() == 64)
41     return FloatAttr::get(getType(), fabs(apf.convertToDouble()));
42 
43   if (ft.getWidth() == 32)
44     return FloatAttr::get(getType(), fabsf(apf.convertToFloat()));
45 
46   return {};
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // CeilOp folder
51 //===----------------------------------------------------------------------===//
52 
53 OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
54   auto constOperand = operands.front();
55   if (!constOperand)
56     return {};
57 
58   auto attr = constOperand.dyn_cast<FloatAttr>();
59   if (!attr)
60     return {};
61 
62   APFloat sourceVal = attr.getValue();
63   sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive);
64 
65   return FloatAttr::get(getType(), sourceVal);
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // CopySignOp folder
70 //===----------------------------------------------------------------------===//
71 
72 OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
73   auto ft = getType().dyn_cast<FloatType>();
74   if (!ft)
75     return {};
76 
77   APFloat vals[2]{APFloat(ft.getFloatSemantics()),
78                   APFloat(ft.getFloatSemantics())};
79   for (int i = 0; i < 2; ++i) {
80     if (!operands[i])
81       return {};
82 
83     auto attr = operands[i].dyn_cast<FloatAttr>();
84     if (!attr)
85       return {};
86 
87     vals[i] = attr.getValue();
88   }
89 
90   vals[0].copySign(vals[1]);
91 
92   return FloatAttr::get(getType(), vals[0]);
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // CountLeadingZerosOp folder
97 //===----------------------------------------------------------------------===//
98 
99 OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
100   auto constOperand = operands.front();
101   if (!constOperand)
102     return {};
103 
104   auto attr = constOperand.dyn_cast<IntegerAttr>();
105   if (!attr)
106     return {};
107 
108   return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros());
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // CountTrailingZerosOp folder
113 //===----------------------------------------------------------------------===//
114 
115 OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
116   auto constOperand = operands.front();
117   if (!constOperand)
118     return {};
119 
120   auto attr = constOperand.dyn_cast<IntegerAttr>();
121   if (!attr)
122     return {};
123 
124   return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros());
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // CtPopOp folder
129 //===----------------------------------------------------------------------===//
130 
131 OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
132   auto constOperand = operands.front();
133   if (!constOperand)
134     return {};
135 
136   auto attr = constOperand.dyn_cast<IntegerAttr>();
137   if (!attr)
138     return {};
139 
140   return IntegerAttr::get(getType(), attr.getValue().countPopulation());
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // Log2Op folder
145 //===----------------------------------------------------------------------===//
146 
147 OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
148   auto constOperand = operands.front();
149   if (!constOperand)
150     return {};
151 
152   auto attr = constOperand.dyn_cast<FloatAttr>();
153   if (!attr)
154     return {};
155 
156   auto ft = getType().cast<FloatType>();
157 
158   APFloat apf = attr.getValue();
159 
160   if (apf.isNegative())
161     return {};
162 
163   if (ft.getWidth() == 64)
164     return FloatAttr::get(getType(), log2(apf.convertToDouble()));
165 
166   if (ft.getWidth() == 32)
167     return FloatAttr::get(getType(), log2f(apf.convertToFloat()));
168 
169   return {};
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // PowFOp folder
174 //===----------------------------------------------------------------------===//
175 
176 OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
177   auto ft = getType().dyn_cast<FloatType>();
178   if (!ft)
179     return {};
180 
181   APFloat vals[2]{APFloat(ft.getFloatSemantics()),
182                   APFloat(ft.getFloatSemantics())};
183   for (int i = 0; i < 2; ++i) {
184     if (!operands[i])
185       return {};
186 
187     auto attr = operands[i].dyn_cast<FloatAttr>();
188     if (!attr)
189       return {};
190 
191     vals[i] = attr.getValue();
192   }
193 
194   if (ft.getWidth() == 64)
195     return FloatAttr::get(
196         getType(), pow(vals[0].convertToDouble(), vals[1].convertToDouble()));
197 
198   if (ft.getWidth() == 32)
199     return FloatAttr::get(
200         getType(), powf(vals[0].convertToFloat(), vals[1].convertToFloat()));
201 
202   return {};
203 }
204 
205 OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
206   auto constOperand = operands.front();
207   if (!constOperand)
208     return {};
209 
210   auto attr = constOperand.dyn_cast<FloatAttr>();
211   if (!attr)
212     return {};
213 
214   auto ft = getType().cast<FloatType>();
215 
216   APFloat apf = attr.getValue();
217 
218   if (apf.isNegative())
219     return {};
220 
221   if (ft.getWidth() == 64)
222     return FloatAttr::get(getType(), sqrt(apf.convertToDouble()));
223 
224   if (ft.getWidth() == 32)
225     return FloatAttr::get(getType(), sqrtf(apf.convertToFloat()));
226 
227   return {};
228 }
229 
230 /// Materialize an integer or floating point constant.
231 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
232                                                   Attribute value, Type type,
233                                                   Location loc) {
234   return builder.create<arith::ConstantOp>(loc, value, type);
235 }
236