18a91bc7bSHarrietAkot //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
28a91bc7bSHarrietAkot //
38a91bc7bSHarrietAkot // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48a91bc7bSHarrietAkot // See https://llvm.org/LICENSE.txt for license information.
58a91bc7bSHarrietAkot // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a91bc7bSHarrietAkot //
78a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
88a91bc7bSHarrietAkot //
98a91bc7bSHarrietAkot // This file implements a light-weight runtime support library that is useful
108a91bc7bSHarrietAkot // for sparse tensor manipulations. The functionality provided in this library
118a91bc7bSHarrietAkot // is meant to simplify benchmarking, testing, and debugging MLIR code that
128a91bc7bSHarrietAkot // operates on sparse tensors. The provided functionality is **not** part
138a91bc7bSHarrietAkot // of core MLIR, however.
148a91bc7bSHarrietAkot //
158a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
168a91bc7bSHarrietAkot 
17845561ecSwren romano #include "mlir/ExecutionEngine/SparseTensorUtils.h"
188a91bc7bSHarrietAkot #include "mlir/ExecutionEngine/CRunnerUtils.h"
198a91bc7bSHarrietAkot 
208a91bc7bSHarrietAkot #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
218a91bc7bSHarrietAkot 
228a91bc7bSHarrietAkot #include <algorithm>
238a91bc7bSHarrietAkot #include <cassert>
24736c1b66SAart Bik #include <complex>
258a91bc7bSHarrietAkot #include <cctype>
268a91bc7bSHarrietAkot #include <cinttypes>
278a91bc7bSHarrietAkot #include <cstdio>
288a91bc7bSHarrietAkot #include <cstdlib>
298a91bc7bSHarrietAkot #include <cstring>
30efa15f41SAart Bik #include <fstream>
31753fe330Swren romano #include <functional>
32efa15f41SAart Bik #include <iostream>
334d0a18d0Swren romano #include <limits>
348a91bc7bSHarrietAkot #include <numeric>
358a91bc7bSHarrietAkot #include <vector>
368a91bc7bSHarrietAkot 
37736c1b66SAart Bik using complex64 = std::complex<double>;
38736c1b66SAart Bik using complex32 = std::complex<float>;
39736c1b66SAart Bik 
408a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
418a91bc7bSHarrietAkot //
428a91bc7bSHarrietAkot // Internal support for storing and reading sparse tensors.
438a91bc7bSHarrietAkot //
448a91bc7bSHarrietAkot // The following memory-resident sparse storage schemes are supported:
458a91bc7bSHarrietAkot //
468a91bc7bSHarrietAkot // (a) A coordinate scheme for temporarily storing and lexicographically
478a91bc7bSHarrietAkot //     sorting a sparse tensor by index (SparseTensorCOO).
488a91bc7bSHarrietAkot //
498a91bc7bSHarrietAkot // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
508a91bc7bSHarrietAkot //     per-dimension sparse/dense annnotations together with a dimension
518a91bc7bSHarrietAkot //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
528a91bc7bSHarrietAkot //
538a91bc7bSHarrietAkot // The following external formats are supported:
548a91bc7bSHarrietAkot //
558a91bc7bSHarrietAkot // (1) Matrix Market Exchange (MME): *.mtx
568a91bc7bSHarrietAkot //     https://math.nist.gov/MatrixMarket/formats.html
578a91bc7bSHarrietAkot //
588a91bc7bSHarrietAkot // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
598a91bc7bSHarrietAkot //     http://frostt.io/tensors/file-formats.html
608a91bc7bSHarrietAkot //
618a91bc7bSHarrietAkot // Two public APIs are supported:
628a91bc7bSHarrietAkot //
638a91bc7bSHarrietAkot // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
648a91bc7bSHarrietAkot //     tensors. These methods should be used exclusively by MLIR
658a91bc7bSHarrietAkot //     compiler-generated code.
668a91bc7bSHarrietAkot //
678a91bc7bSHarrietAkot // (II) Methods that accept C-style data structures to interact with sparse
688a91bc7bSHarrietAkot //      tensors. These methods can be used by any external runtime that wants
698a91bc7bSHarrietAkot //      to interact with MLIR compiler-generated code.
708a91bc7bSHarrietAkot //
718a91bc7bSHarrietAkot // In both cases (I) and (II), the SparseTensorStorage format is externally
728a91bc7bSHarrietAkot // only visible as an opaque pointer.
738a91bc7bSHarrietAkot //
748a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
758a91bc7bSHarrietAkot 
768a91bc7bSHarrietAkot namespace {
778a91bc7bSHarrietAkot 
7803fe15ceSAart Bik static constexpr int kColWidth = 1025;
7903fe15ceSAart Bik 
8072ec2f76Swren romano /// A version of `operator*` on `uint64_t` which checks for overflows.
8172ec2f76Swren romano static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
8272ec2f76Swren romano   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
8372ec2f76Swren romano          "Integer overflow");
8472ec2f76Swren romano   return lhs * rhs;
8572ec2f76Swren romano }
8672ec2f76Swren romano 
878cb33240Swren romano // TODO: adjust this so it can be used by `openSparseTensorCOO` too.
88*fa6aed2aSwren romano // That version doesn't have the permutation, and the `dimSizes` are
898cb33240Swren romano // a pointer/C-array rather than `std::vector`.
908cb33240Swren romano //
91*fa6aed2aSwren romano /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
928cb33240Swren romano /// semantic-order to target-order) are a refinement of the desired `shape`
938cb33240Swren romano /// (in semantic-order).
948cb33240Swren romano ///
958cb33240Swren romano /// Precondition: `perm` and `shape` must be valid for `rank`.
968cb33240Swren romano static inline void
97*fa6aed2aSwren romano assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes,
98*fa6aed2aSwren romano                               uint64_t rank, const uint64_t *perm,
99*fa6aed2aSwren romano                               const uint64_t *shape) {
1008cb33240Swren romano   assert(perm && shape);
101*fa6aed2aSwren romano   assert(rank == dimSizes.size() && "Rank mismatch");
1028cb33240Swren romano   for (uint64_t r = 0; r < rank; r++)
103*fa6aed2aSwren romano     assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) &&
1048cb33240Swren romano            "Dimension size mismatch");
1058cb33240Swren romano }
1068cb33240Swren romano 
1078a91bc7bSHarrietAkot /// A sparse tensor element in coordinate scheme (value and indices).
1088a91bc7bSHarrietAkot /// For example, a rank-1 vector element would look like
1098a91bc7bSHarrietAkot ///   ({i}, a[i])
1108a91bc7bSHarrietAkot /// and a rank-5 tensor element like
1118a91bc7bSHarrietAkot ///   ({i,j,k,l,m}, a[i,j,k,l,m])
112ccd047cbSAart Bik /// We use pointer to a shared index pool rather than e.g. a direct
113ccd047cbSAart Bik /// vector since that (1) reduces the per-element memory footprint, and
114ccd047cbSAart Bik /// (2) centralizes the memory reservation and (re)allocation to one place.
1158a91bc7bSHarrietAkot template <typename V>
11676944420Swren romano struct Element final {
117ccd047cbSAart Bik   Element(uint64_t *ind, V val) : indices(ind), value(val){};
118ccd047cbSAart Bik   uint64_t *indices; // pointer into shared index pool
1198a91bc7bSHarrietAkot   V value;
1208a91bc7bSHarrietAkot };
1218a91bc7bSHarrietAkot 
122753fe330Swren romano /// The type of callback functions which receive an element.  We avoid
123753fe330Swren romano /// packaging the coordinates and value together as an `Element` object
124753fe330Swren romano /// because this helps keep code somewhat cleaner.
125753fe330Swren romano template <typename V>
126753fe330Swren romano using ElementConsumer =
127753fe330Swren romano     const std::function<void(const std::vector<uint64_t> &, V)> &;
128753fe330Swren romano 
1298a91bc7bSHarrietAkot /// A memory-resident sparse tensor in coordinate scheme (collection of
1308a91bc7bSHarrietAkot /// elements). This data structure is used to read a sparse tensor from
1318a91bc7bSHarrietAkot /// any external format into memory and sort the elements lexicographically
1328a91bc7bSHarrietAkot /// by indices before passing it back to the client (most packed storage
1338a91bc7bSHarrietAkot /// formats require the elements to appear in lexicographic index order).
1348a91bc7bSHarrietAkot template <typename V>
13576944420Swren romano struct SparseTensorCOO final {
1368a91bc7bSHarrietAkot public:
137*fa6aed2aSwren romano   SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
138*fa6aed2aSwren romano       : dimSizes(dimSizes) {
139ccd047cbSAart Bik     if (capacity) {
1408a91bc7bSHarrietAkot       elements.reserve(capacity);
141ccd047cbSAart Bik       indices.reserve(capacity * getRank());
1428a91bc7bSHarrietAkot     }
143ccd047cbSAart Bik   }
144ccd047cbSAart Bik 
1458a91bc7bSHarrietAkot   /// Adds element as indices and value.
1468a91bc7bSHarrietAkot   void add(const std::vector<uint64_t> &ind, V val) {
1478a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to add() after startIterator()");
148ccd047cbSAart Bik     uint64_t *base = indices.data();
149ccd047cbSAart Bik     uint64_t size = indices.size();
1508a91bc7bSHarrietAkot     uint64_t rank = getRank();
151*fa6aed2aSwren romano     assert(ind.size() == rank && "Element rank mismatch");
152ccd047cbSAart Bik     for (uint64_t r = 0; r < rank; r++) {
153*fa6aed2aSwren romano       assert(ind[r] < dimSizes[r] && "Index is too large for the dimension");
154ccd047cbSAart Bik       indices.push_back(ind[r]);
1558a91bc7bSHarrietAkot     }
156ccd047cbSAart Bik     // This base only changes if indices were reallocated. In that case, we
157ccd047cbSAart Bik     // need to correct all previous pointers into the vector. Note that this
158ccd047cbSAart Bik     // only happens if we did not set the initial capacity right, and then only
159ccd047cbSAart Bik     // for every internal vector reallocation (which with the doubling rule
160ccd047cbSAart Bik     // should only incur an amortized linear overhead).
161298d2fa1SMehdi Amini     uint64_t *newBase = indices.data();
162298d2fa1SMehdi Amini     if (newBase != base) {
163ccd047cbSAart Bik       for (uint64_t i = 0, n = elements.size(); i < n; i++)
164298d2fa1SMehdi Amini         elements[i].indices = newBase + (elements[i].indices - base);
165298d2fa1SMehdi Amini       base = newBase;
166ccd047cbSAart Bik     }
167ccd047cbSAart Bik     // Add element as (pointer into shared index pool, value) pair.
168ccd047cbSAart Bik     elements.emplace_back(base + size, val);
169ccd047cbSAart Bik   }
170ccd047cbSAart Bik 
1718a91bc7bSHarrietAkot   /// Sorts elements lexicographically by index.
1728a91bc7bSHarrietAkot   void sort() {
1738a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
174cf358253Swren romano     // TODO: we may want to cache an `isSorted` bit, to avoid
175cf358253Swren romano     // unnecessary/redundant sorting.
176ccd047cbSAart Bik     std::sort(elements.begin(), elements.end(),
177ccd047cbSAart Bik               [this](const Element<V> &e1, const Element<V> &e2) {
178ccd047cbSAart Bik                 uint64_t rank = getRank();
179ccd047cbSAart Bik                 for (uint64_t r = 0; r < rank; r++) {
180ccd047cbSAart Bik                   if (e1.indices[r] == e2.indices[r])
181ccd047cbSAart Bik                     continue;
182ccd047cbSAart Bik                   return e1.indices[r] < e2.indices[r];
1838a91bc7bSHarrietAkot                 }
184ccd047cbSAart Bik                 return false;
185ccd047cbSAart Bik               });
186ccd047cbSAart Bik   }
187ccd047cbSAart Bik 
188*fa6aed2aSwren romano   /// Get the rank of the tensor.
189*fa6aed2aSwren romano   uint64_t getRank() const { return dimSizes.size(); }
190ccd047cbSAart Bik 
191*fa6aed2aSwren romano   /// Getter for the dimension-sizes array.
192*fa6aed2aSwren romano   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
193ccd047cbSAart Bik 
194*fa6aed2aSwren romano   /// Getter for the elements array.
1958a91bc7bSHarrietAkot   const std::vector<Element<V>> &getElements() const { return elements; }
1968a91bc7bSHarrietAkot 
1978a91bc7bSHarrietAkot   /// Switch into iterator mode.
1988a91bc7bSHarrietAkot   void startIterator() {
1998a91bc7bSHarrietAkot     iteratorLocked = true;
2008a91bc7bSHarrietAkot     iteratorPos = 0;
2018a91bc7bSHarrietAkot   }
202ccd047cbSAart Bik 
2038a91bc7bSHarrietAkot   /// Get the next element.
2048a91bc7bSHarrietAkot   const Element<V> *getNext() {
2058a91bc7bSHarrietAkot     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
2068a91bc7bSHarrietAkot     if (iteratorPos < elements.size())
2078a91bc7bSHarrietAkot       return &(elements[iteratorPos++]);
2088a91bc7bSHarrietAkot     iteratorLocked = false;
2098a91bc7bSHarrietAkot     return nullptr;
2108a91bc7bSHarrietAkot   }
2118a91bc7bSHarrietAkot 
2128a91bc7bSHarrietAkot   /// Factory method. Permutes the original dimensions according to
2138a91bc7bSHarrietAkot   /// the given ordering and expects subsequent add() calls to honor
2148a91bc7bSHarrietAkot   /// that same ordering for the given indices. The result is a
2158a91bc7bSHarrietAkot   /// fully permuted coordinate scheme.
2168d8b566fSwren romano   ///
217*fa6aed2aSwren romano   /// Precondition: `dimSizes` and `perm` must be valid for `rank`.
2188a91bc7bSHarrietAkot   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
219*fa6aed2aSwren romano                                                 const uint64_t *dimSizes,
2208a91bc7bSHarrietAkot                                                 const uint64_t *perm,
2218a91bc7bSHarrietAkot                                                 uint64_t capacity = 0) {
2228a91bc7bSHarrietAkot     std::vector<uint64_t> permsz(rank);
223d83a7068Swren romano     for (uint64_t r = 0; r < rank; r++) {
224*fa6aed2aSwren romano       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
225*fa6aed2aSwren romano       permsz[perm[r]] = dimSizes[r];
226d83a7068Swren romano     }
2278a91bc7bSHarrietAkot     return new SparseTensorCOO<V>(permsz, capacity);
2288a91bc7bSHarrietAkot   }
2298a91bc7bSHarrietAkot 
2308a91bc7bSHarrietAkot private:
231*fa6aed2aSwren romano   const std::vector<uint64_t> dimSizes; // per-dimension sizes
232ccd047cbSAart Bik   std::vector<Element<V>> elements;     // all COO elements
233ccd047cbSAart Bik   std::vector<uint64_t> indices;        // shared index pool
234db6796dfSMehdi Amini   bool iteratorLocked = false;
235db6796dfSMehdi Amini   unsigned iteratorPos = 0;
2368a91bc7bSHarrietAkot };
2378a91bc7bSHarrietAkot 
2381313f5d3Swren romano // See <https://en.wikipedia.org/wiki/X_Macro>
2391313f5d3Swren romano //
2401313f5d3Swren romano // `FOREVERY_SIMPLEX_V` only specifies the non-complex `V` types, because
2411313f5d3Swren romano // the ABI for complex types has compiler/architecture dependent complexities
2421313f5d3Swren romano // we need to work around.  Namely, when a function takes a parameter of
2431313f5d3Swren romano // C/C++ type `complex32` (per se), then there is additional padding that
2441313f5d3Swren romano // causes it not to match the LLVM type `!llvm.struct<(f32, f32)>`.  This
2451313f5d3Swren romano // only happens with the `complex32` type itself, not with pointers/arrays
2461313f5d3Swren romano // of complex values.  So far `complex64` doesn't exhibit this ABI
2471313f5d3Swren romano // incompatibility, but we exclude it anyways just to be safe.
2481313f5d3Swren romano #define FOREVERY_SIMPLEX_V(DO)                                                 \
2491313f5d3Swren romano   DO(F64, double)                                                              \
2501313f5d3Swren romano   DO(F32, float)                                                               \
2511313f5d3Swren romano   DO(I64, int64_t)                                                             \
2521313f5d3Swren romano   DO(I32, int32_t)                                                             \
2531313f5d3Swren romano   DO(I16, int16_t)                                                             \
2541313f5d3Swren romano   DO(I8, int8_t)
2551313f5d3Swren romano 
2561313f5d3Swren romano #define FOREVERY_V(DO)                                                         \
2571313f5d3Swren romano   FOREVERY_SIMPLEX_V(DO)                                                       \
2581313f5d3Swren romano   DO(C64, complex64)                                                           \
2591313f5d3Swren romano   DO(C32, complex32)
2601313f5d3Swren romano 
2618cb33240Swren romano // Forward.
2628cb33240Swren romano template <typename V>
2638cb33240Swren romano class SparseTensorEnumeratorBase;
2648cb33240Swren romano 
2658d8b566fSwren romano /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
2668d8b566fSwren romano /// takes responsibility for all the `<P,I,V>`-independent aspects
2678d8b566fSwren romano /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
2688d8b566fSwren romano /// we use function overloading to implement "partial" method
2698d8b566fSwren romano /// specialization, which the C-API relies on to catch type errors
2708d8b566fSwren romano /// arising from our use of opaque pointers.
2718a91bc7bSHarrietAkot class SparseTensorStorageBase {
2728a91bc7bSHarrietAkot public:
2738d8b566fSwren romano   /// Constructs a new storage object.  The `perm` maps the tensor's
2748d8b566fSwren romano   /// semantic-ordering of dimensions to this object's storage-order.
275*fa6aed2aSwren romano   /// The `dimSizes` and `sparsity` arrays are already in storage-order.
2768d8b566fSwren romano   ///
277*fa6aed2aSwren romano   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
278*fa6aed2aSwren romano   SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes,
2798d8b566fSwren romano                           const uint64_t *perm, const DimLevelType *sparsity)
280*fa6aed2aSwren romano       : dimSizes(dimSizes), rev(getRank()),
2818d8b566fSwren romano         dimTypes(sparsity, sparsity + getRank()) {
282753fe330Swren romano     assert(perm && sparsity);
2838d8b566fSwren romano     const uint64_t rank = getRank();
2848d8b566fSwren romano     // Validate parameters.
2858d8b566fSwren romano     assert(rank > 0 && "Trivial shape is unsupported");
2868d8b566fSwren romano     for (uint64_t r = 0; r < rank; r++) {
2878d8b566fSwren romano       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
2888d8b566fSwren romano       assert((dimTypes[r] == DimLevelType::kDense ||
2898d8b566fSwren romano               dimTypes[r] == DimLevelType::kCompressed) &&
2908d8b566fSwren romano              "Unsupported DimLevelType");
2918d8b566fSwren romano     }
2928d8b566fSwren romano     // Construct the "reverse" (i.e., inverse) permutation.
2938d8b566fSwren romano     for (uint64_t r = 0; r < rank; r++)
2948d8b566fSwren romano       rev[perm[r]] = r;
2958d8b566fSwren romano   }
2968d8b566fSwren romano 
2978d8b566fSwren romano   virtual ~SparseTensorStorageBase() = default;
2988d8b566fSwren romano 
2998d8b566fSwren romano   /// Get the rank of the tensor.
3008d8b566fSwren romano   uint64_t getRank() const { return dimSizes.size(); }
3018d8b566fSwren romano 
3028d8b566fSwren romano   /// Getter for the dimension-sizes array, in storage-order.
3038d8b566fSwren romano   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
3048d8b566fSwren romano 
3058d8b566fSwren romano   /// Safely lookup the size of the given (storage-order) dimension.
3068d8b566fSwren romano   uint64_t getDimSize(uint64_t d) const {
3078d8b566fSwren romano     assert(d < getRank());
3088d8b566fSwren romano     return dimSizes[d];
3098d8b566fSwren romano   }
3108d8b566fSwren romano 
3118d8b566fSwren romano   /// Getter for the "reverse" permutation, which maps this object's
3128d8b566fSwren romano   /// storage-order to the tensor's semantic-order.
3138d8b566fSwren romano   const std::vector<uint64_t> &getRev() const { return rev; }
3148d8b566fSwren romano 
3158d8b566fSwren romano   /// Getter for the dimension-types array, in storage-order.
3168d8b566fSwren romano   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
3178d8b566fSwren romano 
3188d8b566fSwren romano   /// Safely check if the (storage-order) dimension uses compressed storage.
3198d8b566fSwren romano   bool isCompressedDim(uint64_t d) const {
3208d8b566fSwren romano     assert(d < getRank());
3218d8b566fSwren romano     return (dimTypes[d] == DimLevelType::kCompressed);
3228d8b566fSwren romano   }
3238a91bc7bSHarrietAkot 
3248cb33240Swren romano   /// Allocate a new enumerator.
3251313f5d3Swren romano #define DECL_NEWENUMERATOR(VNAME, V)                                           \
3261313f5d3Swren romano   virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t,       \
3271313f5d3Swren romano                              const uint64_t *) const {                         \
3281313f5d3Swren romano     fatal("newEnumerator" #VNAME);                                             \
3298cb33240Swren romano   }
3301313f5d3Swren romano   FOREVERY_V(DECL_NEWENUMERATOR)
3311313f5d3Swren romano #undef DECL_NEWENUMERATOR
3328cb33240Swren romano 
3334f2ec7f9SAart Bik   /// Overhead storage.
3348a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
3358a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
3368a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
3378a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
3388a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
3398a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
3408a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
3418a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
3428a91bc7bSHarrietAkot 
3434f2ec7f9SAart Bik   /// Primary storage.
3441313f5d3Swren romano #define DECL_GETVALUES(VNAME, V)                                               \
3451313f5d3Swren romano   virtual void getValues(std::vector<V> **) { fatal("getValues" #VNAME); }
3461313f5d3Swren romano   FOREVERY_V(DECL_GETVALUES)
3471313f5d3Swren romano #undef DECL_GETVALUES
3488a91bc7bSHarrietAkot 
3494f2ec7f9SAart Bik   /// Element-wise insertion in lexicographic index order.
3501313f5d3Swren romano #define DECL_LEXINSERT(VNAME, V)                                               \
3511313f5d3Swren romano   virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); }
3521313f5d3Swren romano   FOREVERY_V(DECL_LEXINSERT)
3531313f5d3Swren romano #undef DECL_LEXINSERT
3544f2ec7f9SAart Bik 
3554f2ec7f9SAart Bik   /// Expanded insertion.
3561313f5d3Swren romano #define DECL_EXPINSERT(VNAME, V)                                               \
3571313f5d3Swren romano   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) {      \
3581313f5d3Swren romano     fatal("expInsert" #VNAME);                                                 \
3594f2ec7f9SAart Bik   }
3601313f5d3Swren romano   FOREVERY_V(DECL_EXPINSERT)
3611313f5d3Swren romano #undef DECL_EXPINSERT
3624f2ec7f9SAart Bik 
3634f2ec7f9SAart Bik   /// Finishes insertion.
364f66e5769SAart Bik   virtual void endInsert() = 0;
365f66e5769SAart Bik 
366753fe330Swren romano protected:
367753fe330Swren romano   // Since this class is virtual, we must disallow public copying in
368753fe330Swren romano   // order to avoid "slicing".  Since this class has data members,
369753fe330Swren romano   // that means making copying protected.
370753fe330Swren romano   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
371753fe330Swren romano   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
372753fe330Swren romano   // Copy-assignment would be implicitly deleted (because `dimSizes`
373753fe330Swren romano   // is const), so we explicitly delete it for clarity.
374753fe330Swren romano   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
375753fe330Swren romano 
3768a91bc7bSHarrietAkot private:
37746bdacaaSwren romano   static void fatal(const char *tp) {
3788a91bc7bSHarrietAkot     fprintf(stderr, "unsupported %s\n", tp);
3798a91bc7bSHarrietAkot     exit(1);
3808a91bc7bSHarrietAkot   }
3818d8b566fSwren romano 
3828d8b566fSwren romano   const std::vector<uint64_t> dimSizes;
3838d8b566fSwren romano   std::vector<uint64_t> rev;
3848d8b566fSwren romano   const std::vector<DimLevelType> dimTypes;
3858a91bc7bSHarrietAkot };
3868a91bc7bSHarrietAkot 
387753fe330Swren romano // Forward.
388753fe330Swren romano template <typename P, typename I, typename V>
389753fe330Swren romano class SparseTensorEnumerator;
390753fe330Swren romano 
3918a91bc7bSHarrietAkot /// A memory-resident sparse tensor using a storage scheme based on
3928a91bc7bSHarrietAkot /// per-dimension sparse/dense annotations. This data structure provides a
3938a91bc7bSHarrietAkot /// bufferized form of a sparse tensor type. In contrast to generating setup
3948a91bc7bSHarrietAkot /// methods for each differently annotated sparse tensor, this method provides
3958a91bc7bSHarrietAkot /// a convenient "one-size-fits-all" solution that simply takes an input tensor
3968a91bc7bSHarrietAkot /// and annotations to implement all required setup in a general manner.
3978a91bc7bSHarrietAkot template <typename P, typename I, typename V>
39876944420Swren romano class SparseTensorStorage final : public SparseTensorStorageBase {
3998cb33240Swren romano   /// Private constructor to share code between the other constructors.
4008cb33240Swren romano   /// Beware that the object is not necessarily guaranteed to be in a
4018cb33240Swren romano   /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
4028cb33240Swren romano   /// doesn't entail `!(pointers[d].empty())`.
4038cb33240Swren romano   ///
404*fa6aed2aSwren romano   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
405*fa6aed2aSwren romano   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
406*fa6aed2aSwren romano                       const uint64_t *perm, const DimLevelType *sparsity)
407*fa6aed2aSwren romano       : SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()),
4088cb33240Swren romano         indices(getRank()), idx(getRank()) {}
4098cb33240Swren romano 
4108a91bc7bSHarrietAkot public:
4118a91bc7bSHarrietAkot   /// Constructs a sparse tensor storage scheme with the given dimensions,
4128a91bc7bSHarrietAkot   /// permutation, and per-dimension dense/sparse annotations, using
4138a91bc7bSHarrietAkot   /// the coordinate scheme tensor for the initial contents if provided.
4148d8b566fSwren romano   ///
415*fa6aed2aSwren romano   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
416*fa6aed2aSwren romano   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
417*fa6aed2aSwren romano                       const uint64_t *perm, const DimLevelType *sparsity,
418*fa6aed2aSwren romano                       SparseTensorCOO<V> *coo)
419*fa6aed2aSwren romano       : SparseTensorStorage(dimSizes, perm, sparsity) {
4208a91bc7bSHarrietAkot     // Provide hints on capacity of pointers and indices.
421175b9af4SAart Bik     // TODO: needs much fine-tuning based on actual sparsity; currently
422175b9af4SAart Bik     //       we reserve pointer/index space based on all previous dense
423175b9af4SAart Bik     //       dimensions, which works well up to first sparse dim; but
424175b9af4SAart Bik     //       we should really use nnz and dense/sparse distribution.
425f66e5769SAart Bik     bool allDense = true;
426f66e5769SAart Bik     uint64_t sz = 1;
4278d8b566fSwren romano     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
4288d8b566fSwren romano       if (isCompressedDim(r)) {
429*fa6aed2aSwren romano         // TODO: Take a parameter between 1 and `dimSizes[r]`, and multiply
4308d8b566fSwren romano         // `sz` by that before reserving. (For now we just use 1.)
431f66e5769SAart Bik         pointers[r].reserve(sz + 1);
4328d8b566fSwren romano         pointers[r].push_back(0);
433f66e5769SAart Bik         indices[r].reserve(sz);
434f66e5769SAart Bik         sz = 1;
435f66e5769SAart Bik         allDense = false;
4368d8b566fSwren romano       } else { // Dense dimension.
4378d8b566fSwren romano         sz = checkedMul(sz, getDimSizes()[r]);
4388a91bc7bSHarrietAkot       }
4398a91bc7bSHarrietAkot     }
4408a91bc7bSHarrietAkot     // Then assign contents from coordinate scheme tensor if provided.
4418d8b566fSwren romano     if (coo) {
4424d0a18d0Swren romano       // Ensure both preconditions of `fromCOO`.
443*fa6aed2aSwren romano       assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch");
4448d8b566fSwren romano       coo->sort();
4454d0a18d0Swren romano       // Now actually insert the `elements`.
4468d8b566fSwren romano       const std::vector<Element<V>> &elements = coo->getElements();
447ceda1ae9Swren romano       uint64_t nnz = elements.size();
4488a91bc7bSHarrietAkot       values.reserve(nnz);
449ceda1ae9Swren romano       fromCOO(elements, 0, nnz, 0);
4501ce77b56SAart Bik     } else if (allDense) {
451f66e5769SAart Bik       values.resize(sz, 0);
4528a91bc7bSHarrietAkot     }
4538a91bc7bSHarrietAkot   }
4548a91bc7bSHarrietAkot 
4558cb33240Swren romano   /// Constructs a sparse tensor storage scheme with the given dimensions,
4568cb33240Swren romano   /// permutation, and per-dimension dense/sparse annotations, using
4578cb33240Swren romano   /// the given sparse tensor for the initial contents.
4588cb33240Swren romano   ///
4598cb33240Swren romano   /// Preconditions:
460*fa6aed2aSwren romano   /// * `perm` and `sparsity` must be valid for `dimSizes.size()`.
4618cb33240Swren romano   /// * The `tensor` must have the same value type `V`.
462*fa6aed2aSwren romano   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
463*fa6aed2aSwren romano                       const uint64_t *perm, const DimLevelType *sparsity,
4648cb33240Swren romano                       const SparseTensorStorageBase &tensor);
4658cb33240Swren romano 
46676944420Swren romano   ~SparseTensorStorage() final override = default;
4678a91bc7bSHarrietAkot 
468f66e5769SAart Bik   /// Partially specialize these getter methods based on template types.
46976944420Swren romano   void getPointers(std::vector<P> **out, uint64_t d) final override {
4708a91bc7bSHarrietAkot     assert(d < getRank());
4718a91bc7bSHarrietAkot     *out = &pointers[d];
4728a91bc7bSHarrietAkot   }
47376944420Swren romano   void getIndices(std::vector<I> **out, uint64_t d) final override {
4748a91bc7bSHarrietAkot     assert(d < getRank());
4758a91bc7bSHarrietAkot     *out = &indices[d];
4768a91bc7bSHarrietAkot   }
47776944420Swren romano   void getValues(std::vector<V> **out) final override { *out = &values; }
4788a91bc7bSHarrietAkot 
47903fe15ceSAart Bik   /// Partially specialize lexicographical insertions based on template types.
48076944420Swren romano   void lexInsert(const uint64_t *cursor, V val) final override {
4811ce77b56SAart Bik     // First, wrap up pending insertion path.
4821ce77b56SAart Bik     uint64_t diff = 0;
4831ce77b56SAart Bik     uint64_t top = 0;
4841ce77b56SAart Bik     if (!values.empty()) {
4851ce77b56SAart Bik       diff = lexDiff(cursor);
4861ce77b56SAart Bik       endPath(diff + 1);
4871ce77b56SAart Bik       top = idx[diff] + 1;
4881ce77b56SAart Bik     }
4891ce77b56SAart Bik     // Then continue with insertion path.
4901ce77b56SAart Bik     insPath(cursor, diff, top, val);
491f66e5769SAart Bik   }
492f66e5769SAart Bik 
4934f2ec7f9SAart Bik   /// Partially specialize expanded insertions based on template types.
4944f2ec7f9SAart Bik   /// Note that this method resets the values/filled-switch array back
4954f2ec7f9SAart Bik   /// to all-zero/false while only iterating over the nonzero elements.
4964f2ec7f9SAart Bik   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
49776944420Swren romano                  uint64_t count) final override {
4984f2ec7f9SAart Bik     if (count == 0)
4994f2ec7f9SAart Bik       return;
5004f2ec7f9SAart Bik     // Sort.
5014f2ec7f9SAart Bik     std::sort(added, added + count);
5024f2ec7f9SAart Bik     // Restore insertion path for first insert.
5033bf2ba3bSwren romano     const uint64_t lastDim = getRank() - 1;
5044f2ec7f9SAart Bik     uint64_t index = added[0];
5053bf2ba3bSwren romano     cursor[lastDim] = index;
5064f2ec7f9SAart Bik     lexInsert(cursor, values[index]);
5074f2ec7f9SAart Bik     assert(filled[index]);
5084f2ec7f9SAart Bik     values[index] = 0;
5094f2ec7f9SAart Bik     filled[index] = false;
5104f2ec7f9SAart Bik     // Subsequent insertions are quick.
5114f2ec7f9SAart Bik     for (uint64_t i = 1; i < count; i++) {
5124f2ec7f9SAart Bik       assert(index < added[i] && "non-lexicographic insertion");
5134f2ec7f9SAart Bik       index = added[i];
5143bf2ba3bSwren romano       cursor[lastDim] = index;
5153bf2ba3bSwren romano       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
5164f2ec7f9SAart Bik       assert(filled[index]);
5173bf2ba3bSwren romano       values[index] = 0;
5184f2ec7f9SAart Bik       filled[index] = false;
5194f2ec7f9SAart Bik     }
5204f2ec7f9SAart Bik   }
5214f2ec7f9SAart Bik 
522f66e5769SAart Bik   /// Finalizes lexicographic insertions.
52376944420Swren romano   void endInsert() final override {
5241ce77b56SAart Bik     if (values.empty())
52572ec2f76Swren romano       finalizeSegment(0);
5261ce77b56SAart Bik     else
5271ce77b56SAart Bik       endPath(0);
5281ce77b56SAart Bik   }
529f66e5769SAart Bik 
5308cb33240Swren romano   void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
53176944420Swren romano                      const uint64_t *perm) const final override {
5328cb33240Swren romano     *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
5338cb33240Swren romano   }
5348cb33240Swren romano 
5358a91bc7bSHarrietAkot   /// Returns this sparse tensor storage scheme as a new memory-resident
5368a91bc7bSHarrietAkot   /// sparse tensor in coordinate scheme with the given dimension order.
5378d8b566fSwren romano   ///
5388d8b566fSwren romano   /// Precondition: `perm` must be valid for `getRank()`.
539753fe330Swren romano   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
5408cb33240Swren romano     SparseTensorEnumeratorBase<V> *enumerator;
5418cb33240Swren romano     newEnumerator(&enumerator, getRank(), perm);
542753fe330Swren romano     SparseTensorCOO<V> *coo =
5438cb33240Swren romano         new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
5448cb33240Swren romano     enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
545753fe330Swren romano       coo->add(ind, val);
546753fe330Swren romano     });
5478d8b566fSwren romano     // TODO: This assertion assumes there are no stored zeros,
5488d8b566fSwren romano     // or if there are then that we don't filter them out.
5498d8b566fSwren romano     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
5508d8b566fSwren romano     assert(coo->getElements().size() == values.size());
5518cb33240Swren romano     delete enumerator;
5528d8b566fSwren romano     return coo;
5538a91bc7bSHarrietAkot   }
5548a91bc7bSHarrietAkot 
5558a91bc7bSHarrietAkot   /// Factory method. Constructs a sparse tensor storage scheme with the given
5568a91bc7bSHarrietAkot   /// dimensions, permutation, and per-dimension dense/sparse annotations,
5578a91bc7bSHarrietAkot   /// using the coordinate scheme tensor for the initial contents if provided.
5588a91bc7bSHarrietAkot   /// In the latter case, the coordinate scheme must respect the same
5598a91bc7bSHarrietAkot   /// permutation as is desired for the new sparse tensor storage.
5608d8b566fSwren romano   ///
5618d8b566fSwren romano   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
5628a91bc7bSHarrietAkot   static SparseTensorStorage<P, I, V> *
563d83a7068Swren romano   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
5648d8b566fSwren romano                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
5658a91bc7bSHarrietAkot     SparseTensorStorage<P, I, V> *n = nullptr;
5668d8b566fSwren romano     if (coo) {
567*fa6aed2aSwren romano       const auto &coosz = coo->getDimSizes();
5688cb33240Swren romano       assertPermutedSizesMatchShape(coosz, rank, perm, shape);
5698d8b566fSwren romano       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
5708a91bc7bSHarrietAkot     } else {
5718a91bc7bSHarrietAkot       std::vector<uint64_t> permsz(rank);
572d83a7068Swren romano       for (uint64_t r = 0; r < rank; r++) {
573d83a7068Swren romano         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
574d83a7068Swren romano         permsz[perm[r]] = shape[r];
575d83a7068Swren romano       }
5768cb33240Swren romano       // We pass the null `coo` to ensure we select the intended constructor.
5778cb33240Swren romano       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
5788a91bc7bSHarrietAkot     }
5798a91bc7bSHarrietAkot     return n;
5808a91bc7bSHarrietAkot   }
5818a91bc7bSHarrietAkot 
5828cb33240Swren romano   /// Factory method. Constructs a sparse tensor storage scheme with
5838cb33240Swren romano   /// the given dimensions, permutation, and per-dimension dense/sparse
5848cb33240Swren romano   /// annotations, using the sparse tensor for the initial contents.
5858cb33240Swren romano   ///
5868cb33240Swren romano   /// Preconditions:
5878cb33240Swren romano   /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
5888cb33240Swren romano   /// * The `tensor` must have the same value type `V`.
5898cb33240Swren romano   static SparseTensorStorage<P, I, V> *
5908cb33240Swren romano   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
5918cb33240Swren romano                   const DimLevelType *sparsity,
5928cb33240Swren romano                   const SparseTensorStorageBase *source) {
5938cb33240Swren romano     assert(source && "Got nullptr for source");
5948cb33240Swren romano     SparseTensorEnumeratorBase<V> *enumerator;
5958cb33240Swren romano     source->newEnumerator(&enumerator, rank, perm);
5968cb33240Swren romano     const auto &permsz = enumerator->permutedSizes();
5978cb33240Swren romano     assertPermutedSizesMatchShape(permsz, rank, perm, shape);
5988cb33240Swren romano     auto *tensor =
5998cb33240Swren romano         new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
6008cb33240Swren romano     delete enumerator;
6018cb33240Swren romano     return tensor;
6028cb33240Swren romano   }
6038cb33240Swren romano 
6048a91bc7bSHarrietAkot private:
60572ec2f76Swren romano   /// Appends an arbitrary new position to `pointers[d]`.  This method
60672ec2f76Swren romano   /// checks that `pos` is representable in the `P` type; however, it
60772ec2f76Swren romano   /// does not check that `pos` is semantically valid (i.e., larger than
60872ec2f76Swren romano   /// the previous position and smaller than `indices[d].capacity()`).
6098d8b566fSwren romano   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
61072ec2f76Swren romano     assert(isCompressedDim(d));
61172ec2f76Swren romano     assert(pos <= std::numeric_limits<P>::max() &&
6124d0a18d0Swren romano            "Pointer value is too large for the P-type");
61372ec2f76Swren romano     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
6144d0a18d0Swren romano   }
6154d0a18d0Swren romano 
61672ec2f76Swren romano   /// Appends index `i` to dimension `d`, in the semantically general
61772ec2f76Swren romano   /// sense.  For non-dense dimensions, that means appending to the
61872ec2f76Swren romano   /// `indices[d]` array, checking that `i` is representable in the `I`
61972ec2f76Swren romano   /// type; however, we do not verify other semantic requirements (e.g.,
620*fa6aed2aSwren romano   /// that `i` is in bounds for `dimSizes[d]`, and not previously occurring
62172ec2f76Swren romano   /// in the same segment).  For dense dimensions, this method instead
62272ec2f76Swren romano   /// appends the appropriate number of zeros to the `values` array,
62372ec2f76Swren romano   /// where `full` is the number of "entries" already written to `values`
62472ec2f76Swren romano   /// for this segment (aka one after the highest index previously appended).
62572ec2f76Swren romano   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
62672ec2f76Swren romano     if (isCompressedDim(d)) {
6274d0a18d0Swren romano       assert(i <= std::numeric_limits<I>::max() &&
6284d0a18d0Swren romano              "Index value is too large for the I-type");
62972ec2f76Swren romano       indices[d].push_back(static_cast<I>(i));
63072ec2f76Swren romano     } else { // Dense dimension.
63172ec2f76Swren romano       assert(i >= full && "Index was already filled");
63272ec2f76Swren romano       if (i == full)
63372ec2f76Swren romano         return; // Short-circuit, since it'll be a nop.
63472ec2f76Swren romano       if (d + 1 == getRank())
63572ec2f76Swren romano         values.insert(values.end(), i - full, 0);
63672ec2f76Swren romano       else
63772ec2f76Swren romano         finalizeSegment(d + 1, 0, i - full);
63872ec2f76Swren romano     }
6394d0a18d0Swren romano   }
6404d0a18d0Swren romano 
6418cb33240Swren romano   /// Writes the given coordinate to `indices[d][pos]`.  This method
6428cb33240Swren romano   /// checks that `i` is representable in the `I` type; however, it
6438cb33240Swren romano   /// does not check that `i` is semantically valid (i.e., in bounds
644*fa6aed2aSwren romano   /// for `dimSizes[d]` and not elsewhere occurring in the same segment).
6458cb33240Swren romano   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
6468cb33240Swren romano     assert(isCompressedDim(d));
6478cb33240Swren romano     // Subscript assignment to `std::vector` requires that the `pos`-th
6488cb33240Swren romano     // entry has been initialized; thus we must be sure to check `size()`
6498cb33240Swren romano     // here, instead of `capacity()` as would be ideal.
6508cb33240Swren romano     assert(pos < indices[d].size() && "Index position is out of bounds");
6518cb33240Swren romano     assert(i <= std::numeric_limits<I>::max() &&
6528cb33240Swren romano            "Index value is too large for the I-type");
6538cb33240Swren romano     indices[d][pos] = static_cast<I>(i);
6548cb33240Swren romano   }
6558cb33240Swren romano 
6568cb33240Swren romano   /// Computes the assembled-size associated with the `d`-th dimension,
6578cb33240Swren romano   /// given the assembled-size associated with the `(d-1)`-th dimension.
6588cb33240Swren romano   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
6598cb33240Swren romano   /// storage, as opposed to "dimension-sizes" which are the cardinality
6608cb33240Swren romano   /// of coordinates for that dimension.
6618cb33240Swren romano   ///
6628cb33240Swren romano   /// Precondition: the `pointers[d]` array must be fully initialized
6638cb33240Swren romano   /// before calling this method.
6648cb33240Swren romano   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
6658cb33240Swren romano     if (isCompressedDim(d))
6668cb33240Swren romano       return pointers[d][parentSz];
6678cb33240Swren romano     // else if dense:
6688cb33240Swren romano     return parentSz * getDimSizes()[d];
6698cb33240Swren romano   }
6708cb33240Swren romano 
6718a91bc7bSHarrietAkot   /// Initializes sparse tensor storage scheme from a memory-resident sparse
6728a91bc7bSHarrietAkot   /// tensor in coordinate scheme. This method prepares the pointers and
6738a91bc7bSHarrietAkot   /// indices arrays under the given per-dimension dense/sparse annotations.
6744d0a18d0Swren romano   ///
6754d0a18d0Swren romano   /// Preconditions:
6764d0a18d0Swren romano   /// (1) the `elements` must be lexicographically sorted.
677*fa6aed2aSwren romano   /// (2) the indices of every element are valid for `dimSizes` (equal rank
6784d0a18d0Swren romano   ///     and pointwise less-than).
679ceda1ae9Swren romano   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
680ceda1ae9Swren romano                uint64_t hi, uint64_t d) {
681753fe330Swren romano     uint64_t rank = getRank();
682753fe330Swren romano     assert(d <= rank && hi <= elements.size());
6838a91bc7bSHarrietAkot     // Once dimensions are exhausted, insert the numerical values.
684753fe330Swren romano     if (d == rank) {
685c4017f9dSwren romano       assert(lo < hi);
6861ce77b56SAart Bik       values.push_back(elements[lo].value);
6878a91bc7bSHarrietAkot       return;
6888a91bc7bSHarrietAkot     }
6898a91bc7bSHarrietAkot     // Visit all elements in this interval.
6908a91bc7bSHarrietAkot     uint64_t full = 0;
691c4017f9dSwren romano     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
6928a91bc7bSHarrietAkot       // Find segment in interval with same index elements in this dimension.
693f66e5769SAart Bik       uint64_t i = elements[lo].indices[d];
6948a91bc7bSHarrietAkot       uint64_t seg = lo + 1;
695f66e5769SAart Bik       while (seg < hi && elements[seg].indices[d] == i)
6968a91bc7bSHarrietAkot         seg++;
6978a91bc7bSHarrietAkot       // Handle segment in interval for sparse or dense dimension.
69872ec2f76Swren romano       appendIndex(d, full, i);
69972ec2f76Swren romano       full = i + 1;
700ceda1ae9Swren romano       fromCOO(elements, lo, seg, d + 1);
7018a91bc7bSHarrietAkot       // And move on to next segment in interval.
7028a91bc7bSHarrietAkot       lo = seg;
7038a91bc7bSHarrietAkot     }
7048a91bc7bSHarrietAkot     // Finalize the sparse pointer structure at this dimension.
70572ec2f76Swren romano     finalizeSegment(d, full);
7068a91bc7bSHarrietAkot   }
7078a91bc7bSHarrietAkot 
70872ec2f76Swren romano   /// Finalize the sparse pointer structure at this dimension.
70972ec2f76Swren romano   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
71072ec2f76Swren romano     if (count == 0)
71172ec2f76Swren romano       return; // Short-circuit, since it'll be a nop.
71272ec2f76Swren romano     if (isCompressedDim(d)) {
71372ec2f76Swren romano       appendPointer(d, indices[d].size(), count);
71472ec2f76Swren romano     } else { // Dense dimension.
7158d8b566fSwren romano       const uint64_t sz = getDimSizes()[d];
71672ec2f76Swren romano       assert(sz >= full && "Segment is overfull");
7178d8b566fSwren romano       count = checkedMul(count, sz - full);
71872ec2f76Swren romano       // For dense storage we must enumerate all the remaining coordinates
71972ec2f76Swren romano       // in this dimension (i.e., coordinates after the last non-zero
72072ec2f76Swren romano       // element), and either fill in their zero values or else recurse
72172ec2f76Swren romano       // to finalize some deeper dimension.
72272ec2f76Swren romano       if (d + 1 == getRank())
72372ec2f76Swren romano         values.insert(values.end(), count, 0);
72472ec2f76Swren romano       else
72572ec2f76Swren romano         finalizeSegment(d + 1, 0, count);
7261ce77b56SAart Bik     }
7271ce77b56SAart Bik   }
7281ce77b56SAart Bik 
7291ce77b56SAart Bik   /// Wraps up a single insertion path, inner to outer.
7301ce77b56SAart Bik   void endPath(uint64_t diff) {
7311ce77b56SAart Bik     uint64_t rank = getRank();
7321ce77b56SAart Bik     assert(diff <= rank);
7331ce77b56SAart Bik     for (uint64_t i = 0; i < rank - diff; i++) {
73472ec2f76Swren romano       const uint64_t d = rank - i - 1;
73572ec2f76Swren romano       finalizeSegment(d, idx[d] + 1);
7361ce77b56SAart Bik     }
7371ce77b56SAart Bik   }
7381ce77b56SAart Bik 
7391ce77b56SAart Bik   /// Continues a single insertion path, outer to inner.
740c03fd1e6Swren romano   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
7411ce77b56SAart Bik     uint64_t rank = getRank();
7421ce77b56SAart Bik     assert(diff < rank);
7431ce77b56SAart Bik     for (uint64_t d = diff; d < rank; d++) {
7441ce77b56SAart Bik       uint64_t i = cursor[d];
74572ec2f76Swren romano       appendIndex(d, top, i);
7461ce77b56SAart Bik       top = 0;
7471ce77b56SAart Bik       idx[d] = i;
7481ce77b56SAart Bik     }
7491ce77b56SAart Bik     values.push_back(val);
7501ce77b56SAart Bik   }
7511ce77b56SAart Bik 
7521ce77b56SAart Bik   /// Finds the lexicographic differing dimension.
75346bdacaaSwren romano   uint64_t lexDiff(const uint64_t *cursor) const {
7541ce77b56SAart Bik     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
7551ce77b56SAart Bik       if (cursor[r] > idx[r])
7561ce77b56SAart Bik         return r;
7571ce77b56SAart Bik       else
7581ce77b56SAart Bik         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
7591ce77b56SAart Bik     assert(0 && "duplication insertion");
7601ce77b56SAart Bik     return -1u;
7611ce77b56SAart Bik   }
7621ce77b56SAart Bik 
763753fe330Swren romano   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
764753fe330Swren romano   // the cost of virtual-function dispatch in inner loops), without
765753fe330Swren romano   // making them public to other client code.
766753fe330Swren romano   friend class SparseTensorEnumerator<P, I, V>;
767753fe330Swren romano 
7688a91bc7bSHarrietAkot   std::vector<std::vector<P>> pointers;
7698a91bc7bSHarrietAkot   std::vector<std::vector<I>> indices;
7708a91bc7bSHarrietAkot   std::vector<V> values;
7718d8b566fSwren romano   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
7728a91bc7bSHarrietAkot };
7738a91bc7bSHarrietAkot 
774753fe330Swren romano /// A (higher-order) function object for enumerating the elements of some
775753fe330Swren romano /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
776753fe330Swren romano /// method encapsulates the loop-nest for enumerating the elements of
777753fe330Swren romano /// the source tensor (in whatever order is best for the source tensor),
778753fe330Swren romano /// and applies a permutation to the coordinates/indices before handing
779753fe330Swren romano /// each element to the callback.  A single enumerator object can be
780753fe330Swren romano /// freely reused for several calls to `forallElements`, just so long
781753fe330Swren romano /// as each call is sequential with respect to one another.
782753fe330Swren romano ///
783753fe330Swren romano /// N.B., this class stores a reference to the `SparseTensorStorageBase`
784753fe330Swren romano /// passed to the constructor; thus, objects of this class must not
785753fe330Swren romano /// outlive the sparse tensor they depend on.
786753fe330Swren romano ///
787753fe330Swren romano /// Design Note: The reason we define this class instead of simply using
788753fe330Swren romano /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
789753fe330Swren romano /// the `<P,I>` template parameters from MLIR client code (to simplify the
790753fe330Swren romano /// type parameters used for direct sparse-to-sparse conversion).  And the
791753fe330Swren romano /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
792753fe330Swren romano /// than simply using this class, is to avoid the cost of virtual-method
793753fe330Swren romano /// dispatch within the loop-nest.
794753fe330Swren romano template <typename V>
795753fe330Swren romano class SparseTensorEnumeratorBase {
796753fe330Swren romano public:
797753fe330Swren romano   /// Constructs an enumerator with the given permutation for mapping
798753fe330Swren romano   /// the semantic-ordering of dimensions to the desired target-ordering.
799753fe330Swren romano   ///
800753fe330Swren romano   /// Preconditions:
801753fe330Swren romano   /// * the `tensor` must have the same `V` value type.
802753fe330Swren romano   /// * `perm` must be valid for `rank`.
803753fe330Swren romano   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
804753fe330Swren romano                              uint64_t rank, const uint64_t *perm)
805753fe330Swren romano       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
806753fe330Swren romano         cursor(getRank()) {
807753fe330Swren romano     assert(perm && "Received nullptr for permutation");
808753fe330Swren romano     assert(rank == getRank() && "Permutation rank mismatch");
809*fa6aed2aSwren romano     const auto &rev = src.getRev();           // source-order -> semantic-order
810*fa6aed2aSwren romano     const auto &dimSizes = src.getDimSizes(); // in source storage-order
811753fe330Swren romano     for (uint64_t s = 0; s < rank; s++) {     // `s` source storage-order
812753fe330Swren romano       uint64_t t = perm[rev[s]];              // `t` target-order
813753fe330Swren romano       reord[s] = t;
814*fa6aed2aSwren romano       permsz[t] = dimSizes[s];
815753fe330Swren romano     }
816753fe330Swren romano   }
817753fe330Swren romano 
818753fe330Swren romano   virtual ~SparseTensorEnumeratorBase() = default;
819753fe330Swren romano 
820753fe330Swren romano   // We disallow copying to help avoid leaking the `src` reference.
821753fe330Swren romano   // (In addition to avoiding the problem of slicing.)
822753fe330Swren romano   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
823753fe330Swren romano   SparseTensorEnumeratorBase &
824753fe330Swren romano   operator=(const SparseTensorEnumeratorBase &) = delete;
825753fe330Swren romano 
826753fe330Swren romano   /// Returns the source/target tensor's rank.  (The source-rank and
827753fe330Swren romano   /// target-rank are always equal since we only support permutations.
828753fe330Swren romano   /// Though once we add support for other dimension mappings, this
829753fe330Swren romano   /// method will have to be split in two.)
830753fe330Swren romano   uint64_t getRank() const { return permsz.size(); }
831753fe330Swren romano 
832753fe330Swren romano   /// Returns the target tensor's dimension sizes.
833753fe330Swren romano   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
834753fe330Swren romano 
835753fe330Swren romano   /// Enumerates all elements of the source tensor, permutes their
836753fe330Swren romano   /// indices, and passes the permuted element to the callback.
837753fe330Swren romano   /// The callback must not store the cursor reference directly,
838753fe330Swren romano   /// since this function reuses the storage.  Instead, the callback
839753fe330Swren romano   /// must copy it if they want to keep it.
840753fe330Swren romano   virtual void forallElements(ElementConsumer<V> yield) = 0;
841753fe330Swren romano 
842753fe330Swren romano protected:
843753fe330Swren romano   const SparseTensorStorageBase &src;
844753fe330Swren romano   std::vector<uint64_t> permsz; // in target order.
845753fe330Swren romano   std::vector<uint64_t> reord;  // source storage-order -> target order.
846753fe330Swren romano   std::vector<uint64_t> cursor; // in target order.
847753fe330Swren romano };
848753fe330Swren romano 
849753fe330Swren romano template <typename P, typename I, typename V>
850753fe330Swren romano class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
851753fe330Swren romano   using Base = SparseTensorEnumeratorBase<V>;
852753fe330Swren romano 
853753fe330Swren romano public:
854753fe330Swren romano   /// Constructs an enumerator with the given permutation for mapping
855753fe330Swren romano   /// the semantic-ordering of dimensions to the desired target-ordering.
856753fe330Swren romano   ///
857753fe330Swren romano   /// Precondition: `perm` must be valid for `rank`.
858753fe330Swren romano   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
859753fe330Swren romano                          uint64_t rank, const uint64_t *perm)
860753fe330Swren romano       : Base(tensor, rank, perm) {}
861753fe330Swren romano 
862753fe330Swren romano   ~SparseTensorEnumerator() final override = default;
863753fe330Swren romano 
864753fe330Swren romano   void forallElements(ElementConsumer<V> yield) final override {
865753fe330Swren romano     forallElements(yield, 0, 0);
866753fe330Swren romano   }
867753fe330Swren romano 
868753fe330Swren romano private:
869753fe330Swren romano   /// The recursive component of the public `forallElements`.
870753fe330Swren romano   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
871753fe330Swren romano                       uint64_t d) {
872753fe330Swren romano     // Recover the `<P,I,V>` type parameters of `src`.
873753fe330Swren romano     const auto &src =
874753fe330Swren romano         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
875753fe330Swren romano     if (d == Base::getRank()) {
876753fe330Swren romano       assert(parentPos < src.values.size() &&
877753fe330Swren romano              "Value position is out of bounds");
878753fe330Swren romano       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
879753fe330Swren romano       yield(this->cursor, src.values[parentPos]);
880753fe330Swren romano     } else if (src.isCompressedDim(d)) {
881753fe330Swren romano       // Look up the bounds of the `d`-level segment determined by the
882753fe330Swren romano       // `d-1`-level position `parentPos`.
883753fe330Swren romano       const std::vector<P> &pointers_d = src.pointers[d];
884753fe330Swren romano       assert(parentPos + 1 < pointers_d.size() &&
885753fe330Swren romano              "Parent pointer position is out of bounds");
886753fe330Swren romano       const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
887753fe330Swren romano       const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
888753fe330Swren romano       // Loop-invariant code for looking up the `d`-level coordinates/indices.
889753fe330Swren romano       const std::vector<I> &indices_d = src.indices[d];
8903b13f880SAart Bik       assert(pstop <= indices_d.size() && "Index position is out of bounds");
891753fe330Swren romano       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
892753fe330Swren romano       for (uint64_t pos = pstart; pos < pstop; pos++) {
893753fe330Swren romano         cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
894753fe330Swren romano         forallElements(yield, pos, d + 1);
895753fe330Swren romano       }
896753fe330Swren romano     } else { // Dense dimension.
897753fe330Swren romano       const uint64_t sz = src.getDimSizes()[d];
898753fe330Swren romano       const uint64_t pstart = parentPos * sz;
899753fe330Swren romano       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
900753fe330Swren romano       for (uint64_t i = 0; i < sz; i++) {
901753fe330Swren romano         cursor_reord_d = i;
902753fe330Swren romano         forallElements(yield, pstart + i, d + 1);
903753fe330Swren romano       }
904753fe330Swren romano     }
905753fe330Swren romano   }
906753fe330Swren romano };
907753fe330Swren romano 
9088cb33240Swren romano /// Statistics regarding the number of nonzero subtensors in
9098cb33240Swren romano /// a source tensor, for direct sparse=>sparse conversion a la
9108cb33240Swren romano /// <https://arxiv.org/abs/2001.02609>.
9118cb33240Swren romano ///
9128cb33240Swren romano /// N.B., this class stores references to the parameters passed to
9138cb33240Swren romano /// the constructor; thus, objects of this class must not outlive
9148cb33240Swren romano /// those parameters.
91576944420Swren romano class SparseTensorNNZ final {
9168cb33240Swren romano public:
9178cb33240Swren romano   /// Allocate the statistics structure for the desired sizes and
9188cb33240Swren romano   /// sparsity (in the target tensor's storage-order).  This constructor
9198cb33240Swren romano   /// does not actually populate the statistics, however; for that see
9208cb33240Swren romano   /// `initialize`.
9218cb33240Swren romano   ///
922*fa6aed2aSwren romano   /// Precondition: `dimSizes` must not contain zeros.
923*fa6aed2aSwren romano   SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
9248cb33240Swren romano                   const std::vector<DimLevelType> &sparsity)
925*fa6aed2aSwren romano       : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
9268cb33240Swren romano     assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
9278cb33240Swren romano     bool uncompressed = true;
9288cb33240Swren romano     uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
9298cb33240Swren romano     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
9308cb33240Swren romano       switch (dimTypes[r]) {
9318cb33240Swren romano       case DimLevelType::kCompressed:
9328cb33240Swren romano         assert(uncompressed &&
9338cb33240Swren romano                "Multiple compressed layers not currently supported");
9348cb33240Swren romano         uncompressed = false;
9358cb33240Swren romano         nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
9368cb33240Swren romano         break;
9378cb33240Swren romano       case DimLevelType::kDense:
9388cb33240Swren romano         assert(uncompressed &&
9398cb33240Swren romano                "Dense after compressed not currently supported");
9408cb33240Swren romano         break;
9418cb33240Swren romano       case DimLevelType::kSingleton:
9428cb33240Swren romano         // Singleton after Compressed causes no problems for allocating
9438cb33240Swren romano         // `nnz` nor for the yieldPos loop.  This remains true even
9448cb33240Swren romano         // when adding support for multiple compressed dimensions or
9458cb33240Swren romano         // for dense-after-compressed.
9468cb33240Swren romano         break;
9478cb33240Swren romano       }
9488cb33240Swren romano       sz = checkedMul(sz, dimSizes[r]);
9498cb33240Swren romano     }
9508cb33240Swren romano   }
9518cb33240Swren romano 
9528cb33240Swren romano   // We disallow copying to help avoid leaking the stored references.
9538cb33240Swren romano   SparseTensorNNZ(const SparseTensorNNZ &) = delete;
9548cb33240Swren romano   SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
9558cb33240Swren romano 
9568cb33240Swren romano   /// Returns the rank of the target tensor.
9578cb33240Swren romano   uint64_t getRank() const { return dimSizes.size(); }
9588cb33240Swren romano 
9598cb33240Swren romano   /// Enumerate the source tensor to fill in the statistics.  The
9608cb33240Swren romano   /// enumerator should already incorporate the permutation (from
9618cb33240Swren romano   /// semantic-order to the target storage-order).
9628cb33240Swren romano   template <typename V>
9638cb33240Swren romano   void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
9648cb33240Swren romano     assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
9658cb33240Swren romano     assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
9668cb33240Swren romano     enumerator.forallElements(
9678cb33240Swren romano         [this](const std::vector<uint64_t> &ind, V) { add(ind); });
9688cb33240Swren romano   }
9698cb33240Swren romano 
9708cb33240Swren romano   /// The type of callback functions which receive an nnz-statistic.
9718cb33240Swren romano   using NNZConsumer = const std::function<void(uint64_t)> &;
9728cb33240Swren romano 
9738cb33240Swren romano   /// Lexicographically enumerates all indicies for dimensions strictly
9748cb33240Swren romano   /// less than `stopDim`, and passes their nnz statistic to the callback.
9758cb33240Swren romano   /// Since our use-case only requires the statistic not the coordinates
9768cb33240Swren romano   /// themselves, we do not bother to construct those coordinates.
9778cb33240Swren romano   void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
9788cb33240Swren romano     assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
9798cb33240Swren romano     assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
9808cb33240Swren romano            "Cannot look up non-compressed dimensions");
9818cb33240Swren romano     forallIndices(yield, stopDim, 0, 0);
9828cb33240Swren romano   }
9838cb33240Swren romano 
9848cb33240Swren romano private:
9858cb33240Swren romano   /// Adds a new element (i.e., increment its statistics).  We use
9868cb33240Swren romano   /// a method rather than inlining into the lambda in `initialize`,
9878cb33240Swren romano   /// to avoid spurious templating over `V`.  And this method is private
9888cb33240Swren romano   /// to avoid needing to re-assert validity of `ind` (which is guaranteed
9898cb33240Swren romano   /// by `forallElements`).
9908cb33240Swren romano   void add(const std::vector<uint64_t> &ind) {
9918cb33240Swren romano     uint64_t parentPos = 0;
9928cb33240Swren romano     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
9938cb33240Swren romano       if (dimTypes[r] == DimLevelType::kCompressed)
9948cb33240Swren romano         nnz[r][parentPos]++;
9958cb33240Swren romano       parentPos = parentPos * dimSizes[r] + ind[r];
9968cb33240Swren romano     }
9978cb33240Swren romano   }
9988cb33240Swren romano 
9998cb33240Swren romano   /// Recursive component of the public `forallIndices`.
10008cb33240Swren romano   void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
10018cb33240Swren romano                      uint64_t d) const {
10028cb33240Swren romano     assert(d <= stopDim);
10038cb33240Swren romano     if (d == stopDim) {
10048cb33240Swren romano       assert(parentPos < nnz[d].size() && "Cursor is out of range");
10058cb33240Swren romano       yield(nnz[d][parentPos]);
10068cb33240Swren romano     } else {
10078cb33240Swren romano       const uint64_t sz = dimSizes[d];
10088cb33240Swren romano       const uint64_t pstart = parentPos * sz;
10098cb33240Swren romano       for (uint64_t i = 0; i < sz; i++)
10108cb33240Swren romano         forallIndices(yield, stopDim, pstart + i, d + 1);
10118cb33240Swren romano     }
10128cb33240Swren romano   }
10138cb33240Swren romano 
10148cb33240Swren romano   // All of these are in the target storage-order.
10158cb33240Swren romano   const std::vector<uint64_t> &dimSizes;
10168cb33240Swren romano   const std::vector<DimLevelType> &dimTypes;
10178cb33240Swren romano   std::vector<std::vector<uint64_t>> nnz;
10188cb33240Swren romano };
10198cb33240Swren romano 
10208cb33240Swren romano template <typename P, typename I, typename V>
10218cb33240Swren romano SparseTensorStorage<P, I, V>::SparseTensorStorage(
1022*fa6aed2aSwren romano     const std::vector<uint64_t> &dimSizes, const uint64_t *perm,
10238cb33240Swren romano     const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1024*fa6aed2aSwren romano     : SparseTensorStorage(dimSizes, perm, sparsity) {
10258cb33240Swren romano   SparseTensorEnumeratorBase<V> *enumerator;
10268cb33240Swren romano   tensor.newEnumerator(&enumerator, getRank(), perm);
10278cb33240Swren romano   {
10288cb33240Swren romano     // Initialize the statistics structure.
10298cb33240Swren romano     SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
10308cb33240Swren romano     nnz.initialize(*enumerator);
10318cb33240Swren romano     // Initialize "pointers" overhead (and allocate "indices", "values").
10328cb33240Swren romano     uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
10338cb33240Swren romano     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
10348cb33240Swren romano       if (isCompressedDim(r)) {
10358cb33240Swren romano         pointers[r].reserve(parentSz + 1);
10368cb33240Swren romano         pointers[r].push_back(0);
10378cb33240Swren romano         uint64_t currentPos = 0;
10388cb33240Swren romano         nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
10398cb33240Swren romano           currentPos += n;
10408cb33240Swren romano           appendPointer(r, currentPos);
10418cb33240Swren romano         });
10428cb33240Swren romano         assert(pointers[r].size() == parentSz + 1 &&
10438cb33240Swren romano                "Final pointers size doesn't match allocated size");
10448cb33240Swren romano         // That assertion entails `assembledSize(parentSz, r)`
10458cb33240Swren romano         // is now in a valid state.  That is, `pointers[r][parentSz]`
10468cb33240Swren romano         // equals the present value of `currentPos`, which is the
10478cb33240Swren romano         // correct assembled-size for `indices[r]`.
10488cb33240Swren romano       }
10498cb33240Swren romano       // Update assembled-size for the next iteration.
10508cb33240Swren romano       parentSz = assembledSize(parentSz, r);
10518cb33240Swren romano       // Ideally we need only `indices[r].reserve(parentSz)`, however
10528cb33240Swren romano       // the `std::vector` implementation forces us to initialize it too.
10538cb33240Swren romano       // That is, in the yieldPos loop we need random-access assignment
10548cb33240Swren romano       // to `indices[r]`; however, `std::vector`'s subscript-assignment
10558cb33240Swren romano       // only allows assigning to already-initialized positions.
10568cb33240Swren romano       if (isCompressedDim(r))
10578cb33240Swren romano         indices[r].resize(parentSz, 0);
10588cb33240Swren romano     }
10598cb33240Swren romano     values.resize(parentSz, 0); // Both allocate and zero-initialize.
10608cb33240Swren romano   }
10618cb33240Swren romano   // The yieldPos loop
10628cb33240Swren romano   enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
10638cb33240Swren romano     uint64_t parentSz = 1, parentPos = 0;
10648cb33240Swren romano     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
10658cb33240Swren romano       if (isCompressedDim(r)) {
10668cb33240Swren romano         // If `parentPos == parentSz` then it's valid as an array-lookup;
10678cb33240Swren romano         // however, it's semantically invalid here since that entry
10688cb33240Swren romano         // does not represent a segment of `indices[r]`.  Moreover, that
10698cb33240Swren romano         // entry must be immutable for `assembledSize` to remain valid.
10708cb33240Swren romano         assert(parentPos < parentSz && "Pointers position is out of bounds");
10718cb33240Swren romano         const uint64_t currentPos = pointers[r][parentPos];
10728cb33240Swren romano         // This increment won't overflow the `P` type, since it can't
10738cb33240Swren romano         // exceed the original value of `pointers[r][parentPos+1]`
10748cb33240Swren romano         // which was already verified to be within bounds for `P`
10758cb33240Swren romano         // when it was written to the array.
10768cb33240Swren romano         pointers[r][parentPos]++;
10778cb33240Swren romano         writeIndex(r, currentPos, ind[r]);
10788cb33240Swren romano         parentPos = currentPos;
10798cb33240Swren romano       } else { // Dense dimension.
10808cb33240Swren romano         parentPos = parentPos * getDimSizes()[r] + ind[r];
10818cb33240Swren romano       }
10828cb33240Swren romano       parentSz = assembledSize(parentSz, r);
10838cb33240Swren romano     }
10848cb33240Swren romano     assert(parentPos < values.size() && "Value position is out of bounds");
10858cb33240Swren romano     values[parentPos] = val;
10868cb33240Swren romano   });
10878cb33240Swren romano   // No longer need the enumerator, so we'll delete it ASAP.
10888cb33240Swren romano   delete enumerator;
10898cb33240Swren romano   // The finalizeYieldPos loop
10908cb33240Swren romano   for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
10918cb33240Swren romano     if (isCompressedDim(r)) {
10928cb33240Swren romano       assert(parentSz == pointers[r].size() - 1 &&
10938cb33240Swren romano              "Actual pointers size doesn't match the expected size");
10948cb33240Swren romano       // Can't check all of them, but at least we can check the last one.
10958cb33240Swren romano       assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
10968cb33240Swren romano              "Pointers got corrupted");
10978cb33240Swren romano       // TODO: optimize this by using `memmove` or similar.
10988cb33240Swren romano       for (uint64_t n = 0; n < parentSz; n++) {
10998cb33240Swren romano         const uint64_t parentPos = parentSz - n;
11008cb33240Swren romano         pointers[r][parentPos] = pointers[r][parentPos - 1];
11018cb33240Swren romano       }
11028cb33240Swren romano       pointers[r][0] = 0;
11038cb33240Swren romano     }
11048cb33240Swren romano     parentSz = assembledSize(parentSz, r);
11058cb33240Swren romano   }
11068cb33240Swren romano }
11078cb33240Swren romano 
11088a91bc7bSHarrietAkot /// Helper to convert string to lower case.
11098a91bc7bSHarrietAkot static char *toLower(char *token) {
11108a91bc7bSHarrietAkot   for (char *c = token; *c; c++)
11118a91bc7bSHarrietAkot     *c = tolower(*c);
11128a91bc7bSHarrietAkot   return token;
11138a91bc7bSHarrietAkot }
11148a91bc7bSHarrietAkot 
11158a91bc7bSHarrietAkot /// Read the MME header of a general sparse matrix of type real.
111603fe15ceSAart Bik static void readMMEHeader(FILE *file, char *filename, char *line,
111733e8ab8eSAart Bik                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
11188a91bc7bSHarrietAkot   char header[64];
11198a91bc7bSHarrietAkot   char object[64];
11208a91bc7bSHarrietAkot   char format[64];
11218a91bc7bSHarrietAkot   char field[64];
11228a91bc7bSHarrietAkot   char symmetry[64];
11238a91bc7bSHarrietAkot   // Read header line.
11248a91bc7bSHarrietAkot   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
11258a91bc7bSHarrietAkot              symmetry) != 5) {
112603fe15ceSAart Bik     fprintf(stderr, "Corrupt header in %s\n", filename);
11278a91bc7bSHarrietAkot     exit(1);
11288a91bc7bSHarrietAkot   }
112933e8ab8eSAart Bik   // Set properties
113033e8ab8eSAart Bik   *isPattern = (strcmp(toLower(field), "pattern") == 0);
1131bb56c2b3SMehdi Amini   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
11328a91bc7bSHarrietAkot   // Make sure this is a general sparse matrix.
11338a91bc7bSHarrietAkot   if (strcmp(toLower(header), "%%matrixmarket") ||
11348a91bc7bSHarrietAkot       strcmp(toLower(object), "matrix") ||
113533e8ab8eSAart Bik       strcmp(toLower(format), "coordinate") ||
113633e8ab8eSAart Bik       (strcmp(toLower(field), "real") && !(*isPattern)) ||
1137bb56c2b3SMehdi Amini       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
113833e8ab8eSAart Bik     fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
11398a91bc7bSHarrietAkot     exit(1);
11408a91bc7bSHarrietAkot   }
11418a91bc7bSHarrietAkot   // Skip comments.
1142e5639b3fSMehdi Amini   while (true) {
114303fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
114403fe15ceSAart Bik       fprintf(stderr, "Cannot find data in %s\n", filename);
11458a91bc7bSHarrietAkot       exit(1);
11468a91bc7bSHarrietAkot     }
11478a91bc7bSHarrietAkot     if (line[0] != '%')
11488a91bc7bSHarrietAkot       break;
11498a91bc7bSHarrietAkot   }
11508a91bc7bSHarrietAkot   // Next line contains M N NNZ.
11518a91bc7bSHarrietAkot   idata[0] = 2; // rank
11528a91bc7bSHarrietAkot   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
11538a91bc7bSHarrietAkot              idata + 1) != 3) {
115403fe15ceSAart Bik     fprintf(stderr, "Cannot find size in %s\n", filename);
11558a91bc7bSHarrietAkot     exit(1);
11568a91bc7bSHarrietAkot   }
11578a91bc7bSHarrietAkot }
11588a91bc7bSHarrietAkot 
11598a91bc7bSHarrietAkot /// Read the "extended" FROSTT header. Although not part of the documented
11608a91bc7bSHarrietAkot /// format, we assume that the file starts with optional comments followed
11618a91bc7bSHarrietAkot /// by two lines that define the rank, the number of nonzeros, and the
11628a91bc7bSHarrietAkot /// dimensions sizes (one per rank) of the sparse tensor.
116303fe15ceSAart Bik static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
116403fe15ceSAart Bik                                 uint64_t *idata) {
11658a91bc7bSHarrietAkot   // Skip comments.
1166e5639b3fSMehdi Amini   while (true) {
116703fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
116803fe15ceSAart Bik       fprintf(stderr, "Cannot find data in %s\n", filename);
11698a91bc7bSHarrietAkot       exit(1);
11708a91bc7bSHarrietAkot     }
11718a91bc7bSHarrietAkot     if (line[0] != '#')
11728a91bc7bSHarrietAkot       break;
11738a91bc7bSHarrietAkot   }
11748a91bc7bSHarrietAkot   // Next line contains RANK and NNZ.
11758a91bc7bSHarrietAkot   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
117603fe15ceSAart Bik     fprintf(stderr, "Cannot find metadata in %s\n", filename);
11778a91bc7bSHarrietAkot     exit(1);
11788a91bc7bSHarrietAkot   }
11798a91bc7bSHarrietAkot   // Followed by a line with the dimension sizes (one per rank).
11808a91bc7bSHarrietAkot   for (uint64_t r = 0; r < idata[0]; r++) {
11818a91bc7bSHarrietAkot     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
118203fe15ceSAart Bik       fprintf(stderr, "Cannot find dimension size %s\n", filename);
11838a91bc7bSHarrietAkot       exit(1);
11848a91bc7bSHarrietAkot     }
11858a91bc7bSHarrietAkot   }
118603fe15ceSAart Bik   fgets(line, kColWidth, file); // end of line
11878a91bc7bSHarrietAkot }
11888a91bc7bSHarrietAkot 
11898a91bc7bSHarrietAkot /// Reads a sparse tensor with the given filename into a memory-resident
11908a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme.
11918a91bc7bSHarrietAkot template <typename V>
11928a91bc7bSHarrietAkot static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
1193d83a7068Swren romano                                                const uint64_t *shape,
11948a91bc7bSHarrietAkot                                                const uint64_t *perm) {
11958a91bc7bSHarrietAkot   // Open the file.
11968a91bc7bSHarrietAkot   FILE *file = fopen(filename, "r");
11978a91bc7bSHarrietAkot   if (!file) {
11983734c078Swren romano     assert(filename && "Received nullptr for filename");
11993734c078Swren romano     fprintf(stderr, "Cannot find file %s\n", filename);
12008a91bc7bSHarrietAkot     exit(1);
12018a91bc7bSHarrietAkot   }
12028a91bc7bSHarrietAkot   // Perform some file format dependent set up.
120303fe15ceSAart Bik   char line[kColWidth];
12048a91bc7bSHarrietAkot   uint64_t idata[512];
120533e8ab8eSAart Bik   bool isPattern = false;
1206bb56c2b3SMehdi Amini   bool isSymmetric = false;
12078a91bc7bSHarrietAkot   if (strstr(filename, ".mtx")) {
120833e8ab8eSAart Bik     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
12098a91bc7bSHarrietAkot   } else if (strstr(filename, ".tns")) {
121003fe15ceSAart Bik     readExtFROSTTHeader(file, filename, line, idata);
12118a91bc7bSHarrietAkot   } else {
12128a91bc7bSHarrietAkot     fprintf(stderr, "Unknown format %s\n", filename);
12138a91bc7bSHarrietAkot     exit(1);
12148a91bc7bSHarrietAkot   }
12158a91bc7bSHarrietAkot   // Prepare sparse tensor object with per-dimension sizes
12168a91bc7bSHarrietAkot   // and the number of nonzeros as initial capacity.
12178a91bc7bSHarrietAkot   assert(rank == idata[0] && "rank mismatch");
12188a91bc7bSHarrietAkot   uint64_t nnz = idata[1];
12198a91bc7bSHarrietAkot   for (uint64_t r = 0; r < rank; r++)
1220d83a7068Swren romano     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
12218a91bc7bSHarrietAkot            "dimension size mismatch");
12228a91bc7bSHarrietAkot   SparseTensorCOO<V> *tensor =
12238a91bc7bSHarrietAkot       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
12248a91bc7bSHarrietAkot   // Read all nonzero elements.
12258a91bc7bSHarrietAkot   std::vector<uint64_t> indices(rank);
12268a91bc7bSHarrietAkot   for (uint64_t k = 0; k < nnz; k++) {
122703fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
122803fe15ceSAart Bik       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
12298a91bc7bSHarrietAkot       exit(1);
12308a91bc7bSHarrietAkot     }
123103fe15ceSAart Bik     char *linePtr = line;
123203fe15ceSAart Bik     for (uint64_t r = 0; r < rank; r++) {
123303fe15ceSAart Bik       uint64_t idx = strtoul(linePtr, &linePtr, 10);
12348a91bc7bSHarrietAkot       // Add 0-based index.
12358a91bc7bSHarrietAkot       indices[perm[r]] = idx - 1;
12368a91bc7bSHarrietAkot     }
12378a91bc7bSHarrietAkot     // The external formats always store the numerical values with the type
12388a91bc7bSHarrietAkot     // double, but we cast these values to the sparse tensor object type.
123933e8ab8eSAart Bik     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
124033e8ab8eSAart Bik     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
12418a91bc7bSHarrietAkot     tensor->add(indices, value);
124202710413SBixia Zheng     // We currently chose to deal with symmetric matrices by fully constructing
124302710413SBixia Zheng     // them. In the future, we may want to make symmetry implicit for storage
124402710413SBixia Zheng     // reasons.
1245bb56c2b3SMehdi Amini     if (isSymmetric && indices[0] != indices[1])
124602710413SBixia Zheng       tensor->add({indices[1], indices[0]}, value);
12478a91bc7bSHarrietAkot   }
12488a91bc7bSHarrietAkot   // Close the file and return tensor.
12498a91bc7bSHarrietAkot   fclose(file);
12508a91bc7bSHarrietAkot   return tensor;
12518a91bc7bSHarrietAkot }
12528a91bc7bSHarrietAkot 
1253efa15f41SAart Bik /// Writes the sparse tensor to extended FROSTT format.
1254efa15f41SAart Bik template <typename V>
125546bdacaaSwren romano static void outSparseTensor(void *tensor, void *dest, bool sort) {
12566438783fSAart Bik   assert(tensor && dest);
12576438783fSAart Bik   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
12586438783fSAart Bik   if (sort)
12596438783fSAart Bik     coo->sort();
12606438783fSAart Bik   char *filename = static_cast<char *>(dest);
1261*fa6aed2aSwren romano   auto &dimSizes = coo->getDimSizes();
12626438783fSAart Bik   auto &elements = coo->getElements();
12636438783fSAart Bik   uint64_t rank = coo->getRank();
1264efa15f41SAart Bik   uint64_t nnz = elements.size();
1265efa15f41SAart Bik   std::fstream file;
1266efa15f41SAart Bik   file.open(filename, std::ios_base::out | std::ios_base::trunc);
1267efa15f41SAart Bik   assert(file.is_open());
1268efa15f41SAart Bik   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1269efa15f41SAart Bik   for (uint64_t r = 0; r < rank - 1; r++)
1270*fa6aed2aSwren romano     file << dimSizes[r] << " ";
1271*fa6aed2aSwren romano   file << dimSizes[rank - 1] << std::endl;
1272efa15f41SAart Bik   for (uint64_t i = 0; i < nnz; i++) {
1273efa15f41SAart Bik     auto &idx = elements[i].indices;
1274efa15f41SAart Bik     for (uint64_t r = 0; r < rank; r++)
1275efa15f41SAart Bik       file << (idx[r] + 1) << " ";
1276efa15f41SAart Bik     file << elements[i].value << std::endl;
1277efa15f41SAart Bik   }
1278efa15f41SAart Bik   file.flush();
1279efa15f41SAart Bik   file.close();
1280efa15f41SAart Bik   assert(file.good());
12816438783fSAart Bik }
12826438783fSAart Bik 
12836438783fSAart Bik /// Initializes sparse tensor from an external COO-flavored format.
12846438783fSAart Bik template <typename V>
128546bdacaaSwren romano static SparseTensorStorage<uint64_t, uint64_t, V> *
12866438783fSAart Bik toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
128720eaa88fSBixia Zheng                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
128820eaa88fSBixia Zheng   const DimLevelType *sparsity = (DimLevelType *)(sparse);
128920eaa88fSBixia Zheng #ifndef NDEBUG
129020eaa88fSBixia Zheng   // Verify that perm is a permutation of 0..(rank-1).
129120eaa88fSBixia Zheng   std::vector<uint64_t> order(perm, perm + rank);
129220eaa88fSBixia Zheng   std::sort(order.begin(), order.end());
12931e47888dSAart Bik   for (uint64_t i = 0; i < rank; ++i) {
129420eaa88fSBixia Zheng     if (i != order[i]) {
1295988d4b0dSAart Bik       fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
129620eaa88fSBixia Zheng       exit(1);
129720eaa88fSBixia Zheng     }
129820eaa88fSBixia Zheng   }
129920eaa88fSBixia Zheng 
130020eaa88fSBixia Zheng   // Verify that the sparsity values are supported.
13011e47888dSAart Bik   for (uint64_t i = 0; i < rank; ++i) {
130220eaa88fSBixia Zheng     if (sparsity[i] != DimLevelType::kDense &&
130320eaa88fSBixia Zheng         sparsity[i] != DimLevelType::kCompressed) {
130420eaa88fSBixia Zheng       fprintf(stderr, "Unsupported sparsity value %d\n",
130520eaa88fSBixia Zheng               static_cast<int>(sparsity[i]));
130620eaa88fSBixia Zheng       exit(1);
130720eaa88fSBixia Zheng     }
130820eaa88fSBixia Zheng   }
130920eaa88fSBixia Zheng #endif
131020eaa88fSBixia Zheng 
13116438783fSAart Bik   // Convert external format to internal COO.
131263bdcaf9Swren romano   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
13136438783fSAart Bik   std::vector<uint64_t> idx(rank);
13146438783fSAart Bik   for (uint64_t i = 0, base = 0; i < nse; i++) {
13156438783fSAart Bik     for (uint64_t r = 0; r < rank; r++)
1316d8b229a1SAart Bik       idx[perm[r]] = indices[base + r];
131763bdcaf9Swren romano     coo->add(idx, values[i]);
13186438783fSAart Bik     base += rank;
13196438783fSAart Bik   }
13206438783fSAart Bik   // Return sparse tensor storage format as opaque pointer.
132163bdcaf9Swren romano   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
132263bdcaf9Swren romano       rank, shape, perm, sparsity, coo);
132363bdcaf9Swren romano   delete coo;
132463bdcaf9Swren romano   return tensor;
13256438783fSAart Bik }
13266438783fSAart Bik 
13276438783fSAart Bik /// Converts a sparse tensor to an external COO-flavored format.
13286438783fSAart Bik template <typename V>
132946bdacaaSwren romano static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
133046bdacaaSwren romano                                  uint64_t **pShape, V **pValues,
133146bdacaaSwren romano                                  uint64_t **pIndices) {
1332736c1b66SAart Bik   assert(tensor);
13336438783fSAart Bik   auto sparseTensor =
13346438783fSAart Bik       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
13356438783fSAart Bik   uint64_t rank = sparseTensor->getRank();
13366438783fSAart Bik   std::vector<uint64_t> perm(rank);
13376438783fSAart Bik   std::iota(perm.begin(), perm.end(), 0);
13386438783fSAart Bik   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
13396438783fSAart Bik 
13406438783fSAart Bik   const std::vector<Element<V>> &elements = coo->getElements();
13416438783fSAart Bik   uint64_t nse = elements.size();
13426438783fSAart Bik 
13436438783fSAart Bik   uint64_t *shape = new uint64_t[rank];
13446438783fSAart Bik   for (uint64_t i = 0; i < rank; i++)
1345*fa6aed2aSwren romano     shape[i] = coo->getDimSizes()[i];
13466438783fSAart Bik 
13476438783fSAart Bik   V *values = new V[nse];
13486438783fSAart Bik   uint64_t *indices = new uint64_t[rank * nse];
13496438783fSAart Bik 
13506438783fSAart Bik   for (uint64_t i = 0, base = 0; i < nse; i++) {
13516438783fSAart Bik     values[i] = elements[i].value;
13526438783fSAart Bik     for (uint64_t j = 0; j < rank; j++)
13536438783fSAart Bik       indices[base + j] = elements[i].indices[j];
13546438783fSAart Bik     base += rank;
13556438783fSAart Bik   }
13566438783fSAart Bik 
13576438783fSAart Bik   delete coo;
13586438783fSAart Bik   *pRank = rank;
13596438783fSAart Bik   *pNse = nse;
13606438783fSAart Bik   *pShape = shape;
13616438783fSAart Bik   *pValues = values;
13626438783fSAart Bik   *pIndices = indices;
1363efa15f41SAart Bik }
1364efa15f41SAart Bik 
1365be0a7e9fSMehdi Amini } // namespace
13668a91bc7bSHarrietAkot 
13678a91bc7bSHarrietAkot extern "C" {
13688a91bc7bSHarrietAkot 
13698a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
13708a91bc7bSHarrietAkot //
13718a91bc7bSHarrietAkot // Public API with methods that operate on MLIR buffers (memrefs) to interact
13728a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally.
13738a91bc7bSHarrietAkot // These methods should be used exclusively by MLIR compiler-generated code.
13748a91bc7bSHarrietAkot //
13758a91bc7bSHarrietAkot // Some macro magic is used to generate implementations for all required type
13768a91bc7bSHarrietAkot // combinations that can be called from MLIR compiler-generated code.
13778a91bc7bSHarrietAkot //
13788a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
13798a91bc7bSHarrietAkot 
13808a91bc7bSHarrietAkot #define CASE(p, i, v, P, I, V)                                                 \
13818a91bc7bSHarrietAkot   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
138263bdcaf9Swren romano     SparseTensorCOO<V> *coo = nullptr;                                         \
1383845561ecSwren romano     if (action <= Action::kFromCOO) {                                          \
1384845561ecSwren romano       if (action == Action::kFromFile) {                                       \
13858a91bc7bSHarrietAkot         char *filename = static_cast<char *>(ptr);                             \
138663bdcaf9Swren romano         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
1387845561ecSwren romano       } else if (action == Action::kFromCOO) {                                 \
138863bdcaf9Swren romano         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
13898a91bc7bSHarrietAkot       } else {                                                                 \
1390845561ecSwren romano         assert(action == Action::kEmpty);                                      \
13918a91bc7bSHarrietAkot       }                                                                        \
139263bdcaf9Swren romano       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
139363bdcaf9Swren romano           rank, shape, perm, sparsity, coo);                                   \
139463bdcaf9Swren romano       if (action == Action::kFromFile)                                         \
139563bdcaf9Swren romano         delete coo;                                                            \
139663bdcaf9Swren romano       return tensor;                                                           \
1397bb56c2b3SMehdi Amini     }                                                                          \
13988cb33240Swren romano     if (action == Action::kSparseToSparse) {                                   \
13998cb33240Swren romano       auto *tensor = static_cast<SparseTensorStorageBase *>(ptr);              \
14008cb33240Swren romano       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
14018cb33240Swren romano                                                            sparsity, tensor);  \
14028cb33240Swren romano     }                                                                          \
1403bb56c2b3SMehdi Amini     if (action == Action::kEmptyCOO)                                           \
1404d83a7068Swren romano       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
140563bdcaf9Swren romano     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1406845561ecSwren romano     if (action == Action::kToIterator) {                                       \
140763bdcaf9Swren romano       coo->startIterator();                                                    \
14088a91bc7bSHarrietAkot     } else {                                                                   \
1409845561ecSwren romano       assert(action == Action::kToCOO);                                        \
14108a91bc7bSHarrietAkot     }                                                                          \
141163bdcaf9Swren romano     return coo;                                                                \
14128a91bc7bSHarrietAkot   }
14138a91bc7bSHarrietAkot 
1414845561ecSwren romano #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
14154f2ec7f9SAart Bik 
1416d2215e79SRainer Orth // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1417bc04a470Swren romano // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1418bc04a470Swren romano // that this file cannot get out of sync with its header.
1419d2215e79SRainer Orth static_assert(std::is_same<index_type, uint64_t>::value,
1420d2215e79SRainer Orth               "Expected index_type == uint64_t");
1421bc04a470Swren romano 
14228a91bc7bSHarrietAkot /// Constructs a new sparse tensor. This is the "swiss army knife"
14238a91bc7bSHarrietAkot /// method for materializing sparse tensors into the computation.
14248a91bc7bSHarrietAkot ///
1425845561ecSwren romano /// Action:
14268a91bc7bSHarrietAkot /// kEmpty = returns empty storage to fill later
14278a91bc7bSHarrietAkot /// kFromFile = returns storage, where ptr contains filename to read
14288a91bc7bSHarrietAkot /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
14298a91bc7bSHarrietAkot /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
14308a91bc7bSHarrietAkot /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
1431845561ecSwren romano /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
14328a91bc7bSHarrietAkot void *
1433845561ecSwren romano _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1434d2215e79SRainer Orth                              StridedMemRefType<index_type, 1> *sref,
1435d2215e79SRainer Orth                              StridedMemRefType<index_type, 1> *pref,
1436845561ecSwren romano                              OverheadType ptrTp, OverheadType indTp,
1437845561ecSwren romano                              PrimaryType valTp, Action action, void *ptr) {
14388a91bc7bSHarrietAkot   assert(aref && sref && pref);
14398a91bc7bSHarrietAkot   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
14408a91bc7bSHarrietAkot          pref->strides[0] == 1);
14418a91bc7bSHarrietAkot   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1442845561ecSwren romano   const DimLevelType *sparsity = aref->data + aref->offset;
1443d83a7068Swren romano   const index_type *shape = sref->data + sref->offset;
1444d2215e79SRainer Orth   const index_type *perm = pref->data + pref->offset;
14458a91bc7bSHarrietAkot   uint64_t rank = aref->sizes[0];
14468a91bc7bSHarrietAkot 
1447bc04a470Swren romano   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1448bc04a470Swren romano   // This is safe because of the static_assert above.
1449bc04a470Swren romano   if (ptrTp == OverheadType::kIndex)
1450bc04a470Swren romano     ptrTp = OverheadType::kU64;
1451bc04a470Swren romano   if (indTp == OverheadType::kIndex)
1452bc04a470Swren romano     indTp = OverheadType::kU64;
1453bc04a470Swren romano 
14548a91bc7bSHarrietAkot   // Double matrices with all combinations of overhead storage.
1455845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1456845561ecSwren romano        uint64_t, double);
1457845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1458845561ecSwren romano        uint32_t, double);
1459845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1460845561ecSwren romano        uint16_t, double);
1461845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1462845561ecSwren romano        uint8_t, double);
1463845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1464845561ecSwren romano        uint64_t, double);
1465845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1466845561ecSwren romano        uint32_t, double);
1467845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1468845561ecSwren romano        uint16_t, double);
1469845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1470845561ecSwren romano        uint8_t, double);
1471845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1472845561ecSwren romano        uint64_t, double);
1473845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1474845561ecSwren romano        uint32_t, double);
1475845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1476845561ecSwren romano        uint16_t, double);
1477845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1478845561ecSwren romano        uint8_t, double);
1479845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1480845561ecSwren romano        uint64_t, double);
1481845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1482845561ecSwren romano        uint32_t, double);
1483845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1484845561ecSwren romano        uint16_t, double);
1485845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1486845561ecSwren romano        uint8_t, double);
14878a91bc7bSHarrietAkot 
14888a91bc7bSHarrietAkot   // Float matrices with all combinations of overhead storage.
1489845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1490845561ecSwren romano        uint64_t, float);
1491845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1492845561ecSwren romano        uint32_t, float);
1493845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1494845561ecSwren romano        uint16_t, float);
1495845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1496845561ecSwren romano        uint8_t, float);
1497845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1498845561ecSwren romano        uint64_t, float);
1499845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1500845561ecSwren romano        uint32_t, float);
1501845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1502845561ecSwren romano        uint16_t, float);
1503845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1504845561ecSwren romano        uint8_t, float);
1505845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1506845561ecSwren romano        uint64_t, float);
1507845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1508845561ecSwren romano        uint32_t, float);
1509845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1510845561ecSwren romano        uint16_t, float);
1511845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1512845561ecSwren romano        uint8_t, float);
1513845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1514845561ecSwren romano        uint64_t, float);
1515845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1516845561ecSwren romano        uint32_t, float);
1517845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1518845561ecSwren romano        uint16_t, float);
1519845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1520845561ecSwren romano        uint8_t, float);
15218a91bc7bSHarrietAkot 
1522845561ecSwren romano   // Integral matrices with both overheads of the same type.
1523845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1524845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1525845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1526845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1527845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1528845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1529845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1530845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1531845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1532845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1533845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1534845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1535845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
15368a91bc7bSHarrietAkot 
1537736c1b66SAart Bik   // Complex matrices with wide overhead.
1538736c1b66SAart Bik   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1539736c1b66SAart Bik   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1540736c1b66SAart Bik 
15418a91bc7bSHarrietAkot   // Unsupported case (add above if needed).
15428a91bc7bSHarrietAkot   fputs("unsupported combination of types\n", stderr);
15438a91bc7bSHarrietAkot   exit(1);
15448a91bc7bSHarrietAkot }
15458a91bc7bSHarrietAkot #undef CASE
15461313f5d3Swren romano #undef CASE_SECSAME
15476438783fSAart Bik 
1548bfadd13dSwren romano /// Methods that provide direct access to values.
1549bfadd13dSwren romano #define IMPL_SPARSEVALUES(VNAME, V)                                            \
1550bfadd13dSwren romano   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
1551bfadd13dSwren romano                                         void *tensor) {                        \
1552bfadd13dSwren romano     assert(ref &&tensor);                                                      \
1553bfadd13dSwren romano     std::vector<V> *v;                                                         \
1554bfadd13dSwren romano     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
1555bfadd13dSwren romano     ref->basePtr = ref->data = v->data();                                      \
1556bfadd13dSwren romano     ref->offset = 0;                                                           \
1557bfadd13dSwren romano     ref->sizes[0] = v->size();                                                 \
1558bfadd13dSwren romano     ref->strides[0] = 1;                                                       \
1559bfadd13dSwren romano   }
1560bfadd13dSwren romano FOREVERY_V(IMPL_SPARSEVALUES)
1561bfadd13dSwren romano #undef IMPL_SPARSEVALUES
1562bfadd13dSwren romano 
1563bfadd13dSwren romano #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1564bfadd13dSwren romano   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1565bfadd13dSwren romano                            index_type d) {                                     \
1566bfadd13dSwren romano     assert(ref &&tensor);                                                      \
1567bfadd13dSwren romano     std::vector<TYPE> *v;                                                      \
1568bfadd13dSwren romano     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1569bfadd13dSwren romano     ref->basePtr = ref->data = v->data();                                      \
1570bfadd13dSwren romano     ref->offset = 0;                                                           \
1571bfadd13dSwren romano     ref->sizes[0] = v->size();                                                 \
1572bfadd13dSwren romano     ref->strides[0] = 1;                                                       \
1573bfadd13dSwren romano   }
1574bfadd13dSwren romano /// Methods that provide direct access to pointers.
1575bfadd13dSwren romano IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
1576bfadd13dSwren romano IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
1577bfadd13dSwren romano IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
1578bfadd13dSwren romano IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
1579bfadd13dSwren romano IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
1580bfadd13dSwren romano 
1581bfadd13dSwren romano /// Methods that provide direct access to indices.
1582bfadd13dSwren romano IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
1583bfadd13dSwren romano IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
1584bfadd13dSwren romano IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
1585bfadd13dSwren romano IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
1586bfadd13dSwren romano IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
1587bfadd13dSwren romano #undef IMPL_GETOVERHEAD
1588bfadd13dSwren romano 
1589bfadd13dSwren romano /// Helper to add value to coordinate scheme, one per value type.
1590bfadd13dSwren romano #define IMPL_ADDELT(VNAME, V)                                                  \
1591bfadd13dSwren romano   void *_mlir_ciface_addElt##VNAME(void *coo, V value,                         \
1592bfadd13dSwren romano                                    StridedMemRefType<index_type, 1> *iref,     \
1593bfadd13dSwren romano                                    StridedMemRefType<index_type, 1> *pref) {   \
1594bfadd13dSwren romano     assert(coo &&iref &&pref);                                                 \
1595bfadd13dSwren romano     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1596bfadd13dSwren romano     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1597bfadd13dSwren romano     const index_type *indx = iref->data + iref->offset;                        \
1598bfadd13dSwren romano     const index_type *perm = pref->data + pref->offset;                        \
1599bfadd13dSwren romano     uint64_t isize = iref->sizes[0];                                           \
1600bfadd13dSwren romano     std::vector<index_type> indices(isize);                                    \
1601bfadd13dSwren romano     for (uint64_t r = 0; r < isize; r++)                                       \
1602bfadd13dSwren romano       indices[perm[r]] = indx[r];                                              \
1603bfadd13dSwren romano     static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value);               \
1604bfadd13dSwren romano     return coo;                                                                \
1605bfadd13dSwren romano   }
1606bfadd13dSwren romano FOREVERY_SIMPLEX_V(IMPL_ADDELT)
1607bfadd13dSwren romano // `complex64` apparently doesn't encounter any ABI issues (yet).
1608bfadd13dSwren romano IMPL_ADDELT(C64, complex64)
1609bfadd13dSwren romano // TODO: cleaner way to avoid ABI padding problem?
1610bfadd13dSwren romano IMPL_ADDELT(C32ABI, complex32)
1611bfadd13dSwren romano void *_mlir_ciface_addEltC32(void *coo, float r, float i,
1612bfadd13dSwren romano                              StridedMemRefType<index_type, 1> *iref,
1613bfadd13dSwren romano                              StridedMemRefType<index_type, 1> *pref) {
1614bfadd13dSwren romano   return _mlir_ciface_addEltC32ABI(coo, complex32(r, i), iref, pref);
1615bfadd13dSwren romano }
1616bfadd13dSwren romano #undef IMPL_ADDELT
1617bfadd13dSwren romano 
1618bfadd13dSwren romano /// Helper to enumerate elements of coordinate scheme, one per value type.
1619bfadd13dSwren romano #define IMPL_GETNEXT(VNAME, V)                                                 \
1620bfadd13dSwren romano   bool _mlir_ciface_getNext##VNAME(void *coo,                                  \
1621bfadd13dSwren romano                                    StridedMemRefType<index_type, 1> *iref,     \
1622bfadd13dSwren romano                                    StridedMemRefType<V, 0> *vref) {            \
1623bfadd13dSwren romano     assert(coo &&iref &&vref);                                                 \
1624bfadd13dSwren romano     assert(iref->strides[0] == 1);                                             \
1625bfadd13dSwren romano     index_type *indx = iref->data + iref->offset;                              \
1626bfadd13dSwren romano     V *value = vref->data + vref->offset;                                      \
1627bfadd13dSwren romano     const uint64_t isize = iref->sizes[0];                                     \
1628bfadd13dSwren romano     const Element<V> *elem =                                                   \
1629bfadd13dSwren romano         static_cast<SparseTensorCOO<V> *>(coo)->getNext();                     \
1630bfadd13dSwren romano     if (elem == nullptr)                                                       \
1631bfadd13dSwren romano       return false;                                                            \
1632bfadd13dSwren romano     for (uint64_t r = 0; r < isize; r++)                                       \
1633bfadd13dSwren romano       indx[r] = elem->indices[r];                                              \
1634bfadd13dSwren romano     *value = elem->value;                                                      \
1635bfadd13dSwren romano     return true;                                                               \
1636bfadd13dSwren romano   }
1637bfadd13dSwren romano FOREVERY_V(IMPL_GETNEXT)
1638bfadd13dSwren romano #undef IMPL_GETNEXT
1639bfadd13dSwren romano 
1640bfadd13dSwren romano /// Insert elements in lexicographical index order, one per value type.
1641bfadd13dSwren romano #define IMPL_LEXINSERT(VNAME, V)                                               \
1642bfadd13dSwren romano   void _mlir_ciface_lexInsert##VNAME(                                          \
1643bfadd13dSwren romano       void *tensor, StridedMemRefType<index_type, 1> *cref, V val) {           \
1644bfadd13dSwren romano     assert(tensor &&cref);                                                     \
1645bfadd13dSwren romano     assert(cref->strides[0] == 1);                                             \
1646bfadd13dSwren romano     index_type *cursor = cref->data + cref->offset;                            \
1647bfadd13dSwren romano     assert(cursor);                                                            \
1648bfadd13dSwren romano     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1649bfadd13dSwren romano   }
1650bfadd13dSwren romano FOREVERY_SIMPLEX_V(IMPL_LEXINSERT)
1651bfadd13dSwren romano // `complex64` apparently doesn't encounter any ABI issues (yet).
1652bfadd13dSwren romano IMPL_LEXINSERT(C64, complex64)
1653bfadd13dSwren romano // TODO: cleaner way to avoid ABI padding problem?
1654bfadd13dSwren romano IMPL_LEXINSERT(C32ABI, complex32)
1655bfadd13dSwren romano void _mlir_ciface_lexInsertC32(void *tensor,
1656bfadd13dSwren romano                                StridedMemRefType<index_type, 1> *cref, float r,
1657bfadd13dSwren romano                                float i) {
1658bfadd13dSwren romano   _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
1659bfadd13dSwren romano }
1660bfadd13dSwren romano #undef IMPL_LEXINSERT
1661bfadd13dSwren romano 
1662bfadd13dSwren romano /// Insert using expansion, one per value type.
1663bfadd13dSwren romano #define IMPL_EXPINSERT(VNAME, V)                                               \
1664bfadd13dSwren romano   void _mlir_ciface_expInsert##VNAME(                                          \
1665bfadd13dSwren romano       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1666bfadd13dSwren romano       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1667bfadd13dSwren romano       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1668bfadd13dSwren romano     assert(tensor &&cref &&vref &&fref &&aref);                                \
1669bfadd13dSwren romano     assert(cref->strides[0] == 1);                                             \
1670bfadd13dSwren romano     assert(vref->strides[0] == 1);                                             \
1671bfadd13dSwren romano     assert(fref->strides[0] == 1);                                             \
1672bfadd13dSwren romano     assert(aref->strides[0] == 1);                                             \
1673bfadd13dSwren romano     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1674bfadd13dSwren romano     index_type *cursor = cref->data + cref->offset;                            \
1675bfadd13dSwren romano     V *values = vref->data + vref->offset;                                     \
1676bfadd13dSwren romano     bool *filled = fref->data + fref->offset;                                  \
1677bfadd13dSwren romano     index_type *added = aref->data + aref->offset;                             \
1678bfadd13dSwren romano     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1679bfadd13dSwren romano         cursor, values, filled, added, count);                                 \
1680bfadd13dSwren romano   }
1681bfadd13dSwren romano FOREVERY_V(IMPL_EXPINSERT)
1682bfadd13dSwren romano #undef IMPL_EXPINSERT
1683bfadd13dSwren romano 
16846438783fSAart Bik /// Output a sparse tensor, one per value type.
16851313f5d3Swren romano #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \
16861313f5d3Swren romano   void outSparseTensor##VNAME(void *coo, void *dest, bool sort) {              \
16871313f5d3Swren romano     return outSparseTensor<V>(coo, dest, sort);                                \
16886438783fSAart Bik   }
16891313f5d3Swren romano FOREVERY_V(IMPL_OUTSPARSETENSOR)
16901313f5d3Swren romano #undef IMPL_OUTSPARSETENSOR
16918a91bc7bSHarrietAkot 
16928a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
16938a91bc7bSHarrietAkot //
16948a91bc7bSHarrietAkot // Public API with methods that accept C-style data structures to interact
16958a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally.
16968a91bc7bSHarrietAkot // These methods can be used both by MLIR compiler-generated code as well as by
16978a91bc7bSHarrietAkot // an external runtime that wants to interact with MLIR compiler-generated code.
16988a91bc7bSHarrietAkot //
16998a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
17008a91bc7bSHarrietAkot 
17018a91bc7bSHarrietAkot /// Helper method to read a sparse tensor filename from the environment,
17028a91bc7bSHarrietAkot /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1703d2215e79SRainer Orth char *getTensorFilename(index_type id) {
17048a91bc7bSHarrietAkot   char var[80];
17058a91bc7bSHarrietAkot   sprintf(var, "TENSOR%" PRIu64, id);
17068a91bc7bSHarrietAkot   char *env = getenv(var);
17073734c078Swren romano   if (!env) {
17083734c078Swren romano     fprintf(stderr, "Environment variable %s is not set\n", var);
17093734c078Swren romano     exit(1);
17103734c078Swren romano   }
17118a91bc7bSHarrietAkot   return env;
17128a91bc7bSHarrietAkot }
17138a91bc7bSHarrietAkot 
17148a91bc7bSHarrietAkot /// Returns size of sparse tensor in given dimension.
1715d2215e79SRainer Orth index_type sparseDimSize(void *tensor, index_type d) {
17168a91bc7bSHarrietAkot   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
17178a91bc7bSHarrietAkot }
17188a91bc7bSHarrietAkot 
1719f66e5769SAart Bik /// Finalizes lexicographic insertions.
1720f66e5769SAart Bik void endInsert(void *tensor) {
1721f66e5769SAart Bik   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1722f66e5769SAart Bik }
1723f66e5769SAart Bik 
17248a91bc7bSHarrietAkot /// Releases sparse tensor storage.
17258a91bc7bSHarrietAkot void delSparseTensor(void *tensor) {
17268a91bc7bSHarrietAkot   delete static_cast<SparseTensorStorageBase *>(tensor);
17278a91bc7bSHarrietAkot }
17288a91bc7bSHarrietAkot 
172963bdcaf9Swren romano /// Releases sparse tensor coordinate scheme.
173063bdcaf9Swren romano #define IMPL_DELCOO(VNAME, V)                                                  \
173163bdcaf9Swren romano   void delSparseTensorCOO##VNAME(void *coo) {                                  \
173263bdcaf9Swren romano     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
173363bdcaf9Swren romano   }
17341313f5d3Swren romano FOREVERY_V(IMPL_DELCOO)
173563bdcaf9Swren romano #undef IMPL_DELCOO
173663bdcaf9Swren romano 
17378a91bc7bSHarrietAkot /// Initializes sparse tensor from a COO-flavored format expressed using C-style
17388a91bc7bSHarrietAkot /// data structures. The expected parameters are:
17398a91bc7bSHarrietAkot ///
17408a91bc7bSHarrietAkot ///   rank:    rank of tensor
17418a91bc7bSHarrietAkot ///   nse:     number of specified elements (usually the nonzeros)
17428a91bc7bSHarrietAkot ///   shape:   array with dimension size for each rank
17438a91bc7bSHarrietAkot ///   values:  a "nse" array with values for all specified elements
17448a91bc7bSHarrietAkot ///   indices: a flat "nse x rank" array with indices for all specified elements
174520eaa88fSBixia Zheng ///   perm:    the permutation of the dimensions in the storage
174620eaa88fSBixia Zheng ///   sparse:  the sparsity for the dimensions
17478a91bc7bSHarrietAkot ///
17488a91bc7bSHarrietAkot /// For example, the sparse matrix
17498a91bc7bSHarrietAkot ///     | 1.0 0.0 0.0 |
17508a91bc7bSHarrietAkot ///     | 0.0 5.0 3.0 |
17518a91bc7bSHarrietAkot /// can be passed as
17528a91bc7bSHarrietAkot ///      rank    = 2
17538a91bc7bSHarrietAkot ///      nse     = 3
17548a91bc7bSHarrietAkot ///      shape   = [2, 3]
17558a91bc7bSHarrietAkot ///      values  = [1.0, 5.0, 3.0]
17568a91bc7bSHarrietAkot ///      indices = [ 0, 0,  1, 1,  1, 2]
17578a91bc7bSHarrietAkot //
175820eaa88fSBixia Zheng // TODO: generalize beyond 64-bit indices.
17598a91bc7bSHarrietAkot //
17601313f5d3Swren romano #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V)                               \
17611313f5d3Swren romano   void *convertToMLIRSparseTensor##VNAME(                                      \
17621313f5d3Swren romano       uint64_t rank, uint64_t nse, uint64_t *shape, V *values,                 \
17631313f5d3Swren romano       uint64_t *indices, uint64_t *perm, uint8_t *sparse) {                    \
17641313f5d3Swren romano     return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm,      \
17651313f5d3Swren romano                                  sparse);                                      \
17668a91bc7bSHarrietAkot   }
17671313f5d3Swren romano FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
17681313f5d3Swren romano #undef IMPL_CONVERTTOMLIRSPARSETENSOR
17698a91bc7bSHarrietAkot 
17702f49e6b0SBixia Zheng /// Converts a sparse tensor to COO-flavored format expressed using C-style
17712f49e6b0SBixia Zheng /// data structures. The expected output parameters are pointers for these
17722f49e6b0SBixia Zheng /// values:
17732f49e6b0SBixia Zheng ///
17742f49e6b0SBixia Zheng ///   rank:    rank of tensor
17752f49e6b0SBixia Zheng ///   nse:     number of specified elements (usually the nonzeros)
17762f49e6b0SBixia Zheng ///   shape:   array with dimension size for each rank
17772f49e6b0SBixia Zheng ///   values:  a "nse" array with values for all specified elements
17782f49e6b0SBixia Zheng ///   indices: a flat "nse x rank" array with indices for all specified elements
17792f49e6b0SBixia Zheng ///
17802f49e6b0SBixia Zheng /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
17812f49e6b0SBixia Zheng /// from convertToMLIRSparseTensor.
17822f49e6b0SBixia Zheng ///
17832f49e6b0SBixia Zheng //  TODO: Currently, values are copied from SparseTensorStorage to
17842f49e6b0SBixia Zheng //  SparseTensorCOO, then to the output. We may want to reduce the number of
17852f49e6b0SBixia Zheng //  copies.
17862f49e6b0SBixia Zheng //
17876438783fSAart Bik // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
17886438783fSAart Bik // compressed
17892f49e6b0SBixia Zheng //
17901313f5d3Swren romano #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V)                             \
17911313f5d3Swren romano   void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank,       \
17921313f5d3Swren romano                                           uint64_t *pNse, uint64_t **pShape,   \
17931313f5d3Swren romano                                           V **pValues, uint64_t **pIndices) {  \
17941313f5d3Swren romano     fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices);   \
17952f49e6b0SBixia Zheng   }
17961313f5d3Swren romano FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
17971313f5d3Swren romano #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
1798efa15f41SAart Bik 
17998a91bc7bSHarrietAkot } // extern "C"
18008a91bc7bSHarrietAkot 
18018a91bc7bSHarrietAkot #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1802