14348d8abSStephan Herhut //===- MathOps.cpp - MLIR operations for math implementation --------------===//
24348d8abSStephan Herhut //
34348d8abSStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44348d8abSStephan Herhut // See https://llvm.org/LICENSE.txt for license information.
54348d8abSStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64348d8abSStephan Herhut //
74348d8abSStephan Herhut //===----------------------------------------------------------------------===//
84348d8abSStephan Herhut
91773dddaSWilliam S. Moses #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1083bd4fe2Sjacquesguan #include "mlir/Dialect/CommonFolders.h"
114348d8abSStephan Herhut #include "mlir/Dialect/Math/IR/Math.h"
121773dddaSWilliam S. Moses #include "mlir/IR/Builders.h"
134348d8abSStephan Herhut
144348d8abSStephan Herhut using namespace mlir;
154348d8abSStephan Herhut using namespace mlir::math;
164348d8abSStephan Herhut
174348d8abSStephan Herhut //===----------------------------------------------------------------------===//
184348d8abSStephan Herhut // TableGen'd op method definitions
194348d8abSStephan Herhut //===----------------------------------------------------------------------===//
204348d8abSStephan Herhut
214348d8abSStephan Herhut #define GET_OP_CLASSES
224348d8abSStephan Herhut #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
231773dddaSWilliam S. Moses
241773dddaSWilliam S. Moses //===----------------------------------------------------------------------===//
25e609417cSjacquesguan // AbsOp folder
26e609417cSjacquesguan //===----------------------------------------------------------------------===//
27e609417cSjacquesguan
fold(ArrayRef<Attribute> operands)28e609417cSjacquesguan OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
2983bd4fe2Sjacquesguan return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
301881d6fcSMehdi Amini const APFloat &result(a);
3183bd4fe2Sjacquesguan return abs(result);
3283bd4fe2Sjacquesguan });
33e609417cSjacquesguan }
34e609417cSjacquesguan
35e609417cSjacquesguan //===----------------------------------------------------------------------===//
361773dddaSWilliam S. Moses // CeilOp folder
371773dddaSWilliam S. Moses //===----------------------------------------------------------------------===//
381773dddaSWilliam S. Moses
fold(ArrayRef<Attribute> operands)391773dddaSWilliam S. Moses OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
4083bd4fe2Sjacquesguan return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
4183bd4fe2Sjacquesguan APFloat result(a);
4283bd4fe2Sjacquesguan result.roundToIntegral(llvm::RoundingMode::TowardPositive);
4383bd4fe2Sjacquesguan return result;
4483bd4fe2Sjacquesguan });
451773dddaSWilliam S. Moses }
461773dddaSWilliam S. Moses
471773dddaSWilliam S. Moses //===----------------------------------------------------------------------===//
48e609417cSjacquesguan // CopySignOp folder
49e609417cSjacquesguan //===----------------------------------------------------------------------===//
50e609417cSjacquesguan
fold(ArrayRef<Attribute> operands)51e609417cSjacquesguan OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
5283bd4fe2Sjacquesguan return constFoldBinaryOp<FloatAttr>(operands,
5383bd4fe2Sjacquesguan [](const APFloat &a, const APFloat &b) {
5483bd4fe2Sjacquesguan APFloat result(a);
5583bd4fe2Sjacquesguan result.copySign(b);
5683bd4fe2Sjacquesguan return result;
5783bd4fe2Sjacquesguan });
58e609417cSjacquesguan }
59e609417cSjacquesguan
60e609417cSjacquesguan //===----------------------------------------------------------------------===//
61e609417cSjacquesguan // CountLeadingZerosOp folder
62e609417cSjacquesguan //===----------------------------------------------------------------------===//
63e609417cSjacquesguan
fold(ArrayRef<Attribute> operands)64e609417cSjacquesguan OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
6583bd4fe2Sjacquesguan return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
6683bd4fe2Sjacquesguan return APInt(a.getBitWidth(), a.countLeadingZeros());
6783bd4fe2Sjacquesguan });
68e609417cSjacquesguan }
69e609417cSjacquesguan
70e609417cSjacquesguan //===----------------------------------------------------------------------===//
71e609417cSjacquesguan // CountTrailingZerosOp folder
72e609417cSjacquesguan //===----------------------------------------------------------------------===//
73e609417cSjacquesguan
fold(ArrayRef<Attribute> operands)74e609417cSjacquesguan OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
7583bd4fe2Sjacquesguan return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
7683bd4fe2Sjacquesguan return APInt(a.getBitWidth(), a.countTrailingZeros());
7783bd4fe2Sjacquesguan });
78e609417cSjacquesguan }
79e609417cSjacquesguan
80e609417cSjacquesguan //===----------------------------------------------------------------------===//
81e609417cSjacquesguan // CtPopOp folder
82e609417cSjacquesguan //===----------------------------------------------------------------------===//
83e609417cSjacquesguan
fold(ArrayRef<Attribute> operands)84e609417cSjacquesguan OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
8583bd4fe2Sjacquesguan return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
8683bd4fe2Sjacquesguan return APInt(a.getBitWidth(), a.countPopulation());
8783bd4fe2Sjacquesguan });
88e609417cSjacquesguan }
89e609417cSjacquesguan
90e609417cSjacquesguan //===----------------------------------------------------------------------===//
919c22853eSjacquesguan // LogOp folder
929c22853eSjacquesguan //===----------------------------------------------------------------------===//
939c22853eSjacquesguan
fold(ArrayRef<Attribute> operands)949c22853eSjacquesguan OpFoldResult math::LogOp::fold(ArrayRef<Attribute> operands) {
959c22853eSjacquesguan return constFoldUnaryOpConditional<FloatAttr>(
969c22853eSjacquesguan operands, [](const APFloat &a) -> Optional<APFloat> {
979c22853eSjacquesguan if (a.isNegative())
989c22853eSjacquesguan return {};
999c22853eSjacquesguan
1009c22853eSjacquesguan if (a.getSizeInBits(a.getSemantics()) == 64)
1019c22853eSjacquesguan return APFloat(log(a.convertToDouble()));
1029c22853eSjacquesguan
1039c22853eSjacquesguan if (a.getSizeInBits(a.getSemantics()) == 32)
1049c22853eSjacquesguan return APFloat(logf(a.convertToFloat()));
1059c22853eSjacquesguan
1069c22853eSjacquesguan return {};
1079c22853eSjacquesguan });
1089c22853eSjacquesguan }
1099c22853eSjacquesguan
1109c22853eSjacquesguan //===----------------------------------------------------------------------===//
1111773dddaSWilliam S. Moses // Log2Op folder
1121773dddaSWilliam S. Moses //===----------------------------------------------------------------------===//
1131773dddaSWilliam S. Moses
fold(ArrayRef<Attribute> operands)1141773dddaSWilliam S. Moses OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
115ad4b7fb3Sjacquesguan return constFoldUnaryOpConditional<FloatAttr>(
116ad4b7fb3Sjacquesguan operands, [](const APFloat &a) -> Optional<APFloat> {
117ad4b7fb3Sjacquesguan if (a.isNegative())
1181773dddaSWilliam S. Moses return {};
1191773dddaSWilliam S. Moses
120ad4b7fb3Sjacquesguan if (a.getSizeInBits(a.getSemantics()) == 64)
121ad4b7fb3Sjacquesguan return APFloat(log2(a.convertToDouble()));
1221773dddaSWilliam S. Moses
123ad4b7fb3Sjacquesguan if (a.getSizeInBits(a.getSemantics()) == 32)
124ad4b7fb3Sjacquesguan return APFloat(log2f(a.convertToFloat()));
125164c7afaSWilliam S. Moses
126164c7afaSWilliam S. Moses return {};
127ad4b7fb3Sjacquesguan });
128164c7afaSWilliam S. Moses }
129164c7afaSWilliam S. Moses
130164c7afaSWilliam S. Moses //===----------------------------------------------------------------------===//
131a5cae20bSjacquesguan // Log10Op folder
132a5cae20bSjacquesguan //===----------------------------------------------------------------------===//
133a5cae20bSjacquesguan
fold(ArrayRef<Attribute> operands)134a5cae20bSjacquesguan OpFoldResult math::Log10Op::fold(ArrayRef<Attribute> operands) {
135a5cae20bSjacquesguan return constFoldUnaryOpConditional<FloatAttr>(
136a5cae20bSjacquesguan operands, [](const APFloat &a) -> Optional<APFloat> {
137a5cae20bSjacquesguan if (a.isNegative())
138a5cae20bSjacquesguan return {};
139a5cae20bSjacquesguan
140a5cae20bSjacquesguan switch (a.getSizeInBits(a.getSemantics())) {
141a5cae20bSjacquesguan case 64:
142a5cae20bSjacquesguan return APFloat(log10(a.convertToDouble()));
143a5cae20bSjacquesguan case 32:
144a5cae20bSjacquesguan return APFloat(log10f(a.convertToFloat()));
145a5cae20bSjacquesguan default:
146a5cae20bSjacquesguan return {};
147a5cae20bSjacquesguan }
148a5cae20bSjacquesguan });
149a5cae20bSjacquesguan }
150a5cae20bSjacquesguan
151a5cae20bSjacquesguan //===----------------------------------------------------------------------===//
152c3d856bfSjacquesguan // Log1pOp folder
153c3d856bfSjacquesguan //===----------------------------------------------------------------------===//
154c3d856bfSjacquesguan
fold(ArrayRef<Attribute> operands)155c3d856bfSjacquesguan OpFoldResult math::Log1pOp::fold(ArrayRef<Attribute> operands) {
156c3d856bfSjacquesguan return constFoldUnaryOpConditional<FloatAttr>(
157c3d856bfSjacquesguan operands, [](const APFloat &a) -> Optional<APFloat> {
158c3d856bfSjacquesguan switch (a.getSizeInBits(a.getSemantics())) {
159c3d856bfSjacquesguan case 64:
160c3d856bfSjacquesguan if ((a + APFloat(1.0)).isNegative())
161c3d856bfSjacquesguan return {};
162c3d856bfSjacquesguan return APFloat(log1p(a.convertToDouble()));
163c3d856bfSjacquesguan case 32:
164c3d856bfSjacquesguan if ((a + APFloat(1.0f)).isNegative())
165c3d856bfSjacquesguan return {};
166c3d856bfSjacquesguan return APFloat(log1pf(a.convertToFloat()));
167c3d856bfSjacquesguan default:
168c3d856bfSjacquesguan return {};
169c3d856bfSjacquesguan }
170c3d856bfSjacquesguan });
171c3d856bfSjacquesguan }
172c3d856bfSjacquesguan
173c3d856bfSjacquesguan //===----------------------------------------------------------------------===//
174164c7afaSWilliam S. Moses // PowFOp folder
175164c7afaSWilliam S. Moses //===----------------------------------------------------------------------===//
176164c7afaSWilliam S. Moses
fold(ArrayRef<Attribute> operands)177164c7afaSWilliam S. Moses OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
178362240e0Sjacquesguan return constFoldBinaryOpConditional<FloatAttr>(
179362240e0Sjacquesguan operands, [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
180362240e0Sjacquesguan if (a.getSizeInBits(a.getSemantics()) == 64 &&
181362240e0Sjacquesguan b.getSizeInBits(b.getSemantics()) == 64)
182362240e0Sjacquesguan return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
183164c7afaSWilliam S. Moses
184362240e0Sjacquesguan if (a.getSizeInBits(a.getSemantics()) == 32 &&
185362240e0Sjacquesguan b.getSizeInBits(b.getSemantics()) == 32)
186362240e0Sjacquesguan return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
1871773dddaSWilliam S. Moses
1881773dddaSWilliam S. Moses return {};
189362240e0Sjacquesguan });
1901773dddaSWilliam S. Moses }
1911773dddaSWilliam S. Moses
1924d7d5c5fSjacquesguan //===----------------------------------------------------------------------===//
1934d7d5c5fSjacquesguan // SqrtOp folder
1944d7d5c5fSjacquesguan //===----------------------------------------------------------------------===//
1954d7d5c5fSjacquesguan
fold(ArrayRef<Attribute> operands)19626c95ae3Sjacquesguan OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
1974d7d5c5fSjacquesguan return constFoldUnaryOpConditional<FloatAttr>(
1984d7d5c5fSjacquesguan operands, [](const APFloat &a) -> Optional<APFloat> {
1994d7d5c5fSjacquesguan if (a.isNegative())
20026c95ae3Sjacquesguan return {};
20126c95ae3Sjacquesguan
2024d7d5c5fSjacquesguan switch (a.getSizeInBits(a.getSemantics())) {
2034d7d5c5fSjacquesguan case 64:
2044d7d5c5fSjacquesguan return APFloat(sqrt(a.convertToDouble()));
2054d7d5c5fSjacquesguan case 32:
2064d7d5c5fSjacquesguan return APFloat(sqrtf(a.convertToFloat()));
2074d7d5c5fSjacquesguan default:
20826c95ae3Sjacquesguan return {};
2094d7d5c5fSjacquesguan }
2104d7d5c5fSjacquesguan });
21126c95ae3Sjacquesguan }
21226c95ae3Sjacquesguan
2139e241c70Sjacquesguan //===----------------------------------------------------------------------===//
2149e241c70Sjacquesguan // ExpOp folder
2159e241c70Sjacquesguan //===----------------------------------------------------------------------===//
2169e241c70Sjacquesguan
fold(ArrayRef<Attribute> operands)2179e241c70Sjacquesguan OpFoldResult math::ExpOp::fold(ArrayRef<Attribute> operands) {
2189e241c70Sjacquesguan return constFoldUnaryOpConditional<FloatAttr>(
2199e241c70Sjacquesguan operands, [](const APFloat &a) -> Optional<APFloat> {
2209e241c70Sjacquesguan switch (a.getSizeInBits(a.getSemantics())) {
2219e241c70Sjacquesguan case 64:
2229e241c70Sjacquesguan return APFloat(exp(a.convertToDouble()));
2239e241c70Sjacquesguan case 32:
2249e241c70Sjacquesguan return APFloat(expf(a.convertToFloat()));
2259e241c70Sjacquesguan default:
2269e241c70Sjacquesguan return {};
2279e241c70Sjacquesguan }
2289e241c70Sjacquesguan });
2299e241c70Sjacquesguan }
2309e241c70Sjacquesguan
231*78015047Sjacquesguan //===----------------------------------------------------------------------===//
232*78015047Sjacquesguan // Exp2Op folder
233*78015047Sjacquesguan //===----------------------------------------------------------------------===//
234*78015047Sjacquesguan
fold(ArrayRef<Attribute> operands)235*78015047Sjacquesguan OpFoldResult math::Exp2Op::fold(ArrayRef<Attribute> operands) {
236*78015047Sjacquesguan return constFoldUnaryOpConditional<FloatAttr>(
237*78015047Sjacquesguan operands, [](const APFloat &a) -> Optional<APFloat> {
238*78015047Sjacquesguan switch (a.getSizeInBits(a.getSemantics())) {
239*78015047Sjacquesguan case 64:
240*78015047Sjacquesguan return APFloat(exp2(a.convertToDouble()));
241*78015047Sjacquesguan case 32:
242*78015047Sjacquesguan return APFloat(exp2f(a.convertToFloat()));
243*78015047Sjacquesguan default:
244*78015047Sjacquesguan return {};
245*78015047Sjacquesguan }
246*78015047Sjacquesguan });
247*78015047Sjacquesguan }
248*78015047Sjacquesguan
2491773dddaSWilliam S. Moses /// Materialize an integer or floating point constant.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)2501773dddaSWilliam S. Moses Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
2511773dddaSWilliam S. Moses Attribute value, Type type,
2521773dddaSWilliam S. Moses Location loc) {
2531773dddaSWilliam S. Moses return builder.create<arith::ConstantOp>(loc, value, type);
2541773dddaSWilliam S. Moses }
255