18a91bc7bSHarrietAkot //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
28a91bc7bSHarrietAkot //
38a91bc7bSHarrietAkot // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48a91bc7bSHarrietAkot // See https://llvm.org/LICENSE.txt for license information.
58a91bc7bSHarrietAkot // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a91bc7bSHarrietAkot //
78a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
88a91bc7bSHarrietAkot //
98a91bc7bSHarrietAkot // This file implements a light-weight runtime support library that is useful
108a91bc7bSHarrietAkot // for sparse tensor manipulations. The functionality provided in this library
118a91bc7bSHarrietAkot // is meant to simplify benchmarking, testing, and debugging MLIR code that
128a91bc7bSHarrietAkot // operates on sparse tensors. The provided functionality is **not** part
138a91bc7bSHarrietAkot // of core MLIR, however.
148a91bc7bSHarrietAkot //
158a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
168a91bc7bSHarrietAkot 
17845561ecSwren romano #include "mlir/ExecutionEngine/SparseTensorUtils.h"
188a91bc7bSHarrietAkot 
198a91bc7bSHarrietAkot #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
208a91bc7bSHarrietAkot 
218a91bc7bSHarrietAkot #include <algorithm>
228a91bc7bSHarrietAkot #include <cassert>
238a91bc7bSHarrietAkot #include <cctype>
248a91bc7bSHarrietAkot #include <cstdio>
258a91bc7bSHarrietAkot #include <cstdlib>
268a91bc7bSHarrietAkot #include <cstring>
27efa15f41SAart Bik #include <fstream>
28753fe330Swren romano #include <functional>
29efa15f41SAart Bik #include <iostream>
304d0a18d0Swren romano #include <limits>
318a91bc7bSHarrietAkot #include <numeric>
32736c1b66SAart Bik 
338a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
348a91bc7bSHarrietAkot //
358a91bc7bSHarrietAkot // Internal support for storing and reading sparse tensors.
368a91bc7bSHarrietAkot //
378a91bc7bSHarrietAkot // The following memory-resident sparse storage schemes are supported:
388a91bc7bSHarrietAkot //
398a91bc7bSHarrietAkot // (a) A coordinate scheme for temporarily storing and lexicographically
408a91bc7bSHarrietAkot //     sorting a sparse tensor by index (SparseTensorCOO).
418a91bc7bSHarrietAkot //
428a91bc7bSHarrietAkot // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
438a91bc7bSHarrietAkot //     per-dimension sparse/dense annnotations together with a dimension
448a91bc7bSHarrietAkot //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
458a91bc7bSHarrietAkot //
468a91bc7bSHarrietAkot // The following external formats are supported:
478a91bc7bSHarrietAkot //
488a91bc7bSHarrietAkot // (1) Matrix Market Exchange (MME): *.mtx
498a91bc7bSHarrietAkot //     https://math.nist.gov/MatrixMarket/formats.html
508a91bc7bSHarrietAkot //
518a91bc7bSHarrietAkot // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
528a91bc7bSHarrietAkot //     http://frostt.io/tensors/file-formats.html
538a91bc7bSHarrietAkot //
548a91bc7bSHarrietAkot // Two public APIs are supported:
558a91bc7bSHarrietAkot //
568a91bc7bSHarrietAkot // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
578a91bc7bSHarrietAkot //     tensors. These methods should be used exclusively by MLIR
588a91bc7bSHarrietAkot //     compiler-generated code.
598a91bc7bSHarrietAkot //
608a91bc7bSHarrietAkot // (II) Methods that accept C-style data structures to interact with sparse
618a91bc7bSHarrietAkot //      tensors. These methods can be used by any external runtime that wants
628a91bc7bSHarrietAkot //      to interact with MLIR compiler-generated code.
638a91bc7bSHarrietAkot //
648a91bc7bSHarrietAkot // In both cases (I) and (II), the SparseTensorStorage format is externally
658a91bc7bSHarrietAkot // only visible as an opaque pointer.
668a91bc7bSHarrietAkot //
678a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
688a91bc7bSHarrietAkot 
698a91bc7bSHarrietAkot namespace {
708a91bc7bSHarrietAkot 
7103fe15ceSAart Bik static constexpr int kColWidth = 1025;
7203fe15ceSAart Bik 
7372ec2f76Swren romano /// A version of `operator*` on `uint64_t` which checks for overflows.
checkedMul(uint64_t lhs,uint64_t rhs)7472ec2f76Swren romano static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
7572ec2f76Swren romano   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
7672ec2f76Swren romano          "Integer overflow");
7772ec2f76Swren romano   return lhs * rhs;
7872ec2f76Swren romano }
7972ec2f76Swren romano 
80774674ceSwren romano // This macro helps minimize repetition of this idiom, as well as ensuring
81774674ceSwren romano // we have some additional output indicating where the error is coming from.
82774674ceSwren romano // (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
83774674ceSwren romano // to track down whether an error is coming from our code vs somewhere else
84774674ceSwren romano // in MLIR.)
85774674ceSwren romano #define FATAL(...)                                                             \
86c63d4facSwren romano   do {                                                                         \
87774674ceSwren romano     fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__);                        \
88774674ceSwren romano     exit(1);                                                                   \
89c63d4facSwren romano   } while (0)
90774674ceSwren romano 
91a4c53f8cSwren romano // TODO: try to unify this with `SparseTensorFile::assertMatchesShape`
92a4c53f8cSwren romano // which is used by `openSparseTensorCOO`.  It's easy enough to resolve
93a4c53f8cSwren romano // the `std::vector` vs pointer mismatch for `dimSizes`; but it's trickier
94a4c53f8cSwren romano // to resolve the presence/absence of `perm` (without introducing extra
95a4c53f8cSwren romano // overhead), so perhaps the code duplication is unavoidable.
968cb33240Swren romano //
97fa6aed2aSwren romano /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
988cb33240Swren romano /// semantic-order to target-order) are a refinement of the desired `shape`
998cb33240Swren romano /// (in semantic-order).
1008cb33240Swren romano ///
1018cb33240Swren romano /// Precondition: `perm` and `shape` must be valid for `rank`.
1028cb33240Swren romano static inline void
assertPermutedSizesMatchShape(const std::vector<uint64_t> & dimSizes,uint64_t rank,const uint64_t * perm,const uint64_t * shape)103fa6aed2aSwren romano assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes,
104fa6aed2aSwren romano                               uint64_t rank, const uint64_t *perm,
105fa6aed2aSwren romano                               const uint64_t *shape) {
1068cb33240Swren romano   assert(perm && shape);
107fa6aed2aSwren romano   assert(rank == dimSizes.size() && "Rank mismatch");
1088cb33240Swren romano   for (uint64_t r = 0; r < rank; r++)
109fa6aed2aSwren romano     assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) &&
1108cb33240Swren romano            "Dimension size mismatch");
1118cb33240Swren romano }
1128cb33240Swren romano 
1138a91bc7bSHarrietAkot /// A sparse tensor element in coordinate scheme (value and indices).
1148a91bc7bSHarrietAkot /// For example, a rank-1 vector element would look like
1158a91bc7bSHarrietAkot ///   ({i}, a[i])
1168a91bc7bSHarrietAkot /// and a rank-5 tensor element like
1178a91bc7bSHarrietAkot ///   ({i,j,k,l,m}, a[i,j,k,l,m])
118ccd047cbSAart Bik /// We use pointer to a shared index pool rather than e.g. a direct
119ccd047cbSAart Bik /// vector since that (1) reduces the per-element memory footprint, and
120ccd047cbSAart Bik /// (2) centralizes the memory reservation and (re)allocation to one place.
1218a91bc7bSHarrietAkot template <typename V>
12276944420Swren romano struct Element final {
Element__anon2f09016c0111::Element123ccd047cbSAart Bik   Element(uint64_t *ind, V val) : indices(ind), value(val){};
124ccd047cbSAart Bik   uint64_t *indices; // pointer into shared index pool
1258a91bc7bSHarrietAkot   V value;
1268a91bc7bSHarrietAkot };
1278a91bc7bSHarrietAkot 
128753fe330Swren romano /// The type of callback functions which receive an element.  We avoid
129753fe330Swren romano /// packaging the coordinates and value together as an `Element` object
130753fe330Swren romano /// because this helps keep code somewhat cleaner.
131753fe330Swren romano template <typename V>
132753fe330Swren romano using ElementConsumer =
133753fe330Swren romano     const std::function<void(const std::vector<uint64_t> &, V)> &;
134753fe330Swren romano 
1358a91bc7bSHarrietAkot /// A memory-resident sparse tensor in coordinate scheme (collection of
1368a91bc7bSHarrietAkot /// elements). This data structure is used to read a sparse tensor from
1378a91bc7bSHarrietAkot /// any external format into memory and sort the elements lexicographically
1388a91bc7bSHarrietAkot /// by indices before passing it back to the client (most packed storage
1398a91bc7bSHarrietAkot /// formats require the elements to appear in lexicographic index order).
1408a91bc7bSHarrietAkot template <typename V>
14176944420Swren romano struct SparseTensorCOO final {
1428a91bc7bSHarrietAkot public:
SparseTensorCOO__anon2f09016c0111::SparseTensorCOO143fa6aed2aSwren romano   SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
144fa6aed2aSwren romano       : dimSizes(dimSizes) {
145ccd047cbSAart Bik     if (capacity) {
1468a91bc7bSHarrietAkot       elements.reserve(capacity);
147ccd047cbSAart Bik       indices.reserve(capacity * getRank());
1488a91bc7bSHarrietAkot     }
149ccd047cbSAart Bik   }
150ccd047cbSAart Bik 
1518a91bc7bSHarrietAkot   /// Adds element as indices and value.
add__anon2f09016c0111::SparseTensorCOO1528a91bc7bSHarrietAkot   void add(const std::vector<uint64_t> &ind, V val) {
1538a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to add() after startIterator()");
154ccd047cbSAart Bik     uint64_t *base = indices.data();
155ccd047cbSAart Bik     uint64_t size = indices.size();
1568a91bc7bSHarrietAkot     uint64_t rank = getRank();
157fa6aed2aSwren romano     assert(ind.size() == rank && "Element rank mismatch");
158ccd047cbSAart Bik     for (uint64_t r = 0; r < rank; r++) {
159fa6aed2aSwren romano       assert(ind[r] < dimSizes[r] && "Index is too large for the dimension");
160ccd047cbSAart Bik       indices.push_back(ind[r]);
1618a91bc7bSHarrietAkot     }
162ccd047cbSAart Bik     // This base only changes if indices were reallocated. In that case, we
163ccd047cbSAart Bik     // need to correct all previous pointers into the vector. Note that this
164ccd047cbSAart Bik     // only happens if we did not set the initial capacity right, and then only
165ccd047cbSAart Bik     // for every internal vector reallocation (which with the doubling rule
166ccd047cbSAart Bik     // should only incur an amortized linear overhead).
167298d2fa1SMehdi Amini     uint64_t *newBase = indices.data();
168298d2fa1SMehdi Amini     if (newBase != base) {
169ccd047cbSAart Bik       for (uint64_t i = 0, n = elements.size(); i < n; i++)
170298d2fa1SMehdi Amini         elements[i].indices = newBase + (elements[i].indices - base);
171298d2fa1SMehdi Amini       base = newBase;
172ccd047cbSAart Bik     }
173ccd047cbSAart Bik     // Add element as (pointer into shared index pool, value) pair.
174ccd047cbSAart Bik     elements.emplace_back(base + size, val);
175ccd047cbSAart Bik   }
176ccd047cbSAart Bik 
1778a91bc7bSHarrietAkot   /// Sorts elements lexicographically by index.
sort__anon2f09016c0111::SparseTensorCOO1788a91bc7bSHarrietAkot   void sort() {
1798a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
180cf358253Swren romano     // TODO: we may want to cache an `isSorted` bit, to avoid
181cf358253Swren romano     // unnecessary/redundant sorting.
182ccd047cbSAart Bik     uint64_t rank = getRank();
183aff9c89fSwren romano     std::sort(elements.begin(), elements.end(),
184aff9c89fSwren romano               [rank](const Element<V> &e1, const Element<V> &e2) {
185ccd047cbSAart Bik                 for (uint64_t r = 0; r < rank; r++) {
186ccd047cbSAart Bik                   if (e1.indices[r] == e2.indices[r])
187ccd047cbSAart Bik                     continue;
188ccd047cbSAart Bik                   return e1.indices[r] < e2.indices[r];
1898a91bc7bSHarrietAkot                 }
190ccd047cbSAart Bik                 return false;
191ccd047cbSAart Bik               });
192ccd047cbSAart Bik   }
193ccd047cbSAart Bik 
194fa6aed2aSwren romano   /// Get the rank of the tensor.
getRank__anon2f09016c0111::SparseTensorCOO195fa6aed2aSwren romano   uint64_t getRank() const { return dimSizes.size(); }
196ccd047cbSAart Bik 
197fa6aed2aSwren romano   /// Getter for the dimension-sizes array.
getDimSizes__anon2f09016c0111::SparseTensorCOO198fa6aed2aSwren romano   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
199ccd047cbSAart Bik 
200fa6aed2aSwren romano   /// Getter for the elements array.
getElements__anon2f09016c0111::SparseTensorCOO2018a91bc7bSHarrietAkot   const std::vector<Element<V>> &getElements() const { return elements; }
2028a91bc7bSHarrietAkot 
2038a91bc7bSHarrietAkot   /// Switch into iterator mode.
startIterator__anon2f09016c0111::SparseTensorCOO2048a91bc7bSHarrietAkot   void startIterator() {
2058a91bc7bSHarrietAkot     iteratorLocked = true;
2068a91bc7bSHarrietAkot     iteratorPos = 0;
2078a91bc7bSHarrietAkot   }
208ccd047cbSAart Bik 
2098a91bc7bSHarrietAkot   /// Get the next element.
getNext__anon2f09016c0111::SparseTensorCOO2108a91bc7bSHarrietAkot   const Element<V> *getNext() {
2118a91bc7bSHarrietAkot     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
2128a91bc7bSHarrietAkot     if (iteratorPos < elements.size())
2138a91bc7bSHarrietAkot       return &(elements[iteratorPos++]);
2148a91bc7bSHarrietAkot     iteratorLocked = false;
2158a91bc7bSHarrietAkot     return nullptr;
2168a91bc7bSHarrietAkot   }
2178a91bc7bSHarrietAkot 
2188a91bc7bSHarrietAkot   /// Factory method. Permutes the original dimensions according to
2198a91bc7bSHarrietAkot   /// the given ordering and expects subsequent add() calls to honor
2208a91bc7bSHarrietAkot   /// that same ordering for the given indices. The result is a
2218a91bc7bSHarrietAkot   /// fully permuted coordinate scheme.
2228d8b566fSwren romano   ///
223fa6aed2aSwren romano   /// Precondition: `dimSizes` and `perm` must be valid for `rank`.
newSparseTensorCOO__anon2f09016c0111::SparseTensorCOO2248a91bc7bSHarrietAkot   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
225fa6aed2aSwren romano                                                 const uint64_t *dimSizes,
2268a91bc7bSHarrietAkot                                                 const uint64_t *perm,
2278a91bc7bSHarrietAkot                                                 uint64_t capacity = 0) {
2288a91bc7bSHarrietAkot     std::vector<uint64_t> permsz(rank);
229d83a7068Swren romano     for (uint64_t r = 0; r < rank; r++) {
230fa6aed2aSwren romano       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
231fa6aed2aSwren romano       permsz[perm[r]] = dimSizes[r];
232d83a7068Swren romano     }
2338a91bc7bSHarrietAkot     return new SparseTensorCOO<V>(permsz, capacity);
2348a91bc7bSHarrietAkot   }
2358a91bc7bSHarrietAkot 
2368a91bc7bSHarrietAkot private:
237fa6aed2aSwren romano   const std::vector<uint64_t> dimSizes; // per-dimension sizes
238ccd047cbSAart Bik   std::vector<Element<V>> elements;     // all COO elements
239ccd047cbSAart Bik   std::vector<uint64_t> indices;        // shared index pool
240db6796dfSMehdi Amini   bool iteratorLocked = false;
241db6796dfSMehdi Amini   unsigned iteratorPos = 0;
2428a91bc7bSHarrietAkot };
2438a91bc7bSHarrietAkot 
2448cb33240Swren romano // Forward.
2458cb33240Swren romano template <typename V>
2468cb33240Swren romano class SparseTensorEnumeratorBase;
2478cb33240Swren romano 
248774674ceSwren romano // Helper macro for generating error messages when some
249774674ceSwren romano // `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
250774674ceSwren romano // and then the wrong "partial method specialization" is called.
251774674ceSwren romano #define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
252774674ceSwren romano 
2538d8b566fSwren romano /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
2548d8b566fSwren romano /// takes responsibility for all the `<P,I,V>`-independent aspects
2558d8b566fSwren romano /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
2568d8b566fSwren romano /// we use function overloading to implement "partial" method
2578d8b566fSwren romano /// specialization, which the C-API relies on to catch type errors
2588d8b566fSwren romano /// arising from our use of opaque pointers.
2598a91bc7bSHarrietAkot class SparseTensorStorageBase {
2608a91bc7bSHarrietAkot public:
2618d8b566fSwren romano   /// Constructs a new storage object.  The `perm` maps the tensor's
2628d8b566fSwren romano   /// semantic-ordering of dimensions to this object's storage-order.
263fa6aed2aSwren romano   /// The `dimSizes` and `sparsity` arrays are already in storage-order.
2648d8b566fSwren romano   ///
265fa6aed2aSwren romano   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
SparseTensorStorageBase(const std::vector<uint64_t> & dimSizes,const uint64_t * perm,const DimLevelType * sparsity)266fa6aed2aSwren romano   SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes,
2678d8b566fSwren romano                           const uint64_t *perm, const DimLevelType *sparsity)
268fa6aed2aSwren romano       : dimSizes(dimSizes), rev(getRank()),
2698d8b566fSwren romano         dimTypes(sparsity, sparsity + getRank()) {
270753fe330Swren romano     assert(perm && sparsity);
2718d8b566fSwren romano     const uint64_t rank = getRank();
2728d8b566fSwren romano     // Validate parameters.
2738d8b566fSwren romano     assert(rank > 0 && "Trivial shape is unsupported");
2748d8b566fSwren romano     for (uint64_t r = 0; r < rank; r++) {
2758d8b566fSwren romano       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
2768d8b566fSwren romano       assert((dimTypes[r] == DimLevelType::kDense ||
2778d8b566fSwren romano               dimTypes[r] == DimLevelType::kCompressed) &&
2788d8b566fSwren romano              "Unsupported DimLevelType");
2798d8b566fSwren romano     }
2808d8b566fSwren romano     // Construct the "reverse" (i.e., inverse) permutation.
2818d8b566fSwren romano     for (uint64_t r = 0; r < rank; r++)
2828d8b566fSwren romano       rev[perm[r]] = r;
2838d8b566fSwren romano   }
2848d8b566fSwren romano 
2858d8b566fSwren romano   virtual ~SparseTensorStorageBase() = default;
2868d8b566fSwren romano 
2878d8b566fSwren romano   /// Get the rank of the tensor.
getRank() const2888d8b566fSwren romano   uint64_t getRank() const { return dimSizes.size(); }
2898d8b566fSwren romano 
2908d8b566fSwren romano   /// Getter for the dimension-sizes array, in storage-order.
getDimSizes() const2918d8b566fSwren romano   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
2928d8b566fSwren romano 
2938d8b566fSwren romano   /// Safely lookup the size of the given (storage-order) dimension.
getDimSize(uint64_t d) const2948d8b566fSwren romano   uint64_t getDimSize(uint64_t d) const {
2958d8b566fSwren romano     assert(d < getRank());
2968d8b566fSwren romano     return dimSizes[d];
2978d8b566fSwren romano   }
2988d8b566fSwren romano 
2998d8b566fSwren romano   /// Getter for the "reverse" permutation, which maps this object's
3008d8b566fSwren romano   /// storage-order to the tensor's semantic-order.
getRev() const3018d8b566fSwren romano   const std::vector<uint64_t> &getRev() const { return rev; }
3028d8b566fSwren romano 
3038d8b566fSwren romano   /// Getter for the dimension-types array, in storage-order.
getDimTypes() const3048d8b566fSwren romano   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
3058d8b566fSwren romano 
3068d8b566fSwren romano   /// Safely check if the (storage-order) dimension uses compressed storage.
isCompressedDim(uint64_t d) const3078d8b566fSwren romano   bool isCompressedDim(uint64_t d) const {
3088d8b566fSwren romano     assert(d < getRank());
3098d8b566fSwren romano     return (dimTypes[d] == DimLevelType::kCompressed);
3108d8b566fSwren romano   }
3118a91bc7bSHarrietAkot 
3128cb33240Swren romano   /// Allocate a new enumerator.
3131313f5d3Swren romano #define DECL_NEWENUMERATOR(VNAME, V)                                           \
3141313f5d3Swren romano   virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t,       \
3151313f5d3Swren romano                              const uint64_t *) const {                         \
316774674ceSwren romano     FATAL_PIV("newEnumerator" #VNAME);                                         \
3178cb33240Swren romano   }
3181313f5d3Swren romano   FOREVERY_V(DECL_NEWENUMERATOR)
3191313f5d3Swren romano #undef DECL_NEWENUMERATOR
3208cb33240Swren romano 
3214f2ec7f9SAart Bik   /// Overhead storage.
322a9a19f59Swren romano #define DECL_GETPOINTERS(PNAME, P)                                             \
323a9a19f59Swren romano   virtual void getPointers(std::vector<P> **, uint64_t) {                      \
324a9a19f59Swren romano     FATAL_PIV("getPointers" #PNAME);                                           \
325774674ceSwren romano   }
326a9a19f59Swren romano   FOREVERY_FIXED_O(DECL_GETPOINTERS)
327a9a19f59Swren romano #undef DECL_GETPOINTERS
328a9a19f59Swren romano #define DECL_GETINDICES(INAME, I)                                              \
329a9a19f59Swren romano   virtual void getIndices(std::vector<I> **, uint64_t) {                       \
330a9a19f59Swren romano     FATAL_PIV("getIndices" #INAME);                                            \
331774674ceSwren romano   }
332a9a19f59Swren romano   FOREVERY_FIXED_O(DECL_GETINDICES)
333a9a19f59Swren romano #undef DECL_GETINDICES
3348a91bc7bSHarrietAkot 
3354f2ec7f9SAart Bik   /// Primary storage.
3361313f5d3Swren romano #define DECL_GETVALUES(VNAME, V)                                               \
337774674ceSwren romano   virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
3381313f5d3Swren romano   FOREVERY_V(DECL_GETVALUES)
3391313f5d3Swren romano #undef DECL_GETVALUES
3408a91bc7bSHarrietAkot 
3414f2ec7f9SAart Bik   /// Element-wise insertion in lexicographic index order.
3421313f5d3Swren romano #define DECL_LEXINSERT(VNAME, V)                                               \
343774674ceSwren romano   virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
3441313f5d3Swren romano   FOREVERY_V(DECL_LEXINSERT)
3451313f5d3Swren romano #undef DECL_LEXINSERT
3464f2ec7f9SAart Bik 
3474f2ec7f9SAart Bik   /// Expanded insertion.
3481313f5d3Swren romano #define DECL_EXPINSERT(VNAME, V)                                               \
3491313f5d3Swren romano   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) {      \
350774674ceSwren romano     FATAL_PIV("expInsert" #VNAME);                                             \
3514f2ec7f9SAart Bik   }
3521313f5d3Swren romano   FOREVERY_V(DECL_EXPINSERT)
3531313f5d3Swren romano #undef DECL_EXPINSERT
3544f2ec7f9SAart Bik 
3554f2ec7f9SAart Bik   /// Finishes insertion.
356f66e5769SAart Bik   virtual void endInsert() = 0;
357f66e5769SAart Bik 
358753fe330Swren romano protected:
359753fe330Swren romano   // Since this class is virtual, we must disallow public copying in
360753fe330Swren romano   // order to avoid "slicing".  Since this class has data members,
361753fe330Swren romano   // that means making copying protected.
362753fe330Swren romano   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
363753fe330Swren romano   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
364753fe330Swren romano   // Copy-assignment would be implicitly deleted (because `dimSizes`
365753fe330Swren romano   // is const), so we explicitly delete it for clarity.
366753fe330Swren romano   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
367753fe330Swren romano 
3688a91bc7bSHarrietAkot private:
3698d8b566fSwren romano   const std::vector<uint64_t> dimSizes;
3708d8b566fSwren romano   std::vector<uint64_t> rev;
3718d8b566fSwren romano   const std::vector<DimLevelType> dimTypes;
3728a91bc7bSHarrietAkot };
3738a91bc7bSHarrietAkot 
374774674ceSwren romano #undef FATAL_PIV
375774674ceSwren romano 
376753fe330Swren romano // Forward.
377753fe330Swren romano template <typename P, typename I, typename V>
378753fe330Swren romano class SparseTensorEnumerator;
379753fe330Swren romano 
3808a91bc7bSHarrietAkot /// A memory-resident sparse tensor using a storage scheme based on
3818a91bc7bSHarrietAkot /// per-dimension sparse/dense annotations. This data structure provides a
3828a91bc7bSHarrietAkot /// bufferized form of a sparse tensor type. In contrast to generating setup
3838a91bc7bSHarrietAkot /// methods for each differently annotated sparse tensor, this method provides
3848a91bc7bSHarrietAkot /// a convenient "one-size-fits-all" solution that simply takes an input tensor
3858a91bc7bSHarrietAkot /// and annotations to implement all required setup in a general manner.
3868a91bc7bSHarrietAkot template <typename P, typename I, typename V>
38776944420Swren romano class SparseTensorStorage final : public SparseTensorStorageBase {
3888cb33240Swren romano   /// Private constructor to share code between the other constructors.
3898cb33240Swren romano   /// Beware that the object is not necessarily guaranteed to be in a
3908cb33240Swren romano   /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
3918cb33240Swren romano   /// doesn't entail `!(pointers[d].empty())`.
3928cb33240Swren romano   ///
393fa6aed2aSwren romano   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
SparseTensorStorage(const std::vector<uint64_t> & dimSizes,const uint64_t * perm,const DimLevelType * sparsity)394fa6aed2aSwren romano   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
395fa6aed2aSwren romano                       const uint64_t *perm, const DimLevelType *sparsity)
396fa6aed2aSwren romano       : SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()),
3978cb33240Swren romano         indices(getRank()), idx(getRank()) {}
3988cb33240Swren romano 
3998a91bc7bSHarrietAkot public:
4008a91bc7bSHarrietAkot   /// Constructs a sparse tensor storage scheme with the given dimensions,
4018a91bc7bSHarrietAkot   /// permutation, and per-dimension dense/sparse annotations, using
4028a91bc7bSHarrietAkot   /// the coordinate scheme tensor for the initial contents if provided.
4038d8b566fSwren romano   ///
404fa6aed2aSwren romano   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
SparseTensorStorage(const std::vector<uint64_t> & dimSizes,const uint64_t * perm,const DimLevelType * sparsity,SparseTensorCOO<V> * coo)405fa6aed2aSwren romano   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
406fa6aed2aSwren romano                       const uint64_t *perm, const DimLevelType *sparsity,
407fa6aed2aSwren romano                       SparseTensorCOO<V> *coo)
408fa6aed2aSwren romano       : SparseTensorStorage(dimSizes, perm, sparsity) {
4098a91bc7bSHarrietAkot     // Provide hints on capacity of pointers and indices.
410175b9af4SAart Bik     // TODO: needs much fine-tuning based on actual sparsity; currently
411175b9af4SAart Bik     //       we reserve pointer/index space based on all previous dense
412175b9af4SAart Bik     //       dimensions, which works well up to first sparse dim; but
413175b9af4SAart Bik     //       we should really use nnz and dense/sparse distribution.
414f66e5769SAart Bik     bool allDense = true;
415f66e5769SAart Bik     uint64_t sz = 1;
4168d8b566fSwren romano     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
4178d8b566fSwren romano       if (isCompressedDim(r)) {
418fa6aed2aSwren romano         // TODO: Take a parameter between 1 and `dimSizes[r]`, and multiply
4198d8b566fSwren romano         // `sz` by that before reserving. (For now we just use 1.)
420f66e5769SAart Bik         pointers[r].reserve(sz + 1);
4218d8b566fSwren romano         pointers[r].push_back(0);
422f66e5769SAart Bik         indices[r].reserve(sz);
423f66e5769SAart Bik         sz = 1;
424f66e5769SAart Bik         allDense = false;
4258d8b566fSwren romano       } else { // Dense dimension.
4268d8b566fSwren romano         sz = checkedMul(sz, getDimSizes()[r]);
4278a91bc7bSHarrietAkot       }
4288a91bc7bSHarrietAkot     }
4298a91bc7bSHarrietAkot     // Then assign contents from coordinate scheme tensor if provided.
4308d8b566fSwren romano     if (coo) {
4314d0a18d0Swren romano       // Ensure both preconditions of `fromCOO`.
432fa6aed2aSwren romano       assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch");
4338d8b566fSwren romano       coo->sort();
4344d0a18d0Swren romano       // Now actually insert the `elements`.
4358d8b566fSwren romano       const std::vector<Element<V>> &elements = coo->getElements();
436ceda1ae9Swren romano       uint64_t nnz = elements.size();
4378a91bc7bSHarrietAkot       values.reserve(nnz);
438ceda1ae9Swren romano       fromCOO(elements, 0, nnz, 0);
4391ce77b56SAart Bik     } else if (allDense) {
440f66e5769SAart Bik       values.resize(sz, 0);
4418a91bc7bSHarrietAkot     }
4428a91bc7bSHarrietAkot   }
4438a91bc7bSHarrietAkot 
4448cb33240Swren romano   /// Constructs a sparse tensor storage scheme with the given dimensions,
4458cb33240Swren romano   /// permutation, and per-dimension dense/sparse annotations, using
4468cb33240Swren romano   /// the given sparse tensor for the initial contents.
4478cb33240Swren romano   ///
4488cb33240Swren romano   /// Preconditions:
449fa6aed2aSwren romano   /// * `perm` and `sparsity` must be valid for `dimSizes.size()`.
4508cb33240Swren romano   /// * The `tensor` must have the same value type `V`.
451fa6aed2aSwren romano   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
452fa6aed2aSwren romano                       const uint64_t *perm, const DimLevelType *sparsity,
4538cb33240Swren romano                       const SparseTensorStorageBase &tensor);
4548cb33240Swren romano 
4550f68c959SMehdi Amini   ~SparseTensorStorage() final = default;
4568a91bc7bSHarrietAkot 
457f66e5769SAart Bik   /// Partially specialize these getter methods based on template types.
getPointers(std::vector<P> ** out,uint64_t d)4580f68c959SMehdi Amini   void getPointers(std::vector<P> **out, uint64_t d) final {
4598a91bc7bSHarrietAkot     assert(d < getRank());
4608a91bc7bSHarrietAkot     *out = &pointers[d];
4618a91bc7bSHarrietAkot   }
getIndices(std::vector<I> ** out,uint64_t d)4620f68c959SMehdi Amini   void getIndices(std::vector<I> **out, uint64_t d) final {
4638a91bc7bSHarrietAkot     assert(d < getRank());
4648a91bc7bSHarrietAkot     *out = &indices[d];
4658a91bc7bSHarrietAkot   }
getValues(std::vector<V> ** out)4660f68c959SMehdi Amini   void getValues(std::vector<V> **out) final { *out = &values; }
4678a91bc7bSHarrietAkot 
46803fe15ceSAart Bik   /// Partially specialize lexicographical insertions based on template types.
lexInsert(const uint64_t * cursor,V val)4690f68c959SMehdi Amini   void lexInsert(const uint64_t *cursor, V val) final {
4701ce77b56SAart Bik     // First, wrap up pending insertion path.
4711ce77b56SAart Bik     uint64_t diff = 0;
4721ce77b56SAart Bik     uint64_t top = 0;
4731ce77b56SAart Bik     if (!values.empty()) {
4741ce77b56SAart Bik       diff = lexDiff(cursor);
4751ce77b56SAart Bik       endPath(diff + 1);
4761ce77b56SAart Bik       top = idx[diff] + 1;
4771ce77b56SAart Bik     }
4781ce77b56SAart Bik     // Then continue with insertion path.
4791ce77b56SAart Bik     insPath(cursor, diff, top, val);
480f66e5769SAart Bik   }
481f66e5769SAart Bik 
4824f2ec7f9SAart Bik   /// Partially specialize expanded insertions based on template types.
4834f2ec7f9SAart Bik   /// Note that this method resets the values/filled-switch array back
4844f2ec7f9SAart Bik   /// to all-zero/false while only iterating over the nonzero elements.
expInsert(uint64_t * cursor,V * values,bool * filled,uint64_t * added,uint64_t count)4854f2ec7f9SAart Bik   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
4860f68c959SMehdi Amini                  uint64_t count) final {
4874f2ec7f9SAart Bik     if (count == 0)
4884f2ec7f9SAart Bik       return;
4894f2ec7f9SAart Bik     // Sort.
4904f2ec7f9SAart Bik     std::sort(added, added + count);
4914f2ec7f9SAart Bik     // Restore insertion path for first insert.
4923bf2ba3bSwren romano     const uint64_t lastDim = getRank() - 1;
4934f2ec7f9SAart Bik     uint64_t index = added[0];
4943bf2ba3bSwren romano     cursor[lastDim] = index;
4954f2ec7f9SAart Bik     lexInsert(cursor, values[index]);
4964f2ec7f9SAart Bik     assert(filled[index]);
4974f2ec7f9SAart Bik     values[index] = 0;
4984f2ec7f9SAart Bik     filled[index] = false;
4994f2ec7f9SAart Bik     // Subsequent insertions are quick.
5004f2ec7f9SAart Bik     for (uint64_t i = 1; i < count; i++) {
5014f2ec7f9SAart Bik       assert(index < added[i] && "non-lexicographic insertion");
5024f2ec7f9SAart Bik       index = added[i];
5033bf2ba3bSwren romano       cursor[lastDim] = index;
5043bf2ba3bSwren romano       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
5054f2ec7f9SAart Bik       assert(filled[index]);
5063bf2ba3bSwren romano       values[index] = 0;
5074f2ec7f9SAart Bik       filled[index] = false;
5084f2ec7f9SAart Bik     }
5094f2ec7f9SAart Bik   }
5104f2ec7f9SAart Bik 
511f66e5769SAart Bik   /// Finalizes lexicographic insertions.
endInsert()5120f68c959SMehdi Amini   void endInsert() final {
5131ce77b56SAart Bik     if (values.empty())
51472ec2f76Swren romano       finalizeSegment(0);
5151ce77b56SAart Bik     else
5161ce77b56SAart Bik       endPath(0);
5171ce77b56SAart Bik   }
518f66e5769SAart Bik 
newEnumerator(SparseTensorEnumeratorBase<V> ** out,uint64_t rank,const uint64_t * perm) const5198cb33240Swren romano   void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
5200f68c959SMehdi Amini                      const uint64_t *perm) const final {
5218cb33240Swren romano     *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
5228cb33240Swren romano   }
5238cb33240Swren romano 
5248a91bc7bSHarrietAkot   /// Returns this sparse tensor storage scheme as a new memory-resident
5258a91bc7bSHarrietAkot   /// sparse tensor in coordinate scheme with the given dimension order.
5268d8b566fSwren romano   ///
5278d8b566fSwren romano   /// Precondition: `perm` must be valid for `getRank()`.
toCOO(const uint64_t * perm) const528753fe330Swren romano   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
5298cb33240Swren romano     SparseTensorEnumeratorBase<V> *enumerator;
5308cb33240Swren romano     newEnumerator(&enumerator, getRank(), perm);
531753fe330Swren romano     SparseTensorCOO<V> *coo =
5328cb33240Swren romano         new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
5338cb33240Swren romano     enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
534753fe330Swren romano       coo->add(ind, val);
535753fe330Swren romano     });
5368d8b566fSwren romano     // TODO: This assertion assumes there are no stored zeros,
5378d8b566fSwren romano     // or if there are then that we don't filter them out.
5388d8b566fSwren romano     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
5398d8b566fSwren romano     assert(coo->getElements().size() == values.size());
5408cb33240Swren romano     delete enumerator;
5418d8b566fSwren romano     return coo;
5428a91bc7bSHarrietAkot   }
5438a91bc7bSHarrietAkot 
5448a91bc7bSHarrietAkot   /// Factory method. Constructs a sparse tensor storage scheme with the given
5458a91bc7bSHarrietAkot   /// dimensions, permutation, and per-dimension dense/sparse annotations,
5468a91bc7bSHarrietAkot   /// using the coordinate scheme tensor for the initial contents if provided.
5478a91bc7bSHarrietAkot   /// In the latter case, the coordinate scheme must respect the same
5488a91bc7bSHarrietAkot   /// permutation as is desired for the new sparse tensor storage.
5498d8b566fSwren romano   ///
5508d8b566fSwren romano   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
5518a91bc7bSHarrietAkot   static SparseTensorStorage<P, I, V> *
newSparseTensor(uint64_t rank,const uint64_t * shape,const uint64_t * perm,const DimLevelType * sparsity,SparseTensorCOO<V> * coo)552d83a7068Swren romano   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
5538d8b566fSwren romano                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
5548a91bc7bSHarrietAkot     SparseTensorStorage<P, I, V> *n = nullptr;
5558d8b566fSwren romano     if (coo) {
556fa6aed2aSwren romano       const auto &coosz = coo->getDimSizes();
5578cb33240Swren romano       assertPermutedSizesMatchShape(coosz, rank, perm, shape);
5588d8b566fSwren romano       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
5598a91bc7bSHarrietAkot     } else {
5608a91bc7bSHarrietAkot       std::vector<uint64_t> permsz(rank);
561d83a7068Swren romano       for (uint64_t r = 0; r < rank; r++) {
562d83a7068Swren romano         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
563d83a7068Swren romano         permsz[perm[r]] = shape[r];
564d83a7068Swren romano       }
5658cb33240Swren romano       // We pass the null `coo` to ensure we select the intended constructor.
5668cb33240Swren romano       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
5678a91bc7bSHarrietAkot     }
5688a91bc7bSHarrietAkot     return n;
5698a91bc7bSHarrietAkot   }
5708a91bc7bSHarrietAkot 
5718cb33240Swren romano   /// Factory method. Constructs a sparse tensor storage scheme with
5728cb33240Swren romano   /// the given dimensions, permutation, and per-dimension dense/sparse
5738cb33240Swren romano   /// annotations, using the sparse tensor for the initial contents.
5748cb33240Swren romano   ///
5758cb33240Swren romano   /// Preconditions:
5768cb33240Swren romano   /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
5778cb33240Swren romano   /// * The `tensor` must have the same value type `V`.
5788cb33240Swren romano   static SparseTensorStorage<P, I, V> *
newSparseTensor(uint64_t rank,const uint64_t * shape,const uint64_t * perm,const DimLevelType * sparsity,const SparseTensorStorageBase * source)5798cb33240Swren romano   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
5808cb33240Swren romano                   const DimLevelType *sparsity,
5818cb33240Swren romano                   const SparseTensorStorageBase *source) {
5828cb33240Swren romano     assert(source && "Got nullptr for source");
5838cb33240Swren romano     SparseTensorEnumeratorBase<V> *enumerator;
5848cb33240Swren romano     source->newEnumerator(&enumerator, rank, perm);
5858cb33240Swren romano     const auto &permsz = enumerator->permutedSizes();
5868cb33240Swren romano     assertPermutedSizesMatchShape(permsz, rank, perm, shape);
5878cb33240Swren romano     auto *tensor =
5888cb33240Swren romano         new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
5898cb33240Swren romano     delete enumerator;
5908cb33240Swren romano     return tensor;
5918cb33240Swren romano   }
5928cb33240Swren romano 
5938a91bc7bSHarrietAkot private:
59472ec2f76Swren romano   /// Appends an arbitrary new position to `pointers[d]`.  This method
59572ec2f76Swren romano   /// checks that `pos` is representable in the `P` type; however, it
59672ec2f76Swren romano   /// does not check that `pos` is semantically valid (i.e., larger than
59772ec2f76Swren romano   /// the previous position and smaller than `indices[d].capacity()`).
appendPointer(uint64_t d,uint64_t pos,uint64_t count=1)5988d8b566fSwren romano   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
59972ec2f76Swren romano     assert(isCompressedDim(d));
60072ec2f76Swren romano     assert(pos <= std::numeric_limits<P>::max() &&
6014d0a18d0Swren romano            "Pointer value is too large for the P-type");
60272ec2f76Swren romano     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
6034d0a18d0Swren romano   }
6044d0a18d0Swren romano 
60572ec2f76Swren romano   /// Appends index `i` to dimension `d`, in the semantically general
60672ec2f76Swren romano   /// sense.  For non-dense dimensions, that means appending to the
60772ec2f76Swren romano   /// `indices[d]` array, checking that `i` is representable in the `I`
60872ec2f76Swren romano   /// type; however, we do not verify other semantic requirements (e.g.,
609fa6aed2aSwren romano   /// that `i` is in bounds for `dimSizes[d]`, and not previously occurring
61072ec2f76Swren romano   /// in the same segment).  For dense dimensions, this method instead
61172ec2f76Swren romano   /// appends the appropriate number of zeros to the `values` array,
61272ec2f76Swren romano   /// where `full` is the number of "entries" already written to `values`
61372ec2f76Swren romano   /// for this segment (aka one after the highest index previously appended).
appendIndex(uint64_t d,uint64_t full,uint64_t i)61472ec2f76Swren romano   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
61572ec2f76Swren romano     if (isCompressedDim(d)) {
6164d0a18d0Swren romano       assert(i <= std::numeric_limits<I>::max() &&
6174d0a18d0Swren romano              "Index value is too large for the I-type");
61872ec2f76Swren romano       indices[d].push_back(static_cast<I>(i));
61972ec2f76Swren romano     } else { // Dense dimension.
62072ec2f76Swren romano       assert(i >= full && "Index was already filled");
62172ec2f76Swren romano       if (i == full)
62272ec2f76Swren romano         return; // Short-circuit, since it'll be a nop.
62372ec2f76Swren romano       if (d + 1 == getRank())
62472ec2f76Swren romano         values.insert(values.end(), i - full, 0);
62572ec2f76Swren romano       else
62672ec2f76Swren romano         finalizeSegment(d + 1, 0, i - full);
62772ec2f76Swren romano     }
6284d0a18d0Swren romano   }
6294d0a18d0Swren romano 
6308cb33240Swren romano   /// Writes the given coordinate to `indices[d][pos]`.  This method
6318cb33240Swren romano   /// checks that `i` is representable in the `I` type; however, it
6328cb33240Swren romano   /// does not check that `i` is semantically valid (i.e., in bounds
633fa6aed2aSwren romano   /// for `dimSizes[d]` and not elsewhere occurring in the same segment).
writeIndex(uint64_t d,uint64_t pos,uint64_t i)6348cb33240Swren romano   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
6358cb33240Swren romano     assert(isCompressedDim(d));
6368cb33240Swren romano     // Subscript assignment to `std::vector` requires that the `pos`-th
6378cb33240Swren romano     // entry has been initialized; thus we must be sure to check `size()`
6388cb33240Swren romano     // here, instead of `capacity()` as would be ideal.
6398cb33240Swren romano     assert(pos < indices[d].size() && "Index position is out of bounds");
6408cb33240Swren romano     assert(i <= std::numeric_limits<I>::max() &&
6418cb33240Swren romano            "Index value is too large for the I-type");
6428cb33240Swren romano     indices[d][pos] = static_cast<I>(i);
6438cb33240Swren romano   }
6448cb33240Swren romano 
6458cb33240Swren romano   /// Computes the assembled-size associated with the `d`-th dimension,
6468cb33240Swren romano   /// given the assembled-size associated with the `(d-1)`-th dimension.
6478cb33240Swren romano   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
6488cb33240Swren romano   /// storage, as opposed to "dimension-sizes" which are the cardinality
6498cb33240Swren romano   /// of coordinates for that dimension.
6508cb33240Swren romano   ///
6518cb33240Swren romano   /// Precondition: the `pointers[d]` array must be fully initialized
6528cb33240Swren romano   /// before calling this method.
assembledSize(uint64_t parentSz,uint64_t d) const6538cb33240Swren romano   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
6548cb33240Swren romano     if (isCompressedDim(d))
6558cb33240Swren romano       return pointers[d][parentSz];
6568cb33240Swren romano     // else if dense:
6578cb33240Swren romano     return parentSz * getDimSizes()[d];
6588cb33240Swren romano   }
6598cb33240Swren romano 
6608a91bc7bSHarrietAkot   /// Initializes sparse tensor storage scheme from a memory-resident sparse
6618a91bc7bSHarrietAkot   /// tensor in coordinate scheme. This method prepares the pointers and
6628a91bc7bSHarrietAkot   /// indices arrays under the given per-dimension dense/sparse annotations.
6634d0a18d0Swren romano   ///
6644d0a18d0Swren romano   /// Preconditions:
6654d0a18d0Swren romano   /// (1) the `elements` must be lexicographically sorted.
666fa6aed2aSwren romano   /// (2) the indices of every element are valid for `dimSizes` (equal rank
6674d0a18d0Swren romano   ///     and pointwise less-than).
fromCOO(const std::vector<Element<V>> & elements,uint64_t lo,uint64_t hi,uint64_t d)668ceda1ae9Swren romano   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
669ceda1ae9Swren romano                uint64_t hi, uint64_t d) {
670753fe330Swren romano     uint64_t rank = getRank();
671753fe330Swren romano     assert(d <= rank && hi <= elements.size());
6728a91bc7bSHarrietAkot     // Once dimensions are exhausted, insert the numerical values.
673753fe330Swren romano     if (d == rank) {
674c4017f9dSwren romano       assert(lo < hi);
6751ce77b56SAart Bik       values.push_back(elements[lo].value);
6768a91bc7bSHarrietAkot       return;
6778a91bc7bSHarrietAkot     }
6788a91bc7bSHarrietAkot     // Visit all elements in this interval.
6798a91bc7bSHarrietAkot     uint64_t full = 0;
680c4017f9dSwren romano     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
6818a91bc7bSHarrietAkot       // Find segment in interval with same index elements in this dimension.
682f66e5769SAart Bik       uint64_t i = elements[lo].indices[d];
6838a91bc7bSHarrietAkot       uint64_t seg = lo + 1;
684f66e5769SAart Bik       while (seg < hi && elements[seg].indices[d] == i)
6858a91bc7bSHarrietAkot         seg++;
6868a91bc7bSHarrietAkot       // Handle segment in interval for sparse or dense dimension.
68772ec2f76Swren romano       appendIndex(d, full, i);
68872ec2f76Swren romano       full = i + 1;
689ceda1ae9Swren romano       fromCOO(elements, lo, seg, d + 1);
6908a91bc7bSHarrietAkot       // And move on to next segment in interval.
6918a91bc7bSHarrietAkot       lo = seg;
6928a91bc7bSHarrietAkot     }
6938a91bc7bSHarrietAkot     // Finalize the sparse pointer structure at this dimension.
69472ec2f76Swren romano     finalizeSegment(d, full);
6958a91bc7bSHarrietAkot   }
6968a91bc7bSHarrietAkot 
69772ec2f76Swren romano   /// Finalize the sparse pointer structure at this dimension.
finalizeSegment(uint64_t d,uint64_t full=0,uint64_t count=1)69872ec2f76Swren romano   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
69972ec2f76Swren romano     if (count == 0)
70072ec2f76Swren romano       return; // Short-circuit, since it'll be a nop.
70172ec2f76Swren romano     if (isCompressedDim(d)) {
70272ec2f76Swren romano       appendPointer(d, indices[d].size(), count);
70372ec2f76Swren romano     } else { // Dense dimension.
7048d8b566fSwren romano       const uint64_t sz = getDimSizes()[d];
70572ec2f76Swren romano       assert(sz >= full && "Segment is overfull");
7068d8b566fSwren romano       count = checkedMul(count, sz - full);
70772ec2f76Swren romano       // For dense storage we must enumerate all the remaining coordinates
70872ec2f76Swren romano       // in this dimension (i.e., coordinates after the last non-zero
70972ec2f76Swren romano       // element), and either fill in their zero values or else recurse
71072ec2f76Swren romano       // to finalize some deeper dimension.
71172ec2f76Swren romano       if (d + 1 == getRank())
71272ec2f76Swren romano         values.insert(values.end(), count, 0);
71372ec2f76Swren romano       else
71472ec2f76Swren romano         finalizeSegment(d + 1, 0, count);
7151ce77b56SAart Bik     }
7161ce77b56SAart Bik   }
7171ce77b56SAart Bik 
7181ce77b56SAart Bik   /// Wraps up a single insertion path, inner to outer.
endPath(uint64_t diff)7191ce77b56SAart Bik   void endPath(uint64_t diff) {
7201ce77b56SAart Bik     uint64_t rank = getRank();
7211ce77b56SAart Bik     assert(diff <= rank);
7221ce77b56SAart Bik     for (uint64_t i = 0; i < rank - diff; i++) {
72372ec2f76Swren romano       const uint64_t d = rank - i - 1;
72472ec2f76Swren romano       finalizeSegment(d, idx[d] + 1);
7251ce77b56SAart Bik     }
7261ce77b56SAart Bik   }
7271ce77b56SAart Bik 
7281ce77b56SAart Bik   /// Continues a single insertion path, outer to inner.
insPath(const uint64_t * cursor,uint64_t diff,uint64_t top,V val)729c03fd1e6Swren romano   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
7301ce77b56SAart Bik     uint64_t rank = getRank();
7311ce77b56SAart Bik     assert(diff < rank);
7321ce77b56SAart Bik     for (uint64_t d = diff; d < rank; d++) {
7331ce77b56SAart Bik       uint64_t i = cursor[d];
73472ec2f76Swren romano       appendIndex(d, top, i);
7351ce77b56SAart Bik       top = 0;
7361ce77b56SAart Bik       idx[d] = i;
7371ce77b56SAart Bik     }
7381ce77b56SAart Bik     values.push_back(val);
7391ce77b56SAart Bik   }
7401ce77b56SAart Bik 
7411ce77b56SAart Bik   /// Finds the lexicographic differing dimension.
lexDiff(const uint64_t * cursor) const74246bdacaaSwren romano   uint64_t lexDiff(const uint64_t *cursor) const {
7431ce77b56SAart Bik     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
7441ce77b56SAart Bik       if (cursor[r] > idx[r])
7451ce77b56SAart Bik         return r;
7461ce77b56SAart Bik       else
7471ce77b56SAart Bik         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
7481ce77b56SAart Bik     assert(0 && "duplication insertion");
7491ce77b56SAart Bik     return -1u;
7501ce77b56SAart Bik   }
7511ce77b56SAart Bik 
752753fe330Swren romano   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
753753fe330Swren romano   // the cost of virtual-function dispatch in inner loops), without
754753fe330Swren romano   // making them public to other client code.
755753fe330Swren romano   friend class SparseTensorEnumerator<P, I, V>;
756753fe330Swren romano 
7578a91bc7bSHarrietAkot   std::vector<std::vector<P>> pointers;
7588a91bc7bSHarrietAkot   std::vector<std::vector<I>> indices;
7598a91bc7bSHarrietAkot   std::vector<V> values;
7608d8b566fSwren romano   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
7618a91bc7bSHarrietAkot };
7628a91bc7bSHarrietAkot 
763753fe330Swren romano /// A (higher-order) function object for enumerating the elements of some
764753fe330Swren romano /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
765753fe330Swren romano /// method encapsulates the loop-nest for enumerating the elements of
766753fe330Swren romano /// the source tensor (in whatever order is best for the source tensor),
767753fe330Swren romano /// and applies a permutation to the coordinates/indices before handing
768753fe330Swren romano /// each element to the callback.  A single enumerator object can be
769753fe330Swren romano /// freely reused for several calls to `forallElements`, just so long
770753fe330Swren romano /// as each call is sequential with respect to one another.
771753fe330Swren romano ///
772753fe330Swren romano /// N.B., this class stores a reference to the `SparseTensorStorageBase`
773753fe330Swren romano /// passed to the constructor; thus, objects of this class must not
774753fe330Swren romano /// outlive the sparse tensor they depend on.
775753fe330Swren romano ///
776753fe330Swren romano /// Design Note: The reason we define this class instead of simply using
777753fe330Swren romano /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
778753fe330Swren romano /// the `<P,I>` template parameters from MLIR client code (to simplify the
779753fe330Swren romano /// type parameters used for direct sparse-to-sparse conversion).  And the
780753fe330Swren romano /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
781753fe330Swren romano /// than simply using this class, is to avoid the cost of virtual-method
782753fe330Swren romano /// dispatch within the loop-nest.
783753fe330Swren romano template <typename V>
784753fe330Swren romano class SparseTensorEnumeratorBase {
785753fe330Swren romano public:
786753fe330Swren romano   /// Constructs an enumerator with the given permutation for mapping
787753fe330Swren romano   /// the semantic-ordering of dimensions to the desired target-ordering.
788753fe330Swren romano   ///
789753fe330Swren romano   /// Preconditions:
790753fe330Swren romano   /// * the `tensor` must have the same `V` value type.
791753fe330Swren romano   /// * `perm` must be valid for `rank`.
SparseTensorEnumeratorBase(const SparseTensorStorageBase & tensor,uint64_t rank,const uint64_t * perm)792753fe330Swren romano   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
793753fe330Swren romano                              uint64_t rank, const uint64_t *perm)
794753fe330Swren romano       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
795753fe330Swren romano         cursor(getRank()) {
796753fe330Swren romano     assert(perm && "Received nullptr for permutation");
797753fe330Swren romano     assert(rank == getRank() && "Permutation rank mismatch");
798fa6aed2aSwren romano     const auto &rev = src.getRev();           // source-order -> semantic-order
799fa6aed2aSwren romano     const auto &dimSizes = src.getDimSizes(); // in source storage-order
800753fe330Swren romano     for (uint64_t s = 0; s < rank; s++) {     // `s` source storage-order
801753fe330Swren romano       uint64_t t = perm[rev[s]];              // `t` target-order
802753fe330Swren romano       reord[s] = t;
803fa6aed2aSwren romano       permsz[t] = dimSizes[s];
804753fe330Swren romano     }
805753fe330Swren romano   }
806753fe330Swren romano 
807753fe330Swren romano   virtual ~SparseTensorEnumeratorBase() = default;
808753fe330Swren romano 
809753fe330Swren romano   // We disallow copying to help avoid leaking the `src` reference.
810753fe330Swren romano   // (In addition to avoiding the problem of slicing.)
811753fe330Swren romano   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
812753fe330Swren romano   SparseTensorEnumeratorBase &
813753fe330Swren romano   operator=(const SparseTensorEnumeratorBase &) = delete;
814753fe330Swren romano 
815753fe330Swren romano   /// Returns the source/target tensor's rank.  (The source-rank and
816753fe330Swren romano   /// target-rank are always equal since we only support permutations.
817753fe330Swren romano   /// Though once we add support for other dimension mappings, this
818753fe330Swren romano   /// method will have to be split in two.)
getRank() const819753fe330Swren romano   uint64_t getRank() const { return permsz.size(); }
820753fe330Swren romano 
821753fe330Swren romano   /// Returns the target tensor's dimension sizes.
permutedSizes() const822753fe330Swren romano   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
823753fe330Swren romano 
824753fe330Swren romano   /// Enumerates all elements of the source tensor, permutes their
825753fe330Swren romano   /// indices, and passes the permuted element to the callback.
826753fe330Swren romano   /// The callback must not store the cursor reference directly,
827753fe330Swren romano   /// since this function reuses the storage.  Instead, the callback
828753fe330Swren romano   /// must copy it if they want to keep it.
829753fe330Swren romano   virtual void forallElements(ElementConsumer<V> yield) = 0;
830753fe330Swren romano 
831753fe330Swren romano protected:
832753fe330Swren romano   const SparseTensorStorageBase &src;
833753fe330Swren romano   std::vector<uint64_t> permsz; // in target order.
834753fe330Swren romano   std::vector<uint64_t> reord;  // source storage-order -> target order.
835753fe330Swren romano   std::vector<uint64_t> cursor; // in target order.
836753fe330Swren romano };
837753fe330Swren romano 
838753fe330Swren romano template <typename P, typename I, typename V>
839753fe330Swren romano class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
840753fe330Swren romano   using Base = SparseTensorEnumeratorBase<V>;
841753fe330Swren romano 
842753fe330Swren romano public:
843753fe330Swren romano   /// Constructs an enumerator with the given permutation for mapping
844753fe330Swren romano   /// the semantic-ordering of dimensions to the desired target-ordering.
845753fe330Swren romano   ///
846753fe330Swren romano   /// Precondition: `perm` must be valid for `rank`.
SparseTensorEnumerator(const SparseTensorStorage<P,I,V> & tensor,uint64_t rank,const uint64_t * perm)847753fe330Swren romano   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
848753fe330Swren romano                          uint64_t rank, const uint64_t *perm)
849753fe330Swren romano       : Base(tensor, rank, perm) {}
850753fe330Swren romano 
851f38765a8SMehdi Amini   ~SparseTensorEnumerator() final = default;
852753fe330Swren romano 
forallElements(ElementConsumer<V> yield)853f38765a8SMehdi Amini   void forallElements(ElementConsumer<V> yield) final {
854753fe330Swren romano     forallElements(yield, 0, 0);
855753fe330Swren romano   }
856753fe330Swren romano 
857753fe330Swren romano private:
858753fe330Swren romano   /// The recursive component of the public `forallElements`.
forallElements(ElementConsumer<V> yield,uint64_t parentPos,uint64_t d)859753fe330Swren romano   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
860753fe330Swren romano                       uint64_t d) {
861753fe330Swren romano     // Recover the `<P,I,V>` type parameters of `src`.
862753fe330Swren romano     const auto &src =
863753fe330Swren romano         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
864753fe330Swren romano     if (d == Base::getRank()) {
865753fe330Swren romano       assert(parentPos < src.values.size() &&
866753fe330Swren romano              "Value position is out of bounds");
867753fe330Swren romano       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
868753fe330Swren romano       yield(this->cursor, src.values[parentPos]);
869753fe330Swren romano     } else if (src.isCompressedDim(d)) {
870753fe330Swren romano       // Look up the bounds of the `d`-level segment determined by the
871753fe330Swren romano       // `d-1`-level position `parentPos`.
872d8c46eb6SMehdi Amini       const std::vector<P> &pointersD = src.pointers[d];
873d8c46eb6SMehdi Amini       assert(parentPos + 1 < pointersD.size() &&
874753fe330Swren romano              "Parent pointer position is out of bounds");
875d8c46eb6SMehdi Amini       const uint64_t pstart = static_cast<uint64_t>(pointersD[parentPos]);
876d8c46eb6SMehdi Amini       const uint64_t pstop = static_cast<uint64_t>(pointersD[parentPos + 1]);
877753fe330Swren romano       // Loop-invariant code for looking up the `d`-level coordinates/indices.
878d8c46eb6SMehdi Amini       const std::vector<I> &indicesD = src.indices[d];
879d8c46eb6SMehdi Amini       assert(pstop <= indicesD.size() && "Index position is out of bounds");
880d8c46eb6SMehdi Amini       uint64_t &cursorReordD = this->cursor[this->reord[d]];
881753fe330Swren romano       for (uint64_t pos = pstart; pos < pstop; pos++) {
882d8c46eb6SMehdi Amini         cursorReordD = static_cast<uint64_t>(indicesD[pos]);
883753fe330Swren romano         forallElements(yield, pos, d + 1);
884753fe330Swren romano       }
885753fe330Swren romano     } else { // Dense dimension.
886753fe330Swren romano       const uint64_t sz = src.getDimSizes()[d];
887753fe330Swren romano       const uint64_t pstart = parentPos * sz;
888d8c46eb6SMehdi Amini       uint64_t &cursorReordD = this->cursor[this->reord[d]];
889753fe330Swren romano       for (uint64_t i = 0; i < sz; i++) {
890d8c46eb6SMehdi Amini         cursorReordD = i;
891753fe330Swren romano         forallElements(yield, pstart + i, d + 1);
892753fe330Swren romano       }
893753fe330Swren romano     }
894753fe330Swren romano   }
895753fe330Swren romano };
896753fe330Swren romano 
8978cb33240Swren romano /// Statistics regarding the number of nonzero subtensors in
8988cb33240Swren romano /// a source tensor, for direct sparse=>sparse conversion a la
8998cb33240Swren romano /// <https://arxiv.org/abs/2001.02609>.
9008cb33240Swren romano ///
9018cb33240Swren romano /// N.B., this class stores references to the parameters passed to
9028cb33240Swren romano /// the constructor; thus, objects of this class must not outlive
9038cb33240Swren romano /// those parameters.
90476944420Swren romano class SparseTensorNNZ final {
9058cb33240Swren romano public:
9068cb33240Swren romano   /// Allocate the statistics structure for the desired sizes and
9078cb33240Swren romano   /// sparsity (in the target tensor's storage-order).  This constructor
9088cb33240Swren romano   /// does not actually populate the statistics, however; for that see
9098cb33240Swren romano   /// `initialize`.
9108cb33240Swren romano   ///
911fa6aed2aSwren romano   /// Precondition: `dimSizes` must not contain zeros.
SparseTensorNNZ(const std::vector<uint64_t> & dimSizes,const std::vector<DimLevelType> & sparsity)912fa6aed2aSwren romano   SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
9138cb33240Swren romano                   const std::vector<DimLevelType> &sparsity)
914fa6aed2aSwren romano       : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
9158cb33240Swren romano     assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
9168cb33240Swren romano     bool uncompressed = true;
917*c35807f2SJacques Pienaar     (void)uncompressed;
9188cb33240Swren romano     uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
9198cb33240Swren romano     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
9208cb33240Swren romano       switch (dimTypes[r]) {
9218cb33240Swren romano       case DimLevelType::kCompressed:
9228cb33240Swren romano         assert(uncompressed &&
9238cb33240Swren romano                "Multiple compressed layers not currently supported");
9248cb33240Swren romano         uncompressed = false;
9258cb33240Swren romano         nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
9268cb33240Swren romano         break;
9278cb33240Swren romano       case DimLevelType::kDense:
9288cb33240Swren romano         assert(uncompressed &&
9298cb33240Swren romano                "Dense after compressed not currently supported");
9308cb33240Swren romano         break;
9318cb33240Swren romano       case DimLevelType::kSingleton:
9328cb33240Swren romano         // Singleton after Compressed causes no problems for allocating
9338cb33240Swren romano         // `nnz` nor for the yieldPos loop.  This remains true even
9348cb33240Swren romano         // when adding support for multiple compressed dimensions or
9358cb33240Swren romano         // for dense-after-compressed.
9368cb33240Swren romano         break;
9378cb33240Swren romano       }
9388cb33240Swren romano       sz = checkedMul(sz, dimSizes[r]);
9398cb33240Swren romano     }
9408cb33240Swren romano   }
9418cb33240Swren romano 
9428cb33240Swren romano   // We disallow copying to help avoid leaking the stored references.
9438cb33240Swren romano   SparseTensorNNZ(const SparseTensorNNZ &) = delete;
9448cb33240Swren romano   SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
9458cb33240Swren romano 
9468cb33240Swren romano   /// Returns the rank of the target tensor.
getRank() const9478cb33240Swren romano   uint64_t getRank() const { return dimSizes.size(); }
9488cb33240Swren romano 
9498cb33240Swren romano   /// Enumerate the source tensor to fill in the statistics.  The
9508cb33240Swren romano   /// enumerator should already incorporate the permutation (from
9518cb33240Swren romano   /// semantic-order to the target storage-order).
9528cb33240Swren romano   template <typename V>
initialize(SparseTensorEnumeratorBase<V> & enumerator)9538cb33240Swren romano   void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
9548cb33240Swren romano     assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
9558cb33240Swren romano     assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
9568cb33240Swren romano     enumerator.forallElements(
9578cb33240Swren romano         [this](const std::vector<uint64_t> &ind, V) { add(ind); });
9588cb33240Swren romano   }
9598cb33240Swren romano 
9608cb33240Swren romano   /// The type of callback functions which receive an nnz-statistic.
9618cb33240Swren romano   using NNZConsumer = const std::function<void(uint64_t)> &;
9628cb33240Swren romano 
9638cb33240Swren romano   /// Lexicographically enumerates all indicies for dimensions strictly
9648cb33240Swren romano   /// less than `stopDim`, and passes their nnz statistic to the callback.
9658cb33240Swren romano   /// Since our use-case only requires the statistic not the coordinates
9668cb33240Swren romano   /// themselves, we do not bother to construct those coordinates.
forallIndices(uint64_t stopDim,NNZConsumer yield) const9678cb33240Swren romano   void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
9688cb33240Swren romano     assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
9698cb33240Swren romano     assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
9708cb33240Swren romano            "Cannot look up non-compressed dimensions");
9718cb33240Swren romano     forallIndices(yield, stopDim, 0, 0);
9728cb33240Swren romano   }
9738cb33240Swren romano 
9748cb33240Swren romano private:
9758cb33240Swren romano   /// Adds a new element (i.e., increment its statistics).  We use
9768cb33240Swren romano   /// a method rather than inlining into the lambda in `initialize`,
9778cb33240Swren romano   /// to avoid spurious templating over `V`.  And this method is private
9788cb33240Swren romano   /// to avoid needing to re-assert validity of `ind` (which is guaranteed
9798cb33240Swren romano   /// by `forallElements`).
add(const std::vector<uint64_t> & ind)9808cb33240Swren romano   void add(const std::vector<uint64_t> &ind) {
9818cb33240Swren romano     uint64_t parentPos = 0;
9828cb33240Swren romano     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
9838cb33240Swren romano       if (dimTypes[r] == DimLevelType::kCompressed)
9848cb33240Swren romano         nnz[r][parentPos]++;
9858cb33240Swren romano       parentPos = parentPos * dimSizes[r] + ind[r];
9868cb33240Swren romano     }
9878cb33240Swren romano   }
9888cb33240Swren romano 
9898cb33240Swren romano   /// Recursive component of the public `forallIndices`.
forallIndices(NNZConsumer yield,uint64_t stopDim,uint64_t parentPos,uint64_t d) const9908cb33240Swren romano   void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
9918cb33240Swren romano                      uint64_t d) const {
9928cb33240Swren romano     assert(d <= stopDim);
9938cb33240Swren romano     if (d == stopDim) {
9948cb33240Swren romano       assert(parentPos < nnz[d].size() && "Cursor is out of range");
9958cb33240Swren romano       yield(nnz[d][parentPos]);
9968cb33240Swren romano     } else {
9978cb33240Swren romano       const uint64_t sz = dimSizes[d];
9988cb33240Swren romano       const uint64_t pstart = parentPos * sz;
9998cb33240Swren romano       for (uint64_t i = 0; i < sz; i++)
10008cb33240Swren romano         forallIndices(yield, stopDim, pstart + i, d + 1);
10018cb33240Swren romano     }
10028cb33240Swren romano   }
10038cb33240Swren romano 
10048cb33240Swren romano   // All of these are in the target storage-order.
10058cb33240Swren romano   const std::vector<uint64_t> &dimSizes;
10068cb33240Swren romano   const std::vector<DimLevelType> &dimTypes;
10078cb33240Swren romano   std::vector<std::vector<uint64_t>> nnz;
10088cb33240Swren romano };
10098cb33240Swren romano 
10108cb33240Swren romano template <typename P, typename I, typename V>
SparseTensorStorage(const std::vector<uint64_t> & dimSizes,const uint64_t * perm,const DimLevelType * sparsity,const SparseTensorStorageBase & tensor)10118cb33240Swren romano SparseTensorStorage<P, I, V>::SparseTensorStorage(
1012fa6aed2aSwren romano     const std::vector<uint64_t> &dimSizes, const uint64_t *perm,
10138cb33240Swren romano     const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1014fa6aed2aSwren romano     : SparseTensorStorage(dimSizes, perm, sparsity) {
10158cb33240Swren romano   SparseTensorEnumeratorBase<V> *enumerator;
10168cb33240Swren romano   tensor.newEnumerator(&enumerator, getRank(), perm);
10178cb33240Swren romano   {
10188cb33240Swren romano     // Initialize the statistics structure.
10198cb33240Swren romano     SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
10208cb33240Swren romano     nnz.initialize(*enumerator);
10218cb33240Swren romano     // Initialize "pointers" overhead (and allocate "indices", "values").
10228cb33240Swren romano     uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
10238cb33240Swren romano     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
10248cb33240Swren romano       if (isCompressedDim(r)) {
10258cb33240Swren romano         pointers[r].reserve(parentSz + 1);
10268cb33240Swren romano         pointers[r].push_back(0);
10278cb33240Swren romano         uint64_t currentPos = 0;
10288cb33240Swren romano         nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
10298cb33240Swren romano           currentPos += n;
10308cb33240Swren romano           appendPointer(r, currentPos);
10318cb33240Swren romano         });
10328cb33240Swren romano         assert(pointers[r].size() == parentSz + 1 &&
10338cb33240Swren romano                "Final pointers size doesn't match allocated size");
10348cb33240Swren romano         // That assertion entails `assembledSize(parentSz, r)`
10358cb33240Swren romano         // is now in a valid state.  That is, `pointers[r][parentSz]`
10368cb33240Swren romano         // equals the present value of `currentPos`, which is the
10378cb33240Swren romano         // correct assembled-size for `indices[r]`.
10388cb33240Swren romano       }
10398cb33240Swren romano       // Update assembled-size for the next iteration.
10408cb33240Swren romano       parentSz = assembledSize(parentSz, r);
10418cb33240Swren romano       // Ideally we need only `indices[r].reserve(parentSz)`, however
10428cb33240Swren romano       // the `std::vector` implementation forces us to initialize it too.
10438cb33240Swren romano       // That is, in the yieldPos loop we need random-access assignment
10448cb33240Swren romano       // to `indices[r]`; however, `std::vector`'s subscript-assignment
10458cb33240Swren romano       // only allows assigning to already-initialized positions.
10468cb33240Swren romano       if (isCompressedDim(r))
10478cb33240Swren romano         indices[r].resize(parentSz, 0);
10488cb33240Swren romano     }
10498cb33240Swren romano     values.resize(parentSz, 0); // Both allocate and zero-initialize.
10508cb33240Swren romano   }
10518cb33240Swren romano   // The yieldPos loop
10528cb33240Swren romano   enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
10538cb33240Swren romano     uint64_t parentSz = 1, parentPos = 0;
10548cb33240Swren romano     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
10558cb33240Swren romano       if (isCompressedDim(r)) {
10568cb33240Swren romano         // If `parentPos == parentSz` then it's valid as an array-lookup;
10578cb33240Swren romano         // however, it's semantically invalid here since that entry
10588cb33240Swren romano         // does not represent a segment of `indices[r]`.  Moreover, that
10598cb33240Swren romano         // entry must be immutable for `assembledSize` to remain valid.
10608cb33240Swren romano         assert(parentPos < parentSz && "Pointers position is out of bounds");
10618cb33240Swren romano         const uint64_t currentPos = pointers[r][parentPos];
10628cb33240Swren romano         // This increment won't overflow the `P` type, since it can't
10638cb33240Swren romano         // exceed the original value of `pointers[r][parentPos+1]`
10648cb33240Swren romano         // which was already verified to be within bounds for `P`
10658cb33240Swren romano         // when it was written to the array.
10668cb33240Swren romano         pointers[r][parentPos]++;
10678cb33240Swren romano         writeIndex(r, currentPos, ind[r]);
10688cb33240Swren romano         parentPos = currentPos;
10698cb33240Swren romano       } else { // Dense dimension.
10708cb33240Swren romano         parentPos = parentPos * getDimSizes()[r] + ind[r];
10718cb33240Swren romano       }
10728cb33240Swren romano       parentSz = assembledSize(parentSz, r);
10738cb33240Swren romano     }
10748cb33240Swren romano     assert(parentPos < values.size() && "Value position is out of bounds");
10758cb33240Swren romano     values[parentPos] = val;
10768cb33240Swren romano   });
10778cb33240Swren romano   // No longer need the enumerator, so we'll delete it ASAP.
10788cb33240Swren romano   delete enumerator;
10798cb33240Swren romano   // The finalizeYieldPos loop
10808cb33240Swren romano   for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
10818cb33240Swren romano     if (isCompressedDim(r)) {
10828cb33240Swren romano       assert(parentSz == pointers[r].size() - 1 &&
10838cb33240Swren romano              "Actual pointers size doesn't match the expected size");
10848cb33240Swren romano       // Can't check all of them, but at least we can check the last one.
10858cb33240Swren romano       assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
10868cb33240Swren romano              "Pointers got corrupted");
10878cb33240Swren romano       // TODO: optimize this by using `memmove` or similar.
10888cb33240Swren romano       for (uint64_t n = 0; n < parentSz; n++) {
10898cb33240Swren romano         const uint64_t parentPos = parentSz - n;
10908cb33240Swren romano         pointers[r][parentPos] = pointers[r][parentPos - 1];
10918cb33240Swren romano       }
10928cb33240Swren romano       pointers[r][0] = 0;
10938cb33240Swren romano     }
10948cb33240Swren romano     parentSz = assembledSize(parentSz, r);
10958cb33240Swren romano   }
10968cb33240Swren romano }
10978cb33240Swren romano 
10988a91bc7bSHarrietAkot /// Helper to convert string to lower case.
toLower(char * token)10998a91bc7bSHarrietAkot static char *toLower(char *token) {
11008a91bc7bSHarrietAkot   for (char *c = token; *c; c++)
11018a91bc7bSHarrietAkot     *c = tolower(*c);
11028a91bc7bSHarrietAkot   return token;
11038a91bc7bSHarrietAkot }
11048a91bc7bSHarrietAkot 
1105a4c53f8cSwren romano /// This class abstracts over the information stored in file headers,
1106a4c53f8cSwren romano /// as well as providing the buffers and methods for parsing those headers.
1107a4c53f8cSwren romano class SparseTensorFile final {
1108a4c53f8cSwren romano public:
11095b1c5fc5Sbixia1   enum class ValueKind {
11105b1c5fc5Sbixia1     kInvalid = 0,
11115b1c5fc5Sbixia1     kPattern = 1,
11125b1c5fc5Sbixia1     kReal = 2,
11135b1c5fc5Sbixia1     kInteger = 3,
11145b1c5fc5Sbixia1     kComplex = 4,
11155b1c5fc5Sbixia1     kUndefined = 5
11165b1c5fc5Sbixia1   };
11175b1c5fc5Sbixia1 
SparseTensorFile(char * filename)1118a4c53f8cSwren romano   explicit SparseTensorFile(char *filename) : filename(filename) {
1119a4c53f8cSwren romano     assert(filename && "Received nullptr for filename");
1120a4c53f8cSwren romano   }
1121a4c53f8cSwren romano 
1122a4c53f8cSwren romano   // Disallows copying, to avoid duplicating the `file` pointer.
1123a4c53f8cSwren romano   SparseTensorFile(const SparseTensorFile &) = delete;
1124a4c53f8cSwren romano   SparseTensorFile &operator=(const SparseTensorFile &) = delete;
1125a4c53f8cSwren romano 
1126a4c53f8cSwren romano   // This dtor tries to avoid leaking the `file`.  (Though it's better
1127a4c53f8cSwren romano   // to call `closeFile` explicitly when possible, since there are
1128a4c53f8cSwren romano   // circumstances where dtors are not called reliably.)
~SparseTensorFile()1129a4c53f8cSwren romano   ~SparseTensorFile() { closeFile(); }
1130a4c53f8cSwren romano 
1131a4c53f8cSwren romano   /// Opens the file for reading.
openFile()1132a4c53f8cSwren romano   void openFile() {
1133a4c53f8cSwren romano     if (file)
1134a4c53f8cSwren romano       FATAL("Already opened file %s\n", filename);
1135a4c53f8cSwren romano     file = fopen(filename, "r");
1136a4c53f8cSwren romano     if (!file)
1137a4c53f8cSwren romano       FATAL("Cannot find file %s\n", filename);
1138a4c53f8cSwren romano   }
1139a4c53f8cSwren romano 
1140a4c53f8cSwren romano   /// Closes the file.
closeFile()1141a4c53f8cSwren romano   void closeFile() {
1142a4c53f8cSwren romano     if (file) {
1143a4c53f8cSwren romano       fclose(file);
1144a4c53f8cSwren romano       file = nullptr;
1145a4c53f8cSwren romano     }
1146a4c53f8cSwren romano   }
1147a4c53f8cSwren romano 
1148a4c53f8cSwren romano   // TODO(wrengr/bixia): figure out how to reorganize the element-parsing
1149a4c53f8cSwren romano   // loop of `openSparseTensorCOO` into methods of this class, so we can
1150a4c53f8cSwren romano   // avoid leaking access to the `line` pointer (both for general hygiene
1151a4c53f8cSwren romano   // and because we can't mark it const due to the second argument of
1152a4c53f8cSwren romano   // `strtoul`/`strtoud` being `char * *restrict` rather than
1153a4c53f8cSwren romano   // `char const* *restrict`).
1154a4c53f8cSwren romano   //
1155a4c53f8cSwren romano   /// Attempts to read a line from the file.
readLine()1156a4c53f8cSwren romano   char *readLine() {
1157a4c53f8cSwren romano     if (fgets(line, kColWidth, file))
1158a4c53f8cSwren romano       return line;
1159a4c53f8cSwren romano     FATAL("Cannot read next line of %s\n", filename);
1160a4c53f8cSwren romano   }
1161a4c53f8cSwren romano 
1162a4c53f8cSwren romano   /// Reads and parses the file's header.
readHeader()1163a4c53f8cSwren romano   void readHeader() {
1164a4c53f8cSwren romano     assert(file && "Attempt to readHeader() before openFile()");
1165a4c53f8cSwren romano     if (strstr(filename, ".mtx"))
1166a4c53f8cSwren romano       readMMEHeader();
1167a4c53f8cSwren romano     else if (strstr(filename, ".tns"))
1168a4c53f8cSwren romano       readExtFROSTTHeader();
1169a4c53f8cSwren romano     else
1170a4c53f8cSwren romano       FATAL("Unknown format %s\n", filename);
11715b1c5fc5Sbixia1     assert(isValid() && "Failed to read the header");
1172a4c53f8cSwren romano   }
1173a4c53f8cSwren romano 
getValueKind() const11745b1c5fc5Sbixia1   ValueKind getValueKind() const { return valueKind_; }
11755b1c5fc5Sbixia1 
isValid() const1176ff96d434Sbixia1   bool isValid() const { return valueKind_ != ValueKind::kInvalid; }
11775b1c5fc5Sbixia1 
1178a4c53f8cSwren romano   /// Gets the MME "pattern" property setting.  Is only valid after
1179a4c53f8cSwren romano   /// parsing the header.
isPattern() const1180a4c53f8cSwren romano   bool isPattern() const {
11815b1c5fc5Sbixia1     assert(isValid() && "Attempt to isPattern() before readHeader()");
11825b1c5fc5Sbixia1     return valueKind_ == ValueKind::kPattern;
1183a4c53f8cSwren romano   }
1184a4c53f8cSwren romano 
1185a4c53f8cSwren romano   /// Gets the MME "symmetric" property setting.  Is only valid after
1186a4c53f8cSwren romano   /// parsing the header.
isSymmetric() const1187a4c53f8cSwren romano   bool isSymmetric() const {
11885b1c5fc5Sbixia1     assert(isValid() && "Attempt to isSymmetric() before readHeader()");
1189a4c53f8cSwren romano     return isSymmetric_;
1190a4c53f8cSwren romano   }
1191a4c53f8cSwren romano 
1192a4c53f8cSwren romano   /// Gets the rank of the tensor.  Is only valid after parsing the header.
getRank() const1193a4c53f8cSwren romano   uint64_t getRank() const {
11945b1c5fc5Sbixia1     assert(isValid() && "Attempt to getRank() before readHeader()");
1195a4c53f8cSwren romano     return idata[0];
1196a4c53f8cSwren romano   }
1197a4c53f8cSwren romano 
1198a4c53f8cSwren romano   /// Gets the number of non-zeros.  Is only valid after parsing the header.
getNNZ() const1199a4c53f8cSwren romano   uint64_t getNNZ() const {
12005b1c5fc5Sbixia1     assert(isValid() && "Attempt to getNNZ() before readHeader()");
1201a4c53f8cSwren romano     return idata[1];
1202a4c53f8cSwren romano   }
1203a4c53f8cSwren romano 
1204a4c53f8cSwren romano   /// Gets the dimension-sizes array.  The pointer itself is always
1205a4c53f8cSwren romano   /// valid; however, the values stored therein are only valid after
1206a4c53f8cSwren romano   /// parsing the header.
getDimSizes() const1207a4c53f8cSwren romano   const uint64_t *getDimSizes() const { return idata + 2; }
1208a4c53f8cSwren romano 
1209a4c53f8cSwren romano   /// Safely gets the size of the given dimension.  Is only valid
1210a4c53f8cSwren romano   /// after parsing the header.
getDimSize(uint64_t d) const1211a4c53f8cSwren romano   uint64_t getDimSize(uint64_t d) const {
1212a4c53f8cSwren romano     assert(d < getRank());
1213a4c53f8cSwren romano     return idata[2 + d];
1214a4c53f8cSwren romano   }
1215a4c53f8cSwren romano 
1216a4c53f8cSwren romano   /// Asserts the shape subsumes the actual dimension sizes.  Is only
1217a4c53f8cSwren romano   /// valid after parsing the header.
assertMatchesShape(uint64_t rank,const uint64_t * shape) const1218a4c53f8cSwren romano   void assertMatchesShape(uint64_t rank, const uint64_t *shape) const {
1219a4c53f8cSwren romano     assert(rank == getRank() && "Rank mismatch");
1220a4c53f8cSwren romano     for (uint64_t r = 0; r < rank; r++)
1221a4c53f8cSwren romano       assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
1222a4c53f8cSwren romano              "Dimension size mismatch");
1223a4c53f8cSwren romano   }
1224a4c53f8cSwren romano 
1225a4c53f8cSwren romano private:
1226a4c53f8cSwren romano   void readMMEHeader();
1227a4c53f8cSwren romano   void readExtFROSTTHeader();
1228a4c53f8cSwren romano 
1229a4c53f8cSwren romano   const char *filename;
1230a4c53f8cSwren romano   FILE *file = nullptr;
12315b1c5fc5Sbixia1   ValueKind valueKind_ = ValueKind::kInvalid;
1232a4c53f8cSwren romano   bool isSymmetric_ = false;
1233a4c53f8cSwren romano   uint64_t idata[512];
1234a4c53f8cSwren romano   char line[kColWidth];
1235a4c53f8cSwren romano };
1236a4c53f8cSwren romano 
12378a91bc7bSHarrietAkot /// Read the MME header of a general sparse matrix of type real.
readMMEHeader()1238a4c53f8cSwren romano void SparseTensorFile::readMMEHeader() {
12398a91bc7bSHarrietAkot   char header[64];
12408a91bc7bSHarrietAkot   char object[64];
12418a91bc7bSHarrietAkot   char format[64];
12428a91bc7bSHarrietAkot   char field[64];
12438a91bc7bSHarrietAkot   char symmetry[64];
12448a91bc7bSHarrietAkot   // Read header line.
12458a91bc7bSHarrietAkot   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
1246774674ceSwren romano              symmetry) != 5)
1247774674ceSwren romano     FATAL("Corrupt header in %s\n", filename);
12485b1c5fc5Sbixia1   // Process `field`, which specify pattern or the data type of the values.
12495b1c5fc5Sbixia1   if (strcmp(toLower(field), "pattern") == 0)
12505b1c5fc5Sbixia1     valueKind_ = ValueKind::kPattern;
12515b1c5fc5Sbixia1   else if (strcmp(toLower(field), "real") == 0)
12525b1c5fc5Sbixia1     valueKind_ = ValueKind::kReal;
12535b1c5fc5Sbixia1   else if (strcmp(toLower(field), "integer") == 0)
12545b1c5fc5Sbixia1     valueKind_ = ValueKind::kInteger;
12555b1c5fc5Sbixia1   else if (strcmp(toLower(field), "complex") == 0)
12565b1c5fc5Sbixia1     valueKind_ = ValueKind::kComplex;
12575b1c5fc5Sbixia1   else
12585b1c5fc5Sbixia1     FATAL("Unexpected header field value in %s\n", filename);
12595b1c5fc5Sbixia1 
12605b1c5fc5Sbixia1   // Set properties.
1261a4c53f8cSwren romano   isSymmetric_ = (strcmp(toLower(symmetry), "symmetric") == 0);
12628a91bc7bSHarrietAkot   // Make sure this is a general sparse matrix.
12638a91bc7bSHarrietAkot   if (strcmp(toLower(header), "%%matrixmarket") ||
12648a91bc7bSHarrietAkot       strcmp(toLower(object), "matrix") ||
126533e8ab8eSAart Bik       strcmp(toLower(format), "coordinate") ||
1266a4c53f8cSwren romano       (strcmp(toLower(symmetry), "general") && !isSymmetric_))
1267774674ceSwren romano     FATAL("Cannot find a general sparse matrix in %s\n", filename);
12688a91bc7bSHarrietAkot   // Skip comments.
1269e5639b3fSMehdi Amini   while (true) {
1270a4c53f8cSwren romano     readLine();
12718a91bc7bSHarrietAkot     if (line[0] != '%')
12728a91bc7bSHarrietAkot       break;
12738a91bc7bSHarrietAkot   }
12748a91bc7bSHarrietAkot   // Next line contains M N NNZ.
12758a91bc7bSHarrietAkot   idata[0] = 2; // rank
12768a91bc7bSHarrietAkot   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
1277774674ceSwren romano              idata + 1) != 3)
1278774674ceSwren romano     FATAL("Cannot find size in %s\n", filename);
12798a91bc7bSHarrietAkot }
12808a91bc7bSHarrietAkot 
12818a91bc7bSHarrietAkot /// Read the "extended" FROSTT header. Although not part of the documented
12828a91bc7bSHarrietAkot /// format, we assume that the file starts with optional comments followed
12838a91bc7bSHarrietAkot /// by two lines that define the rank, the number of nonzeros, and the
12848a91bc7bSHarrietAkot /// dimensions sizes (one per rank) of the sparse tensor.
readExtFROSTTHeader()1285a4c53f8cSwren romano void SparseTensorFile::readExtFROSTTHeader() {
12868a91bc7bSHarrietAkot   // Skip comments.
1287e5639b3fSMehdi Amini   while (true) {
1288a4c53f8cSwren romano     readLine();
12898a91bc7bSHarrietAkot     if (line[0] != '#')
12908a91bc7bSHarrietAkot       break;
12918a91bc7bSHarrietAkot   }
12928a91bc7bSHarrietAkot   // Next line contains RANK and NNZ.
1293774674ceSwren romano   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
1294774674ceSwren romano     FATAL("Cannot find metadata in %s\n", filename);
12958a91bc7bSHarrietAkot   // Followed by a line with the dimension sizes (one per rank).
1296774674ceSwren romano   for (uint64_t r = 0; r < idata[0]; r++)
1297774674ceSwren romano     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
1298774674ceSwren romano       FATAL("Cannot find dimension size %s\n", filename);
1299a4c53f8cSwren romano   readLine(); // end of line
13005b1c5fc5Sbixia1   // The FROSTT format does not define the data type of the nonzero elements.
13015b1c5fc5Sbixia1   valueKind_ = ValueKind::kUndefined;
13025b1c5fc5Sbixia1 }
13035b1c5fc5Sbixia1 
13045b1c5fc5Sbixia1 // Adds a value to a tensor in coordinate scheme. If is_symmetric_value is true,
13055b1c5fc5Sbixia1 // also adds the value to its symmetric location.
13065b1c5fc5Sbixia1 template <typename T, typename V>
addValue(T * coo,V value,const std::vector<uint64_t> indices,bool is_symmetric_value)13075b1c5fc5Sbixia1 static inline void addValue(T *coo, V value,
13085b1c5fc5Sbixia1                             const std::vector<uint64_t> indices,
13095b1c5fc5Sbixia1                             bool is_symmetric_value) {
13105b1c5fc5Sbixia1   // TODO: <https://github.com/llvm/llvm-project/issues/54179>
13115b1c5fc5Sbixia1   coo->add(indices, value);
13125b1c5fc5Sbixia1   // We currently chose to deal with symmetric matrices by fully constructing
13135b1c5fc5Sbixia1   // them. In the future, we may want to make symmetry implicit for storage
13145b1c5fc5Sbixia1   // reasons.
13155b1c5fc5Sbixia1   if (is_symmetric_value)
13165b1c5fc5Sbixia1     coo->add({indices[1], indices[0]}, value);
13175b1c5fc5Sbixia1 }
13185b1c5fc5Sbixia1 
13195b1c5fc5Sbixia1 // Reads an element of a complex type for the current indices in coordinate
13205b1c5fc5Sbixia1 // scheme.
13215b1c5fc5Sbixia1 template <typename V>
readCOOValue(SparseTensorCOO<std::complex<V>> * coo,const std::vector<uint64_t> indices,char ** linePtr,bool is_pattern,bool add_symmetric_value)13225b1c5fc5Sbixia1 static inline void readCOOValue(SparseTensorCOO<std::complex<V>> *coo,
13235b1c5fc5Sbixia1                                 const std::vector<uint64_t> indices,
13245b1c5fc5Sbixia1                                 char **linePtr, bool is_pattern,
13255b1c5fc5Sbixia1                                 bool add_symmetric_value) {
13265b1c5fc5Sbixia1   // Read two values to make a complex. The external formats always store
13275b1c5fc5Sbixia1   // numerical values with the type double, but we cast these values to the
13285b1c5fc5Sbixia1   // sparse tensor object type. For a pattern tensor, we arbitrarily pick the
13295b1c5fc5Sbixia1   // value 1 for all entries.
13305b1c5fc5Sbixia1   V re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
13315b1c5fc5Sbixia1   V im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
13325b1c5fc5Sbixia1   std::complex<V> value = {re, im};
13335b1c5fc5Sbixia1   addValue(coo, value, indices, add_symmetric_value);
13345b1c5fc5Sbixia1 }
13355b1c5fc5Sbixia1 
13365b1c5fc5Sbixia1 // Reads an element of a non-complex type for the current indices in coordinate
13375b1c5fc5Sbixia1 // scheme.
13385b1c5fc5Sbixia1 template <typename V,
13395b1c5fc5Sbixia1           typename std::enable_if<
13405b1c5fc5Sbixia1               !std::is_same<std::complex<float>, V>::value &&
13415b1c5fc5Sbixia1               !std::is_same<std::complex<double>, V>::value>::type * = nullptr>
readCOOValue(SparseTensorCOO<V> * coo,const std::vector<uint64_t> indices,char ** linePtr,bool is_pattern,bool is_symmetric_value)13425b1c5fc5Sbixia1 static void inline readCOOValue(SparseTensorCOO<V> *coo,
13435b1c5fc5Sbixia1                                 const std::vector<uint64_t> indices,
13445b1c5fc5Sbixia1                                 char **linePtr, bool is_pattern,
13455b1c5fc5Sbixia1                                 bool is_symmetric_value) {
13465b1c5fc5Sbixia1   // The external formats always store these numerical values with the type
13475b1c5fc5Sbixia1   // double, but we cast these values to the sparse tensor object type.
13485b1c5fc5Sbixia1   // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
13495b1c5fc5Sbixia1   double value = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
13505b1c5fc5Sbixia1   addValue(coo, value, indices, is_symmetric_value);
13518a91bc7bSHarrietAkot }
13528a91bc7bSHarrietAkot 
13538a91bc7bSHarrietAkot /// Reads a sparse tensor with the given filename into a memory-resident
13548a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme.
13558a91bc7bSHarrietAkot template <typename V>
13565b1c5fc5Sbixia1 static SparseTensorCOO<V> *
openSparseTensorCOO(char * filename,uint64_t rank,const uint64_t * shape,const uint64_t * perm,PrimaryType valTp)13575b1c5fc5Sbixia1 openSparseTensorCOO(char *filename, uint64_t rank, const uint64_t *shape,
13585b1c5fc5Sbixia1                     const uint64_t *perm, PrimaryType valTp) {
1359a4c53f8cSwren romano   SparseTensorFile stfile(filename);
1360a4c53f8cSwren romano   stfile.openFile();
1361a4c53f8cSwren romano   stfile.readHeader();
13625b1c5fc5Sbixia1   // Check tensor element type against the value type in the input file.
13635b1c5fc5Sbixia1   SparseTensorFile::ValueKind valueKind = stfile.getValueKind();
13645b1c5fc5Sbixia1   bool tensorIsInteger =
13655b1c5fc5Sbixia1       (valTp >= PrimaryType::kI64 && valTp <= PrimaryType::kI8);
13665b1c5fc5Sbixia1   bool tensorIsReal = (valTp >= PrimaryType::kF64 && valTp <= PrimaryType::kI8);
13675b1c5fc5Sbixia1   if ((valueKind == SparseTensorFile::ValueKind::kReal && tensorIsInteger) ||
13685b1c5fc5Sbixia1       (valueKind == SparseTensorFile::ValueKind::kComplex && tensorIsReal)) {
13695b1c5fc5Sbixia1     FATAL("Tensor element type %d not compatible with values in file %s\n",
137042f5b050SDaniil Dudkin           static_cast<int>(valTp), filename);
13715b1c5fc5Sbixia1   }
1372a4c53f8cSwren romano   stfile.assertMatchesShape(rank, shape);
13738a91bc7bSHarrietAkot   // Prepare sparse tensor object with per-dimension sizes
13748a91bc7bSHarrietAkot   // and the number of nonzeros as initial capacity.
1375a4c53f8cSwren romano   uint64_t nnz = stfile.getNNZ();
1376a4c53f8cSwren romano   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, stfile.getDimSizes(),
1377a4c53f8cSwren romano                                                      perm, nnz);
13788a91bc7bSHarrietAkot   // Read all nonzero elements.
13798a91bc7bSHarrietAkot   std::vector<uint64_t> indices(rank);
13808a91bc7bSHarrietAkot   for (uint64_t k = 0; k < nnz; k++) {
1381a4c53f8cSwren romano     char *linePtr = stfile.readLine();
138203fe15ceSAart Bik     for (uint64_t r = 0; r < rank; r++) {
138303fe15ceSAart Bik       uint64_t idx = strtoul(linePtr, &linePtr, 10);
13848a91bc7bSHarrietAkot       // Add 0-based index.
13858a91bc7bSHarrietAkot       indices[perm[r]] = idx - 1;
13868a91bc7bSHarrietAkot     }
13875b1c5fc5Sbixia1     readCOOValue(coo, indices, &linePtr, stfile.isPattern(),
13885b1c5fc5Sbixia1                  stfile.isSymmetric() && indices[0] != indices[1]);
13898a91bc7bSHarrietAkot   }
13908a91bc7bSHarrietAkot   // Close the file and return tensor.
1391a4c53f8cSwren romano   stfile.closeFile();
1392a4c53f8cSwren romano   return coo;
13938a91bc7bSHarrietAkot }
13948a91bc7bSHarrietAkot 
13952046e11aSwren romano /// Writes the sparse tensor to `dest` in extended FROSTT format.
1396efa15f41SAart Bik template <typename V>
outSparseTensor(void * tensor,void * dest,bool sort)139746bdacaaSwren romano static void outSparseTensor(void *tensor, void *dest, bool sort) {
13986438783fSAart Bik   assert(tensor && dest);
13996438783fSAart Bik   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
14006438783fSAart Bik   if (sort)
14016438783fSAart Bik     coo->sort();
14026438783fSAart Bik   char *filename = static_cast<char *>(dest);
1403fa6aed2aSwren romano   auto &dimSizes = coo->getDimSizes();
14046438783fSAart Bik   auto &elements = coo->getElements();
14056438783fSAart Bik   uint64_t rank = coo->getRank();
1406efa15f41SAart Bik   uint64_t nnz = elements.size();
1407efa15f41SAart Bik   std::fstream file;
1408efa15f41SAart Bik   file.open(filename, std::ios_base::out | std::ios_base::trunc);
1409efa15f41SAart Bik   assert(file.is_open());
1410efa15f41SAart Bik   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1411efa15f41SAart Bik   for (uint64_t r = 0; r < rank - 1; r++)
1412fa6aed2aSwren romano     file << dimSizes[r] << " ";
1413fa6aed2aSwren romano   file << dimSizes[rank - 1] << std::endl;
1414efa15f41SAart Bik   for (uint64_t i = 0; i < nnz; i++) {
1415efa15f41SAart Bik     auto &idx = elements[i].indices;
1416efa15f41SAart Bik     for (uint64_t r = 0; r < rank; r++)
1417efa15f41SAart Bik       file << (idx[r] + 1) << " ";
1418efa15f41SAart Bik     file << elements[i].value << std::endl;
1419efa15f41SAart Bik   }
1420efa15f41SAart Bik   file.flush();
1421efa15f41SAart Bik   file.close();
1422efa15f41SAart Bik   assert(file.good());
14236438783fSAart Bik }
14246438783fSAart Bik 
14256438783fSAart Bik /// Initializes sparse tensor from an external COO-flavored format.
14266438783fSAart Bik template <typename V>
142746bdacaaSwren romano static SparseTensorStorage<uint64_t, uint64_t, V> *
toMLIRSparseTensor(uint64_t rank,uint64_t nse,uint64_t * shape,V * values,uint64_t * indices,uint64_t * perm,uint8_t * sparse)14286438783fSAart Bik toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
142920eaa88fSBixia Zheng                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
143020eaa88fSBixia Zheng   const DimLevelType *sparsity = (DimLevelType *)(sparse);
143120eaa88fSBixia Zheng #ifndef NDEBUG
143220eaa88fSBixia Zheng   // Verify that perm is a permutation of 0..(rank-1).
143320eaa88fSBixia Zheng   std::vector<uint64_t> order(perm, perm + rank);
143420eaa88fSBixia Zheng   std::sort(order.begin(), order.end());
1435774674ceSwren romano   for (uint64_t i = 0; i < rank; ++i)
1436774674ceSwren romano     if (i != order[i])
1437774674ceSwren romano       FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
143820eaa88fSBixia Zheng 
143920eaa88fSBixia Zheng   // Verify that the sparsity values are supported.
1440774674ceSwren romano   for (uint64_t i = 0; i < rank; ++i)
144120eaa88fSBixia Zheng     if (sparsity[i] != DimLevelType::kDense &&
1442774674ceSwren romano         sparsity[i] != DimLevelType::kCompressed)
1443774674ceSwren romano       FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
144420eaa88fSBixia Zheng #endif
144520eaa88fSBixia Zheng 
14466438783fSAart Bik   // Convert external format to internal COO.
144763bdcaf9Swren romano   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
14486438783fSAart Bik   std::vector<uint64_t> idx(rank);
14496438783fSAart Bik   for (uint64_t i = 0, base = 0; i < nse; i++) {
14506438783fSAart Bik     for (uint64_t r = 0; r < rank; r++)
1451d8b229a1SAart Bik       idx[perm[r]] = indices[base + r];
145263bdcaf9Swren romano     coo->add(idx, values[i]);
14536438783fSAart Bik     base += rank;
14546438783fSAart Bik   }
14556438783fSAart Bik   // Return sparse tensor storage format as opaque pointer.
145663bdcaf9Swren romano   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
145763bdcaf9Swren romano       rank, shape, perm, sparsity, coo);
145863bdcaf9Swren romano   delete coo;
145963bdcaf9Swren romano   return tensor;
14606438783fSAart Bik }
14616438783fSAart Bik 
14626438783fSAart Bik /// Converts a sparse tensor to an external COO-flavored format.
14636438783fSAart Bik template <typename V>
fromMLIRSparseTensor(void * tensor,uint64_t * pRank,uint64_t * pNse,uint64_t ** pShape,V ** pValues,uint64_t ** pIndices)146446bdacaaSwren romano static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
146546bdacaaSwren romano                                  uint64_t **pShape, V **pValues,
146646bdacaaSwren romano                                  uint64_t **pIndices) {
1467736c1b66SAart Bik   assert(tensor);
14686438783fSAart Bik   auto sparseTensor =
14696438783fSAart Bik       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
14706438783fSAart Bik   uint64_t rank = sparseTensor->getRank();
14716438783fSAart Bik   std::vector<uint64_t> perm(rank);
14726438783fSAart Bik   std::iota(perm.begin(), perm.end(), 0);
14736438783fSAart Bik   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
14746438783fSAart Bik 
14756438783fSAart Bik   const std::vector<Element<V>> &elements = coo->getElements();
14766438783fSAart Bik   uint64_t nse = elements.size();
14776438783fSAart Bik 
14786438783fSAart Bik   uint64_t *shape = new uint64_t[rank];
14796438783fSAart Bik   for (uint64_t i = 0; i < rank; i++)
1480fa6aed2aSwren romano     shape[i] = coo->getDimSizes()[i];
14816438783fSAart Bik 
14826438783fSAart Bik   V *values = new V[nse];
14836438783fSAart Bik   uint64_t *indices = new uint64_t[rank * nse];
14846438783fSAart Bik 
14856438783fSAart Bik   for (uint64_t i = 0, base = 0; i < nse; i++) {
14866438783fSAart Bik     values[i] = elements[i].value;
14876438783fSAart Bik     for (uint64_t j = 0; j < rank; j++)
14886438783fSAart Bik       indices[base + j] = elements[i].indices[j];
14896438783fSAart Bik     base += rank;
14906438783fSAart Bik   }
14916438783fSAart Bik 
14926438783fSAart Bik   delete coo;
14936438783fSAart Bik   *pRank = rank;
14946438783fSAart Bik   *pNse = nse;
14956438783fSAart Bik   *pShape = shape;
14966438783fSAart Bik   *pValues = values;
14976438783fSAart Bik   *pIndices = indices;
1498efa15f41SAart Bik }
1499efa15f41SAart Bik 
15002046e11aSwren romano } // anonymous namespace
15018a91bc7bSHarrietAkot 
15028a91bc7bSHarrietAkot extern "C" {
15038a91bc7bSHarrietAkot 
15048a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
15058a91bc7bSHarrietAkot //
15062046e11aSwren romano // Public functions which operate on MLIR buffers (memrefs) to interact
15072046e11aSwren romano // with sparse tensors (which are only visible as opaque pointers externally).
15088a91bc7bSHarrietAkot //
15098a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
15108a91bc7bSHarrietAkot 
15118a91bc7bSHarrietAkot #define CASE(p, i, v, P, I, V)                                                 \
15128a91bc7bSHarrietAkot   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
151363bdcaf9Swren romano     SparseTensorCOO<V> *coo = nullptr;                                         \
1514845561ecSwren romano     if (action <= Action::kFromCOO) {                                          \
1515845561ecSwren romano       if (action == Action::kFromFile) {                                       \
15168a91bc7bSHarrietAkot         char *filename = static_cast<char *>(ptr);                             \
15175b1c5fc5Sbixia1         coo = openSparseTensorCOO<V>(filename, rank, shape, perm, v);          \
1518845561ecSwren romano       } else if (action == Action::kFromCOO) {                                 \
151963bdcaf9Swren romano         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
15208a91bc7bSHarrietAkot       } else {                                                                 \
1521845561ecSwren romano         assert(action == Action::kEmpty);                                      \
15228a91bc7bSHarrietAkot       }                                                                        \
152363bdcaf9Swren romano       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
152463bdcaf9Swren romano           rank, shape, perm, sparsity, coo);                                   \
152563bdcaf9Swren romano       if (action == Action::kFromFile)                                         \
152663bdcaf9Swren romano         delete coo;                                                            \
152763bdcaf9Swren romano       return tensor;                                                           \
1528bb56c2b3SMehdi Amini     }                                                                          \
15298cb33240Swren romano     if (action == Action::kSparseToSparse) {                                   \
15308cb33240Swren romano       auto *tensor = static_cast<SparseTensorStorageBase *>(ptr);              \
15318cb33240Swren romano       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
15328cb33240Swren romano                                                            sparsity, tensor);  \
15338cb33240Swren romano     }                                                                          \
1534bb56c2b3SMehdi Amini     if (action == Action::kEmptyCOO)                                           \
1535d83a7068Swren romano       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
153663bdcaf9Swren romano     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1537845561ecSwren romano     if (action == Action::kToIterator) {                                       \
153863bdcaf9Swren romano       coo->startIterator();                                                    \
15398a91bc7bSHarrietAkot     } else {                                                                   \
1540845561ecSwren romano       assert(action == Action::kToCOO);                                        \
15418a91bc7bSHarrietAkot     }                                                                          \
154263bdcaf9Swren romano     return coo;                                                                \
15438a91bc7bSHarrietAkot   }
15448a91bc7bSHarrietAkot 
1545845561ecSwren romano #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
15464f2ec7f9SAart Bik 
1547d2215e79SRainer Orth // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1548bc04a470Swren romano // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1549bc04a470Swren romano // that this file cannot get out of sync with its header.
1550d2215e79SRainer Orth static_assert(std::is_same<index_type, uint64_t>::value,
1551d2215e79SRainer Orth               "Expected index_type == uint64_t");
1552bc04a470Swren romano 
15538a91bc7bSHarrietAkot void *
_mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType,1> * aref,StridedMemRefType<index_type,1> * sref,StridedMemRefType<index_type,1> * pref,OverheadType ptrTp,OverheadType indTp,PrimaryType valTp,Action action,void * ptr)1554845561ecSwren romano _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1555d2215e79SRainer Orth                              StridedMemRefType<index_type, 1> *sref,
1556d2215e79SRainer Orth                              StridedMemRefType<index_type, 1> *pref,
1557845561ecSwren romano                              OverheadType ptrTp, OverheadType indTp,
1558845561ecSwren romano                              PrimaryType valTp, Action action, void *ptr) {
15598a91bc7bSHarrietAkot   assert(aref && sref && pref);
15608a91bc7bSHarrietAkot   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
15618a91bc7bSHarrietAkot          pref->strides[0] == 1);
15628a91bc7bSHarrietAkot   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1563845561ecSwren romano   const DimLevelType *sparsity = aref->data + aref->offset;
1564d83a7068Swren romano   const index_type *shape = sref->data + sref->offset;
1565d2215e79SRainer Orth   const index_type *perm = pref->data + pref->offset;
15668a91bc7bSHarrietAkot   uint64_t rank = aref->sizes[0];
15678a91bc7bSHarrietAkot 
1568bc04a470Swren romano   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1569bc04a470Swren romano   // This is safe because of the static_assert above.
1570bc04a470Swren romano   if (ptrTp == OverheadType::kIndex)
1571bc04a470Swren romano     ptrTp = OverheadType::kU64;
1572bc04a470Swren romano   if (indTp == OverheadType::kIndex)
1573bc04a470Swren romano     indTp = OverheadType::kU64;
1574bc04a470Swren romano 
15758a91bc7bSHarrietAkot   // Double matrices with all combinations of overhead storage.
1576845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1577845561ecSwren romano        uint64_t, double);
1578845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1579845561ecSwren romano        uint32_t, double);
1580845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1581845561ecSwren romano        uint16_t, double);
1582845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1583845561ecSwren romano        uint8_t, double);
1584845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1585845561ecSwren romano        uint64_t, double);
1586845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1587845561ecSwren romano        uint32_t, double);
1588845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1589845561ecSwren romano        uint16_t, double);
1590845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1591845561ecSwren romano        uint8_t, double);
1592845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1593845561ecSwren romano        uint64_t, double);
1594845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1595845561ecSwren romano        uint32_t, double);
1596845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1597845561ecSwren romano        uint16_t, double);
1598845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1599845561ecSwren romano        uint8_t, double);
1600845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1601845561ecSwren romano        uint64_t, double);
1602845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1603845561ecSwren romano        uint32_t, double);
1604845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1605845561ecSwren romano        uint16_t, double);
1606845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1607845561ecSwren romano        uint8_t, double);
16088a91bc7bSHarrietAkot 
16098a91bc7bSHarrietAkot   // Float matrices with all combinations of overhead storage.
1610845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1611845561ecSwren romano        uint64_t, float);
1612845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1613845561ecSwren romano        uint32_t, float);
1614845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1615845561ecSwren romano        uint16_t, float);
1616845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1617845561ecSwren romano        uint8_t, float);
1618845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1619845561ecSwren romano        uint64_t, float);
1620845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1621845561ecSwren romano        uint32_t, float);
1622845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1623845561ecSwren romano        uint16_t, float);
1624845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1625845561ecSwren romano        uint8_t, float);
1626845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1627845561ecSwren romano        uint64_t, float);
1628845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1629845561ecSwren romano        uint32_t, float);
1630845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1631845561ecSwren romano        uint16_t, float);
1632845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1633845561ecSwren romano        uint8_t, float);
1634845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1635845561ecSwren romano        uint64_t, float);
1636845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1637845561ecSwren romano        uint32_t, float);
1638845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1639845561ecSwren romano        uint16_t, float);
1640845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1641845561ecSwren romano        uint8_t, float);
16428a91bc7bSHarrietAkot 
1643ea8ed5cbSbixia1   // Two-byte floats with both overheads of the same type.
1644ea8ed5cbSbixia1   CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16);
1645ea8ed5cbSbixia1   CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16);
1646ea8ed5cbSbixia1   CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16);
1647ea8ed5cbSbixia1   CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16);
1648ea8ed5cbSbixia1   CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16);
1649ea8ed5cbSbixia1   CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16);
1650ea8ed5cbSbixia1   CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16);
1651ea8ed5cbSbixia1   CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16);
1652ea8ed5cbSbixia1 
1653845561ecSwren romano   // Integral matrices with both overheads of the same type.
1654845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1655845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1656845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1657845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
16582046e11aSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t);
1659845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1660845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1661845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
16622046e11aSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t);
1663845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1664845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1665845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
16662046e11aSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t);
1667845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1668845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1669845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
16708a91bc7bSHarrietAkot 
1671736c1b66SAart Bik   // Complex matrices with wide overhead.
1672736c1b66SAart Bik   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1673736c1b66SAart Bik   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1674736c1b66SAart Bik 
16758a91bc7bSHarrietAkot   // Unsupported case (add above if needed).
1676774674ceSwren romano   // TODO: better pretty-printing of enum values!
1677774674ceSwren romano   FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
1678774674ceSwren romano         static_cast<int>(ptrTp), static_cast<int>(indTp),
1679774674ceSwren romano         static_cast<int>(valTp));
16808a91bc7bSHarrietAkot }
16818a91bc7bSHarrietAkot #undef CASE
16821313f5d3Swren romano #undef CASE_SECSAME
16836438783fSAart Bik 
1684bfadd13dSwren romano #define IMPL_SPARSEVALUES(VNAME, V)                                            \
1685bfadd13dSwren romano   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
1686bfadd13dSwren romano                                         void *tensor) {                        \
1687bfadd13dSwren romano     assert(ref &&tensor);                                                      \
1688bfadd13dSwren romano     std::vector<V> *v;                                                         \
1689bfadd13dSwren romano     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
1690bfadd13dSwren romano     ref->basePtr = ref->data = v->data();                                      \
1691bfadd13dSwren romano     ref->offset = 0;                                                           \
1692bfadd13dSwren romano     ref->sizes[0] = v->size();                                                 \
1693bfadd13dSwren romano     ref->strides[0] = 1;                                                       \
1694bfadd13dSwren romano   }
1695bfadd13dSwren romano FOREVERY_V(IMPL_SPARSEVALUES)
1696bfadd13dSwren romano #undef IMPL_SPARSEVALUES
1697bfadd13dSwren romano 
1698bfadd13dSwren romano #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1699bfadd13dSwren romano   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1700bfadd13dSwren romano                            index_type d) {                                     \
1701bfadd13dSwren romano     assert(ref &&tensor);                                                      \
1702bfadd13dSwren romano     std::vector<TYPE> *v;                                                      \
1703bfadd13dSwren romano     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1704bfadd13dSwren romano     ref->basePtr = ref->data = v->data();                                      \
1705bfadd13dSwren romano     ref->offset = 0;                                                           \
1706bfadd13dSwren romano     ref->sizes[0] = v->size();                                                 \
1707bfadd13dSwren romano     ref->strides[0] = 1;                                                       \
1708bfadd13dSwren romano   }
1709a9a19f59Swren romano #define IMPL_SPARSEPOINTERS(PNAME, P)                                          \
1710a9a19f59Swren romano   IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
FOREVERY_O(IMPL_SPARSEPOINTERS)1711a9a19f59Swren romano FOREVERY_O(IMPL_SPARSEPOINTERS)
1712a9a19f59Swren romano #undef IMPL_SPARSEPOINTERS
1713bfadd13dSwren romano 
1714a9a19f59Swren romano #define IMPL_SPARSEINDICES(INAME, I)                                           \
1715a9a19f59Swren romano   IMPL_GETOVERHEAD(sparseIndices##INAME, I, getIndices)
1716a9a19f59Swren romano FOREVERY_O(IMPL_SPARSEINDICES)
1717a9a19f59Swren romano #undef IMPL_SPARSEINDICES
1718bfadd13dSwren romano #undef IMPL_GETOVERHEAD
1719bfadd13dSwren romano 
1720bfadd13dSwren romano #define IMPL_ADDELT(VNAME, V)                                                  \
1721aef20f59SAart Bik   void *_mlir_ciface_addElt##VNAME(void *coo, StridedMemRefType<V, 0> *vref,   \
1722bfadd13dSwren romano                                    StridedMemRefType<index_type, 1> *iref,     \
1723bfadd13dSwren romano                                    StridedMemRefType<index_type, 1> *pref) {   \
1724aef20f59SAart Bik     assert(coo &&vref &&iref &&pref);                                          \
1725bfadd13dSwren romano     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1726bfadd13dSwren romano     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1727bfadd13dSwren romano     const index_type *indx = iref->data + iref->offset;                        \
1728bfadd13dSwren romano     const index_type *perm = pref->data + pref->offset;                        \
1729bfadd13dSwren romano     uint64_t isize = iref->sizes[0];                                           \
1730bfadd13dSwren romano     std::vector<index_type> indices(isize);                                    \
1731bfadd13dSwren romano     for (uint64_t r = 0; r < isize; r++)                                       \
1732bfadd13dSwren romano       indices[perm[r]] = indx[r];                                              \
1733aef20f59SAart Bik     V *value = vref->data + vref->offset;                                      \
1734aef20f59SAart Bik     static_cast<SparseTensorCOO<V> *>(coo)->add(indices, *value);              \
1735bfadd13dSwren romano     return coo;                                                                \
1736bfadd13dSwren romano   }
1737aef20f59SAart Bik FOREVERY_V(IMPL_ADDELT)
17382046e11aSwren romano #undef IMPL_ADDELT
1739bfadd13dSwren romano 
1740bfadd13dSwren romano #define IMPL_GETNEXT(VNAME, V)                                                 \
1741bfadd13dSwren romano   bool _mlir_ciface_getNext##VNAME(void *coo,                                  \
1742bfadd13dSwren romano                                    StridedMemRefType<index_type, 1> *iref,     \
1743bfadd13dSwren romano                                    StridedMemRefType<V, 0> *vref) {            \
1744bfadd13dSwren romano     assert(coo &&iref &&vref);                                                 \
1745bfadd13dSwren romano     assert(iref->strides[0] == 1);                                             \
1746bfadd13dSwren romano     index_type *indx = iref->data + iref->offset;                              \
1747bfadd13dSwren romano     V *value = vref->data + vref->offset;                                      \
1748bfadd13dSwren romano     const uint64_t isize = iref->sizes[0];                                     \
1749bfadd13dSwren romano     const Element<V> *elem =                                                   \
1750bfadd13dSwren romano         static_cast<SparseTensorCOO<V> *>(coo)->getNext();                     \
1751bfadd13dSwren romano     if (elem == nullptr)                                                       \
1752bfadd13dSwren romano       return false;                                                            \
1753bfadd13dSwren romano     for (uint64_t r = 0; r < isize; r++)                                       \
1754bfadd13dSwren romano       indx[r] = elem->indices[r];                                              \
1755bfadd13dSwren romano     *value = elem->value;                                                      \
1756bfadd13dSwren romano     return true;                                                               \
1757bfadd13dSwren romano   }
1758bfadd13dSwren romano FOREVERY_V(IMPL_GETNEXT)
1759bfadd13dSwren romano #undef IMPL_GETNEXT
1760bfadd13dSwren romano 
1761bfadd13dSwren romano #define IMPL_LEXINSERT(VNAME, V)                                               \
1762aef20f59SAart Bik   void _mlir_ciface_lexInsert##VNAME(void *tensor,                             \
1763aef20f59SAart Bik                                      StridedMemRefType<index_type, 1> *cref,   \
1764aef20f59SAart Bik                                      StridedMemRefType<V, 0> *vref) {          \
1765aef20f59SAart Bik     assert(tensor &&cref &&vref);                                              \
1766bfadd13dSwren romano     assert(cref->strides[0] == 1);                                             \
1767bfadd13dSwren romano     index_type *cursor = cref->data + cref->offset;                            \
1768bfadd13dSwren romano     assert(cursor);                                                            \
1769aef20f59SAart Bik     V *value = vref->data + vref->offset;                                      \
1770aef20f59SAart Bik     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, *value); \
1771bfadd13dSwren romano   }
1772aef20f59SAart Bik FOREVERY_V(IMPL_LEXINSERT)
17732046e11aSwren romano #undef IMPL_LEXINSERT
1774bfadd13dSwren romano 
1775bfadd13dSwren romano #define IMPL_EXPINSERT(VNAME, V)                                               \
1776bfadd13dSwren romano   void _mlir_ciface_expInsert##VNAME(                                          \
1777bfadd13dSwren romano       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1778bfadd13dSwren romano       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1779bfadd13dSwren romano       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1780bfadd13dSwren romano     assert(tensor &&cref &&vref &&fref &&aref);                                \
1781bfadd13dSwren romano     assert(cref->strides[0] == 1);                                             \
1782bfadd13dSwren romano     assert(vref->strides[0] == 1);                                             \
1783bfadd13dSwren romano     assert(fref->strides[0] == 1);                                             \
1784bfadd13dSwren romano     assert(aref->strides[0] == 1);                                             \
1785bfadd13dSwren romano     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1786bfadd13dSwren romano     index_type *cursor = cref->data + cref->offset;                            \
1787bfadd13dSwren romano     V *values = vref->data + vref->offset;                                     \
1788bfadd13dSwren romano     bool *filled = fref->data + fref->offset;                                  \
1789bfadd13dSwren romano     index_type *added = aref->data + aref->offset;                             \
1790bfadd13dSwren romano     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1791bfadd13dSwren romano         cursor, values, filled, added, count);                                 \
1792bfadd13dSwren romano   }
1793bfadd13dSwren romano FOREVERY_V(IMPL_EXPINSERT)
1794bfadd13dSwren romano #undef IMPL_EXPINSERT
1795bfadd13dSwren romano 
17968a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
17978a91bc7bSHarrietAkot //
17982046e11aSwren romano // Public functions which accept only C-style data structures to interact
17992046e11aSwren romano // with sparse tensors (which are only visible as opaque pointers externally).
18008a91bc7bSHarrietAkot //
18018a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
18028a91bc7bSHarrietAkot 
1803d2215e79SRainer Orth index_type sparseDimSize(void *tensor, index_type d) {
18048a91bc7bSHarrietAkot   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
18058a91bc7bSHarrietAkot }
18068a91bc7bSHarrietAkot 
endInsert(void * tensor)1807f66e5769SAart Bik void endInsert(void *tensor) {
1808f66e5769SAart Bik   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1809f66e5769SAart Bik }
1810f66e5769SAart Bik 
181105c17bc4Swren romano #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \
181205c17bc4Swren romano   void outSparseTensor##VNAME(void *coo, void *dest, bool sort) {              \
181305c17bc4Swren romano     return outSparseTensor<V>(coo, dest, sort);                                \
181405c17bc4Swren romano   }
FOREVERY_V(IMPL_OUTSPARSETENSOR)181505c17bc4Swren romano FOREVERY_V(IMPL_OUTSPARSETENSOR)
181605c17bc4Swren romano #undef IMPL_OUTSPARSETENSOR
181705c17bc4Swren romano 
18188a91bc7bSHarrietAkot void delSparseTensor(void *tensor) {
18198a91bc7bSHarrietAkot   delete static_cast<SparseTensorStorageBase *>(tensor);
18208a91bc7bSHarrietAkot }
18218a91bc7bSHarrietAkot 
182263bdcaf9Swren romano #define IMPL_DELCOO(VNAME, V)                                                  \
182363bdcaf9Swren romano   void delSparseTensorCOO##VNAME(void *coo) {                                  \
182463bdcaf9Swren romano     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
182563bdcaf9Swren romano   }
FOREVERY_V(IMPL_DELCOO)18261313f5d3Swren romano FOREVERY_V(IMPL_DELCOO)
182763bdcaf9Swren romano #undef IMPL_DELCOO
182863bdcaf9Swren romano 
182905c17bc4Swren romano char *getTensorFilename(index_type id) {
183005c17bc4Swren romano   char var[80];
183105c17bc4Swren romano   sprintf(var, "TENSOR%" PRIu64, id);
183205c17bc4Swren romano   char *env = getenv(var);
183305c17bc4Swren romano   if (!env)
183405c17bc4Swren romano     FATAL("Environment variable %s is not set\n", var);
183505c17bc4Swren romano   return env;
183605c17bc4Swren romano }
183705c17bc4Swren romano 
readSparseTensorShape(char * filename,std::vector<uint64_t> * out)1838a4c53f8cSwren romano void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
1839a4c53f8cSwren romano   assert(out && "Received nullptr for out-parameter");
1840a4c53f8cSwren romano   SparseTensorFile stfile(filename);
1841a4c53f8cSwren romano   stfile.openFile();
1842a4c53f8cSwren romano   stfile.readHeader();
1843a4c53f8cSwren romano   stfile.closeFile();
1844a4c53f8cSwren romano   const uint64_t rank = stfile.getRank();
1845a4c53f8cSwren romano   const uint64_t *dimSizes = stfile.getDimSizes();
1846a4c53f8cSwren romano   out->reserve(rank);
1847a4c53f8cSwren romano   out->assign(dimSizes, dimSizes + rank);
1848a4c53f8cSwren romano }
1849a4c53f8cSwren romano 
185020eaa88fSBixia Zheng // TODO: generalize beyond 64-bit indices.
18511313f5d3Swren romano #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V)                               \
18521313f5d3Swren romano   void *convertToMLIRSparseTensor##VNAME(                                      \
18531313f5d3Swren romano       uint64_t rank, uint64_t nse, uint64_t *shape, V *values,                 \
18541313f5d3Swren romano       uint64_t *indices, uint64_t *perm, uint8_t *sparse) {                    \
18551313f5d3Swren romano     return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm,      \
18561313f5d3Swren romano                                  sparse);                                      \
18578a91bc7bSHarrietAkot   }
18581313f5d3Swren romano FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
18591313f5d3Swren romano #undef IMPL_CONVERTTOMLIRSPARSETENSOR
18608a91bc7bSHarrietAkot 
18612f49e6b0SBixia Zheng // TODO: Currently, values are copied from SparseTensorStorage to
18622046e11aSwren romano // SparseTensorCOO, then to the output.  We may want to reduce the number
18632046e11aSwren romano // of copies.
18642f49e6b0SBixia Zheng //
18656438783fSAart Bik // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
18666438783fSAart Bik // compressed
18671313f5d3Swren romano #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V)                             \
18681313f5d3Swren romano   void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank,       \
18691313f5d3Swren romano                                           uint64_t *pNse, uint64_t **pShape,   \
18701313f5d3Swren romano                                           V **pValues, uint64_t **pIndices) {  \
18711313f5d3Swren romano     fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices);   \
18722f49e6b0SBixia Zheng   }
18731313f5d3Swren romano FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
18741313f5d3Swren romano #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
1875efa15f41SAart Bik 
18768a91bc7bSHarrietAkot } // extern "C"
18778a91bc7bSHarrietAkot 
18788a91bc7bSHarrietAkot #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1879