1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 
19 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cctype>
24 #include <cstdio>
25 #include <cstdlib>
26 #include <cstring>
27 #include <fstream>
28 #include <functional>
29 #include <iostream>
30 #include <limits>
31 #include <numeric>
32 
33 //===----------------------------------------------------------------------===//
34 //
35 // Internal support for storing and reading sparse tensors.
36 //
37 // The following memory-resident sparse storage schemes are supported:
38 //
39 // (a) A coordinate scheme for temporarily storing and lexicographically
40 //     sorting a sparse tensor by index (SparseTensorCOO).
41 //
42 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
43 //     per-dimension sparse/dense annnotations together with a dimension
44 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
45 //
46 // The following external formats are supported:
47 //
48 // (1) Matrix Market Exchange (MME): *.mtx
49 //     https://math.nist.gov/MatrixMarket/formats.html
50 //
51 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
52 //     http://frostt.io/tensors/file-formats.html
53 //
54 // Two public APIs are supported:
55 //
56 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
57 //     tensors. These methods should be used exclusively by MLIR
58 //     compiler-generated code.
59 //
60 // (II) Methods that accept C-style data structures to interact with sparse
61 //      tensors. These methods can be used by any external runtime that wants
62 //      to interact with MLIR compiler-generated code.
63 //
64 // In both cases (I) and (II), the SparseTensorStorage format is externally
65 // only visible as an opaque pointer.
66 //
67 //===----------------------------------------------------------------------===//
68 
69 namespace {
70 
71 static constexpr int kColWidth = 1025;
72 
73 /// A version of `operator*` on `uint64_t` which checks for overflows.
74 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
75   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
76          "Integer overflow");
77   return lhs * rhs;
78 }
79 
80 // This macro helps minimize repetition of this idiom, as well as ensuring
81 // we have some additional output indicating where the error is coming from.
82 // (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
83 // to track down whether an error is coming from our code vs somewhere else
84 // in MLIR.)
85 #define FATAL(...)                                                             \
86   do {                                                                         \
87     fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__);                        \
88     exit(1);                                                                   \
89   } while (0)
90 
91 // TODO: try to unify this with `SparseTensorFile::assertMatchesShape`
92 // which is used by `openSparseTensorCOO`.  It's easy enough to resolve
93 // the `std::vector` vs pointer mismatch for `dimSizes`; but it's trickier
94 // to resolve the presence/absence of `perm` (without introducing extra
95 // overhead), so perhaps the code duplication is unavoidable.
96 //
97 /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
98 /// semantic-order to target-order) are a refinement of the desired `shape`
99 /// (in semantic-order).
100 ///
101 /// Precondition: `perm` and `shape` must be valid for `rank`.
102 static inline void
103 assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes,
104                               uint64_t rank, const uint64_t *perm,
105                               const uint64_t *shape) {
106   assert(perm && shape);
107   assert(rank == dimSizes.size() && "Rank mismatch");
108   for (uint64_t r = 0; r < rank; r++)
109     assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) &&
110            "Dimension size mismatch");
111 }
112 
113 /// A sparse tensor element in coordinate scheme (value and indices).
114 /// For example, a rank-1 vector element would look like
115 ///   ({i}, a[i])
116 /// and a rank-5 tensor element like
117 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
118 /// We use pointer to a shared index pool rather than e.g. a direct
119 /// vector since that (1) reduces the per-element memory footprint, and
120 /// (2) centralizes the memory reservation and (re)allocation to one place.
121 template <typename V>
122 struct Element final {
123   Element(uint64_t *ind, V val) : indices(ind), value(val){};
124   uint64_t *indices; // pointer into shared index pool
125   V value;
126 };
127 
128 /// The type of callback functions which receive an element.  We avoid
129 /// packaging the coordinates and value together as an `Element` object
130 /// because this helps keep code somewhat cleaner.
131 template <typename V>
132 using ElementConsumer =
133     const std::function<void(const std::vector<uint64_t> &, V)> &;
134 
135 /// A memory-resident sparse tensor in coordinate scheme (collection of
136 /// elements). This data structure is used to read a sparse tensor from
137 /// any external format into memory and sort the elements lexicographically
138 /// by indices before passing it back to the client (most packed storage
139 /// formats require the elements to appear in lexicographic index order).
140 template <typename V>
141 struct SparseTensorCOO final {
142 public:
143   SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
144       : dimSizes(dimSizes) {
145     if (capacity) {
146       elements.reserve(capacity);
147       indices.reserve(capacity * getRank());
148     }
149   }
150 
151   /// Adds element as indices and value.
152   void add(const std::vector<uint64_t> &ind, V val) {
153     assert(!iteratorLocked && "Attempt to add() after startIterator()");
154     uint64_t *base = indices.data();
155     uint64_t size = indices.size();
156     uint64_t rank = getRank();
157     assert(ind.size() == rank && "Element rank mismatch");
158     for (uint64_t r = 0; r < rank; r++) {
159       assert(ind[r] < dimSizes[r] && "Index is too large for the dimension");
160       indices.push_back(ind[r]);
161     }
162     // This base only changes if indices were reallocated. In that case, we
163     // need to correct all previous pointers into the vector. Note that this
164     // only happens if we did not set the initial capacity right, and then only
165     // for every internal vector reallocation (which with the doubling rule
166     // should only incur an amortized linear overhead).
167     uint64_t *newBase = indices.data();
168     if (newBase != base) {
169       for (uint64_t i = 0, n = elements.size(); i < n; i++)
170         elements[i].indices = newBase + (elements[i].indices - base);
171       base = newBase;
172     }
173     // Add element as (pointer into shared index pool, value) pair.
174     elements.emplace_back(base + size, val);
175   }
176 
177   /// Sorts elements lexicographically by index.
178   void sort() {
179     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
180     // TODO: we may want to cache an `isSorted` bit, to avoid
181     // unnecessary/redundant sorting.
182     uint64_t rank = getRank();
183     std::sort(elements.begin(), elements.end(),
184               [rank](const Element<V> &e1, const Element<V> &e2) {
185                 for (uint64_t r = 0; r < rank; r++) {
186                   if (e1.indices[r] == e2.indices[r])
187                     continue;
188                   return e1.indices[r] < e2.indices[r];
189                 }
190                 return false;
191               });
192   }
193 
194   /// Get the rank of the tensor.
195   uint64_t getRank() const { return dimSizes.size(); }
196 
197   /// Getter for the dimension-sizes array.
198   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
199 
200   /// Getter for the elements array.
201   const std::vector<Element<V>> &getElements() const { return elements; }
202 
203   /// Switch into iterator mode.
204   void startIterator() {
205     iteratorLocked = true;
206     iteratorPos = 0;
207   }
208 
209   /// Get the next element.
210   const Element<V> *getNext() {
211     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
212     if (iteratorPos < elements.size())
213       return &(elements[iteratorPos++]);
214     iteratorLocked = false;
215     return nullptr;
216   }
217 
218   /// Factory method. Permutes the original dimensions according to
219   /// the given ordering and expects subsequent add() calls to honor
220   /// that same ordering for the given indices. The result is a
221   /// fully permuted coordinate scheme.
222   ///
223   /// Precondition: `dimSizes` and `perm` must be valid for `rank`.
224   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
225                                                 const uint64_t *dimSizes,
226                                                 const uint64_t *perm,
227                                                 uint64_t capacity = 0) {
228     std::vector<uint64_t> permsz(rank);
229     for (uint64_t r = 0; r < rank; r++) {
230       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
231       permsz[perm[r]] = dimSizes[r];
232     }
233     return new SparseTensorCOO<V>(permsz, capacity);
234   }
235 
236 private:
237   const std::vector<uint64_t> dimSizes; // per-dimension sizes
238   std::vector<Element<V>> elements;     // all COO elements
239   std::vector<uint64_t> indices;        // shared index pool
240   bool iteratorLocked = false;
241   unsigned iteratorPos = 0;
242 };
243 
244 // Forward.
245 template <typename V>
246 class SparseTensorEnumeratorBase;
247 
248 // Helper macro for generating error messages when some
249 // `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
250 // and then the wrong "partial method specialization" is called.
251 #define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
252 
253 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
254 /// takes responsibility for all the `<P,I,V>`-independent aspects
255 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
256 /// we use function overloading to implement "partial" method
257 /// specialization, which the C-API relies on to catch type errors
258 /// arising from our use of opaque pointers.
259 class SparseTensorStorageBase {
260 public:
261   /// Constructs a new storage object.  The `perm` maps the tensor's
262   /// semantic-ordering of dimensions to this object's storage-order.
263   /// The `dimSizes` and `sparsity` arrays are already in storage-order.
264   ///
265   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
266   SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes,
267                           const uint64_t *perm, const DimLevelType *sparsity)
268       : dimSizes(dimSizes), rev(getRank()),
269         dimTypes(sparsity, sparsity + getRank()) {
270     assert(perm && sparsity);
271     const uint64_t rank = getRank();
272     // Validate parameters.
273     assert(rank > 0 && "Trivial shape is unsupported");
274     for (uint64_t r = 0; r < rank; r++) {
275       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
276       assert((dimTypes[r] == DimLevelType::kDense ||
277               dimTypes[r] == DimLevelType::kCompressed) &&
278              "Unsupported DimLevelType");
279     }
280     // Construct the "reverse" (i.e., inverse) permutation.
281     for (uint64_t r = 0; r < rank; r++)
282       rev[perm[r]] = r;
283   }
284 
285   virtual ~SparseTensorStorageBase() = default;
286 
287   /// Get the rank of the tensor.
288   uint64_t getRank() const { return dimSizes.size(); }
289 
290   /// Getter for the dimension-sizes array, in storage-order.
291   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
292 
293   /// Safely lookup the size of the given (storage-order) dimension.
294   uint64_t getDimSize(uint64_t d) const {
295     assert(d < getRank());
296     return dimSizes[d];
297   }
298 
299   /// Getter for the "reverse" permutation, which maps this object's
300   /// storage-order to the tensor's semantic-order.
301   const std::vector<uint64_t> &getRev() const { return rev; }
302 
303   /// Getter for the dimension-types array, in storage-order.
304   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
305 
306   /// Safely check if the (storage-order) dimension uses compressed storage.
307   bool isCompressedDim(uint64_t d) const {
308     assert(d < getRank());
309     return (dimTypes[d] == DimLevelType::kCompressed);
310   }
311 
312   /// Allocate a new enumerator.
313 #define DECL_NEWENUMERATOR(VNAME, V)                                           \
314   virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t,       \
315                              const uint64_t *) const {                         \
316     FATAL_PIV("newEnumerator" #VNAME);                                         \
317   }
318   FOREVERY_V(DECL_NEWENUMERATOR)
319 #undef DECL_NEWENUMERATOR
320 
321   /// Overhead storage.
322 #define DECL_GETPOINTERS(PNAME, P)                                             \
323   virtual void getPointers(std::vector<P> **, uint64_t) {                      \
324     FATAL_PIV("getPointers" #PNAME);                                           \
325   }
326   FOREVERY_FIXED_O(DECL_GETPOINTERS)
327 #undef DECL_GETPOINTERS
328 #define DECL_GETINDICES(INAME, I)                                              \
329   virtual void getIndices(std::vector<I> **, uint64_t) {                       \
330     FATAL_PIV("getIndices" #INAME);                                            \
331   }
332   FOREVERY_FIXED_O(DECL_GETINDICES)
333 #undef DECL_GETINDICES
334 
335   /// Primary storage.
336 #define DECL_GETVALUES(VNAME, V)                                               \
337   virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
338   FOREVERY_V(DECL_GETVALUES)
339 #undef DECL_GETVALUES
340 
341   /// Element-wise insertion in lexicographic index order.
342 #define DECL_LEXINSERT(VNAME, V)                                               \
343   virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
344   FOREVERY_V(DECL_LEXINSERT)
345 #undef DECL_LEXINSERT
346 
347   /// Expanded insertion.
348 #define DECL_EXPINSERT(VNAME, V)                                               \
349   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) {      \
350     FATAL_PIV("expInsert" #VNAME);                                             \
351   }
352   FOREVERY_V(DECL_EXPINSERT)
353 #undef DECL_EXPINSERT
354 
355   /// Finishes insertion.
356   virtual void endInsert() = 0;
357 
358 protected:
359   // Since this class is virtual, we must disallow public copying in
360   // order to avoid "slicing".  Since this class has data members,
361   // that means making copying protected.
362   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
363   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
364   // Copy-assignment would be implicitly deleted (because `dimSizes`
365   // is const), so we explicitly delete it for clarity.
366   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
367 
368 private:
369   const std::vector<uint64_t> dimSizes;
370   std::vector<uint64_t> rev;
371   const std::vector<DimLevelType> dimTypes;
372 };
373 
374 #undef FATAL_PIV
375 
376 // Forward.
377 template <typename P, typename I, typename V>
378 class SparseTensorEnumerator;
379 
380 /// A memory-resident sparse tensor using a storage scheme based on
381 /// per-dimension sparse/dense annotations. This data structure provides a
382 /// bufferized form of a sparse tensor type. In contrast to generating setup
383 /// methods for each differently annotated sparse tensor, this method provides
384 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
385 /// and annotations to implement all required setup in a general manner.
386 template <typename P, typename I, typename V>
387 class SparseTensorStorage final : public SparseTensorStorageBase {
388   /// Private constructor to share code between the other constructors.
389   /// Beware that the object is not necessarily guaranteed to be in a
390   /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
391   /// doesn't entail `!(pointers[d].empty())`.
392   ///
393   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
394   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
395                       const uint64_t *perm, const DimLevelType *sparsity)
396       : SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()),
397         indices(getRank()), idx(getRank()) {}
398 
399 public:
400   /// Constructs a sparse tensor storage scheme with the given dimensions,
401   /// permutation, and per-dimension dense/sparse annotations, using
402   /// the coordinate scheme tensor for the initial contents if provided.
403   ///
404   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
405   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
406                       const uint64_t *perm, const DimLevelType *sparsity,
407                       SparseTensorCOO<V> *coo)
408       : SparseTensorStorage(dimSizes, perm, sparsity) {
409     // Provide hints on capacity of pointers and indices.
410     // TODO: needs much fine-tuning based on actual sparsity; currently
411     //       we reserve pointer/index space based on all previous dense
412     //       dimensions, which works well up to first sparse dim; but
413     //       we should really use nnz and dense/sparse distribution.
414     bool allDense = true;
415     uint64_t sz = 1;
416     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
417       if (isCompressedDim(r)) {
418         // TODO: Take a parameter between 1 and `dimSizes[r]`, and multiply
419         // `sz` by that before reserving. (For now we just use 1.)
420         pointers[r].reserve(sz + 1);
421         pointers[r].push_back(0);
422         indices[r].reserve(sz);
423         sz = 1;
424         allDense = false;
425       } else { // Dense dimension.
426         sz = checkedMul(sz, getDimSizes()[r]);
427       }
428     }
429     // Then assign contents from coordinate scheme tensor if provided.
430     if (coo) {
431       // Ensure both preconditions of `fromCOO`.
432       assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch");
433       coo->sort();
434       // Now actually insert the `elements`.
435       const std::vector<Element<V>> &elements = coo->getElements();
436       uint64_t nnz = elements.size();
437       values.reserve(nnz);
438       fromCOO(elements, 0, nnz, 0);
439     } else if (allDense) {
440       values.resize(sz, 0);
441     }
442   }
443 
444   /// Constructs a sparse tensor storage scheme with the given dimensions,
445   /// permutation, and per-dimension dense/sparse annotations, using
446   /// the given sparse tensor for the initial contents.
447   ///
448   /// Preconditions:
449   /// * `perm` and `sparsity` must be valid for `dimSizes.size()`.
450   /// * The `tensor` must have the same value type `V`.
451   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
452                       const uint64_t *perm, const DimLevelType *sparsity,
453                       const SparseTensorStorageBase &tensor);
454 
455   ~SparseTensorStorage() final = default;
456 
457   /// Partially specialize these getter methods based on template types.
458   void getPointers(std::vector<P> **out, uint64_t d) final {
459     assert(d < getRank());
460     *out = &pointers[d];
461   }
462   void getIndices(std::vector<I> **out, uint64_t d) final {
463     assert(d < getRank());
464     *out = &indices[d];
465   }
466   void getValues(std::vector<V> **out) final { *out = &values; }
467 
468   /// Partially specialize lexicographical insertions based on template types.
469   void lexInsert(const uint64_t *cursor, V val) final {
470     // First, wrap up pending insertion path.
471     uint64_t diff = 0;
472     uint64_t top = 0;
473     if (!values.empty()) {
474       diff = lexDiff(cursor);
475       endPath(diff + 1);
476       top = idx[diff] + 1;
477     }
478     // Then continue with insertion path.
479     insPath(cursor, diff, top, val);
480   }
481 
482   /// Partially specialize expanded insertions based on template types.
483   /// Note that this method resets the values/filled-switch array back
484   /// to all-zero/false while only iterating over the nonzero elements.
485   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
486                  uint64_t count) final {
487     if (count == 0)
488       return;
489     // Sort.
490     std::sort(added, added + count);
491     // Restore insertion path for first insert.
492     const uint64_t lastDim = getRank() - 1;
493     uint64_t index = added[0];
494     cursor[lastDim] = index;
495     lexInsert(cursor, values[index]);
496     assert(filled[index]);
497     values[index] = 0;
498     filled[index] = false;
499     // Subsequent insertions are quick.
500     for (uint64_t i = 1; i < count; i++) {
501       assert(index < added[i] && "non-lexicographic insertion");
502       index = added[i];
503       cursor[lastDim] = index;
504       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
505       assert(filled[index]);
506       values[index] = 0;
507       filled[index] = false;
508     }
509   }
510 
511   /// Finalizes lexicographic insertions.
512   void endInsert() final {
513     if (values.empty())
514       finalizeSegment(0);
515     else
516       endPath(0);
517   }
518 
519   void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
520                      const uint64_t *perm) const final {
521     *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
522   }
523 
524   /// Returns this sparse tensor storage scheme as a new memory-resident
525   /// sparse tensor in coordinate scheme with the given dimension order.
526   ///
527   /// Precondition: `perm` must be valid for `getRank()`.
528   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
529     SparseTensorEnumeratorBase<V> *enumerator;
530     newEnumerator(&enumerator, getRank(), perm);
531     SparseTensorCOO<V> *coo =
532         new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
533     enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
534       coo->add(ind, val);
535     });
536     // TODO: This assertion assumes there are no stored zeros,
537     // or if there are then that we don't filter them out.
538     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
539     assert(coo->getElements().size() == values.size());
540     delete enumerator;
541     return coo;
542   }
543 
544   /// Factory method. Constructs a sparse tensor storage scheme with the given
545   /// dimensions, permutation, and per-dimension dense/sparse annotations,
546   /// using the coordinate scheme tensor for the initial contents if provided.
547   /// In the latter case, the coordinate scheme must respect the same
548   /// permutation as is desired for the new sparse tensor storage.
549   ///
550   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
551   static SparseTensorStorage<P, I, V> *
552   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
553                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
554     SparseTensorStorage<P, I, V> *n = nullptr;
555     if (coo) {
556       const auto &coosz = coo->getDimSizes();
557       assertPermutedSizesMatchShape(coosz, rank, perm, shape);
558       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
559     } else {
560       std::vector<uint64_t> permsz(rank);
561       for (uint64_t r = 0; r < rank; r++) {
562         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
563         permsz[perm[r]] = shape[r];
564       }
565       // We pass the null `coo` to ensure we select the intended constructor.
566       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
567     }
568     return n;
569   }
570 
571   /// Factory method. Constructs a sparse tensor storage scheme with
572   /// the given dimensions, permutation, and per-dimension dense/sparse
573   /// annotations, using the sparse tensor for the initial contents.
574   ///
575   /// Preconditions:
576   /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
577   /// * The `tensor` must have the same value type `V`.
578   static SparseTensorStorage<P, I, V> *
579   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
580                   const DimLevelType *sparsity,
581                   const SparseTensorStorageBase *source) {
582     assert(source && "Got nullptr for source");
583     SparseTensorEnumeratorBase<V> *enumerator;
584     source->newEnumerator(&enumerator, rank, perm);
585     const auto &permsz = enumerator->permutedSizes();
586     assertPermutedSizesMatchShape(permsz, rank, perm, shape);
587     auto *tensor =
588         new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
589     delete enumerator;
590     return tensor;
591   }
592 
593 private:
594   /// Appends an arbitrary new position to `pointers[d]`.  This method
595   /// checks that `pos` is representable in the `P` type; however, it
596   /// does not check that `pos` is semantically valid (i.e., larger than
597   /// the previous position and smaller than `indices[d].capacity()`).
598   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
599     assert(isCompressedDim(d));
600     assert(pos <= std::numeric_limits<P>::max() &&
601            "Pointer value is too large for the P-type");
602     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
603   }
604 
605   /// Appends index `i` to dimension `d`, in the semantically general
606   /// sense.  For non-dense dimensions, that means appending to the
607   /// `indices[d]` array, checking that `i` is representable in the `I`
608   /// type; however, we do not verify other semantic requirements (e.g.,
609   /// that `i` is in bounds for `dimSizes[d]`, and not previously occurring
610   /// in the same segment).  For dense dimensions, this method instead
611   /// appends the appropriate number of zeros to the `values` array,
612   /// where `full` is the number of "entries" already written to `values`
613   /// for this segment (aka one after the highest index previously appended).
614   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
615     if (isCompressedDim(d)) {
616       assert(i <= std::numeric_limits<I>::max() &&
617              "Index value is too large for the I-type");
618       indices[d].push_back(static_cast<I>(i));
619     } else { // Dense dimension.
620       assert(i >= full && "Index was already filled");
621       if (i == full)
622         return; // Short-circuit, since it'll be a nop.
623       if (d + 1 == getRank())
624         values.insert(values.end(), i - full, 0);
625       else
626         finalizeSegment(d + 1, 0, i - full);
627     }
628   }
629 
630   /// Writes the given coordinate to `indices[d][pos]`.  This method
631   /// checks that `i` is representable in the `I` type; however, it
632   /// does not check that `i` is semantically valid (i.e., in bounds
633   /// for `dimSizes[d]` and not elsewhere occurring in the same segment).
634   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
635     assert(isCompressedDim(d));
636     // Subscript assignment to `std::vector` requires that the `pos`-th
637     // entry has been initialized; thus we must be sure to check `size()`
638     // here, instead of `capacity()` as would be ideal.
639     assert(pos < indices[d].size() && "Index position is out of bounds");
640     assert(i <= std::numeric_limits<I>::max() &&
641            "Index value is too large for the I-type");
642     indices[d][pos] = static_cast<I>(i);
643   }
644 
645   /// Computes the assembled-size associated with the `d`-th dimension,
646   /// given the assembled-size associated with the `(d-1)`-th dimension.
647   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
648   /// storage, as opposed to "dimension-sizes" which are the cardinality
649   /// of coordinates for that dimension.
650   ///
651   /// Precondition: the `pointers[d]` array must be fully initialized
652   /// before calling this method.
653   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
654     if (isCompressedDim(d))
655       return pointers[d][parentSz];
656     // else if dense:
657     return parentSz * getDimSizes()[d];
658   }
659 
660   /// Initializes sparse tensor storage scheme from a memory-resident sparse
661   /// tensor in coordinate scheme. This method prepares the pointers and
662   /// indices arrays under the given per-dimension dense/sparse annotations.
663   ///
664   /// Preconditions:
665   /// (1) the `elements` must be lexicographically sorted.
666   /// (2) the indices of every element are valid for `dimSizes` (equal rank
667   ///     and pointwise less-than).
668   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
669                uint64_t hi, uint64_t d) {
670     uint64_t rank = getRank();
671     assert(d <= rank && hi <= elements.size());
672     // Once dimensions are exhausted, insert the numerical values.
673     if (d == rank) {
674       assert(lo < hi);
675       values.push_back(elements[lo].value);
676       return;
677     }
678     // Visit all elements in this interval.
679     uint64_t full = 0;
680     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
681       // Find segment in interval with same index elements in this dimension.
682       uint64_t i = elements[lo].indices[d];
683       uint64_t seg = lo + 1;
684       while (seg < hi && elements[seg].indices[d] == i)
685         seg++;
686       // Handle segment in interval for sparse or dense dimension.
687       appendIndex(d, full, i);
688       full = i + 1;
689       fromCOO(elements, lo, seg, d + 1);
690       // And move on to next segment in interval.
691       lo = seg;
692     }
693     // Finalize the sparse pointer structure at this dimension.
694     finalizeSegment(d, full);
695   }
696 
697   /// Finalize the sparse pointer structure at this dimension.
698   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
699     if (count == 0)
700       return; // Short-circuit, since it'll be a nop.
701     if (isCompressedDim(d)) {
702       appendPointer(d, indices[d].size(), count);
703     } else { // Dense dimension.
704       const uint64_t sz = getDimSizes()[d];
705       assert(sz >= full && "Segment is overfull");
706       count = checkedMul(count, sz - full);
707       // For dense storage we must enumerate all the remaining coordinates
708       // in this dimension (i.e., coordinates after the last non-zero
709       // element), and either fill in their zero values or else recurse
710       // to finalize some deeper dimension.
711       if (d + 1 == getRank())
712         values.insert(values.end(), count, 0);
713       else
714         finalizeSegment(d + 1, 0, count);
715     }
716   }
717 
718   /// Wraps up a single insertion path, inner to outer.
719   void endPath(uint64_t diff) {
720     uint64_t rank = getRank();
721     assert(diff <= rank);
722     for (uint64_t i = 0; i < rank - diff; i++) {
723       const uint64_t d = rank - i - 1;
724       finalizeSegment(d, idx[d] + 1);
725     }
726   }
727 
728   /// Continues a single insertion path, outer to inner.
729   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
730     uint64_t rank = getRank();
731     assert(diff < rank);
732     for (uint64_t d = diff; d < rank; d++) {
733       uint64_t i = cursor[d];
734       appendIndex(d, top, i);
735       top = 0;
736       idx[d] = i;
737     }
738     values.push_back(val);
739   }
740 
741   /// Finds the lexicographic differing dimension.
742   uint64_t lexDiff(const uint64_t *cursor) const {
743     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
744       if (cursor[r] > idx[r])
745         return r;
746       else
747         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
748     assert(0 && "duplication insertion");
749     return -1u;
750   }
751 
752   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
753   // the cost of virtual-function dispatch in inner loops), without
754   // making them public to other client code.
755   friend class SparseTensorEnumerator<P, I, V>;
756 
757   std::vector<std::vector<P>> pointers;
758   std::vector<std::vector<I>> indices;
759   std::vector<V> values;
760   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
761 };
762 
763 /// A (higher-order) function object for enumerating the elements of some
764 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
765 /// method encapsulates the loop-nest for enumerating the elements of
766 /// the source tensor (in whatever order is best for the source tensor),
767 /// and applies a permutation to the coordinates/indices before handing
768 /// each element to the callback.  A single enumerator object can be
769 /// freely reused for several calls to `forallElements`, just so long
770 /// as each call is sequential with respect to one another.
771 ///
772 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
773 /// passed to the constructor; thus, objects of this class must not
774 /// outlive the sparse tensor they depend on.
775 ///
776 /// Design Note: The reason we define this class instead of simply using
777 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
778 /// the `<P,I>` template parameters from MLIR client code (to simplify the
779 /// type parameters used for direct sparse-to-sparse conversion).  And the
780 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
781 /// than simply using this class, is to avoid the cost of virtual-method
782 /// dispatch within the loop-nest.
783 template <typename V>
784 class SparseTensorEnumeratorBase {
785 public:
786   /// Constructs an enumerator with the given permutation for mapping
787   /// the semantic-ordering of dimensions to the desired target-ordering.
788   ///
789   /// Preconditions:
790   /// * the `tensor` must have the same `V` value type.
791   /// * `perm` must be valid for `rank`.
792   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
793                              uint64_t rank, const uint64_t *perm)
794       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
795         cursor(getRank()) {
796     assert(perm && "Received nullptr for permutation");
797     assert(rank == getRank() && "Permutation rank mismatch");
798     const auto &rev = src.getRev();           // source-order -> semantic-order
799     const auto &dimSizes = src.getDimSizes(); // in source storage-order
800     for (uint64_t s = 0; s < rank; s++) {     // `s` source storage-order
801       uint64_t t = perm[rev[s]];              // `t` target-order
802       reord[s] = t;
803       permsz[t] = dimSizes[s];
804     }
805   }
806 
807   virtual ~SparseTensorEnumeratorBase() = default;
808 
809   // We disallow copying to help avoid leaking the `src` reference.
810   // (In addition to avoiding the problem of slicing.)
811   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
812   SparseTensorEnumeratorBase &
813   operator=(const SparseTensorEnumeratorBase &) = delete;
814 
815   /// Returns the source/target tensor's rank.  (The source-rank and
816   /// target-rank are always equal since we only support permutations.
817   /// Though once we add support for other dimension mappings, this
818   /// method will have to be split in two.)
819   uint64_t getRank() const { return permsz.size(); }
820 
821   /// Returns the target tensor's dimension sizes.
822   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
823 
824   /// Enumerates all elements of the source tensor, permutes their
825   /// indices, and passes the permuted element to the callback.
826   /// The callback must not store the cursor reference directly,
827   /// since this function reuses the storage.  Instead, the callback
828   /// must copy it if they want to keep it.
829   virtual void forallElements(ElementConsumer<V> yield) = 0;
830 
831 protected:
832   const SparseTensorStorageBase &src;
833   std::vector<uint64_t> permsz; // in target order.
834   std::vector<uint64_t> reord;  // source storage-order -> target order.
835   std::vector<uint64_t> cursor; // in target order.
836 };
837 
838 template <typename P, typename I, typename V>
839 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
840   using Base = SparseTensorEnumeratorBase<V>;
841 
842 public:
843   /// Constructs an enumerator with the given permutation for mapping
844   /// the semantic-ordering of dimensions to the desired target-ordering.
845   ///
846   /// Precondition: `perm` must be valid for `rank`.
847   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
848                          uint64_t rank, const uint64_t *perm)
849       : Base(tensor, rank, perm) {}
850 
851   ~SparseTensorEnumerator() final = default;
852 
853   void forallElements(ElementConsumer<V> yield) final {
854     forallElements(yield, 0, 0);
855   }
856 
857 private:
858   /// The recursive component of the public `forallElements`.
859   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
860                       uint64_t d) {
861     // Recover the `<P,I,V>` type parameters of `src`.
862     const auto &src =
863         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
864     if (d == Base::getRank()) {
865       assert(parentPos < src.values.size() &&
866              "Value position is out of bounds");
867       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
868       yield(this->cursor, src.values[parentPos]);
869     } else if (src.isCompressedDim(d)) {
870       // Look up the bounds of the `d`-level segment determined by the
871       // `d-1`-level position `parentPos`.
872       const std::vector<P> &pointersD = src.pointers[d];
873       assert(parentPos + 1 < pointersD.size() &&
874              "Parent pointer position is out of bounds");
875       const uint64_t pstart = static_cast<uint64_t>(pointersD[parentPos]);
876       const uint64_t pstop = static_cast<uint64_t>(pointersD[parentPos + 1]);
877       // Loop-invariant code for looking up the `d`-level coordinates/indices.
878       const std::vector<I> &indicesD = src.indices[d];
879       assert(pstop <= indicesD.size() && "Index position is out of bounds");
880       uint64_t &cursorReordD = this->cursor[this->reord[d]];
881       for (uint64_t pos = pstart; pos < pstop; pos++) {
882         cursorReordD = static_cast<uint64_t>(indicesD[pos]);
883         forallElements(yield, pos, d + 1);
884       }
885     } else { // Dense dimension.
886       const uint64_t sz = src.getDimSizes()[d];
887       const uint64_t pstart = parentPos * sz;
888       uint64_t &cursorReordD = this->cursor[this->reord[d]];
889       for (uint64_t i = 0; i < sz; i++) {
890         cursorReordD = i;
891         forallElements(yield, pstart + i, d + 1);
892       }
893     }
894   }
895 };
896 
897 /// Statistics regarding the number of nonzero subtensors in
898 /// a source tensor, for direct sparse=>sparse conversion a la
899 /// <https://arxiv.org/abs/2001.02609>.
900 ///
901 /// N.B., this class stores references to the parameters passed to
902 /// the constructor; thus, objects of this class must not outlive
903 /// those parameters.
904 class SparseTensorNNZ final {
905 public:
906   /// Allocate the statistics structure for the desired sizes and
907   /// sparsity (in the target tensor's storage-order).  This constructor
908   /// does not actually populate the statistics, however; for that see
909   /// `initialize`.
910   ///
911   /// Precondition: `dimSizes` must not contain zeros.
912   SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
913                   const std::vector<DimLevelType> &sparsity)
914       : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
915     assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
916     bool uncompressed = true;
917     (void)uncompressed;
918     uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
919     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
920       switch (dimTypes[r]) {
921       case DimLevelType::kCompressed:
922         assert(uncompressed &&
923                "Multiple compressed layers not currently supported");
924         uncompressed = false;
925         nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
926         break;
927       case DimLevelType::kDense:
928         assert(uncompressed &&
929                "Dense after compressed not currently supported");
930         break;
931       case DimLevelType::kSingleton:
932         // Singleton after Compressed causes no problems for allocating
933         // `nnz` nor for the yieldPos loop.  This remains true even
934         // when adding support for multiple compressed dimensions or
935         // for dense-after-compressed.
936         break;
937       }
938       sz = checkedMul(sz, dimSizes[r]);
939     }
940   }
941 
942   // We disallow copying to help avoid leaking the stored references.
943   SparseTensorNNZ(const SparseTensorNNZ &) = delete;
944   SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
945 
946   /// Returns the rank of the target tensor.
947   uint64_t getRank() const { return dimSizes.size(); }
948 
949   /// Enumerate the source tensor to fill in the statistics.  The
950   /// enumerator should already incorporate the permutation (from
951   /// semantic-order to the target storage-order).
952   template <typename V>
953   void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
954     assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
955     assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
956     enumerator.forallElements(
957         [this](const std::vector<uint64_t> &ind, V) { add(ind); });
958   }
959 
960   /// The type of callback functions which receive an nnz-statistic.
961   using NNZConsumer = const std::function<void(uint64_t)> &;
962 
963   /// Lexicographically enumerates all indicies for dimensions strictly
964   /// less than `stopDim`, and passes their nnz statistic to the callback.
965   /// Since our use-case only requires the statistic not the coordinates
966   /// themselves, we do not bother to construct those coordinates.
967   void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
968     assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
969     assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
970            "Cannot look up non-compressed dimensions");
971     forallIndices(yield, stopDim, 0, 0);
972   }
973 
974 private:
975   /// Adds a new element (i.e., increment its statistics).  We use
976   /// a method rather than inlining into the lambda in `initialize`,
977   /// to avoid spurious templating over `V`.  And this method is private
978   /// to avoid needing to re-assert validity of `ind` (which is guaranteed
979   /// by `forallElements`).
980   void add(const std::vector<uint64_t> &ind) {
981     uint64_t parentPos = 0;
982     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
983       if (dimTypes[r] == DimLevelType::kCompressed)
984         nnz[r][parentPos]++;
985       parentPos = parentPos * dimSizes[r] + ind[r];
986     }
987   }
988 
989   /// Recursive component of the public `forallIndices`.
990   void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
991                      uint64_t d) const {
992     assert(d <= stopDim);
993     if (d == stopDim) {
994       assert(parentPos < nnz[d].size() && "Cursor is out of range");
995       yield(nnz[d][parentPos]);
996     } else {
997       const uint64_t sz = dimSizes[d];
998       const uint64_t pstart = parentPos * sz;
999       for (uint64_t i = 0; i < sz; i++)
1000         forallIndices(yield, stopDim, pstart + i, d + 1);
1001     }
1002   }
1003 
1004   // All of these are in the target storage-order.
1005   const std::vector<uint64_t> &dimSizes;
1006   const std::vector<DimLevelType> &dimTypes;
1007   std::vector<std::vector<uint64_t>> nnz;
1008 };
1009 
1010 template <typename P, typename I, typename V>
1011 SparseTensorStorage<P, I, V>::SparseTensorStorage(
1012     const std::vector<uint64_t> &dimSizes, const uint64_t *perm,
1013     const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1014     : SparseTensorStorage(dimSizes, perm, sparsity) {
1015   SparseTensorEnumeratorBase<V> *enumerator;
1016   tensor.newEnumerator(&enumerator, getRank(), perm);
1017   {
1018     // Initialize the statistics structure.
1019     SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
1020     nnz.initialize(*enumerator);
1021     // Initialize "pointers" overhead (and allocate "indices", "values").
1022     uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
1023     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1024       if (isCompressedDim(r)) {
1025         pointers[r].reserve(parentSz + 1);
1026         pointers[r].push_back(0);
1027         uint64_t currentPos = 0;
1028         nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
1029           currentPos += n;
1030           appendPointer(r, currentPos);
1031         });
1032         assert(pointers[r].size() == parentSz + 1 &&
1033                "Final pointers size doesn't match allocated size");
1034         // That assertion entails `assembledSize(parentSz, r)`
1035         // is now in a valid state.  That is, `pointers[r][parentSz]`
1036         // equals the present value of `currentPos`, which is the
1037         // correct assembled-size for `indices[r]`.
1038       }
1039       // Update assembled-size for the next iteration.
1040       parentSz = assembledSize(parentSz, r);
1041       // Ideally we need only `indices[r].reserve(parentSz)`, however
1042       // the `std::vector` implementation forces us to initialize it too.
1043       // That is, in the yieldPos loop we need random-access assignment
1044       // to `indices[r]`; however, `std::vector`'s subscript-assignment
1045       // only allows assigning to already-initialized positions.
1046       if (isCompressedDim(r))
1047         indices[r].resize(parentSz, 0);
1048     }
1049     values.resize(parentSz, 0); // Both allocate and zero-initialize.
1050   }
1051   // The yieldPos loop
1052   enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
1053     uint64_t parentSz = 1, parentPos = 0;
1054     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1055       if (isCompressedDim(r)) {
1056         // If `parentPos == parentSz` then it's valid as an array-lookup;
1057         // however, it's semantically invalid here since that entry
1058         // does not represent a segment of `indices[r]`.  Moreover, that
1059         // entry must be immutable for `assembledSize` to remain valid.
1060         assert(parentPos < parentSz && "Pointers position is out of bounds");
1061         const uint64_t currentPos = pointers[r][parentPos];
1062         // This increment won't overflow the `P` type, since it can't
1063         // exceed the original value of `pointers[r][parentPos+1]`
1064         // which was already verified to be within bounds for `P`
1065         // when it was written to the array.
1066         pointers[r][parentPos]++;
1067         writeIndex(r, currentPos, ind[r]);
1068         parentPos = currentPos;
1069       } else { // Dense dimension.
1070         parentPos = parentPos * getDimSizes()[r] + ind[r];
1071       }
1072       parentSz = assembledSize(parentSz, r);
1073     }
1074     assert(parentPos < values.size() && "Value position is out of bounds");
1075     values[parentPos] = val;
1076   });
1077   // No longer need the enumerator, so we'll delete it ASAP.
1078   delete enumerator;
1079   // The finalizeYieldPos loop
1080   for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
1081     if (isCompressedDim(r)) {
1082       assert(parentSz == pointers[r].size() - 1 &&
1083              "Actual pointers size doesn't match the expected size");
1084       // Can't check all of them, but at least we can check the last one.
1085       assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
1086              "Pointers got corrupted");
1087       // TODO: optimize this by using `memmove` or similar.
1088       for (uint64_t n = 0; n < parentSz; n++) {
1089         const uint64_t parentPos = parentSz - n;
1090         pointers[r][parentPos] = pointers[r][parentPos - 1];
1091       }
1092       pointers[r][0] = 0;
1093     }
1094     parentSz = assembledSize(parentSz, r);
1095   }
1096 }
1097 
1098 /// Helper to convert string to lower case.
1099 static char *toLower(char *token) {
1100   for (char *c = token; *c; c++)
1101     *c = tolower(*c);
1102   return token;
1103 }
1104 
1105 /// This class abstracts over the information stored in file headers,
1106 /// as well as providing the buffers and methods for parsing those headers.
1107 class SparseTensorFile final {
1108 public:
1109   enum class ValueKind {
1110     kInvalid = 0,
1111     kPattern = 1,
1112     kReal = 2,
1113     kInteger = 3,
1114     kComplex = 4,
1115     kUndefined = 5
1116   };
1117 
1118   explicit SparseTensorFile(char *filename) : filename(filename) {
1119     assert(filename && "Received nullptr for filename");
1120   }
1121 
1122   // Disallows copying, to avoid duplicating the `file` pointer.
1123   SparseTensorFile(const SparseTensorFile &) = delete;
1124   SparseTensorFile &operator=(const SparseTensorFile &) = delete;
1125 
1126   // This dtor tries to avoid leaking the `file`.  (Though it's better
1127   // to call `closeFile` explicitly when possible, since there are
1128   // circumstances where dtors are not called reliably.)
1129   ~SparseTensorFile() { closeFile(); }
1130 
1131   /// Opens the file for reading.
1132   void openFile() {
1133     if (file)
1134       FATAL("Already opened file %s\n", filename);
1135     file = fopen(filename, "r");
1136     if (!file)
1137       FATAL("Cannot find file %s\n", filename);
1138   }
1139 
1140   /// Closes the file.
1141   void closeFile() {
1142     if (file) {
1143       fclose(file);
1144       file = nullptr;
1145     }
1146   }
1147 
1148   // TODO(wrengr/bixia): figure out how to reorganize the element-parsing
1149   // loop of `openSparseTensorCOO` into methods of this class, so we can
1150   // avoid leaking access to the `line` pointer (both for general hygiene
1151   // and because we can't mark it const due to the second argument of
1152   // `strtoul`/`strtoud` being `char * *restrict` rather than
1153   // `char const* *restrict`).
1154   //
1155   /// Attempts to read a line from the file.
1156   char *readLine() {
1157     if (fgets(line, kColWidth, file))
1158       return line;
1159     FATAL("Cannot read next line of %s\n", filename);
1160   }
1161 
1162   /// Reads and parses the file's header.
1163   void readHeader() {
1164     assert(file && "Attempt to readHeader() before openFile()");
1165     if (strstr(filename, ".mtx"))
1166       readMMEHeader();
1167     else if (strstr(filename, ".tns"))
1168       readExtFROSTTHeader();
1169     else
1170       FATAL("Unknown format %s\n", filename);
1171     assert(isValid() && "Failed to read the header");
1172   }
1173 
1174   ValueKind getValueKind() const { return valueKind_; }
1175 
1176   bool isValid() const { return valueKind_ != ValueKind::kInvalid; }
1177 
1178   /// Gets the MME "pattern" property setting.  Is only valid after
1179   /// parsing the header.
1180   bool isPattern() const {
1181     assert(isValid() && "Attempt to isPattern() before readHeader()");
1182     return valueKind_ == ValueKind::kPattern;
1183   }
1184 
1185   /// Gets the MME "symmetric" property setting.  Is only valid after
1186   /// parsing the header.
1187   bool isSymmetric() const {
1188     assert(isValid() && "Attempt to isSymmetric() before readHeader()");
1189     return isSymmetric_;
1190   }
1191 
1192   /// Gets the rank of the tensor.  Is only valid after parsing the header.
1193   uint64_t getRank() const {
1194     assert(isValid() && "Attempt to getRank() before readHeader()");
1195     return idata[0];
1196   }
1197 
1198   /// Gets the number of non-zeros.  Is only valid after parsing the header.
1199   uint64_t getNNZ() const {
1200     assert(isValid() && "Attempt to getNNZ() before readHeader()");
1201     return idata[1];
1202   }
1203 
1204   /// Gets the dimension-sizes array.  The pointer itself is always
1205   /// valid; however, the values stored therein are only valid after
1206   /// parsing the header.
1207   const uint64_t *getDimSizes() const { return idata + 2; }
1208 
1209   /// Safely gets the size of the given dimension.  Is only valid
1210   /// after parsing the header.
1211   uint64_t getDimSize(uint64_t d) const {
1212     assert(d < getRank());
1213     return idata[2 + d];
1214   }
1215 
1216   /// Asserts the shape subsumes the actual dimension sizes.  Is only
1217   /// valid after parsing the header.
1218   void assertMatchesShape(uint64_t rank, const uint64_t *shape) const {
1219     assert(rank == getRank() && "Rank mismatch");
1220     for (uint64_t r = 0; r < rank; r++)
1221       assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
1222              "Dimension size mismatch");
1223   }
1224 
1225 private:
1226   void readMMEHeader();
1227   void readExtFROSTTHeader();
1228 
1229   const char *filename;
1230   FILE *file = nullptr;
1231   ValueKind valueKind_ = ValueKind::kInvalid;
1232   bool isSymmetric_ = false;
1233   uint64_t idata[512];
1234   char line[kColWidth];
1235 };
1236 
1237 /// Read the MME header of a general sparse matrix of type real.
1238 void SparseTensorFile::readMMEHeader() {
1239   char header[64];
1240   char object[64];
1241   char format[64];
1242   char field[64];
1243   char symmetry[64];
1244   // Read header line.
1245   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
1246              symmetry) != 5)
1247     FATAL("Corrupt header in %s\n", filename);
1248   // Process `field`, which specify pattern or the data type of the values.
1249   if (strcmp(toLower(field), "pattern") == 0)
1250     valueKind_ = ValueKind::kPattern;
1251   else if (strcmp(toLower(field), "real") == 0)
1252     valueKind_ = ValueKind::kReal;
1253   else if (strcmp(toLower(field), "integer") == 0)
1254     valueKind_ = ValueKind::kInteger;
1255   else if (strcmp(toLower(field), "complex") == 0)
1256     valueKind_ = ValueKind::kComplex;
1257   else
1258     FATAL("Unexpected header field value in %s\n", filename);
1259 
1260   // Set properties.
1261   isSymmetric_ = (strcmp(toLower(symmetry), "symmetric") == 0);
1262   // Make sure this is a general sparse matrix.
1263   if (strcmp(toLower(header), "%%matrixmarket") ||
1264       strcmp(toLower(object), "matrix") ||
1265       strcmp(toLower(format), "coordinate") ||
1266       (strcmp(toLower(symmetry), "general") && !isSymmetric_))
1267     FATAL("Cannot find a general sparse matrix in %s\n", filename);
1268   // Skip comments.
1269   while (true) {
1270     readLine();
1271     if (line[0] != '%')
1272       break;
1273   }
1274   // Next line contains M N NNZ.
1275   idata[0] = 2; // rank
1276   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
1277              idata + 1) != 3)
1278     FATAL("Cannot find size in %s\n", filename);
1279 }
1280 
1281 /// Read the "extended" FROSTT header. Although not part of the documented
1282 /// format, we assume that the file starts with optional comments followed
1283 /// by two lines that define the rank, the number of nonzeros, and the
1284 /// dimensions sizes (one per rank) of the sparse tensor.
1285 void SparseTensorFile::readExtFROSTTHeader() {
1286   // Skip comments.
1287   while (true) {
1288     readLine();
1289     if (line[0] != '#')
1290       break;
1291   }
1292   // Next line contains RANK and NNZ.
1293   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
1294     FATAL("Cannot find metadata in %s\n", filename);
1295   // Followed by a line with the dimension sizes (one per rank).
1296   for (uint64_t r = 0; r < idata[0]; r++)
1297     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
1298       FATAL("Cannot find dimension size %s\n", filename);
1299   readLine(); // end of line
1300   // The FROSTT format does not define the data type of the nonzero elements.
1301   valueKind_ = ValueKind::kUndefined;
1302 }
1303 
1304 // Adds a value to a tensor in coordinate scheme. If is_symmetric_value is true,
1305 // also adds the value to its symmetric location.
1306 template <typename T, typename V>
1307 static inline void addValue(T *coo, V value,
1308                             const std::vector<uint64_t> indices,
1309                             bool is_symmetric_value) {
1310   // TODO: <https://github.com/llvm/llvm-project/issues/54179>
1311   coo->add(indices, value);
1312   // We currently chose to deal with symmetric matrices by fully constructing
1313   // them. In the future, we may want to make symmetry implicit for storage
1314   // reasons.
1315   if (is_symmetric_value)
1316     coo->add({indices[1], indices[0]}, value);
1317 }
1318 
1319 // Reads an element of a complex type for the current indices in coordinate
1320 // scheme.
1321 template <typename V>
1322 static inline void readCOOValue(SparseTensorCOO<std::complex<V>> *coo,
1323                                 const std::vector<uint64_t> indices,
1324                                 char **linePtr, bool is_pattern,
1325                                 bool add_symmetric_value) {
1326   // Read two values to make a complex. The external formats always store
1327   // numerical values with the type double, but we cast these values to the
1328   // sparse tensor object type. For a pattern tensor, we arbitrarily pick the
1329   // value 1 for all entries.
1330   V re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
1331   V im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
1332   std::complex<V> value = {re, im};
1333   addValue(coo, value, indices, add_symmetric_value);
1334 }
1335 
1336 // Reads an element of a non-complex type for the current indices in coordinate
1337 // scheme.
1338 template <typename V,
1339           typename std::enable_if<
1340               !std::is_same<std::complex<float>, V>::value &&
1341               !std::is_same<std::complex<double>, V>::value>::type * = nullptr>
1342 static void inline readCOOValue(SparseTensorCOO<V> *coo,
1343                                 const std::vector<uint64_t> indices,
1344                                 char **linePtr, bool is_pattern,
1345                                 bool is_symmetric_value) {
1346   // The external formats always store these numerical values with the type
1347   // double, but we cast these values to the sparse tensor object type.
1348   // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
1349   double value = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
1350   addValue(coo, value, indices, is_symmetric_value);
1351 }
1352 
1353 /// Reads a sparse tensor with the given filename into a memory-resident
1354 /// sparse tensor in coordinate scheme.
1355 template <typename V>
1356 static SparseTensorCOO<V> *
1357 openSparseTensorCOO(char *filename, uint64_t rank, const uint64_t *shape,
1358                     const uint64_t *perm, PrimaryType valTp) {
1359   SparseTensorFile stfile(filename);
1360   stfile.openFile();
1361   stfile.readHeader();
1362   // Check tensor element type against the value type in the input file.
1363   SparseTensorFile::ValueKind valueKind = stfile.getValueKind();
1364   bool tensorIsInteger =
1365       (valTp >= PrimaryType::kI64 && valTp <= PrimaryType::kI8);
1366   bool tensorIsReal = (valTp >= PrimaryType::kF64 && valTp <= PrimaryType::kI8);
1367   if ((valueKind == SparseTensorFile::ValueKind::kReal && tensorIsInteger) ||
1368       (valueKind == SparseTensorFile::ValueKind::kComplex && tensorIsReal)) {
1369     FATAL("Tensor element type %d not compatible with values in file %s\n",
1370           static_cast<int>(valTp), filename);
1371   }
1372   stfile.assertMatchesShape(rank, shape);
1373   // Prepare sparse tensor object with per-dimension sizes
1374   // and the number of nonzeros as initial capacity.
1375   uint64_t nnz = stfile.getNNZ();
1376   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, stfile.getDimSizes(),
1377                                                      perm, nnz);
1378   // Read all nonzero elements.
1379   std::vector<uint64_t> indices(rank);
1380   for (uint64_t k = 0; k < nnz; k++) {
1381     char *linePtr = stfile.readLine();
1382     for (uint64_t r = 0; r < rank; r++) {
1383       uint64_t idx = strtoul(linePtr, &linePtr, 10);
1384       // Add 0-based index.
1385       indices[perm[r]] = idx - 1;
1386     }
1387     readCOOValue(coo, indices, &linePtr, stfile.isPattern(),
1388                  stfile.isSymmetric() && indices[0] != indices[1]);
1389   }
1390   // Close the file and return tensor.
1391   stfile.closeFile();
1392   return coo;
1393 }
1394 
1395 /// Writes the sparse tensor to `dest` in extended FROSTT format.
1396 template <typename V>
1397 static void outSparseTensor(void *tensor, void *dest, bool sort) {
1398   assert(tensor && dest);
1399   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
1400   if (sort)
1401     coo->sort();
1402   char *filename = static_cast<char *>(dest);
1403   auto &dimSizes = coo->getDimSizes();
1404   auto &elements = coo->getElements();
1405   uint64_t rank = coo->getRank();
1406   uint64_t nnz = elements.size();
1407   std::fstream file;
1408   file.open(filename, std::ios_base::out | std::ios_base::trunc);
1409   assert(file.is_open());
1410   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1411   for (uint64_t r = 0; r < rank - 1; r++)
1412     file << dimSizes[r] << " ";
1413   file << dimSizes[rank - 1] << std::endl;
1414   for (uint64_t i = 0; i < nnz; i++) {
1415     auto &idx = elements[i].indices;
1416     for (uint64_t r = 0; r < rank; r++)
1417       file << (idx[r] + 1) << " ";
1418     file << elements[i].value << std::endl;
1419   }
1420   file.flush();
1421   file.close();
1422   assert(file.good());
1423 }
1424 
1425 /// Initializes sparse tensor from an external COO-flavored format.
1426 template <typename V>
1427 static SparseTensorStorage<uint64_t, uint64_t, V> *
1428 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
1429                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
1430   const DimLevelType *sparsity = (DimLevelType *)(sparse);
1431 #ifndef NDEBUG
1432   // Verify that perm is a permutation of 0..(rank-1).
1433   std::vector<uint64_t> order(perm, perm + rank);
1434   std::sort(order.begin(), order.end());
1435   for (uint64_t i = 0; i < rank; ++i)
1436     if (i != order[i])
1437       FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
1438 
1439   // Verify that the sparsity values are supported.
1440   for (uint64_t i = 0; i < rank; ++i)
1441     if (sparsity[i] != DimLevelType::kDense &&
1442         sparsity[i] != DimLevelType::kCompressed)
1443       FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
1444 #endif
1445 
1446   // Convert external format to internal COO.
1447   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
1448   std::vector<uint64_t> idx(rank);
1449   for (uint64_t i = 0, base = 0; i < nse; i++) {
1450     for (uint64_t r = 0; r < rank; r++)
1451       idx[perm[r]] = indices[base + r];
1452     coo->add(idx, values[i]);
1453     base += rank;
1454   }
1455   // Return sparse tensor storage format as opaque pointer.
1456   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
1457       rank, shape, perm, sparsity, coo);
1458   delete coo;
1459   return tensor;
1460 }
1461 
1462 /// Converts a sparse tensor to an external COO-flavored format.
1463 template <typename V>
1464 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1465                                  uint64_t **pShape, V **pValues,
1466                                  uint64_t **pIndices) {
1467   assert(tensor);
1468   auto sparseTensor =
1469       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1470   uint64_t rank = sparseTensor->getRank();
1471   std::vector<uint64_t> perm(rank);
1472   std::iota(perm.begin(), perm.end(), 0);
1473   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1474 
1475   const std::vector<Element<V>> &elements = coo->getElements();
1476   uint64_t nse = elements.size();
1477 
1478   uint64_t *shape = new uint64_t[rank];
1479   for (uint64_t i = 0; i < rank; i++)
1480     shape[i] = coo->getDimSizes()[i];
1481 
1482   V *values = new V[nse];
1483   uint64_t *indices = new uint64_t[rank * nse];
1484 
1485   for (uint64_t i = 0, base = 0; i < nse; i++) {
1486     values[i] = elements[i].value;
1487     for (uint64_t j = 0; j < rank; j++)
1488       indices[base + j] = elements[i].indices[j];
1489     base += rank;
1490   }
1491 
1492   delete coo;
1493   *pRank = rank;
1494   *pNse = nse;
1495   *pShape = shape;
1496   *pValues = values;
1497   *pIndices = indices;
1498 }
1499 
1500 } // anonymous namespace
1501 
1502 extern "C" {
1503 
1504 //===----------------------------------------------------------------------===//
1505 //
1506 // Public functions which operate on MLIR buffers (memrefs) to interact
1507 // with sparse tensors (which are only visible as opaque pointers externally).
1508 //
1509 //===----------------------------------------------------------------------===//
1510 
1511 #define CASE(p, i, v, P, I, V)                                                 \
1512   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
1513     SparseTensorCOO<V> *coo = nullptr;                                         \
1514     if (action <= Action::kFromCOO) {                                          \
1515       if (action == Action::kFromFile) {                                       \
1516         char *filename = static_cast<char *>(ptr);                             \
1517         coo = openSparseTensorCOO<V>(filename, rank, shape, perm, v);          \
1518       } else if (action == Action::kFromCOO) {                                 \
1519         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
1520       } else {                                                                 \
1521         assert(action == Action::kEmpty);                                      \
1522       }                                                                        \
1523       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
1524           rank, shape, perm, sparsity, coo);                                   \
1525       if (action == Action::kFromFile)                                         \
1526         delete coo;                                                            \
1527       return tensor;                                                           \
1528     }                                                                          \
1529     if (action == Action::kSparseToSparse) {                                   \
1530       auto *tensor = static_cast<SparseTensorStorageBase *>(ptr);              \
1531       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
1532                                                            sparsity, tensor);  \
1533     }                                                                          \
1534     if (action == Action::kEmptyCOO)                                           \
1535       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
1536     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1537     if (action == Action::kToIterator) {                                       \
1538       coo->startIterator();                                                    \
1539     } else {                                                                   \
1540       assert(action == Action::kToCOO);                                        \
1541     }                                                                          \
1542     return coo;                                                                \
1543   }
1544 
1545 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1546 
1547 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1548 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1549 // that this file cannot get out of sync with its header.
1550 static_assert(std::is_same<index_type, uint64_t>::value,
1551               "Expected index_type == uint64_t");
1552 
1553 void *
1554 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1555                              StridedMemRefType<index_type, 1> *sref,
1556                              StridedMemRefType<index_type, 1> *pref,
1557                              OverheadType ptrTp, OverheadType indTp,
1558                              PrimaryType valTp, Action action, void *ptr) {
1559   assert(aref && sref && pref);
1560   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1561          pref->strides[0] == 1);
1562   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1563   const DimLevelType *sparsity = aref->data + aref->offset;
1564   const index_type *shape = sref->data + sref->offset;
1565   const index_type *perm = pref->data + pref->offset;
1566   uint64_t rank = aref->sizes[0];
1567 
1568   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1569   // This is safe because of the static_assert above.
1570   if (ptrTp == OverheadType::kIndex)
1571     ptrTp = OverheadType::kU64;
1572   if (indTp == OverheadType::kIndex)
1573     indTp = OverheadType::kU64;
1574 
1575   // Double matrices with all combinations of overhead storage.
1576   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1577        uint64_t, double);
1578   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1579        uint32_t, double);
1580   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1581        uint16_t, double);
1582   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1583        uint8_t, double);
1584   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1585        uint64_t, double);
1586   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1587        uint32_t, double);
1588   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1589        uint16_t, double);
1590   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1591        uint8_t, double);
1592   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1593        uint64_t, double);
1594   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1595        uint32_t, double);
1596   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1597        uint16_t, double);
1598   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1599        uint8_t, double);
1600   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1601        uint64_t, double);
1602   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1603        uint32_t, double);
1604   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1605        uint16_t, double);
1606   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1607        uint8_t, double);
1608 
1609   // Float matrices with all combinations of overhead storage.
1610   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1611        uint64_t, float);
1612   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1613        uint32_t, float);
1614   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1615        uint16_t, float);
1616   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1617        uint8_t, float);
1618   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1619        uint64_t, float);
1620   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1621        uint32_t, float);
1622   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1623        uint16_t, float);
1624   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1625        uint8_t, float);
1626   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1627        uint64_t, float);
1628   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1629        uint32_t, float);
1630   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1631        uint16_t, float);
1632   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1633        uint8_t, float);
1634   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1635        uint64_t, float);
1636   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1637        uint32_t, float);
1638   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1639        uint16_t, float);
1640   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1641        uint8_t, float);
1642 
1643   // Two-byte floats with both overheads of the same type.
1644   CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16);
1645   CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16);
1646   CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16);
1647   CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16);
1648   CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16);
1649   CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16);
1650   CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16);
1651   CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16);
1652 
1653   // Integral matrices with both overheads of the same type.
1654   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1655   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1656   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1657   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1658   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t);
1659   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1660   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1661   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1662   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t);
1663   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1664   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1665   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1666   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t);
1667   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1668   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1669   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1670 
1671   // Complex matrices with wide overhead.
1672   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1673   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1674 
1675   // Unsupported case (add above if needed).
1676   // TODO: better pretty-printing of enum values!
1677   FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
1678         static_cast<int>(ptrTp), static_cast<int>(indTp),
1679         static_cast<int>(valTp));
1680 }
1681 #undef CASE
1682 #undef CASE_SECSAME
1683 
1684 #define IMPL_SPARSEVALUES(VNAME, V)                                            \
1685   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
1686                                         void *tensor) {                        \
1687     assert(ref &&tensor);                                                      \
1688     std::vector<V> *v;                                                         \
1689     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
1690     ref->basePtr = ref->data = v->data();                                      \
1691     ref->offset = 0;                                                           \
1692     ref->sizes[0] = v->size();                                                 \
1693     ref->strides[0] = 1;                                                       \
1694   }
1695 FOREVERY_V(IMPL_SPARSEVALUES)
1696 #undef IMPL_SPARSEVALUES
1697 
1698 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1699   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1700                            index_type d) {                                     \
1701     assert(ref &&tensor);                                                      \
1702     std::vector<TYPE> *v;                                                      \
1703     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1704     ref->basePtr = ref->data = v->data();                                      \
1705     ref->offset = 0;                                                           \
1706     ref->sizes[0] = v->size();                                                 \
1707     ref->strides[0] = 1;                                                       \
1708   }
1709 #define IMPL_SPARSEPOINTERS(PNAME, P)                                          \
1710   IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
1711 FOREVERY_O(IMPL_SPARSEPOINTERS)
1712 #undef IMPL_SPARSEPOINTERS
1713 
1714 #define IMPL_SPARSEINDICES(INAME, I)                                           \
1715   IMPL_GETOVERHEAD(sparseIndices##INAME, I, getIndices)
1716 FOREVERY_O(IMPL_SPARSEINDICES)
1717 #undef IMPL_SPARSEINDICES
1718 #undef IMPL_GETOVERHEAD
1719 
1720 #define IMPL_ADDELT(VNAME, V)                                                  \
1721   void *_mlir_ciface_addElt##VNAME(void *coo, StridedMemRefType<V, 0> *vref,   \
1722                                    StridedMemRefType<index_type, 1> *iref,     \
1723                                    StridedMemRefType<index_type, 1> *pref) {   \
1724     assert(coo &&vref &&iref &&pref);                                          \
1725     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1726     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1727     const index_type *indx = iref->data + iref->offset;                        \
1728     const index_type *perm = pref->data + pref->offset;                        \
1729     uint64_t isize = iref->sizes[0];                                           \
1730     std::vector<index_type> indices(isize);                                    \
1731     for (uint64_t r = 0; r < isize; r++)                                       \
1732       indices[perm[r]] = indx[r];                                              \
1733     V *value = vref->data + vref->offset;                                      \
1734     static_cast<SparseTensorCOO<V> *>(coo)->add(indices, *value);              \
1735     return coo;                                                                \
1736   }
1737 FOREVERY_V(IMPL_ADDELT)
1738 #undef IMPL_ADDELT
1739 
1740 #define IMPL_GETNEXT(VNAME, V)                                                 \
1741   bool _mlir_ciface_getNext##VNAME(void *coo,                                  \
1742                                    StridedMemRefType<index_type, 1> *iref,     \
1743                                    StridedMemRefType<V, 0> *vref) {            \
1744     assert(coo &&iref &&vref);                                                 \
1745     assert(iref->strides[0] == 1);                                             \
1746     index_type *indx = iref->data + iref->offset;                              \
1747     V *value = vref->data + vref->offset;                                      \
1748     const uint64_t isize = iref->sizes[0];                                     \
1749     const Element<V> *elem =                                                   \
1750         static_cast<SparseTensorCOO<V> *>(coo)->getNext();                     \
1751     if (elem == nullptr)                                                       \
1752       return false;                                                            \
1753     for (uint64_t r = 0; r < isize; r++)                                       \
1754       indx[r] = elem->indices[r];                                              \
1755     *value = elem->value;                                                      \
1756     return true;                                                               \
1757   }
1758 FOREVERY_V(IMPL_GETNEXT)
1759 #undef IMPL_GETNEXT
1760 
1761 #define IMPL_LEXINSERT(VNAME, V)                                               \
1762   void _mlir_ciface_lexInsert##VNAME(void *tensor,                             \
1763                                      StridedMemRefType<index_type, 1> *cref,   \
1764                                      StridedMemRefType<V, 0> *vref) {          \
1765     assert(tensor &&cref &&vref);                                              \
1766     assert(cref->strides[0] == 1);                                             \
1767     index_type *cursor = cref->data + cref->offset;                            \
1768     assert(cursor);                                                            \
1769     V *value = vref->data + vref->offset;                                      \
1770     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, *value); \
1771   }
1772 FOREVERY_V(IMPL_LEXINSERT)
1773 #undef IMPL_LEXINSERT
1774 
1775 #define IMPL_EXPINSERT(VNAME, V)                                               \
1776   void _mlir_ciface_expInsert##VNAME(                                          \
1777       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1778       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1779       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1780     assert(tensor &&cref &&vref &&fref &&aref);                                \
1781     assert(cref->strides[0] == 1);                                             \
1782     assert(vref->strides[0] == 1);                                             \
1783     assert(fref->strides[0] == 1);                                             \
1784     assert(aref->strides[0] == 1);                                             \
1785     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1786     index_type *cursor = cref->data + cref->offset;                            \
1787     V *values = vref->data + vref->offset;                                     \
1788     bool *filled = fref->data + fref->offset;                                  \
1789     index_type *added = aref->data + aref->offset;                             \
1790     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1791         cursor, values, filled, added, count);                                 \
1792   }
1793 FOREVERY_V(IMPL_EXPINSERT)
1794 #undef IMPL_EXPINSERT
1795 
1796 //===----------------------------------------------------------------------===//
1797 //
1798 // Public functions which accept only C-style data structures to interact
1799 // with sparse tensors (which are only visible as opaque pointers externally).
1800 //
1801 //===----------------------------------------------------------------------===//
1802 
1803 index_type sparseDimSize(void *tensor, index_type d) {
1804   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1805 }
1806 
1807 void endInsert(void *tensor) {
1808   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1809 }
1810 
1811 #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \
1812   void outSparseTensor##VNAME(void *coo, void *dest, bool sort) {              \
1813     return outSparseTensor<V>(coo, dest, sort);                                \
1814   }
1815 FOREVERY_V(IMPL_OUTSPARSETENSOR)
1816 #undef IMPL_OUTSPARSETENSOR
1817 
1818 void delSparseTensor(void *tensor) {
1819   delete static_cast<SparseTensorStorageBase *>(tensor);
1820 }
1821 
1822 #define IMPL_DELCOO(VNAME, V)                                                  \
1823   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1824     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1825   }
1826 FOREVERY_V(IMPL_DELCOO)
1827 #undef IMPL_DELCOO
1828 
1829 char *getTensorFilename(index_type id) {
1830   char var[80];
1831   sprintf(var, "TENSOR%" PRIu64, id);
1832   char *env = getenv(var);
1833   if (!env)
1834     FATAL("Environment variable %s is not set\n", var);
1835   return env;
1836 }
1837 
1838 void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
1839   assert(out && "Received nullptr for out-parameter");
1840   SparseTensorFile stfile(filename);
1841   stfile.openFile();
1842   stfile.readHeader();
1843   stfile.closeFile();
1844   const uint64_t rank = stfile.getRank();
1845   const uint64_t *dimSizes = stfile.getDimSizes();
1846   out->reserve(rank);
1847   out->assign(dimSizes, dimSizes + rank);
1848 }
1849 
1850 // TODO: generalize beyond 64-bit indices.
1851 #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V)                               \
1852   void *convertToMLIRSparseTensor##VNAME(                                      \
1853       uint64_t rank, uint64_t nse, uint64_t *shape, V *values,                 \
1854       uint64_t *indices, uint64_t *perm, uint8_t *sparse) {                    \
1855     return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm,      \
1856                                  sparse);                                      \
1857   }
1858 FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
1859 #undef IMPL_CONVERTTOMLIRSPARSETENSOR
1860 
1861 // TODO: Currently, values are copied from SparseTensorStorage to
1862 // SparseTensorCOO, then to the output.  We may want to reduce the number
1863 // of copies.
1864 //
1865 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1866 // compressed
1867 #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V)                             \
1868   void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank,       \
1869                                           uint64_t *pNse, uint64_t **pShape,   \
1870                                           V **pValues, uint64_t **pIndices) {  \
1871     fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices);   \
1872   }
1873 FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
1874 #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
1875 
1876 } // extern "C"
1877 
1878 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1879