1 //===-- runtime/dot-product.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "cpp-type.h"
10 #include "descriptor.h"
11 #include "reduction.h"
12 #include "terminator.h"
13 #include "tools.h"
14 #include <cinttypes>
15 
16 namespace Fortran::runtime {
17 
18 template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
19 class Accumulator {
20 public:
21   using Result = RESULT;
22   Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
23   void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
24     if constexpr (XCAT == TypeCategory::Complex) {
25       sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
26           static_cast<Result>(*y_.Element<YT>(&yAt));
27     } else if constexpr (XCAT == TypeCategory::Logical) {
28       sum_ = sum_ ||
29           (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
30     } else {
31       sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
32           static_cast<Result>(*y_.Element<YT>(&yAt));
33     }
34   }
35   Result GetResult() const { return sum_; }
36 
37 private:
38   const Descriptor &x_, &y_;
39   Result sum_{};
40 };
41 
42 template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
43 static inline RESULT DoDotProduct(
44     const Descriptor &x, const Descriptor &y, Terminator &terminator) {
45   RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
46   SubscriptValue n{x.GetDimension(0).Extent()};
47   if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
48     terminator.Crash(
49         "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
50         static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
51   }
52   if constexpr (std::is_same_v<XT, YT>) {
53     if constexpr (std::is_same_v<XT, float>) {
54       // TODO: call BLAS-1 SDOT or SDSDOT
55     } else if constexpr (std::is_same_v<XT, double>) {
56       // TODO: call BLAS-1 DDOT
57     } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
58       // TODO: call BLAS-1 CDOTC
59     } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
60       // TODO: call BLAS-1 ZDOTC
61     }
62   }
63   SubscriptValue xAt{x.GetDimension(0).LowerBound()};
64   SubscriptValue yAt{y.GetDimension(0).LowerBound()};
65   Accumulator<RESULT, XCAT, XT, YT> accumulator{x, y};
66   for (SubscriptValue j{0}; j < n; ++j) {
67     accumulator.Accumulate(xAt++, yAt++);
68   }
69   return accumulator.GetResult();
70 }
71 
72 template <TypeCategory RCAT, int RKIND> struct DotProduct {
73   using Result = CppTypeFor<RCAT, RKIND>;
74   template <TypeCategory XCAT, int XKIND> struct DP1 {
75     template <TypeCategory YCAT, int YKIND> struct DP2 {
76       Result operator()(const Descriptor &x, const Descriptor &y,
77           Terminator &terminator) const {
78         if constexpr (constexpr auto resultType{
79                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
80           if constexpr (resultType->first == RCAT &&
81               resultType->second <= RKIND) {
82             return DoDotProduct<Result, XCAT, CppTypeFor<XCAT, XKIND>,
83                 CppTypeFor<YCAT, YKIND>>(x, y, terminator);
84           }
85         }
86         terminator.Crash(
87             "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
88             static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
89             static_cast<int>(YCAT), YKIND);
90       }
91     };
92     Result operator()(const Descriptor &x, const Descriptor &y,
93         Terminator &terminator, TypeCategory yCat, int yKind) const {
94       return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
95     }
96   };
97   Result operator()(const Descriptor &x, const Descriptor &y,
98       const char *source, int line) const {
99     Terminator terminator{source, line};
100     auto xCatKind{x.type().GetCategoryAndKind()};
101     auto yCatKind{y.type().GetCategoryAndKind()};
102     RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
103     return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, terminator,
104         x, y, terminator, yCatKind->first, yCatKind->second);
105   }
106 };
107 
108 extern "C" {
109 std::int8_t RTNAME(DotProductInteger1)(
110     const Descriptor &x, const Descriptor &y, const char *source, int line) {
111   return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
112 }
113 std::int16_t RTNAME(DotProductInteger2)(
114     const Descriptor &x, const Descriptor &y, const char *source, int line) {
115   return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
116 }
117 std::int32_t RTNAME(DotProductInteger4)(
118     const Descriptor &x, const Descriptor &y, const char *source, int line) {
119   return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
120 }
121 std::int64_t RTNAME(DotProductInteger8)(
122     const Descriptor &x, const Descriptor &y, const char *source, int line) {
123   return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
124 }
125 #ifdef __SIZEOF_INT128__
126 common::int128_t RTNAME(DotProductInteger16)(
127     const Descriptor &x, const Descriptor &y, const char *source, int line) {
128   return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
129 }
130 #endif
131 
132 // TODO: REAL/COMPLEX(2 & 3)
133 float RTNAME(DotProductReal4)(
134     const Descriptor &x, const Descriptor &y, const char *source, int line) {
135   return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
136 }
137 double RTNAME(DotProductReal8)(
138     const Descriptor &x, const Descriptor &y, const char *source, int line) {
139   return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
140 }
141 #if LONG_DOUBLE == 80
142 long double RTNAME(DotProductReal10)(
143     const Descriptor &x, const Descriptor &y, const char *source, int line) {
144   return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
145 }
146 #elif LONG_DOUBLE == 128
147 long double RTNAME(DotProductReal16)(
148     const Descriptor &x, const Descriptor &y, const char *source, int line) {
149   return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
150 }
151 #endif
152 
153 void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
154     const Descriptor &x, const Descriptor &y, const char *source, int line) {
155   auto z{DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line)};
156   result = std::complex<float>{
157       static_cast<float>(z.real()), static_cast<float>(z.imag())};
158 }
159 void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
160     const Descriptor &x, const Descriptor &y, const char *source, int line) {
161   result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
162 }
163 #if LONG_DOUBLE == 80
164 void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
165     const Descriptor &x, const Descriptor &y, const char *source, int line) {
166   result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
167 }
168 #elif LONG_DOUBLE == 128
169 void RTNAME(CppDotProductComplex16)(std::complex<long double> &result,
170     const Descriptor &x, const Descriptor &y, const char *source, int line) {
171   result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
172 }
173 #endif
174 
175 bool RTNAME(DotProductLogical)(
176     const Descriptor &x, const Descriptor &y, const char *source, int line) {
177   return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
178 }
179 } // extern "C"
180 } // namespace Fortran::runtime
181