1 //===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
10 #define LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
11 
12 #include "src/__support/CPP/TypeTraits.h"
13 #include "utils/UnitTest/Test.h"
14 #include "utils/testutils/RoundingModeUtils.h"
15 
16 #include <stdint.h>
17 
18 namespace __llvm_libc {
19 namespace testing {
20 namespace mpfr {
21 
22 enum class Operation : int {
23   // Operations with take a single floating point number as input
24   // and produce a single floating point number as output. The input
25   // and output floating point numbers are of the same kind.
26   BeginUnaryOperationsSingleOutput,
27   Abs,
28   Ceil,
29   Cos,
30   Exp,
31   Exp2,
32   Expm1,
33   Floor,
34   Log,
35   Log2,
36   Log10,
37   Log1p,
38   Mod2PI,
39   ModPIOver2,
40   ModPIOver4,
41   Round,
42   Sin,
43   Sqrt,
44   Tan,
45   Trunc,
46   EndUnaryOperationsSingleOutput,
47 
48   // Operations which take a single floating point nubmer as input
49   // but produce two outputs. The first ouput is a floating point
50   // number of the same type as the input. The second output is of type
51   // 'int'.
52   BeginUnaryOperationsTwoOutputs,
53   Frexp, // Floating point output, the first output, is the fractional part.
54   EndUnaryOperationsTwoOutputs,
55 
56   // Operations wich take two floating point nubmers of the same type as
57   // input and produce a single floating point number of the same type as
58   // output.
59   BeginBinaryOperationsSingleOutput,
60   Fmod,
61   Hypot,
62   EndBinaryOperationsSingleOutput,
63 
64   // Operations which take two floating point numbers of the same type as
65   // input and produce two outputs. The first output is a floating nubmer of
66   // the same type as the inputs. The second output is af type 'int'.
67   BeginBinaryOperationsTwoOutputs,
68   RemQuo, // The first output, the floating point output, is the remainder.
69   EndBinaryOperationsTwoOutputs,
70 
71   // Operations which take three floating point nubmers of the same type as
72   // input and produce a single floating point number of the same type as
73   // output.
74   BeginTernaryOperationsSingleOuput,
75   Fma,
76   EndTernaryOperationsSingleOutput,
77 };
78 
79 using __llvm_libc::testutils::ForceRoundingMode;
80 using __llvm_libc::testutils::RoundingMode;
81 
82 template <typename T> struct BinaryInput {
83   static_assert(
84       __llvm_libc::cpp::IsFloatingPointType<T>::Value,
85       "Template parameter of BinaryInput must be a floating point type.");
86 
87   using Type = T;
88   T x, y;
89 };
90 
91 template <typename T> struct TernaryInput {
92   static_assert(
93       __llvm_libc::cpp::IsFloatingPointType<T>::Value,
94       "Template parameter of TernaryInput must be a floating point type.");
95 
96   using Type = T;
97   T x, y, z;
98 };
99 
100 template <typename T> struct BinaryOutput {
101   T f;
102   int i;
103 };
104 
105 namespace internal {
106 
107 template <typename T1, typename T2>
108 struct AreMatchingBinaryInputAndBinaryOutput {
109   static constexpr bool VALUE = false;
110 };
111 
112 template <typename T>
113 struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
114   static constexpr bool VALUE = cpp::IsFloatingPointType<T>::Value;
115 };
116 
117 template <typename T>
118 bool compare_unary_operation_single_output(Operation op, T input, T libc_output,
119                                            double ulp_tolerance,
120                                            RoundingMode rounding);
121 template <typename T>
122 bool compare_unary_operation_two_outputs(Operation op, T input,
123                                          const BinaryOutput<T> &libc_output,
124                                          double ulp_tolerance,
125                                          RoundingMode rounding);
126 template <typename T>
127 bool compare_binary_operation_two_outputs(Operation op,
128                                           const BinaryInput<T> &input,
129                                           const BinaryOutput<T> &libc_output,
130                                           double ulp_tolerance,
131                                           RoundingMode rounding);
132 
133 template <typename T>
134 bool compare_binary_operation_one_output(Operation op,
135                                          const BinaryInput<T> &input,
136                                          T libc_output, double ulp_tolerance,
137                                          RoundingMode rounding);
138 
139 template <typename T>
140 bool compare_ternary_operation_one_output(Operation op,
141                                           const TernaryInput<T> &input,
142                                           T libc_output, double ulp_tolerance,
143                                           RoundingMode rounding);
144 
145 template <typename T>
146 void explain_unary_operation_single_output_error(Operation op, T input,
147                                                  T match_value,
148                                                  double ulp_tolerance,
149                                                  RoundingMode rounding,
150                                                  testutils::StreamWrapper &OS);
151 template <typename T>
152 void explain_unary_operation_two_outputs_error(
153     Operation op, T input, const BinaryOutput<T> &match_value,
154     double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
155 template <typename T>
156 void explain_binary_operation_two_outputs_error(
157     Operation op, const BinaryInput<T> &input,
158     const BinaryOutput<T> &match_value, double ulp_tolerance,
159     RoundingMode rounding, testutils::StreamWrapper &OS);
160 
161 template <typename T>
162 void explain_binary_operation_one_output_error(
163     Operation op, const BinaryInput<T> &input, T match_value,
164     double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
165 
166 template <typename T>
167 void explain_ternary_operation_one_output_error(
168     Operation op, const TernaryInput<T> &input, T match_value,
169     double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
170 
171 template <Operation op, typename InputType, typename OutputType>
172 class MPFRMatcher : public testing::Matcher<OutputType> {
173   InputType input;
174   OutputType match_value;
175   double ulp_tolerance;
176   RoundingMode rounding;
177 
178 public:
179   MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
180       : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}
181 
182   bool match(OutputType libcResult) {
183     match_value = libcResult;
184     return match(input, match_value);
185   }
186 
187   // This method is marked with NOLINT because it the name `explainError`
188   // does not confirm to the coding style.
189   void explainError(testutils::StreamWrapper &OS) override { // NOLINT
190     explain_error(input, match_value, OS);
191   }
192 
193 private:
194   template <typename T> bool match(T in, T out) {
195     return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
196                                                  rounding);
197   }
198 
199   template <typename T> bool match(T in, const BinaryOutput<T> &out) {
200     return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
201                                                rounding);
202   }
203 
204   template <typename T> bool match(const BinaryInput<T> &in, T out) {
205     return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
206                                                rounding);
207   }
208 
209   template <typename T>
210   bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
211     return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
212                                                 rounding);
213   }
214 
215   template <typename T> bool match(const TernaryInput<T> &in, T out) {
216     return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
217                                                 rounding);
218   }
219 
220   template <typename T>
221   void explain_error(T in, T out, testutils::StreamWrapper &OS) {
222     explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
223                                                 rounding, OS);
224   }
225 
226   template <typename T>
227   void explain_error(T in, const BinaryOutput<T> &out,
228                      testutils::StreamWrapper &OS) {
229     explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
230                                               rounding, OS);
231   }
232 
233   template <typename T>
234   void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out,
235                      testutils::StreamWrapper &OS) {
236     explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
237                                                rounding, OS);
238   }
239 
240   template <typename T>
241   void explain_error(const BinaryInput<T> &in, T out,
242                      testutils::StreamWrapper &OS) {
243     explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
244                                               rounding, OS);
245   }
246 
247   template <typename T>
248   void explain_error(const TernaryInput<T> &in, T out,
249                      testutils::StreamWrapper &OS) {
250     explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
251                                                rounding, OS);
252   }
253 };
254 
255 } // namespace internal
256 
257 // Return true if the input and ouput types for the operation op are valid
258 // types.
259 template <Operation op, typename InputType, typename OutputType>
260 constexpr bool is_valid_operation() {
261   return (Operation::BeginUnaryOperationsSingleOutput < op &&
262           op < Operation::EndUnaryOperationsSingleOutput &&
263           cpp::IsSame<InputType, OutputType>::Value &&
264           cpp::IsFloatingPointType<InputType>::Value) ||
265          (Operation::BeginUnaryOperationsTwoOutputs < op &&
266           op < Operation::EndUnaryOperationsTwoOutputs &&
267           cpp::IsFloatingPointType<InputType>::Value &&
268           cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) ||
269          (Operation::BeginBinaryOperationsSingleOutput < op &&
270           op < Operation::EndBinaryOperationsSingleOutput &&
271           cpp::IsFloatingPointType<OutputType>::Value &&
272           cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) ||
273          (Operation::BeginBinaryOperationsTwoOutputs < op &&
274           op < Operation::EndBinaryOperationsTwoOutputs &&
275           internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
276                                                           OutputType>::VALUE) ||
277          (Operation::BeginTernaryOperationsSingleOuput < op &&
278           op < Operation::EndTernaryOperationsSingleOutput &&
279           cpp::IsFloatingPointType<OutputType>::Value &&
280           cpp::IsSame<InputType, TernaryInput<OutputType>>::Value);
281 }
282 
283 template <Operation op, typename InputType, typename OutputType>
284 __attribute__((no_sanitize("address")))
285 cpp::EnableIfType<is_valid_operation<op, InputType, OutputType>(),
286                   internal::MPFRMatcher<op, InputType, OutputType>>
287 get_mpfr_matcher(InputType input, OutputType output_unused,
288                  double ulp_tolerance, RoundingMode rounding) {
289   return internal::MPFRMatcher<op, InputType, OutputType>(input, ulp_tolerance,
290                                                           rounding);
291 }
292 
293 template <typename T> T round(T x, RoundingMode mode);
294 
295 template <typename T> bool round_to_long(T x, long &result);
296 template <typename T> bool round_to_long(T x, RoundingMode mode, long &result);
297 
298 } // namespace mpfr
299 } // namespace testing
300 } // namespace __llvm_libc
301 
302 // GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a
303 // simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`.
304 #define GET_MPFR_DUMMY_ARG(...) 0
305 
306 #define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME
307 
308 #define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
309   EXPECT_THAT(match_value,                                                     \
310               __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(                \
311                   input, match_value, ulp_tolerance,                           \
312                   __llvm_libc::testing::mpfr::RoundingMode::Nearest))
313 
314 #define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
315                                    rounding)                                   \
316   EXPECT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(   \
317                                input, match_value, ulp_tolerance, rounding))
318 
319 #define EXPECT_MPFR_MATCH(...)                                                 \
320   GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING,                      \
321                  EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
322   (__VA_ARGS__)
323 
324 #define EXPECT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)  \
325   {                                                                            \
326     namespace mpfr = __llvm_libc::testing::mpfr;                               \
327     mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest);                 \
328     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
329                       mpfr::RoundingMode::Nearest);                            \
330     mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward);                  \
331     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
332                       mpfr::RoundingMode::Upward);                             \
333     mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward);                \
334     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
335                       mpfr::RoundingMode::Downward);                           \
336     mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero);              \
337     EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
338                       mpfr::RoundingMode::TowardZero);                         \
339   }
340 
341 #define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
342   ASSERT_THAT(match_value,                                                     \
343               __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(                \
344                   input, match_value, ulp_tolerance,                           \
345                   __llvm_libc::testing::mpfr::RoundingMode::Nearest))
346 
347 #define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
348                                    rounding)                                   \
349   ASSERT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(   \
350                                input, match_value, ulp_tolerance, rounding))
351 
352 #define ASSERT_MPFR_MATCH(...)                                                 \
353   GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING,                      \
354                  ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
355   (__VA_ARGS__)
356 
357 #define ASSERT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)  \
358   {                                                                            \
359     namespace mpfr = __llvm_libc::testing::mpfr;                               \
360     mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest);                 \
361     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
362                       mpfr::RoundingMode::Nearest);                            \
363     mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward);                  \
364     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
365                       mpfr::RoundingMode::Upward);                             \
366     mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward);                \
367     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
368                       mpfr::RoundingMode::Downward);                           \
369     mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero);              \
370     ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                   \
371                       mpfr::RoundingMode::TowardZero);                         \
372   }
373 
374 #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
375