1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 
19 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cctype>
24 #include <cstdio>
25 #include <cstdlib>
26 #include <cstring>
27 #include <fstream>
28 #include <functional>
29 #include <iostream>
30 #include <limits>
31 #include <numeric>
32 
33 //===----------------------------------------------------------------------===//
34 //
35 // Internal support for storing and reading sparse tensors.
36 //
37 // The following memory-resident sparse storage schemes are supported:
38 //
39 // (a) A coordinate scheme for temporarily storing and lexicographically
40 //     sorting a sparse tensor by index (SparseTensorCOO).
41 //
42 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
43 //     per-dimension sparse/dense annnotations together with a dimension
44 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
45 //
46 // The following external formats are supported:
47 //
48 // (1) Matrix Market Exchange (MME): *.mtx
49 //     https://math.nist.gov/MatrixMarket/formats.html
50 //
51 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
52 //     http://frostt.io/tensors/file-formats.html
53 //
54 // Two public APIs are supported:
55 //
56 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
57 //     tensors. These methods should be used exclusively by MLIR
58 //     compiler-generated code.
59 //
60 // (II) Methods that accept C-style data structures to interact with sparse
61 //      tensors. These methods can be used by any external runtime that wants
62 //      to interact with MLIR compiler-generated code.
63 //
64 // In both cases (I) and (II), the SparseTensorStorage format is externally
65 // only visible as an opaque pointer.
66 //
67 //===----------------------------------------------------------------------===//
68 
69 namespace {
70 
71 static constexpr int kColWidth = 1025;
72 
73 /// A version of `operator*` on `uint64_t` which checks for overflows.
74 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
75   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
76          "Integer overflow");
77   return lhs * rhs;
78 }
79 
80 // This macro helps minimize repetition of this idiom, as well as ensuring
81 // we have some additional output indicating where the error is coming from.
82 // (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
83 // to track down whether an error is coming from our code vs somewhere else
84 // in MLIR.)
85 #define FATAL(...)                                                             \
86   do {                                                                         \
87     fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__);                        \
88     exit(1);                                                                   \
89   } while (0)
90 
91 // TODO: try to unify this with `SparseTensorFile::assertMatchesShape`
92 // which is used by `openSparseTensorCOO`.  It's easy enough to resolve
93 // the `std::vector` vs pointer mismatch for `dimSizes`; but it's trickier
94 // to resolve the presence/absence of `perm` (without introducing extra
95 // overhead), so perhaps the code duplication is unavoidable.
96 //
97 /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
98 /// semantic-order to target-order) are a refinement of the desired `shape`
99 /// (in semantic-order).
100 ///
101 /// Precondition: `perm` and `shape` must be valid for `rank`.
102 static inline void
103 assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes,
104                               uint64_t rank, const uint64_t *perm,
105                               const uint64_t *shape) {
106   assert(perm && shape);
107   assert(rank == dimSizes.size() && "Rank mismatch");
108   for (uint64_t r = 0; r < rank; r++)
109     assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) &&
110            "Dimension size mismatch");
111 }
112 
113 /// A sparse tensor element in coordinate scheme (value and indices).
114 /// For example, a rank-1 vector element would look like
115 ///   ({i}, a[i])
116 /// and a rank-5 tensor element like
117 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
118 /// We use pointer to a shared index pool rather than e.g. a direct
119 /// vector since that (1) reduces the per-element memory footprint, and
120 /// (2) centralizes the memory reservation and (re)allocation to one place.
121 template <typename V>
122 struct Element final {
123   Element(uint64_t *ind, V val) : indices(ind), value(val){};
124   uint64_t *indices; // pointer into shared index pool
125   V value;
126 };
127 
128 /// The type of callback functions which receive an element.  We avoid
129 /// packaging the coordinates and value together as an `Element` object
130 /// because this helps keep code somewhat cleaner.
131 template <typename V>
132 using ElementConsumer =
133     const std::function<void(const std::vector<uint64_t> &, V)> &;
134 
135 /// A memory-resident sparse tensor in coordinate scheme (collection of
136 /// elements). This data structure is used to read a sparse tensor from
137 /// any external format into memory and sort the elements lexicographically
138 /// by indices before passing it back to the client (most packed storage
139 /// formats require the elements to appear in lexicographic index order).
140 template <typename V>
141 struct SparseTensorCOO final {
142 public:
143   SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
144       : dimSizes(dimSizes) {
145     if (capacity) {
146       elements.reserve(capacity);
147       indices.reserve(capacity * getRank());
148     }
149   }
150 
151   /// Adds element as indices and value.
152   void add(const std::vector<uint64_t> &ind, V val) {
153     assert(!iteratorLocked && "Attempt to add() after startIterator()");
154     uint64_t *base = indices.data();
155     uint64_t size = indices.size();
156     uint64_t rank = getRank();
157     assert(ind.size() == rank && "Element rank mismatch");
158     for (uint64_t r = 0; r < rank; r++) {
159       assert(ind[r] < dimSizes[r] && "Index is too large for the dimension");
160       indices.push_back(ind[r]);
161     }
162     // This base only changes if indices were reallocated. In that case, we
163     // need to correct all previous pointers into the vector. Note that this
164     // only happens if we did not set the initial capacity right, and then only
165     // for every internal vector reallocation (which with the doubling rule
166     // should only incur an amortized linear overhead).
167     uint64_t *newBase = indices.data();
168     if (newBase != base) {
169       for (uint64_t i = 0, n = elements.size(); i < n; i++)
170         elements[i].indices = newBase + (elements[i].indices - base);
171       base = newBase;
172     }
173     // Add element as (pointer into shared index pool, value) pair.
174     elements.emplace_back(base + size, val);
175   }
176 
177   /// Sorts elements lexicographically by index.
178   void sort() {
179     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
180     // TODO: we may want to cache an `isSorted` bit, to avoid
181     // unnecessary/redundant sorting.
182     uint64_t rank = getRank();
183     std::sort(elements.begin(), elements.end(),
184               [rank](const Element<V> &e1, const Element<V> &e2) {
185                 for (uint64_t r = 0; r < rank; r++) {
186                   if (e1.indices[r] == e2.indices[r])
187                     continue;
188                   return e1.indices[r] < e2.indices[r];
189                 }
190                 return false;
191               });
192   }
193 
194   /// Get the rank of the tensor.
195   uint64_t getRank() const { return dimSizes.size(); }
196 
197   /// Getter for the dimension-sizes array.
198   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
199 
200   /// Getter for the elements array.
201   const std::vector<Element<V>> &getElements() const { return elements; }
202 
203   /// Switch into iterator mode.
204   void startIterator() {
205     iteratorLocked = true;
206     iteratorPos = 0;
207   }
208 
209   /// Get the next element.
210   const Element<V> *getNext() {
211     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
212     if (iteratorPos < elements.size())
213       return &(elements[iteratorPos++]);
214     iteratorLocked = false;
215     return nullptr;
216   }
217 
218   /// Factory method. Permutes the original dimensions according to
219   /// the given ordering and expects subsequent add() calls to honor
220   /// that same ordering for the given indices. The result is a
221   /// fully permuted coordinate scheme.
222   ///
223   /// Precondition: `dimSizes` and `perm` must be valid for `rank`.
224   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
225                                                 const uint64_t *dimSizes,
226                                                 const uint64_t *perm,
227                                                 uint64_t capacity = 0) {
228     std::vector<uint64_t> permsz(rank);
229     for (uint64_t r = 0; r < rank; r++) {
230       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
231       permsz[perm[r]] = dimSizes[r];
232     }
233     return new SparseTensorCOO<V>(permsz, capacity);
234   }
235 
236 private:
237   const std::vector<uint64_t> dimSizes; // per-dimension sizes
238   std::vector<Element<V>> elements;     // all COO elements
239   std::vector<uint64_t> indices;        // shared index pool
240   bool iteratorLocked = false;
241   unsigned iteratorPos = 0;
242 };
243 
244 // Forward.
245 template <typename V>
246 class SparseTensorEnumeratorBase;
247 
248 // Helper macro for generating error messages when some
249 // `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
250 // and then the wrong "partial method specialization" is called.
251 #define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
252 
253 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
254 /// takes responsibility for all the `<P,I,V>`-independent aspects
255 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
256 /// we use function overloading to implement "partial" method
257 /// specialization, which the C-API relies on to catch type errors
258 /// arising from our use of opaque pointers.
259 class SparseTensorStorageBase {
260 public:
261   /// Constructs a new storage object.  The `perm` maps the tensor's
262   /// semantic-ordering of dimensions to this object's storage-order.
263   /// The `dimSizes` and `sparsity` arrays are already in storage-order.
264   ///
265   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
266   SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes,
267                           const uint64_t *perm, const DimLevelType *sparsity)
268       : dimSizes(dimSizes), rev(getRank()),
269         dimTypes(sparsity, sparsity + getRank()) {
270     assert(perm && sparsity);
271     const uint64_t rank = getRank();
272     // Validate parameters.
273     assert(rank > 0 && "Trivial shape is unsupported");
274     for (uint64_t r = 0; r < rank; r++) {
275       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
276       assert((dimTypes[r] == DimLevelType::kDense ||
277               dimTypes[r] == DimLevelType::kCompressed) &&
278              "Unsupported DimLevelType");
279     }
280     // Construct the "reverse" (i.e., inverse) permutation.
281     for (uint64_t r = 0; r < rank; r++)
282       rev[perm[r]] = r;
283   }
284 
285   virtual ~SparseTensorStorageBase() = default;
286 
287   /// Get the rank of the tensor.
288   uint64_t getRank() const { return dimSizes.size(); }
289 
290   /// Getter for the dimension-sizes array, in storage-order.
291   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
292 
293   /// Safely lookup the size of the given (storage-order) dimension.
294   uint64_t getDimSize(uint64_t d) const {
295     assert(d < getRank());
296     return dimSizes[d];
297   }
298 
299   /// Getter for the "reverse" permutation, which maps this object's
300   /// storage-order to the tensor's semantic-order.
301   const std::vector<uint64_t> &getRev() const { return rev; }
302 
303   /// Getter for the dimension-types array, in storage-order.
304   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
305 
306   /// Safely check if the (storage-order) dimension uses compressed storage.
307   bool isCompressedDim(uint64_t d) const {
308     assert(d < getRank());
309     return (dimTypes[d] == DimLevelType::kCompressed);
310   }
311 
312   /// Allocate a new enumerator.
313 #define DECL_NEWENUMERATOR(VNAME, V)                                           \
314   virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t,       \
315                              const uint64_t *) const {                         \
316     FATAL_PIV("newEnumerator" #VNAME);                                         \
317   }
318   FOREVERY_V(DECL_NEWENUMERATOR)
319 #undef DECL_NEWENUMERATOR
320 
321   /// Overhead storage.
322 #define DECL_GETPOINTERS(PNAME, P)                                             \
323   virtual void getPointers(std::vector<P> **, uint64_t) {                      \
324     FATAL_PIV("getPointers" #PNAME);                                           \
325   }
326   FOREVERY_FIXED_O(DECL_GETPOINTERS)
327 #undef DECL_GETPOINTERS
328 #define DECL_GETINDICES(INAME, I)                                              \
329   virtual void getIndices(std::vector<I> **, uint64_t) {                       \
330     FATAL_PIV("getIndices" #INAME);                                            \
331   }
332   FOREVERY_FIXED_O(DECL_GETINDICES)
333 #undef DECL_GETINDICES
334 
335   /// Primary storage.
336 #define DECL_GETVALUES(VNAME, V)                                               \
337   virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
338   FOREVERY_V(DECL_GETVALUES)
339 #undef DECL_GETVALUES
340 
341   /// Element-wise insertion in lexicographic index order.
342 #define DECL_LEXINSERT(VNAME, V)                                               \
343   virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
344   FOREVERY_V(DECL_LEXINSERT)
345 #undef DECL_LEXINSERT
346 
347   /// Expanded insertion.
348 #define DECL_EXPINSERT(VNAME, V)                                               \
349   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) {      \
350     FATAL_PIV("expInsert" #VNAME);                                             \
351   }
352   FOREVERY_V(DECL_EXPINSERT)
353 #undef DECL_EXPINSERT
354 
355   /// Finishes insertion.
356   virtual void endInsert() = 0;
357 
358 protected:
359   // Since this class is virtual, we must disallow public copying in
360   // order to avoid "slicing".  Since this class has data members,
361   // that means making copying protected.
362   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
363   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
364   // Copy-assignment would be implicitly deleted (because `dimSizes`
365   // is const), so we explicitly delete it for clarity.
366   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
367 
368 private:
369   const std::vector<uint64_t> dimSizes;
370   std::vector<uint64_t> rev;
371   const std::vector<DimLevelType> dimTypes;
372 };
373 
374 #undef FATAL_PIV
375 
376 // Forward.
377 template <typename P, typename I, typename V>
378 class SparseTensorEnumerator;
379 
380 /// A memory-resident sparse tensor using a storage scheme based on
381 /// per-dimension sparse/dense annotations. This data structure provides a
382 /// bufferized form of a sparse tensor type. In contrast to generating setup
383 /// methods for each differently annotated sparse tensor, this method provides
384 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
385 /// and annotations to implement all required setup in a general manner.
386 template <typename P, typename I, typename V>
387 class SparseTensorStorage final : public SparseTensorStorageBase {
388   /// Private constructor to share code between the other constructors.
389   /// Beware that the object is not necessarily guaranteed to be in a
390   /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
391   /// doesn't entail `!(pointers[d].empty())`.
392   ///
393   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
394   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
395                       const uint64_t *perm, const DimLevelType *sparsity)
396       : SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()),
397         indices(getRank()), idx(getRank()) {}
398 
399 public:
400   /// Constructs a sparse tensor storage scheme with the given dimensions,
401   /// permutation, and per-dimension dense/sparse annotations, using
402   /// the coordinate scheme tensor for the initial contents if provided.
403   ///
404   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
405   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
406                       const uint64_t *perm, const DimLevelType *sparsity,
407                       SparseTensorCOO<V> *coo)
408       : SparseTensorStorage(dimSizes, perm, sparsity) {
409     // Provide hints on capacity of pointers and indices.
410     // TODO: needs much fine-tuning based on actual sparsity; currently
411     //       we reserve pointer/index space based on all previous dense
412     //       dimensions, which works well up to first sparse dim; but
413     //       we should really use nnz and dense/sparse distribution.
414     bool allDense = true;
415     uint64_t sz = 1;
416     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
417       if (isCompressedDim(r)) {
418         // TODO: Take a parameter between 1 and `dimSizes[r]`, and multiply
419         // `sz` by that before reserving. (For now we just use 1.)
420         pointers[r].reserve(sz + 1);
421         pointers[r].push_back(0);
422         indices[r].reserve(sz);
423         sz = 1;
424         allDense = false;
425       } else { // Dense dimension.
426         sz = checkedMul(sz, getDimSizes()[r]);
427       }
428     }
429     // Then assign contents from coordinate scheme tensor if provided.
430     if (coo) {
431       // Ensure both preconditions of `fromCOO`.
432       assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch");
433       coo->sort();
434       // Now actually insert the `elements`.
435       const std::vector<Element<V>> &elements = coo->getElements();
436       uint64_t nnz = elements.size();
437       values.reserve(nnz);
438       fromCOO(elements, 0, nnz, 0);
439     } else if (allDense) {
440       values.resize(sz, 0);
441     }
442   }
443 
444   /// Constructs a sparse tensor storage scheme with the given dimensions,
445   /// permutation, and per-dimension dense/sparse annotations, using
446   /// the given sparse tensor for the initial contents.
447   ///
448   /// Preconditions:
449   /// * `perm` and `sparsity` must be valid for `dimSizes.size()`.
450   /// * The `tensor` must have the same value type `V`.
451   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
452                       const uint64_t *perm, const DimLevelType *sparsity,
453                       const SparseTensorStorageBase &tensor);
454 
455   ~SparseTensorStorage() final = default;
456 
457   /// Partially specialize these getter methods based on template types.
458   void getPointers(std::vector<P> **out, uint64_t d) final {
459     assert(d < getRank());
460     *out = &pointers[d];
461   }
462   void getIndices(std::vector<I> **out, uint64_t d) final {
463     assert(d < getRank());
464     *out = &indices[d];
465   }
466   void getValues(std::vector<V> **out) final { *out = &values; }
467 
468   /// Partially specialize lexicographical insertions based on template types.
469   void lexInsert(const uint64_t *cursor, V val) final {
470     // First, wrap up pending insertion path.
471     uint64_t diff = 0;
472     uint64_t top = 0;
473     if (!values.empty()) {
474       diff = lexDiff(cursor);
475       endPath(diff + 1);
476       top = idx[diff] + 1;
477     }
478     // Then continue with insertion path.
479     insPath(cursor, diff, top, val);
480   }
481 
482   /// Partially specialize expanded insertions based on template types.
483   /// Note that this method resets the values/filled-switch array back
484   /// to all-zero/false while only iterating over the nonzero elements.
485   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
486                  uint64_t count) final {
487     if (count == 0)
488       return;
489     // Sort.
490     std::sort(added, added + count);
491     // Restore insertion path for first insert.
492     const uint64_t lastDim = getRank() - 1;
493     uint64_t index = added[0];
494     cursor[lastDim] = index;
495     lexInsert(cursor, values[index]);
496     assert(filled[index]);
497     values[index] = 0;
498     filled[index] = false;
499     // Subsequent insertions are quick.
500     for (uint64_t i = 1; i < count; i++) {
501       assert(index < added[i] && "non-lexicographic insertion");
502       index = added[i];
503       cursor[lastDim] = index;
504       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
505       assert(filled[index]);
506       values[index] = 0;
507       filled[index] = false;
508     }
509   }
510 
511   /// Finalizes lexicographic insertions.
512   void endInsert() final {
513     if (values.empty())
514       finalizeSegment(0);
515     else
516       endPath(0);
517   }
518 
519   void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
520                      const uint64_t *perm) const final {
521     *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
522   }
523 
524   /// Returns this sparse tensor storage scheme as a new memory-resident
525   /// sparse tensor in coordinate scheme with the given dimension order.
526   ///
527   /// Precondition: `perm` must be valid for `getRank()`.
528   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
529     SparseTensorEnumeratorBase<V> *enumerator;
530     newEnumerator(&enumerator, getRank(), perm);
531     SparseTensorCOO<V> *coo =
532         new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
533     enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
534       coo->add(ind, val);
535     });
536     // TODO: This assertion assumes there are no stored zeros,
537     // or if there are then that we don't filter them out.
538     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
539     assert(coo->getElements().size() == values.size());
540     delete enumerator;
541     return coo;
542   }
543 
544   /// Factory method. Constructs a sparse tensor storage scheme with the given
545   /// dimensions, permutation, and per-dimension dense/sparse annotations,
546   /// using the coordinate scheme tensor for the initial contents if provided.
547   /// In the latter case, the coordinate scheme must respect the same
548   /// permutation as is desired for the new sparse tensor storage.
549   ///
550   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
551   static SparseTensorStorage<P, I, V> *
552   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
553                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
554     SparseTensorStorage<P, I, V> *n = nullptr;
555     if (coo) {
556       const auto &coosz = coo->getDimSizes();
557       assertPermutedSizesMatchShape(coosz, rank, perm, shape);
558       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
559     } else {
560       std::vector<uint64_t> permsz(rank);
561       for (uint64_t r = 0; r < rank; r++) {
562         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
563         permsz[perm[r]] = shape[r];
564       }
565       // We pass the null `coo` to ensure we select the intended constructor.
566       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
567     }
568     return n;
569   }
570 
571   /// Factory method. Constructs a sparse tensor storage scheme with
572   /// the given dimensions, permutation, and per-dimension dense/sparse
573   /// annotations, using the sparse tensor for the initial contents.
574   ///
575   /// Preconditions:
576   /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
577   /// * The `tensor` must have the same value type `V`.
578   static SparseTensorStorage<P, I, V> *
579   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
580                   const DimLevelType *sparsity,
581                   const SparseTensorStorageBase *source) {
582     assert(source && "Got nullptr for source");
583     SparseTensorEnumeratorBase<V> *enumerator;
584     source->newEnumerator(&enumerator, rank, perm);
585     const auto &permsz = enumerator->permutedSizes();
586     assertPermutedSizesMatchShape(permsz, rank, perm, shape);
587     auto *tensor =
588         new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
589     delete enumerator;
590     return tensor;
591   }
592 
593 private:
594   /// Appends an arbitrary new position to `pointers[d]`.  This method
595   /// checks that `pos` is representable in the `P` type; however, it
596   /// does not check that `pos` is semantically valid (i.e., larger than
597   /// the previous position and smaller than `indices[d].capacity()`).
598   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
599     assert(isCompressedDim(d));
600     assert(pos <= std::numeric_limits<P>::max() &&
601            "Pointer value is too large for the P-type");
602     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
603   }
604 
605   /// Appends index `i` to dimension `d`, in the semantically general
606   /// sense.  For non-dense dimensions, that means appending to the
607   /// `indices[d]` array, checking that `i` is representable in the `I`
608   /// type; however, we do not verify other semantic requirements (e.g.,
609   /// that `i` is in bounds for `dimSizes[d]`, and not previously occurring
610   /// in the same segment).  For dense dimensions, this method instead
611   /// appends the appropriate number of zeros to the `values` array,
612   /// where `full` is the number of "entries" already written to `values`
613   /// for this segment (aka one after the highest index previously appended).
614   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
615     if (isCompressedDim(d)) {
616       assert(i <= std::numeric_limits<I>::max() &&
617              "Index value is too large for the I-type");
618       indices[d].push_back(static_cast<I>(i));
619     } else { // Dense dimension.
620       assert(i >= full && "Index was already filled");
621       if (i == full)
622         return; // Short-circuit, since it'll be a nop.
623       if (d + 1 == getRank())
624         values.insert(values.end(), i - full, 0);
625       else
626         finalizeSegment(d + 1, 0, i - full);
627     }
628   }
629 
630   /// Writes the given coordinate to `indices[d][pos]`.  This method
631   /// checks that `i` is representable in the `I` type; however, it
632   /// does not check that `i` is semantically valid (i.e., in bounds
633   /// for `dimSizes[d]` and not elsewhere occurring in the same segment).
634   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
635     assert(isCompressedDim(d));
636     // Subscript assignment to `std::vector` requires that the `pos`-th
637     // entry has been initialized; thus we must be sure to check `size()`
638     // here, instead of `capacity()` as would be ideal.
639     assert(pos < indices[d].size() && "Index position is out of bounds");
640     assert(i <= std::numeric_limits<I>::max() &&
641            "Index value is too large for the I-type");
642     indices[d][pos] = static_cast<I>(i);
643   }
644 
645   /// Computes the assembled-size associated with the `d`-th dimension,
646   /// given the assembled-size associated with the `(d-1)`-th dimension.
647   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
648   /// storage, as opposed to "dimension-sizes" which are the cardinality
649   /// of coordinates for that dimension.
650   ///
651   /// Precondition: the `pointers[d]` array must be fully initialized
652   /// before calling this method.
653   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
654     if (isCompressedDim(d))
655       return pointers[d][parentSz];
656     // else if dense:
657     return parentSz * getDimSizes()[d];
658   }
659 
660   /// Initializes sparse tensor storage scheme from a memory-resident sparse
661   /// tensor in coordinate scheme. This method prepares the pointers and
662   /// indices arrays under the given per-dimension dense/sparse annotations.
663   ///
664   /// Preconditions:
665   /// (1) the `elements` must be lexicographically sorted.
666   /// (2) the indices of every element are valid for `dimSizes` (equal rank
667   ///     and pointwise less-than).
668   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
669                uint64_t hi, uint64_t d) {
670     uint64_t rank = getRank();
671     assert(d <= rank && hi <= elements.size());
672     // Once dimensions are exhausted, insert the numerical values.
673     if (d == rank) {
674       assert(lo < hi);
675       values.push_back(elements[lo].value);
676       return;
677     }
678     // Visit all elements in this interval.
679     uint64_t full = 0;
680     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
681       // Find segment in interval with same index elements in this dimension.
682       uint64_t i = elements[lo].indices[d];
683       uint64_t seg = lo + 1;
684       while (seg < hi && elements[seg].indices[d] == i)
685         seg++;
686       // Handle segment in interval for sparse or dense dimension.
687       appendIndex(d, full, i);
688       full = i + 1;
689       fromCOO(elements, lo, seg, d + 1);
690       // And move on to next segment in interval.
691       lo = seg;
692     }
693     // Finalize the sparse pointer structure at this dimension.
694     finalizeSegment(d, full);
695   }
696 
697   /// Finalize the sparse pointer structure at this dimension.
698   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
699     if (count == 0)
700       return; // Short-circuit, since it'll be a nop.
701     if (isCompressedDim(d)) {
702       appendPointer(d, indices[d].size(), count);
703     } else { // Dense dimension.
704       const uint64_t sz = getDimSizes()[d];
705       assert(sz >= full && "Segment is overfull");
706       count = checkedMul(count, sz - full);
707       // For dense storage we must enumerate all the remaining coordinates
708       // in this dimension (i.e., coordinates after the last non-zero
709       // element), and either fill in their zero values or else recurse
710       // to finalize some deeper dimension.
711       if (d + 1 == getRank())
712         values.insert(values.end(), count, 0);
713       else
714         finalizeSegment(d + 1, 0, count);
715     }
716   }
717 
718   /// Wraps up a single insertion path, inner to outer.
719   void endPath(uint64_t diff) {
720     uint64_t rank = getRank();
721     assert(diff <= rank);
722     for (uint64_t i = 0; i < rank - diff; i++) {
723       const uint64_t d = rank - i - 1;
724       finalizeSegment(d, idx[d] + 1);
725     }
726   }
727 
728   /// Continues a single insertion path, outer to inner.
729   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
730     uint64_t rank = getRank();
731     assert(diff < rank);
732     for (uint64_t d = diff; d < rank; d++) {
733       uint64_t i = cursor[d];
734       appendIndex(d, top, i);
735       top = 0;
736       idx[d] = i;
737     }
738     values.push_back(val);
739   }
740 
741   /// Finds the lexicographic differing dimension.
742   uint64_t lexDiff(const uint64_t *cursor) const {
743     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
744       if (cursor[r] > idx[r])
745         return r;
746       else
747         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
748     assert(0 && "duplication insertion");
749     return -1u;
750   }
751 
752   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
753   // the cost of virtual-function dispatch in inner loops), without
754   // making them public to other client code.
755   friend class SparseTensorEnumerator<P, I, V>;
756 
757   std::vector<std::vector<P>> pointers;
758   std::vector<std::vector<I>> indices;
759   std::vector<V> values;
760   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
761 };
762 
763 /// A (higher-order) function object for enumerating the elements of some
764 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
765 /// method encapsulates the loop-nest for enumerating the elements of
766 /// the source tensor (in whatever order is best for the source tensor),
767 /// and applies a permutation to the coordinates/indices before handing
768 /// each element to the callback.  A single enumerator object can be
769 /// freely reused for several calls to `forallElements`, just so long
770 /// as each call is sequential with respect to one another.
771 ///
772 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
773 /// passed to the constructor; thus, objects of this class must not
774 /// outlive the sparse tensor they depend on.
775 ///
776 /// Design Note: The reason we define this class instead of simply using
777 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
778 /// the `<P,I>` template parameters from MLIR client code (to simplify the
779 /// type parameters used for direct sparse-to-sparse conversion).  And the
780 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
781 /// than simply using this class, is to avoid the cost of virtual-method
782 /// dispatch within the loop-nest.
783 template <typename V>
784 class SparseTensorEnumeratorBase {
785 public:
786   /// Constructs an enumerator with the given permutation for mapping
787   /// the semantic-ordering of dimensions to the desired target-ordering.
788   ///
789   /// Preconditions:
790   /// * the `tensor` must have the same `V` value type.
791   /// * `perm` must be valid for `rank`.
792   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
793                              uint64_t rank, const uint64_t *perm)
794       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
795         cursor(getRank()) {
796     assert(perm && "Received nullptr for permutation");
797     assert(rank == getRank() && "Permutation rank mismatch");
798     const auto &rev = src.getRev();           // source-order -> semantic-order
799     const auto &dimSizes = src.getDimSizes(); // in source storage-order
800     for (uint64_t s = 0; s < rank; s++) {     // `s` source storage-order
801       uint64_t t = perm[rev[s]];              // `t` target-order
802       reord[s] = t;
803       permsz[t] = dimSizes[s];
804     }
805   }
806 
807   virtual ~SparseTensorEnumeratorBase() = default;
808 
809   // We disallow copying to help avoid leaking the `src` reference.
810   // (In addition to avoiding the problem of slicing.)
811   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
812   SparseTensorEnumeratorBase &
813   operator=(const SparseTensorEnumeratorBase &) = delete;
814 
815   /// Returns the source/target tensor's rank.  (The source-rank and
816   /// target-rank are always equal since we only support permutations.
817   /// Though once we add support for other dimension mappings, this
818   /// method will have to be split in two.)
819   uint64_t getRank() const { return permsz.size(); }
820 
821   /// Returns the target tensor's dimension sizes.
822   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
823 
824   /// Enumerates all elements of the source tensor, permutes their
825   /// indices, and passes the permuted element to the callback.
826   /// The callback must not store the cursor reference directly,
827   /// since this function reuses the storage.  Instead, the callback
828   /// must copy it if they want to keep it.
829   virtual void forallElements(ElementConsumer<V> yield) = 0;
830 
831 protected:
832   const SparseTensorStorageBase &src;
833   std::vector<uint64_t> permsz; // in target order.
834   std::vector<uint64_t> reord;  // source storage-order -> target order.
835   std::vector<uint64_t> cursor; // in target order.
836 };
837 
838 template <typename P, typename I, typename V>
839 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
840   using Base = SparseTensorEnumeratorBase<V>;
841 
842 public:
843   /// Constructs an enumerator with the given permutation for mapping
844   /// the semantic-ordering of dimensions to the desired target-ordering.
845   ///
846   /// Precondition: `perm` must be valid for `rank`.
847   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
848                          uint64_t rank, const uint64_t *perm)
849       : Base(tensor, rank, perm) {}
850 
851   ~SparseTensorEnumerator() final = default;
852 
853   void forallElements(ElementConsumer<V> yield) final {
854     forallElements(yield, 0, 0);
855   }
856 
857 private:
858   /// The recursive component of the public `forallElements`.
859   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
860                       uint64_t d) {
861     // Recover the `<P,I,V>` type parameters of `src`.
862     const auto &src =
863         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
864     if (d == Base::getRank()) {
865       assert(parentPos < src.values.size() &&
866              "Value position is out of bounds");
867       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
868       yield(this->cursor, src.values[parentPos]);
869     } else if (src.isCompressedDim(d)) {
870       // Look up the bounds of the `d`-level segment determined by the
871       // `d-1`-level position `parentPos`.
872       const std::vector<P> &pointersD = src.pointers[d];
873       assert(parentPos + 1 < pointersD.size() &&
874              "Parent pointer position is out of bounds");
875       const uint64_t pstart = static_cast<uint64_t>(pointersD[parentPos]);
876       const uint64_t pstop = static_cast<uint64_t>(pointersD[parentPos + 1]);
877       // Loop-invariant code for looking up the `d`-level coordinates/indices.
878       const std::vector<I> &indicesD = src.indices[d];
879       assert(pstop <= indicesD.size() && "Index position is out of bounds");
880       uint64_t &cursorReordD = this->cursor[this->reord[d]];
881       for (uint64_t pos = pstart; pos < pstop; pos++) {
882         cursorReordD = static_cast<uint64_t>(indicesD[pos]);
883         forallElements(yield, pos, d + 1);
884       }
885     } else { // Dense dimension.
886       const uint64_t sz = src.getDimSizes()[d];
887       const uint64_t pstart = parentPos * sz;
888       uint64_t &cursorReordD = this->cursor[this->reord[d]];
889       for (uint64_t i = 0; i < sz; i++) {
890         cursorReordD = i;
891         forallElements(yield, pstart + i, d + 1);
892       }
893     }
894   }
895 };
896 
897 /// Statistics regarding the number of nonzero subtensors in
898 /// a source tensor, for direct sparse=>sparse conversion a la
899 /// <https://arxiv.org/abs/2001.02609>.
900 ///
901 /// N.B., this class stores references to the parameters passed to
902 /// the constructor; thus, objects of this class must not outlive
903 /// those parameters.
904 class SparseTensorNNZ final {
905 public:
906   /// Allocate the statistics structure for the desired sizes and
907   /// sparsity (in the target tensor's storage-order).  This constructor
908   /// does not actually populate the statistics, however; for that see
909   /// `initialize`.
910   ///
911   /// Precondition: `dimSizes` must not contain zeros.
912   SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
913                   const std::vector<DimLevelType> &sparsity)
914       : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
915     assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
916     bool uncompressed = true;
917     uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
918     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
919       switch (dimTypes[r]) {
920       case DimLevelType::kCompressed:
921         assert(uncompressed &&
922                "Multiple compressed layers not currently supported");
923         uncompressed = false;
924         nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
925         break;
926       case DimLevelType::kDense:
927         assert(uncompressed &&
928                "Dense after compressed not currently supported");
929         break;
930       case DimLevelType::kSingleton:
931         // Singleton after Compressed causes no problems for allocating
932         // `nnz` nor for the yieldPos loop.  This remains true even
933         // when adding support for multiple compressed dimensions or
934         // for dense-after-compressed.
935         break;
936       }
937       sz = checkedMul(sz, dimSizes[r]);
938     }
939   }
940 
941   // We disallow copying to help avoid leaking the stored references.
942   SparseTensorNNZ(const SparseTensorNNZ &) = delete;
943   SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
944 
945   /// Returns the rank of the target tensor.
946   uint64_t getRank() const { return dimSizes.size(); }
947 
948   /// Enumerate the source tensor to fill in the statistics.  The
949   /// enumerator should already incorporate the permutation (from
950   /// semantic-order to the target storage-order).
951   template <typename V>
952   void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
953     assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
954     assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
955     enumerator.forallElements(
956         [this](const std::vector<uint64_t> &ind, V) { add(ind); });
957   }
958 
959   /// The type of callback functions which receive an nnz-statistic.
960   using NNZConsumer = const std::function<void(uint64_t)> &;
961 
962   /// Lexicographically enumerates all indicies for dimensions strictly
963   /// less than `stopDim`, and passes their nnz statistic to the callback.
964   /// Since our use-case only requires the statistic not the coordinates
965   /// themselves, we do not bother to construct those coordinates.
966   void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
967     assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
968     assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
969            "Cannot look up non-compressed dimensions");
970     forallIndices(yield, stopDim, 0, 0);
971   }
972 
973 private:
974   /// Adds a new element (i.e., increment its statistics).  We use
975   /// a method rather than inlining into the lambda in `initialize`,
976   /// to avoid spurious templating over `V`.  And this method is private
977   /// to avoid needing to re-assert validity of `ind` (which is guaranteed
978   /// by `forallElements`).
979   void add(const std::vector<uint64_t> &ind) {
980     uint64_t parentPos = 0;
981     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
982       if (dimTypes[r] == DimLevelType::kCompressed)
983         nnz[r][parentPos]++;
984       parentPos = parentPos * dimSizes[r] + ind[r];
985     }
986   }
987 
988   /// Recursive component of the public `forallIndices`.
989   void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
990                      uint64_t d) const {
991     assert(d <= stopDim);
992     if (d == stopDim) {
993       assert(parentPos < nnz[d].size() && "Cursor is out of range");
994       yield(nnz[d][parentPos]);
995     } else {
996       const uint64_t sz = dimSizes[d];
997       const uint64_t pstart = parentPos * sz;
998       for (uint64_t i = 0; i < sz; i++)
999         forallIndices(yield, stopDim, pstart + i, d + 1);
1000     }
1001   }
1002 
1003   // All of these are in the target storage-order.
1004   const std::vector<uint64_t> &dimSizes;
1005   const std::vector<DimLevelType> &dimTypes;
1006   std::vector<std::vector<uint64_t>> nnz;
1007 };
1008 
1009 template <typename P, typename I, typename V>
1010 SparseTensorStorage<P, I, V>::SparseTensorStorage(
1011     const std::vector<uint64_t> &dimSizes, const uint64_t *perm,
1012     const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1013     : SparseTensorStorage(dimSizes, perm, sparsity) {
1014   SparseTensorEnumeratorBase<V> *enumerator;
1015   tensor.newEnumerator(&enumerator, getRank(), perm);
1016   {
1017     // Initialize the statistics structure.
1018     SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
1019     nnz.initialize(*enumerator);
1020     // Initialize "pointers" overhead (and allocate "indices", "values").
1021     uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
1022     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1023       if (isCompressedDim(r)) {
1024         pointers[r].reserve(parentSz + 1);
1025         pointers[r].push_back(0);
1026         uint64_t currentPos = 0;
1027         nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
1028           currentPos += n;
1029           appendPointer(r, currentPos);
1030         });
1031         assert(pointers[r].size() == parentSz + 1 &&
1032                "Final pointers size doesn't match allocated size");
1033         // That assertion entails `assembledSize(parentSz, r)`
1034         // is now in a valid state.  That is, `pointers[r][parentSz]`
1035         // equals the present value of `currentPos`, which is the
1036         // correct assembled-size for `indices[r]`.
1037       }
1038       // Update assembled-size for the next iteration.
1039       parentSz = assembledSize(parentSz, r);
1040       // Ideally we need only `indices[r].reserve(parentSz)`, however
1041       // the `std::vector` implementation forces us to initialize it too.
1042       // That is, in the yieldPos loop we need random-access assignment
1043       // to `indices[r]`; however, `std::vector`'s subscript-assignment
1044       // only allows assigning to already-initialized positions.
1045       if (isCompressedDim(r))
1046         indices[r].resize(parentSz, 0);
1047     }
1048     values.resize(parentSz, 0); // Both allocate and zero-initialize.
1049   }
1050   // The yieldPos loop
1051   enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
1052     uint64_t parentSz = 1, parentPos = 0;
1053     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1054       if (isCompressedDim(r)) {
1055         // If `parentPos == parentSz` then it's valid as an array-lookup;
1056         // however, it's semantically invalid here since that entry
1057         // does not represent a segment of `indices[r]`.  Moreover, that
1058         // entry must be immutable for `assembledSize` to remain valid.
1059         assert(parentPos < parentSz && "Pointers position is out of bounds");
1060         const uint64_t currentPos = pointers[r][parentPos];
1061         // This increment won't overflow the `P` type, since it can't
1062         // exceed the original value of `pointers[r][parentPos+1]`
1063         // which was already verified to be within bounds for `P`
1064         // when it was written to the array.
1065         pointers[r][parentPos]++;
1066         writeIndex(r, currentPos, ind[r]);
1067         parentPos = currentPos;
1068       } else { // Dense dimension.
1069         parentPos = parentPos * getDimSizes()[r] + ind[r];
1070       }
1071       parentSz = assembledSize(parentSz, r);
1072     }
1073     assert(parentPos < values.size() && "Value position is out of bounds");
1074     values[parentPos] = val;
1075   });
1076   // No longer need the enumerator, so we'll delete it ASAP.
1077   delete enumerator;
1078   // The finalizeYieldPos loop
1079   for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
1080     if (isCompressedDim(r)) {
1081       assert(parentSz == pointers[r].size() - 1 &&
1082              "Actual pointers size doesn't match the expected size");
1083       // Can't check all of them, but at least we can check the last one.
1084       assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
1085              "Pointers got corrupted");
1086       // TODO: optimize this by using `memmove` or similar.
1087       for (uint64_t n = 0; n < parentSz; n++) {
1088         const uint64_t parentPos = parentSz - n;
1089         pointers[r][parentPos] = pointers[r][parentPos - 1];
1090       }
1091       pointers[r][0] = 0;
1092     }
1093     parentSz = assembledSize(parentSz, r);
1094   }
1095 }
1096 
1097 /// Helper to convert string to lower case.
1098 static char *toLower(char *token) {
1099   for (char *c = token; *c; c++)
1100     *c = tolower(*c);
1101   return token;
1102 }
1103 
1104 /// This class abstracts over the information stored in file headers,
1105 /// as well as providing the buffers and methods for parsing those headers.
1106 class SparseTensorFile final {
1107 public:
1108   explicit SparseTensorFile(char *filename) : filename(filename) {
1109     assert(filename && "Received nullptr for filename");
1110   }
1111 
1112   // Disallows copying, to avoid duplicating the `file` pointer.
1113   SparseTensorFile(const SparseTensorFile &) = delete;
1114   SparseTensorFile &operator=(const SparseTensorFile &) = delete;
1115 
1116   // This dtor tries to avoid leaking the `file`.  (Though it's better
1117   // to call `closeFile` explicitly when possible, since there are
1118   // circumstances where dtors are not called reliably.)
1119   ~SparseTensorFile() { closeFile(); }
1120 
1121   /// Opens the file for reading.
1122   void openFile() {
1123     if (file)
1124       FATAL("Already opened file %s\n", filename);
1125     file = fopen(filename, "r");
1126     if (!file)
1127       FATAL("Cannot find file %s\n", filename);
1128   }
1129 
1130   /// Closes the file.
1131   void closeFile() {
1132     if (file) {
1133       fclose(file);
1134       file = nullptr;
1135     }
1136   }
1137 
1138   // TODO(wrengr/bixia): figure out how to reorganize the element-parsing
1139   // loop of `openSparseTensorCOO` into methods of this class, so we can
1140   // avoid leaking access to the `line` pointer (both for general hygiene
1141   // and because we can't mark it const due to the second argument of
1142   // `strtoul`/`strtoud` being `char * *restrict` rather than
1143   // `char const* *restrict`).
1144   //
1145   /// Attempts to read a line from the file.
1146   char *readLine() {
1147     if (fgets(line, kColWidth, file))
1148       return line;
1149     FATAL("Cannot read next line of %s\n", filename);
1150   }
1151 
1152   /// Reads and parses the file's header.
1153   void readHeader() {
1154     assert(file && "Attempt to readHeader() before openFile()");
1155     if (strstr(filename, ".mtx"))
1156       readMMEHeader();
1157     else if (strstr(filename, ".tns"))
1158       readExtFROSTTHeader();
1159     else
1160       FATAL("Unknown format %s\n", filename);
1161     assert(isValid && "Failed to read the header");
1162   }
1163 
1164   /// Gets the MME "pattern" property setting.  Is only valid after
1165   /// parsing the header.
1166   bool isPattern() const {
1167     assert(isValid && "Attempt to isPattern() before readHeader()");
1168     return isPattern_;
1169   }
1170 
1171   /// Gets the MME "symmetric" property setting.  Is only valid after
1172   /// parsing the header.
1173   bool isSymmetric() const {
1174     assert(isValid && "Attempt to isSymmetric() before readHeader()");
1175     return isSymmetric_;
1176   }
1177 
1178   /// Gets the rank of the tensor.  Is only valid after parsing the header.
1179   uint64_t getRank() const {
1180     assert(isValid && "Attempt to getRank() before readHeader()");
1181     return idata[0];
1182   }
1183 
1184   /// Gets the number of non-zeros.  Is only valid after parsing the header.
1185   uint64_t getNNZ() const {
1186     assert(isValid && "Attempt to getNNZ() before readHeader()");
1187     return idata[1];
1188   }
1189 
1190   /// Gets the dimension-sizes array.  The pointer itself is always
1191   /// valid; however, the values stored therein are only valid after
1192   /// parsing the header.
1193   const uint64_t *getDimSizes() const { return idata + 2; }
1194 
1195   /// Safely gets the size of the given dimension.  Is only valid
1196   /// after parsing the header.
1197   uint64_t getDimSize(uint64_t d) const {
1198     assert(d < getRank());
1199     return idata[2 + d];
1200   }
1201 
1202   /// Asserts the shape subsumes the actual dimension sizes.  Is only
1203   /// valid after parsing the header.
1204   void assertMatchesShape(uint64_t rank, const uint64_t *shape) const {
1205     assert(rank == getRank() && "Rank mismatch");
1206     for (uint64_t r = 0; r < rank; r++)
1207       assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
1208              "Dimension size mismatch");
1209   }
1210 
1211 private:
1212   void readMMEHeader();
1213   void readExtFROSTTHeader();
1214 
1215   const char *filename;
1216   FILE *file = nullptr;
1217   bool isValid = false;
1218   bool isPattern_ = false;
1219   bool isSymmetric_ = false;
1220   uint64_t idata[512];
1221   char line[kColWidth];
1222 };
1223 
1224 /// Read the MME header of a general sparse matrix of type real.
1225 void SparseTensorFile::readMMEHeader() {
1226   char header[64];
1227   char object[64];
1228   char format[64];
1229   char field[64];
1230   char symmetry[64];
1231   // Read header line.
1232   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
1233              symmetry) != 5)
1234     FATAL("Corrupt header in %s\n", filename);
1235   // Set properties
1236   isPattern_ = (strcmp(toLower(field), "pattern") == 0);
1237   isSymmetric_ = (strcmp(toLower(symmetry), "symmetric") == 0);
1238   // Make sure this is a general sparse matrix.
1239   if (strcmp(toLower(header), "%%matrixmarket") ||
1240       strcmp(toLower(object), "matrix") ||
1241       strcmp(toLower(format), "coordinate") ||
1242       (strcmp(toLower(field), "real") && !isPattern_) ||
1243       (strcmp(toLower(symmetry), "general") && !isSymmetric_))
1244     FATAL("Cannot find a general sparse matrix in %s\n", filename);
1245   // Skip comments.
1246   while (true) {
1247     readLine();
1248     if (line[0] != '%')
1249       break;
1250   }
1251   // Next line contains M N NNZ.
1252   idata[0] = 2; // rank
1253   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
1254              idata + 1) != 3)
1255     FATAL("Cannot find size in %s\n", filename);
1256   isValid = true;
1257 }
1258 
1259 /// Read the "extended" FROSTT header. Although not part of the documented
1260 /// format, we assume that the file starts with optional comments followed
1261 /// by two lines that define the rank, the number of nonzeros, and the
1262 /// dimensions sizes (one per rank) of the sparse tensor.
1263 void SparseTensorFile::readExtFROSTTHeader() {
1264   // Skip comments.
1265   while (true) {
1266     readLine();
1267     if (line[0] != '#')
1268       break;
1269   }
1270   // Next line contains RANK and NNZ.
1271   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
1272     FATAL("Cannot find metadata in %s\n", filename);
1273   // Followed by a line with the dimension sizes (one per rank).
1274   for (uint64_t r = 0; r < idata[0]; r++)
1275     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
1276       FATAL("Cannot find dimension size %s\n", filename);
1277   readLine(); // end of line
1278   isValid = true;
1279 }
1280 
1281 /// Reads a sparse tensor with the given filename into a memory-resident
1282 /// sparse tensor in coordinate scheme.
1283 template <typename V>
1284 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
1285                                                const uint64_t *shape,
1286                                                const uint64_t *perm) {
1287   SparseTensorFile stfile(filename);
1288   stfile.openFile();
1289   stfile.readHeader();
1290   stfile.assertMatchesShape(rank, shape);
1291   // Prepare sparse tensor object with per-dimension sizes
1292   // and the number of nonzeros as initial capacity.
1293   uint64_t nnz = stfile.getNNZ();
1294   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, stfile.getDimSizes(),
1295                                                      perm, nnz);
1296   // Read all nonzero elements.
1297   std::vector<uint64_t> indices(rank);
1298   for (uint64_t k = 0; k < nnz; k++) {
1299     char *linePtr = stfile.readLine();
1300     for (uint64_t r = 0; r < rank; r++) {
1301       uint64_t idx = strtoul(linePtr, &linePtr, 10);
1302       // Add 0-based index.
1303       indices[perm[r]] = idx - 1;
1304     }
1305     // The external formats always store the numerical values with the type
1306     // double, but we cast these values to the sparse tensor object type.
1307     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
1308     double value = stfile.isPattern() ? 1.0 : strtod(linePtr, &linePtr);
1309     // TODO: <https://github.com/llvm/llvm-project/issues/54179>
1310     coo->add(indices, value);
1311     // We currently chose to deal with symmetric matrices by fully constructing
1312     // them. In the future, we may want to make symmetry implicit for storage
1313     // reasons.
1314     if (stfile.isSymmetric() && indices[0] != indices[1])
1315       coo->add({indices[1], indices[0]}, value);
1316   }
1317   // Close the file and return tensor.
1318   stfile.closeFile();
1319   return coo;
1320 }
1321 
1322 /// Writes the sparse tensor to `dest` in extended FROSTT format.
1323 template <typename V>
1324 static void outSparseTensor(void *tensor, void *dest, bool sort) {
1325   assert(tensor && dest);
1326   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
1327   if (sort)
1328     coo->sort();
1329   char *filename = static_cast<char *>(dest);
1330   auto &dimSizes = coo->getDimSizes();
1331   auto &elements = coo->getElements();
1332   uint64_t rank = coo->getRank();
1333   uint64_t nnz = elements.size();
1334   std::fstream file;
1335   file.open(filename, std::ios_base::out | std::ios_base::trunc);
1336   assert(file.is_open());
1337   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1338   for (uint64_t r = 0; r < rank - 1; r++)
1339     file << dimSizes[r] << " ";
1340   file << dimSizes[rank - 1] << std::endl;
1341   for (uint64_t i = 0; i < nnz; i++) {
1342     auto &idx = elements[i].indices;
1343     for (uint64_t r = 0; r < rank; r++)
1344       file << (idx[r] + 1) << " ";
1345     file << elements[i].value << std::endl;
1346   }
1347   file.flush();
1348   file.close();
1349   assert(file.good());
1350 }
1351 
1352 /// Initializes sparse tensor from an external COO-flavored format.
1353 template <typename V>
1354 static SparseTensorStorage<uint64_t, uint64_t, V> *
1355 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
1356                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
1357   const DimLevelType *sparsity = (DimLevelType *)(sparse);
1358 #ifndef NDEBUG
1359   // Verify that perm is a permutation of 0..(rank-1).
1360   std::vector<uint64_t> order(perm, perm + rank);
1361   std::sort(order.begin(), order.end());
1362   for (uint64_t i = 0; i < rank; ++i)
1363     if (i != order[i])
1364       FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
1365 
1366   // Verify that the sparsity values are supported.
1367   for (uint64_t i = 0; i < rank; ++i)
1368     if (sparsity[i] != DimLevelType::kDense &&
1369         sparsity[i] != DimLevelType::kCompressed)
1370       FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
1371 #endif
1372 
1373   // Convert external format to internal COO.
1374   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
1375   std::vector<uint64_t> idx(rank);
1376   for (uint64_t i = 0, base = 0; i < nse; i++) {
1377     for (uint64_t r = 0; r < rank; r++)
1378       idx[perm[r]] = indices[base + r];
1379     coo->add(idx, values[i]);
1380     base += rank;
1381   }
1382   // Return sparse tensor storage format as opaque pointer.
1383   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
1384       rank, shape, perm, sparsity, coo);
1385   delete coo;
1386   return tensor;
1387 }
1388 
1389 /// Converts a sparse tensor to an external COO-flavored format.
1390 template <typename V>
1391 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1392                                  uint64_t **pShape, V **pValues,
1393                                  uint64_t **pIndices) {
1394   assert(tensor);
1395   auto sparseTensor =
1396       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1397   uint64_t rank = sparseTensor->getRank();
1398   std::vector<uint64_t> perm(rank);
1399   std::iota(perm.begin(), perm.end(), 0);
1400   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1401 
1402   const std::vector<Element<V>> &elements = coo->getElements();
1403   uint64_t nse = elements.size();
1404 
1405   uint64_t *shape = new uint64_t[rank];
1406   for (uint64_t i = 0; i < rank; i++)
1407     shape[i] = coo->getDimSizes()[i];
1408 
1409   V *values = new V[nse];
1410   uint64_t *indices = new uint64_t[rank * nse];
1411 
1412   for (uint64_t i = 0, base = 0; i < nse; i++) {
1413     values[i] = elements[i].value;
1414     for (uint64_t j = 0; j < rank; j++)
1415       indices[base + j] = elements[i].indices[j];
1416     base += rank;
1417   }
1418 
1419   delete coo;
1420   *pRank = rank;
1421   *pNse = nse;
1422   *pShape = shape;
1423   *pValues = values;
1424   *pIndices = indices;
1425 }
1426 
1427 } // anonymous namespace
1428 
1429 extern "C" {
1430 
1431 //===----------------------------------------------------------------------===//
1432 //
1433 // Public functions which operate on MLIR buffers (memrefs) to interact
1434 // with sparse tensors (which are only visible as opaque pointers externally).
1435 //
1436 //===----------------------------------------------------------------------===//
1437 
1438 #define CASE(p, i, v, P, I, V)                                                 \
1439   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
1440     SparseTensorCOO<V> *coo = nullptr;                                         \
1441     if (action <= Action::kFromCOO) {                                          \
1442       if (action == Action::kFromFile) {                                       \
1443         char *filename = static_cast<char *>(ptr);                             \
1444         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
1445       } else if (action == Action::kFromCOO) {                                 \
1446         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
1447       } else {                                                                 \
1448         assert(action == Action::kEmpty);                                      \
1449       }                                                                        \
1450       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
1451           rank, shape, perm, sparsity, coo);                                   \
1452       if (action == Action::kFromFile)                                         \
1453         delete coo;                                                            \
1454       return tensor;                                                           \
1455     }                                                                          \
1456     if (action == Action::kSparseToSparse) {                                   \
1457       auto *tensor = static_cast<SparseTensorStorageBase *>(ptr);              \
1458       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
1459                                                            sparsity, tensor);  \
1460     }                                                                          \
1461     if (action == Action::kEmptyCOO)                                           \
1462       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
1463     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1464     if (action == Action::kToIterator) {                                       \
1465       coo->startIterator();                                                    \
1466     } else {                                                                   \
1467       assert(action == Action::kToCOO);                                        \
1468     }                                                                          \
1469     return coo;                                                                \
1470   }
1471 
1472 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1473 
1474 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1475 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1476 // that this file cannot get out of sync with its header.
1477 static_assert(std::is_same<index_type, uint64_t>::value,
1478               "Expected index_type == uint64_t");
1479 
1480 void *
1481 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1482                              StridedMemRefType<index_type, 1> *sref,
1483                              StridedMemRefType<index_type, 1> *pref,
1484                              OverheadType ptrTp, OverheadType indTp,
1485                              PrimaryType valTp, Action action, void *ptr) {
1486   assert(aref && sref && pref);
1487   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1488          pref->strides[0] == 1);
1489   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1490   const DimLevelType *sparsity = aref->data + aref->offset;
1491   const index_type *shape = sref->data + sref->offset;
1492   const index_type *perm = pref->data + pref->offset;
1493   uint64_t rank = aref->sizes[0];
1494 
1495   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1496   // This is safe because of the static_assert above.
1497   if (ptrTp == OverheadType::kIndex)
1498     ptrTp = OverheadType::kU64;
1499   if (indTp == OverheadType::kIndex)
1500     indTp = OverheadType::kU64;
1501 
1502   // Double matrices with all combinations of overhead storage.
1503   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1504        uint64_t, double);
1505   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1506        uint32_t, double);
1507   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1508        uint16_t, double);
1509   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1510        uint8_t, double);
1511   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1512        uint64_t, double);
1513   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1514        uint32_t, double);
1515   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1516        uint16_t, double);
1517   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1518        uint8_t, double);
1519   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1520        uint64_t, double);
1521   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1522        uint32_t, double);
1523   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1524        uint16_t, double);
1525   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1526        uint8_t, double);
1527   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1528        uint64_t, double);
1529   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1530        uint32_t, double);
1531   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1532        uint16_t, double);
1533   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1534        uint8_t, double);
1535 
1536   // Float matrices with all combinations of overhead storage.
1537   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1538        uint64_t, float);
1539   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1540        uint32_t, float);
1541   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1542        uint16_t, float);
1543   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1544        uint8_t, float);
1545   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1546        uint64_t, float);
1547   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1548        uint32_t, float);
1549   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1550        uint16_t, float);
1551   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1552        uint8_t, float);
1553   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1554        uint64_t, float);
1555   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1556        uint32_t, float);
1557   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1558        uint16_t, float);
1559   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1560        uint8_t, float);
1561   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1562        uint64_t, float);
1563   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1564        uint32_t, float);
1565   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1566        uint16_t, float);
1567   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1568        uint8_t, float);
1569 
1570   // Two-byte floats with both overheads of the same type.
1571   CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16);
1572   CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16);
1573   CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16);
1574   CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16);
1575   CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16);
1576   CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16);
1577   CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16);
1578   CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16);
1579 
1580   // Integral matrices with both overheads of the same type.
1581   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1582   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1583   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1584   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1585   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t);
1586   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1587   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1588   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1589   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t);
1590   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1591   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1592   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1593   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t);
1594   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1595   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1596   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1597 
1598   // Complex matrices with wide overhead.
1599   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1600   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1601 
1602   // Unsupported case (add above if needed).
1603   // TODO: better pretty-printing of enum values!
1604   FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
1605         static_cast<int>(ptrTp), static_cast<int>(indTp),
1606         static_cast<int>(valTp));
1607 }
1608 #undef CASE
1609 #undef CASE_SECSAME
1610 
1611 #define IMPL_SPARSEVALUES(VNAME, V)                                            \
1612   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
1613                                         void *tensor) {                        \
1614     assert(ref &&tensor);                                                      \
1615     std::vector<V> *v;                                                         \
1616     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
1617     ref->basePtr = ref->data = v->data();                                      \
1618     ref->offset = 0;                                                           \
1619     ref->sizes[0] = v->size();                                                 \
1620     ref->strides[0] = 1;                                                       \
1621   }
1622 FOREVERY_V(IMPL_SPARSEVALUES)
1623 #undef IMPL_SPARSEVALUES
1624 
1625 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1626   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1627                            index_type d) {                                     \
1628     assert(ref &&tensor);                                                      \
1629     std::vector<TYPE> *v;                                                      \
1630     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1631     ref->basePtr = ref->data = v->data();                                      \
1632     ref->offset = 0;                                                           \
1633     ref->sizes[0] = v->size();                                                 \
1634     ref->strides[0] = 1;                                                       \
1635   }
1636 #define IMPL_SPARSEPOINTERS(PNAME, P)                                          \
1637   IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
1638 FOREVERY_O(IMPL_SPARSEPOINTERS)
1639 #undef IMPL_SPARSEPOINTERS
1640 
1641 #define IMPL_SPARSEINDICES(INAME, I)                                           \
1642   IMPL_GETOVERHEAD(sparseIndices##INAME, I, getIndices)
1643 FOREVERY_O(IMPL_SPARSEINDICES)
1644 #undef IMPL_SPARSEINDICES
1645 #undef IMPL_GETOVERHEAD
1646 
1647 #define IMPL_ADDELT(VNAME, V)                                                  \
1648   void *_mlir_ciface_addElt##VNAME(void *coo, V value,                         \
1649                                    StridedMemRefType<index_type, 1> *iref,     \
1650                                    StridedMemRefType<index_type, 1> *pref) {   \
1651     assert(coo &&iref &&pref);                                                 \
1652     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1653     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1654     const index_type *indx = iref->data + iref->offset;                        \
1655     const index_type *perm = pref->data + pref->offset;                        \
1656     uint64_t isize = iref->sizes[0];                                           \
1657     std::vector<index_type> indices(isize);                                    \
1658     for (uint64_t r = 0; r < isize; r++)                                       \
1659       indices[perm[r]] = indx[r];                                              \
1660     static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value);               \
1661     return coo;                                                                \
1662   }
1663 FOREVERY_SIMPLEX_V(IMPL_ADDELT)
1664 IMPL_ADDELT(C64, complex64)
1665 // Marked static because it's not part of the public API.
1666 // NOTE: the `static` keyword confuses clang-format here, causing
1667 // the strange indentation of the `_mlir_ciface_addEltC32` prototype.
1668 // In C++11 we can add a semicolon after the call to `IMPL_ADDELT`
1669 // and that will correct clang-format.  Alas, this file is compiled
1670 // in C++98 mode where that semicolon is illegal (and there's no portable
1671 // macro magic to license a no-op semicolon at the top level).
1672 static IMPL_ADDELT(C32ABI, complex32)
1673 #undef IMPL_ADDELT
1674     void *_mlir_ciface_addEltC32(void *coo, float r, float i,
1675                                  StridedMemRefType<index_type, 1> *iref,
1676                                  StridedMemRefType<index_type, 1> *pref) {
1677   return _mlir_ciface_addEltC32ABI(coo, complex32(r, i), iref, pref);
1678 }
1679 
1680 #define IMPL_GETNEXT(VNAME, V)                                                 \
1681   bool _mlir_ciface_getNext##VNAME(void *coo,                                  \
1682                                    StridedMemRefType<index_type, 1> *iref,     \
1683                                    StridedMemRefType<V, 0> *vref) {            \
1684     assert(coo &&iref &&vref);                                                 \
1685     assert(iref->strides[0] == 1);                                             \
1686     index_type *indx = iref->data + iref->offset;                              \
1687     V *value = vref->data + vref->offset;                                      \
1688     const uint64_t isize = iref->sizes[0];                                     \
1689     const Element<V> *elem =                                                   \
1690         static_cast<SparseTensorCOO<V> *>(coo)->getNext();                     \
1691     if (elem == nullptr)                                                       \
1692       return false;                                                            \
1693     for (uint64_t r = 0; r < isize; r++)                                       \
1694       indx[r] = elem->indices[r];                                              \
1695     *value = elem->value;                                                      \
1696     return true;                                                               \
1697   }
1698 FOREVERY_V(IMPL_GETNEXT)
1699 #undef IMPL_GETNEXT
1700 
1701 #define IMPL_LEXINSERT(VNAME, V)                                               \
1702   void _mlir_ciface_lexInsert##VNAME(                                          \
1703       void *tensor, StridedMemRefType<index_type, 1> *cref, V val) {           \
1704     assert(tensor &&cref);                                                     \
1705     assert(cref->strides[0] == 1);                                             \
1706     index_type *cursor = cref->data + cref->offset;                            \
1707     assert(cursor);                                                            \
1708     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1709   }
1710 FOREVERY_SIMPLEX_V(IMPL_LEXINSERT)
1711 IMPL_LEXINSERT(C64, complex64)
1712 // Marked static because it's not part of the public API.
1713 // NOTE: see the note for `_mlir_ciface_addEltC32ABI`
1714 static IMPL_LEXINSERT(C32ABI, complex32)
1715 #undef IMPL_LEXINSERT
1716     void _mlir_ciface_lexInsertC32(void *tensor,
1717                                    StridedMemRefType<index_type, 1> *cref,
1718                                    float r, float i) {
1719   _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
1720 }
1721 
1722 #define IMPL_EXPINSERT(VNAME, V)                                               \
1723   void _mlir_ciface_expInsert##VNAME(                                          \
1724       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1725       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1726       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1727     assert(tensor &&cref &&vref &&fref &&aref);                                \
1728     assert(cref->strides[0] == 1);                                             \
1729     assert(vref->strides[0] == 1);                                             \
1730     assert(fref->strides[0] == 1);                                             \
1731     assert(aref->strides[0] == 1);                                             \
1732     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1733     index_type *cursor = cref->data + cref->offset;                            \
1734     V *values = vref->data + vref->offset;                                     \
1735     bool *filled = fref->data + fref->offset;                                  \
1736     index_type *added = aref->data + aref->offset;                             \
1737     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1738         cursor, values, filled, added, count);                                 \
1739   }
1740 FOREVERY_V(IMPL_EXPINSERT)
1741 #undef IMPL_EXPINSERT
1742 
1743 //===----------------------------------------------------------------------===//
1744 //
1745 // Public functions which accept only C-style data structures to interact
1746 // with sparse tensors (which are only visible as opaque pointers externally).
1747 //
1748 //===----------------------------------------------------------------------===//
1749 
1750 index_type sparseDimSize(void *tensor, index_type d) {
1751   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1752 }
1753 
1754 void endInsert(void *tensor) {
1755   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1756 }
1757 
1758 #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \
1759   void outSparseTensor##VNAME(void *coo, void *dest, bool sort) {              \
1760     return outSparseTensor<V>(coo, dest, sort);                                \
1761   }
1762 FOREVERY_V(IMPL_OUTSPARSETENSOR)
1763 #undef IMPL_OUTSPARSETENSOR
1764 
1765 void delSparseTensor(void *tensor) {
1766   delete static_cast<SparseTensorStorageBase *>(tensor);
1767 }
1768 
1769 #define IMPL_DELCOO(VNAME, V)                                                  \
1770   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1771     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1772   }
1773 FOREVERY_V(IMPL_DELCOO)
1774 #undef IMPL_DELCOO
1775 
1776 char *getTensorFilename(index_type id) {
1777   char var[80];
1778   sprintf(var, "TENSOR%" PRIu64, id);
1779   char *env = getenv(var);
1780   if (!env)
1781     FATAL("Environment variable %s is not set\n", var);
1782   return env;
1783 }
1784 
1785 void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
1786   assert(out && "Received nullptr for out-parameter");
1787   SparseTensorFile stfile(filename);
1788   stfile.openFile();
1789   stfile.readHeader();
1790   stfile.closeFile();
1791   const uint64_t rank = stfile.getRank();
1792   const uint64_t *dimSizes = stfile.getDimSizes();
1793   out->reserve(rank);
1794   out->assign(dimSizes, dimSizes + rank);
1795 }
1796 
1797 // TODO: generalize beyond 64-bit indices.
1798 #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V)                               \
1799   void *convertToMLIRSparseTensor##VNAME(                                      \
1800       uint64_t rank, uint64_t nse, uint64_t *shape, V *values,                 \
1801       uint64_t *indices, uint64_t *perm, uint8_t *sparse) {                    \
1802     return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm,      \
1803                                  sparse);                                      \
1804   }
1805 FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
1806 #undef IMPL_CONVERTTOMLIRSPARSETENSOR
1807 
1808 // TODO: Currently, values are copied from SparseTensorStorage to
1809 // SparseTensorCOO, then to the output.  We may want to reduce the number
1810 // of copies.
1811 //
1812 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1813 // compressed
1814 #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V)                             \
1815   void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank,       \
1816                                           uint64_t *pNse, uint64_t **pShape,   \
1817                                           V **pValues, uint64_t **pIndices) {  \
1818     fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices);   \
1819   }
1820 FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
1821 #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
1822 
1823 } // extern "C"
1824 
1825 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1826