1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef _LIBCPP___ALGORITHM_PSTL_BACKENDS_CPU_BACKENDS_TRANSFORM_REDUCE_H
10 #define _LIBCPP___ALGORITHM_PSTL_BACKENDS_CPU_BACKENDS_TRANSFORM_REDUCE_H
11
12 #include <__algorithm/pstl_backends/cpu_backends/backend.h>
13 #include <__config>
14 #include <__iterator/concepts.h>
15 #include <__iterator/iterator_traits.h>
16 #include <__numeric/transform_reduce.h>
17 #include <__type_traits/is_arithmetic.h>
18 #include <__type_traits/is_execution_policy.h>
19 #include <__type_traits/operation_traits.h>
20 #include <__utility/move.h>
21 #include <new>
22 #include <optional>
23
24 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
25 # pragma GCC system_header
26 #endif
27
28 _LIBCPP_PUSH_MACROS
29 #include <__undef_macros>
30
31 #if !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
32
33 _LIBCPP_BEGIN_NAMESPACE_STD
34
35 template <typename _DifferenceType,
36 typename _Tp,
37 typename _BinaryOperation,
38 typename _UnaryOperation,
39 typename _UnaryResult = invoke_result_t<_UnaryOperation, _DifferenceType>,
40 __enable_if_t<__desugars_to<__plus_tag, _BinaryOperation, _Tp, _UnaryResult>::value && is_arithmetic_v<_Tp> &&
41 is_arithmetic_v<_UnaryResult>,
42 int> = 0>
43 _LIBCPP_HIDE_FROM_ABI _Tp
__simd_transform_reduce(_DifferenceType __n,_Tp __init,_BinaryOperation,_UnaryOperation __f)44 __simd_transform_reduce(_DifferenceType __n, _Tp __init, _BinaryOperation, _UnaryOperation __f) noexcept {
45 _PSTL_PRAGMA_SIMD_REDUCTION(+ : __init)
46 for (_DifferenceType __i = 0; __i < __n; ++__i)
47 __init += __f(__i);
48 return __init;
49 }
50
51 template <typename _Size,
52 typename _Tp,
53 typename _BinaryOperation,
54 typename _UnaryOperation,
55 typename _UnaryResult = invoke_result_t<_UnaryOperation, _Size>,
56 __enable_if_t<!(__desugars_to<__plus_tag, _BinaryOperation, _Tp, _UnaryResult>::value &&
57 is_arithmetic_v<_Tp> && is_arithmetic_v<_UnaryResult>),
58 int> = 0>
59 _LIBCPP_HIDE_FROM_ABI _Tp
__simd_transform_reduce(_Size __n,_Tp __init,_BinaryOperation __binary_op,_UnaryOperation __f)60 __simd_transform_reduce(_Size __n, _Tp __init, _BinaryOperation __binary_op, _UnaryOperation __f) noexcept {
61 const _Size __block_size = __lane_size / sizeof(_Tp);
62 if (__n > 2 * __block_size && __block_size > 1) {
63 alignas(__lane_size) char __lane_buffer[__lane_size];
64 _Tp* __lane = reinterpret_cast<_Tp*>(__lane_buffer);
65
66 // initializer
67 _PSTL_PRAGMA_SIMD
68 for (_Size __i = 0; __i < __block_size; ++__i) {
69 ::new (__lane + __i) _Tp(__binary_op(__f(__i), __f(__block_size + __i)));
70 }
71 // main loop
72 _Size __i = 2 * __block_size;
73 const _Size __last_iteration = __block_size * (__n / __block_size);
74 for (; __i < __last_iteration; __i += __block_size) {
75 _PSTL_PRAGMA_SIMD
76 for (_Size __j = 0; __j < __block_size; ++__j) {
77 __lane[__j] = __binary_op(std::move(__lane[__j]), __f(__i + __j));
78 }
79 }
80 // remainder
81 _PSTL_PRAGMA_SIMD
82 for (_Size __j = 0; __j < __n - __last_iteration; ++__j) {
83 __lane[__j] = __binary_op(std::move(__lane[__j]), __f(__last_iteration + __j));
84 }
85 // combiner
86 for (_Size __j = 0; __j < __block_size; ++__j) {
87 __init = __binary_op(std::move(__init), std::move(__lane[__j]));
88 }
89 // destroyer
90 _PSTL_PRAGMA_SIMD
91 for (_Size __j = 0; __j < __block_size; ++__j) {
92 __lane[__j].~_Tp();
93 }
94 } else {
95 for (_Size __i = 0; __i < __n; ++__i) {
96 __init = __binary_op(std::move(__init), __f(__i));
97 }
98 }
99 return __init;
100 }
101
102 template <class _ExecutionPolicy,
103 class _ForwardIterator1,
104 class _ForwardIterator2,
105 class _Tp,
106 class _BinaryOperation1,
107 class _BinaryOperation2>
__pstl_transform_reduce(__cpu_backend_tag,_ForwardIterator1 __first1,_ForwardIterator1 __last1,_ForwardIterator2 __first2,_Tp __init,_BinaryOperation1 __reduce,_BinaryOperation2 __transform)108 _LIBCPP_HIDE_FROM_ABI optional<_Tp> __pstl_transform_reduce(
109 __cpu_backend_tag,
110 _ForwardIterator1 __first1,
111 _ForwardIterator1 __last1,
112 _ForwardIterator2 __first2,
113 _Tp __init,
114 _BinaryOperation1 __reduce,
115 _BinaryOperation2 __transform) {
116 if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> &&
117 __has_random_access_iterator_category_or_concept<_ForwardIterator1>::value &&
118 __has_random_access_iterator_category_or_concept<_ForwardIterator2>::value) {
119 return __par_backend::__parallel_transform_reduce(
120 __first1,
121 std::move(__last1),
122 [__first1, __first2, __transform](_ForwardIterator1 __iter) {
123 return __transform(*__iter, *(__first2 + (__iter - __first1)));
124 },
125 std::move(__init),
126 std::move(__reduce),
127 [__first1, __first2, __reduce, __transform](
128 _ForwardIterator1 __brick_first, _ForwardIterator1 __brick_last, _Tp __brick_init) {
129 return *std::__pstl_transform_reduce<__remove_parallel_policy_t<_ExecutionPolicy>>(
130 __cpu_backend_tag{},
131 __brick_first,
132 std::move(__brick_last),
133 __first2 + (__brick_first - __first1),
134 std::move(__brick_init),
135 std::move(__reduce),
136 std::move(__transform));
137 });
138 } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> &&
139 __has_random_access_iterator_category_or_concept<_ForwardIterator1>::value &&
140 __has_random_access_iterator_category_or_concept<_ForwardIterator2>::value) {
141 return std::__simd_transform_reduce(
142 __last1 - __first1, std::move(__init), std::move(__reduce), [&](__iter_diff_t<_ForwardIterator1> __i) {
143 return __transform(__first1[__i], __first2[__i]);
144 });
145 } else {
146 return std::transform_reduce(
147 std::move(__first1),
148 std::move(__last1),
149 std::move(__first2),
150 std::move(__init),
151 std::move(__reduce),
152 std::move(__transform));
153 }
154 }
155
156 template <class _ExecutionPolicy, class _ForwardIterator, class _Tp, class _BinaryOperation, class _UnaryOperation>
__pstl_transform_reduce(__cpu_backend_tag,_ForwardIterator __first,_ForwardIterator __last,_Tp __init,_BinaryOperation __reduce,_UnaryOperation __transform)157 _LIBCPP_HIDE_FROM_ABI optional<_Tp> __pstl_transform_reduce(
158 __cpu_backend_tag,
159 _ForwardIterator __first,
160 _ForwardIterator __last,
161 _Tp __init,
162 _BinaryOperation __reduce,
163 _UnaryOperation __transform) {
164 if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> &&
165 __has_random_access_iterator_category_or_concept<_ForwardIterator>::value) {
166 return __par_backend::__parallel_transform_reduce(
167 std::move(__first),
168 std::move(__last),
169 [__transform](_ForwardIterator __iter) { return __transform(*__iter); },
170 std::move(__init),
171 __reduce,
172 [__transform, __reduce](auto __brick_first, auto __brick_last, _Tp __brick_init) {
173 auto __res = std::__pstl_transform_reduce<__remove_parallel_policy_t<_ExecutionPolicy>>(
174 __cpu_backend_tag{},
175 std::move(__brick_first),
176 std::move(__brick_last),
177 std::move(__brick_init),
178 std::move(__reduce),
179 std::move(__transform));
180 _LIBCPP_ASSERT_INTERNAL(__res, "unseq/seq should never try to allocate!");
181 return *std::move(__res);
182 });
183 } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> &&
184 __has_random_access_iterator_category_or_concept<_ForwardIterator>::value) {
185 return std::__simd_transform_reduce(
186 __last - __first,
187 std::move(__init),
188 std::move(__reduce),
189 [=, &__transform](__iter_diff_t<_ForwardIterator> __i) { return __transform(__first[__i]); });
190 } else {
191 return std::transform_reduce(
192 std::move(__first), std::move(__last), std::move(__init), std::move(__reduce), std::move(__transform));
193 }
194 }
195
196 _LIBCPP_END_NAMESPACE_STD
197
198 #endif // !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
199
200 _LIBCPP_POP_MACROS
201
202 #endif // _LIBCPP___ALGORITHM_PSTL_BACKENDS_CPU_BACKENDS_TRANSFORM_REDUCE_H
203