1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 
19 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cctype>
24 #include <cstdio>
25 #include <cstdlib>
26 #include <cstring>
27 #include <fstream>
28 #include <functional>
29 #include <iostream>
30 #include <limits>
31 #include <numeric>
32 
33 //===----------------------------------------------------------------------===//
34 //
35 // Internal support for storing and reading sparse tensors.
36 //
37 // The following memory-resident sparse storage schemes are supported:
38 //
39 // (a) A coordinate scheme for temporarily storing and lexicographically
40 //     sorting a sparse tensor by index (SparseTensorCOO).
41 //
42 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
43 //     per-dimension sparse/dense annnotations together with a dimension
44 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
45 //
46 // The following external formats are supported:
47 //
48 // (1) Matrix Market Exchange (MME): *.mtx
49 //     https://math.nist.gov/MatrixMarket/formats.html
50 //
51 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
52 //     http://frostt.io/tensors/file-formats.html
53 //
54 // Two public APIs are supported:
55 //
56 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
57 //     tensors. These methods should be used exclusively by MLIR
58 //     compiler-generated code.
59 //
60 // (II) Methods that accept C-style data structures to interact with sparse
61 //      tensors. These methods can be used by any external runtime that wants
62 //      to interact with MLIR compiler-generated code.
63 //
64 // In both cases (I) and (II), the SparseTensorStorage format is externally
65 // only visible as an opaque pointer.
66 //
67 //===----------------------------------------------------------------------===//
68 
69 namespace {
70 
71 static constexpr int kColWidth = 1025;
72 
73 /// A version of `operator*` on `uint64_t` which checks for overflows.
74 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
75   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
76          "Integer overflow");
77   return lhs * rhs;
78 }
79 
80 // This macro helps minimize repetition of this idiom, as well as ensuring
81 // we have some additional output indicating where the error is coming from.
82 // (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
83 // to track down whether an error is coming from our code vs somewhere else
84 // in MLIR.)
85 #define FATAL(...)                                                             \
86   do {                                                                         \
87     fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__);                        \
88     exit(1);                                                                   \
89   } while (0)
90 
91 // TODO: try to unify this with `SparseTensorFile::assertMatchesShape`
92 // which is used by `openSparseTensorCOO`.  It's easy enough to resolve
93 // the `std::vector` vs pointer mismatch for `dimSizes`; but it's trickier
94 // to resolve the presence/absence of `perm` (without introducing extra
95 // overhead), so perhaps the code duplication is unavoidable.
96 //
97 /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
98 /// semantic-order to target-order) are a refinement of the desired `shape`
99 /// (in semantic-order).
100 ///
101 /// Precondition: `perm` and `shape` must be valid for `rank`.
102 static inline void
103 assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes,
104                               uint64_t rank, const uint64_t *perm,
105                               const uint64_t *shape) {
106   assert(perm && shape);
107   assert(rank == dimSizes.size() && "Rank mismatch");
108   for (uint64_t r = 0; r < rank; r++)
109     assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) &&
110            "Dimension size mismatch");
111 }
112 
113 /// A sparse tensor element in coordinate scheme (value and indices).
114 /// For example, a rank-1 vector element would look like
115 ///   ({i}, a[i])
116 /// and a rank-5 tensor element like
117 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
118 /// We use pointer to a shared index pool rather than e.g. a direct
119 /// vector since that (1) reduces the per-element memory footprint, and
120 /// (2) centralizes the memory reservation and (re)allocation to one place.
121 template <typename V>
122 struct Element final {
123   Element(uint64_t *ind, V val) : indices(ind), value(val){};
124   uint64_t *indices; // pointer into shared index pool
125   V value;
126 };
127 
128 /// The type of callback functions which receive an element.  We avoid
129 /// packaging the coordinates and value together as an `Element` object
130 /// because this helps keep code somewhat cleaner.
131 template <typename V>
132 using ElementConsumer =
133     const std::function<void(const std::vector<uint64_t> &, V)> &;
134 
135 /// A memory-resident sparse tensor in coordinate scheme (collection of
136 /// elements). This data structure is used to read a sparse tensor from
137 /// any external format into memory and sort the elements lexicographically
138 /// by indices before passing it back to the client (most packed storage
139 /// formats require the elements to appear in lexicographic index order).
140 template <typename V>
141 struct SparseTensorCOO final {
142 public:
143   SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
144       : dimSizes(dimSizes) {
145     if (capacity) {
146       elements.reserve(capacity);
147       indices.reserve(capacity * getRank());
148     }
149   }
150 
151   /// Adds element as indices and value.
152   void add(const std::vector<uint64_t> &ind, V val) {
153     assert(!iteratorLocked && "Attempt to add() after startIterator()");
154     uint64_t *base = indices.data();
155     uint64_t size = indices.size();
156     uint64_t rank = getRank();
157     assert(ind.size() == rank && "Element rank mismatch");
158     for (uint64_t r = 0; r < rank; r++) {
159       assert(ind[r] < dimSizes[r] && "Index is too large for the dimension");
160       indices.push_back(ind[r]);
161     }
162     // This base only changes if indices were reallocated. In that case, we
163     // need to correct all previous pointers into the vector. Note that this
164     // only happens if we did not set the initial capacity right, and then only
165     // for every internal vector reallocation (which with the doubling rule
166     // should only incur an amortized linear overhead).
167     uint64_t *newBase = indices.data();
168     if (newBase != base) {
169       for (uint64_t i = 0, n = elements.size(); i < n; i++)
170         elements[i].indices = newBase + (elements[i].indices - base);
171       base = newBase;
172     }
173     // Add element as (pointer into shared index pool, value) pair.
174     elements.emplace_back(base + size, val);
175   }
176 
177   /// Sorts elements lexicographically by index.
178   void sort() {
179     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
180     // TODO: we may want to cache an `isSorted` bit, to avoid
181     // unnecessary/redundant sorting.
182     uint64_t rank = getRank();
183     std::sort(elements.begin(), elements.end(),
184               [rank](const Element<V> &e1, const Element<V> &e2) {
185                 for (uint64_t r = 0; r < rank; r++) {
186                   if (e1.indices[r] == e2.indices[r])
187                     continue;
188                   return e1.indices[r] < e2.indices[r];
189                 }
190                 return false;
191               });
192   }
193 
194   /// Get the rank of the tensor.
195   uint64_t getRank() const { return dimSizes.size(); }
196 
197   /// Getter for the dimension-sizes array.
198   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
199 
200   /// Getter for the elements array.
201   const std::vector<Element<V>> &getElements() const { return elements; }
202 
203   /// Switch into iterator mode.
204   void startIterator() {
205     iteratorLocked = true;
206     iteratorPos = 0;
207   }
208 
209   /// Get the next element.
210   const Element<V> *getNext() {
211     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
212     if (iteratorPos < elements.size())
213       return &(elements[iteratorPos++]);
214     iteratorLocked = false;
215     return nullptr;
216   }
217 
218   /// Factory method. Permutes the original dimensions according to
219   /// the given ordering and expects subsequent add() calls to honor
220   /// that same ordering for the given indices. The result is a
221   /// fully permuted coordinate scheme.
222   ///
223   /// Precondition: `dimSizes` and `perm` must be valid for `rank`.
224   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
225                                                 const uint64_t *dimSizes,
226                                                 const uint64_t *perm,
227                                                 uint64_t capacity = 0) {
228     std::vector<uint64_t> permsz(rank);
229     for (uint64_t r = 0; r < rank; r++) {
230       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
231       permsz[perm[r]] = dimSizes[r];
232     }
233     return new SparseTensorCOO<V>(permsz, capacity);
234   }
235 
236 private:
237   const std::vector<uint64_t> dimSizes; // per-dimension sizes
238   std::vector<Element<V>> elements;     // all COO elements
239   std::vector<uint64_t> indices;        // shared index pool
240   bool iteratorLocked = false;
241   unsigned iteratorPos = 0;
242 };
243 
244 // Forward.
245 template <typename V>
246 class SparseTensorEnumeratorBase;
247 
248 // Helper macro for generating error messages when some
249 // `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
250 // and then the wrong "partial method specialization" is called.
251 #define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
252 
253 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
254 /// takes responsibility for all the `<P,I,V>`-independent aspects
255 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
256 /// we use function overloading to implement "partial" method
257 /// specialization, which the C-API relies on to catch type errors
258 /// arising from our use of opaque pointers.
259 class SparseTensorStorageBase {
260 public:
261   /// Constructs a new storage object.  The `perm` maps the tensor's
262   /// semantic-ordering of dimensions to this object's storage-order.
263   /// The `dimSizes` and `sparsity` arrays are already in storage-order.
264   ///
265   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
266   SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes,
267                           const uint64_t *perm, const DimLevelType *sparsity)
268       : dimSizes(dimSizes), rev(getRank()),
269         dimTypes(sparsity, sparsity + getRank()) {
270     assert(perm && sparsity);
271     const uint64_t rank = getRank();
272     // Validate parameters.
273     assert(rank > 0 && "Trivial shape is unsupported");
274     for (uint64_t r = 0; r < rank; r++) {
275       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
276       assert((dimTypes[r] == DimLevelType::kDense ||
277               dimTypes[r] == DimLevelType::kCompressed) &&
278              "Unsupported DimLevelType");
279     }
280     // Construct the "reverse" (i.e., inverse) permutation.
281     for (uint64_t r = 0; r < rank; r++)
282       rev[perm[r]] = r;
283   }
284 
285   virtual ~SparseTensorStorageBase() = default;
286 
287   /// Get the rank of the tensor.
288   uint64_t getRank() const { return dimSizes.size(); }
289 
290   /// Getter for the dimension-sizes array, in storage-order.
291   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
292 
293   /// Safely lookup the size of the given (storage-order) dimension.
294   uint64_t getDimSize(uint64_t d) const {
295     assert(d < getRank());
296     return dimSizes[d];
297   }
298 
299   /// Getter for the "reverse" permutation, which maps this object's
300   /// storage-order to the tensor's semantic-order.
301   const std::vector<uint64_t> &getRev() const { return rev; }
302 
303   /// Getter for the dimension-types array, in storage-order.
304   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
305 
306   /// Safely check if the (storage-order) dimension uses compressed storage.
307   bool isCompressedDim(uint64_t d) const {
308     assert(d < getRank());
309     return (dimTypes[d] == DimLevelType::kCompressed);
310   }
311 
312   /// Allocate a new enumerator.
313 #define DECL_NEWENUMERATOR(VNAME, V)                                           \
314   virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t,       \
315                              const uint64_t *) const {                         \
316     FATAL_PIV("newEnumerator" #VNAME);                                         \
317   }
318   FOREVERY_V(DECL_NEWENUMERATOR)
319 #undef DECL_NEWENUMERATOR
320 
321   /// Overhead storage.
322 #define DECL_GETPOINTERS(PNAME, P)                                             \
323   virtual void getPointers(std::vector<P> **, uint64_t) {                      \
324     FATAL_PIV("getPointers" #PNAME);                                           \
325   }
326   FOREVERY_FIXED_O(DECL_GETPOINTERS)
327 #undef DECL_GETPOINTERS
328 #define DECL_GETINDICES(INAME, I)                                              \
329   virtual void getIndices(std::vector<I> **, uint64_t) {                       \
330     FATAL_PIV("getIndices" #INAME);                                            \
331   }
332   FOREVERY_FIXED_O(DECL_GETINDICES)
333 #undef DECL_GETINDICES
334 
335   /// Primary storage.
336 #define DECL_GETVALUES(VNAME, V)                                               \
337   virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
338   FOREVERY_V(DECL_GETVALUES)
339 #undef DECL_GETVALUES
340 
341   /// Element-wise insertion in lexicographic index order.
342 #define DECL_LEXINSERT(VNAME, V)                                               \
343   virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
344   FOREVERY_V(DECL_LEXINSERT)
345 #undef DECL_LEXINSERT
346 
347   /// Expanded insertion.
348 #define DECL_EXPINSERT(VNAME, V)                                               \
349   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) {      \
350     FATAL_PIV("expInsert" #VNAME);                                             \
351   }
352   FOREVERY_V(DECL_EXPINSERT)
353 #undef DECL_EXPINSERT
354 
355   /// Finishes insertion.
356   virtual void endInsert() = 0;
357 
358 protected:
359   // Since this class is virtual, we must disallow public copying in
360   // order to avoid "slicing".  Since this class has data members,
361   // that means making copying protected.
362   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
363   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
364   // Copy-assignment would be implicitly deleted (because `dimSizes`
365   // is const), so we explicitly delete it for clarity.
366   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
367 
368 private:
369   const std::vector<uint64_t> dimSizes;
370   std::vector<uint64_t> rev;
371   const std::vector<DimLevelType> dimTypes;
372 };
373 
374 #undef FATAL_PIV
375 
376 // Forward.
377 template <typename P, typename I, typename V>
378 class SparseTensorEnumerator;
379 
380 /// A memory-resident sparse tensor using a storage scheme based on
381 /// per-dimension sparse/dense annotations. This data structure provides a
382 /// bufferized form of a sparse tensor type. In contrast to generating setup
383 /// methods for each differently annotated sparse tensor, this method provides
384 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
385 /// and annotations to implement all required setup in a general manner.
386 template <typename P, typename I, typename V>
387 class SparseTensorStorage final : public SparseTensorStorageBase {
388   /// Private constructor to share code between the other constructors.
389   /// Beware that the object is not necessarily guaranteed to be in a
390   /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
391   /// doesn't entail `!(pointers[d].empty())`.
392   ///
393   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
394   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
395                       const uint64_t *perm, const DimLevelType *sparsity)
396       : SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()),
397         indices(getRank()), idx(getRank()) {}
398 
399 public:
400   /// Constructs a sparse tensor storage scheme with the given dimensions,
401   /// permutation, and per-dimension dense/sparse annotations, using
402   /// the coordinate scheme tensor for the initial contents if provided.
403   ///
404   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
405   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
406                       const uint64_t *perm, const DimLevelType *sparsity,
407                       SparseTensorCOO<V> *coo)
408       : SparseTensorStorage(dimSizes, perm, sparsity) {
409     // Provide hints on capacity of pointers and indices.
410     // TODO: needs much fine-tuning based on actual sparsity; currently
411     //       we reserve pointer/index space based on all previous dense
412     //       dimensions, which works well up to first sparse dim; but
413     //       we should really use nnz and dense/sparse distribution.
414     bool allDense = true;
415     uint64_t sz = 1;
416     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
417       if (isCompressedDim(r)) {
418         // TODO: Take a parameter between 1 and `dimSizes[r]`, and multiply
419         // `sz` by that before reserving. (For now we just use 1.)
420         pointers[r].reserve(sz + 1);
421         pointers[r].push_back(0);
422         indices[r].reserve(sz);
423         sz = 1;
424         allDense = false;
425       } else { // Dense dimension.
426         sz = checkedMul(sz, getDimSizes()[r]);
427       }
428     }
429     // Then assign contents from coordinate scheme tensor if provided.
430     if (coo) {
431       // Ensure both preconditions of `fromCOO`.
432       assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch");
433       coo->sort();
434       // Now actually insert the `elements`.
435       const std::vector<Element<V>> &elements = coo->getElements();
436       uint64_t nnz = elements.size();
437       values.reserve(nnz);
438       fromCOO(elements, 0, nnz, 0);
439     } else if (allDense) {
440       values.resize(sz, 0);
441     }
442   }
443 
444   /// Constructs a sparse tensor storage scheme with the given dimensions,
445   /// permutation, and per-dimension dense/sparse annotations, using
446   /// the given sparse tensor for the initial contents.
447   ///
448   /// Preconditions:
449   /// * `perm` and `sparsity` must be valid for `dimSizes.size()`.
450   /// * The `tensor` must have the same value type `V`.
451   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
452                       const uint64_t *perm, const DimLevelType *sparsity,
453                       const SparseTensorStorageBase &tensor);
454 
455   ~SparseTensorStorage() final = default;
456 
457   /// Partially specialize these getter methods based on template types.
458   void getPointers(std::vector<P> **out, uint64_t d) final {
459     assert(d < getRank());
460     *out = &pointers[d];
461   }
462   void getIndices(std::vector<I> **out, uint64_t d) final {
463     assert(d < getRank());
464     *out = &indices[d];
465   }
466   void getValues(std::vector<V> **out) final { *out = &values; }
467 
468   /// Partially specialize lexicographical insertions based on template types.
469   void lexInsert(const uint64_t *cursor, V val) final {
470     // First, wrap up pending insertion path.
471     uint64_t diff = 0;
472     uint64_t top = 0;
473     if (!values.empty()) {
474       diff = lexDiff(cursor);
475       endPath(diff + 1);
476       top = idx[diff] + 1;
477     }
478     // Then continue with insertion path.
479     insPath(cursor, diff, top, val);
480   }
481 
482   /// Partially specialize expanded insertions based on template types.
483   /// Note that this method resets the values/filled-switch array back
484   /// to all-zero/false while only iterating over the nonzero elements.
485   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
486                  uint64_t count) final {
487     if (count == 0)
488       return;
489     // Sort.
490     std::sort(added, added + count);
491     // Restore insertion path for first insert.
492     const uint64_t lastDim = getRank() - 1;
493     uint64_t index = added[0];
494     cursor[lastDim] = index;
495     lexInsert(cursor, values[index]);
496     assert(filled[index]);
497     values[index] = 0;
498     filled[index] = false;
499     // Subsequent insertions are quick.
500     for (uint64_t i = 1; i < count; i++) {
501       assert(index < added[i] && "non-lexicographic insertion");
502       index = added[i];
503       cursor[lastDim] = index;
504       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
505       assert(filled[index]);
506       values[index] = 0;
507       filled[index] = false;
508     }
509   }
510 
511   /// Finalizes lexicographic insertions.
512   void endInsert() final {
513     if (values.empty())
514       finalizeSegment(0);
515     else
516       endPath(0);
517   }
518 
519   void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
520                      const uint64_t *perm) const final {
521     *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
522   }
523 
524   /// Returns this sparse tensor storage scheme as a new memory-resident
525   /// sparse tensor in coordinate scheme with the given dimension order.
526   ///
527   /// Precondition: `perm` must be valid for `getRank()`.
528   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
529     SparseTensorEnumeratorBase<V> *enumerator;
530     newEnumerator(&enumerator, getRank(), perm);
531     SparseTensorCOO<V> *coo =
532         new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
533     enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
534       coo->add(ind, val);
535     });
536     // TODO: This assertion assumes there are no stored zeros,
537     // or if there are then that we don't filter them out.
538     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
539     assert(coo->getElements().size() == values.size());
540     delete enumerator;
541     return coo;
542   }
543 
544   /// Factory method. Constructs a sparse tensor storage scheme with the given
545   /// dimensions, permutation, and per-dimension dense/sparse annotations,
546   /// using the coordinate scheme tensor for the initial contents if provided.
547   /// In the latter case, the coordinate scheme must respect the same
548   /// permutation as is desired for the new sparse tensor storage.
549   ///
550   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
551   static SparseTensorStorage<P, I, V> *
552   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
553                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
554     SparseTensorStorage<P, I, V> *n = nullptr;
555     if (coo) {
556       const auto &coosz = coo->getDimSizes();
557       assertPermutedSizesMatchShape(coosz, rank, perm, shape);
558       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
559     } else {
560       std::vector<uint64_t> permsz(rank);
561       for (uint64_t r = 0; r < rank; r++) {
562         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
563         permsz[perm[r]] = shape[r];
564       }
565       // We pass the null `coo` to ensure we select the intended constructor.
566       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
567     }
568     return n;
569   }
570 
571   /// Factory method. Constructs a sparse tensor storage scheme with
572   /// the given dimensions, permutation, and per-dimension dense/sparse
573   /// annotations, using the sparse tensor for the initial contents.
574   ///
575   /// Preconditions:
576   /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
577   /// * The `tensor` must have the same value type `V`.
578   static SparseTensorStorage<P, I, V> *
579   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
580                   const DimLevelType *sparsity,
581                   const SparseTensorStorageBase *source) {
582     assert(source && "Got nullptr for source");
583     SparseTensorEnumeratorBase<V> *enumerator;
584     source->newEnumerator(&enumerator, rank, perm);
585     const auto &permsz = enumerator->permutedSizes();
586     assertPermutedSizesMatchShape(permsz, rank, perm, shape);
587     auto *tensor =
588         new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
589     delete enumerator;
590     return tensor;
591   }
592 
593 private:
594   /// Appends an arbitrary new position to `pointers[d]`.  This method
595   /// checks that `pos` is representable in the `P` type; however, it
596   /// does not check that `pos` is semantically valid (i.e., larger than
597   /// the previous position and smaller than `indices[d].capacity()`).
598   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
599     assert(isCompressedDim(d));
600     assert(pos <= std::numeric_limits<P>::max() &&
601            "Pointer value is too large for the P-type");
602     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
603   }
604 
605   /// Appends index `i` to dimension `d`, in the semantically general
606   /// sense.  For non-dense dimensions, that means appending to the
607   /// `indices[d]` array, checking that `i` is representable in the `I`
608   /// type; however, we do not verify other semantic requirements (e.g.,
609   /// that `i` is in bounds for `dimSizes[d]`, and not previously occurring
610   /// in the same segment).  For dense dimensions, this method instead
611   /// appends the appropriate number of zeros to the `values` array,
612   /// where `full` is the number of "entries" already written to `values`
613   /// for this segment (aka one after the highest index previously appended).
614   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
615     if (isCompressedDim(d)) {
616       assert(i <= std::numeric_limits<I>::max() &&
617              "Index value is too large for the I-type");
618       indices[d].push_back(static_cast<I>(i));
619     } else { // Dense dimension.
620       assert(i >= full && "Index was already filled");
621       if (i == full)
622         return; // Short-circuit, since it'll be a nop.
623       if (d + 1 == getRank())
624         values.insert(values.end(), i - full, 0);
625       else
626         finalizeSegment(d + 1, 0, i - full);
627     }
628   }
629 
630   /// Writes the given coordinate to `indices[d][pos]`.  This method
631   /// checks that `i` is representable in the `I` type; however, it
632   /// does not check that `i` is semantically valid (i.e., in bounds
633   /// for `dimSizes[d]` and not elsewhere occurring in the same segment).
634   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
635     assert(isCompressedDim(d));
636     // Subscript assignment to `std::vector` requires that the `pos`-th
637     // entry has been initialized; thus we must be sure to check `size()`
638     // here, instead of `capacity()` as would be ideal.
639     assert(pos < indices[d].size() && "Index position is out of bounds");
640     assert(i <= std::numeric_limits<I>::max() &&
641            "Index value is too large for the I-type");
642     indices[d][pos] = static_cast<I>(i);
643   }
644 
645   /// Computes the assembled-size associated with the `d`-th dimension,
646   /// given the assembled-size associated with the `(d-1)`-th dimension.
647   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
648   /// storage, as opposed to "dimension-sizes" which are the cardinality
649   /// of coordinates for that dimension.
650   ///
651   /// Precondition: the `pointers[d]` array must be fully initialized
652   /// before calling this method.
653   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
654     if (isCompressedDim(d))
655       return pointers[d][parentSz];
656     // else if dense:
657     return parentSz * getDimSizes()[d];
658   }
659 
660   /// Initializes sparse tensor storage scheme from a memory-resident sparse
661   /// tensor in coordinate scheme. This method prepares the pointers and
662   /// indices arrays under the given per-dimension dense/sparse annotations.
663   ///
664   /// Preconditions:
665   /// (1) the `elements` must be lexicographically sorted.
666   /// (2) the indices of every element are valid for `dimSizes` (equal rank
667   ///     and pointwise less-than).
668   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
669                uint64_t hi, uint64_t d) {
670     uint64_t rank = getRank();
671     assert(d <= rank && hi <= elements.size());
672     // Once dimensions are exhausted, insert the numerical values.
673     if (d == rank) {
674       assert(lo < hi);
675       values.push_back(elements[lo].value);
676       return;
677     }
678     // Visit all elements in this interval.
679     uint64_t full = 0;
680     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
681       // Find segment in interval with same index elements in this dimension.
682       uint64_t i = elements[lo].indices[d];
683       uint64_t seg = lo + 1;
684       while (seg < hi && elements[seg].indices[d] == i)
685         seg++;
686       // Handle segment in interval for sparse or dense dimension.
687       appendIndex(d, full, i);
688       full = i + 1;
689       fromCOO(elements, lo, seg, d + 1);
690       // And move on to next segment in interval.
691       lo = seg;
692     }
693     // Finalize the sparse pointer structure at this dimension.
694     finalizeSegment(d, full);
695   }
696 
697   /// Finalize the sparse pointer structure at this dimension.
698   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
699     if (count == 0)
700       return; // Short-circuit, since it'll be a nop.
701     if (isCompressedDim(d)) {
702       appendPointer(d, indices[d].size(), count);
703     } else { // Dense dimension.
704       const uint64_t sz = getDimSizes()[d];
705       assert(sz >= full && "Segment is overfull");
706       count = checkedMul(count, sz - full);
707       // For dense storage we must enumerate all the remaining coordinates
708       // in this dimension (i.e., coordinates after the last non-zero
709       // element), and either fill in their zero values or else recurse
710       // to finalize some deeper dimension.
711       if (d + 1 == getRank())
712         values.insert(values.end(), count, 0);
713       else
714         finalizeSegment(d + 1, 0, count);
715     }
716   }
717 
718   /// Wraps up a single insertion path, inner to outer.
719   void endPath(uint64_t diff) {
720     uint64_t rank = getRank();
721     assert(diff <= rank);
722     for (uint64_t i = 0; i < rank - diff; i++) {
723       const uint64_t d = rank - i - 1;
724       finalizeSegment(d, idx[d] + 1);
725     }
726   }
727 
728   /// Continues a single insertion path, outer to inner.
729   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
730     uint64_t rank = getRank();
731     assert(diff < rank);
732     for (uint64_t d = diff; d < rank; d++) {
733       uint64_t i = cursor[d];
734       appendIndex(d, top, i);
735       top = 0;
736       idx[d] = i;
737     }
738     values.push_back(val);
739   }
740 
741   /// Finds the lexicographic differing dimension.
742   uint64_t lexDiff(const uint64_t *cursor) const {
743     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
744       if (cursor[r] > idx[r])
745         return r;
746       else
747         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
748     assert(0 && "duplication insertion");
749     return -1u;
750   }
751 
752   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
753   // the cost of virtual-function dispatch in inner loops), without
754   // making them public to other client code.
755   friend class SparseTensorEnumerator<P, I, V>;
756 
757   std::vector<std::vector<P>> pointers;
758   std::vector<std::vector<I>> indices;
759   std::vector<V> values;
760   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
761 };
762 
763 /// A (higher-order) function object for enumerating the elements of some
764 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
765 /// method encapsulates the loop-nest for enumerating the elements of
766 /// the source tensor (in whatever order is best for the source tensor),
767 /// and applies a permutation to the coordinates/indices before handing
768 /// each element to the callback.  A single enumerator object can be
769 /// freely reused for several calls to `forallElements`, just so long
770 /// as each call is sequential with respect to one another.
771 ///
772 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
773 /// passed to the constructor; thus, objects of this class must not
774 /// outlive the sparse tensor they depend on.
775 ///
776 /// Design Note: The reason we define this class instead of simply using
777 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
778 /// the `<P,I>` template parameters from MLIR client code (to simplify the
779 /// type parameters used for direct sparse-to-sparse conversion).  And the
780 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
781 /// than simply using this class, is to avoid the cost of virtual-method
782 /// dispatch within the loop-nest.
783 template <typename V>
784 class SparseTensorEnumeratorBase {
785 public:
786   /// Constructs an enumerator with the given permutation for mapping
787   /// the semantic-ordering of dimensions to the desired target-ordering.
788   ///
789   /// Preconditions:
790   /// * the `tensor` must have the same `V` value type.
791   /// * `perm` must be valid for `rank`.
792   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
793                              uint64_t rank, const uint64_t *perm)
794       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
795         cursor(getRank()) {
796     assert(perm && "Received nullptr for permutation");
797     assert(rank == getRank() && "Permutation rank mismatch");
798     const auto &rev = src.getRev();           // source-order -> semantic-order
799     const auto &dimSizes = src.getDimSizes(); // in source storage-order
800     for (uint64_t s = 0; s < rank; s++) {     // `s` source storage-order
801       uint64_t t = perm[rev[s]];              // `t` target-order
802       reord[s] = t;
803       permsz[t] = dimSizes[s];
804     }
805   }
806 
807   virtual ~SparseTensorEnumeratorBase() = default;
808 
809   // We disallow copying to help avoid leaking the `src` reference.
810   // (In addition to avoiding the problem of slicing.)
811   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
812   SparseTensorEnumeratorBase &
813   operator=(const SparseTensorEnumeratorBase &) = delete;
814 
815   /// Returns the source/target tensor's rank.  (The source-rank and
816   /// target-rank are always equal since we only support permutations.
817   /// Though once we add support for other dimension mappings, this
818   /// method will have to be split in two.)
819   uint64_t getRank() const { return permsz.size(); }
820 
821   /// Returns the target tensor's dimension sizes.
822   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
823 
824   /// Enumerates all elements of the source tensor, permutes their
825   /// indices, and passes the permuted element to the callback.
826   /// The callback must not store the cursor reference directly,
827   /// since this function reuses the storage.  Instead, the callback
828   /// must copy it if they want to keep it.
829   virtual void forallElements(ElementConsumer<V> yield) = 0;
830 
831 protected:
832   const SparseTensorStorageBase &src;
833   std::vector<uint64_t> permsz; // in target order.
834   std::vector<uint64_t> reord;  // source storage-order -> target order.
835   std::vector<uint64_t> cursor; // in target order.
836 };
837 
838 template <typename P, typename I, typename V>
839 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
840   using Base = SparseTensorEnumeratorBase<V>;
841 
842 public:
843   /// Constructs an enumerator with the given permutation for mapping
844   /// the semantic-ordering of dimensions to the desired target-ordering.
845   ///
846   /// Precondition: `perm` must be valid for `rank`.
847   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
848                          uint64_t rank, const uint64_t *perm)
849       : Base(tensor, rank, perm) {}
850 
851   ~SparseTensorEnumerator() final = default;
852 
853   void forallElements(ElementConsumer<V> yield) final {
854     forallElements(yield, 0, 0);
855   }
856 
857 private:
858   /// The recursive component of the public `forallElements`.
859   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
860                       uint64_t d) {
861     // Recover the `<P,I,V>` type parameters of `src`.
862     const auto &src =
863         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
864     if (d == Base::getRank()) {
865       assert(parentPos < src.values.size() &&
866              "Value position is out of bounds");
867       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
868       yield(this->cursor, src.values[parentPos]);
869     } else if (src.isCompressedDim(d)) {
870       // Look up the bounds of the `d`-level segment determined by the
871       // `d-1`-level position `parentPos`.
872       const std::vector<P> &pointersD = src.pointers[d];
873       assert(parentPos + 1 < pointersD.size() &&
874              "Parent pointer position is out of bounds");
875       const uint64_t pstart = static_cast<uint64_t>(pointersD[parentPos]);
876       const uint64_t pstop = static_cast<uint64_t>(pointersD[parentPos + 1]);
877       // Loop-invariant code for looking up the `d`-level coordinates/indices.
878       const std::vector<I> &indicesD = src.indices[d];
879       assert(pstop <= indicesD.size() && "Index position is out of bounds");
880       uint64_t &cursorReordD = this->cursor[this->reord[d]];
881       for (uint64_t pos = pstart; pos < pstop; pos++) {
882         cursorReordD = static_cast<uint64_t>(indicesD[pos]);
883         forallElements(yield, pos, d + 1);
884       }
885     } else { // Dense dimension.
886       const uint64_t sz = src.getDimSizes()[d];
887       const uint64_t pstart = parentPos * sz;
888       uint64_t &cursorReordD = this->cursor[this->reord[d]];
889       for (uint64_t i = 0; i < sz; i++) {
890         cursorReordD = i;
891         forallElements(yield, pstart + i, d + 1);
892       }
893     }
894   }
895 };
896 
897 /// Statistics regarding the number of nonzero subtensors in
898 /// a source tensor, for direct sparse=>sparse conversion a la
899 /// <https://arxiv.org/abs/2001.02609>.
900 ///
901 /// N.B., this class stores references to the parameters passed to
902 /// the constructor; thus, objects of this class must not outlive
903 /// those parameters.
904 class SparseTensorNNZ final {
905 public:
906   /// Allocate the statistics structure for the desired sizes and
907   /// sparsity (in the target tensor's storage-order).  This constructor
908   /// does not actually populate the statistics, however; for that see
909   /// `initialize`.
910   ///
911   /// Precondition: `dimSizes` must not contain zeros.
912   SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
913                   const std::vector<DimLevelType> &sparsity)
914       : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
915     assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
916     bool uncompressed = true;
917     uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
918     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
919       switch (dimTypes[r]) {
920       case DimLevelType::kCompressed:
921         assert(uncompressed &&
922                "Multiple compressed layers not currently supported");
923         uncompressed = false;
924         nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
925         break;
926       case DimLevelType::kDense:
927         assert(uncompressed &&
928                "Dense after compressed not currently supported");
929         break;
930       case DimLevelType::kSingleton:
931         // Singleton after Compressed causes no problems for allocating
932         // `nnz` nor for the yieldPos loop.  This remains true even
933         // when adding support for multiple compressed dimensions or
934         // for dense-after-compressed.
935         break;
936       }
937       sz = checkedMul(sz, dimSizes[r]);
938     }
939   }
940 
941   // We disallow copying to help avoid leaking the stored references.
942   SparseTensorNNZ(const SparseTensorNNZ &) = delete;
943   SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
944 
945   /// Returns the rank of the target tensor.
946   uint64_t getRank() const { return dimSizes.size(); }
947 
948   /// Enumerate the source tensor to fill in the statistics.  The
949   /// enumerator should already incorporate the permutation (from
950   /// semantic-order to the target storage-order).
951   template <typename V>
952   void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
953     assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
954     assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
955     enumerator.forallElements(
956         [this](const std::vector<uint64_t> &ind, V) { add(ind); });
957   }
958 
959   /// The type of callback functions which receive an nnz-statistic.
960   using NNZConsumer = const std::function<void(uint64_t)> &;
961 
962   /// Lexicographically enumerates all indicies for dimensions strictly
963   /// less than `stopDim`, and passes their nnz statistic to the callback.
964   /// Since our use-case only requires the statistic not the coordinates
965   /// themselves, we do not bother to construct those coordinates.
966   void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
967     assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
968     assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
969            "Cannot look up non-compressed dimensions");
970     forallIndices(yield, stopDim, 0, 0);
971   }
972 
973 private:
974   /// Adds a new element (i.e., increment its statistics).  We use
975   /// a method rather than inlining into the lambda in `initialize`,
976   /// to avoid spurious templating over `V`.  And this method is private
977   /// to avoid needing to re-assert validity of `ind` (which is guaranteed
978   /// by `forallElements`).
979   void add(const std::vector<uint64_t> &ind) {
980     uint64_t parentPos = 0;
981     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
982       if (dimTypes[r] == DimLevelType::kCompressed)
983         nnz[r][parentPos]++;
984       parentPos = parentPos * dimSizes[r] + ind[r];
985     }
986   }
987 
988   /// Recursive component of the public `forallIndices`.
989   void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
990                      uint64_t d) const {
991     assert(d <= stopDim);
992     if (d == stopDim) {
993       assert(parentPos < nnz[d].size() && "Cursor is out of range");
994       yield(nnz[d][parentPos]);
995     } else {
996       const uint64_t sz = dimSizes[d];
997       const uint64_t pstart = parentPos * sz;
998       for (uint64_t i = 0; i < sz; i++)
999         forallIndices(yield, stopDim, pstart + i, d + 1);
1000     }
1001   }
1002 
1003   // All of these are in the target storage-order.
1004   const std::vector<uint64_t> &dimSizes;
1005   const std::vector<DimLevelType> &dimTypes;
1006   std::vector<std::vector<uint64_t>> nnz;
1007 };
1008 
1009 template <typename P, typename I, typename V>
1010 SparseTensorStorage<P, I, V>::SparseTensorStorage(
1011     const std::vector<uint64_t> &dimSizes, const uint64_t *perm,
1012     const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1013     : SparseTensorStorage(dimSizes, perm, sparsity) {
1014   SparseTensorEnumeratorBase<V> *enumerator;
1015   tensor.newEnumerator(&enumerator, getRank(), perm);
1016   {
1017     // Initialize the statistics structure.
1018     SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
1019     nnz.initialize(*enumerator);
1020     // Initialize "pointers" overhead (and allocate "indices", "values").
1021     uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
1022     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1023       if (isCompressedDim(r)) {
1024         pointers[r].reserve(parentSz + 1);
1025         pointers[r].push_back(0);
1026         uint64_t currentPos = 0;
1027         nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
1028           currentPos += n;
1029           appendPointer(r, currentPos);
1030         });
1031         assert(pointers[r].size() == parentSz + 1 &&
1032                "Final pointers size doesn't match allocated size");
1033         // That assertion entails `assembledSize(parentSz, r)`
1034         // is now in a valid state.  That is, `pointers[r][parentSz]`
1035         // equals the present value of `currentPos`, which is the
1036         // correct assembled-size for `indices[r]`.
1037       }
1038       // Update assembled-size for the next iteration.
1039       parentSz = assembledSize(parentSz, r);
1040       // Ideally we need only `indices[r].reserve(parentSz)`, however
1041       // the `std::vector` implementation forces us to initialize it too.
1042       // That is, in the yieldPos loop we need random-access assignment
1043       // to `indices[r]`; however, `std::vector`'s subscript-assignment
1044       // only allows assigning to already-initialized positions.
1045       if (isCompressedDim(r))
1046         indices[r].resize(parentSz, 0);
1047     }
1048     values.resize(parentSz, 0); // Both allocate and zero-initialize.
1049   }
1050   // The yieldPos loop
1051   enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
1052     uint64_t parentSz = 1, parentPos = 0;
1053     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1054       if (isCompressedDim(r)) {
1055         // If `parentPos == parentSz` then it's valid as an array-lookup;
1056         // however, it's semantically invalid here since that entry
1057         // does not represent a segment of `indices[r]`.  Moreover, that
1058         // entry must be immutable for `assembledSize` to remain valid.
1059         assert(parentPos < parentSz && "Pointers position is out of bounds");
1060         const uint64_t currentPos = pointers[r][parentPos];
1061         // This increment won't overflow the `P` type, since it can't
1062         // exceed the original value of `pointers[r][parentPos+1]`
1063         // which was already verified to be within bounds for `P`
1064         // when it was written to the array.
1065         pointers[r][parentPos]++;
1066         writeIndex(r, currentPos, ind[r]);
1067         parentPos = currentPos;
1068       } else { // Dense dimension.
1069         parentPos = parentPos * getDimSizes()[r] + ind[r];
1070       }
1071       parentSz = assembledSize(parentSz, r);
1072     }
1073     assert(parentPos < values.size() && "Value position is out of bounds");
1074     values[parentPos] = val;
1075   });
1076   // No longer need the enumerator, so we'll delete it ASAP.
1077   delete enumerator;
1078   // The finalizeYieldPos loop
1079   for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
1080     if (isCompressedDim(r)) {
1081       assert(parentSz == pointers[r].size() - 1 &&
1082              "Actual pointers size doesn't match the expected size");
1083       // Can't check all of them, but at least we can check the last one.
1084       assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
1085              "Pointers got corrupted");
1086       // TODO: optimize this by using `memmove` or similar.
1087       for (uint64_t n = 0; n < parentSz; n++) {
1088         const uint64_t parentPos = parentSz - n;
1089         pointers[r][parentPos] = pointers[r][parentPos - 1];
1090       }
1091       pointers[r][0] = 0;
1092     }
1093     parentSz = assembledSize(parentSz, r);
1094   }
1095 }
1096 
1097 /// Helper to convert string to lower case.
1098 static char *toLower(char *token) {
1099   for (char *c = token; *c; c++)
1100     *c = tolower(*c);
1101   return token;
1102 }
1103 
1104 /// This class abstracts over the information stored in file headers,
1105 /// as well as providing the buffers and methods for parsing those headers.
1106 class SparseTensorFile final {
1107 public:
1108   enum class ValueKind {
1109     kInvalid = 0,
1110     kPattern = 1,
1111     kReal = 2,
1112     kInteger = 3,
1113     kComplex = 4,
1114     kUndefined = 5
1115   };
1116 
1117   explicit SparseTensorFile(char *filename) : filename(filename) {
1118     assert(filename && "Received nullptr for filename");
1119   }
1120 
1121   // Disallows copying, to avoid duplicating the `file` pointer.
1122   SparseTensorFile(const SparseTensorFile &) = delete;
1123   SparseTensorFile &operator=(const SparseTensorFile &) = delete;
1124 
1125   // This dtor tries to avoid leaking the `file`.  (Though it's better
1126   // to call `closeFile` explicitly when possible, since there are
1127   // circumstances where dtors are not called reliably.)
1128   ~SparseTensorFile() { closeFile(); }
1129 
1130   /// Opens the file for reading.
1131   void openFile() {
1132     if (file)
1133       FATAL("Already opened file %s\n", filename);
1134     file = fopen(filename, "r");
1135     if (!file)
1136       FATAL("Cannot find file %s\n", filename);
1137   }
1138 
1139   /// Closes the file.
1140   void closeFile() {
1141     if (file) {
1142       fclose(file);
1143       file = nullptr;
1144     }
1145   }
1146 
1147   // TODO(wrengr/bixia): figure out how to reorganize the element-parsing
1148   // loop of `openSparseTensorCOO` into methods of this class, so we can
1149   // avoid leaking access to the `line` pointer (both for general hygiene
1150   // and because we can't mark it const due to the second argument of
1151   // `strtoul`/`strtoud` being `char * *restrict` rather than
1152   // `char const* *restrict`).
1153   //
1154   /// Attempts to read a line from the file.
1155   char *readLine() {
1156     if (fgets(line, kColWidth, file))
1157       return line;
1158     FATAL("Cannot read next line of %s\n", filename);
1159   }
1160 
1161   /// Reads and parses the file's header.
1162   void readHeader() {
1163     assert(file && "Attempt to readHeader() before openFile()");
1164     if (strstr(filename, ".mtx"))
1165       readMMEHeader();
1166     else if (strstr(filename, ".tns"))
1167       readExtFROSTTHeader();
1168     else
1169       FATAL("Unknown format %s\n", filename);
1170     assert(isValid() && "Failed to read the header");
1171   }
1172 
1173   ValueKind getValueKind() const { return valueKind_; }
1174 
1175   bool isValid() const { return valueKind_ != ValueKind::kInvalid; }
1176 
1177   /// Gets the MME "pattern" property setting.  Is only valid after
1178   /// parsing the header.
1179   bool isPattern() const {
1180     assert(isValid() && "Attempt to isPattern() before readHeader()");
1181     return valueKind_ == ValueKind::kPattern;
1182   }
1183 
1184   /// Gets the MME "symmetric" property setting.  Is only valid after
1185   /// parsing the header.
1186   bool isSymmetric() const {
1187     assert(isValid() && "Attempt to isSymmetric() before readHeader()");
1188     return isSymmetric_;
1189   }
1190 
1191   /// Gets the rank of the tensor.  Is only valid after parsing the header.
1192   uint64_t getRank() const {
1193     assert(isValid() && "Attempt to getRank() before readHeader()");
1194     return idata[0];
1195   }
1196 
1197   /// Gets the number of non-zeros.  Is only valid after parsing the header.
1198   uint64_t getNNZ() const {
1199     assert(isValid() && "Attempt to getNNZ() before readHeader()");
1200     return idata[1];
1201   }
1202 
1203   /// Gets the dimension-sizes array.  The pointer itself is always
1204   /// valid; however, the values stored therein are only valid after
1205   /// parsing the header.
1206   const uint64_t *getDimSizes() const { return idata + 2; }
1207 
1208   /// Safely gets the size of the given dimension.  Is only valid
1209   /// after parsing the header.
1210   uint64_t getDimSize(uint64_t d) const {
1211     assert(d < getRank());
1212     return idata[2 + d];
1213   }
1214 
1215   /// Asserts the shape subsumes the actual dimension sizes.  Is only
1216   /// valid after parsing the header.
1217   void assertMatchesShape(uint64_t rank, const uint64_t *shape) const {
1218     assert(rank == getRank() && "Rank mismatch");
1219     for (uint64_t r = 0; r < rank; r++)
1220       assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
1221              "Dimension size mismatch");
1222   }
1223 
1224 private:
1225   void readMMEHeader();
1226   void readExtFROSTTHeader();
1227 
1228   const char *filename;
1229   FILE *file = nullptr;
1230   ValueKind valueKind_ = ValueKind::kInvalid;
1231   bool isSymmetric_ = false;
1232   uint64_t idata[512];
1233   char line[kColWidth];
1234 };
1235 
1236 /// Read the MME header of a general sparse matrix of type real.
1237 void SparseTensorFile::readMMEHeader() {
1238   char header[64];
1239   char object[64];
1240   char format[64];
1241   char field[64];
1242   char symmetry[64];
1243   // Read header line.
1244   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
1245              symmetry) != 5)
1246     FATAL("Corrupt header in %s\n", filename);
1247   // Process `field`, which specify pattern or the data type of the values.
1248   if (strcmp(toLower(field), "pattern") == 0)
1249     valueKind_ = ValueKind::kPattern;
1250   else if (strcmp(toLower(field), "real") == 0)
1251     valueKind_ = ValueKind::kReal;
1252   else if (strcmp(toLower(field), "integer") == 0)
1253     valueKind_ = ValueKind::kInteger;
1254   else if (strcmp(toLower(field), "complex") == 0)
1255     valueKind_ = ValueKind::kComplex;
1256   else
1257     FATAL("Unexpected header field value in %s\n", filename);
1258 
1259   // Set properties.
1260   isSymmetric_ = (strcmp(toLower(symmetry), "symmetric") == 0);
1261   // Make sure this is a general sparse matrix.
1262   if (strcmp(toLower(header), "%%matrixmarket") ||
1263       strcmp(toLower(object), "matrix") ||
1264       strcmp(toLower(format), "coordinate") ||
1265       (strcmp(toLower(symmetry), "general") && !isSymmetric_))
1266     FATAL("Cannot find a general sparse matrix in %s\n", filename);
1267   // Skip comments.
1268   while (true) {
1269     readLine();
1270     if (line[0] != '%')
1271       break;
1272   }
1273   // Next line contains M N NNZ.
1274   idata[0] = 2; // rank
1275   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
1276              idata + 1) != 3)
1277     FATAL("Cannot find size in %s\n", filename);
1278 }
1279 
1280 /// Read the "extended" FROSTT header. Although not part of the documented
1281 /// format, we assume that the file starts with optional comments followed
1282 /// by two lines that define the rank, the number of nonzeros, and the
1283 /// dimensions sizes (one per rank) of the sparse tensor.
1284 void SparseTensorFile::readExtFROSTTHeader() {
1285   // Skip comments.
1286   while (true) {
1287     readLine();
1288     if (line[0] != '#')
1289       break;
1290   }
1291   // Next line contains RANK and NNZ.
1292   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
1293     FATAL("Cannot find metadata in %s\n", filename);
1294   // Followed by a line with the dimension sizes (one per rank).
1295   for (uint64_t r = 0; r < idata[0]; r++)
1296     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
1297       FATAL("Cannot find dimension size %s\n", filename);
1298   readLine(); // end of line
1299   // The FROSTT format does not define the data type of the nonzero elements.
1300   valueKind_ = ValueKind::kUndefined;
1301 }
1302 
1303 // Adds a value to a tensor in coordinate scheme. If is_symmetric_value is true,
1304 // also adds the value to its symmetric location.
1305 template <typename T, typename V>
1306 static inline void addValue(T *coo, V value,
1307                             const std::vector<uint64_t> indices,
1308                             bool is_symmetric_value) {
1309   // TODO: <https://github.com/llvm/llvm-project/issues/54179>
1310   coo->add(indices, value);
1311   // We currently chose to deal with symmetric matrices by fully constructing
1312   // them. In the future, we may want to make symmetry implicit for storage
1313   // reasons.
1314   if (is_symmetric_value)
1315     coo->add({indices[1], indices[0]}, value);
1316 }
1317 
1318 // Reads an element of a complex type for the current indices in coordinate
1319 // scheme.
1320 template <typename V>
1321 static inline void readCOOValue(SparseTensorCOO<std::complex<V>> *coo,
1322                                 const std::vector<uint64_t> indices,
1323                                 char **linePtr, bool is_pattern,
1324                                 bool add_symmetric_value) {
1325   // Read two values to make a complex. The external formats always store
1326   // numerical values with the type double, but we cast these values to the
1327   // sparse tensor object type. For a pattern tensor, we arbitrarily pick the
1328   // value 1 for all entries.
1329   V re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
1330   V im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
1331   std::complex<V> value = {re, im};
1332   addValue(coo, value, indices, add_symmetric_value);
1333 }
1334 
1335 // Reads an element of a non-complex type for the current indices in coordinate
1336 // scheme.
1337 template <typename V,
1338           typename std::enable_if<
1339               !std::is_same<std::complex<float>, V>::value &&
1340               !std::is_same<std::complex<double>, V>::value>::type * = nullptr>
1341 static void inline readCOOValue(SparseTensorCOO<V> *coo,
1342                                 const std::vector<uint64_t> indices,
1343                                 char **linePtr, bool is_pattern,
1344                                 bool is_symmetric_value) {
1345   // The external formats always store these numerical values with the type
1346   // double, but we cast these values to the sparse tensor object type.
1347   // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
1348   double value = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
1349   addValue(coo, value, indices, is_symmetric_value);
1350 }
1351 
1352 /// Reads a sparse tensor with the given filename into a memory-resident
1353 /// sparse tensor in coordinate scheme.
1354 template <typename V>
1355 static SparseTensorCOO<V> *
1356 openSparseTensorCOO(char *filename, uint64_t rank, const uint64_t *shape,
1357                     const uint64_t *perm, PrimaryType valTp) {
1358   SparseTensorFile stfile(filename);
1359   stfile.openFile();
1360   stfile.readHeader();
1361   // Check tensor element type against the value type in the input file.
1362   SparseTensorFile::ValueKind valueKind = stfile.getValueKind();
1363   bool tensorIsInteger =
1364       (valTp >= PrimaryType::kI64 && valTp <= PrimaryType::kI8);
1365   bool tensorIsReal = (valTp >= PrimaryType::kF64 && valTp <= PrimaryType::kI8);
1366   if ((valueKind == SparseTensorFile::ValueKind::kReal && tensorIsInteger) ||
1367       (valueKind == SparseTensorFile::ValueKind::kComplex && tensorIsReal)) {
1368     FATAL("Tensor element type %d not compatible with values in file %s\n",
1369           valTp, filename);
1370   }
1371   stfile.assertMatchesShape(rank, shape);
1372   // Prepare sparse tensor object with per-dimension sizes
1373   // and the number of nonzeros as initial capacity.
1374   uint64_t nnz = stfile.getNNZ();
1375   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, stfile.getDimSizes(),
1376                                                      perm, nnz);
1377   // Read all nonzero elements.
1378   std::vector<uint64_t> indices(rank);
1379   for (uint64_t k = 0; k < nnz; k++) {
1380     char *linePtr = stfile.readLine();
1381     for (uint64_t r = 0; r < rank; r++) {
1382       uint64_t idx = strtoul(linePtr, &linePtr, 10);
1383       // Add 0-based index.
1384       indices[perm[r]] = idx - 1;
1385     }
1386     readCOOValue(coo, indices, &linePtr, stfile.isPattern(),
1387                  stfile.isSymmetric() && indices[0] != indices[1]);
1388   }
1389   // Close the file and return tensor.
1390   stfile.closeFile();
1391   return coo;
1392 }
1393 
1394 /// Writes the sparse tensor to `dest` in extended FROSTT format.
1395 template <typename V>
1396 static void outSparseTensor(void *tensor, void *dest, bool sort) {
1397   assert(tensor && dest);
1398   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
1399   if (sort)
1400     coo->sort();
1401   char *filename = static_cast<char *>(dest);
1402   auto &dimSizes = coo->getDimSizes();
1403   auto &elements = coo->getElements();
1404   uint64_t rank = coo->getRank();
1405   uint64_t nnz = elements.size();
1406   std::fstream file;
1407   file.open(filename, std::ios_base::out | std::ios_base::trunc);
1408   assert(file.is_open());
1409   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1410   for (uint64_t r = 0; r < rank - 1; r++)
1411     file << dimSizes[r] << " ";
1412   file << dimSizes[rank - 1] << std::endl;
1413   for (uint64_t i = 0; i < nnz; i++) {
1414     auto &idx = elements[i].indices;
1415     for (uint64_t r = 0; r < rank; r++)
1416       file << (idx[r] + 1) << " ";
1417     file << elements[i].value << std::endl;
1418   }
1419   file.flush();
1420   file.close();
1421   assert(file.good());
1422 }
1423 
1424 /// Initializes sparse tensor from an external COO-flavored format.
1425 template <typename V>
1426 static SparseTensorStorage<uint64_t, uint64_t, V> *
1427 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
1428                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
1429   const DimLevelType *sparsity = (DimLevelType *)(sparse);
1430 #ifndef NDEBUG
1431   // Verify that perm is a permutation of 0..(rank-1).
1432   std::vector<uint64_t> order(perm, perm + rank);
1433   std::sort(order.begin(), order.end());
1434   for (uint64_t i = 0; i < rank; ++i)
1435     if (i != order[i])
1436       FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
1437 
1438   // Verify that the sparsity values are supported.
1439   for (uint64_t i = 0; i < rank; ++i)
1440     if (sparsity[i] != DimLevelType::kDense &&
1441         sparsity[i] != DimLevelType::kCompressed)
1442       FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
1443 #endif
1444 
1445   // Convert external format to internal COO.
1446   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
1447   std::vector<uint64_t> idx(rank);
1448   for (uint64_t i = 0, base = 0; i < nse; i++) {
1449     for (uint64_t r = 0; r < rank; r++)
1450       idx[perm[r]] = indices[base + r];
1451     coo->add(idx, values[i]);
1452     base += rank;
1453   }
1454   // Return sparse tensor storage format as opaque pointer.
1455   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
1456       rank, shape, perm, sparsity, coo);
1457   delete coo;
1458   return tensor;
1459 }
1460 
1461 /// Converts a sparse tensor to an external COO-flavored format.
1462 template <typename V>
1463 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1464                                  uint64_t **pShape, V **pValues,
1465                                  uint64_t **pIndices) {
1466   assert(tensor);
1467   auto sparseTensor =
1468       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1469   uint64_t rank = sparseTensor->getRank();
1470   std::vector<uint64_t> perm(rank);
1471   std::iota(perm.begin(), perm.end(), 0);
1472   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1473 
1474   const std::vector<Element<V>> &elements = coo->getElements();
1475   uint64_t nse = elements.size();
1476 
1477   uint64_t *shape = new uint64_t[rank];
1478   for (uint64_t i = 0; i < rank; i++)
1479     shape[i] = coo->getDimSizes()[i];
1480 
1481   V *values = new V[nse];
1482   uint64_t *indices = new uint64_t[rank * nse];
1483 
1484   for (uint64_t i = 0, base = 0; i < nse; i++) {
1485     values[i] = elements[i].value;
1486     for (uint64_t j = 0; j < rank; j++)
1487       indices[base + j] = elements[i].indices[j];
1488     base += rank;
1489   }
1490 
1491   delete coo;
1492   *pRank = rank;
1493   *pNse = nse;
1494   *pShape = shape;
1495   *pValues = values;
1496   *pIndices = indices;
1497 }
1498 
1499 } // anonymous namespace
1500 
1501 extern "C" {
1502 
1503 //===----------------------------------------------------------------------===//
1504 //
1505 // Public functions which operate on MLIR buffers (memrefs) to interact
1506 // with sparse tensors (which are only visible as opaque pointers externally).
1507 //
1508 //===----------------------------------------------------------------------===//
1509 
1510 #define CASE(p, i, v, P, I, V)                                                 \
1511   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
1512     SparseTensorCOO<V> *coo = nullptr;                                         \
1513     if (action <= Action::kFromCOO) {                                          \
1514       if (action == Action::kFromFile) {                                       \
1515         char *filename = static_cast<char *>(ptr);                             \
1516         coo = openSparseTensorCOO<V>(filename, rank, shape, perm, v);          \
1517       } else if (action == Action::kFromCOO) {                                 \
1518         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
1519       } else {                                                                 \
1520         assert(action == Action::kEmpty);                                      \
1521       }                                                                        \
1522       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
1523           rank, shape, perm, sparsity, coo);                                   \
1524       if (action == Action::kFromFile)                                         \
1525         delete coo;                                                            \
1526       return tensor;                                                           \
1527     }                                                                          \
1528     if (action == Action::kSparseToSparse) {                                   \
1529       auto *tensor = static_cast<SparseTensorStorageBase *>(ptr);              \
1530       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
1531                                                            sparsity, tensor);  \
1532     }                                                                          \
1533     if (action == Action::kEmptyCOO)                                           \
1534       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
1535     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1536     if (action == Action::kToIterator) {                                       \
1537       coo->startIterator();                                                    \
1538     } else {                                                                   \
1539       assert(action == Action::kToCOO);                                        \
1540     }                                                                          \
1541     return coo;                                                                \
1542   }
1543 
1544 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1545 
1546 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1547 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1548 // that this file cannot get out of sync with its header.
1549 static_assert(std::is_same<index_type, uint64_t>::value,
1550               "Expected index_type == uint64_t");
1551 
1552 void *
1553 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1554                              StridedMemRefType<index_type, 1> *sref,
1555                              StridedMemRefType<index_type, 1> *pref,
1556                              OverheadType ptrTp, OverheadType indTp,
1557                              PrimaryType valTp, Action action, void *ptr) {
1558   assert(aref && sref && pref);
1559   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1560          pref->strides[0] == 1);
1561   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1562   const DimLevelType *sparsity = aref->data + aref->offset;
1563   const index_type *shape = sref->data + sref->offset;
1564   const index_type *perm = pref->data + pref->offset;
1565   uint64_t rank = aref->sizes[0];
1566 
1567   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1568   // This is safe because of the static_assert above.
1569   if (ptrTp == OverheadType::kIndex)
1570     ptrTp = OverheadType::kU64;
1571   if (indTp == OverheadType::kIndex)
1572     indTp = OverheadType::kU64;
1573 
1574   // Double matrices with all combinations of overhead storage.
1575   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1576        uint64_t, double);
1577   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1578        uint32_t, double);
1579   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1580        uint16_t, double);
1581   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1582        uint8_t, double);
1583   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1584        uint64_t, double);
1585   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1586        uint32_t, double);
1587   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1588        uint16_t, double);
1589   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1590        uint8_t, double);
1591   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1592        uint64_t, double);
1593   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1594        uint32_t, double);
1595   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1596        uint16_t, double);
1597   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1598        uint8_t, double);
1599   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1600        uint64_t, double);
1601   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1602        uint32_t, double);
1603   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1604        uint16_t, double);
1605   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1606        uint8_t, double);
1607 
1608   // Float matrices with all combinations of overhead storage.
1609   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1610        uint64_t, float);
1611   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1612        uint32_t, float);
1613   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1614        uint16_t, float);
1615   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1616        uint8_t, float);
1617   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1618        uint64_t, float);
1619   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1620        uint32_t, float);
1621   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1622        uint16_t, float);
1623   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1624        uint8_t, float);
1625   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1626        uint64_t, float);
1627   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1628        uint32_t, float);
1629   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1630        uint16_t, float);
1631   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1632        uint8_t, float);
1633   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1634        uint64_t, float);
1635   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1636        uint32_t, float);
1637   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1638        uint16_t, float);
1639   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1640        uint8_t, float);
1641 
1642   // Two-byte floats with both overheads of the same type.
1643   CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16);
1644   CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16);
1645   CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16);
1646   CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16);
1647   CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16);
1648   CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16);
1649   CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16);
1650   CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16);
1651 
1652   // Integral matrices with both overheads of the same type.
1653   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1654   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1655   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1656   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1657   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t);
1658   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1659   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1660   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1661   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t);
1662   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1663   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1664   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1665   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t);
1666   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1667   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1668   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1669 
1670   // Complex matrices with wide overhead.
1671   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1672   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1673 
1674   // Unsupported case (add above if needed).
1675   // TODO: better pretty-printing of enum values!
1676   FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
1677         static_cast<int>(ptrTp), static_cast<int>(indTp),
1678         static_cast<int>(valTp));
1679 }
1680 #undef CASE
1681 #undef CASE_SECSAME
1682 
1683 #define IMPL_SPARSEVALUES(VNAME, V)                                            \
1684   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
1685                                         void *tensor) {                        \
1686     assert(ref &&tensor);                                                      \
1687     std::vector<V> *v;                                                         \
1688     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
1689     ref->basePtr = ref->data = v->data();                                      \
1690     ref->offset = 0;                                                           \
1691     ref->sizes[0] = v->size();                                                 \
1692     ref->strides[0] = 1;                                                       \
1693   }
1694 FOREVERY_V(IMPL_SPARSEVALUES)
1695 #undef IMPL_SPARSEVALUES
1696 
1697 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1698   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1699                            index_type d) {                                     \
1700     assert(ref &&tensor);                                                      \
1701     std::vector<TYPE> *v;                                                      \
1702     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1703     ref->basePtr = ref->data = v->data();                                      \
1704     ref->offset = 0;                                                           \
1705     ref->sizes[0] = v->size();                                                 \
1706     ref->strides[0] = 1;                                                       \
1707   }
1708 #define IMPL_SPARSEPOINTERS(PNAME, P)                                          \
1709   IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
1710 FOREVERY_O(IMPL_SPARSEPOINTERS)
1711 #undef IMPL_SPARSEPOINTERS
1712 
1713 #define IMPL_SPARSEINDICES(INAME, I)                                           \
1714   IMPL_GETOVERHEAD(sparseIndices##INAME, I, getIndices)
1715 FOREVERY_O(IMPL_SPARSEINDICES)
1716 #undef IMPL_SPARSEINDICES
1717 #undef IMPL_GETOVERHEAD
1718 
1719 #define IMPL_ADDELT(VNAME, V)                                                  \
1720   void *_mlir_ciface_addElt##VNAME(void *coo, StridedMemRefType<V, 0> *vref,   \
1721                                    StridedMemRefType<index_type, 1> *iref,     \
1722                                    StridedMemRefType<index_type, 1> *pref) {   \
1723     assert(coo &&vref &&iref &&pref);                                          \
1724     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1725     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1726     const index_type *indx = iref->data + iref->offset;                        \
1727     const index_type *perm = pref->data + pref->offset;                        \
1728     uint64_t isize = iref->sizes[0];                                           \
1729     std::vector<index_type> indices(isize);                                    \
1730     for (uint64_t r = 0; r < isize; r++)                                       \
1731       indices[perm[r]] = indx[r];                                              \
1732     V *value = vref->data + vref->offset;                                      \
1733     static_cast<SparseTensorCOO<V> *>(coo)->add(indices, *value);              \
1734     return coo;                                                                \
1735   }
1736 FOREVERY_V(IMPL_ADDELT)
1737 #undef IMPL_ADDELT
1738 
1739 #define IMPL_GETNEXT(VNAME, V)                                                 \
1740   bool _mlir_ciface_getNext##VNAME(void *coo,                                  \
1741                                    StridedMemRefType<index_type, 1> *iref,     \
1742                                    StridedMemRefType<V, 0> *vref) {            \
1743     assert(coo &&iref &&vref);                                                 \
1744     assert(iref->strides[0] == 1);                                             \
1745     index_type *indx = iref->data + iref->offset;                              \
1746     V *value = vref->data + vref->offset;                                      \
1747     const uint64_t isize = iref->sizes[0];                                     \
1748     const Element<V> *elem =                                                   \
1749         static_cast<SparseTensorCOO<V> *>(coo)->getNext();                     \
1750     if (elem == nullptr)                                                       \
1751       return false;                                                            \
1752     for (uint64_t r = 0; r < isize; r++)                                       \
1753       indx[r] = elem->indices[r];                                              \
1754     *value = elem->value;                                                      \
1755     return true;                                                               \
1756   }
1757 FOREVERY_V(IMPL_GETNEXT)
1758 #undef IMPL_GETNEXT
1759 
1760 #define IMPL_LEXINSERT(VNAME, V)                                               \
1761   void _mlir_ciface_lexInsert##VNAME(void *tensor,                             \
1762                                      StridedMemRefType<index_type, 1> *cref,   \
1763                                      StridedMemRefType<V, 0> *vref) {          \
1764     assert(tensor &&cref &&vref);                                              \
1765     assert(cref->strides[0] == 1);                                             \
1766     index_type *cursor = cref->data + cref->offset;                            \
1767     assert(cursor);                                                            \
1768     V *value = vref->data + vref->offset;                                      \
1769     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, *value); \
1770   }
1771 FOREVERY_V(IMPL_LEXINSERT)
1772 #undef IMPL_LEXINSERT
1773 
1774 #define IMPL_EXPINSERT(VNAME, V)                                               \
1775   void _mlir_ciface_expInsert##VNAME(                                          \
1776       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1777       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1778       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1779     assert(tensor &&cref &&vref &&fref &&aref);                                \
1780     assert(cref->strides[0] == 1);                                             \
1781     assert(vref->strides[0] == 1);                                             \
1782     assert(fref->strides[0] == 1);                                             \
1783     assert(aref->strides[0] == 1);                                             \
1784     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1785     index_type *cursor = cref->data + cref->offset;                            \
1786     V *values = vref->data + vref->offset;                                     \
1787     bool *filled = fref->data + fref->offset;                                  \
1788     index_type *added = aref->data + aref->offset;                             \
1789     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1790         cursor, values, filled, added, count);                                 \
1791   }
1792 FOREVERY_V(IMPL_EXPINSERT)
1793 #undef IMPL_EXPINSERT
1794 
1795 //===----------------------------------------------------------------------===//
1796 //
1797 // Public functions which accept only C-style data structures to interact
1798 // with sparse tensors (which are only visible as opaque pointers externally).
1799 //
1800 //===----------------------------------------------------------------------===//
1801 
1802 index_type sparseDimSize(void *tensor, index_type d) {
1803   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1804 }
1805 
1806 void endInsert(void *tensor) {
1807   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1808 }
1809 
1810 #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \
1811   void outSparseTensor##VNAME(void *coo, void *dest, bool sort) {              \
1812     return outSparseTensor<V>(coo, dest, sort);                                \
1813   }
1814 FOREVERY_V(IMPL_OUTSPARSETENSOR)
1815 #undef IMPL_OUTSPARSETENSOR
1816 
1817 void delSparseTensor(void *tensor) {
1818   delete static_cast<SparseTensorStorageBase *>(tensor);
1819 }
1820 
1821 #define IMPL_DELCOO(VNAME, V)                                                  \
1822   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1823     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1824   }
1825 FOREVERY_V(IMPL_DELCOO)
1826 #undef IMPL_DELCOO
1827 
1828 char *getTensorFilename(index_type id) {
1829   char var[80];
1830   sprintf(var, "TENSOR%" PRIu64, id);
1831   char *env = getenv(var);
1832   if (!env)
1833     FATAL("Environment variable %s is not set\n", var);
1834   return env;
1835 }
1836 
1837 void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
1838   assert(out && "Received nullptr for out-parameter");
1839   SparseTensorFile stfile(filename);
1840   stfile.openFile();
1841   stfile.readHeader();
1842   stfile.closeFile();
1843   const uint64_t rank = stfile.getRank();
1844   const uint64_t *dimSizes = stfile.getDimSizes();
1845   out->reserve(rank);
1846   out->assign(dimSizes, dimSizes + rank);
1847 }
1848 
1849 // TODO: generalize beyond 64-bit indices.
1850 #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V)                               \
1851   void *convertToMLIRSparseTensor##VNAME(                                      \
1852       uint64_t rank, uint64_t nse, uint64_t *shape, V *values,                 \
1853       uint64_t *indices, uint64_t *perm, uint8_t *sparse) {                    \
1854     return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm,      \
1855                                  sparse);                                      \
1856   }
1857 FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
1858 #undef IMPL_CONVERTTOMLIRSPARSETENSOR
1859 
1860 // TODO: Currently, values are copied from SparseTensorStorage to
1861 // SparseTensorCOO, then to the output.  We may want to reduce the number
1862 // of copies.
1863 //
1864 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1865 // compressed
1866 #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V)                             \
1867   void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank,       \
1868                                           uint64_t *pNse, uint64_t **pShape,   \
1869                                           V **pValues, uint64_t **pIndices) {  \
1870     fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices);   \
1871   }
1872 FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
1873 #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
1874 
1875 } // extern "C"
1876 
1877 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1878