1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <complex>
25 #include <cctype>
26 #include <cinttypes>
27 #include <cstdio>
28 #include <cstdlib>
29 #include <cstring>
30 #include <fstream>
31 #include <functional>
32 #include <iostream>
33 #include <limits>
34 #include <numeric>
35 #include <vector>
36 
37 using complex64 = std::complex<double>;
38 using complex32 = std::complex<float>;
39 
40 //===----------------------------------------------------------------------===//
41 //
42 // Internal support for storing and reading sparse tensors.
43 //
44 // The following memory-resident sparse storage schemes are supported:
45 //
46 // (a) A coordinate scheme for temporarily storing and lexicographically
47 //     sorting a sparse tensor by index (SparseTensorCOO).
48 //
49 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
50 //     per-dimension sparse/dense annnotations together with a dimension
51 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
52 //
53 // The following external formats are supported:
54 //
55 // (1) Matrix Market Exchange (MME): *.mtx
56 //     https://math.nist.gov/MatrixMarket/formats.html
57 //
58 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
59 //     http://frostt.io/tensors/file-formats.html
60 //
61 // Two public APIs are supported:
62 //
63 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
64 //     tensors. These methods should be used exclusively by MLIR
65 //     compiler-generated code.
66 //
67 // (II) Methods that accept C-style data structures to interact with sparse
68 //      tensors. These methods can be used by any external runtime that wants
69 //      to interact with MLIR compiler-generated code.
70 //
71 // In both cases (I) and (II), the SparseTensorStorage format is externally
72 // only visible as an opaque pointer.
73 //
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 
78 static constexpr int kColWidth = 1025;
79 
80 /// A version of `operator*` on `uint64_t` which checks for overflows.
81 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
82   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
83          "Integer overflow");
84   return lhs * rhs;
85 }
86 
87 // This macro helps minimize repetition of this idiom, as well as ensuring
88 // we have some additional output indicating where the error is coming from.
89 // (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
90 // to track down whether an error is coming from our code vs somewhere else
91 // in MLIR.)
92 #define FATAL(...)                                                             \
93   {                                                                            \
94     fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__);                        \
95     exit(1);                                                                   \
96   }
97 
98 // TODO: adjust this so it can be used by `openSparseTensorCOO` too.
99 // That version doesn't have the permutation, and the `dimSizes` are
100 // a pointer/C-array rather than `std::vector`.
101 //
102 /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
103 /// semantic-order to target-order) are a refinement of the desired `shape`
104 /// (in semantic-order).
105 ///
106 /// Precondition: `perm` and `shape` must be valid for `rank`.
107 static inline void
108 assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes,
109                               uint64_t rank, const uint64_t *perm,
110                               const uint64_t *shape) {
111   assert(perm && shape);
112   assert(rank == dimSizes.size() && "Rank mismatch");
113   for (uint64_t r = 0; r < rank; r++)
114     assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) &&
115            "Dimension size mismatch");
116 }
117 
118 /// A sparse tensor element in coordinate scheme (value and indices).
119 /// For example, a rank-1 vector element would look like
120 ///   ({i}, a[i])
121 /// and a rank-5 tensor element like
122 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
123 /// We use pointer to a shared index pool rather than e.g. a direct
124 /// vector since that (1) reduces the per-element memory footprint, and
125 /// (2) centralizes the memory reservation and (re)allocation to one place.
126 template <typename V>
127 struct Element final {
128   Element(uint64_t *ind, V val) : indices(ind), value(val){};
129   uint64_t *indices; // pointer into shared index pool
130   V value;
131 };
132 
133 /// The type of callback functions which receive an element.  We avoid
134 /// packaging the coordinates and value together as an `Element` object
135 /// because this helps keep code somewhat cleaner.
136 template <typename V>
137 using ElementConsumer =
138     const std::function<void(const std::vector<uint64_t> &, V)> &;
139 
140 /// A memory-resident sparse tensor in coordinate scheme (collection of
141 /// elements). This data structure is used to read a sparse tensor from
142 /// any external format into memory and sort the elements lexicographically
143 /// by indices before passing it back to the client (most packed storage
144 /// formats require the elements to appear in lexicographic index order).
145 template <typename V>
146 struct SparseTensorCOO final {
147 public:
148   SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
149       : dimSizes(dimSizes) {
150     if (capacity) {
151       elements.reserve(capacity);
152       indices.reserve(capacity * getRank());
153     }
154   }
155 
156   /// Adds element as indices and value.
157   void add(const std::vector<uint64_t> &ind, V val) {
158     assert(!iteratorLocked && "Attempt to add() after startIterator()");
159     uint64_t *base = indices.data();
160     uint64_t size = indices.size();
161     uint64_t rank = getRank();
162     assert(ind.size() == rank && "Element rank mismatch");
163     for (uint64_t r = 0; r < rank; r++) {
164       assert(ind[r] < dimSizes[r] && "Index is too large for the dimension");
165       indices.push_back(ind[r]);
166     }
167     // This base only changes if indices were reallocated. In that case, we
168     // need to correct all previous pointers into the vector. Note that this
169     // only happens if we did not set the initial capacity right, and then only
170     // for every internal vector reallocation (which with the doubling rule
171     // should only incur an amortized linear overhead).
172     uint64_t *newBase = indices.data();
173     if (newBase != base) {
174       for (uint64_t i = 0, n = elements.size(); i < n; i++)
175         elements[i].indices = newBase + (elements[i].indices - base);
176       base = newBase;
177     }
178     // Add element as (pointer into shared index pool, value) pair.
179     elements.emplace_back(base + size, val);
180   }
181 
182   /// Sorts elements lexicographically by index.
183   void sort() {
184     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
185     // TODO: we may want to cache an `isSorted` bit, to avoid
186     // unnecessary/redundant sorting.
187     uint64_t rank = getRank();
188     std::sort(elements.begin(), elements.end(),
189               [rank](const Element<V> &e1, const Element<V> &e2) {
190                 for (uint64_t r = 0; r < rank; r++) {
191                   if (e1.indices[r] == e2.indices[r])
192                     continue;
193                   return e1.indices[r] < e2.indices[r];
194                 }
195                 return false;
196               });
197   }
198 
199   /// Get the rank of the tensor.
200   uint64_t getRank() const { return dimSizes.size(); }
201 
202   /// Getter for the dimension-sizes array.
203   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
204 
205   /// Getter for the elements array.
206   const std::vector<Element<V>> &getElements() const { return elements; }
207 
208   /// Switch into iterator mode.
209   void startIterator() {
210     iteratorLocked = true;
211     iteratorPos = 0;
212   }
213 
214   /// Get the next element.
215   const Element<V> *getNext() {
216     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
217     if (iteratorPos < elements.size())
218       return &(elements[iteratorPos++]);
219     iteratorLocked = false;
220     return nullptr;
221   }
222 
223   /// Factory method. Permutes the original dimensions according to
224   /// the given ordering and expects subsequent add() calls to honor
225   /// that same ordering for the given indices. The result is a
226   /// fully permuted coordinate scheme.
227   ///
228   /// Precondition: `dimSizes` and `perm` must be valid for `rank`.
229   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
230                                                 const uint64_t *dimSizes,
231                                                 const uint64_t *perm,
232                                                 uint64_t capacity = 0) {
233     std::vector<uint64_t> permsz(rank);
234     for (uint64_t r = 0; r < rank; r++) {
235       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
236       permsz[perm[r]] = dimSizes[r];
237     }
238     return new SparseTensorCOO<V>(permsz, capacity);
239   }
240 
241 private:
242   const std::vector<uint64_t> dimSizes; // per-dimension sizes
243   std::vector<Element<V>> elements;     // all COO elements
244   std::vector<uint64_t> indices;        // shared index pool
245   bool iteratorLocked = false;
246   unsigned iteratorPos = 0;
247 };
248 
249 // See <https://en.wikipedia.org/wiki/X_Macro>
250 //
251 // `FOREVERY_SIMPLEX_V` only specifies the non-complex `V` types, because
252 // the ABI for complex types has compiler/architecture dependent complexities
253 // we need to work around.  Namely, when a function takes a parameter of
254 // C/C++ type `complex32` (per se), then there is additional padding that
255 // causes it not to match the LLVM type `!llvm.struct<(f32, f32)>`.  This
256 // only happens with the `complex32` type itself, not with pointers/arrays
257 // of complex values.  So far `complex64` doesn't exhibit this ABI
258 // incompatibility, but we exclude it anyways just to be safe.
259 #define FOREVERY_SIMPLEX_V(DO)                                                 \
260   DO(F64, double)                                                              \
261   DO(F32, float)                                                               \
262   DO(I64, int64_t)                                                             \
263   DO(I32, int32_t)                                                             \
264   DO(I16, int16_t)                                                             \
265   DO(I8, int8_t)
266 
267 #define FOREVERY_V(DO)                                                         \
268   FOREVERY_SIMPLEX_V(DO)                                                       \
269   DO(C64, complex64)                                                           \
270   DO(C32, complex32)
271 
272 // This x-macro calls its argument on every overhead type which has
273 // fixed-width.  It excludes `index_type` because that type is often
274 // handled specially (e.g., by translating it into the architecture-dependent
275 // equivalent fixed-width overhead type).
276 #define FOREVERY_FIXED_O(DO)                                                   \
277   DO(64, uint64_t)                                                             \
278   DO(32, uint32_t)                                                             \
279   DO(16, uint16_t)                                                             \
280   DO(8, uint8_t)
281 
282 // This x-macro calls its argument on every overhead type, including
283 // `index_type`.  Our naming convention uses an empty suffix for
284 // `index_type`, so the missing first argument when we call `DO`
285 // gets resolved to the empty token which can then be concatenated
286 // as intended.  (This behavior is standard per C99 6.10.3/4 and
287 // C++11 N3290 16.3/4; whereas in C++03 16.3/10 it was undefined behavior.)
288 #define FOREVERY_O(DO)                                                         \
289   FOREVERY_FIXED_O(DO)                                                         \
290   DO(, index_type)
291 
292 // Forward.
293 template <typename V>
294 class SparseTensorEnumeratorBase;
295 
296 // Helper macro for generating error messages when some
297 // `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
298 // and then the wrong "partial method specialization" is called.
299 #define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
300 
301 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
302 /// takes responsibility for all the `<P,I,V>`-independent aspects
303 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
304 /// we use function overloading to implement "partial" method
305 /// specialization, which the C-API relies on to catch type errors
306 /// arising from our use of opaque pointers.
307 class SparseTensorStorageBase {
308 public:
309   /// Constructs a new storage object.  The `perm` maps the tensor's
310   /// semantic-ordering of dimensions to this object's storage-order.
311   /// The `dimSizes` and `sparsity` arrays are already in storage-order.
312   ///
313   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
314   SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes,
315                           const uint64_t *perm, const DimLevelType *sparsity)
316       : dimSizes(dimSizes), rev(getRank()),
317         dimTypes(sparsity, sparsity + getRank()) {
318     assert(perm && sparsity);
319     const uint64_t rank = getRank();
320     // Validate parameters.
321     assert(rank > 0 && "Trivial shape is unsupported");
322     for (uint64_t r = 0; r < rank; r++) {
323       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
324       assert((dimTypes[r] == DimLevelType::kDense ||
325               dimTypes[r] == DimLevelType::kCompressed) &&
326              "Unsupported DimLevelType");
327     }
328     // Construct the "reverse" (i.e., inverse) permutation.
329     for (uint64_t r = 0; r < rank; r++)
330       rev[perm[r]] = r;
331   }
332 
333   virtual ~SparseTensorStorageBase() = default;
334 
335   /// Get the rank of the tensor.
336   uint64_t getRank() const { return dimSizes.size(); }
337 
338   /// Getter for the dimension-sizes array, in storage-order.
339   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
340 
341   /// Safely lookup the size of the given (storage-order) dimension.
342   uint64_t getDimSize(uint64_t d) const {
343     assert(d < getRank());
344     return dimSizes[d];
345   }
346 
347   /// Getter for the "reverse" permutation, which maps this object's
348   /// storage-order to the tensor's semantic-order.
349   const std::vector<uint64_t> &getRev() const { return rev; }
350 
351   /// Getter for the dimension-types array, in storage-order.
352   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
353 
354   /// Safely check if the (storage-order) dimension uses compressed storage.
355   bool isCompressedDim(uint64_t d) const {
356     assert(d < getRank());
357     return (dimTypes[d] == DimLevelType::kCompressed);
358   }
359 
360   /// Allocate a new enumerator.
361 #define DECL_NEWENUMERATOR(VNAME, V)                                           \
362   virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t,       \
363                              const uint64_t *) const {                         \
364     FATAL_PIV("newEnumerator" #VNAME);                                         \
365   }
366   FOREVERY_V(DECL_NEWENUMERATOR)
367 #undef DECL_NEWENUMERATOR
368 
369   /// Overhead storage.
370 #define DECL_GETPOINTERS(PNAME, P)                                             \
371   virtual void getPointers(std::vector<P> **, uint64_t) {                      \
372     FATAL_PIV("getPointers" #PNAME);                                           \
373   }
374   FOREVERY_FIXED_O(DECL_GETPOINTERS)
375 #undef DECL_GETPOINTERS
376 #define DECL_GETINDICES(INAME, I)                                              \
377   virtual void getIndices(std::vector<I> **, uint64_t) {                       \
378     FATAL_PIV("getIndices" #INAME);                                            \
379   }
380   FOREVERY_FIXED_O(DECL_GETINDICES)
381 #undef DECL_GETINDICES
382 
383   /// Primary storage.
384 #define DECL_GETVALUES(VNAME, V)                                               \
385   virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
386   FOREVERY_V(DECL_GETVALUES)
387 #undef DECL_GETVALUES
388 
389   /// Element-wise insertion in lexicographic index order.
390 #define DECL_LEXINSERT(VNAME, V)                                               \
391   virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
392   FOREVERY_V(DECL_LEXINSERT)
393 #undef DECL_LEXINSERT
394 
395   /// Expanded insertion.
396 #define DECL_EXPINSERT(VNAME, V)                                               \
397   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) {      \
398     FATAL_PIV("expInsert" #VNAME);                                             \
399   }
400   FOREVERY_V(DECL_EXPINSERT)
401 #undef DECL_EXPINSERT
402 
403   /// Finishes insertion.
404   virtual void endInsert() = 0;
405 
406 protected:
407   // Since this class is virtual, we must disallow public copying in
408   // order to avoid "slicing".  Since this class has data members,
409   // that means making copying protected.
410   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
411   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
412   // Copy-assignment would be implicitly deleted (because `dimSizes`
413   // is const), so we explicitly delete it for clarity.
414   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
415 
416 private:
417   const std::vector<uint64_t> dimSizes;
418   std::vector<uint64_t> rev;
419   const std::vector<DimLevelType> dimTypes;
420 };
421 
422 #undef FATAL_PIV
423 
424 // Forward.
425 template <typename P, typename I, typename V>
426 class SparseTensorEnumerator;
427 
428 /// A memory-resident sparse tensor using a storage scheme based on
429 /// per-dimension sparse/dense annotations. This data structure provides a
430 /// bufferized form of a sparse tensor type. In contrast to generating setup
431 /// methods for each differently annotated sparse tensor, this method provides
432 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
433 /// and annotations to implement all required setup in a general manner.
434 template <typename P, typename I, typename V>
435 class SparseTensorStorage final : public SparseTensorStorageBase {
436   /// Private constructor to share code between the other constructors.
437   /// Beware that the object is not necessarily guaranteed to be in a
438   /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
439   /// doesn't entail `!(pointers[d].empty())`.
440   ///
441   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
442   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
443                       const uint64_t *perm, const DimLevelType *sparsity)
444       : SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()),
445         indices(getRank()), idx(getRank()) {}
446 
447 public:
448   /// Constructs a sparse tensor storage scheme with the given dimensions,
449   /// permutation, and per-dimension dense/sparse annotations, using
450   /// the coordinate scheme tensor for the initial contents if provided.
451   ///
452   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
453   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
454                       const uint64_t *perm, const DimLevelType *sparsity,
455                       SparseTensorCOO<V> *coo)
456       : SparseTensorStorage(dimSizes, perm, sparsity) {
457     // Provide hints on capacity of pointers and indices.
458     // TODO: needs much fine-tuning based on actual sparsity; currently
459     //       we reserve pointer/index space based on all previous dense
460     //       dimensions, which works well up to first sparse dim; but
461     //       we should really use nnz and dense/sparse distribution.
462     bool allDense = true;
463     uint64_t sz = 1;
464     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
465       if (isCompressedDim(r)) {
466         // TODO: Take a parameter between 1 and `dimSizes[r]`, and multiply
467         // `sz` by that before reserving. (For now we just use 1.)
468         pointers[r].reserve(sz + 1);
469         pointers[r].push_back(0);
470         indices[r].reserve(sz);
471         sz = 1;
472         allDense = false;
473       } else { // Dense dimension.
474         sz = checkedMul(sz, getDimSizes()[r]);
475       }
476     }
477     // Then assign contents from coordinate scheme tensor if provided.
478     if (coo) {
479       // Ensure both preconditions of `fromCOO`.
480       assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch");
481       coo->sort();
482       // Now actually insert the `elements`.
483       const std::vector<Element<V>> &elements = coo->getElements();
484       uint64_t nnz = elements.size();
485       values.reserve(nnz);
486       fromCOO(elements, 0, nnz, 0);
487     } else if (allDense) {
488       values.resize(sz, 0);
489     }
490   }
491 
492   /// Constructs a sparse tensor storage scheme with the given dimensions,
493   /// permutation, and per-dimension dense/sparse annotations, using
494   /// the given sparse tensor for the initial contents.
495   ///
496   /// Preconditions:
497   /// * `perm` and `sparsity` must be valid for `dimSizes.size()`.
498   /// * The `tensor` must have the same value type `V`.
499   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
500                       const uint64_t *perm, const DimLevelType *sparsity,
501                       const SparseTensorStorageBase &tensor);
502 
503   ~SparseTensorStorage() final override = default;
504 
505   /// Partially specialize these getter methods based on template types.
506   void getPointers(std::vector<P> **out, uint64_t d) final override {
507     assert(d < getRank());
508     *out = &pointers[d];
509   }
510   void getIndices(std::vector<I> **out, uint64_t d) final override {
511     assert(d < getRank());
512     *out = &indices[d];
513   }
514   void getValues(std::vector<V> **out) final override { *out = &values; }
515 
516   /// Partially specialize lexicographical insertions based on template types.
517   void lexInsert(const uint64_t *cursor, V val) final override {
518     // First, wrap up pending insertion path.
519     uint64_t diff = 0;
520     uint64_t top = 0;
521     if (!values.empty()) {
522       diff = lexDiff(cursor);
523       endPath(diff + 1);
524       top = idx[diff] + 1;
525     }
526     // Then continue with insertion path.
527     insPath(cursor, diff, top, val);
528   }
529 
530   /// Partially specialize expanded insertions based on template types.
531   /// Note that this method resets the values/filled-switch array back
532   /// to all-zero/false while only iterating over the nonzero elements.
533   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
534                  uint64_t count) final override {
535     if (count == 0)
536       return;
537     // Sort.
538     std::sort(added, added + count);
539     // Restore insertion path for first insert.
540     const uint64_t lastDim = getRank() - 1;
541     uint64_t index = added[0];
542     cursor[lastDim] = index;
543     lexInsert(cursor, values[index]);
544     assert(filled[index]);
545     values[index] = 0;
546     filled[index] = false;
547     // Subsequent insertions are quick.
548     for (uint64_t i = 1; i < count; i++) {
549       assert(index < added[i] && "non-lexicographic insertion");
550       index = added[i];
551       cursor[lastDim] = index;
552       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
553       assert(filled[index]);
554       values[index] = 0;
555       filled[index] = false;
556     }
557   }
558 
559   /// Finalizes lexicographic insertions.
560   void endInsert() final override {
561     if (values.empty())
562       finalizeSegment(0);
563     else
564       endPath(0);
565   }
566 
567   void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
568                      const uint64_t *perm) const final override {
569     *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
570   }
571 
572   /// Returns this sparse tensor storage scheme as a new memory-resident
573   /// sparse tensor in coordinate scheme with the given dimension order.
574   ///
575   /// Precondition: `perm` must be valid for `getRank()`.
576   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
577     SparseTensorEnumeratorBase<V> *enumerator;
578     newEnumerator(&enumerator, getRank(), perm);
579     SparseTensorCOO<V> *coo =
580         new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
581     enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
582       coo->add(ind, val);
583     });
584     // TODO: This assertion assumes there are no stored zeros,
585     // or if there are then that we don't filter them out.
586     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
587     assert(coo->getElements().size() == values.size());
588     delete enumerator;
589     return coo;
590   }
591 
592   /// Factory method. Constructs a sparse tensor storage scheme with the given
593   /// dimensions, permutation, and per-dimension dense/sparse annotations,
594   /// using the coordinate scheme tensor for the initial contents if provided.
595   /// In the latter case, the coordinate scheme must respect the same
596   /// permutation as is desired for the new sparse tensor storage.
597   ///
598   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
599   static SparseTensorStorage<P, I, V> *
600   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
601                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
602     SparseTensorStorage<P, I, V> *n = nullptr;
603     if (coo) {
604       const auto &coosz = coo->getDimSizes();
605       assertPermutedSizesMatchShape(coosz, rank, perm, shape);
606       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
607     } else {
608       std::vector<uint64_t> permsz(rank);
609       for (uint64_t r = 0; r < rank; r++) {
610         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
611         permsz[perm[r]] = shape[r];
612       }
613       // We pass the null `coo` to ensure we select the intended constructor.
614       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
615     }
616     return n;
617   }
618 
619   /// Factory method. Constructs a sparse tensor storage scheme with
620   /// the given dimensions, permutation, and per-dimension dense/sparse
621   /// annotations, using the sparse tensor for the initial contents.
622   ///
623   /// Preconditions:
624   /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
625   /// * The `tensor` must have the same value type `V`.
626   static SparseTensorStorage<P, I, V> *
627   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
628                   const DimLevelType *sparsity,
629                   const SparseTensorStorageBase *source) {
630     assert(source && "Got nullptr for source");
631     SparseTensorEnumeratorBase<V> *enumerator;
632     source->newEnumerator(&enumerator, rank, perm);
633     const auto &permsz = enumerator->permutedSizes();
634     assertPermutedSizesMatchShape(permsz, rank, perm, shape);
635     auto *tensor =
636         new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
637     delete enumerator;
638     return tensor;
639   }
640 
641 private:
642   /// Appends an arbitrary new position to `pointers[d]`.  This method
643   /// checks that `pos` is representable in the `P` type; however, it
644   /// does not check that `pos` is semantically valid (i.e., larger than
645   /// the previous position and smaller than `indices[d].capacity()`).
646   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
647     assert(isCompressedDim(d));
648     assert(pos <= std::numeric_limits<P>::max() &&
649            "Pointer value is too large for the P-type");
650     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
651   }
652 
653   /// Appends index `i` to dimension `d`, in the semantically general
654   /// sense.  For non-dense dimensions, that means appending to the
655   /// `indices[d]` array, checking that `i` is representable in the `I`
656   /// type; however, we do not verify other semantic requirements (e.g.,
657   /// that `i` is in bounds for `dimSizes[d]`, and not previously occurring
658   /// in the same segment).  For dense dimensions, this method instead
659   /// appends the appropriate number of zeros to the `values` array,
660   /// where `full` is the number of "entries" already written to `values`
661   /// for this segment (aka one after the highest index previously appended).
662   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
663     if (isCompressedDim(d)) {
664       assert(i <= std::numeric_limits<I>::max() &&
665              "Index value is too large for the I-type");
666       indices[d].push_back(static_cast<I>(i));
667     } else { // Dense dimension.
668       assert(i >= full && "Index was already filled");
669       if (i == full)
670         return; // Short-circuit, since it'll be a nop.
671       if (d + 1 == getRank())
672         values.insert(values.end(), i - full, 0);
673       else
674         finalizeSegment(d + 1, 0, i - full);
675     }
676   }
677 
678   /// Writes the given coordinate to `indices[d][pos]`.  This method
679   /// checks that `i` is representable in the `I` type; however, it
680   /// does not check that `i` is semantically valid (i.e., in bounds
681   /// for `dimSizes[d]` and not elsewhere occurring in the same segment).
682   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
683     assert(isCompressedDim(d));
684     // Subscript assignment to `std::vector` requires that the `pos`-th
685     // entry has been initialized; thus we must be sure to check `size()`
686     // here, instead of `capacity()` as would be ideal.
687     assert(pos < indices[d].size() && "Index position is out of bounds");
688     assert(i <= std::numeric_limits<I>::max() &&
689            "Index value is too large for the I-type");
690     indices[d][pos] = static_cast<I>(i);
691   }
692 
693   /// Computes the assembled-size associated with the `d`-th dimension,
694   /// given the assembled-size associated with the `(d-1)`-th dimension.
695   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
696   /// storage, as opposed to "dimension-sizes" which are the cardinality
697   /// of coordinates for that dimension.
698   ///
699   /// Precondition: the `pointers[d]` array must be fully initialized
700   /// before calling this method.
701   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
702     if (isCompressedDim(d))
703       return pointers[d][parentSz];
704     // else if dense:
705     return parentSz * getDimSizes()[d];
706   }
707 
708   /// Initializes sparse tensor storage scheme from a memory-resident sparse
709   /// tensor in coordinate scheme. This method prepares the pointers and
710   /// indices arrays under the given per-dimension dense/sparse annotations.
711   ///
712   /// Preconditions:
713   /// (1) the `elements` must be lexicographically sorted.
714   /// (2) the indices of every element are valid for `dimSizes` (equal rank
715   ///     and pointwise less-than).
716   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
717                uint64_t hi, uint64_t d) {
718     uint64_t rank = getRank();
719     assert(d <= rank && hi <= elements.size());
720     // Once dimensions are exhausted, insert the numerical values.
721     if (d == rank) {
722       assert(lo < hi);
723       values.push_back(elements[lo].value);
724       return;
725     }
726     // Visit all elements in this interval.
727     uint64_t full = 0;
728     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
729       // Find segment in interval with same index elements in this dimension.
730       uint64_t i = elements[lo].indices[d];
731       uint64_t seg = lo + 1;
732       while (seg < hi && elements[seg].indices[d] == i)
733         seg++;
734       // Handle segment in interval for sparse or dense dimension.
735       appendIndex(d, full, i);
736       full = i + 1;
737       fromCOO(elements, lo, seg, d + 1);
738       // And move on to next segment in interval.
739       lo = seg;
740     }
741     // Finalize the sparse pointer structure at this dimension.
742     finalizeSegment(d, full);
743   }
744 
745   /// Finalize the sparse pointer structure at this dimension.
746   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
747     if (count == 0)
748       return; // Short-circuit, since it'll be a nop.
749     if (isCompressedDim(d)) {
750       appendPointer(d, indices[d].size(), count);
751     } else { // Dense dimension.
752       const uint64_t sz = getDimSizes()[d];
753       assert(sz >= full && "Segment is overfull");
754       count = checkedMul(count, sz - full);
755       // For dense storage we must enumerate all the remaining coordinates
756       // in this dimension (i.e., coordinates after the last non-zero
757       // element), and either fill in their zero values or else recurse
758       // to finalize some deeper dimension.
759       if (d + 1 == getRank())
760         values.insert(values.end(), count, 0);
761       else
762         finalizeSegment(d + 1, 0, count);
763     }
764   }
765 
766   /// Wraps up a single insertion path, inner to outer.
767   void endPath(uint64_t diff) {
768     uint64_t rank = getRank();
769     assert(diff <= rank);
770     for (uint64_t i = 0; i < rank - diff; i++) {
771       const uint64_t d = rank - i - 1;
772       finalizeSegment(d, idx[d] + 1);
773     }
774   }
775 
776   /// Continues a single insertion path, outer to inner.
777   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
778     uint64_t rank = getRank();
779     assert(diff < rank);
780     for (uint64_t d = diff; d < rank; d++) {
781       uint64_t i = cursor[d];
782       appendIndex(d, top, i);
783       top = 0;
784       idx[d] = i;
785     }
786     values.push_back(val);
787   }
788 
789   /// Finds the lexicographic differing dimension.
790   uint64_t lexDiff(const uint64_t *cursor) const {
791     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
792       if (cursor[r] > idx[r])
793         return r;
794       else
795         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
796     assert(0 && "duplication insertion");
797     return -1u;
798   }
799 
800   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
801   // the cost of virtual-function dispatch in inner loops), without
802   // making them public to other client code.
803   friend class SparseTensorEnumerator<P, I, V>;
804 
805   std::vector<std::vector<P>> pointers;
806   std::vector<std::vector<I>> indices;
807   std::vector<V> values;
808   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
809 };
810 
811 /// A (higher-order) function object for enumerating the elements of some
812 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
813 /// method encapsulates the loop-nest for enumerating the elements of
814 /// the source tensor (in whatever order is best for the source tensor),
815 /// and applies a permutation to the coordinates/indices before handing
816 /// each element to the callback.  A single enumerator object can be
817 /// freely reused for several calls to `forallElements`, just so long
818 /// as each call is sequential with respect to one another.
819 ///
820 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
821 /// passed to the constructor; thus, objects of this class must not
822 /// outlive the sparse tensor they depend on.
823 ///
824 /// Design Note: The reason we define this class instead of simply using
825 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
826 /// the `<P,I>` template parameters from MLIR client code (to simplify the
827 /// type parameters used for direct sparse-to-sparse conversion).  And the
828 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
829 /// than simply using this class, is to avoid the cost of virtual-method
830 /// dispatch within the loop-nest.
831 template <typename V>
832 class SparseTensorEnumeratorBase {
833 public:
834   /// Constructs an enumerator with the given permutation for mapping
835   /// the semantic-ordering of dimensions to the desired target-ordering.
836   ///
837   /// Preconditions:
838   /// * the `tensor` must have the same `V` value type.
839   /// * `perm` must be valid for `rank`.
840   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
841                              uint64_t rank, const uint64_t *perm)
842       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
843         cursor(getRank()) {
844     assert(perm && "Received nullptr for permutation");
845     assert(rank == getRank() && "Permutation rank mismatch");
846     const auto &rev = src.getRev();           // source-order -> semantic-order
847     const auto &dimSizes = src.getDimSizes(); // in source storage-order
848     for (uint64_t s = 0; s < rank; s++) {     // `s` source storage-order
849       uint64_t t = perm[rev[s]];              // `t` target-order
850       reord[s] = t;
851       permsz[t] = dimSizes[s];
852     }
853   }
854 
855   virtual ~SparseTensorEnumeratorBase() = default;
856 
857   // We disallow copying to help avoid leaking the `src` reference.
858   // (In addition to avoiding the problem of slicing.)
859   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
860   SparseTensorEnumeratorBase &
861   operator=(const SparseTensorEnumeratorBase &) = delete;
862 
863   /// Returns the source/target tensor's rank.  (The source-rank and
864   /// target-rank are always equal since we only support permutations.
865   /// Though once we add support for other dimension mappings, this
866   /// method will have to be split in two.)
867   uint64_t getRank() const { return permsz.size(); }
868 
869   /// Returns the target tensor's dimension sizes.
870   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
871 
872   /// Enumerates all elements of the source tensor, permutes their
873   /// indices, and passes the permuted element to the callback.
874   /// The callback must not store the cursor reference directly,
875   /// since this function reuses the storage.  Instead, the callback
876   /// must copy it if they want to keep it.
877   virtual void forallElements(ElementConsumer<V> yield) = 0;
878 
879 protected:
880   const SparseTensorStorageBase &src;
881   std::vector<uint64_t> permsz; // in target order.
882   std::vector<uint64_t> reord;  // source storage-order -> target order.
883   std::vector<uint64_t> cursor; // in target order.
884 };
885 
886 template <typename P, typename I, typename V>
887 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
888   using Base = SparseTensorEnumeratorBase<V>;
889 
890 public:
891   /// Constructs an enumerator with the given permutation for mapping
892   /// the semantic-ordering of dimensions to the desired target-ordering.
893   ///
894   /// Precondition: `perm` must be valid for `rank`.
895   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
896                          uint64_t rank, const uint64_t *perm)
897       : Base(tensor, rank, perm) {}
898 
899   ~SparseTensorEnumerator() final override = default;
900 
901   void forallElements(ElementConsumer<V> yield) final override {
902     forallElements(yield, 0, 0);
903   }
904 
905 private:
906   /// The recursive component of the public `forallElements`.
907   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
908                       uint64_t d) {
909     // Recover the `<P,I,V>` type parameters of `src`.
910     const auto &src =
911         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
912     if (d == Base::getRank()) {
913       assert(parentPos < src.values.size() &&
914              "Value position is out of bounds");
915       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
916       yield(this->cursor, src.values[parentPos]);
917     } else if (src.isCompressedDim(d)) {
918       // Look up the bounds of the `d`-level segment determined by the
919       // `d-1`-level position `parentPos`.
920       const std::vector<P> &pointers_d = src.pointers[d];
921       assert(parentPos + 1 < pointers_d.size() &&
922              "Parent pointer position is out of bounds");
923       const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
924       const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
925       // Loop-invariant code for looking up the `d`-level coordinates/indices.
926       const std::vector<I> &indices_d = src.indices[d];
927       assert(pstop <= indices_d.size() && "Index position is out of bounds");
928       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
929       for (uint64_t pos = pstart; pos < pstop; pos++) {
930         cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
931         forallElements(yield, pos, d + 1);
932       }
933     } else { // Dense dimension.
934       const uint64_t sz = src.getDimSizes()[d];
935       const uint64_t pstart = parentPos * sz;
936       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
937       for (uint64_t i = 0; i < sz; i++) {
938         cursor_reord_d = i;
939         forallElements(yield, pstart + i, d + 1);
940       }
941     }
942   }
943 };
944 
945 /// Statistics regarding the number of nonzero subtensors in
946 /// a source tensor, for direct sparse=>sparse conversion a la
947 /// <https://arxiv.org/abs/2001.02609>.
948 ///
949 /// N.B., this class stores references to the parameters passed to
950 /// the constructor; thus, objects of this class must not outlive
951 /// those parameters.
952 class SparseTensorNNZ final {
953 public:
954   /// Allocate the statistics structure for the desired sizes and
955   /// sparsity (in the target tensor's storage-order).  This constructor
956   /// does not actually populate the statistics, however; for that see
957   /// `initialize`.
958   ///
959   /// Precondition: `dimSizes` must not contain zeros.
960   SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
961                   const std::vector<DimLevelType> &sparsity)
962       : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
963     assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
964     bool uncompressed = true;
965     uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
966     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
967       switch (dimTypes[r]) {
968       case DimLevelType::kCompressed:
969         assert(uncompressed &&
970                "Multiple compressed layers not currently supported");
971         uncompressed = false;
972         nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
973         break;
974       case DimLevelType::kDense:
975         assert(uncompressed &&
976                "Dense after compressed not currently supported");
977         break;
978       case DimLevelType::kSingleton:
979         // Singleton after Compressed causes no problems for allocating
980         // `nnz` nor for the yieldPos loop.  This remains true even
981         // when adding support for multiple compressed dimensions or
982         // for dense-after-compressed.
983         break;
984       }
985       sz = checkedMul(sz, dimSizes[r]);
986     }
987   }
988 
989   // We disallow copying to help avoid leaking the stored references.
990   SparseTensorNNZ(const SparseTensorNNZ &) = delete;
991   SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
992 
993   /// Returns the rank of the target tensor.
994   uint64_t getRank() const { return dimSizes.size(); }
995 
996   /// Enumerate the source tensor to fill in the statistics.  The
997   /// enumerator should already incorporate the permutation (from
998   /// semantic-order to the target storage-order).
999   template <typename V>
1000   void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
1001     assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
1002     assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
1003     enumerator.forallElements(
1004         [this](const std::vector<uint64_t> &ind, V) { add(ind); });
1005   }
1006 
1007   /// The type of callback functions which receive an nnz-statistic.
1008   using NNZConsumer = const std::function<void(uint64_t)> &;
1009 
1010   /// Lexicographically enumerates all indicies for dimensions strictly
1011   /// less than `stopDim`, and passes their nnz statistic to the callback.
1012   /// Since our use-case only requires the statistic not the coordinates
1013   /// themselves, we do not bother to construct those coordinates.
1014   void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
1015     assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
1016     assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
1017            "Cannot look up non-compressed dimensions");
1018     forallIndices(yield, stopDim, 0, 0);
1019   }
1020 
1021 private:
1022   /// Adds a new element (i.e., increment its statistics).  We use
1023   /// a method rather than inlining into the lambda in `initialize`,
1024   /// to avoid spurious templating over `V`.  And this method is private
1025   /// to avoid needing to re-assert validity of `ind` (which is guaranteed
1026   /// by `forallElements`).
1027   void add(const std::vector<uint64_t> &ind) {
1028     uint64_t parentPos = 0;
1029     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1030       if (dimTypes[r] == DimLevelType::kCompressed)
1031         nnz[r][parentPos]++;
1032       parentPos = parentPos * dimSizes[r] + ind[r];
1033     }
1034   }
1035 
1036   /// Recursive component of the public `forallIndices`.
1037   void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
1038                      uint64_t d) const {
1039     assert(d <= stopDim);
1040     if (d == stopDim) {
1041       assert(parentPos < nnz[d].size() && "Cursor is out of range");
1042       yield(nnz[d][parentPos]);
1043     } else {
1044       const uint64_t sz = dimSizes[d];
1045       const uint64_t pstart = parentPos * sz;
1046       for (uint64_t i = 0; i < sz; i++)
1047         forallIndices(yield, stopDim, pstart + i, d + 1);
1048     }
1049   }
1050 
1051   // All of these are in the target storage-order.
1052   const std::vector<uint64_t> &dimSizes;
1053   const std::vector<DimLevelType> &dimTypes;
1054   std::vector<std::vector<uint64_t>> nnz;
1055 };
1056 
1057 template <typename P, typename I, typename V>
1058 SparseTensorStorage<P, I, V>::SparseTensorStorage(
1059     const std::vector<uint64_t> &dimSizes, const uint64_t *perm,
1060     const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1061     : SparseTensorStorage(dimSizes, perm, sparsity) {
1062   SparseTensorEnumeratorBase<V> *enumerator;
1063   tensor.newEnumerator(&enumerator, getRank(), perm);
1064   {
1065     // Initialize the statistics structure.
1066     SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
1067     nnz.initialize(*enumerator);
1068     // Initialize "pointers" overhead (and allocate "indices", "values").
1069     uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
1070     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1071       if (isCompressedDim(r)) {
1072         pointers[r].reserve(parentSz + 1);
1073         pointers[r].push_back(0);
1074         uint64_t currentPos = 0;
1075         nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
1076           currentPos += n;
1077           appendPointer(r, currentPos);
1078         });
1079         assert(pointers[r].size() == parentSz + 1 &&
1080                "Final pointers size doesn't match allocated size");
1081         // That assertion entails `assembledSize(parentSz, r)`
1082         // is now in a valid state.  That is, `pointers[r][parentSz]`
1083         // equals the present value of `currentPos`, which is the
1084         // correct assembled-size for `indices[r]`.
1085       }
1086       // Update assembled-size for the next iteration.
1087       parentSz = assembledSize(parentSz, r);
1088       // Ideally we need only `indices[r].reserve(parentSz)`, however
1089       // the `std::vector` implementation forces us to initialize it too.
1090       // That is, in the yieldPos loop we need random-access assignment
1091       // to `indices[r]`; however, `std::vector`'s subscript-assignment
1092       // only allows assigning to already-initialized positions.
1093       if (isCompressedDim(r))
1094         indices[r].resize(parentSz, 0);
1095     }
1096     values.resize(parentSz, 0); // Both allocate and zero-initialize.
1097   }
1098   // The yieldPos loop
1099   enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
1100     uint64_t parentSz = 1, parentPos = 0;
1101     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1102       if (isCompressedDim(r)) {
1103         // If `parentPos == parentSz` then it's valid as an array-lookup;
1104         // however, it's semantically invalid here since that entry
1105         // does not represent a segment of `indices[r]`.  Moreover, that
1106         // entry must be immutable for `assembledSize` to remain valid.
1107         assert(parentPos < parentSz && "Pointers position is out of bounds");
1108         const uint64_t currentPos = pointers[r][parentPos];
1109         // This increment won't overflow the `P` type, since it can't
1110         // exceed the original value of `pointers[r][parentPos+1]`
1111         // which was already verified to be within bounds for `P`
1112         // when it was written to the array.
1113         pointers[r][parentPos]++;
1114         writeIndex(r, currentPos, ind[r]);
1115         parentPos = currentPos;
1116       } else { // Dense dimension.
1117         parentPos = parentPos * getDimSizes()[r] + ind[r];
1118       }
1119       parentSz = assembledSize(parentSz, r);
1120     }
1121     assert(parentPos < values.size() && "Value position is out of bounds");
1122     values[parentPos] = val;
1123   });
1124   // No longer need the enumerator, so we'll delete it ASAP.
1125   delete enumerator;
1126   // The finalizeYieldPos loop
1127   for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
1128     if (isCompressedDim(r)) {
1129       assert(parentSz == pointers[r].size() - 1 &&
1130              "Actual pointers size doesn't match the expected size");
1131       // Can't check all of them, but at least we can check the last one.
1132       assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
1133              "Pointers got corrupted");
1134       // TODO: optimize this by using `memmove` or similar.
1135       for (uint64_t n = 0; n < parentSz; n++) {
1136         const uint64_t parentPos = parentSz - n;
1137         pointers[r][parentPos] = pointers[r][parentPos - 1];
1138       }
1139       pointers[r][0] = 0;
1140     }
1141     parentSz = assembledSize(parentSz, r);
1142   }
1143 }
1144 
1145 /// Helper to convert string to lower case.
1146 static char *toLower(char *token) {
1147   for (char *c = token; *c; c++)
1148     *c = tolower(*c);
1149   return token;
1150 }
1151 
1152 /// Read the MME header of a general sparse matrix of type real.
1153 static void readMMEHeader(FILE *file, char *filename, char *line,
1154                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
1155   char header[64];
1156   char object[64];
1157   char format[64];
1158   char field[64];
1159   char symmetry[64];
1160   // Read header line.
1161   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
1162              symmetry) != 5)
1163     FATAL("Corrupt header in %s\n", filename);
1164   // Set properties
1165   *isPattern = (strcmp(toLower(field), "pattern") == 0);
1166   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
1167   // Make sure this is a general sparse matrix.
1168   if (strcmp(toLower(header), "%%matrixmarket") ||
1169       strcmp(toLower(object), "matrix") ||
1170       strcmp(toLower(format), "coordinate") ||
1171       (strcmp(toLower(field), "real") && !(*isPattern)) ||
1172       (strcmp(toLower(symmetry), "general") && !(*isSymmetric)))
1173     FATAL("Cannot find a general sparse matrix in %s\n", filename);
1174   // Skip comments.
1175   while (true) {
1176     if (!fgets(line, kColWidth, file))
1177       FATAL("Cannot find data in %s\n", filename);
1178     if (line[0] != '%')
1179       break;
1180   }
1181   // Next line contains M N NNZ.
1182   idata[0] = 2; // rank
1183   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
1184              idata + 1) != 3)
1185     FATAL("Cannot find size in %s\n", filename);
1186 }
1187 
1188 /// Read the "extended" FROSTT header. Although not part of the documented
1189 /// format, we assume that the file starts with optional comments followed
1190 /// by two lines that define the rank, the number of nonzeros, and the
1191 /// dimensions sizes (one per rank) of the sparse tensor.
1192 static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
1193                                 uint64_t *idata) {
1194   // Skip comments.
1195   while (true) {
1196     if (!fgets(line, kColWidth, file))
1197       FATAL("Cannot find data in %s\n", filename);
1198     if (line[0] != '#')
1199       break;
1200   }
1201   // Next line contains RANK and NNZ.
1202   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
1203     FATAL("Cannot find metadata in %s\n", filename);
1204   // Followed by a line with the dimension sizes (one per rank).
1205   for (uint64_t r = 0; r < idata[0]; r++)
1206     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
1207       FATAL("Cannot find dimension size %s\n", filename);
1208   fgets(line, kColWidth, file); // end of line
1209 }
1210 
1211 /// Reads a sparse tensor with the given filename into a memory-resident
1212 /// sparse tensor in coordinate scheme.
1213 template <typename V>
1214 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
1215                                                const uint64_t *shape,
1216                                                const uint64_t *perm) {
1217   // Open the file.
1218   assert(filename && "Received nullptr for filename");
1219   FILE *file = fopen(filename, "r");
1220   if (!file)
1221     FATAL("Cannot find file %s\n", filename);
1222   // Perform some file format dependent set up.
1223   char line[kColWidth];
1224   uint64_t idata[512];
1225   bool isPattern = false;
1226   bool isSymmetric = false;
1227   if (strstr(filename, ".mtx")) {
1228     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
1229   } else if (strstr(filename, ".tns")) {
1230     readExtFROSTTHeader(file, filename, line, idata);
1231   } else {
1232     FATAL("Unknown format %s\n", filename);
1233   }
1234   // Prepare sparse tensor object with per-dimension sizes
1235   // and the number of nonzeros as initial capacity.
1236   assert(rank == idata[0] && "rank mismatch");
1237   uint64_t nnz = idata[1];
1238   for (uint64_t r = 0; r < rank; r++)
1239     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
1240            "dimension size mismatch");
1241   SparseTensorCOO<V> *tensor =
1242       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
1243   // Read all nonzero elements.
1244   std::vector<uint64_t> indices(rank);
1245   for (uint64_t k = 0; k < nnz; k++) {
1246     if (!fgets(line, kColWidth, file))
1247       FATAL("Cannot find next line of data in %s\n", filename);
1248     char *linePtr = line;
1249     for (uint64_t r = 0; r < rank; r++) {
1250       uint64_t idx = strtoul(linePtr, &linePtr, 10);
1251       // Add 0-based index.
1252       indices[perm[r]] = idx - 1;
1253     }
1254     // The external formats always store the numerical values with the type
1255     // double, but we cast these values to the sparse tensor object type.
1256     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
1257     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
1258     tensor->add(indices, value);
1259     // We currently chose to deal with symmetric matrices by fully constructing
1260     // them. In the future, we may want to make symmetry implicit for storage
1261     // reasons.
1262     if (isSymmetric && indices[0] != indices[1])
1263       tensor->add({indices[1], indices[0]}, value);
1264   }
1265   // Close the file and return tensor.
1266   fclose(file);
1267   return tensor;
1268 }
1269 
1270 /// Writes the sparse tensor to extended FROSTT format.
1271 template <typename V>
1272 static void outSparseTensor(void *tensor, void *dest, bool sort) {
1273   assert(tensor && dest);
1274   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
1275   if (sort)
1276     coo->sort();
1277   char *filename = static_cast<char *>(dest);
1278   auto &dimSizes = coo->getDimSizes();
1279   auto &elements = coo->getElements();
1280   uint64_t rank = coo->getRank();
1281   uint64_t nnz = elements.size();
1282   std::fstream file;
1283   file.open(filename, std::ios_base::out | std::ios_base::trunc);
1284   assert(file.is_open());
1285   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1286   for (uint64_t r = 0; r < rank - 1; r++)
1287     file << dimSizes[r] << " ";
1288   file << dimSizes[rank - 1] << std::endl;
1289   for (uint64_t i = 0; i < nnz; i++) {
1290     auto &idx = elements[i].indices;
1291     for (uint64_t r = 0; r < rank; r++)
1292       file << (idx[r] + 1) << " ";
1293     file << elements[i].value << std::endl;
1294   }
1295   file.flush();
1296   file.close();
1297   assert(file.good());
1298 }
1299 
1300 /// Initializes sparse tensor from an external COO-flavored format.
1301 template <typename V>
1302 static SparseTensorStorage<uint64_t, uint64_t, V> *
1303 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
1304                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
1305   const DimLevelType *sparsity = (DimLevelType *)(sparse);
1306 #ifndef NDEBUG
1307   // Verify that perm is a permutation of 0..(rank-1).
1308   std::vector<uint64_t> order(perm, perm + rank);
1309   std::sort(order.begin(), order.end());
1310   for (uint64_t i = 0; i < rank; ++i)
1311     if (i != order[i])
1312       FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
1313 
1314   // Verify that the sparsity values are supported.
1315   for (uint64_t i = 0; i < rank; ++i)
1316     if (sparsity[i] != DimLevelType::kDense &&
1317         sparsity[i] != DimLevelType::kCompressed)
1318       FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
1319 #endif
1320 
1321   // Convert external format to internal COO.
1322   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
1323   std::vector<uint64_t> idx(rank);
1324   for (uint64_t i = 0, base = 0; i < nse; i++) {
1325     for (uint64_t r = 0; r < rank; r++)
1326       idx[perm[r]] = indices[base + r];
1327     coo->add(idx, values[i]);
1328     base += rank;
1329   }
1330   // Return sparse tensor storage format as opaque pointer.
1331   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
1332       rank, shape, perm, sparsity, coo);
1333   delete coo;
1334   return tensor;
1335 }
1336 
1337 /// Converts a sparse tensor to an external COO-flavored format.
1338 template <typename V>
1339 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1340                                  uint64_t **pShape, V **pValues,
1341                                  uint64_t **pIndices) {
1342   assert(tensor);
1343   auto sparseTensor =
1344       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1345   uint64_t rank = sparseTensor->getRank();
1346   std::vector<uint64_t> perm(rank);
1347   std::iota(perm.begin(), perm.end(), 0);
1348   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1349 
1350   const std::vector<Element<V>> &elements = coo->getElements();
1351   uint64_t nse = elements.size();
1352 
1353   uint64_t *shape = new uint64_t[rank];
1354   for (uint64_t i = 0; i < rank; i++)
1355     shape[i] = coo->getDimSizes()[i];
1356 
1357   V *values = new V[nse];
1358   uint64_t *indices = new uint64_t[rank * nse];
1359 
1360   for (uint64_t i = 0, base = 0; i < nse; i++) {
1361     values[i] = elements[i].value;
1362     for (uint64_t j = 0; j < rank; j++)
1363       indices[base + j] = elements[i].indices[j];
1364     base += rank;
1365   }
1366 
1367   delete coo;
1368   *pRank = rank;
1369   *pNse = nse;
1370   *pShape = shape;
1371   *pValues = values;
1372   *pIndices = indices;
1373 }
1374 
1375 } // namespace
1376 
1377 extern "C" {
1378 
1379 //===----------------------------------------------------------------------===//
1380 //
1381 // Public API with methods that operate on MLIR buffers (memrefs) to interact
1382 // with sparse tensors, which are only visible as opaque pointers externally.
1383 // These methods should be used exclusively by MLIR compiler-generated code.
1384 //
1385 // Some macro magic is used to generate implementations for all required type
1386 // combinations that can be called from MLIR compiler-generated code.
1387 //
1388 //===----------------------------------------------------------------------===//
1389 
1390 #define CASE(p, i, v, P, I, V)                                                 \
1391   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
1392     SparseTensorCOO<V> *coo = nullptr;                                         \
1393     if (action <= Action::kFromCOO) {                                          \
1394       if (action == Action::kFromFile) {                                       \
1395         char *filename = static_cast<char *>(ptr);                             \
1396         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
1397       } else if (action == Action::kFromCOO) {                                 \
1398         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
1399       } else {                                                                 \
1400         assert(action == Action::kEmpty);                                      \
1401       }                                                                        \
1402       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
1403           rank, shape, perm, sparsity, coo);                                   \
1404       if (action == Action::kFromFile)                                         \
1405         delete coo;                                                            \
1406       return tensor;                                                           \
1407     }                                                                          \
1408     if (action == Action::kSparseToSparse) {                                   \
1409       auto *tensor = static_cast<SparseTensorStorageBase *>(ptr);              \
1410       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
1411                                                            sparsity, tensor);  \
1412     }                                                                          \
1413     if (action == Action::kEmptyCOO)                                           \
1414       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
1415     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1416     if (action == Action::kToIterator) {                                       \
1417       coo->startIterator();                                                    \
1418     } else {                                                                   \
1419       assert(action == Action::kToCOO);                                        \
1420     }                                                                          \
1421     return coo;                                                                \
1422   }
1423 
1424 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1425 
1426 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1427 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1428 // that this file cannot get out of sync with its header.
1429 static_assert(std::is_same<index_type, uint64_t>::value,
1430               "Expected index_type == uint64_t");
1431 
1432 /// Constructs a new sparse tensor. This is the "swiss army knife"
1433 /// method for materializing sparse tensors into the computation.
1434 ///
1435 /// Action:
1436 /// kEmpty = returns empty storage to fill later
1437 /// kFromFile = returns storage, where ptr contains filename to read
1438 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
1439 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
1440 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
1441 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
1442 void *
1443 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1444                              StridedMemRefType<index_type, 1> *sref,
1445                              StridedMemRefType<index_type, 1> *pref,
1446                              OverheadType ptrTp, OverheadType indTp,
1447                              PrimaryType valTp, Action action, void *ptr) {
1448   assert(aref && sref && pref);
1449   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1450          pref->strides[0] == 1);
1451   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1452   const DimLevelType *sparsity = aref->data + aref->offset;
1453   const index_type *shape = sref->data + sref->offset;
1454   const index_type *perm = pref->data + pref->offset;
1455   uint64_t rank = aref->sizes[0];
1456 
1457   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1458   // This is safe because of the static_assert above.
1459   if (ptrTp == OverheadType::kIndex)
1460     ptrTp = OverheadType::kU64;
1461   if (indTp == OverheadType::kIndex)
1462     indTp = OverheadType::kU64;
1463 
1464   // Double matrices with all combinations of overhead storage.
1465   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1466        uint64_t, double);
1467   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1468        uint32_t, double);
1469   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1470        uint16_t, double);
1471   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1472        uint8_t, double);
1473   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1474        uint64_t, double);
1475   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1476        uint32_t, double);
1477   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1478        uint16_t, double);
1479   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1480        uint8_t, double);
1481   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1482        uint64_t, double);
1483   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1484        uint32_t, double);
1485   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1486        uint16_t, double);
1487   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1488        uint8_t, double);
1489   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1490        uint64_t, double);
1491   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1492        uint32_t, double);
1493   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1494        uint16_t, double);
1495   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1496        uint8_t, double);
1497 
1498   // Float matrices with all combinations of overhead storage.
1499   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1500        uint64_t, float);
1501   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1502        uint32_t, float);
1503   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1504        uint16_t, float);
1505   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1506        uint8_t, float);
1507   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1508        uint64_t, float);
1509   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1510        uint32_t, float);
1511   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1512        uint16_t, float);
1513   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1514        uint8_t, float);
1515   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1516        uint64_t, float);
1517   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1518        uint32_t, float);
1519   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1520        uint16_t, float);
1521   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1522        uint8_t, float);
1523   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1524        uint64_t, float);
1525   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1526        uint32_t, float);
1527   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1528        uint16_t, float);
1529   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1530        uint8_t, float);
1531 
1532   // Integral matrices with both overheads of the same type.
1533   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1534   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1535   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1536   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1537   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1538   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1539   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1540   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1541   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1542   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1543   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1544   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1545   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1546 
1547   // Complex matrices with wide overhead.
1548   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1549   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1550 
1551   // Unsupported case (add above if needed).
1552   // TODO: better pretty-printing of enum values!
1553   FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
1554         static_cast<int>(ptrTp), static_cast<int>(indTp),
1555         static_cast<int>(valTp));
1556 }
1557 #undef CASE
1558 #undef CASE_SECSAME
1559 
1560 /// Methods that provide direct access to values.
1561 #define IMPL_SPARSEVALUES(VNAME, V)                                            \
1562   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
1563                                         void *tensor) {                        \
1564     assert(ref &&tensor);                                                      \
1565     std::vector<V> *v;                                                         \
1566     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
1567     ref->basePtr = ref->data = v->data();                                      \
1568     ref->offset = 0;                                                           \
1569     ref->sizes[0] = v->size();                                                 \
1570     ref->strides[0] = 1;                                                       \
1571   }
1572 FOREVERY_V(IMPL_SPARSEVALUES)
1573 #undef IMPL_SPARSEVALUES
1574 
1575 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1576   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1577                            index_type d) {                                     \
1578     assert(ref &&tensor);                                                      \
1579     std::vector<TYPE> *v;                                                      \
1580     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1581     ref->basePtr = ref->data = v->data();                                      \
1582     ref->offset = 0;                                                           \
1583     ref->sizes[0] = v->size();                                                 \
1584     ref->strides[0] = 1;                                                       \
1585   }
1586 /// Methods that provide direct access to pointers.
1587 #define IMPL_SPARSEPOINTERS(PNAME, P)                                          \
1588   IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
1589 FOREVERY_O(IMPL_SPARSEPOINTERS)
1590 #undef IMPL_SPARSEPOINTERS
1591 
1592 /// Methods that provide direct access to indices.
1593 #define IMPL_SPARSEINDICES(INAME, I)                                           \
1594   IMPL_GETOVERHEAD(sparseIndices##INAME, I, getIndices)
1595 FOREVERY_O(IMPL_SPARSEINDICES)
1596 #undef IMPL_SPARSEINDICES
1597 #undef IMPL_GETOVERHEAD
1598 
1599 /// Helper to add value to coordinate scheme, one per value type.
1600 #define IMPL_ADDELT(VNAME, V)                                                  \
1601   void *_mlir_ciface_addElt##VNAME(void *coo, V value,                         \
1602                                    StridedMemRefType<index_type, 1> *iref,     \
1603                                    StridedMemRefType<index_type, 1> *pref) {   \
1604     assert(coo &&iref &&pref);                                                 \
1605     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1606     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1607     const index_type *indx = iref->data + iref->offset;                        \
1608     const index_type *perm = pref->data + pref->offset;                        \
1609     uint64_t isize = iref->sizes[0];                                           \
1610     std::vector<index_type> indices(isize);                                    \
1611     for (uint64_t r = 0; r < isize; r++)                                       \
1612       indices[perm[r]] = indx[r];                                              \
1613     static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value);               \
1614     return coo;                                                                \
1615   }
1616 FOREVERY_SIMPLEX_V(IMPL_ADDELT)
1617 // `complex64` apparently doesn't encounter any ABI issues (yet).
1618 IMPL_ADDELT(C64, complex64)
1619 // TODO: cleaner way to avoid ABI padding problem?
1620 IMPL_ADDELT(C32ABI, complex32)
1621 void *_mlir_ciface_addEltC32(void *coo, float r, float i,
1622                              StridedMemRefType<index_type, 1> *iref,
1623                              StridedMemRefType<index_type, 1> *pref) {
1624   return _mlir_ciface_addEltC32ABI(coo, complex32(r, i), iref, pref);
1625 }
1626 #undef IMPL_ADDELT
1627 
1628 /// Helper to enumerate elements of coordinate scheme, one per value type.
1629 #define IMPL_GETNEXT(VNAME, V)                                                 \
1630   bool _mlir_ciface_getNext##VNAME(void *coo,                                  \
1631                                    StridedMemRefType<index_type, 1> *iref,     \
1632                                    StridedMemRefType<V, 0> *vref) {            \
1633     assert(coo &&iref &&vref);                                                 \
1634     assert(iref->strides[0] == 1);                                             \
1635     index_type *indx = iref->data + iref->offset;                              \
1636     V *value = vref->data + vref->offset;                                      \
1637     const uint64_t isize = iref->sizes[0];                                     \
1638     const Element<V> *elem =                                                   \
1639         static_cast<SparseTensorCOO<V> *>(coo)->getNext();                     \
1640     if (elem == nullptr)                                                       \
1641       return false;                                                            \
1642     for (uint64_t r = 0; r < isize; r++)                                       \
1643       indx[r] = elem->indices[r];                                              \
1644     *value = elem->value;                                                      \
1645     return true;                                                               \
1646   }
1647 FOREVERY_V(IMPL_GETNEXT)
1648 #undef IMPL_GETNEXT
1649 
1650 /// Insert elements in lexicographical index order, one per value type.
1651 #define IMPL_LEXINSERT(VNAME, V)                                               \
1652   void _mlir_ciface_lexInsert##VNAME(                                          \
1653       void *tensor, StridedMemRefType<index_type, 1> *cref, V val) {           \
1654     assert(tensor &&cref);                                                     \
1655     assert(cref->strides[0] == 1);                                             \
1656     index_type *cursor = cref->data + cref->offset;                            \
1657     assert(cursor);                                                            \
1658     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1659   }
1660 FOREVERY_SIMPLEX_V(IMPL_LEXINSERT)
1661 // `complex64` apparently doesn't encounter any ABI issues (yet).
1662 IMPL_LEXINSERT(C64, complex64)
1663 // TODO: cleaner way to avoid ABI padding problem?
1664 IMPL_LEXINSERT(C32ABI, complex32)
1665 void _mlir_ciface_lexInsertC32(void *tensor,
1666                                StridedMemRefType<index_type, 1> *cref, float r,
1667                                float i) {
1668   _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
1669 }
1670 #undef IMPL_LEXINSERT
1671 
1672 /// Insert using expansion, one per value type.
1673 #define IMPL_EXPINSERT(VNAME, V)                                               \
1674   void _mlir_ciface_expInsert##VNAME(                                          \
1675       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1676       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1677       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1678     assert(tensor &&cref &&vref &&fref &&aref);                                \
1679     assert(cref->strides[0] == 1);                                             \
1680     assert(vref->strides[0] == 1);                                             \
1681     assert(fref->strides[0] == 1);                                             \
1682     assert(aref->strides[0] == 1);                                             \
1683     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1684     index_type *cursor = cref->data + cref->offset;                            \
1685     V *values = vref->data + vref->offset;                                     \
1686     bool *filled = fref->data + fref->offset;                                  \
1687     index_type *added = aref->data + aref->offset;                             \
1688     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1689         cursor, values, filled, added, count);                                 \
1690   }
1691 FOREVERY_V(IMPL_EXPINSERT)
1692 #undef IMPL_EXPINSERT
1693 
1694 /// Output a sparse tensor, one per value type.
1695 #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \
1696   void outSparseTensor##VNAME(void *coo, void *dest, bool sort) {              \
1697     return outSparseTensor<V>(coo, dest, sort);                                \
1698   }
1699 FOREVERY_V(IMPL_OUTSPARSETENSOR)
1700 #undef IMPL_OUTSPARSETENSOR
1701 
1702 //===----------------------------------------------------------------------===//
1703 //
1704 // Public API with methods that accept C-style data structures to interact
1705 // with sparse tensors, which are only visible as opaque pointers externally.
1706 // These methods can be used both by MLIR compiler-generated code as well as by
1707 // an external runtime that wants to interact with MLIR compiler-generated code.
1708 //
1709 //===----------------------------------------------------------------------===//
1710 
1711 /// Helper method to read a sparse tensor filename from the environment,
1712 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1713 char *getTensorFilename(index_type id) {
1714   char var[80];
1715   sprintf(var, "TENSOR%" PRIu64, id);
1716   char *env = getenv(var);
1717   if (!env)
1718     FATAL("Environment variable %s is not set\n", var);
1719   return env;
1720 }
1721 
1722 /// Returns size of sparse tensor in given dimension.
1723 index_type sparseDimSize(void *tensor, index_type d) {
1724   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1725 }
1726 
1727 /// Finalizes lexicographic insertions.
1728 void endInsert(void *tensor) {
1729   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1730 }
1731 
1732 /// Releases sparse tensor storage.
1733 void delSparseTensor(void *tensor) {
1734   delete static_cast<SparseTensorStorageBase *>(tensor);
1735 }
1736 
1737 /// Releases sparse tensor coordinate scheme.
1738 #define IMPL_DELCOO(VNAME, V)                                                  \
1739   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1740     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1741   }
1742 FOREVERY_V(IMPL_DELCOO)
1743 #undef IMPL_DELCOO
1744 
1745 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
1746 /// data structures. The expected parameters are:
1747 ///
1748 ///   rank:    rank of tensor
1749 ///   nse:     number of specified elements (usually the nonzeros)
1750 ///   shape:   array with dimension size for each rank
1751 ///   values:  a "nse" array with values for all specified elements
1752 ///   indices: a flat "nse x rank" array with indices for all specified elements
1753 ///   perm:    the permutation of the dimensions in the storage
1754 ///   sparse:  the sparsity for the dimensions
1755 ///
1756 /// For example, the sparse matrix
1757 ///     | 1.0 0.0 0.0 |
1758 ///     | 0.0 5.0 3.0 |
1759 /// can be passed as
1760 ///      rank    = 2
1761 ///      nse     = 3
1762 ///      shape   = [2, 3]
1763 ///      values  = [1.0, 5.0, 3.0]
1764 ///      indices = [ 0, 0,  1, 1,  1, 2]
1765 //
1766 // TODO: generalize beyond 64-bit indices.
1767 //
1768 #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V)                               \
1769   void *convertToMLIRSparseTensor##VNAME(                                      \
1770       uint64_t rank, uint64_t nse, uint64_t *shape, V *values,                 \
1771       uint64_t *indices, uint64_t *perm, uint8_t *sparse) {                    \
1772     return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm,      \
1773                                  sparse);                                      \
1774   }
1775 FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
1776 #undef IMPL_CONVERTTOMLIRSPARSETENSOR
1777 
1778 /// Converts a sparse tensor to COO-flavored format expressed using C-style
1779 /// data structures. The expected output parameters are pointers for these
1780 /// values:
1781 ///
1782 ///   rank:    rank of tensor
1783 ///   nse:     number of specified elements (usually the nonzeros)
1784 ///   shape:   array with dimension size for each rank
1785 ///   values:  a "nse" array with values for all specified elements
1786 ///   indices: a flat "nse x rank" array with indices for all specified elements
1787 ///
1788 /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
1789 /// from convertToMLIRSparseTensor.
1790 ///
1791 //  TODO: Currently, values are copied from SparseTensorStorage to
1792 //  SparseTensorCOO, then to the output. We may want to reduce the number of
1793 //  copies.
1794 //
1795 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1796 // compressed
1797 //
1798 #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V)                             \
1799   void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank,       \
1800                                           uint64_t *pNse, uint64_t **pShape,   \
1801                                           V **pValues, uint64_t **pIndices) {  \
1802     fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices);   \
1803   }
1804 FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
1805 #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
1806 
1807 } // extern "C"
1808 
1809 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1810