14348d8abSStephan Herhut //===- MathOps.cpp - MLIR operations for math implementation --------------===//
24348d8abSStephan Herhut //
34348d8abSStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44348d8abSStephan Herhut // See https://llvm.org/LICENSE.txt for license information.
54348d8abSStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64348d8abSStephan Herhut //
74348d8abSStephan Herhut //===----------------------------------------------------------------------===//
84348d8abSStephan Herhut 
91773dddaSWilliam S. Moses #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1083bd4fe2Sjacquesguan #include "mlir/Dialect/CommonFolders.h"
114348d8abSStephan Herhut #include "mlir/Dialect/Math/IR/Math.h"
121773dddaSWilliam S. Moses #include "mlir/IR/Builders.h"
134348d8abSStephan Herhut 
144348d8abSStephan Herhut using namespace mlir;
154348d8abSStephan Herhut using namespace mlir::math;
164348d8abSStephan Herhut 
174348d8abSStephan Herhut //===----------------------------------------------------------------------===//
184348d8abSStephan Herhut // TableGen'd op method definitions
194348d8abSStephan Herhut //===----------------------------------------------------------------------===//
204348d8abSStephan Herhut 
214348d8abSStephan Herhut #define GET_OP_CLASSES
224348d8abSStephan Herhut #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
231773dddaSWilliam S. Moses 
241773dddaSWilliam S. Moses //===----------------------------------------------------------------------===//
25e609417cSjacquesguan // AbsOp folder
26e609417cSjacquesguan //===----------------------------------------------------------------------===//
27e609417cSjacquesguan 
fold(ArrayRef<Attribute> operands)28e609417cSjacquesguan OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
2983bd4fe2Sjacquesguan   return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
301881d6fcSMehdi Amini     const APFloat &result(a);
3183bd4fe2Sjacquesguan     return abs(result);
3283bd4fe2Sjacquesguan   });
33e609417cSjacquesguan }
34e609417cSjacquesguan 
35e609417cSjacquesguan //===----------------------------------------------------------------------===//
361773dddaSWilliam S. Moses // CeilOp folder
371773dddaSWilliam S. Moses //===----------------------------------------------------------------------===//
381773dddaSWilliam S. Moses 
fold(ArrayRef<Attribute> operands)391773dddaSWilliam S. Moses OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
4083bd4fe2Sjacquesguan   return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
4183bd4fe2Sjacquesguan     APFloat result(a);
4283bd4fe2Sjacquesguan     result.roundToIntegral(llvm::RoundingMode::TowardPositive);
4383bd4fe2Sjacquesguan     return result;
4483bd4fe2Sjacquesguan   });
451773dddaSWilliam S. Moses }
461773dddaSWilliam S. Moses 
471773dddaSWilliam S. Moses //===----------------------------------------------------------------------===//
48e609417cSjacquesguan // CopySignOp folder
49e609417cSjacquesguan //===----------------------------------------------------------------------===//
50e609417cSjacquesguan 
fold(ArrayRef<Attribute> operands)51e609417cSjacquesguan OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
5283bd4fe2Sjacquesguan   return constFoldBinaryOp<FloatAttr>(operands,
5383bd4fe2Sjacquesguan                                       [](const APFloat &a, const APFloat &b) {
5483bd4fe2Sjacquesguan                                         APFloat result(a);
5583bd4fe2Sjacquesguan                                         result.copySign(b);
5683bd4fe2Sjacquesguan                                         return result;
5783bd4fe2Sjacquesguan                                       });
58e609417cSjacquesguan }
59e609417cSjacquesguan 
60e609417cSjacquesguan //===----------------------------------------------------------------------===//
61e609417cSjacquesguan // CountLeadingZerosOp folder
62e609417cSjacquesguan //===----------------------------------------------------------------------===//
63e609417cSjacquesguan 
fold(ArrayRef<Attribute> operands)64e609417cSjacquesguan OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
6583bd4fe2Sjacquesguan   return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
6683bd4fe2Sjacquesguan     return APInt(a.getBitWidth(), a.countLeadingZeros());
6783bd4fe2Sjacquesguan   });
68e609417cSjacquesguan }
69e609417cSjacquesguan 
70e609417cSjacquesguan //===----------------------------------------------------------------------===//
71e609417cSjacquesguan // CountTrailingZerosOp folder
72e609417cSjacquesguan //===----------------------------------------------------------------------===//
73e609417cSjacquesguan 
fold(ArrayRef<Attribute> operands)74e609417cSjacquesguan OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
7583bd4fe2Sjacquesguan   return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
7683bd4fe2Sjacquesguan     return APInt(a.getBitWidth(), a.countTrailingZeros());
7783bd4fe2Sjacquesguan   });
78e609417cSjacquesguan }
79e609417cSjacquesguan 
80e609417cSjacquesguan //===----------------------------------------------------------------------===//
81e609417cSjacquesguan // CtPopOp folder
82e609417cSjacquesguan //===----------------------------------------------------------------------===//
83e609417cSjacquesguan 
fold(ArrayRef<Attribute> operands)84e609417cSjacquesguan OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
8583bd4fe2Sjacquesguan   return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
8683bd4fe2Sjacquesguan     return APInt(a.getBitWidth(), a.countPopulation());
8783bd4fe2Sjacquesguan   });
88e609417cSjacquesguan }
89e609417cSjacquesguan 
90e609417cSjacquesguan //===----------------------------------------------------------------------===//
919c22853eSjacquesguan // LogOp folder
929c22853eSjacquesguan //===----------------------------------------------------------------------===//
939c22853eSjacquesguan 
fold(ArrayRef<Attribute> operands)949c22853eSjacquesguan OpFoldResult math::LogOp::fold(ArrayRef<Attribute> operands) {
959c22853eSjacquesguan   return constFoldUnaryOpConditional<FloatAttr>(
969c22853eSjacquesguan       operands, [](const APFloat &a) -> Optional<APFloat> {
979c22853eSjacquesguan         if (a.isNegative())
989c22853eSjacquesguan           return {};
999c22853eSjacquesguan 
1009c22853eSjacquesguan         if (a.getSizeInBits(a.getSemantics()) == 64)
1019c22853eSjacquesguan           return APFloat(log(a.convertToDouble()));
1029c22853eSjacquesguan 
1039c22853eSjacquesguan         if (a.getSizeInBits(a.getSemantics()) == 32)
1049c22853eSjacquesguan           return APFloat(logf(a.convertToFloat()));
1059c22853eSjacquesguan 
1069c22853eSjacquesguan         return {};
1079c22853eSjacquesguan       });
1089c22853eSjacquesguan }
1099c22853eSjacquesguan 
1109c22853eSjacquesguan //===----------------------------------------------------------------------===//
1111773dddaSWilliam S. Moses // Log2Op folder
1121773dddaSWilliam S. Moses //===----------------------------------------------------------------------===//
1131773dddaSWilliam S. Moses 
fold(ArrayRef<Attribute> operands)1141773dddaSWilliam S. Moses OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
115ad4b7fb3Sjacquesguan   return constFoldUnaryOpConditional<FloatAttr>(
116ad4b7fb3Sjacquesguan       operands, [](const APFloat &a) -> Optional<APFloat> {
117ad4b7fb3Sjacquesguan         if (a.isNegative())
1181773dddaSWilliam S. Moses           return {};
1191773dddaSWilliam S. Moses 
120ad4b7fb3Sjacquesguan         if (a.getSizeInBits(a.getSemantics()) == 64)
121ad4b7fb3Sjacquesguan           return APFloat(log2(a.convertToDouble()));
1221773dddaSWilliam S. Moses 
123ad4b7fb3Sjacquesguan         if (a.getSizeInBits(a.getSemantics()) == 32)
124ad4b7fb3Sjacquesguan           return APFloat(log2f(a.convertToFloat()));
125164c7afaSWilliam S. Moses 
126164c7afaSWilliam S. Moses         return {};
127ad4b7fb3Sjacquesguan       });
128164c7afaSWilliam S. Moses }
129164c7afaSWilliam S. Moses 
130164c7afaSWilliam S. Moses //===----------------------------------------------------------------------===//
131a5cae20bSjacquesguan // Log10Op folder
132a5cae20bSjacquesguan //===----------------------------------------------------------------------===//
133a5cae20bSjacquesguan 
fold(ArrayRef<Attribute> operands)134a5cae20bSjacquesguan OpFoldResult math::Log10Op::fold(ArrayRef<Attribute> operands) {
135a5cae20bSjacquesguan   return constFoldUnaryOpConditional<FloatAttr>(
136a5cae20bSjacquesguan       operands, [](const APFloat &a) -> Optional<APFloat> {
137a5cae20bSjacquesguan         if (a.isNegative())
138a5cae20bSjacquesguan           return {};
139a5cae20bSjacquesguan 
140a5cae20bSjacquesguan         switch (a.getSizeInBits(a.getSemantics())) {
141a5cae20bSjacquesguan         case 64:
142a5cae20bSjacquesguan           return APFloat(log10(a.convertToDouble()));
143a5cae20bSjacquesguan         case 32:
144a5cae20bSjacquesguan           return APFloat(log10f(a.convertToFloat()));
145a5cae20bSjacquesguan         default:
146a5cae20bSjacquesguan           return {};
147a5cae20bSjacquesguan         }
148a5cae20bSjacquesguan       });
149a5cae20bSjacquesguan }
150a5cae20bSjacquesguan 
151a5cae20bSjacquesguan //===----------------------------------------------------------------------===//
152c3d856bfSjacquesguan // Log1pOp folder
153c3d856bfSjacquesguan //===----------------------------------------------------------------------===//
154c3d856bfSjacquesguan 
fold(ArrayRef<Attribute> operands)155c3d856bfSjacquesguan OpFoldResult math::Log1pOp::fold(ArrayRef<Attribute> operands) {
156c3d856bfSjacquesguan   return constFoldUnaryOpConditional<FloatAttr>(
157c3d856bfSjacquesguan       operands, [](const APFloat &a) -> Optional<APFloat> {
158c3d856bfSjacquesguan         switch (a.getSizeInBits(a.getSemantics())) {
159c3d856bfSjacquesguan         case 64:
160c3d856bfSjacquesguan           if ((a + APFloat(1.0)).isNegative())
161c3d856bfSjacquesguan             return {};
162c3d856bfSjacquesguan           return APFloat(log1p(a.convertToDouble()));
163c3d856bfSjacquesguan         case 32:
164c3d856bfSjacquesguan           if ((a + APFloat(1.0f)).isNegative())
165c3d856bfSjacquesguan             return {};
166c3d856bfSjacquesguan           return APFloat(log1pf(a.convertToFloat()));
167c3d856bfSjacquesguan         default:
168c3d856bfSjacquesguan           return {};
169c3d856bfSjacquesguan         }
170c3d856bfSjacquesguan       });
171c3d856bfSjacquesguan }
172c3d856bfSjacquesguan 
173c3d856bfSjacquesguan //===----------------------------------------------------------------------===//
174164c7afaSWilliam S. Moses // PowFOp folder
175164c7afaSWilliam S. Moses //===----------------------------------------------------------------------===//
176164c7afaSWilliam S. Moses 
fold(ArrayRef<Attribute> operands)177164c7afaSWilliam S. Moses OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
178362240e0Sjacquesguan   return constFoldBinaryOpConditional<FloatAttr>(
179362240e0Sjacquesguan       operands, [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
180362240e0Sjacquesguan         if (a.getSizeInBits(a.getSemantics()) == 64 &&
181362240e0Sjacquesguan             b.getSizeInBits(b.getSemantics()) == 64)
182362240e0Sjacquesguan           return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
183164c7afaSWilliam S. Moses 
184362240e0Sjacquesguan         if (a.getSizeInBits(a.getSemantics()) == 32 &&
185362240e0Sjacquesguan             b.getSizeInBits(b.getSemantics()) == 32)
186362240e0Sjacquesguan           return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
1871773dddaSWilliam S. Moses 
1881773dddaSWilliam S. Moses         return {};
189362240e0Sjacquesguan       });
1901773dddaSWilliam S. Moses }
1911773dddaSWilliam S. Moses 
1924d7d5c5fSjacquesguan //===----------------------------------------------------------------------===//
1934d7d5c5fSjacquesguan // SqrtOp folder
1944d7d5c5fSjacquesguan //===----------------------------------------------------------------------===//
1954d7d5c5fSjacquesguan 
fold(ArrayRef<Attribute> operands)19626c95ae3Sjacquesguan OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
1974d7d5c5fSjacquesguan   return constFoldUnaryOpConditional<FloatAttr>(
1984d7d5c5fSjacquesguan       operands, [](const APFloat &a) -> Optional<APFloat> {
1994d7d5c5fSjacquesguan         if (a.isNegative())
20026c95ae3Sjacquesguan           return {};
20126c95ae3Sjacquesguan 
2024d7d5c5fSjacquesguan         switch (a.getSizeInBits(a.getSemantics())) {
2034d7d5c5fSjacquesguan         case 64:
2044d7d5c5fSjacquesguan           return APFloat(sqrt(a.convertToDouble()));
2054d7d5c5fSjacquesguan         case 32:
2064d7d5c5fSjacquesguan           return APFloat(sqrtf(a.convertToFloat()));
2074d7d5c5fSjacquesguan         default:
20826c95ae3Sjacquesguan           return {};
2094d7d5c5fSjacquesguan         }
2104d7d5c5fSjacquesguan       });
21126c95ae3Sjacquesguan }
21226c95ae3Sjacquesguan 
2139e241c70Sjacquesguan //===----------------------------------------------------------------------===//
2149e241c70Sjacquesguan // ExpOp folder
2159e241c70Sjacquesguan //===----------------------------------------------------------------------===//
2169e241c70Sjacquesguan 
fold(ArrayRef<Attribute> operands)2179e241c70Sjacquesguan OpFoldResult math::ExpOp::fold(ArrayRef<Attribute> operands) {
2189e241c70Sjacquesguan   return constFoldUnaryOpConditional<FloatAttr>(
2199e241c70Sjacquesguan       operands, [](const APFloat &a) -> Optional<APFloat> {
2209e241c70Sjacquesguan         switch (a.getSizeInBits(a.getSemantics())) {
2219e241c70Sjacquesguan         case 64:
2229e241c70Sjacquesguan           return APFloat(exp(a.convertToDouble()));
2239e241c70Sjacquesguan         case 32:
2249e241c70Sjacquesguan           return APFloat(expf(a.convertToFloat()));
2259e241c70Sjacquesguan         default:
2269e241c70Sjacquesguan           return {};
2279e241c70Sjacquesguan         }
2289e241c70Sjacquesguan       });
2299e241c70Sjacquesguan }
2309e241c70Sjacquesguan 
231*78015047Sjacquesguan //===----------------------------------------------------------------------===//
232*78015047Sjacquesguan // Exp2Op folder
233*78015047Sjacquesguan //===----------------------------------------------------------------------===//
234*78015047Sjacquesguan 
fold(ArrayRef<Attribute> operands)235*78015047Sjacquesguan OpFoldResult math::Exp2Op::fold(ArrayRef<Attribute> operands) {
236*78015047Sjacquesguan   return constFoldUnaryOpConditional<FloatAttr>(
237*78015047Sjacquesguan       operands, [](const APFloat &a) -> Optional<APFloat> {
238*78015047Sjacquesguan         switch (a.getSizeInBits(a.getSemantics())) {
239*78015047Sjacquesguan         case 64:
240*78015047Sjacquesguan           return APFloat(exp2(a.convertToDouble()));
241*78015047Sjacquesguan         case 32:
242*78015047Sjacquesguan           return APFloat(exp2f(a.convertToFloat()));
243*78015047Sjacquesguan         default:
244*78015047Sjacquesguan           return {};
245*78015047Sjacquesguan         }
246*78015047Sjacquesguan       });
247*78015047Sjacquesguan }
248*78015047Sjacquesguan 
2491773dddaSWilliam S. Moses /// Materialize an integer or floating point constant.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)2501773dddaSWilliam S. Moses Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
2511773dddaSWilliam S. Moses                                                   Attribute value, Type type,
2521773dddaSWilliam S. Moses                                                   Location loc) {
2531773dddaSWilliam S. Moses   return builder.create<arith::ConstantOp>(loc, value, type);
2541773dddaSWilliam S. Moses }
255