1 //===-- runtime/dot-product.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "float.h" 10 #include "terminator.h" 11 #include "tools.h" 12 #include "flang/Runtime/cpp-type.h" 13 #include "flang/Runtime/descriptor.h" 14 #include "flang/Runtime/reduction.h" 15 #include <cfloat> 16 #include <cinttypes> 17 18 namespace Fortran::runtime { 19 20 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first 21 // argument; MATMUL does not. 22 23 // General accumulator for any type and stride; this is not used for 24 // contiguous numeric vectors. 25 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 26 class Accumulator { 27 public: 28 using Result = AccumulationType<RCAT, RKIND>; 29 Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} 30 void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) { 31 if constexpr (RCAT == TypeCategory::Logical) { 32 sum_ = sum_ || 33 (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); 34 } else { 35 const XT &xElement{*x_.Element<XT>(&xAt)}; 36 const YT &yElement{*y_.Element<YT>(&yAt)}; 37 if constexpr (RCAT == TypeCategory::Complex) { 38 sum_ += std::conj(static_cast<Result>(xElement)) * 39 static_cast<Result>(yElement); 40 } else { 41 sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement); 42 } 43 } 44 } 45 Result GetResult() const { return sum_; } 46 47 private: 48 const Descriptor &x_, &y_; 49 Result sum_{}; 50 }; 51 52 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 53 static inline CppTypeFor<RCAT, RKIND> DoDotProduct( 54 const Descriptor &x, const Descriptor &y, Terminator &terminator) { 55 using Result = CppTypeFor<RCAT, RKIND>; 56 RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); 57 SubscriptValue n{x.GetDimension(0).Extent()}; 58 if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { 59 terminator.Crash( 60 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd", 61 static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN)); 62 } 63 if constexpr (RCAT != TypeCategory::Logical) { 64 if (x.GetDimension(0).ByteStride() == sizeof(XT) && 65 y.GetDimension(0).ByteStride() == sizeof(YT)) { 66 // Contiguous numeric vectors 67 if constexpr (std::is_same_v<XT, YT>) { 68 // Contiguous homogeneous numeric vectors 69 if constexpr (std::is_same_v<XT, float>) { 70 // TODO: call BLAS-1 SDOT or SDSDOT 71 } else if constexpr (std::is_same_v<XT, double>) { 72 // TODO: call BLAS-1 DDOT 73 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 74 // TODO: call BLAS-1 CDOTC 75 } else if constexpr (std::is_same_v<XT, std::complex<double>>) { 76 // TODO: call BLAS-1 ZDOTC 77 } 78 } 79 XT *xp{x.OffsetElement<XT>(0)}; 80 YT *yp{y.OffsetElement<YT>(0)}; 81 using AccumType = AccumulationType<RCAT, RKIND>; 82 AccumType accum{}; 83 if constexpr (RCAT == TypeCategory::Complex) { 84 for (SubscriptValue j{0}; j < n; ++j) { 85 accum += std::conj(static_cast<AccumType>(*xp++)) * 86 static_cast<AccumType>(*yp++); 87 } 88 } else { 89 for (SubscriptValue j{0}; j < n; ++j) { 90 accum += 91 static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++); 92 } 93 } 94 return static_cast<Result>(accum); 95 } 96 } 97 // Non-contiguous, heterogeneous, & LOGICAL cases 98 SubscriptValue xAt{x.GetDimension(0).LowerBound()}; 99 SubscriptValue yAt{y.GetDimension(0).LowerBound()}; 100 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 101 for (SubscriptValue j{0}; j < n; ++j) { 102 accumulator.AccumulateIndexed(xAt++, yAt++); 103 } 104 return static_cast<Result>(accumulator.GetResult()); 105 } 106 107 template <TypeCategory RCAT, int RKIND> struct DotProduct { 108 using Result = CppTypeFor<RCAT, RKIND>; 109 template <TypeCategory XCAT, int XKIND> struct DP1 { 110 template <TypeCategory YCAT, int YKIND> struct DP2 { 111 Result operator()(const Descriptor &x, const Descriptor &y, 112 Terminator &terminator) const { 113 if constexpr (constexpr auto resultType{ 114 GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 115 if constexpr (resultType->first == RCAT && 116 (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) { 117 return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>, 118 CppTypeFor<YCAT, YKIND>>(x, y, terminator); 119 } 120 } 121 terminator.Crash( 122 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))", 123 static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND, 124 static_cast<int>(YCAT), YKIND); 125 } 126 }; 127 Result operator()(const Descriptor &x, const Descriptor &y, 128 Terminator &terminator, TypeCategory yCat, int yKind) const { 129 return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator); 130 } 131 }; 132 Result operator()(const Descriptor &x, const Descriptor &y, 133 const char *source, int line) const { 134 Terminator terminator{source, line}; 135 if (RCAT != TypeCategory::Logical && x.type() == y.type()) { 136 // No conversions needed, operands and result have same known type 137 return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}( 138 x, y, terminator); 139 } else { 140 auto xCatKind{x.type().GetCategoryAndKind()}; 141 auto yCatKind{y.type().GetCategoryAndKind()}; 142 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 143 return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, 144 terminator, x, y, terminator, yCatKind->first, yCatKind->second); 145 } 146 } 147 }; 148 149 extern "C" { 150 std::int8_t RTNAME(DotProductInteger1)( 151 const Descriptor &x, const Descriptor &y, const char *source, int line) { 152 return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line); 153 } 154 std::int16_t RTNAME(DotProductInteger2)( 155 const Descriptor &x, const Descriptor &y, const char *source, int line) { 156 return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line); 157 } 158 std::int32_t RTNAME(DotProductInteger4)( 159 const Descriptor &x, const Descriptor &y, const char *source, int line) { 160 return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line); 161 } 162 std::int64_t RTNAME(DotProductInteger8)( 163 const Descriptor &x, const Descriptor &y, const char *source, int line) { 164 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); 165 } 166 #ifdef __SIZEOF_INT128__ 167 common::int128_t RTNAME(DotProductInteger16)( 168 const Descriptor &x, const Descriptor &y, const char *source, int line) { 169 return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line); 170 } 171 #endif 172 173 // TODO: REAL/COMPLEX(2 & 3) 174 // Intermediate results and operations are at least 64 bits 175 float RTNAME(DotProductReal4)( 176 const Descriptor &x, const Descriptor &y, const char *source, int line) { 177 return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line); 178 } 179 double RTNAME(DotProductReal8)( 180 const Descriptor &x, const Descriptor &y, const char *source, int line) { 181 return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line); 182 } 183 #if LDBL_MANT_DIG == 64 184 long double RTNAME(DotProductReal10)( 185 const Descriptor &x, const Descriptor &y, const char *source, int line) { 186 return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line); 187 } 188 #endif 189 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128 190 CppTypeFor<TypeCategory::Real, 16> RTNAME(DotProductReal16)( 191 const Descriptor &x, const Descriptor &y, const char *source, int line) { 192 return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line); 193 } 194 #endif 195 196 void RTNAME(CppDotProductComplex4)(std::complex<float> &result, 197 const Descriptor &x, const Descriptor &y, const char *source, int line) { 198 auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)}; 199 result = std::complex<float>{ 200 static_cast<float>(z.real()), static_cast<float>(z.imag())}; 201 } 202 void RTNAME(CppDotProductComplex8)(std::complex<double> &result, 203 const Descriptor &x, const Descriptor &y, const char *source, int line) { 204 result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line); 205 } 206 #if LDBL_MANT_DIG == 64 207 void RTNAME(CppDotProductComplex10)(std::complex<long double> &result, 208 const Descriptor &x, const Descriptor &y, const char *source, int line) { 209 result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line); 210 } 211 #elif LDBL_MANT_DIG == 113 212 void RTNAME(CppDotProductComplex16)(std::complex<CppFloat128Type> &result, 213 const Descriptor &x, const Descriptor &y, const char *source, int line) { 214 result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line); 215 } 216 #endif 217 218 bool RTNAME(DotProductLogical)( 219 const Descriptor &x, const Descriptor &y, const char *source, int line) { 220 return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line); 221 } 222 } // extern "C" 223 } // namespace Fortran::runtime 224