1 //===-- runtime/dot-product.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "terminator.h" 10 #include "tools.h" 11 #include "flang/Runtime/cpp-type.h" 12 #include "flang/Runtime/descriptor.h" 13 #include "flang/Runtime/reduction.h" 14 #include <cinttypes> 15 16 namespace Fortran::runtime { 17 18 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first 19 // argument; MATMUL does not. 20 21 // General accumulator for any type and stride; this is not used for 22 // contiguous numeric vectors. 23 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 24 class Accumulator { 25 public: 26 using Result = AccumulationType<RCAT, RKIND>; 27 Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} 28 void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) { 29 if constexpr (RCAT == TypeCategory::Logical) { 30 sum_ = sum_ || 31 (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); 32 } else { 33 const XT &xElement{*x_.Element<XT>(&xAt)}; 34 const YT &yElement{*y_.Element<YT>(&yAt)}; 35 if constexpr (RCAT == TypeCategory::Complex) { 36 sum_ += std::conj(static_cast<Result>(xElement)) * 37 static_cast<Result>(yElement); 38 } else { 39 sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement); 40 } 41 } 42 } 43 Result GetResult() const { return sum_; } 44 45 private: 46 const Descriptor &x_, &y_; 47 Result sum_{}; 48 }; 49 50 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 51 static inline CppTypeFor<RCAT, RKIND> DoDotProduct( 52 const Descriptor &x, const Descriptor &y, Terminator &terminator) { 53 using Result = CppTypeFor<RCAT, RKIND>; 54 RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); 55 SubscriptValue n{x.GetDimension(0).Extent()}; 56 if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { 57 terminator.Crash( 58 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd", 59 static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN)); 60 } 61 if constexpr (RCAT != TypeCategory::Logical) { 62 if (x.GetDimension(0).ByteStride() == sizeof(XT) && 63 y.GetDimension(0).ByteStride() == sizeof(YT)) { 64 // Contiguous numeric vectors 65 if constexpr (std::is_same_v<XT, YT>) { 66 // Contiguous homogeneous numeric vectors 67 if constexpr (std::is_same_v<XT, float>) { 68 // TODO: call BLAS-1 SDOT or SDSDOT 69 } else if constexpr (std::is_same_v<XT, double>) { 70 // TODO: call BLAS-1 DDOT 71 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 72 // TODO: call BLAS-1 CDOTC 73 } else if constexpr (std::is_same_v<XT, std::complex<double>>) { 74 // TODO: call BLAS-1 ZDOTC 75 } 76 } 77 XT *xp{x.OffsetElement<XT>(0)}; 78 YT *yp{y.OffsetElement<YT>(0)}; 79 using AccumType = AccumulationType<RCAT, RKIND>; 80 AccumType accum{}; 81 if constexpr (RCAT == TypeCategory::Complex) { 82 for (SubscriptValue j{0}; j < n; ++j) { 83 accum += std::conj(static_cast<AccumType>(*xp++)) * 84 static_cast<AccumType>(*yp++); 85 } 86 } else { 87 for (SubscriptValue j{0}; j < n; ++j) { 88 accum += 89 static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++); 90 } 91 } 92 return static_cast<Result>(accum); 93 } 94 } 95 // Non-contiguous, heterogeneous, & LOGICAL cases 96 SubscriptValue xAt{x.GetDimension(0).LowerBound()}; 97 SubscriptValue yAt{y.GetDimension(0).LowerBound()}; 98 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 99 for (SubscriptValue j{0}; j < n; ++j) { 100 accumulator.AccumulateIndexed(xAt++, yAt++); 101 } 102 return static_cast<Result>(accumulator.GetResult()); 103 } 104 105 template <TypeCategory RCAT, int RKIND> struct DotProduct { 106 using Result = CppTypeFor<RCAT, RKIND>; 107 template <TypeCategory XCAT, int XKIND> struct DP1 { 108 template <TypeCategory YCAT, int YKIND> struct DP2 { 109 Result operator()(const Descriptor &x, const Descriptor &y, 110 Terminator &terminator) const { 111 if constexpr (constexpr auto resultType{ 112 GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 113 if constexpr (resultType->first == RCAT && 114 (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) { 115 return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>, 116 CppTypeFor<YCAT, YKIND>>(x, y, terminator); 117 } 118 } 119 terminator.Crash( 120 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))", 121 static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND, 122 static_cast<int>(YCAT), YKIND); 123 } 124 }; 125 Result operator()(const Descriptor &x, const Descriptor &y, 126 Terminator &terminator, TypeCategory yCat, int yKind) const { 127 return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator); 128 } 129 }; 130 Result operator()(const Descriptor &x, const Descriptor &y, 131 const char *source, int line) const { 132 Terminator terminator{source, line}; 133 if (RCAT != TypeCategory::Logical && x.type() == y.type()) { 134 // No conversions needed, operands and result have same known type 135 return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}( 136 x, y, terminator); 137 } else { 138 auto xCatKind{x.type().GetCategoryAndKind()}; 139 auto yCatKind{y.type().GetCategoryAndKind()}; 140 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 141 return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, 142 terminator, x, y, terminator, yCatKind->first, yCatKind->second); 143 } 144 } 145 }; 146 147 extern "C" { 148 std::int8_t RTNAME(DotProductInteger1)( 149 const Descriptor &x, const Descriptor &y, const char *source, int line) { 150 return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line); 151 } 152 std::int16_t RTNAME(DotProductInteger2)( 153 const Descriptor &x, const Descriptor &y, const char *source, int line) { 154 return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line); 155 } 156 std::int32_t RTNAME(DotProductInteger4)( 157 const Descriptor &x, const Descriptor &y, const char *source, int line) { 158 return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line); 159 } 160 std::int64_t RTNAME(DotProductInteger8)( 161 const Descriptor &x, const Descriptor &y, const char *source, int line) { 162 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); 163 } 164 #ifdef __SIZEOF_INT128__ 165 common::int128_t RTNAME(DotProductInteger16)( 166 const Descriptor &x, const Descriptor &y, const char *source, int line) { 167 return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line); 168 } 169 #endif 170 171 // TODO: REAL/COMPLEX(2 & 3) 172 // Intermediate results and operations are at least 64 bits 173 float RTNAME(DotProductReal4)( 174 const Descriptor &x, const Descriptor &y, const char *source, int line) { 175 return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line); 176 } 177 double RTNAME(DotProductReal8)( 178 const Descriptor &x, const Descriptor &y, const char *source, int line) { 179 return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line); 180 } 181 #if LONG_DOUBLE == 80 182 long double RTNAME(DotProductReal10)( 183 const Descriptor &x, const Descriptor &y, const char *source, int line) { 184 return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line); 185 } 186 #elif LONG_DOUBLE == 128 187 long double RTNAME(DotProductReal16)( 188 const Descriptor &x, const Descriptor &y, const char *source, int line) { 189 return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line); 190 } 191 #endif 192 193 void RTNAME(CppDotProductComplex4)(std::complex<float> &result, 194 const Descriptor &x, const Descriptor &y, const char *source, int line) { 195 auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)}; 196 result = std::complex<float>{ 197 static_cast<float>(z.real()), static_cast<float>(z.imag())}; 198 } 199 void RTNAME(CppDotProductComplex8)(std::complex<double> &result, 200 const Descriptor &x, const Descriptor &y, const char *source, int line) { 201 result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line); 202 } 203 #if LONG_DOUBLE == 80 204 void RTNAME(CppDotProductComplex10)(std::complex<long double> &result, 205 const Descriptor &x, const Descriptor &y, const char *source, int line) { 206 result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line); 207 } 208 #elif LONG_DOUBLE == 128 209 void RTNAME(CppDotProductComplex16)(std::complex<long double> &result, 210 const Descriptor &x, const Descriptor &y, const char *source, int line) { 211 result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line); 212 } 213 #endif 214 215 bool RTNAME(DotProductLogical)( 216 const Descriptor &x, const Descriptor &y, const char *source, int line) { 217 return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line); 218 } 219 } // extern "C" 220 } // namespace Fortran::runtime 221