1 //===- MathOps.cpp - MLIR operations for math implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/CommonFolders.h"
11 #include "mlir/Dialect/Math/IR/Math.h"
12 #include "mlir/IR/Builders.h"
13 
14 using namespace mlir;
15 using namespace mlir::math;
16 
17 //===----------------------------------------------------------------------===//
18 // TableGen'd op method definitions
19 //===----------------------------------------------------------------------===//
20 
21 #define GET_OP_CLASSES
22 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
23 
24 //===----------------------------------------------------------------------===//
25 // AbsOp folder
26 //===----------------------------------------------------------------------===//
27 
fold(ArrayRef<Attribute> operands)28 OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
29   return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
30     const APFloat &result(a);
31     return abs(result);
32   });
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // CeilOp folder
37 //===----------------------------------------------------------------------===//
38 
fold(ArrayRef<Attribute> operands)39 OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
40   return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
41     APFloat result(a);
42     result.roundToIntegral(llvm::RoundingMode::TowardPositive);
43     return result;
44   });
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // CopySignOp folder
49 //===----------------------------------------------------------------------===//
50 
fold(ArrayRef<Attribute> operands)51 OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
52   return constFoldBinaryOp<FloatAttr>(operands,
53                                       [](const APFloat &a, const APFloat &b) {
54                                         APFloat result(a);
55                                         result.copySign(b);
56                                         return result;
57                                       });
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // CountLeadingZerosOp folder
62 //===----------------------------------------------------------------------===//
63 
fold(ArrayRef<Attribute> operands)64 OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
65   return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
66     return APInt(a.getBitWidth(), a.countLeadingZeros());
67   });
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // CountTrailingZerosOp folder
72 //===----------------------------------------------------------------------===//
73 
fold(ArrayRef<Attribute> operands)74 OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
75   return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
76     return APInt(a.getBitWidth(), a.countTrailingZeros());
77   });
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // CtPopOp folder
82 //===----------------------------------------------------------------------===//
83 
fold(ArrayRef<Attribute> operands)84 OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
85   return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
86     return APInt(a.getBitWidth(), a.countPopulation());
87   });
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // LogOp folder
92 //===----------------------------------------------------------------------===//
93 
fold(ArrayRef<Attribute> operands)94 OpFoldResult math::LogOp::fold(ArrayRef<Attribute> operands) {
95   return constFoldUnaryOpConditional<FloatAttr>(
96       operands, [](const APFloat &a) -> Optional<APFloat> {
97         if (a.isNegative())
98           return {};
99 
100         if (a.getSizeInBits(a.getSemantics()) == 64)
101           return APFloat(log(a.convertToDouble()));
102 
103         if (a.getSizeInBits(a.getSemantics()) == 32)
104           return APFloat(logf(a.convertToFloat()));
105 
106         return {};
107       });
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // Log2Op folder
112 //===----------------------------------------------------------------------===//
113 
fold(ArrayRef<Attribute> operands)114 OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
115   return constFoldUnaryOpConditional<FloatAttr>(
116       operands, [](const APFloat &a) -> Optional<APFloat> {
117         if (a.isNegative())
118           return {};
119 
120         if (a.getSizeInBits(a.getSemantics()) == 64)
121           return APFloat(log2(a.convertToDouble()));
122 
123         if (a.getSizeInBits(a.getSemantics()) == 32)
124           return APFloat(log2f(a.convertToFloat()));
125 
126         return {};
127       });
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // Log10Op folder
132 //===----------------------------------------------------------------------===//
133 
fold(ArrayRef<Attribute> operands)134 OpFoldResult math::Log10Op::fold(ArrayRef<Attribute> operands) {
135   return constFoldUnaryOpConditional<FloatAttr>(
136       operands, [](const APFloat &a) -> Optional<APFloat> {
137         if (a.isNegative())
138           return {};
139 
140         switch (a.getSizeInBits(a.getSemantics())) {
141         case 64:
142           return APFloat(log10(a.convertToDouble()));
143         case 32:
144           return APFloat(log10f(a.convertToFloat()));
145         default:
146           return {};
147         }
148       });
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // Log1pOp folder
153 //===----------------------------------------------------------------------===//
154 
fold(ArrayRef<Attribute> operands)155 OpFoldResult math::Log1pOp::fold(ArrayRef<Attribute> operands) {
156   return constFoldUnaryOpConditional<FloatAttr>(
157       operands, [](const APFloat &a) -> Optional<APFloat> {
158         switch (a.getSizeInBits(a.getSemantics())) {
159         case 64:
160           if ((a + APFloat(1.0)).isNegative())
161             return {};
162           return APFloat(log1p(a.convertToDouble()));
163         case 32:
164           if ((a + APFloat(1.0f)).isNegative())
165             return {};
166           return APFloat(log1pf(a.convertToFloat()));
167         default:
168           return {};
169         }
170       });
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // PowFOp folder
175 //===----------------------------------------------------------------------===//
176 
fold(ArrayRef<Attribute> operands)177 OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
178   return constFoldBinaryOpConditional<FloatAttr>(
179       operands, [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
180         if (a.getSizeInBits(a.getSemantics()) == 64 &&
181             b.getSizeInBits(b.getSemantics()) == 64)
182           return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
183 
184         if (a.getSizeInBits(a.getSemantics()) == 32 &&
185             b.getSizeInBits(b.getSemantics()) == 32)
186           return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
187 
188         return {};
189       });
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // SqrtOp folder
194 //===----------------------------------------------------------------------===//
195 
fold(ArrayRef<Attribute> operands)196 OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
197   return constFoldUnaryOpConditional<FloatAttr>(
198       operands, [](const APFloat &a) -> Optional<APFloat> {
199         if (a.isNegative())
200           return {};
201 
202         switch (a.getSizeInBits(a.getSemantics())) {
203         case 64:
204           return APFloat(sqrt(a.convertToDouble()));
205         case 32:
206           return APFloat(sqrtf(a.convertToFloat()));
207         default:
208           return {};
209         }
210       });
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // ExpOp folder
215 //===----------------------------------------------------------------------===//
216 
fold(ArrayRef<Attribute> operands)217 OpFoldResult math::ExpOp::fold(ArrayRef<Attribute> operands) {
218   return constFoldUnaryOpConditional<FloatAttr>(
219       operands, [](const APFloat &a) -> Optional<APFloat> {
220         switch (a.getSizeInBits(a.getSemantics())) {
221         case 64:
222           return APFloat(exp(a.convertToDouble()));
223         case 32:
224           return APFloat(expf(a.convertToFloat()));
225         default:
226           return {};
227         }
228       });
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // Exp2Op folder
233 //===----------------------------------------------------------------------===//
234 
fold(ArrayRef<Attribute> operands)235 OpFoldResult math::Exp2Op::fold(ArrayRef<Attribute> operands) {
236   return constFoldUnaryOpConditional<FloatAttr>(
237       operands, [](const APFloat &a) -> Optional<APFloat> {
238         switch (a.getSizeInBits(a.getSemantics())) {
239         case 64:
240           return APFloat(exp2(a.convertToDouble()));
241         case 32:
242           return APFloat(exp2f(a.convertToFloat()));
243         default:
244           return {};
245         }
246       });
247 }
248 
249 /// Materialize an integer or floating point constant.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)250 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
251                                                   Attribute value, Type type,
252                                                   Location loc) {
253   return builder.create<arith::ConstantOp>(loc, value, type);
254 }
255