1 //===-- runtime/dot-product.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "float.h"
10 #include "terminator.h"
11 #include "tools.h"
12 #include "flang/Runtime/cpp-type.h"
13 #include "flang/Runtime/descriptor.h"
14 #include "flang/Runtime/reduction.h"
15 #include <cfloat>
16 #include <cinttypes>
17 
18 namespace Fortran::runtime {
19 
20 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
21 // argument; MATMUL does not.
22 
23 // General accumulator for any type and stride; this is not used for
24 // contiguous numeric vectors.
25 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
26 class Accumulator {
27 public:
28   using Result = AccumulationType<RCAT, RKIND>;
Accumulator(const Descriptor & x,const Descriptor & y)29   Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
AccumulateIndexed(SubscriptValue xAt,SubscriptValue yAt)30   void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
31     if constexpr (RCAT == TypeCategory::Logical) {
32       sum_ = sum_ ||
33           (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
34     } else {
35       const XT &xElement{*x_.Element<XT>(&xAt)};
36       const YT &yElement{*y_.Element<YT>(&yAt)};
37       if constexpr (RCAT == TypeCategory::Complex) {
38         sum_ += std::conj(static_cast<Result>(xElement)) *
39             static_cast<Result>(yElement);
40       } else {
41         sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
42       }
43     }
44   }
GetResult() const45   Result GetResult() const { return sum_; }
46 
47 private:
48   const Descriptor &x_, &y_;
49   Result sum_{};
50 };
51 
52 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
DoDotProduct(const Descriptor & x,const Descriptor & y,Terminator & terminator)53 static inline CppTypeFor<RCAT, RKIND> DoDotProduct(
54     const Descriptor &x, const Descriptor &y, Terminator &terminator) {
55   using Result = CppTypeFor<RCAT, RKIND>;
56   RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
57   SubscriptValue n{x.GetDimension(0).Extent()};
58   if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
59     terminator.Crash(
60         "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
61         static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
62   }
63   if constexpr (RCAT != TypeCategory::Logical) {
64     if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
65         y.GetDimension(0).ByteStride() == sizeof(YT)) {
66       // Contiguous numeric vectors
67       if constexpr (std::is_same_v<XT, YT>) {
68         // Contiguous homogeneous numeric vectors
69         if constexpr (std::is_same_v<XT, float>) {
70           // TODO: call BLAS-1 SDOT or SDSDOT
71         } else if constexpr (std::is_same_v<XT, double>) {
72           // TODO: call BLAS-1 DDOT
73         } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
74           // TODO: call BLAS-1 CDOTC
75         } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
76           // TODO: call BLAS-1 ZDOTC
77         }
78       }
79       XT *xp{x.OffsetElement<XT>(0)};
80       YT *yp{y.OffsetElement<YT>(0)};
81       using AccumType = AccumulationType<RCAT, RKIND>;
82       AccumType accum{};
83       if constexpr (RCAT == TypeCategory::Complex) {
84         for (SubscriptValue j{0}; j < n; ++j) {
85           accum += std::conj(static_cast<AccumType>(*xp++)) *
86               static_cast<AccumType>(*yp++);
87         }
88       } else {
89         for (SubscriptValue j{0}; j < n; ++j) {
90           accum +=
91               static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
92         }
93       }
94       return static_cast<Result>(accum);
95     }
96   }
97   // Non-contiguous, heterogeneous, & LOGICAL cases
98   SubscriptValue xAt{x.GetDimension(0).LowerBound()};
99   SubscriptValue yAt{y.GetDimension(0).LowerBound()};
100   Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
101   for (SubscriptValue j{0}; j < n; ++j) {
102     accumulator.AccumulateIndexed(xAt++, yAt++);
103   }
104   return static_cast<Result>(accumulator.GetResult());
105 }
106 
107 template <TypeCategory RCAT, int RKIND> struct DotProduct {
108   using Result = CppTypeFor<RCAT, RKIND>;
109   template <TypeCategory XCAT, int XKIND> struct DP1 {
110     template <TypeCategory YCAT, int YKIND> struct DP2 {
operator ()Fortran::runtime::DotProduct::DP1::DP2111       Result operator()(const Descriptor &x, const Descriptor &y,
112           Terminator &terminator) const {
113         if constexpr (constexpr auto resultType{
114                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
115           if constexpr (resultType->first == RCAT &&
116               (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) {
117             return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
118                 CppTypeFor<YCAT, YKIND>>(x, y, terminator);
119           }
120         }
121         terminator.Crash(
122             "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
123             static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
124             static_cast<int>(YCAT), YKIND);
125       }
126     };
operator ()Fortran::runtime::DotProduct::DP1127     Result operator()(const Descriptor &x, const Descriptor &y,
128         Terminator &terminator, TypeCategory yCat, int yKind) const {
129       return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
130     }
131   };
operator ()Fortran::runtime::DotProduct132   Result operator()(const Descriptor &x, const Descriptor &y,
133       const char *source, int line) const {
134     Terminator terminator{source, line};
135     if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
136       // No conversions needed, operands and result have same known type
137       return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
138           x, y, terminator);
139     } else {
140       auto xCatKind{x.type().GetCategoryAndKind()};
141       auto yCatKind{y.type().GetCategoryAndKind()};
142       RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
143       return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
144           terminator, x, y, terminator, yCatKind->first, yCatKind->second);
145     }
146   }
147 };
148 
149 extern "C" {
RTNAME(DotProductInteger1)150 std::int8_t RTNAME(DotProductInteger1)(
151     const Descriptor &x, const Descriptor &y, const char *source, int line) {
152   return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
153 }
RTNAME(DotProductInteger2)154 std::int16_t RTNAME(DotProductInteger2)(
155     const Descriptor &x, const Descriptor &y, const char *source, int line) {
156   return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
157 }
RTNAME(DotProductInteger4)158 std::int32_t RTNAME(DotProductInteger4)(
159     const Descriptor &x, const Descriptor &y, const char *source, int line) {
160   return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
161 }
RTNAME(DotProductInteger8)162 std::int64_t RTNAME(DotProductInteger8)(
163     const Descriptor &x, const Descriptor &y, const char *source, int line) {
164   return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
165 }
166 #ifdef __SIZEOF_INT128__
RTNAME(DotProductInteger16)167 common::int128_t RTNAME(DotProductInteger16)(
168     const Descriptor &x, const Descriptor &y, const char *source, int line) {
169   return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
170 }
171 #endif
172 
173 // TODO: REAL/COMPLEX(2 & 3)
174 // Intermediate results and operations are at least 64 bits
RTNAME(DotProductReal4)175 float RTNAME(DotProductReal4)(
176     const Descriptor &x, const Descriptor &y, const char *source, int line) {
177   return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
178 }
RTNAME(DotProductReal8)179 double RTNAME(DotProductReal8)(
180     const Descriptor &x, const Descriptor &y, const char *source, int line) {
181   return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
182 }
183 #if LDBL_MANT_DIG == 64
RTNAME(DotProductReal10)184 long double RTNAME(DotProductReal10)(
185     const Descriptor &x, const Descriptor &y, const char *source, int line) {
186   return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
187 }
188 #endif
189 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
RTNAME(DotProductReal16)190 CppTypeFor<TypeCategory::Real, 16> RTNAME(DotProductReal16)(
191     const Descriptor &x, const Descriptor &y, const char *source, int line) {
192   return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
193 }
194 #endif
195 
RTNAME(CppDotProductComplex4)196 void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
197     const Descriptor &x, const Descriptor &y, const char *source, int line) {
198   auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)};
199   result = std::complex<float>{
200       static_cast<float>(z.real()), static_cast<float>(z.imag())};
201 }
RTNAME(CppDotProductComplex8)202 void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
203     const Descriptor &x, const Descriptor &y, const char *source, int line) {
204   result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
205 }
206 #if LDBL_MANT_DIG == 64
RTNAME(CppDotProductComplex10)207 void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
208     const Descriptor &x, const Descriptor &y, const char *source, int line) {
209   result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
210 }
211 #elif LDBL_MANT_DIG == 113
RTNAME(CppDotProductComplex16)212 void RTNAME(CppDotProductComplex16)(std::complex<CppFloat128Type> &result,
213     const Descriptor &x, const Descriptor &y, const char *source, int line) {
214   result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
215 }
216 #endif
217 
RTNAME(DotProductLogical)218 bool RTNAME(DotProductLogical)(
219     const Descriptor &x, const Descriptor &y, const char *source, int line) {
220   return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
221 }
222 } // extern "C"
223 } // namespace Fortran::runtime
224