xref: /llvm-project-15.0.7/flang/runtime/sum.cpp (revision 4daa33f6)
1beb5ac8bSpeter klausler //===-- runtime/sum.cpp ---------------------------------------------------===//
2beb5ac8bSpeter klausler //
3beb5ac8bSpeter klausler // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4beb5ac8bSpeter klausler // See https://llvm.org/LICENSE.txt for license information.
5beb5ac8bSpeter klausler // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6beb5ac8bSpeter klausler //
7beb5ac8bSpeter klausler //===----------------------------------------------------------------------===//
8beb5ac8bSpeter klausler 
9beb5ac8bSpeter klausler // Implements SUM for all required operand types and shapes.
10beb5ac8bSpeter klausler //
11beb5ac8bSpeter klausler // Real and complex SUM reductions attempt to reduce floating-point
12c375ec86Speter klausler // cancellation on intermediate results by using "Kahan summation"
13c375ec86Speter klausler // (basically the same as manual "double-double").
14beb5ac8bSpeter klausler 
15beb5ac8bSpeter klausler #include "reduction-templates.h"
16*4daa33f6SPeter Klausler #include "flang/Runtime/float128.h"
17830c0b90SPeter Klausler #include "flang/Runtime/reduction.h"
18*4daa33f6SPeter Klausler #include <cfloat>
19beb5ac8bSpeter klausler #include <cinttypes>
20beb5ac8bSpeter klausler #include <complex>
21beb5ac8bSpeter klausler 
22beb5ac8bSpeter klausler namespace Fortran::runtime {
23beb5ac8bSpeter klausler 
24beb5ac8bSpeter klausler template <typename INTERMEDIATE> class IntegerSumAccumulator {
25beb5ac8bSpeter klausler public:
IntegerSumAccumulator(const Descriptor & array)26beb5ac8bSpeter klausler   explicit IntegerSumAccumulator(const Descriptor &array) : array_{array} {}
Reinitialize()27beb5ac8bSpeter klausler   void Reinitialize() { sum_ = 0; }
GetResult(A * p,int=-1) const28beb5ac8bSpeter klausler   template <typename A> void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
29beb5ac8bSpeter klausler     *p = static_cast<A>(sum_);
30beb5ac8bSpeter klausler   }
AccumulateAt(const SubscriptValue at[])31beb5ac8bSpeter klausler   template <typename A> bool AccumulateAt(const SubscriptValue at[]) {
32beb5ac8bSpeter klausler     sum_ += *array_.Element<A>(at);
33beb5ac8bSpeter klausler     return true;
34beb5ac8bSpeter klausler   }
35beb5ac8bSpeter klausler 
36beb5ac8bSpeter klausler private:
37beb5ac8bSpeter klausler   const Descriptor &array_;
38beb5ac8bSpeter klausler   INTERMEDIATE sum_{0};
39beb5ac8bSpeter klausler };
40beb5ac8bSpeter klausler 
41beb5ac8bSpeter klausler template <typename INTERMEDIATE> class RealSumAccumulator {
42beb5ac8bSpeter klausler public:
RealSumAccumulator(const Descriptor & array)43beb5ac8bSpeter klausler   explicit RealSumAccumulator(const Descriptor &array) : array_{array} {}
Reinitialize()44c375ec86Speter klausler   void Reinitialize() { sum_ = correction_ = 0; }
Result() const45c375ec86Speter klausler   template <typename A> A Result() const { return sum_; }
GetResult(A * p,int=-1) const46beb5ac8bSpeter klausler   template <typename A> void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
47beb5ac8bSpeter klausler     *p = Result<A>();
48beb5ac8bSpeter klausler   }
Accumulate(A x)49beb5ac8bSpeter klausler   template <typename A> bool Accumulate(A x) {
50c375ec86Speter klausler     // Kahan summation
51c375ec86Speter klausler     auto next{x + correction_};
52c375ec86Speter klausler     auto oldSum{sum_};
53c375ec86Speter klausler     sum_ += next;
54c375ec86Speter klausler     correction_ = (sum_ - oldSum) - next; // algebraically zero
55beb5ac8bSpeter klausler     return true;
56beb5ac8bSpeter klausler   }
AccumulateAt(const SubscriptValue at[])57beb5ac8bSpeter klausler   template <typename A> bool AccumulateAt(const SubscriptValue at[]) {
58beb5ac8bSpeter klausler     return Accumulate(*array_.Element<A>(at));
59beb5ac8bSpeter klausler   }
60beb5ac8bSpeter klausler 
61beb5ac8bSpeter klausler private:
62beb5ac8bSpeter klausler   const Descriptor &array_;
63c375ec86Speter klausler   INTERMEDIATE sum_{0.0}, correction_{0.0};
64beb5ac8bSpeter klausler };
65beb5ac8bSpeter klausler 
66beb5ac8bSpeter klausler template <typename PART> class ComplexSumAccumulator {
67beb5ac8bSpeter klausler public:
ComplexSumAccumulator(const Descriptor & array)68beb5ac8bSpeter klausler   explicit ComplexSumAccumulator(const Descriptor &array) : array_{array} {}
Reinitialize()69beb5ac8bSpeter klausler   void Reinitialize() {
70beb5ac8bSpeter klausler     reals_.Reinitialize();
71beb5ac8bSpeter klausler     imaginaries_.Reinitialize();
72beb5ac8bSpeter klausler   }
GetResult(A * p,int=-1) const73beb5ac8bSpeter klausler   template <typename A> void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
74beb5ac8bSpeter klausler     using ResultPart = typename A::value_type;
75beb5ac8bSpeter klausler     *p = {reals_.template Result<ResultPart>(),
76beb5ac8bSpeter klausler         imaginaries_.template Result<ResultPart>()};
77beb5ac8bSpeter klausler   }
Accumulate(const A & z)78beb5ac8bSpeter klausler   template <typename A> bool Accumulate(const A &z) {
79beb5ac8bSpeter klausler     reals_.Accumulate(z.real());
80beb5ac8bSpeter klausler     imaginaries_.Accumulate(z.imag());
81beb5ac8bSpeter klausler     return true;
82beb5ac8bSpeter klausler   }
AccumulateAt(const SubscriptValue at[])83beb5ac8bSpeter klausler   template <typename A> bool AccumulateAt(const SubscriptValue at[]) {
84beb5ac8bSpeter klausler     return Accumulate(*array_.Element<A>(at));
85beb5ac8bSpeter klausler   }
86beb5ac8bSpeter klausler 
87beb5ac8bSpeter klausler private:
88beb5ac8bSpeter klausler   const Descriptor &array_;
89beb5ac8bSpeter klausler   RealSumAccumulator<PART> reals_{array_}, imaginaries_{array_};
90beb5ac8bSpeter klausler };
91beb5ac8bSpeter klausler 
92beb5ac8bSpeter klausler extern "C" {
RTNAME(SumInteger1)93beb5ac8bSpeter klausler CppTypeFor<TypeCategory::Integer, 1> RTNAME(SumInteger1)(const Descriptor &x,
94beb5ac8bSpeter klausler     const char *source, int line, int dim, const Descriptor *mask) {
95beb5ac8bSpeter klausler   return GetTotalReduction<TypeCategory::Integer, 1>(x, source, line, dim, mask,
96beb5ac8bSpeter klausler       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
97beb5ac8bSpeter klausler }
RTNAME(SumInteger2)98beb5ac8bSpeter klausler CppTypeFor<TypeCategory::Integer, 2> RTNAME(SumInteger2)(const Descriptor &x,
99beb5ac8bSpeter klausler     const char *source, int line, int dim, const Descriptor *mask) {
100beb5ac8bSpeter klausler   return GetTotalReduction<TypeCategory::Integer, 2>(x, source, line, dim, mask,
101beb5ac8bSpeter klausler       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
102beb5ac8bSpeter klausler }
RTNAME(SumInteger4)103beb5ac8bSpeter klausler CppTypeFor<TypeCategory::Integer, 4> RTNAME(SumInteger4)(const Descriptor &x,
104beb5ac8bSpeter klausler     const char *source, int line, int dim, const Descriptor *mask) {
105beb5ac8bSpeter klausler   return GetTotalReduction<TypeCategory::Integer, 4>(x, source, line, dim, mask,
106beb5ac8bSpeter klausler       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
107beb5ac8bSpeter klausler }
RTNAME(SumInteger8)108beb5ac8bSpeter klausler CppTypeFor<TypeCategory::Integer, 8> RTNAME(SumInteger8)(const Descriptor &x,
109beb5ac8bSpeter klausler     const char *source, int line, int dim, const Descriptor *mask) {
110beb5ac8bSpeter klausler   return GetTotalReduction<TypeCategory::Integer, 8>(x, source, line, dim, mask,
111beb5ac8bSpeter klausler       IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 8>>{x}, "SUM");
112beb5ac8bSpeter klausler }
113beb5ac8bSpeter klausler #ifdef __SIZEOF_INT128__
RTNAME(SumInteger16)114beb5ac8bSpeter klausler CppTypeFor<TypeCategory::Integer, 16> RTNAME(SumInteger16)(const Descriptor &x,
115beb5ac8bSpeter klausler     const char *source, int line, int dim, const Descriptor *mask) {
116beb5ac8bSpeter klausler   return GetTotalReduction<TypeCategory::Integer, 16>(x, source, line, dim,
117beb5ac8bSpeter klausler       mask, IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 16>>{x},
118beb5ac8bSpeter klausler       "SUM");
119beb5ac8bSpeter klausler }
120beb5ac8bSpeter klausler #endif
121beb5ac8bSpeter klausler 
122beb5ac8bSpeter klausler // TODO: real/complex(2 & 3)
RTNAME(SumReal4)123beb5ac8bSpeter klausler CppTypeFor<TypeCategory::Real, 4> RTNAME(SumReal4)(const Descriptor &x,
124beb5ac8bSpeter klausler     const char *source, int line, int dim, const Descriptor *mask) {
125beb5ac8bSpeter klausler   return GetTotalReduction<TypeCategory::Real, 4>(
126beb5ac8bSpeter klausler       x, source, line, dim, mask, RealSumAccumulator<double>{x}, "SUM");
127beb5ac8bSpeter klausler }
RTNAME(SumReal8)128beb5ac8bSpeter klausler CppTypeFor<TypeCategory::Real, 8> RTNAME(SumReal8)(const Descriptor &x,
129beb5ac8bSpeter klausler     const char *source, int line, int dim, const Descriptor *mask) {
130beb5ac8bSpeter klausler   return GetTotalReduction<TypeCategory::Real, 8>(
131beb5ac8bSpeter klausler       x, source, line, dim, mask, RealSumAccumulator<double>{x}, "SUM");
132beb5ac8bSpeter klausler }
133*4daa33f6SPeter Klausler #if LDBL_MANT_DIG == 64
RTNAME(SumReal10)134beb5ac8bSpeter klausler CppTypeFor<TypeCategory::Real, 10> RTNAME(SumReal10)(const Descriptor &x,
135beb5ac8bSpeter klausler     const char *source, int line, int dim, const Descriptor *mask) {
136beb5ac8bSpeter klausler   return GetTotalReduction<TypeCategory::Real, 10>(
137beb5ac8bSpeter klausler       x, source, line, dim, mask, RealSumAccumulator<long double>{x}, "SUM");
138beb5ac8bSpeter klausler }
139*4daa33f6SPeter Klausler #endif
140*4daa33f6SPeter Klausler #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
RTNAME(SumReal16)141beb5ac8bSpeter klausler CppTypeFor<TypeCategory::Real, 16> RTNAME(SumReal16)(const Descriptor &x,
142beb5ac8bSpeter klausler     const char *source, int line, int dim, const Descriptor *mask) {
143beb5ac8bSpeter klausler   return GetTotalReduction<TypeCategory::Real, 16>(
144beb5ac8bSpeter klausler       x, source, line, dim, mask, RealSumAccumulator<long double>{x}, "SUM");
145beb5ac8bSpeter klausler }
146beb5ac8bSpeter klausler #endif
147beb5ac8bSpeter klausler 
RTNAME(CppSumComplex4)148beb5ac8bSpeter klausler void RTNAME(CppSumComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
149beb5ac8bSpeter klausler     const Descriptor &x, const char *source, int line, int dim,
150beb5ac8bSpeter klausler     const Descriptor *mask) {
151beb5ac8bSpeter klausler   result = GetTotalReduction<TypeCategory::Complex, 4>(
152beb5ac8bSpeter klausler       x, source, line, dim, mask, ComplexSumAccumulator<double>{x}, "SUM");
153beb5ac8bSpeter klausler }
RTNAME(CppSumComplex8)154beb5ac8bSpeter klausler void RTNAME(CppSumComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
155beb5ac8bSpeter klausler     const Descriptor &x, const char *source, int line, int dim,
156beb5ac8bSpeter klausler     const Descriptor *mask) {
157beb5ac8bSpeter klausler   result = GetTotalReduction<TypeCategory::Complex, 8>(
158beb5ac8bSpeter klausler       x, source, line, dim, mask, ComplexSumAccumulator<double>{x}, "SUM");
159beb5ac8bSpeter klausler }
160*4daa33f6SPeter Klausler #if LDBL_MANT_DIG == 64
RTNAME(CppSumComplex10)161beb5ac8bSpeter klausler void RTNAME(CppSumComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
162beb5ac8bSpeter klausler     const Descriptor &x, const char *source, int line, int dim,
163beb5ac8bSpeter klausler     const Descriptor *mask) {
164beb5ac8bSpeter klausler   result = GetTotalReduction<TypeCategory::Complex, 10>(
165beb5ac8bSpeter klausler       x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
166beb5ac8bSpeter klausler }
167*4daa33f6SPeter Klausler #elif LDBL_MANT_DIG == 113
RTNAME(CppSumComplex16)168beb5ac8bSpeter klausler void RTNAME(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
169beb5ac8bSpeter klausler     const Descriptor &x, const char *source, int line, int dim,
170beb5ac8bSpeter klausler     const Descriptor *mask) {
171beb5ac8bSpeter klausler   result = GetTotalReduction<TypeCategory::Complex, 16>(
172beb5ac8bSpeter klausler       x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
173beb5ac8bSpeter klausler }
174beb5ac8bSpeter klausler #endif
175beb5ac8bSpeter klausler 
RTNAME(SumDim)176beb5ac8bSpeter klausler void RTNAME(SumDim)(Descriptor &result, const Descriptor &x, int dim,
177beb5ac8bSpeter klausler     const char *source, int line, const Descriptor *mask) {
178beb5ac8bSpeter klausler   TypedPartialNumericReduction<IntegerSumAccumulator, RealSumAccumulator,
179beb5ac8bSpeter klausler       ComplexSumAccumulator>(result, x, dim, source, line, mask, "SUM");
180beb5ac8bSpeter klausler }
181beb5ac8bSpeter klausler } // extern "C"
182beb5ac8bSpeter klausler } // namespace Fortran::runtime
183