1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 
19 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cctype>
24 #include <cstdio>
25 #include <cstdlib>
26 #include <cstring>
27 #include <fstream>
28 #include <functional>
29 #include <iostream>
30 #include <limits>
31 #include <numeric>
32 
33 //===----------------------------------------------------------------------===//
34 //
35 // Internal support for storing and reading sparse tensors.
36 //
37 // The following memory-resident sparse storage schemes are supported:
38 //
39 // (a) A coordinate scheme for temporarily storing and lexicographically
40 //     sorting a sparse tensor by index (SparseTensorCOO).
41 //
42 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
43 //     per-dimension sparse/dense annnotations together with a dimension
44 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
45 //
46 // The following external formats are supported:
47 //
48 // (1) Matrix Market Exchange (MME): *.mtx
49 //     https://math.nist.gov/MatrixMarket/formats.html
50 //
51 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
52 //     http://frostt.io/tensors/file-formats.html
53 //
54 // Two public APIs are supported:
55 //
56 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
57 //     tensors. These methods should be used exclusively by MLIR
58 //     compiler-generated code.
59 //
60 // (II) Methods that accept C-style data structures to interact with sparse
61 //      tensors. These methods can be used by any external runtime that wants
62 //      to interact with MLIR compiler-generated code.
63 //
64 // In both cases (I) and (II), the SparseTensorStorage format is externally
65 // only visible as an opaque pointer.
66 //
67 //===----------------------------------------------------------------------===//
68 
69 namespace {
70 
71 static constexpr int kColWidth = 1025;
72 
73 /// A version of `operator*` on `uint64_t` which checks for overflows.
74 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
75   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
76          "Integer overflow");
77   return lhs * rhs;
78 }
79 
80 // This macro helps minimize repetition of this idiom, as well as ensuring
81 // we have some additional output indicating where the error is coming from.
82 // (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
83 // to track down whether an error is coming from our code vs somewhere else
84 // in MLIR.)
85 #define FATAL(...)                                                             \
86   {                                                                            \
87     fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__);                        \
88     exit(1);                                                                   \
89   }
90 
91 // TODO: adjust this so it can be used by `openSparseTensorCOO` too.
92 // That version doesn't have the permutation, and the `dimSizes` are
93 // a pointer/C-array rather than `std::vector`.
94 //
95 /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
96 /// semantic-order to target-order) are a refinement of the desired `shape`
97 /// (in semantic-order).
98 ///
99 /// Precondition: `perm` and `shape` must be valid for `rank`.
100 static inline void
101 assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes,
102                               uint64_t rank, const uint64_t *perm,
103                               const uint64_t *shape) {
104   assert(perm && shape);
105   assert(rank == dimSizes.size() && "Rank mismatch");
106   for (uint64_t r = 0; r < rank; r++)
107     assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) &&
108            "Dimension size mismatch");
109 }
110 
111 /// A sparse tensor element in coordinate scheme (value and indices).
112 /// For example, a rank-1 vector element would look like
113 ///   ({i}, a[i])
114 /// and a rank-5 tensor element like
115 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
116 /// We use pointer to a shared index pool rather than e.g. a direct
117 /// vector since that (1) reduces the per-element memory footprint, and
118 /// (2) centralizes the memory reservation and (re)allocation to one place.
119 template <typename V>
120 struct Element final {
121   Element(uint64_t *ind, V val) : indices(ind), value(val){};
122   uint64_t *indices; // pointer into shared index pool
123   V value;
124 };
125 
126 /// The type of callback functions which receive an element.  We avoid
127 /// packaging the coordinates and value together as an `Element` object
128 /// because this helps keep code somewhat cleaner.
129 template <typename V>
130 using ElementConsumer =
131     const std::function<void(const std::vector<uint64_t> &, V)> &;
132 
133 /// A memory-resident sparse tensor in coordinate scheme (collection of
134 /// elements). This data structure is used to read a sparse tensor from
135 /// any external format into memory and sort the elements lexicographically
136 /// by indices before passing it back to the client (most packed storage
137 /// formats require the elements to appear in lexicographic index order).
138 template <typename V>
139 struct SparseTensorCOO final {
140 public:
141   SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
142       : dimSizes(dimSizes) {
143     if (capacity) {
144       elements.reserve(capacity);
145       indices.reserve(capacity * getRank());
146     }
147   }
148 
149   /// Adds element as indices and value.
150   void add(const std::vector<uint64_t> &ind, V val) {
151     assert(!iteratorLocked && "Attempt to add() after startIterator()");
152     uint64_t *base = indices.data();
153     uint64_t size = indices.size();
154     uint64_t rank = getRank();
155     assert(ind.size() == rank && "Element rank mismatch");
156     for (uint64_t r = 0; r < rank; r++) {
157       assert(ind[r] < dimSizes[r] && "Index is too large for the dimension");
158       indices.push_back(ind[r]);
159     }
160     // This base only changes if indices were reallocated. In that case, we
161     // need to correct all previous pointers into the vector. Note that this
162     // only happens if we did not set the initial capacity right, and then only
163     // for every internal vector reallocation (which with the doubling rule
164     // should only incur an amortized linear overhead).
165     uint64_t *newBase = indices.data();
166     if (newBase != base) {
167       for (uint64_t i = 0, n = elements.size(); i < n; i++)
168         elements[i].indices = newBase + (elements[i].indices - base);
169       base = newBase;
170     }
171     // Add element as (pointer into shared index pool, value) pair.
172     elements.emplace_back(base + size, val);
173   }
174 
175   /// Sorts elements lexicographically by index.
176   void sort() {
177     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
178     // TODO: we may want to cache an `isSorted` bit, to avoid
179     // unnecessary/redundant sorting.
180     uint64_t rank = getRank();
181     std::sort(elements.begin(), elements.end(),
182               [rank](const Element<V> &e1, const Element<V> &e2) {
183                 for (uint64_t r = 0; r < rank; r++) {
184                   if (e1.indices[r] == e2.indices[r])
185                     continue;
186                   return e1.indices[r] < e2.indices[r];
187                 }
188                 return false;
189               });
190   }
191 
192   /// Get the rank of the tensor.
193   uint64_t getRank() const { return dimSizes.size(); }
194 
195   /// Getter for the dimension-sizes array.
196   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
197 
198   /// Getter for the elements array.
199   const std::vector<Element<V>> &getElements() const { return elements; }
200 
201   /// Switch into iterator mode.
202   void startIterator() {
203     iteratorLocked = true;
204     iteratorPos = 0;
205   }
206 
207   /// Get the next element.
208   const Element<V> *getNext() {
209     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
210     if (iteratorPos < elements.size())
211       return &(elements[iteratorPos++]);
212     iteratorLocked = false;
213     return nullptr;
214   }
215 
216   /// Factory method. Permutes the original dimensions according to
217   /// the given ordering and expects subsequent add() calls to honor
218   /// that same ordering for the given indices. The result is a
219   /// fully permuted coordinate scheme.
220   ///
221   /// Precondition: `dimSizes` and `perm` must be valid for `rank`.
222   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
223                                                 const uint64_t *dimSizes,
224                                                 const uint64_t *perm,
225                                                 uint64_t capacity = 0) {
226     std::vector<uint64_t> permsz(rank);
227     for (uint64_t r = 0; r < rank; r++) {
228       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
229       permsz[perm[r]] = dimSizes[r];
230     }
231     return new SparseTensorCOO<V>(permsz, capacity);
232   }
233 
234 private:
235   const std::vector<uint64_t> dimSizes; // per-dimension sizes
236   std::vector<Element<V>> elements;     // all COO elements
237   std::vector<uint64_t> indices;        // shared index pool
238   bool iteratorLocked = false;
239   unsigned iteratorPos = 0;
240 };
241 
242 // Forward.
243 template <typename V>
244 class SparseTensorEnumeratorBase;
245 
246 // Helper macro for generating error messages when some
247 // `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
248 // and then the wrong "partial method specialization" is called.
249 #define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
250 
251 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
252 /// takes responsibility for all the `<P,I,V>`-independent aspects
253 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
254 /// we use function overloading to implement "partial" method
255 /// specialization, which the C-API relies on to catch type errors
256 /// arising from our use of opaque pointers.
257 class SparseTensorStorageBase {
258 public:
259   /// Constructs a new storage object.  The `perm` maps the tensor's
260   /// semantic-ordering of dimensions to this object's storage-order.
261   /// The `dimSizes` and `sparsity` arrays are already in storage-order.
262   ///
263   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
264   SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes,
265                           const uint64_t *perm, const DimLevelType *sparsity)
266       : dimSizes(dimSizes), rev(getRank()),
267         dimTypes(sparsity, sparsity + getRank()) {
268     assert(perm && sparsity);
269     const uint64_t rank = getRank();
270     // Validate parameters.
271     assert(rank > 0 && "Trivial shape is unsupported");
272     for (uint64_t r = 0; r < rank; r++) {
273       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
274       assert((dimTypes[r] == DimLevelType::kDense ||
275               dimTypes[r] == DimLevelType::kCompressed) &&
276              "Unsupported DimLevelType");
277     }
278     // Construct the "reverse" (i.e., inverse) permutation.
279     for (uint64_t r = 0; r < rank; r++)
280       rev[perm[r]] = r;
281   }
282 
283   virtual ~SparseTensorStorageBase() = default;
284 
285   /// Get the rank of the tensor.
286   uint64_t getRank() const { return dimSizes.size(); }
287 
288   /// Getter for the dimension-sizes array, in storage-order.
289   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
290 
291   /// Safely lookup the size of the given (storage-order) dimension.
292   uint64_t getDimSize(uint64_t d) const {
293     assert(d < getRank());
294     return dimSizes[d];
295   }
296 
297   /// Getter for the "reverse" permutation, which maps this object's
298   /// storage-order to the tensor's semantic-order.
299   const std::vector<uint64_t> &getRev() const { return rev; }
300 
301   /// Getter for the dimension-types array, in storage-order.
302   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
303 
304   /// Safely check if the (storage-order) dimension uses compressed storage.
305   bool isCompressedDim(uint64_t d) const {
306     assert(d < getRank());
307     return (dimTypes[d] == DimLevelType::kCompressed);
308   }
309 
310   /// Allocate a new enumerator.
311 #define DECL_NEWENUMERATOR(VNAME, V)                                           \
312   virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t,       \
313                              const uint64_t *) const {                         \
314     FATAL_PIV("newEnumerator" #VNAME);                                         \
315   }
316   FOREVERY_V(DECL_NEWENUMERATOR)
317 #undef DECL_NEWENUMERATOR
318 
319   /// Overhead storage.
320 #define DECL_GETPOINTERS(PNAME, P)                                             \
321   virtual void getPointers(std::vector<P> **, uint64_t) {                      \
322     FATAL_PIV("getPointers" #PNAME);                                           \
323   }
324   FOREVERY_FIXED_O(DECL_GETPOINTERS)
325 #undef DECL_GETPOINTERS
326 #define DECL_GETINDICES(INAME, I)                                              \
327   virtual void getIndices(std::vector<I> **, uint64_t) {                       \
328     FATAL_PIV("getIndices" #INAME);                                            \
329   }
330   FOREVERY_FIXED_O(DECL_GETINDICES)
331 #undef DECL_GETINDICES
332 
333   /// Primary storage.
334 #define DECL_GETVALUES(VNAME, V)                                               \
335   virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
336   FOREVERY_V(DECL_GETVALUES)
337 #undef DECL_GETVALUES
338 
339   /// Element-wise insertion in lexicographic index order.
340 #define DECL_LEXINSERT(VNAME, V)                                               \
341   virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
342   FOREVERY_V(DECL_LEXINSERT)
343 #undef DECL_LEXINSERT
344 
345   /// Expanded insertion.
346 #define DECL_EXPINSERT(VNAME, V)                                               \
347   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) {      \
348     FATAL_PIV("expInsert" #VNAME);                                             \
349   }
350   FOREVERY_V(DECL_EXPINSERT)
351 #undef DECL_EXPINSERT
352 
353   /// Finishes insertion.
354   virtual void endInsert() = 0;
355 
356 protected:
357   // Since this class is virtual, we must disallow public copying in
358   // order to avoid "slicing".  Since this class has data members,
359   // that means making copying protected.
360   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
361   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
362   // Copy-assignment would be implicitly deleted (because `dimSizes`
363   // is const), so we explicitly delete it for clarity.
364   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
365 
366 private:
367   const std::vector<uint64_t> dimSizes;
368   std::vector<uint64_t> rev;
369   const std::vector<DimLevelType> dimTypes;
370 };
371 
372 #undef FATAL_PIV
373 
374 // Forward.
375 template <typename P, typename I, typename V>
376 class SparseTensorEnumerator;
377 
378 /// A memory-resident sparse tensor using a storage scheme based on
379 /// per-dimension sparse/dense annotations. This data structure provides a
380 /// bufferized form of a sparse tensor type. In contrast to generating setup
381 /// methods for each differently annotated sparse tensor, this method provides
382 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
383 /// and annotations to implement all required setup in a general manner.
384 template <typename P, typename I, typename V>
385 class SparseTensorStorage final : public SparseTensorStorageBase {
386   /// Private constructor to share code between the other constructors.
387   /// Beware that the object is not necessarily guaranteed to be in a
388   /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
389   /// doesn't entail `!(pointers[d].empty())`.
390   ///
391   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
392   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
393                       const uint64_t *perm, const DimLevelType *sparsity)
394       : SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()),
395         indices(getRank()), idx(getRank()) {}
396 
397 public:
398   /// Constructs a sparse tensor storage scheme with the given dimensions,
399   /// permutation, and per-dimension dense/sparse annotations, using
400   /// the coordinate scheme tensor for the initial contents if provided.
401   ///
402   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
403   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
404                       const uint64_t *perm, const DimLevelType *sparsity,
405                       SparseTensorCOO<V> *coo)
406       : SparseTensorStorage(dimSizes, perm, sparsity) {
407     // Provide hints on capacity of pointers and indices.
408     // TODO: needs much fine-tuning based on actual sparsity; currently
409     //       we reserve pointer/index space based on all previous dense
410     //       dimensions, which works well up to first sparse dim; but
411     //       we should really use nnz and dense/sparse distribution.
412     bool allDense = true;
413     uint64_t sz = 1;
414     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
415       if (isCompressedDim(r)) {
416         // TODO: Take a parameter between 1 and `dimSizes[r]`, and multiply
417         // `sz` by that before reserving. (For now we just use 1.)
418         pointers[r].reserve(sz + 1);
419         pointers[r].push_back(0);
420         indices[r].reserve(sz);
421         sz = 1;
422         allDense = false;
423       } else { // Dense dimension.
424         sz = checkedMul(sz, getDimSizes()[r]);
425       }
426     }
427     // Then assign contents from coordinate scheme tensor if provided.
428     if (coo) {
429       // Ensure both preconditions of `fromCOO`.
430       assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch");
431       coo->sort();
432       // Now actually insert the `elements`.
433       const std::vector<Element<V>> &elements = coo->getElements();
434       uint64_t nnz = elements.size();
435       values.reserve(nnz);
436       fromCOO(elements, 0, nnz, 0);
437     } else if (allDense) {
438       values.resize(sz, 0);
439     }
440   }
441 
442   /// Constructs a sparse tensor storage scheme with the given dimensions,
443   /// permutation, and per-dimension dense/sparse annotations, using
444   /// the given sparse tensor for the initial contents.
445   ///
446   /// Preconditions:
447   /// * `perm` and `sparsity` must be valid for `dimSizes.size()`.
448   /// * The `tensor` must have the same value type `V`.
449   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
450                       const uint64_t *perm, const DimLevelType *sparsity,
451                       const SparseTensorStorageBase &tensor);
452 
453   ~SparseTensorStorage() final override = default;
454 
455   /// Partially specialize these getter methods based on template types.
456   void getPointers(std::vector<P> **out, uint64_t d) final override {
457     assert(d < getRank());
458     *out = &pointers[d];
459   }
460   void getIndices(std::vector<I> **out, uint64_t d) final override {
461     assert(d < getRank());
462     *out = &indices[d];
463   }
464   void getValues(std::vector<V> **out) final override { *out = &values; }
465 
466   /// Partially specialize lexicographical insertions based on template types.
467   void lexInsert(const uint64_t *cursor, V val) final override {
468     // First, wrap up pending insertion path.
469     uint64_t diff = 0;
470     uint64_t top = 0;
471     if (!values.empty()) {
472       diff = lexDiff(cursor);
473       endPath(diff + 1);
474       top = idx[diff] + 1;
475     }
476     // Then continue with insertion path.
477     insPath(cursor, diff, top, val);
478   }
479 
480   /// Partially specialize expanded insertions based on template types.
481   /// Note that this method resets the values/filled-switch array back
482   /// to all-zero/false while only iterating over the nonzero elements.
483   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
484                  uint64_t count) final override {
485     if (count == 0)
486       return;
487     // Sort.
488     std::sort(added, added + count);
489     // Restore insertion path for first insert.
490     const uint64_t lastDim = getRank() - 1;
491     uint64_t index = added[0];
492     cursor[lastDim] = index;
493     lexInsert(cursor, values[index]);
494     assert(filled[index]);
495     values[index] = 0;
496     filled[index] = false;
497     // Subsequent insertions are quick.
498     for (uint64_t i = 1; i < count; i++) {
499       assert(index < added[i] && "non-lexicographic insertion");
500       index = added[i];
501       cursor[lastDim] = index;
502       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
503       assert(filled[index]);
504       values[index] = 0;
505       filled[index] = false;
506     }
507   }
508 
509   /// Finalizes lexicographic insertions.
510   void endInsert() final override {
511     if (values.empty())
512       finalizeSegment(0);
513     else
514       endPath(0);
515   }
516 
517   void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
518                      const uint64_t *perm) const final override {
519     *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
520   }
521 
522   /// Returns this sparse tensor storage scheme as a new memory-resident
523   /// sparse tensor in coordinate scheme with the given dimension order.
524   ///
525   /// Precondition: `perm` must be valid for `getRank()`.
526   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
527     SparseTensorEnumeratorBase<V> *enumerator;
528     newEnumerator(&enumerator, getRank(), perm);
529     SparseTensorCOO<V> *coo =
530         new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
531     enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
532       coo->add(ind, val);
533     });
534     // TODO: This assertion assumes there are no stored zeros,
535     // or if there are then that we don't filter them out.
536     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
537     assert(coo->getElements().size() == values.size());
538     delete enumerator;
539     return coo;
540   }
541 
542   /// Factory method. Constructs a sparse tensor storage scheme with the given
543   /// dimensions, permutation, and per-dimension dense/sparse annotations,
544   /// using the coordinate scheme tensor for the initial contents if provided.
545   /// In the latter case, the coordinate scheme must respect the same
546   /// permutation as is desired for the new sparse tensor storage.
547   ///
548   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
549   static SparseTensorStorage<P, I, V> *
550   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
551                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
552     SparseTensorStorage<P, I, V> *n = nullptr;
553     if (coo) {
554       const auto &coosz = coo->getDimSizes();
555       assertPermutedSizesMatchShape(coosz, rank, perm, shape);
556       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
557     } else {
558       std::vector<uint64_t> permsz(rank);
559       for (uint64_t r = 0; r < rank; r++) {
560         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
561         permsz[perm[r]] = shape[r];
562       }
563       // We pass the null `coo` to ensure we select the intended constructor.
564       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
565     }
566     return n;
567   }
568 
569   /// Factory method. Constructs a sparse tensor storage scheme with
570   /// the given dimensions, permutation, and per-dimension dense/sparse
571   /// annotations, using the sparse tensor for the initial contents.
572   ///
573   /// Preconditions:
574   /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
575   /// * The `tensor` must have the same value type `V`.
576   static SparseTensorStorage<P, I, V> *
577   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
578                   const DimLevelType *sparsity,
579                   const SparseTensorStorageBase *source) {
580     assert(source && "Got nullptr for source");
581     SparseTensorEnumeratorBase<V> *enumerator;
582     source->newEnumerator(&enumerator, rank, perm);
583     const auto &permsz = enumerator->permutedSizes();
584     assertPermutedSizesMatchShape(permsz, rank, perm, shape);
585     auto *tensor =
586         new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
587     delete enumerator;
588     return tensor;
589   }
590 
591 private:
592   /// Appends an arbitrary new position to `pointers[d]`.  This method
593   /// checks that `pos` is representable in the `P` type; however, it
594   /// does not check that `pos` is semantically valid (i.e., larger than
595   /// the previous position and smaller than `indices[d].capacity()`).
596   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
597     assert(isCompressedDim(d));
598     assert(pos <= std::numeric_limits<P>::max() &&
599            "Pointer value is too large for the P-type");
600     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
601   }
602 
603   /// Appends index `i` to dimension `d`, in the semantically general
604   /// sense.  For non-dense dimensions, that means appending to the
605   /// `indices[d]` array, checking that `i` is representable in the `I`
606   /// type; however, we do not verify other semantic requirements (e.g.,
607   /// that `i` is in bounds for `dimSizes[d]`, and not previously occurring
608   /// in the same segment).  For dense dimensions, this method instead
609   /// appends the appropriate number of zeros to the `values` array,
610   /// where `full` is the number of "entries" already written to `values`
611   /// for this segment (aka one after the highest index previously appended).
612   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
613     if (isCompressedDim(d)) {
614       assert(i <= std::numeric_limits<I>::max() &&
615              "Index value is too large for the I-type");
616       indices[d].push_back(static_cast<I>(i));
617     } else { // Dense dimension.
618       assert(i >= full && "Index was already filled");
619       if (i == full)
620         return; // Short-circuit, since it'll be a nop.
621       if (d + 1 == getRank())
622         values.insert(values.end(), i - full, 0);
623       else
624         finalizeSegment(d + 1, 0, i - full);
625     }
626   }
627 
628   /// Writes the given coordinate to `indices[d][pos]`.  This method
629   /// checks that `i` is representable in the `I` type; however, it
630   /// does not check that `i` is semantically valid (i.e., in bounds
631   /// for `dimSizes[d]` and not elsewhere occurring in the same segment).
632   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
633     assert(isCompressedDim(d));
634     // Subscript assignment to `std::vector` requires that the `pos`-th
635     // entry has been initialized; thus we must be sure to check `size()`
636     // here, instead of `capacity()` as would be ideal.
637     assert(pos < indices[d].size() && "Index position is out of bounds");
638     assert(i <= std::numeric_limits<I>::max() &&
639            "Index value is too large for the I-type");
640     indices[d][pos] = static_cast<I>(i);
641   }
642 
643   /// Computes the assembled-size associated with the `d`-th dimension,
644   /// given the assembled-size associated with the `(d-1)`-th dimension.
645   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
646   /// storage, as opposed to "dimension-sizes" which are the cardinality
647   /// of coordinates for that dimension.
648   ///
649   /// Precondition: the `pointers[d]` array must be fully initialized
650   /// before calling this method.
651   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
652     if (isCompressedDim(d))
653       return pointers[d][parentSz];
654     // else if dense:
655     return parentSz * getDimSizes()[d];
656   }
657 
658   /// Initializes sparse tensor storage scheme from a memory-resident sparse
659   /// tensor in coordinate scheme. This method prepares the pointers and
660   /// indices arrays under the given per-dimension dense/sparse annotations.
661   ///
662   /// Preconditions:
663   /// (1) the `elements` must be lexicographically sorted.
664   /// (2) the indices of every element are valid for `dimSizes` (equal rank
665   ///     and pointwise less-than).
666   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
667                uint64_t hi, uint64_t d) {
668     uint64_t rank = getRank();
669     assert(d <= rank && hi <= elements.size());
670     // Once dimensions are exhausted, insert the numerical values.
671     if (d == rank) {
672       assert(lo < hi);
673       values.push_back(elements[lo].value);
674       return;
675     }
676     // Visit all elements in this interval.
677     uint64_t full = 0;
678     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
679       // Find segment in interval with same index elements in this dimension.
680       uint64_t i = elements[lo].indices[d];
681       uint64_t seg = lo + 1;
682       while (seg < hi && elements[seg].indices[d] == i)
683         seg++;
684       // Handle segment in interval for sparse or dense dimension.
685       appendIndex(d, full, i);
686       full = i + 1;
687       fromCOO(elements, lo, seg, d + 1);
688       // And move on to next segment in interval.
689       lo = seg;
690     }
691     // Finalize the sparse pointer structure at this dimension.
692     finalizeSegment(d, full);
693   }
694 
695   /// Finalize the sparse pointer structure at this dimension.
696   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
697     if (count == 0)
698       return; // Short-circuit, since it'll be a nop.
699     if (isCompressedDim(d)) {
700       appendPointer(d, indices[d].size(), count);
701     } else { // Dense dimension.
702       const uint64_t sz = getDimSizes()[d];
703       assert(sz >= full && "Segment is overfull");
704       count = checkedMul(count, sz - full);
705       // For dense storage we must enumerate all the remaining coordinates
706       // in this dimension (i.e., coordinates after the last non-zero
707       // element), and either fill in their zero values or else recurse
708       // to finalize some deeper dimension.
709       if (d + 1 == getRank())
710         values.insert(values.end(), count, 0);
711       else
712         finalizeSegment(d + 1, 0, count);
713     }
714   }
715 
716   /// Wraps up a single insertion path, inner to outer.
717   void endPath(uint64_t diff) {
718     uint64_t rank = getRank();
719     assert(diff <= rank);
720     for (uint64_t i = 0; i < rank - diff; i++) {
721       const uint64_t d = rank - i - 1;
722       finalizeSegment(d, idx[d] + 1);
723     }
724   }
725 
726   /// Continues a single insertion path, outer to inner.
727   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
728     uint64_t rank = getRank();
729     assert(diff < rank);
730     for (uint64_t d = diff; d < rank; d++) {
731       uint64_t i = cursor[d];
732       appendIndex(d, top, i);
733       top = 0;
734       idx[d] = i;
735     }
736     values.push_back(val);
737   }
738 
739   /// Finds the lexicographic differing dimension.
740   uint64_t lexDiff(const uint64_t *cursor) const {
741     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
742       if (cursor[r] > idx[r])
743         return r;
744       else
745         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
746     assert(0 && "duplication insertion");
747     return -1u;
748   }
749 
750   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
751   // the cost of virtual-function dispatch in inner loops), without
752   // making them public to other client code.
753   friend class SparseTensorEnumerator<P, I, V>;
754 
755   std::vector<std::vector<P>> pointers;
756   std::vector<std::vector<I>> indices;
757   std::vector<V> values;
758   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
759 };
760 
761 /// A (higher-order) function object for enumerating the elements of some
762 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
763 /// method encapsulates the loop-nest for enumerating the elements of
764 /// the source tensor (in whatever order is best for the source tensor),
765 /// and applies a permutation to the coordinates/indices before handing
766 /// each element to the callback.  A single enumerator object can be
767 /// freely reused for several calls to `forallElements`, just so long
768 /// as each call is sequential with respect to one another.
769 ///
770 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
771 /// passed to the constructor; thus, objects of this class must not
772 /// outlive the sparse tensor they depend on.
773 ///
774 /// Design Note: The reason we define this class instead of simply using
775 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
776 /// the `<P,I>` template parameters from MLIR client code (to simplify the
777 /// type parameters used for direct sparse-to-sparse conversion).  And the
778 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
779 /// than simply using this class, is to avoid the cost of virtual-method
780 /// dispatch within the loop-nest.
781 template <typename V>
782 class SparseTensorEnumeratorBase {
783 public:
784   /// Constructs an enumerator with the given permutation for mapping
785   /// the semantic-ordering of dimensions to the desired target-ordering.
786   ///
787   /// Preconditions:
788   /// * the `tensor` must have the same `V` value type.
789   /// * `perm` must be valid for `rank`.
790   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
791                              uint64_t rank, const uint64_t *perm)
792       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
793         cursor(getRank()) {
794     assert(perm && "Received nullptr for permutation");
795     assert(rank == getRank() && "Permutation rank mismatch");
796     const auto &rev = src.getRev();           // source-order -> semantic-order
797     const auto &dimSizes = src.getDimSizes(); // in source storage-order
798     for (uint64_t s = 0; s < rank; s++) {     // `s` source storage-order
799       uint64_t t = perm[rev[s]];              // `t` target-order
800       reord[s] = t;
801       permsz[t] = dimSizes[s];
802     }
803   }
804 
805   virtual ~SparseTensorEnumeratorBase() = default;
806 
807   // We disallow copying to help avoid leaking the `src` reference.
808   // (In addition to avoiding the problem of slicing.)
809   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
810   SparseTensorEnumeratorBase &
811   operator=(const SparseTensorEnumeratorBase &) = delete;
812 
813   /// Returns the source/target tensor's rank.  (The source-rank and
814   /// target-rank are always equal since we only support permutations.
815   /// Though once we add support for other dimension mappings, this
816   /// method will have to be split in two.)
817   uint64_t getRank() const { return permsz.size(); }
818 
819   /// Returns the target tensor's dimension sizes.
820   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
821 
822   /// Enumerates all elements of the source tensor, permutes their
823   /// indices, and passes the permuted element to the callback.
824   /// The callback must not store the cursor reference directly,
825   /// since this function reuses the storage.  Instead, the callback
826   /// must copy it if they want to keep it.
827   virtual void forallElements(ElementConsumer<V> yield) = 0;
828 
829 protected:
830   const SparseTensorStorageBase &src;
831   std::vector<uint64_t> permsz; // in target order.
832   std::vector<uint64_t> reord;  // source storage-order -> target order.
833   std::vector<uint64_t> cursor; // in target order.
834 };
835 
836 template <typename P, typename I, typename V>
837 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
838   using Base = SparseTensorEnumeratorBase<V>;
839 
840 public:
841   /// Constructs an enumerator with the given permutation for mapping
842   /// the semantic-ordering of dimensions to the desired target-ordering.
843   ///
844   /// Precondition: `perm` must be valid for `rank`.
845   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
846                          uint64_t rank, const uint64_t *perm)
847       : Base(tensor, rank, perm) {}
848 
849   ~SparseTensorEnumerator() final = default;
850 
851   void forallElements(ElementConsumer<V> yield) final {
852     forallElements(yield, 0, 0);
853   }
854 
855 private:
856   /// The recursive component of the public `forallElements`.
857   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
858                       uint64_t d) {
859     // Recover the `<P,I,V>` type parameters of `src`.
860     const auto &src =
861         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
862     if (d == Base::getRank()) {
863       assert(parentPos < src.values.size() &&
864              "Value position is out of bounds");
865       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
866       yield(this->cursor, src.values[parentPos]);
867     } else if (src.isCompressedDim(d)) {
868       // Look up the bounds of the `d`-level segment determined by the
869       // `d-1`-level position `parentPos`.
870       const std::vector<P> &pointers_d = src.pointers[d];
871       assert(parentPos + 1 < pointers_d.size() &&
872              "Parent pointer position is out of bounds");
873       const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
874       const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
875       // Loop-invariant code for looking up the `d`-level coordinates/indices.
876       const std::vector<I> &indices_d = src.indices[d];
877       assert(pstop <= indices_d.size() && "Index position is out of bounds");
878       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
879       for (uint64_t pos = pstart; pos < pstop; pos++) {
880         cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
881         forallElements(yield, pos, d + 1);
882       }
883     } else { // Dense dimension.
884       const uint64_t sz = src.getDimSizes()[d];
885       const uint64_t pstart = parentPos * sz;
886       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
887       for (uint64_t i = 0; i < sz; i++) {
888         cursor_reord_d = i;
889         forallElements(yield, pstart + i, d + 1);
890       }
891     }
892   }
893 };
894 
895 /// Statistics regarding the number of nonzero subtensors in
896 /// a source tensor, for direct sparse=>sparse conversion a la
897 /// <https://arxiv.org/abs/2001.02609>.
898 ///
899 /// N.B., this class stores references to the parameters passed to
900 /// the constructor; thus, objects of this class must not outlive
901 /// those parameters.
902 class SparseTensorNNZ final {
903 public:
904   /// Allocate the statistics structure for the desired sizes and
905   /// sparsity (in the target tensor's storage-order).  This constructor
906   /// does not actually populate the statistics, however; for that see
907   /// `initialize`.
908   ///
909   /// Precondition: `dimSizes` must not contain zeros.
910   SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
911                   const std::vector<DimLevelType> &sparsity)
912       : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
913     assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
914     bool uncompressed = true;
915     uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
916     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
917       switch (dimTypes[r]) {
918       case DimLevelType::kCompressed:
919         assert(uncompressed &&
920                "Multiple compressed layers not currently supported");
921         uncompressed = false;
922         nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
923         break;
924       case DimLevelType::kDense:
925         assert(uncompressed &&
926                "Dense after compressed not currently supported");
927         break;
928       case DimLevelType::kSingleton:
929         // Singleton after Compressed causes no problems for allocating
930         // `nnz` nor for the yieldPos loop.  This remains true even
931         // when adding support for multiple compressed dimensions or
932         // for dense-after-compressed.
933         break;
934       }
935       sz = checkedMul(sz, dimSizes[r]);
936     }
937   }
938 
939   // We disallow copying to help avoid leaking the stored references.
940   SparseTensorNNZ(const SparseTensorNNZ &) = delete;
941   SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
942 
943   /// Returns the rank of the target tensor.
944   uint64_t getRank() const { return dimSizes.size(); }
945 
946   /// Enumerate the source tensor to fill in the statistics.  The
947   /// enumerator should already incorporate the permutation (from
948   /// semantic-order to the target storage-order).
949   template <typename V>
950   void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
951     assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
952     assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
953     enumerator.forallElements(
954         [this](const std::vector<uint64_t> &ind, V) { add(ind); });
955   }
956 
957   /// The type of callback functions which receive an nnz-statistic.
958   using NNZConsumer = const std::function<void(uint64_t)> &;
959 
960   /// Lexicographically enumerates all indicies for dimensions strictly
961   /// less than `stopDim`, and passes their nnz statistic to the callback.
962   /// Since our use-case only requires the statistic not the coordinates
963   /// themselves, we do not bother to construct those coordinates.
964   void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
965     assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
966     assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
967            "Cannot look up non-compressed dimensions");
968     forallIndices(yield, stopDim, 0, 0);
969   }
970 
971 private:
972   /// Adds a new element (i.e., increment its statistics).  We use
973   /// a method rather than inlining into the lambda in `initialize`,
974   /// to avoid spurious templating over `V`.  And this method is private
975   /// to avoid needing to re-assert validity of `ind` (which is guaranteed
976   /// by `forallElements`).
977   void add(const std::vector<uint64_t> &ind) {
978     uint64_t parentPos = 0;
979     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
980       if (dimTypes[r] == DimLevelType::kCompressed)
981         nnz[r][parentPos]++;
982       parentPos = parentPos * dimSizes[r] + ind[r];
983     }
984   }
985 
986   /// Recursive component of the public `forallIndices`.
987   void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
988                      uint64_t d) const {
989     assert(d <= stopDim);
990     if (d == stopDim) {
991       assert(parentPos < nnz[d].size() && "Cursor is out of range");
992       yield(nnz[d][parentPos]);
993     } else {
994       const uint64_t sz = dimSizes[d];
995       const uint64_t pstart = parentPos * sz;
996       for (uint64_t i = 0; i < sz; i++)
997         forallIndices(yield, stopDim, pstart + i, d + 1);
998     }
999   }
1000 
1001   // All of these are in the target storage-order.
1002   const std::vector<uint64_t> &dimSizes;
1003   const std::vector<DimLevelType> &dimTypes;
1004   std::vector<std::vector<uint64_t>> nnz;
1005 };
1006 
1007 template <typename P, typename I, typename V>
1008 SparseTensorStorage<P, I, V>::SparseTensorStorage(
1009     const std::vector<uint64_t> &dimSizes, const uint64_t *perm,
1010     const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1011     : SparseTensorStorage(dimSizes, perm, sparsity) {
1012   SparseTensorEnumeratorBase<V> *enumerator;
1013   tensor.newEnumerator(&enumerator, getRank(), perm);
1014   {
1015     // Initialize the statistics structure.
1016     SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
1017     nnz.initialize(*enumerator);
1018     // Initialize "pointers" overhead (and allocate "indices", "values").
1019     uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
1020     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1021       if (isCompressedDim(r)) {
1022         pointers[r].reserve(parentSz + 1);
1023         pointers[r].push_back(0);
1024         uint64_t currentPos = 0;
1025         nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
1026           currentPos += n;
1027           appendPointer(r, currentPos);
1028         });
1029         assert(pointers[r].size() == parentSz + 1 &&
1030                "Final pointers size doesn't match allocated size");
1031         // That assertion entails `assembledSize(parentSz, r)`
1032         // is now in a valid state.  That is, `pointers[r][parentSz]`
1033         // equals the present value of `currentPos`, which is the
1034         // correct assembled-size for `indices[r]`.
1035       }
1036       // Update assembled-size for the next iteration.
1037       parentSz = assembledSize(parentSz, r);
1038       // Ideally we need only `indices[r].reserve(parentSz)`, however
1039       // the `std::vector` implementation forces us to initialize it too.
1040       // That is, in the yieldPos loop we need random-access assignment
1041       // to `indices[r]`; however, `std::vector`'s subscript-assignment
1042       // only allows assigning to already-initialized positions.
1043       if (isCompressedDim(r))
1044         indices[r].resize(parentSz, 0);
1045     }
1046     values.resize(parentSz, 0); // Both allocate and zero-initialize.
1047   }
1048   // The yieldPos loop
1049   enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
1050     uint64_t parentSz = 1, parentPos = 0;
1051     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1052       if (isCompressedDim(r)) {
1053         // If `parentPos == parentSz` then it's valid as an array-lookup;
1054         // however, it's semantically invalid here since that entry
1055         // does not represent a segment of `indices[r]`.  Moreover, that
1056         // entry must be immutable for `assembledSize` to remain valid.
1057         assert(parentPos < parentSz && "Pointers position is out of bounds");
1058         const uint64_t currentPos = pointers[r][parentPos];
1059         // This increment won't overflow the `P` type, since it can't
1060         // exceed the original value of `pointers[r][parentPos+1]`
1061         // which was already verified to be within bounds for `P`
1062         // when it was written to the array.
1063         pointers[r][parentPos]++;
1064         writeIndex(r, currentPos, ind[r]);
1065         parentPos = currentPos;
1066       } else { // Dense dimension.
1067         parentPos = parentPos * getDimSizes()[r] + ind[r];
1068       }
1069       parentSz = assembledSize(parentSz, r);
1070     }
1071     assert(parentPos < values.size() && "Value position is out of bounds");
1072     values[parentPos] = val;
1073   });
1074   // No longer need the enumerator, so we'll delete it ASAP.
1075   delete enumerator;
1076   // The finalizeYieldPos loop
1077   for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
1078     if (isCompressedDim(r)) {
1079       assert(parentSz == pointers[r].size() - 1 &&
1080              "Actual pointers size doesn't match the expected size");
1081       // Can't check all of them, but at least we can check the last one.
1082       assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
1083              "Pointers got corrupted");
1084       // TODO: optimize this by using `memmove` or similar.
1085       for (uint64_t n = 0; n < parentSz; n++) {
1086         const uint64_t parentPos = parentSz - n;
1087         pointers[r][parentPos] = pointers[r][parentPos - 1];
1088       }
1089       pointers[r][0] = 0;
1090     }
1091     parentSz = assembledSize(parentSz, r);
1092   }
1093 }
1094 
1095 /// Helper to convert string to lower case.
1096 static char *toLower(char *token) {
1097   for (char *c = token; *c; c++)
1098     *c = tolower(*c);
1099   return token;
1100 }
1101 
1102 /// Read the MME header of a general sparse matrix of type real.
1103 static void readMMEHeader(FILE *file, char *filename, char *line,
1104                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
1105   char header[64];
1106   char object[64];
1107   char format[64];
1108   char field[64];
1109   char symmetry[64];
1110   // Read header line.
1111   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
1112              symmetry) != 5)
1113     FATAL("Corrupt header in %s\n", filename);
1114   // Set properties
1115   *isPattern = (strcmp(toLower(field), "pattern") == 0);
1116   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
1117   // Make sure this is a general sparse matrix.
1118   if (strcmp(toLower(header), "%%matrixmarket") ||
1119       strcmp(toLower(object), "matrix") ||
1120       strcmp(toLower(format), "coordinate") ||
1121       (strcmp(toLower(field), "real") && !(*isPattern)) ||
1122       (strcmp(toLower(symmetry), "general") && !(*isSymmetric)))
1123     FATAL("Cannot find a general sparse matrix in %s\n", filename);
1124   // Skip comments.
1125   while (true) {
1126     if (!fgets(line, kColWidth, file))
1127       FATAL("Cannot find data in %s\n", filename);
1128     if (line[0] != '%')
1129       break;
1130   }
1131   // Next line contains M N NNZ.
1132   idata[0] = 2; // rank
1133   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
1134              idata + 1) != 3)
1135     FATAL("Cannot find size in %s\n", filename);
1136 }
1137 
1138 /// Read the "extended" FROSTT header. Although not part of the documented
1139 /// format, we assume that the file starts with optional comments followed
1140 /// by two lines that define the rank, the number of nonzeros, and the
1141 /// dimensions sizes (one per rank) of the sparse tensor.
1142 static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
1143                                 uint64_t *idata) {
1144   // Skip comments.
1145   while (true) {
1146     if (!fgets(line, kColWidth, file))
1147       FATAL("Cannot find data in %s\n", filename);
1148     if (line[0] != '#')
1149       break;
1150   }
1151   // Next line contains RANK and NNZ.
1152   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
1153     FATAL("Cannot find metadata in %s\n", filename);
1154   // Followed by a line with the dimension sizes (one per rank).
1155   for (uint64_t r = 0; r < idata[0]; r++)
1156     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
1157       FATAL("Cannot find dimension size %s\n", filename);
1158   fgets(line, kColWidth, file); // end of line
1159 }
1160 
1161 /// Reads a sparse tensor with the given filename into a memory-resident
1162 /// sparse tensor in coordinate scheme.
1163 template <typename V>
1164 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
1165                                                const uint64_t *shape,
1166                                                const uint64_t *perm) {
1167   // Open the file.
1168   assert(filename && "Received nullptr for filename");
1169   FILE *file = fopen(filename, "r");
1170   if (!file)
1171     FATAL("Cannot find file %s\n", filename);
1172   // Perform some file format dependent set up.
1173   char line[kColWidth];
1174   uint64_t idata[512];
1175   bool isPattern = false;
1176   bool isSymmetric = false;
1177   if (strstr(filename, ".mtx")) {
1178     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
1179   } else if (strstr(filename, ".tns")) {
1180     readExtFROSTTHeader(file, filename, line, idata);
1181   } else {
1182     FATAL("Unknown format %s\n", filename);
1183   }
1184   // Prepare sparse tensor object with per-dimension sizes
1185   // and the number of nonzeros as initial capacity.
1186   assert(rank == idata[0] && "rank mismatch");
1187   uint64_t nnz = idata[1];
1188   for (uint64_t r = 0; r < rank; r++)
1189     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
1190            "dimension size mismatch");
1191   SparseTensorCOO<V> *tensor =
1192       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
1193   // Read all nonzero elements.
1194   std::vector<uint64_t> indices(rank);
1195   for (uint64_t k = 0; k < nnz; k++) {
1196     if (!fgets(line, kColWidth, file))
1197       FATAL("Cannot find next line of data in %s\n", filename);
1198     char *linePtr = line;
1199     for (uint64_t r = 0; r < rank; r++) {
1200       uint64_t idx = strtoul(linePtr, &linePtr, 10);
1201       // Add 0-based index.
1202       indices[perm[r]] = idx - 1;
1203     }
1204     // The external formats always store the numerical values with the type
1205     // double, but we cast these values to the sparse tensor object type.
1206     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
1207     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
1208     tensor->add(indices, value);
1209     // We currently chose to deal with symmetric matrices by fully constructing
1210     // them. In the future, we may want to make symmetry implicit for storage
1211     // reasons.
1212     if (isSymmetric && indices[0] != indices[1])
1213       tensor->add({indices[1], indices[0]}, value);
1214   }
1215   // Close the file and return tensor.
1216   fclose(file);
1217   return tensor;
1218 }
1219 
1220 /// Writes the sparse tensor to `dest` in extended FROSTT format.
1221 template <typename V>
1222 static void outSparseTensor(void *tensor, void *dest, bool sort) {
1223   assert(tensor && dest);
1224   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
1225   if (sort)
1226     coo->sort();
1227   char *filename = static_cast<char *>(dest);
1228   auto &dimSizes = coo->getDimSizes();
1229   auto &elements = coo->getElements();
1230   uint64_t rank = coo->getRank();
1231   uint64_t nnz = elements.size();
1232   std::fstream file;
1233   file.open(filename, std::ios_base::out | std::ios_base::trunc);
1234   assert(file.is_open());
1235   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1236   for (uint64_t r = 0; r < rank - 1; r++)
1237     file << dimSizes[r] << " ";
1238   file << dimSizes[rank - 1] << std::endl;
1239   for (uint64_t i = 0; i < nnz; i++) {
1240     auto &idx = elements[i].indices;
1241     for (uint64_t r = 0; r < rank; r++)
1242       file << (idx[r] + 1) << " ";
1243     file << elements[i].value << std::endl;
1244   }
1245   file.flush();
1246   file.close();
1247   assert(file.good());
1248 }
1249 
1250 /// Initializes sparse tensor from an external COO-flavored format.
1251 template <typename V>
1252 static SparseTensorStorage<uint64_t, uint64_t, V> *
1253 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
1254                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
1255   const DimLevelType *sparsity = (DimLevelType *)(sparse);
1256 #ifndef NDEBUG
1257   // Verify that perm is a permutation of 0..(rank-1).
1258   std::vector<uint64_t> order(perm, perm + rank);
1259   std::sort(order.begin(), order.end());
1260   for (uint64_t i = 0; i < rank; ++i)
1261     if (i != order[i])
1262       FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
1263 
1264   // Verify that the sparsity values are supported.
1265   for (uint64_t i = 0; i < rank; ++i)
1266     if (sparsity[i] != DimLevelType::kDense &&
1267         sparsity[i] != DimLevelType::kCompressed)
1268       FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
1269 #endif
1270 
1271   // Convert external format to internal COO.
1272   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
1273   std::vector<uint64_t> idx(rank);
1274   for (uint64_t i = 0, base = 0; i < nse; i++) {
1275     for (uint64_t r = 0; r < rank; r++)
1276       idx[perm[r]] = indices[base + r];
1277     coo->add(idx, values[i]);
1278     base += rank;
1279   }
1280   // Return sparse tensor storage format as opaque pointer.
1281   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
1282       rank, shape, perm, sparsity, coo);
1283   delete coo;
1284   return tensor;
1285 }
1286 
1287 /// Converts a sparse tensor to an external COO-flavored format.
1288 template <typename V>
1289 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1290                                  uint64_t **pShape, V **pValues,
1291                                  uint64_t **pIndices) {
1292   assert(tensor);
1293   auto sparseTensor =
1294       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1295   uint64_t rank = sparseTensor->getRank();
1296   std::vector<uint64_t> perm(rank);
1297   std::iota(perm.begin(), perm.end(), 0);
1298   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1299 
1300   const std::vector<Element<V>> &elements = coo->getElements();
1301   uint64_t nse = elements.size();
1302 
1303   uint64_t *shape = new uint64_t[rank];
1304   for (uint64_t i = 0; i < rank; i++)
1305     shape[i] = coo->getDimSizes()[i];
1306 
1307   V *values = new V[nse];
1308   uint64_t *indices = new uint64_t[rank * nse];
1309 
1310   for (uint64_t i = 0, base = 0; i < nse; i++) {
1311     values[i] = elements[i].value;
1312     for (uint64_t j = 0; j < rank; j++)
1313       indices[base + j] = elements[i].indices[j];
1314     base += rank;
1315   }
1316 
1317   delete coo;
1318   *pRank = rank;
1319   *pNse = nse;
1320   *pShape = shape;
1321   *pValues = values;
1322   *pIndices = indices;
1323 }
1324 
1325 } // anonymous namespace
1326 
1327 extern "C" {
1328 
1329 //===----------------------------------------------------------------------===//
1330 //
1331 // Public functions which operate on MLIR buffers (memrefs) to interact
1332 // with sparse tensors (which are only visible as opaque pointers externally).
1333 //
1334 //===----------------------------------------------------------------------===//
1335 
1336 #define CASE(p, i, v, P, I, V)                                                 \
1337   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
1338     SparseTensorCOO<V> *coo = nullptr;                                         \
1339     if (action <= Action::kFromCOO) {                                          \
1340       if (action == Action::kFromFile) {                                       \
1341         char *filename = static_cast<char *>(ptr);                             \
1342         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
1343       } else if (action == Action::kFromCOO) {                                 \
1344         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
1345       } else {                                                                 \
1346         assert(action == Action::kEmpty);                                      \
1347       }                                                                        \
1348       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
1349           rank, shape, perm, sparsity, coo);                                   \
1350       if (action == Action::kFromFile)                                         \
1351         delete coo;                                                            \
1352       return tensor;                                                           \
1353     }                                                                          \
1354     if (action == Action::kSparseToSparse) {                                   \
1355       auto *tensor = static_cast<SparseTensorStorageBase *>(ptr);              \
1356       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
1357                                                            sparsity, tensor);  \
1358     }                                                                          \
1359     if (action == Action::kEmptyCOO)                                           \
1360       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
1361     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1362     if (action == Action::kToIterator) {                                       \
1363       coo->startIterator();                                                    \
1364     } else {                                                                   \
1365       assert(action == Action::kToCOO);                                        \
1366     }                                                                          \
1367     return coo;                                                                \
1368   }
1369 
1370 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1371 
1372 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1373 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1374 // that this file cannot get out of sync with its header.
1375 static_assert(std::is_same<index_type, uint64_t>::value,
1376               "Expected index_type == uint64_t");
1377 
1378 void *
1379 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1380                              StridedMemRefType<index_type, 1> *sref,
1381                              StridedMemRefType<index_type, 1> *pref,
1382                              OverheadType ptrTp, OverheadType indTp,
1383                              PrimaryType valTp, Action action, void *ptr) {
1384   assert(aref && sref && pref);
1385   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1386          pref->strides[0] == 1);
1387   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1388   const DimLevelType *sparsity = aref->data + aref->offset;
1389   const index_type *shape = sref->data + sref->offset;
1390   const index_type *perm = pref->data + pref->offset;
1391   uint64_t rank = aref->sizes[0];
1392 
1393   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1394   // This is safe because of the static_assert above.
1395   if (ptrTp == OverheadType::kIndex)
1396     ptrTp = OverheadType::kU64;
1397   if (indTp == OverheadType::kIndex)
1398     indTp = OverheadType::kU64;
1399 
1400   // Double matrices with all combinations of overhead storage.
1401   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1402        uint64_t, double);
1403   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1404        uint32_t, double);
1405   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1406        uint16_t, double);
1407   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1408        uint8_t, double);
1409   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1410        uint64_t, double);
1411   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1412        uint32_t, double);
1413   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1414        uint16_t, double);
1415   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1416        uint8_t, double);
1417   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1418        uint64_t, double);
1419   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1420        uint32_t, double);
1421   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1422        uint16_t, double);
1423   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1424        uint8_t, double);
1425   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1426        uint64_t, double);
1427   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1428        uint32_t, double);
1429   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1430        uint16_t, double);
1431   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1432        uint8_t, double);
1433 
1434   // Float matrices with all combinations of overhead storage.
1435   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1436        uint64_t, float);
1437   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1438        uint32_t, float);
1439   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1440        uint16_t, float);
1441   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1442        uint8_t, float);
1443   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1444        uint64_t, float);
1445   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1446        uint32_t, float);
1447   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1448        uint16_t, float);
1449   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1450        uint8_t, float);
1451   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1452        uint64_t, float);
1453   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1454        uint32_t, float);
1455   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1456        uint16_t, float);
1457   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1458        uint8_t, float);
1459   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1460        uint64_t, float);
1461   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1462        uint32_t, float);
1463   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1464        uint16_t, float);
1465   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1466        uint8_t, float);
1467 
1468   // Integral matrices with both overheads of the same type.
1469   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1470   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1471   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1472   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1473   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t);
1474   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1475   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1476   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1477   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t);
1478   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1479   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1480   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1481   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t);
1482   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1483   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1484   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1485 
1486   // Complex matrices with wide overhead.
1487   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1488   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1489 
1490   // Unsupported case (add above if needed).
1491   // TODO: better pretty-printing of enum values!
1492   FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
1493         static_cast<int>(ptrTp), static_cast<int>(indTp),
1494         static_cast<int>(valTp));
1495 }
1496 #undef CASE
1497 #undef CASE_SECSAME
1498 
1499 #define IMPL_SPARSEVALUES(VNAME, V)                                            \
1500   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
1501                                         void *tensor) {                        \
1502     assert(ref &&tensor);                                                      \
1503     std::vector<V> *v;                                                         \
1504     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
1505     ref->basePtr = ref->data = v->data();                                      \
1506     ref->offset = 0;                                                           \
1507     ref->sizes[0] = v->size();                                                 \
1508     ref->strides[0] = 1;                                                       \
1509   }
1510 FOREVERY_V(IMPL_SPARSEVALUES)
1511 #undef IMPL_SPARSEVALUES
1512 
1513 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1514   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1515                            index_type d) {                                     \
1516     assert(ref &&tensor);                                                      \
1517     std::vector<TYPE> *v;                                                      \
1518     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1519     ref->basePtr = ref->data = v->data();                                      \
1520     ref->offset = 0;                                                           \
1521     ref->sizes[0] = v->size();                                                 \
1522     ref->strides[0] = 1;                                                       \
1523   }
1524 #define IMPL_SPARSEPOINTERS(PNAME, P)                                          \
1525   IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
1526 FOREVERY_O(IMPL_SPARSEPOINTERS)
1527 #undef IMPL_SPARSEPOINTERS
1528 
1529 #define IMPL_SPARSEINDICES(INAME, I)                                           \
1530   IMPL_GETOVERHEAD(sparseIndices##INAME, I, getIndices)
1531 FOREVERY_O(IMPL_SPARSEINDICES)
1532 #undef IMPL_SPARSEINDICES
1533 #undef IMPL_GETOVERHEAD
1534 
1535 #define IMPL_ADDELT(VNAME, V)                                                  \
1536   void *_mlir_ciface_addElt##VNAME(void *coo, V value,                         \
1537                                    StridedMemRefType<index_type, 1> *iref,     \
1538                                    StridedMemRefType<index_type, 1> *pref) {   \
1539     assert(coo &&iref &&pref);                                                 \
1540     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1541     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1542     const index_type *indx = iref->data + iref->offset;                        \
1543     const index_type *perm = pref->data + pref->offset;                        \
1544     uint64_t isize = iref->sizes[0];                                           \
1545     std::vector<index_type> indices(isize);                                    \
1546     for (uint64_t r = 0; r < isize; r++)                                       \
1547       indices[perm[r]] = indx[r];                                              \
1548     static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value);               \
1549     return coo;                                                                \
1550   }
1551 FOREVERY_SIMPLEX_V(IMPL_ADDELT)
1552 IMPL_ADDELT(C64, complex64)
1553 // Marked static because it's not part of the public API.
1554 // NOTE: the `static` keyword confuses clang-format here, causing
1555 // the strange indentation of the `_mlir_ciface_addEltC32` prototype.
1556 // In C++11 we can add a semicolon after the call to `IMPL_ADDELT`
1557 // and that will correct clang-format.  Alas, this file is compiled
1558 // in C++98 mode where that semicolon is illegal (and there's no portable
1559 // macro magic to license a no-op semicolon at the top level).
1560 static IMPL_ADDELT(C32ABI, complex32)
1561 #undef IMPL_ADDELT
1562     void *_mlir_ciface_addEltC32(void *coo, float r, float i,
1563                                  StridedMemRefType<index_type, 1> *iref,
1564                                  StridedMemRefType<index_type, 1> *pref) {
1565   return _mlir_ciface_addEltC32ABI(coo, complex32(r, i), iref, pref);
1566 }
1567 
1568 #define IMPL_GETNEXT(VNAME, V)                                                 \
1569   bool _mlir_ciface_getNext##VNAME(void *coo,                                  \
1570                                    StridedMemRefType<index_type, 1> *iref,     \
1571                                    StridedMemRefType<V, 0> *vref) {            \
1572     assert(coo &&iref &&vref);                                                 \
1573     assert(iref->strides[0] == 1);                                             \
1574     index_type *indx = iref->data + iref->offset;                              \
1575     V *value = vref->data + vref->offset;                                      \
1576     const uint64_t isize = iref->sizes[0];                                     \
1577     const Element<V> *elem =                                                   \
1578         static_cast<SparseTensorCOO<V> *>(coo)->getNext();                     \
1579     if (elem == nullptr)                                                       \
1580       return false;                                                            \
1581     for (uint64_t r = 0; r < isize; r++)                                       \
1582       indx[r] = elem->indices[r];                                              \
1583     *value = elem->value;                                                      \
1584     return true;                                                               \
1585   }
1586 FOREVERY_V(IMPL_GETNEXT)
1587 #undef IMPL_GETNEXT
1588 
1589 #define IMPL_LEXINSERT(VNAME, V)                                               \
1590   void _mlir_ciface_lexInsert##VNAME(                                          \
1591       void *tensor, StridedMemRefType<index_type, 1> *cref, V val) {           \
1592     assert(tensor &&cref);                                                     \
1593     assert(cref->strides[0] == 1);                                             \
1594     index_type *cursor = cref->data + cref->offset;                            \
1595     assert(cursor);                                                            \
1596     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1597   }
1598 FOREVERY_SIMPLEX_V(IMPL_LEXINSERT)
1599 IMPL_LEXINSERT(C64, complex64)
1600 // Marked static because it's not part of the public API.
1601 // NOTE: see the note for `_mlir_ciface_addEltC32ABI`
1602 static IMPL_LEXINSERT(C32ABI, complex32)
1603 #undef IMPL_LEXINSERT
1604     void _mlir_ciface_lexInsertC32(void *tensor,
1605                                    StridedMemRefType<index_type, 1> *cref,
1606                                    float r, float i) {
1607   _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
1608 }
1609 
1610 #define IMPL_EXPINSERT(VNAME, V)                                               \
1611   void _mlir_ciface_expInsert##VNAME(                                          \
1612       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1613       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1614       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1615     assert(tensor &&cref &&vref &&fref &&aref);                                \
1616     assert(cref->strides[0] == 1);                                             \
1617     assert(vref->strides[0] == 1);                                             \
1618     assert(fref->strides[0] == 1);                                             \
1619     assert(aref->strides[0] == 1);                                             \
1620     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1621     index_type *cursor = cref->data + cref->offset;                            \
1622     V *values = vref->data + vref->offset;                                     \
1623     bool *filled = fref->data + fref->offset;                                  \
1624     index_type *added = aref->data + aref->offset;                             \
1625     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1626         cursor, values, filled, added, count);                                 \
1627   }
1628 FOREVERY_V(IMPL_EXPINSERT)
1629 #undef IMPL_EXPINSERT
1630 
1631 //===----------------------------------------------------------------------===//
1632 //
1633 // Public functions which accept only C-style data structures to interact
1634 // with sparse tensors (which are only visible as opaque pointers externally).
1635 //
1636 //===----------------------------------------------------------------------===//
1637 
1638 index_type sparseDimSize(void *tensor, index_type d) {
1639   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1640 }
1641 
1642 void endInsert(void *tensor) {
1643   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1644 }
1645 
1646 #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \
1647   void outSparseTensor##VNAME(void *coo, void *dest, bool sort) {              \
1648     return outSparseTensor<V>(coo, dest, sort);                                \
1649   }
1650 FOREVERY_V(IMPL_OUTSPARSETENSOR)
1651 #undef IMPL_OUTSPARSETENSOR
1652 
1653 void delSparseTensor(void *tensor) {
1654   delete static_cast<SparseTensorStorageBase *>(tensor);
1655 }
1656 
1657 #define IMPL_DELCOO(VNAME, V)                                                  \
1658   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1659     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1660   }
1661 FOREVERY_V(IMPL_DELCOO)
1662 #undef IMPL_DELCOO
1663 
1664 char *getTensorFilename(index_type id) {
1665   char var[80];
1666   sprintf(var, "TENSOR%" PRIu64, id);
1667   char *env = getenv(var);
1668   if (!env)
1669     FATAL("Environment variable %s is not set\n", var);
1670   return env;
1671 }
1672 
1673 // TODO: generalize beyond 64-bit indices.
1674 #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V)                               \
1675   void *convertToMLIRSparseTensor##VNAME(                                      \
1676       uint64_t rank, uint64_t nse, uint64_t *shape, V *values,                 \
1677       uint64_t *indices, uint64_t *perm, uint8_t *sparse) {                    \
1678     return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm,      \
1679                                  sparse);                                      \
1680   }
1681 FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
1682 #undef IMPL_CONVERTTOMLIRSPARSETENSOR
1683 
1684 // TODO: Currently, values are copied from SparseTensorStorage to
1685 // SparseTensorCOO, then to the output.  We may want to reduce the number
1686 // of copies.
1687 //
1688 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1689 // compressed
1690 #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V)                             \
1691   void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank,       \
1692                                           uint64_t *pNse, uint64_t **pShape,   \
1693                                           V **pValues, uint64_t **pIndices) {  \
1694     fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices);   \
1695   }
1696 FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
1697 #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
1698 
1699 } // extern "C"
1700 
1701 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1702