1 //===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
10 #define LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
11 
12 #include "src/__support/CPP/TypeTraits.h"
13 #include "utils/UnitTest/Test.h"
14 
15 #include <stdint.h>
16 
17 namespace __llvm_libc {
18 namespace testing {
19 namespace mpfr {
20 
21 enum class Operation : int {
22   // Operations with take a single floating point number as input
23   // and produce a single floating point number as output. The input
24   // and output floating point numbers are of the same kind.
25   BeginUnaryOperationsSingleOutput,
26   Abs,
27   Ceil,
28   Cos,
29   Exp,
30   Exp2,
31   Expm1,
32   Floor,
33   Log,
34   Log2,
35   Mod2PI,
36   ModPIOver2,
37   ModPIOver4,
38   Round,
39   Sin,
40   Sqrt,
41   Tan,
42   Trunc,
43   EndUnaryOperationsSingleOutput,
44 
45   // Operations which take a single floating point nubmer as input
46   // but produce two outputs. The first ouput is a floating point
47   // number of the same type as the input. The second output is of type
48   // 'int'.
49   BeginUnaryOperationsTwoOutputs,
50   Frexp, // Floating point output, the first output, is the fractional part.
51   EndUnaryOperationsTwoOutputs,
52 
53   // Operations wich take two floating point nubmers of the same type as
54   // input and produce a single floating point number of the same type as
55   // output.
56   BeginBinaryOperationsSingleOutput,
57   Hypot,
58   EndBinaryOperationsSingleOutput,
59 
60   // Operations which take two floating point numbers of the same type as
61   // input and produce two outputs. The first output is a floating nubmer of
62   // the same type as the inputs. The second output is af type 'int'.
63   BeginBinaryOperationsTwoOutputs,
64   RemQuo, // The first output, the floating point output, is the remainder.
65   EndBinaryOperationsTwoOutputs,
66 
67   // Operations which take three floating point nubmers of the same type as
68   // input and produce a single floating point number of the same type as
69   // output.
70   BeginTernaryOperationsSingleOuput,
71   Fma,
72   EndTernaryOperationsSingleOutput,
73 };
74 
75 enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest };
76 
77 int get_fe_rounding(RoundingMode mode);
78 
79 struct ForceRoundingMode {
80   ForceRoundingMode(RoundingMode);
81   ~ForceRoundingMode();
82 
83   int old_rounding_mode;
84   int rounding_mode;
85 };
86 
87 template <typename T> struct BinaryInput {
88   static_assert(
89       __llvm_libc::cpp::IsFloatingPointType<T>::Value,
90       "Template parameter of BinaryInput must be a floating point type.");
91 
92   using Type = T;
93   T x, y;
94 };
95 
96 template <typename T> struct TernaryInput {
97   static_assert(
98       __llvm_libc::cpp::IsFloatingPointType<T>::Value,
99       "Template parameter of TernaryInput must be a floating point type.");
100 
101   using Type = T;
102   T x, y, z;
103 };
104 
105 template <typename T> struct BinaryOutput {
106   T f;
107   int i;
108 };
109 
110 namespace internal {
111 
112 template <typename T1, typename T2>
113 struct AreMatchingBinaryInputAndBinaryOutput {
114   static constexpr bool VALUE = false;
115 };
116 
117 template <typename T>
118 struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
119   static constexpr bool VALUE = cpp::IsFloatingPointType<T>::Value;
120 };
121 
122 template <typename T>
123 bool compare_unary_operation_single_output(Operation op, T input, T libc_output,
124                                            double ulp_tolerance,
125                                            RoundingMode rounding);
126 template <typename T>
127 bool compare_unary_operation_two_outputs(Operation op, T input,
128                                          const BinaryOutput<T> &libc_output,
129                                          double ulp_tolerance,
130                                          RoundingMode rounding);
131 template <typename T>
132 bool compare_binary_operation_two_outputs(Operation op,
133                                           const BinaryInput<T> &input,
134                                           const BinaryOutput<T> &libc_output,
135                                           double ulp_tolerance,
136                                           RoundingMode rounding);
137 
138 template <typename T>
139 bool compare_binary_operation_one_output(Operation op,
140                                          const BinaryInput<T> &input,
141                                          T libc_output, double ulp_tolerance,
142                                          RoundingMode rounding);
143 
144 template <typename T>
145 bool compare_ternary_operation_one_output(Operation op,
146                                           const TernaryInput<T> &input,
147                                           T libc_output, double ulp_tolerance,
148                                           RoundingMode rounding);
149 
150 template <typename T>
151 void explain_unary_operation_single_output_error(Operation op, T input,
152                                                  T match_value,
153                                                  double ulp_tolerance,
154                                                  RoundingMode rounding,
155                                                  testutils::StreamWrapper &OS);
156 template <typename T>
157 void explain_unary_operation_two_outputs_error(
158     Operation op, T input, const BinaryOutput<T> &match_value,
159     double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
160 template <typename T>
161 void explain_binary_operation_two_outputs_error(
162     Operation op, const BinaryInput<T> &input,
163     const BinaryOutput<T> &match_value, double ulp_tolerance,
164     RoundingMode rounding, testutils::StreamWrapper &OS);
165 
166 template <typename T>
167 void explain_binary_operation_one_output_error(
168     Operation op, const BinaryInput<T> &input, T match_value,
169     double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
170 
171 template <typename T>
172 void explain_ternary_operation_one_output_error(
173     Operation op, const TernaryInput<T> &input, T match_value,
174     double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
175 
176 template <Operation op, typename InputType, typename OutputType>
177 class MPFRMatcher : public testing::Matcher<OutputType> {
178   InputType input;
179   OutputType match_value;
180   double ulp_tolerance;
181   RoundingMode rounding;
182 
183 public:
184   MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
185       : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}
186 
187   bool match(OutputType libcResult) {
188     match_value = libcResult;
189     return match(input, match_value);
190   }
191 
192   // This method is marked with NOLINT because it the name `explainError`
193   // does not confirm to the coding style.
194   void explainError(testutils::StreamWrapper &OS) override { // NOLINT
195     explain_error(input, match_value, OS);
196   }
197 
198 private:
199   template <typename T> bool match(T in, T out) {
200     return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
201                                                  rounding);
202   }
203 
204   template <typename T> bool match(T in, const BinaryOutput<T> &out) {
205     return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
206                                                rounding);
207   }
208 
209   template <typename T> bool match(const BinaryInput<T> &in, T out) {
210     return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
211                                                rounding);
212   }
213 
214   template <typename T>
215   bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
216     return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
217                                                 rounding);
218   }
219 
220   template <typename T> bool match(const TernaryInput<T> &in, T out) {
221     return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
222                                                 rounding);
223   }
224 
225   template <typename T>
226   void explain_error(T in, T out, testutils::StreamWrapper &OS) {
227     explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
228                                                 rounding, OS);
229   }
230 
231   template <typename T>
232   void explain_error(T in, const BinaryOutput<T> &out,
233                      testutils::StreamWrapper &OS) {
234     explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
235                                               rounding, OS);
236   }
237 
238   template <typename T>
239   void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out,
240                      testutils::StreamWrapper &OS) {
241     explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
242                                                rounding, OS);
243   }
244 
245   template <typename T>
246   void explain_error(const BinaryInput<T> &in, T out,
247                      testutils::StreamWrapper &OS) {
248     explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
249                                               rounding, OS);
250   }
251 
252   template <typename T>
253   void explain_error(const TernaryInput<T> &in, T out,
254                      testutils::StreamWrapper &OS) {
255     explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
256                                                rounding, OS);
257   }
258 };
259 
260 } // namespace internal
261 
262 // Return true if the input and ouput types for the operation op are valid
263 // types.
264 template <Operation op, typename InputType, typename OutputType>
265 constexpr bool is_valid_operation() {
266   return (Operation::BeginUnaryOperationsSingleOutput < op &&
267           op < Operation::EndUnaryOperationsSingleOutput &&
268           cpp::IsSame<InputType, OutputType>::Value &&
269           cpp::IsFloatingPointType<InputType>::Value) ||
270          (Operation::BeginUnaryOperationsTwoOutputs < op &&
271           op < Operation::EndUnaryOperationsTwoOutputs &&
272           cpp::IsFloatingPointType<InputType>::Value &&
273           cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) ||
274          (Operation::BeginBinaryOperationsSingleOutput < op &&
275           op < Operation::EndBinaryOperationsSingleOutput &&
276           cpp::IsFloatingPointType<OutputType>::Value &&
277           cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) ||
278          (Operation::BeginBinaryOperationsTwoOutputs < op &&
279           op < Operation::EndBinaryOperationsTwoOutputs &&
280           internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
281                                                           OutputType>::VALUE) ||
282          (Operation::BeginTernaryOperationsSingleOuput < op &&
283           op < Operation::EndTernaryOperationsSingleOutput &&
284           cpp::IsFloatingPointType<OutputType>::Value &&
285           cpp::IsSame<InputType, TernaryInput<OutputType>>::Value);
286 }
287 
288 template <Operation op, typename InputType, typename OutputType>
289 __attribute__((no_sanitize("address")))
290 cpp::EnableIfType<is_valid_operation<op, InputType, OutputType>(),
291                   internal::MPFRMatcher<op, InputType, OutputType>>
292 get_mpfr_matcher(InputType input, OutputType output_unused,
293                  double ulp_tolerance, RoundingMode rounding) {
294   return internal::MPFRMatcher<op, InputType, OutputType>(input, ulp_tolerance,
295                                                           rounding);
296 }
297 
298 template <typename T> T round(T x, RoundingMode mode);
299 
300 template <typename T> bool round_to_long(T x, long &result);
301 template <typename T> bool round_to_long(T x, RoundingMode mode, long &result);
302 
303 } // namespace mpfr
304 } // namespace testing
305 } // namespace __llvm_libc
306 
307 // GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a
308 // simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`.
309 #define GET_MPFR_DUMMY_ARG(...) 0
310 
311 #define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME
312 
313 #define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
314   EXPECT_THAT(match_value,                                                     \
315               __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(                \
316                   input, match_value, ulp_tolerance,                           \
317                   __llvm_libc::testing::mpfr::RoundingMode::Nearest))
318 
319 #define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
320                                    rounding)                                   \
321   EXPECT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(   \
322                                input, match_value, ulp_tolerance, rounding))
323 
324 #define EXPECT_MPFR_MATCH(...)                                                 \
325   GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING,                      \
326                  EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
327   (__VA_ARGS__)
328 
329 #define EXPECT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)  \
330   {                                                                            \
331     namespace mpfr = __llvm_libc::testing::mpfr;                               \
332     mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest);                 \
333     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
334                       mpfr::RoundingMode::Nearest);                            \
335     mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward);                  \
336     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
337                       mpfr::RoundingMode::Upward);                             \
338     mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward);                \
339     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
340                       mpfr::RoundingMode::Downward);                           \
341     mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero);              \
342     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
343                       mpfr::RoundingMode::TowardZero);                         \
344   }
345 
346 #define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
347   ASSERT_THAT(match_value,                                                     \
348               __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(                \
349                   input, match_value, ulp_tolerance,                           \
350                   __llvm_libc::testing::mpfr::RoundingMode::Nearest))
351 
352 #define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
353                                    rounding)                                   \
354   ASSERT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(   \
355                                input, match_value, ulp_tolerance, rounding))
356 
357 #define ASSERT_MPFR_MATCH(...)                                                 \
358   GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING,                      \
359                  ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
360   (__VA_ARGS__)
361 
362 #define ASSERT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)  \
363   {                                                                            \
364     namespace mpfr = __llvm_libc::testing::mpfr;                               \
365     mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest);                 \
366     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
367                       mpfr::RoundingMode::Nearest);                            \
368     mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward);                  \
369     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
370                       mpfr::RoundingMode::Upward);                             \
371     mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward);                \
372     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
373                       mpfr::RoundingMode::Downward);                           \
374     mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero);              \
375     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
376                       mpfr::RoundingMode::TowardZero);                         \
377   }
378 
379 #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
380