xref: /llvm-project-15.0.7/flang/runtime/sum.cpp (revision e552fa28)
1 //===-- runtime/sum.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // Implements SUM for all required operand types and shapes.
10 //
11 // Real and complex SUM reductions attempt to reduce floating-point
12 // cancellation on intermediate results by adding up partial sums
13 // for positive and negative elements independently.
14 
15 #include "reduction-templates.h"
16 #include "reduction.h"
17 #include "flang/Common/long-double.h"
18 #include <cinttypes>
19 #include <complex>
20 
21 namespace Fortran::runtime {
22 
23 template <typename INTERMEDIATE> class IntegerSumAccumulator {
24 public:
25   explicit IntegerSumAccumulator(const Descriptor &array) : array_{array} {}
26   void Reinitialize() { sum_ = 0; }
27   template <typename A> void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
28     *p = static_cast<A>(sum_);
29   }
30   template <typename A> bool AccumulateAt(const SubscriptValue at[]) {
31     sum_ += *array_.Element<A>(at);
32     return true;
33   }
34 
35 private:
36   const Descriptor &array_;
37   INTERMEDIATE sum_{0};
38 };
39 
40 template <typename INTERMEDIATE> class RealSumAccumulator {
41 public:
42   explicit RealSumAccumulator(const Descriptor &array) : array_{array} {}
43   void Reinitialize() { positives_ = negatives_ = inOrder_ = 0; }
44   template <typename A> A Result() const {
45     auto sum{static_cast<A>(positives_ + negatives_)};
46     return std::isfinite(sum) ? sum : static_cast<A>(inOrder_);
47   }
48   template <typename A> void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
49     *p = Result<A>();
50   }
51   template <typename A> bool Accumulate(A x) {
52     // Accumulate the nonnegative and negative elements independently
53     // to reduce cancellation; also record an in-order sum for use
54     // in case of overflow.
55     if (x >= 0) {
56       positives_ += x;
57     } else {
58       negatives_ += x;
59     }
60     inOrder_ += x;
61     return true;
62   }
63   template <typename A> bool AccumulateAt(const SubscriptValue at[]) {
64     return Accumulate(*array_.Element<A>(at));
65   }
66 
67 private:
68   const Descriptor &array_;
69   INTERMEDIATE positives_{0.0}, negatives_{0.0}, inOrder_{0.0};
70 };
71 
72 template <typename PART> class ComplexSumAccumulator {
73 public:
74   explicit ComplexSumAccumulator(const Descriptor &array) : array_{array} {}
75   void Reinitialize() {
76     reals_.Reinitialize();
77     imaginaries_.Reinitialize();
78   }
79   template <typename A> void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
80     using ResultPart = typename A::value_type;
81     *p = {reals_.template Result<ResultPart>(),
82         imaginaries_.template Result<ResultPart>()};
83   }
84   template <typename A> bool Accumulate(const A &z) {
85     reals_.Accumulate(z.real());
86     imaginaries_.Accumulate(z.imag());
87     return true;
88   }
89   template <typename A> bool AccumulateAt(const SubscriptValue at[]) {
90     return Accumulate(*array_.Element<A>(at));
91   }
92 
93 private:
94   const Descriptor &array_;
95   RealSumAccumulator<PART> reals_{array_}, imaginaries_{array_};
96 };
97 
98 extern "C" {
99 CppTypeFor<TypeCategory::Integer, 1> RTNAME(SumInteger1)(const Descriptor &x,
100     const char *source, int line, int dim, const Descriptor *mask) {
101   return GetTotalReduction<TypeCategory::Integer, 1>(x, source, line, dim, mask,
102       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
103 }
104 CppTypeFor<TypeCategory::Integer, 2> RTNAME(SumInteger2)(const Descriptor &x,
105     const char *source, int line, int dim, const Descriptor *mask) {
106   return GetTotalReduction<TypeCategory::Integer, 2>(x, source, line, dim, mask,
107       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
108 }
109 CppTypeFor<TypeCategory::Integer, 4> RTNAME(SumInteger4)(const Descriptor &x,
110     const char *source, int line, int dim, const Descriptor *mask) {
111   return GetTotalReduction<TypeCategory::Integer, 4>(x, source, line, dim, mask,
112       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
113 }
114 CppTypeFor<TypeCategory::Integer, 8> RTNAME(SumInteger8)(const Descriptor &x,
115     const char *source, int line, int dim, const Descriptor *mask) {
116   return GetTotalReduction<TypeCategory::Integer, 8>(x, source, line, dim, mask,
117       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 8>>{x}, "SUM");
118 }
119 #ifdef __SIZEOF_INT128__
120 CppTypeFor<TypeCategory::Integer, 16> RTNAME(SumInteger16)(const Descriptor &x,
121     const char *source, int line, int dim, const Descriptor *mask) {
122   return GetTotalReduction<TypeCategory::Integer, 16>(x, source, line, dim,
123       mask, IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 16>>{x},
124       "SUM");
125 }
126 #endif
127 
128 // TODO: real/complex(2 & 3)
129 CppTypeFor<TypeCategory::Real, 4> RTNAME(SumReal4)(const Descriptor &x,
130     const char *source, int line, int dim, const Descriptor *mask) {
131   return GetTotalReduction<TypeCategory::Real, 4>(
132       x, source, line, dim, mask, RealSumAccumulator<double>{x}, "SUM");
133 }
134 CppTypeFor<TypeCategory::Real, 8> RTNAME(SumReal8)(const Descriptor &x,
135     const char *source, int line, int dim, const Descriptor *mask) {
136   return GetTotalReduction<TypeCategory::Real, 8>(
137       x, source, line, dim, mask, RealSumAccumulator<double>{x}, "SUM");
138 }
139 #if LONG_DOUBLE == 80
140 CppTypeFor<TypeCategory::Real, 10> RTNAME(SumReal10)(const Descriptor &x,
141     const char *source, int line, int dim, const Descriptor *mask) {
142   return GetTotalReduction<TypeCategory::Real, 10>(
143       x, source, line, dim, mask, RealSumAccumulator<long double>{x}, "SUM");
144 }
145 #elif LONG_DOUBLE == 128
146 CppTypeFor<TypeCategory::Real, 16> RTNAME(SumReal16)(const Descriptor &x,
147     const char *source, int line, int dim, const Descriptor *mask) {
148   return GetTotalReduction<TypeCategory::Real, 16>(
149       x, source, line, dim, mask, RealSumAccumulator<long double>{x}, "SUM");
150 }
151 #endif
152 
153 void RTNAME(CppSumComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
154     const Descriptor &x, const char *source, int line, int dim,
155     const Descriptor *mask) {
156   result = GetTotalReduction<TypeCategory::Complex, 4>(
157       x, source, line, dim, mask, ComplexSumAccumulator<double>{x}, "SUM");
158 }
159 void RTNAME(CppSumComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
160     const Descriptor &x, const char *source, int line, int dim,
161     const Descriptor *mask) {
162   result = GetTotalReduction<TypeCategory::Complex, 8>(
163       x, source, line, dim, mask, ComplexSumAccumulator<double>{x}, "SUM");
164 }
165 #if LONG_DOUBLE == 80
166 void RTNAME(CppSumComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
167     const Descriptor &x, const char *source, int line, int dim,
168     const Descriptor *mask) {
169   result = GetTotalReduction<TypeCategory::Complex, 10>(
170       x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
171 }
172 #elif LONG_DOUBLE == 128
173 void RTNAME(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
174     const Descriptor &x, const char *source, int line, int dim,
175     const Descriptor *mask) {
176   result = GetTotalReduction<TypeCategory::Complex, 16>(
177       x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
178 }
179 #endif
180 
181 void RTNAME(SumDim)(Descriptor &result, const Descriptor &x, int dim,
182     const char *source, int line, const Descriptor *mask) {
183   TypedPartialNumericReduction<IntegerSumAccumulator, RealSumAccumulator,
184       ComplexSumAccumulator>(result, x, dim, source, line, mask, "SUM");
185 }
186 } // extern "C"
187 } // namespace Fortran::runtime
188