1 //===-- runtime/matmul.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 // Implements all forms of MATMUL (Fortran 2018 16.9.124) 10 // 11 // There are two main entry points; one establishes a descriptor for the 12 // result and allocates it, and the other expects a result descriptor that 13 // points to existing storage. 14 // 15 // This implementation must handle all combinations of numeric types and 16 // kinds (100 - 165 cases depending on the target), plus all combinations 17 // of logical kinds (16). A single template undergoes many instantiations 18 // to cover all of the valid possibilities. 19 // 20 // Places where BLAS routines could be called are marked as TODO items. 21 22 #include "flang/Runtime/matmul.h" 23 #include "terminator.h" 24 #include "tools.h" 25 #include "flang/Runtime/cpp-type.h" 26 #include "flang/Runtime/descriptor.h" 27 28 namespace Fortran::runtime { 29 30 template <TypeCategory RCAT, int RKIND, typename XT, typename YT> 31 class Accumulator { 32 public: 33 // Accumulate floating-point results in (at least) double precision 34 using Result = CppTypeFor<RCAT, 35 RCAT == TypeCategory::Real || RCAT == TypeCategory::Complex 36 ? std::max(RKIND, static_cast<int>(sizeof(double))) 37 : RKIND>; 38 Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} 39 void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) { 40 if constexpr (RCAT == TypeCategory::Logical) { 41 sum_ = sum_ || 42 (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); 43 } else { 44 sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) * 45 static_cast<Result>(*y_.Element<YT>(yAt)); 46 } 47 } 48 Result GetResult() const { return sum_; } 49 50 private: 51 const Descriptor &x_, &y_; 52 Result sum_{}; 53 }; 54 55 // Implements an instance of MATMUL for given argument types. 56 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, 57 typename YT> 58 static inline void DoMatmul( 59 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, 60 const Descriptor &x, const Descriptor &y, Terminator &terminator) { 61 int xRank{x.rank()}; 62 int yRank{y.rank()}; 63 int resRank{xRank + yRank - 2}; 64 if (xRank * yRank != 2 * resRank) { 65 terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank); 66 } 67 SubscriptValue extent[2]{ 68 xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), 69 resRank == 2 ? y.GetDimension(1).Extent() : 0}; 70 if constexpr (IS_ALLOCATING) { 71 result.Establish( 72 RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); 73 for (int j{0}; j < resRank; ++j) { 74 result.GetDimension(j).SetBounds(1, extent[j]); 75 } 76 if (int stat{result.Allocate()}) { 77 terminator.Crash( 78 "MATMUL: could not allocate memory for result; STAT=%d", stat); 79 } 80 } else { 81 RUNTIME_CHECK(terminator, resRank == result.rank()); 82 RUNTIME_CHECK(terminator, result.type() == (TypeCode{RCAT, RKIND})); 83 RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); 84 RUNTIME_CHECK(terminator, 85 resRank == 1 || result.GetDimension(1).Extent() == extent[1]); 86 } 87 using WriteResult = 88 CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, 89 RKIND>; 90 SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; 91 if (n != y.GetDimension(0).Extent()) { 92 terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)", 93 static_cast<std::intmax_t>(n), 94 static_cast<std::intmax_t>(y.GetDimension(0).Extent())); 95 } 96 SubscriptValue xAt[2], yAt[2], resAt[2]; 97 x.GetLowerBounds(xAt); 98 y.GetLowerBounds(yAt); 99 result.GetLowerBounds(resAt); 100 if (resRank == 2) { // M*M -> M 101 if constexpr (std::is_same_v<XT, YT>) { 102 if constexpr (std::is_same_v<XT, float>) { 103 // TODO: call BLAS-3 SGEMM 104 } else if constexpr (std::is_same_v<XT, double>) { 105 // TODO: call BLAS-3 DGEMM 106 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 107 // TODO: call BLAS-3 CGEMM 108 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 109 // TODO: call BLAS-3 ZGEMM 110 } 111 } 112 SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; 113 for (SubscriptValue i{0}; i < extent[0]; ++i) { 114 for (SubscriptValue j{0}; j < extent[1]; ++j) { 115 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 116 yAt[1] = y1 + j; 117 for (SubscriptValue k{0}; k < n; ++k) { 118 xAt[1] = x1 + k; 119 yAt[0] = y0 + k; 120 accumulator.Accumulate(xAt, yAt); 121 } 122 resAt[1] = res1 + j; 123 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 124 } 125 ++resAt[0]; 126 ++xAt[0]; 127 } 128 } else { 129 if constexpr (std::is_same_v<XT, YT>) { 130 if constexpr (std::is_same_v<XT, float>) { 131 // TODO: call BLAS-2 SGEMV 132 } else if constexpr (std::is_same_v<XT, double>) { 133 // TODO: call BLAS-2 DGEMV 134 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 135 // TODO: call BLAS-2 CGEMV 136 } else if constexpr (std::is_same_v<XT, std::complex<float>>) { 137 // TODO: call BLAS-2 ZGEMV 138 } 139 } 140 if (xRank == 2) { // M*V -> V 141 SubscriptValue x1{xAt[1]}, y0{yAt[0]}; 142 for (SubscriptValue j{0}; j < extent[0]; ++j) { 143 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 144 for (SubscriptValue k{0}; k < n; ++k) { 145 xAt[1] = x1 + k; 146 yAt[0] = y0 + k; 147 accumulator.Accumulate(xAt, yAt); 148 } 149 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 150 ++resAt[0]; 151 ++xAt[0]; 152 } 153 } else { // V*M -> V 154 SubscriptValue x0{xAt[0]}, y0{yAt[0]}; 155 for (SubscriptValue j{0}; j < extent[0]; ++j) { 156 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; 157 for (SubscriptValue k{0}; k < n; ++k) { 158 xAt[0] = x0 + k; 159 yAt[0] = y0 + k; 160 accumulator.Accumulate(xAt, yAt); 161 } 162 *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); 163 ++resAt[0]; 164 ++yAt[1]; 165 } 166 } 167 } 168 } 169 170 // Maps the dynamic type information from the arguments' descriptors 171 // to the right instantiation of DoMatmul() for valid combinations of 172 // types. 173 template <bool IS_ALLOCATING> struct Matmul { 174 using ResultDescriptor = 175 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; 176 template <TypeCategory XCAT, int XKIND> struct MM1 { 177 template <TypeCategory YCAT, int YKIND> struct MM2 { 178 void operator()(ResultDescriptor &result, const Descriptor &x, 179 const Descriptor &y, Terminator &terminator) const { 180 if constexpr (constexpr auto resultType{ 181 GetResultType(XCAT, XKIND, YCAT, YKIND)}) { 182 if constexpr (common::IsNumericTypeCategory(resultType->first) || 183 resultType->first == TypeCategory::Logical) { 184 return DoMatmul<IS_ALLOCATING, resultType->first, 185 resultType->second, CppTypeFor<XCAT, XKIND>, 186 CppTypeFor<YCAT, YKIND>>(result, x, y, terminator); 187 } 188 } 189 terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))", 190 static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); 191 } 192 }; 193 void operator()(ResultDescriptor &result, const Descriptor &x, 194 const Descriptor &y, Terminator &terminator, TypeCategory yCat, 195 int yKind) const { 196 ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator); 197 } 198 }; 199 void operator()(ResultDescriptor &result, const Descriptor &x, 200 const Descriptor &y, const char *sourceFile, int line) const { 201 Terminator terminator{sourceFile, line}; 202 auto xCatKind{x.type().GetCategoryAndKind()}; 203 auto yCatKind{y.type().GetCategoryAndKind()}; 204 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); 205 ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result, 206 x, y, terminator, yCatKind->first, yCatKind->second); 207 } 208 }; 209 210 extern "C" { 211 void RTNAME(Matmul)(Descriptor &result, const Descriptor &x, 212 const Descriptor &y, const char *sourceFile, int line) { 213 Matmul<true>{}(result, x, y, sourceFile, line); 214 } 215 void RTNAME(MatmulDirect)(const Descriptor &result, const Descriptor &x, 216 const Descriptor &y, const char *sourceFile, int line) { 217 Matmul<false>{}(result, x, y, sourceFile, line); 218 } 219 } // extern "C" 220 } // namespace Fortran::runtime 221