1 //===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
10 #define LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
11 
12 #include "src/__support/CPP/TypeTraits.h"
13 #include "utils/UnitTest/Test.h"
14 
15 #include <stdint.h>
16 
17 namespace __llvm_libc {
18 namespace testing {
19 namespace mpfr {
20 
21 enum class Operation : int {
22   // Operations with take a single floating point number as input
23   // and produce a single floating point number as output. The input
24   // and output floating point numbers are of the same kind.
25   BeginUnaryOperationsSingleOutput,
26   Abs,
27   Ceil,
28   Cos,
29   Exp,
30   Exp2,
31   Expm1,
32   Floor,
33   Log,
34   Log2,
35   Log10,
36   Log1p,
37   Mod2PI,
38   ModPIOver2,
39   ModPIOver4,
40   Round,
41   Sin,
42   Sqrt,
43   Tan,
44   Trunc,
45   EndUnaryOperationsSingleOutput,
46 
47   // Operations which take a single floating point nubmer as input
48   // but produce two outputs. The first ouput is a floating point
49   // number of the same type as the input. The second output is of type
50   // 'int'.
51   BeginUnaryOperationsTwoOutputs,
52   Frexp, // Floating point output, the first output, is the fractional part.
53   EndUnaryOperationsTwoOutputs,
54 
55   // Operations wich take two floating point nubmers of the same type as
56   // input and produce a single floating point number of the same type as
57   // output.
58   BeginBinaryOperationsSingleOutput,
59   Fmod,
60   Hypot,
61   EndBinaryOperationsSingleOutput,
62 
63   // Operations which take two floating point numbers of the same type as
64   // input and produce two outputs. The first output is a floating nubmer of
65   // the same type as the inputs. The second output is af type 'int'.
66   BeginBinaryOperationsTwoOutputs,
67   RemQuo, // The first output, the floating point output, is the remainder.
68   EndBinaryOperationsTwoOutputs,
69 
70   // Operations which take three floating point nubmers of the same type as
71   // input and produce a single floating point number of the same type as
72   // output.
73   BeginTernaryOperationsSingleOuput,
74   Fma,
75   EndTernaryOperationsSingleOutput,
76 };
77 
78 enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest };
79 
80 int get_fe_rounding(RoundingMode mode);
81 
82 struct ForceRoundingMode {
83   ForceRoundingMode(RoundingMode);
84   ~ForceRoundingMode();
85 
86   int old_rounding_mode;
87   int rounding_mode;
88 };
89 
90 template <typename T> struct BinaryInput {
91   static_assert(
92       __llvm_libc::cpp::IsFloatingPointType<T>::Value,
93       "Template parameter of BinaryInput must be a floating point type.");
94 
95   using Type = T;
96   T x, y;
97 };
98 
99 template <typename T> struct TernaryInput {
100   static_assert(
101       __llvm_libc::cpp::IsFloatingPointType<T>::Value,
102       "Template parameter of TernaryInput must be a floating point type.");
103 
104   using Type = T;
105   T x, y, z;
106 };
107 
108 template <typename T> struct BinaryOutput {
109   T f;
110   int i;
111 };
112 
113 namespace internal {
114 
115 template <typename T1, typename T2>
116 struct AreMatchingBinaryInputAndBinaryOutput {
117   static constexpr bool VALUE = false;
118 };
119 
120 template <typename T>
121 struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
122   static constexpr bool VALUE = cpp::IsFloatingPointType<T>::Value;
123 };
124 
125 template <typename T>
126 bool compare_unary_operation_single_output(Operation op, T input, T libc_output,
127                                            double ulp_tolerance,
128                                            RoundingMode rounding);
129 template <typename T>
130 bool compare_unary_operation_two_outputs(Operation op, T input,
131                                          const BinaryOutput<T> &libc_output,
132                                          double ulp_tolerance,
133                                          RoundingMode rounding);
134 template <typename T>
135 bool compare_binary_operation_two_outputs(Operation op,
136                                           const BinaryInput<T> &input,
137                                           const BinaryOutput<T> &libc_output,
138                                           double ulp_tolerance,
139                                           RoundingMode rounding);
140 
141 template <typename T>
142 bool compare_binary_operation_one_output(Operation op,
143                                          const BinaryInput<T> &input,
144                                          T libc_output, double ulp_tolerance,
145                                          RoundingMode rounding);
146 
147 template <typename T>
148 bool compare_ternary_operation_one_output(Operation op,
149                                           const TernaryInput<T> &input,
150                                           T libc_output, double ulp_tolerance,
151                                           RoundingMode rounding);
152 
153 template <typename T>
154 void explain_unary_operation_single_output_error(Operation op, T input,
155                                                  T match_value,
156                                                  double ulp_tolerance,
157                                                  RoundingMode rounding,
158                                                  testutils::StreamWrapper &OS);
159 template <typename T>
160 void explain_unary_operation_two_outputs_error(
161     Operation op, T input, const BinaryOutput<T> &match_value,
162     double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
163 template <typename T>
164 void explain_binary_operation_two_outputs_error(
165     Operation op, const BinaryInput<T> &input,
166     const BinaryOutput<T> &match_value, double ulp_tolerance,
167     RoundingMode rounding, testutils::StreamWrapper &OS);
168 
169 template <typename T>
170 void explain_binary_operation_one_output_error(
171     Operation op, const BinaryInput<T> &input, T match_value,
172     double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
173 
174 template <typename T>
175 void explain_ternary_operation_one_output_error(
176     Operation op, const TernaryInput<T> &input, T match_value,
177     double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
178 
179 template <Operation op, typename InputType, typename OutputType>
180 class MPFRMatcher : public testing::Matcher<OutputType> {
181   InputType input;
182   OutputType match_value;
183   double ulp_tolerance;
184   RoundingMode rounding;
185 
186 public:
187   MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
188       : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}
189 
190   bool match(OutputType libcResult) {
191     match_value = libcResult;
192     return match(input, match_value);
193   }
194 
195   // This method is marked with NOLINT because it the name `explainError`
196   // does not confirm to the coding style.
197   void explainError(testutils::StreamWrapper &OS) override { // NOLINT
198     explain_error(input, match_value, OS);
199   }
200 
201 private:
202   template <typename T> bool match(T in, T out) {
203     return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
204                                                  rounding);
205   }
206 
207   template <typename T> bool match(T in, const BinaryOutput<T> &out) {
208     return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
209                                                rounding);
210   }
211 
212   template <typename T> bool match(const BinaryInput<T> &in, T out) {
213     return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
214                                                rounding);
215   }
216 
217   template <typename T>
218   bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
219     return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
220                                                 rounding);
221   }
222 
223   template <typename T> bool match(const TernaryInput<T> &in, T out) {
224     return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
225                                                 rounding);
226   }
227 
228   template <typename T>
229   void explain_error(T in, T out, testutils::StreamWrapper &OS) {
230     explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
231                                                 rounding, OS);
232   }
233 
234   template <typename T>
235   void explain_error(T in, const BinaryOutput<T> &out,
236                      testutils::StreamWrapper &OS) {
237     explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
238                                               rounding, OS);
239   }
240 
241   template <typename T>
242   void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out,
243                      testutils::StreamWrapper &OS) {
244     explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
245                                                rounding, OS);
246   }
247 
248   template <typename T>
249   void explain_error(const BinaryInput<T> &in, T out,
250                      testutils::StreamWrapper &OS) {
251     explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
252                                               rounding, OS);
253   }
254 
255   template <typename T>
256   void explain_error(const TernaryInput<T> &in, T out,
257                      testutils::StreamWrapper &OS) {
258     explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
259                                                rounding, OS);
260   }
261 };
262 
263 } // namespace internal
264 
265 // Return true if the input and ouput types for the operation op are valid
266 // types.
267 template <Operation op, typename InputType, typename OutputType>
268 constexpr bool is_valid_operation() {
269   return (Operation::BeginUnaryOperationsSingleOutput < op &&
270           op < Operation::EndUnaryOperationsSingleOutput &&
271           cpp::IsSame<InputType, OutputType>::Value &&
272           cpp::IsFloatingPointType<InputType>::Value) ||
273          (Operation::BeginUnaryOperationsTwoOutputs < op &&
274           op < Operation::EndUnaryOperationsTwoOutputs &&
275           cpp::IsFloatingPointType<InputType>::Value &&
276           cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) ||
277          (Operation::BeginBinaryOperationsSingleOutput < op &&
278           op < Operation::EndBinaryOperationsSingleOutput &&
279           cpp::IsFloatingPointType<OutputType>::Value &&
280           cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) ||
281          (Operation::BeginBinaryOperationsTwoOutputs < op &&
282           op < Operation::EndBinaryOperationsTwoOutputs &&
283           internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
284                                                           OutputType>::VALUE) ||
285          (Operation::BeginTernaryOperationsSingleOuput < op &&
286           op < Operation::EndTernaryOperationsSingleOutput &&
287           cpp::IsFloatingPointType<OutputType>::Value &&
288           cpp::IsSame<InputType, TernaryInput<OutputType>>::Value);
289 }
290 
291 template <Operation op, typename InputType, typename OutputType>
292 __attribute__((no_sanitize("address")))
293 cpp::EnableIfType<is_valid_operation<op, InputType, OutputType>(),
294                   internal::MPFRMatcher<op, InputType, OutputType>>
295 get_mpfr_matcher(InputType input, OutputType output_unused,
296                  double ulp_tolerance, RoundingMode rounding) {
297   return internal::MPFRMatcher<op, InputType, OutputType>(input, ulp_tolerance,
298                                                           rounding);
299 }
300 
301 template <typename T> T round(T x, RoundingMode mode);
302 
303 template <typename T> bool round_to_long(T x, long &result);
304 template <typename T> bool round_to_long(T x, RoundingMode mode, long &result);
305 
306 } // namespace mpfr
307 } // namespace testing
308 } // namespace __llvm_libc
309 
310 // GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a
311 // simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`.
312 #define GET_MPFR_DUMMY_ARG(...) 0
313 
314 #define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME
315 
316 #define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
317   EXPECT_THAT(match_value,                                                     \
318               __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(                \
319                   input, match_value, ulp_tolerance,                           \
320                   __llvm_libc::testing::mpfr::RoundingMode::Nearest))
321 
322 #define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
323                                    rounding)                                   \
324   EXPECT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(   \
325                                input, match_value, ulp_tolerance, rounding))
326 
327 #define EXPECT_MPFR_MATCH(...)                                                 \
328   GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING,                      \
329                  EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
330   (__VA_ARGS__)
331 
332 #define EXPECT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)  \
333   {                                                                            \
334     namespace mpfr = __llvm_libc::testing::mpfr;                               \
335     mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest);                 \
336     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
337                       mpfr::RoundingMode::Nearest);                            \
338     mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward);                  \
339     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
340                       mpfr::RoundingMode::Upward);                             \
341     mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward);                \
342     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
343                       mpfr::RoundingMode::Downward);                           \
344     mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero);              \
345     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
346                       mpfr::RoundingMode::TowardZero);                         \
347   }
348 
349 #define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
350   ASSERT_THAT(match_value,                                                     \
351               __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(                \
352                   input, match_value, ulp_tolerance,                           \
353                   __llvm_libc::testing::mpfr::RoundingMode::Nearest))
354 
355 #define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
356                                    rounding)                                   \
357   ASSERT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(   \
358                                input, match_value, ulp_tolerance, rounding))
359 
360 #define ASSERT_MPFR_MATCH(...)                                                 \
361   GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING,                      \
362                  ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
363   (__VA_ARGS__)
364 
365 #define ASSERT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)  \
366   {                                                                            \
367     namespace mpfr = __llvm_libc::testing::mpfr;                               \
368     mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest);                 \
369     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
370                       mpfr::RoundingMode::Nearest);                            \
371     mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward);                  \
372     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
373                       mpfr::RoundingMode::Upward);                             \
374     mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward);                \
375     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
376                       mpfr::RoundingMode::Downward);                           \
377     mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero);              \
378     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
379                       mpfr::RoundingMode::TowardZero);                         \
380   }
381 
382 #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
383