1 //===- MathOps.cpp - MLIR operations for math implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/CommonFolders.h"
11 #include "mlir/Dialect/Math/IR/Math.h"
12 #include "mlir/IR/Builders.h"
13
14 using namespace mlir;
15 using namespace mlir::math;
16
17 //===----------------------------------------------------------------------===//
18 // TableGen'd op method definitions
19 //===----------------------------------------------------------------------===//
20
21 #define GET_OP_CLASSES
22 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
23
24 //===----------------------------------------------------------------------===//
25 // AbsOp folder
26 //===----------------------------------------------------------------------===//
27
fold(ArrayRef<Attribute> operands)28 OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
29 return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
30 const APFloat &result(a);
31 return abs(result);
32 });
33 }
34
35 //===----------------------------------------------------------------------===//
36 // CeilOp folder
37 //===----------------------------------------------------------------------===//
38
fold(ArrayRef<Attribute> operands)39 OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
40 return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
41 APFloat result(a);
42 result.roundToIntegral(llvm::RoundingMode::TowardPositive);
43 return result;
44 });
45 }
46
47 //===----------------------------------------------------------------------===//
48 // CopySignOp folder
49 //===----------------------------------------------------------------------===//
50
fold(ArrayRef<Attribute> operands)51 OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
52 return constFoldBinaryOp<FloatAttr>(operands,
53 [](const APFloat &a, const APFloat &b) {
54 APFloat result(a);
55 result.copySign(b);
56 return result;
57 });
58 }
59
60 //===----------------------------------------------------------------------===//
61 // CountLeadingZerosOp folder
62 //===----------------------------------------------------------------------===//
63
fold(ArrayRef<Attribute> operands)64 OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
65 return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
66 return APInt(a.getBitWidth(), a.countLeadingZeros());
67 });
68 }
69
70 //===----------------------------------------------------------------------===//
71 // CountTrailingZerosOp folder
72 //===----------------------------------------------------------------------===//
73
fold(ArrayRef<Attribute> operands)74 OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
75 return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
76 return APInt(a.getBitWidth(), a.countTrailingZeros());
77 });
78 }
79
80 //===----------------------------------------------------------------------===//
81 // CtPopOp folder
82 //===----------------------------------------------------------------------===//
83
fold(ArrayRef<Attribute> operands)84 OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
85 return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
86 return APInt(a.getBitWidth(), a.countPopulation());
87 });
88 }
89
90 //===----------------------------------------------------------------------===//
91 // LogOp folder
92 //===----------------------------------------------------------------------===//
93
fold(ArrayRef<Attribute> operands)94 OpFoldResult math::LogOp::fold(ArrayRef<Attribute> operands) {
95 return constFoldUnaryOpConditional<FloatAttr>(
96 operands, [](const APFloat &a) -> Optional<APFloat> {
97 if (a.isNegative())
98 return {};
99
100 if (a.getSizeInBits(a.getSemantics()) == 64)
101 return APFloat(log(a.convertToDouble()));
102
103 if (a.getSizeInBits(a.getSemantics()) == 32)
104 return APFloat(logf(a.convertToFloat()));
105
106 return {};
107 });
108 }
109
110 //===----------------------------------------------------------------------===//
111 // Log2Op folder
112 //===----------------------------------------------------------------------===//
113
fold(ArrayRef<Attribute> operands)114 OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
115 return constFoldUnaryOpConditional<FloatAttr>(
116 operands, [](const APFloat &a) -> Optional<APFloat> {
117 if (a.isNegative())
118 return {};
119
120 if (a.getSizeInBits(a.getSemantics()) == 64)
121 return APFloat(log2(a.convertToDouble()));
122
123 if (a.getSizeInBits(a.getSemantics()) == 32)
124 return APFloat(log2f(a.convertToFloat()));
125
126 return {};
127 });
128 }
129
130 //===----------------------------------------------------------------------===//
131 // Log10Op folder
132 //===----------------------------------------------------------------------===//
133
fold(ArrayRef<Attribute> operands)134 OpFoldResult math::Log10Op::fold(ArrayRef<Attribute> operands) {
135 return constFoldUnaryOpConditional<FloatAttr>(
136 operands, [](const APFloat &a) -> Optional<APFloat> {
137 if (a.isNegative())
138 return {};
139
140 switch (a.getSizeInBits(a.getSemantics())) {
141 case 64:
142 return APFloat(log10(a.convertToDouble()));
143 case 32:
144 return APFloat(log10f(a.convertToFloat()));
145 default:
146 return {};
147 }
148 });
149 }
150
151 //===----------------------------------------------------------------------===//
152 // Log1pOp folder
153 //===----------------------------------------------------------------------===//
154
fold(ArrayRef<Attribute> operands)155 OpFoldResult math::Log1pOp::fold(ArrayRef<Attribute> operands) {
156 return constFoldUnaryOpConditional<FloatAttr>(
157 operands, [](const APFloat &a) -> Optional<APFloat> {
158 switch (a.getSizeInBits(a.getSemantics())) {
159 case 64:
160 if ((a + APFloat(1.0)).isNegative())
161 return {};
162 return APFloat(log1p(a.convertToDouble()));
163 case 32:
164 if ((a + APFloat(1.0f)).isNegative())
165 return {};
166 return APFloat(log1pf(a.convertToFloat()));
167 default:
168 return {};
169 }
170 });
171 }
172
173 //===----------------------------------------------------------------------===//
174 // PowFOp folder
175 //===----------------------------------------------------------------------===//
176
fold(ArrayRef<Attribute> operands)177 OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
178 return constFoldBinaryOpConditional<FloatAttr>(
179 operands, [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
180 if (a.getSizeInBits(a.getSemantics()) == 64 &&
181 b.getSizeInBits(b.getSemantics()) == 64)
182 return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
183
184 if (a.getSizeInBits(a.getSemantics()) == 32 &&
185 b.getSizeInBits(b.getSemantics()) == 32)
186 return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
187
188 return {};
189 });
190 }
191
192 //===----------------------------------------------------------------------===//
193 // SqrtOp folder
194 //===----------------------------------------------------------------------===//
195
fold(ArrayRef<Attribute> operands)196 OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
197 return constFoldUnaryOpConditional<FloatAttr>(
198 operands, [](const APFloat &a) -> Optional<APFloat> {
199 if (a.isNegative())
200 return {};
201
202 switch (a.getSizeInBits(a.getSemantics())) {
203 case 64:
204 return APFloat(sqrt(a.convertToDouble()));
205 case 32:
206 return APFloat(sqrtf(a.convertToFloat()));
207 default:
208 return {};
209 }
210 });
211 }
212
213 //===----------------------------------------------------------------------===//
214 // ExpOp folder
215 //===----------------------------------------------------------------------===//
216
fold(ArrayRef<Attribute> operands)217 OpFoldResult math::ExpOp::fold(ArrayRef<Attribute> operands) {
218 return constFoldUnaryOpConditional<FloatAttr>(
219 operands, [](const APFloat &a) -> Optional<APFloat> {
220 switch (a.getSizeInBits(a.getSemantics())) {
221 case 64:
222 return APFloat(exp(a.convertToDouble()));
223 case 32:
224 return APFloat(expf(a.convertToFloat()));
225 default:
226 return {};
227 }
228 });
229 }
230
231 //===----------------------------------------------------------------------===//
232 // Exp2Op folder
233 //===----------------------------------------------------------------------===//
234
fold(ArrayRef<Attribute> operands)235 OpFoldResult math::Exp2Op::fold(ArrayRef<Attribute> operands) {
236 return constFoldUnaryOpConditional<FloatAttr>(
237 operands, [](const APFloat &a) -> Optional<APFloat> {
238 switch (a.getSizeInBits(a.getSemantics())) {
239 case 64:
240 return APFloat(exp2(a.convertToDouble()));
241 case 32:
242 return APFloat(exp2f(a.convertToFloat()));
243 default:
244 return {};
245 }
246 });
247 }
248
249 /// Materialize an integer or floating point constant.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)250 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
251 Attribute value, Type type,
252 Location loc) {
253 return builder.create<arith::ConstantOp>(loc, value, type);
254 }
255