1 //===-- runtime/dot-product.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "terminator.h"
10 #include "tools.h"
11 #include "flang/Runtime/cpp-type.h"
12 #include "flang/Runtime/descriptor.h"
13 #include "flang/Runtime/reduction.h"
14 #include <cinttypes>
15 
16 namespace Fortran::runtime {
17 
18 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
19 // argument; MATMUL does not.
20 
21 // General accumulator for any type and stride; this is not used for
22 // contiguous numeric vectors.
23 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
24 class Accumulator {
25 public:
26   using Result = AccumulationType<RCAT, RKIND>;
27   Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
28   void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
29     if constexpr (RCAT == TypeCategory::Logical) {
30       sum_ = sum_ ||
31           (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
32     } else {
33       const XT &xElement{*x_.Element<XT>(&xAt)};
34       const YT &yElement{*y_.Element<YT>(&yAt)};
35       if constexpr (RCAT == TypeCategory::Complex) {
36         sum_ += std::conj(static_cast<Result>(xElement)) *
37             static_cast<Result>(yElement);
38       } else {
39         sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
40       }
41     }
42   }
43   Result GetResult() const { return sum_; }
44 
45 private:
46   const Descriptor &x_, &y_;
47   Result sum_{};
48 };
49 
50 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
51 static inline CppTypeFor<RCAT, RKIND> DoDotProduct(
52     const Descriptor &x, const Descriptor &y, Terminator &terminator) {
53   using Result = CppTypeFor<RCAT, RKIND>;
54   RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
55   SubscriptValue n{x.GetDimension(0).Extent()};
56   if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
57     terminator.Crash(
58         "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
59         static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
60   }
61   if constexpr (RCAT != TypeCategory::Logical) {
62     if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
63         y.GetDimension(0).ByteStride() == sizeof(YT)) {
64       // Contiguous numeric vectors
65       if constexpr (std::is_same_v<XT, YT>) {
66         // Contiguous homogeneous numeric vectors
67         if constexpr (std::is_same_v<XT, float>) {
68           // TODO: call BLAS-1 SDOT or SDSDOT
69         } else if constexpr (std::is_same_v<XT, double>) {
70           // TODO: call BLAS-1 DDOT
71         } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
72           // TODO: call BLAS-1 CDOTC
73         } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
74           // TODO: call BLAS-1 ZDOTC
75         }
76       }
77       XT *xp{x.OffsetElement<XT>(0)};
78       YT *yp{y.OffsetElement<YT>(0)};
79       using AccumType = AccumulationType<RCAT, RKIND>;
80       AccumType accum{};
81       if constexpr (RCAT == TypeCategory::Complex) {
82         for (SubscriptValue j{0}; j < n; ++j) {
83           accum += std::conj(static_cast<AccumType>(*xp++)) *
84               static_cast<AccumType>(*yp++);
85         }
86       } else {
87         for (SubscriptValue j{0}; j < n; ++j) {
88           accum +=
89               static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
90         }
91       }
92       return static_cast<Result>(accum);
93     }
94   }
95   // Non-contiguous, heterogeneous, & LOGICAL cases
96   SubscriptValue xAt{x.GetDimension(0).LowerBound()};
97   SubscriptValue yAt{y.GetDimension(0).LowerBound()};
98   Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
99   for (SubscriptValue j{0}; j < n; ++j) {
100     accumulator.AccumulateIndexed(xAt++, yAt++);
101   }
102   return static_cast<Result>(accumulator.GetResult());
103 }
104 
105 template <TypeCategory RCAT, int RKIND> struct DotProduct {
106   using Result = CppTypeFor<RCAT, RKIND>;
107   template <TypeCategory XCAT, int XKIND> struct DP1 {
108     template <TypeCategory YCAT, int YKIND> struct DP2 {
109       Result operator()(const Descriptor &x, const Descriptor &y,
110           Terminator &terminator) const {
111         if constexpr (constexpr auto resultType{
112                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
113           if constexpr (resultType->first == RCAT &&
114               (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) {
115             return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
116                 CppTypeFor<YCAT, YKIND>>(x, y, terminator);
117           }
118         }
119         terminator.Crash(
120             "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
121             static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
122             static_cast<int>(YCAT), YKIND);
123       }
124     };
125     Result operator()(const Descriptor &x, const Descriptor &y,
126         Terminator &terminator, TypeCategory yCat, int yKind) const {
127       return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
128     }
129   };
130   Result operator()(const Descriptor &x, const Descriptor &y,
131       const char *source, int line) const {
132     Terminator terminator{source, line};
133     if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
134       // No conversions needed, operands and result have same known type
135       return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
136           x, y, terminator);
137     } else {
138       auto xCatKind{x.type().GetCategoryAndKind()};
139       auto yCatKind{y.type().GetCategoryAndKind()};
140       RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
141       return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
142           terminator, x, y, terminator, yCatKind->first, yCatKind->second);
143     }
144   }
145 };
146 
147 extern "C" {
148 std::int8_t RTNAME(DotProductInteger1)(
149     const Descriptor &x, const Descriptor &y, const char *source, int line) {
150   return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
151 }
152 std::int16_t RTNAME(DotProductInteger2)(
153     const Descriptor &x, const Descriptor &y, const char *source, int line) {
154   return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
155 }
156 std::int32_t RTNAME(DotProductInteger4)(
157     const Descriptor &x, const Descriptor &y, const char *source, int line) {
158   return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
159 }
160 std::int64_t RTNAME(DotProductInteger8)(
161     const Descriptor &x, const Descriptor &y, const char *source, int line) {
162   return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
163 }
164 #ifdef __SIZEOF_INT128__
165 common::int128_t RTNAME(DotProductInteger16)(
166     const Descriptor &x, const Descriptor &y, const char *source, int line) {
167   return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
168 }
169 #endif
170 
171 // TODO: REAL/COMPLEX(2 & 3)
172 // Intermediate results and operations are at least 64 bits
173 float RTNAME(DotProductReal4)(
174     const Descriptor &x, const Descriptor &y, const char *source, int line) {
175   return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
176 }
177 double RTNAME(DotProductReal8)(
178     const Descriptor &x, const Descriptor &y, const char *source, int line) {
179   return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
180 }
181 #if LONG_DOUBLE == 80
182 long double RTNAME(DotProductReal10)(
183     const Descriptor &x, const Descriptor &y, const char *source, int line) {
184   return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
185 }
186 #elif LONG_DOUBLE == 128
187 long double RTNAME(DotProductReal16)(
188     const Descriptor &x, const Descriptor &y, const char *source, int line) {
189   return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
190 }
191 #endif
192 
193 void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
194     const Descriptor &x, const Descriptor &y, const char *source, int line) {
195   auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)};
196   result = std::complex<float>{
197       static_cast<float>(z.real()), static_cast<float>(z.imag())};
198 }
199 void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
200     const Descriptor &x, const Descriptor &y, const char *source, int line) {
201   result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
202 }
203 #if LONG_DOUBLE == 80
204 void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
205     const Descriptor &x, const Descriptor &y, const char *source, int line) {
206   result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
207 }
208 #elif LONG_DOUBLE == 128
209 void RTNAME(CppDotProductComplex16)(std::complex<long double> &result,
210     const Descriptor &x, const Descriptor &y, const char *source, int line) {
211   result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
212 }
213 #endif
214 
215 bool RTNAME(DotProductLogical)(
216     const Descriptor &x, const Descriptor &y, const char *source, int line) {
217   return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
218 }
219 } // extern "C"
220 } // namespace Fortran::runtime
221