1 //===-- runtime/dot-product.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "cpp-type.h" 10 #include "descriptor.h" 11 #include "reduction.h" 12 #include "terminator.h" 13 #include "tools.h" 14 #include <cinttypes> 15 16 namespace Fortran::runtime { 17 18 template <typename RESULT, TypeCategory XCAT, typename XT, typename YT> 19 class Accumulator { 20 public: 21 using Result = RESULT; 22 Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} 23 void Accumulate(SubscriptValue xAt, SubscriptValue yAt) { 24 if constexpr (XCAT == TypeCategory::Complex) { 25 sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) * 26 static_cast<Result>(*y_.Element<YT>(&yAt)); 27 } else if constexpr (XCAT == TypeCategory::Logical) { 28 sum_ = sum_ || 29 (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); 30 } else { 31 sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) * 32 static_cast<Result>(*y_.Element<YT>(&yAt)); 33 } 34 } 35 Result GetResult() const { return sum_; } 36 37 private: 38 const Descriptor &x_, &y_; 39 Result sum_{}; 40 }; 41 42 template <typename RESULT, TypeCategory XCAT, typename XT, typename YT> 43 static inline RESULT DoDotProduct( 44 const Descriptor &x, const Descriptor &y, Terminator &terminator) { 45 RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); 46 SubscriptValue n{x.GetDimension(0).Extent()}; 47 if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { 48 terminator.Crash( 49 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd", 50 static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN)); 51 } 52 if constexpr (std::is_same_v<XT, YT>) { 53 if constexpr (std::is_same_v<XT, float>) { 54 // TODO: call BLAS-1 SDOT or SDSDOT 55 } else if constexpr (std::is_same_v<XT, double>) { 56 // TODO: call BLAS-1 DDOT 57 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 58 // TODO: call BLAS-1 CDOTC 59 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 60 // TODO: call BLAS-1 ZDOTC 61 } 62 } 63 SubscriptValue xAt{x.GetDimension(0).LowerBound()}; 64 SubscriptValue yAt{y.GetDimension(0).LowerBound()}; 65 Accumulator<RESULT, XCAT, XT, YT> accumulator{x, y}; 66 for (SubscriptValue j{0}; j < n; ++j) { 67 accumulator.Accumulate(xAt++, yAt++); 68 } 69 return accumulator.GetResult(); 70 } 71 72 template <TypeCategory RCAT, int RKIND> struct DotProduct { 73 using Result = CppTypeFor<RCAT, RKIND>; 74 template <TypeCategory XCAT, int XKIND> struct DP1 { 75 template <TypeCategory YCAT, int YKIND> struct DP2 { 76 Result operator()(const Descriptor &x, const Descriptor &y, 77 Terminator &terminator) const { 78 if constexpr (constexpr auto resultType{ 79 GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 80 if constexpr (resultType->first == RCAT && 81 resultType->second <= RKIND) { 82 return DoDotProduct<Result, XCAT, CppTypeFor<XCAT, XKIND>, 83 CppTypeFor<YCAT, YKIND>>(x, y, terminator); 84 } 85 } 86 terminator.Crash( 87 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))", 88 static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND, 89 static_cast<int>(YCAT), YKIND); 90 } 91 }; 92 Result operator()(const Descriptor &x, const Descriptor &y, 93 Terminator &terminator, TypeCategory yCat, int yKind) const { 94 return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator); 95 } 96 }; 97 Result operator()(const Descriptor &x, const Descriptor &y, 98 const char *source, int line) const { 99 Terminator terminator{source, line}; 100 auto xCatKind{x.type().GetCategoryAndKind()}; 101 auto yCatKind{y.type().GetCategoryAndKind()}; 102 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 103 return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, terminator, 104 x, y, terminator, yCatKind->first, yCatKind->second); 105 } 106 }; 107 108 extern "C" { 109 std::int8_t RTNAME(DotProductInteger1)( 110 const Descriptor &x, const Descriptor &y, const char *source, int line) { 111 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); 112 } 113 std::int16_t RTNAME(DotProductInteger2)( 114 const Descriptor &x, const Descriptor &y, const char *source, int line) { 115 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); 116 } 117 std::int32_t RTNAME(DotProductInteger4)( 118 const Descriptor &x, const Descriptor &y, const char *source, int line) { 119 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); 120 } 121 std::int64_t RTNAME(DotProductInteger8)( 122 const Descriptor &x, const Descriptor &y, const char *source, int line) { 123 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); 124 } 125 #ifdef __SIZEOF_INT128__ 126 common::int128_t RTNAME(DotProductInteger16)( 127 const Descriptor &x, const Descriptor &y, const char *source, int line) { 128 return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line); 129 } 130 #endif 131 132 // TODO: REAL/COMPLEX(2 & 3) 133 float RTNAME(DotProductReal4)( 134 const Descriptor &x, const Descriptor &y, const char *source, int line) { 135 return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line); 136 } 137 double RTNAME(DotProductReal8)( 138 const Descriptor &x, const Descriptor &y, const char *source, int line) { 139 return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line); 140 } 141 #if LONG_DOUBLE == 80 142 long double RTNAME(DotProductReal10)( 143 const Descriptor &x, const Descriptor &y, const char *source, int line) { 144 return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line); 145 } 146 #elif LONG_DOUBLE == 128 147 long double RTNAME(DotProductReal16)( 148 const Descriptor &x, const Descriptor &y, const char *source, int line) { 149 return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line); 150 } 151 #endif 152 153 void RTNAME(CppDotProductComplex4)(std::complex<float> &result, 154 const Descriptor &x, const Descriptor &y, const char *source, int line) { 155 auto z{DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line)}; 156 result = std::complex<float>{ 157 static_cast<float>(z.real()), static_cast<float>(z.imag())}; 158 } 159 void RTNAME(CppDotProductComplex8)(std::complex<double> &result, 160 const Descriptor &x, const Descriptor &y, const char *source, int line) { 161 result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line); 162 } 163 #if LONG_DOUBLE == 80 164 void RTNAME(CppDotProductComplex10)(std::complex<long double> &result, 165 const Descriptor &x, const Descriptor &y, const char *source, int line) { 166 result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line); 167 } 168 #elif LONG_DOUBLE == 128 169 void RTNAME(CppDotProductComplex16)(std::complex<long double> &result, 170 const Descriptor &x, const Descriptor &y, const char *source, int line) { 171 result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line); 172 } 173 #endif 174 175 bool RTNAME(DotProductLogical)( 176 const Descriptor &x, const Descriptor &y, const char *source, int line) { 177 return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line); 178 } 179 } // extern "C" 180 } // namespace Fortran::runtime 181