1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <complex>
25 #include <cctype>
26 #include <cinttypes>
27 #include <cstdio>
28 #include <cstdlib>
29 #include <cstring>
30 #include <fstream>
31 #include <functional>
32 #include <iostream>
33 #include <limits>
34 #include <numeric>
35 #include <vector>
36 
37 using complex64 = std::complex<double>;
38 using complex32 = std::complex<float>;
39 
40 //===----------------------------------------------------------------------===//
41 //
42 // Internal support for storing and reading sparse tensors.
43 //
44 // The following memory-resident sparse storage schemes are supported:
45 //
46 // (a) A coordinate scheme for temporarily storing and lexicographically
47 //     sorting a sparse tensor by index (SparseTensorCOO).
48 //
49 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
50 //     per-dimension sparse/dense annnotations together with a dimension
51 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
52 //
53 // The following external formats are supported:
54 //
55 // (1) Matrix Market Exchange (MME): *.mtx
56 //     https://math.nist.gov/MatrixMarket/formats.html
57 //
58 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
59 //     http://frostt.io/tensors/file-formats.html
60 //
61 // Two public APIs are supported:
62 //
63 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
64 //     tensors. These methods should be used exclusively by MLIR
65 //     compiler-generated code.
66 //
67 // (II) Methods that accept C-style data structures to interact with sparse
68 //      tensors. These methods can be used by any external runtime that wants
69 //      to interact with MLIR compiler-generated code.
70 //
71 // In both cases (I) and (II), the SparseTensorStorage format is externally
72 // only visible as an opaque pointer.
73 //
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 
78 static constexpr int kColWidth = 1025;
79 
80 /// A version of `operator*` on `uint64_t` which checks for overflows.
81 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
82   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
83          "Integer overflow");
84   return lhs * rhs;
85 }
86 
87 // This macro helps minimize repetition of this idiom, as well as ensuring
88 // we have some additional output indicating where the error is coming from.
89 // (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
90 // to track down whether an error is coming from our code vs somewhere else
91 // in MLIR.)
92 #define FATAL(...)                                                             \
93   {                                                                            \
94     fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__);                        \
95     exit(1);                                                                   \
96   }
97 
98 // TODO: adjust this so it can be used by `openSparseTensorCOO` too.
99 // That version doesn't have the permutation, and the `dimSizes` are
100 // a pointer/C-array rather than `std::vector`.
101 //
102 /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
103 /// semantic-order to target-order) are a refinement of the desired `shape`
104 /// (in semantic-order).
105 ///
106 /// Precondition: `perm` and `shape` must be valid for `rank`.
107 static inline void
108 assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes,
109                               uint64_t rank, const uint64_t *perm,
110                               const uint64_t *shape) {
111   assert(perm && shape);
112   assert(rank == dimSizes.size() && "Rank mismatch");
113   for (uint64_t r = 0; r < rank; r++)
114     assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) &&
115            "Dimension size mismatch");
116 }
117 
118 /// A sparse tensor element in coordinate scheme (value and indices).
119 /// For example, a rank-1 vector element would look like
120 ///   ({i}, a[i])
121 /// and a rank-5 tensor element like
122 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
123 /// We use pointer to a shared index pool rather than e.g. a direct
124 /// vector since that (1) reduces the per-element memory footprint, and
125 /// (2) centralizes the memory reservation and (re)allocation to one place.
126 template <typename V>
127 struct Element final {
128   Element(uint64_t *ind, V val) : indices(ind), value(val){};
129   uint64_t *indices; // pointer into shared index pool
130   V value;
131 };
132 
133 /// The type of callback functions which receive an element.  We avoid
134 /// packaging the coordinates and value together as an `Element` object
135 /// because this helps keep code somewhat cleaner.
136 template <typename V>
137 using ElementConsumer =
138     const std::function<void(const std::vector<uint64_t> &, V)> &;
139 
140 /// A memory-resident sparse tensor in coordinate scheme (collection of
141 /// elements). This data structure is used to read a sparse tensor from
142 /// any external format into memory and sort the elements lexicographically
143 /// by indices before passing it back to the client (most packed storage
144 /// formats require the elements to appear in lexicographic index order).
145 template <typename V>
146 struct SparseTensorCOO final {
147 public:
148   SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
149       : dimSizes(dimSizes) {
150     if (capacity) {
151       elements.reserve(capacity);
152       indices.reserve(capacity * getRank());
153     }
154   }
155 
156   /// Adds element as indices and value.
157   void add(const std::vector<uint64_t> &ind, V val) {
158     assert(!iteratorLocked && "Attempt to add() after startIterator()");
159     uint64_t *base = indices.data();
160     uint64_t size = indices.size();
161     uint64_t rank = getRank();
162     assert(ind.size() == rank && "Element rank mismatch");
163     for (uint64_t r = 0; r < rank; r++) {
164       assert(ind[r] < dimSizes[r] && "Index is too large for the dimension");
165       indices.push_back(ind[r]);
166     }
167     // This base only changes if indices were reallocated. In that case, we
168     // need to correct all previous pointers into the vector. Note that this
169     // only happens if we did not set the initial capacity right, and then only
170     // for every internal vector reallocation (which with the doubling rule
171     // should only incur an amortized linear overhead).
172     uint64_t *newBase = indices.data();
173     if (newBase != base) {
174       for (uint64_t i = 0, n = elements.size(); i < n; i++)
175         elements[i].indices = newBase + (elements[i].indices - base);
176       base = newBase;
177     }
178     // Add element as (pointer into shared index pool, value) pair.
179     elements.emplace_back(base + size, val);
180   }
181 
182   /// Sorts elements lexicographically by index.
183   void sort() {
184     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
185     // TODO: we may want to cache an `isSorted` bit, to avoid
186     // unnecessary/redundant sorting.
187     uint64_t rank = getRank();
188     std::sort(elements.begin(), elements.end(),
189               [rank](const Element<V> &e1, const Element<V> &e2) {
190                 for (uint64_t r = 0; r < rank; r++) {
191                   if (e1.indices[r] == e2.indices[r])
192                     continue;
193                   return e1.indices[r] < e2.indices[r];
194                 }
195                 return false;
196               });
197   }
198 
199   /// Get the rank of the tensor.
200   uint64_t getRank() const { return dimSizes.size(); }
201 
202   /// Getter for the dimension-sizes array.
203   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
204 
205   /// Getter for the elements array.
206   const std::vector<Element<V>> &getElements() const { return elements; }
207 
208   /// Switch into iterator mode.
209   void startIterator() {
210     iteratorLocked = true;
211     iteratorPos = 0;
212   }
213 
214   /// Get the next element.
215   const Element<V> *getNext() {
216     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
217     if (iteratorPos < elements.size())
218       return &(elements[iteratorPos++]);
219     iteratorLocked = false;
220     return nullptr;
221   }
222 
223   /// Factory method. Permutes the original dimensions according to
224   /// the given ordering and expects subsequent add() calls to honor
225   /// that same ordering for the given indices. The result is a
226   /// fully permuted coordinate scheme.
227   ///
228   /// Precondition: `dimSizes` and `perm` must be valid for `rank`.
229   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
230                                                 const uint64_t *dimSizes,
231                                                 const uint64_t *perm,
232                                                 uint64_t capacity = 0) {
233     std::vector<uint64_t> permsz(rank);
234     for (uint64_t r = 0; r < rank; r++) {
235       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
236       permsz[perm[r]] = dimSizes[r];
237     }
238     return new SparseTensorCOO<V>(permsz, capacity);
239   }
240 
241 private:
242   const std::vector<uint64_t> dimSizes; // per-dimension sizes
243   std::vector<Element<V>> elements;     // all COO elements
244   std::vector<uint64_t> indices;        // shared index pool
245   bool iteratorLocked = false;
246   unsigned iteratorPos = 0;
247 };
248 
249 // See <https://en.wikipedia.org/wiki/X_Macro>
250 //
251 // `FOREVERY_SIMPLEX_V` only specifies the non-complex `V` types, because
252 // the ABI for complex types has compiler/architecture dependent complexities
253 // we need to work around.  Namely, when a function takes a parameter of
254 // C/C++ type `complex32` (per se), then there is additional padding that
255 // causes it not to match the LLVM type `!llvm.struct<(f32, f32)>`.  This
256 // only happens with the `complex32` type itself, not with pointers/arrays
257 // of complex values.  So far `complex64` doesn't exhibit this ABI
258 // incompatibility, but we exclude it anyways just to be safe.
259 #define FOREVERY_SIMPLEX_V(DO)                                                 \
260   DO(F64, double)                                                              \
261   DO(F32, float)                                                               \
262   DO(I64, int64_t)                                                             \
263   DO(I32, int32_t)                                                             \
264   DO(I16, int16_t)                                                             \
265   DO(I8, int8_t)
266 
267 #define FOREVERY_V(DO)                                                         \
268   FOREVERY_SIMPLEX_V(DO)                                                       \
269   DO(C64, complex64)                                                           \
270   DO(C32, complex32)
271 
272 // Forward.
273 template <typename V>
274 class SparseTensorEnumeratorBase;
275 
276 // Helper macro for generating error messages when some
277 // `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
278 // and then the wrong "partial method specialization" is called.
279 #define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
280 
281 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
282 /// takes responsibility for all the `<P,I,V>`-independent aspects
283 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
284 /// we use function overloading to implement "partial" method
285 /// specialization, which the C-API relies on to catch type errors
286 /// arising from our use of opaque pointers.
287 class SparseTensorStorageBase {
288 public:
289   /// Constructs a new storage object.  The `perm` maps the tensor's
290   /// semantic-ordering of dimensions to this object's storage-order.
291   /// The `dimSizes` and `sparsity` arrays are already in storage-order.
292   ///
293   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
294   SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes,
295                           const uint64_t *perm, const DimLevelType *sparsity)
296       : dimSizes(dimSizes), rev(getRank()),
297         dimTypes(sparsity, sparsity + getRank()) {
298     assert(perm && sparsity);
299     const uint64_t rank = getRank();
300     // Validate parameters.
301     assert(rank > 0 && "Trivial shape is unsupported");
302     for (uint64_t r = 0; r < rank; r++) {
303       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
304       assert((dimTypes[r] == DimLevelType::kDense ||
305               dimTypes[r] == DimLevelType::kCompressed) &&
306              "Unsupported DimLevelType");
307     }
308     // Construct the "reverse" (i.e., inverse) permutation.
309     for (uint64_t r = 0; r < rank; r++)
310       rev[perm[r]] = r;
311   }
312 
313   virtual ~SparseTensorStorageBase() = default;
314 
315   /// Get the rank of the tensor.
316   uint64_t getRank() const { return dimSizes.size(); }
317 
318   /// Getter for the dimension-sizes array, in storage-order.
319   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
320 
321   /// Safely lookup the size of the given (storage-order) dimension.
322   uint64_t getDimSize(uint64_t d) const {
323     assert(d < getRank());
324     return dimSizes[d];
325   }
326 
327   /// Getter for the "reverse" permutation, which maps this object's
328   /// storage-order to the tensor's semantic-order.
329   const std::vector<uint64_t> &getRev() const { return rev; }
330 
331   /// Getter for the dimension-types array, in storage-order.
332   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
333 
334   /// Safely check if the (storage-order) dimension uses compressed storage.
335   bool isCompressedDim(uint64_t d) const {
336     assert(d < getRank());
337     return (dimTypes[d] == DimLevelType::kCompressed);
338   }
339 
340   /// Allocate a new enumerator.
341 #define DECL_NEWENUMERATOR(VNAME, V)                                           \
342   virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t,       \
343                              const uint64_t *) const {                         \
344     FATAL_PIV("newEnumerator" #VNAME);                                         \
345   }
346   FOREVERY_V(DECL_NEWENUMERATOR)
347 #undef DECL_NEWENUMERATOR
348 
349   /// Overhead storage.
350   virtual void getPointers(std::vector<uint64_t> **, uint64_t) {
351     FATAL_PIV("p64");
352   }
353   virtual void getPointers(std::vector<uint32_t> **, uint64_t) {
354     FATAL_PIV("p32");
355   }
356   virtual void getPointers(std::vector<uint16_t> **, uint64_t) {
357     FATAL_PIV("p16");
358   }
359   virtual void getPointers(std::vector<uint8_t> **, uint64_t) {
360     FATAL_PIV("p8");
361   }
362   virtual void getIndices(std::vector<uint64_t> **, uint64_t) {
363     FATAL_PIV("i64");
364   }
365   virtual void getIndices(std::vector<uint32_t> **, uint64_t) {
366     FATAL_PIV("i32");
367   }
368   virtual void getIndices(std::vector<uint16_t> **, uint64_t) {
369     FATAL_PIV("i16");
370   }
371   virtual void getIndices(std::vector<uint8_t> **, uint64_t) {
372     FATAL_PIV("i8");
373   }
374 
375   /// Primary storage.
376 #define DECL_GETVALUES(VNAME, V)                                               \
377   virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
378   FOREVERY_V(DECL_GETVALUES)
379 #undef DECL_GETVALUES
380 
381   /// Element-wise insertion in lexicographic index order.
382 #define DECL_LEXINSERT(VNAME, V)                                               \
383   virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
384   FOREVERY_V(DECL_LEXINSERT)
385 #undef DECL_LEXINSERT
386 
387   /// Expanded insertion.
388 #define DECL_EXPINSERT(VNAME, V)                                               \
389   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) {      \
390     FATAL_PIV("expInsert" #VNAME);                                             \
391   }
392   FOREVERY_V(DECL_EXPINSERT)
393 #undef DECL_EXPINSERT
394 
395   /// Finishes insertion.
396   virtual void endInsert() = 0;
397 
398 protected:
399   // Since this class is virtual, we must disallow public copying in
400   // order to avoid "slicing".  Since this class has data members,
401   // that means making copying protected.
402   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
403   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
404   // Copy-assignment would be implicitly deleted (because `dimSizes`
405   // is const), so we explicitly delete it for clarity.
406   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
407 
408 private:
409   const std::vector<uint64_t> dimSizes;
410   std::vector<uint64_t> rev;
411   const std::vector<DimLevelType> dimTypes;
412 };
413 
414 #undef FATAL_PIV
415 
416 // Forward.
417 template <typename P, typename I, typename V>
418 class SparseTensorEnumerator;
419 
420 /// A memory-resident sparse tensor using a storage scheme based on
421 /// per-dimension sparse/dense annotations. This data structure provides a
422 /// bufferized form of a sparse tensor type. In contrast to generating setup
423 /// methods for each differently annotated sparse tensor, this method provides
424 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
425 /// and annotations to implement all required setup in a general manner.
426 template <typename P, typename I, typename V>
427 class SparseTensorStorage final : public SparseTensorStorageBase {
428   /// Private constructor to share code between the other constructors.
429   /// Beware that the object is not necessarily guaranteed to be in a
430   /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
431   /// doesn't entail `!(pointers[d].empty())`.
432   ///
433   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
434   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
435                       const uint64_t *perm, const DimLevelType *sparsity)
436       : SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()),
437         indices(getRank()), idx(getRank()) {}
438 
439 public:
440   /// Constructs a sparse tensor storage scheme with the given dimensions,
441   /// permutation, and per-dimension dense/sparse annotations, using
442   /// the coordinate scheme tensor for the initial contents if provided.
443   ///
444   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
445   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
446                       const uint64_t *perm, const DimLevelType *sparsity,
447                       SparseTensorCOO<V> *coo)
448       : SparseTensorStorage(dimSizes, perm, sparsity) {
449     // Provide hints on capacity of pointers and indices.
450     // TODO: needs much fine-tuning based on actual sparsity; currently
451     //       we reserve pointer/index space based on all previous dense
452     //       dimensions, which works well up to first sparse dim; but
453     //       we should really use nnz and dense/sparse distribution.
454     bool allDense = true;
455     uint64_t sz = 1;
456     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
457       if (isCompressedDim(r)) {
458         // TODO: Take a parameter between 1 and `dimSizes[r]`, and multiply
459         // `sz` by that before reserving. (For now we just use 1.)
460         pointers[r].reserve(sz + 1);
461         pointers[r].push_back(0);
462         indices[r].reserve(sz);
463         sz = 1;
464         allDense = false;
465       } else { // Dense dimension.
466         sz = checkedMul(sz, getDimSizes()[r]);
467       }
468     }
469     // Then assign contents from coordinate scheme tensor if provided.
470     if (coo) {
471       // Ensure both preconditions of `fromCOO`.
472       assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch");
473       coo->sort();
474       // Now actually insert the `elements`.
475       const std::vector<Element<V>> &elements = coo->getElements();
476       uint64_t nnz = elements.size();
477       values.reserve(nnz);
478       fromCOO(elements, 0, nnz, 0);
479     } else if (allDense) {
480       values.resize(sz, 0);
481     }
482   }
483 
484   /// Constructs a sparse tensor storage scheme with the given dimensions,
485   /// permutation, and per-dimension dense/sparse annotations, using
486   /// the given sparse tensor for the initial contents.
487   ///
488   /// Preconditions:
489   /// * `perm` and `sparsity` must be valid for `dimSizes.size()`.
490   /// * The `tensor` must have the same value type `V`.
491   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
492                       const uint64_t *perm, const DimLevelType *sparsity,
493                       const SparseTensorStorageBase &tensor);
494 
495   ~SparseTensorStorage() final override = default;
496 
497   /// Partially specialize these getter methods based on template types.
498   void getPointers(std::vector<P> **out, uint64_t d) final override {
499     assert(d < getRank());
500     *out = &pointers[d];
501   }
502   void getIndices(std::vector<I> **out, uint64_t d) final override {
503     assert(d < getRank());
504     *out = &indices[d];
505   }
506   void getValues(std::vector<V> **out) final override { *out = &values; }
507 
508   /// Partially specialize lexicographical insertions based on template types.
509   void lexInsert(const uint64_t *cursor, V val) final override {
510     // First, wrap up pending insertion path.
511     uint64_t diff = 0;
512     uint64_t top = 0;
513     if (!values.empty()) {
514       diff = lexDiff(cursor);
515       endPath(diff + 1);
516       top = idx[diff] + 1;
517     }
518     // Then continue with insertion path.
519     insPath(cursor, diff, top, val);
520   }
521 
522   /// Partially specialize expanded insertions based on template types.
523   /// Note that this method resets the values/filled-switch array back
524   /// to all-zero/false while only iterating over the nonzero elements.
525   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
526                  uint64_t count) final override {
527     if (count == 0)
528       return;
529     // Sort.
530     std::sort(added, added + count);
531     // Restore insertion path for first insert.
532     const uint64_t lastDim = getRank() - 1;
533     uint64_t index = added[0];
534     cursor[lastDim] = index;
535     lexInsert(cursor, values[index]);
536     assert(filled[index]);
537     values[index] = 0;
538     filled[index] = false;
539     // Subsequent insertions are quick.
540     for (uint64_t i = 1; i < count; i++) {
541       assert(index < added[i] && "non-lexicographic insertion");
542       index = added[i];
543       cursor[lastDim] = index;
544       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
545       assert(filled[index]);
546       values[index] = 0;
547       filled[index] = false;
548     }
549   }
550 
551   /// Finalizes lexicographic insertions.
552   void endInsert() final override {
553     if (values.empty())
554       finalizeSegment(0);
555     else
556       endPath(0);
557   }
558 
559   void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
560                      const uint64_t *perm) const final override {
561     *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
562   }
563 
564   /// Returns this sparse tensor storage scheme as a new memory-resident
565   /// sparse tensor in coordinate scheme with the given dimension order.
566   ///
567   /// Precondition: `perm` must be valid for `getRank()`.
568   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
569     SparseTensorEnumeratorBase<V> *enumerator;
570     newEnumerator(&enumerator, getRank(), perm);
571     SparseTensorCOO<V> *coo =
572         new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
573     enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
574       coo->add(ind, val);
575     });
576     // TODO: This assertion assumes there are no stored zeros,
577     // or if there are then that we don't filter them out.
578     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
579     assert(coo->getElements().size() == values.size());
580     delete enumerator;
581     return coo;
582   }
583 
584   /// Factory method. Constructs a sparse tensor storage scheme with the given
585   /// dimensions, permutation, and per-dimension dense/sparse annotations,
586   /// using the coordinate scheme tensor for the initial contents if provided.
587   /// In the latter case, the coordinate scheme must respect the same
588   /// permutation as is desired for the new sparse tensor storage.
589   ///
590   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
591   static SparseTensorStorage<P, I, V> *
592   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
593                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
594     SparseTensorStorage<P, I, V> *n = nullptr;
595     if (coo) {
596       const auto &coosz = coo->getDimSizes();
597       assertPermutedSizesMatchShape(coosz, rank, perm, shape);
598       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
599     } else {
600       std::vector<uint64_t> permsz(rank);
601       for (uint64_t r = 0; r < rank; r++) {
602         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
603         permsz[perm[r]] = shape[r];
604       }
605       // We pass the null `coo` to ensure we select the intended constructor.
606       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
607     }
608     return n;
609   }
610 
611   /// Factory method. Constructs a sparse tensor storage scheme with
612   /// the given dimensions, permutation, and per-dimension dense/sparse
613   /// annotations, using the sparse tensor for the initial contents.
614   ///
615   /// Preconditions:
616   /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
617   /// * The `tensor` must have the same value type `V`.
618   static SparseTensorStorage<P, I, V> *
619   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
620                   const DimLevelType *sparsity,
621                   const SparseTensorStorageBase *source) {
622     assert(source && "Got nullptr for source");
623     SparseTensorEnumeratorBase<V> *enumerator;
624     source->newEnumerator(&enumerator, rank, perm);
625     const auto &permsz = enumerator->permutedSizes();
626     assertPermutedSizesMatchShape(permsz, rank, perm, shape);
627     auto *tensor =
628         new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
629     delete enumerator;
630     return tensor;
631   }
632 
633 private:
634   /// Appends an arbitrary new position to `pointers[d]`.  This method
635   /// checks that `pos` is representable in the `P` type; however, it
636   /// does not check that `pos` is semantically valid (i.e., larger than
637   /// the previous position and smaller than `indices[d].capacity()`).
638   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
639     assert(isCompressedDim(d));
640     assert(pos <= std::numeric_limits<P>::max() &&
641            "Pointer value is too large for the P-type");
642     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
643   }
644 
645   /// Appends index `i` to dimension `d`, in the semantically general
646   /// sense.  For non-dense dimensions, that means appending to the
647   /// `indices[d]` array, checking that `i` is representable in the `I`
648   /// type; however, we do not verify other semantic requirements (e.g.,
649   /// that `i` is in bounds for `dimSizes[d]`, and not previously occurring
650   /// in the same segment).  For dense dimensions, this method instead
651   /// appends the appropriate number of zeros to the `values` array,
652   /// where `full` is the number of "entries" already written to `values`
653   /// for this segment (aka one after the highest index previously appended).
654   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
655     if (isCompressedDim(d)) {
656       assert(i <= std::numeric_limits<I>::max() &&
657              "Index value is too large for the I-type");
658       indices[d].push_back(static_cast<I>(i));
659     } else { // Dense dimension.
660       assert(i >= full && "Index was already filled");
661       if (i == full)
662         return; // Short-circuit, since it'll be a nop.
663       if (d + 1 == getRank())
664         values.insert(values.end(), i - full, 0);
665       else
666         finalizeSegment(d + 1, 0, i - full);
667     }
668   }
669 
670   /// Writes the given coordinate to `indices[d][pos]`.  This method
671   /// checks that `i` is representable in the `I` type; however, it
672   /// does not check that `i` is semantically valid (i.e., in bounds
673   /// for `dimSizes[d]` and not elsewhere occurring in the same segment).
674   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
675     assert(isCompressedDim(d));
676     // Subscript assignment to `std::vector` requires that the `pos`-th
677     // entry has been initialized; thus we must be sure to check `size()`
678     // here, instead of `capacity()` as would be ideal.
679     assert(pos < indices[d].size() && "Index position is out of bounds");
680     assert(i <= std::numeric_limits<I>::max() &&
681            "Index value is too large for the I-type");
682     indices[d][pos] = static_cast<I>(i);
683   }
684 
685   /// Computes the assembled-size associated with the `d`-th dimension,
686   /// given the assembled-size associated with the `(d-1)`-th dimension.
687   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
688   /// storage, as opposed to "dimension-sizes" which are the cardinality
689   /// of coordinates for that dimension.
690   ///
691   /// Precondition: the `pointers[d]` array must be fully initialized
692   /// before calling this method.
693   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
694     if (isCompressedDim(d))
695       return pointers[d][parentSz];
696     // else if dense:
697     return parentSz * getDimSizes()[d];
698   }
699 
700   /// Initializes sparse tensor storage scheme from a memory-resident sparse
701   /// tensor in coordinate scheme. This method prepares the pointers and
702   /// indices arrays under the given per-dimension dense/sparse annotations.
703   ///
704   /// Preconditions:
705   /// (1) the `elements` must be lexicographically sorted.
706   /// (2) the indices of every element are valid for `dimSizes` (equal rank
707   ///     and pointwise less-than).
708   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
709                uint64_t hi, uint64_t d) {
710     uint64_t rank = getRank();
711     assert(d <= rank && hi <= elements.size());
712     // Once dimensions are exhausted, insert the numerical values.
713     if (d == rank) {
714       assert(lo < hi);
715       values.push_back(elements[lo].value);
716       return;
717     }
718     // Visit all elements in this interval.
719     uint64_t full = 0;
720     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
721       // Find segment in interval with same index elements in this dimension.
722       uint64_t i = elements[lo].indices[d];
723       uint64_t seg = lo + 1;
724       while (seg < hi && elements[seg].indices[d] == i)
725         seg++;
726       // Handle segment in interval for sparse or dense dimension.
727       appendIndex(d, full, i);
728       full = i + 1;
729       fromCOO(elements, lo, seg, d + 1);
730       // And move on to next segment in interval.
731       lo = seg;
732     }
733     // Finalize the sparse pointer structure at this dimension.
734     finalizeSegment(d, full);
735   }
736 
737   /// Finalize the sparse pointer structure at this dimension.
738   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
739     if (count == 0)
740       return; // Short-circuit, since it'll be a nop.
741     if (isCompressedDim(d)) {
742       appendPointer(d, indices[d].size(), count);
743     } else { // Dense dimension.
744       const uint64_t sz = getDimSizes()[d];
745       assert(sz >= full && "Segment is overfull");
746       count = checkedMul(count, sz - full);
747       // For dense storage we must enumerate all the remaining coordinates
748       // in this dimension (i.e., coordinates after the last non-zero
749       // element), and either fill in their zero values or else recurse
750       // to finalize some deeper dimension.
751       if (d + 1 == getRank())
752         values.insert(values.end(), count, 0);
753       else
754         finalizeSegment(d + 1, 0, count);
755     }
756   }
757 
758   /// Wraps up a single insertion path, inner to outer.
759   void endPath(uint64_t diff) {
760     uint64_t rank = getRank();
761     assert(diff <= rank);
762     for (uint64_t i = 0; i < rank - diff; i++) {
763       const uint64_t d = rank - i - 1;
764       finalizeSegment(d, idx[d] + 1);
765     }
766   }
767 
768   /// Continues a single insertion path, outer to inner.
769   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
770     uint64_t rank = getRank();
771     assert(diff < rank);
772     for (uint64_t d = diff; d < rank; d++) {
773       uint64_t i = cursor[d];
774       appendIndex(d, top, i);
775       top = 0;
776       idx[d] = i;
777     }
778     values.push_back(val);
779   }
780 
781   /// Finds the lexicographic differing dimension.
782   uint64_t lexDiff(const uint64_t *cursor) const {
783     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
784       if (cursor[r] > idx[r])
785         return r;
786       else
787         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
788     assert(0 && "duplication insertion");
789     return -1u;
790   }
791 
792   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
793   // the cost of virtual-function dispatch in inner loops), without
794   // making them public to other client code.
795   friend class SparseTensorEnumerator<P, I, V>;
796 
797   std::vector<std::vector<P>> pointers;
798   std::vector<std::vector<I>> indices;
799   std::vector<V> values;
800   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
801 };
802 
803 /// A (higher-order) function object for enumerating the elements of some
804 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
805 /// method encapsulates the loop-nest for enumerating the elements of
806 /// the source tensor (in whatever order is best for the source tensor),
807 /// and applies a permutation to the coordinates/indices before handing
808 /// each element to the callback.  A single enumerator object can be
809 /// freely reused for several calls to `forallElements`, just so long
810 /// as each call is sequential with respect to one another.
811 ///
812 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
813 /// passed to the constructor; thus, objects of this class must not
814 /// outlive the sparse tensor they depend on.
815 ///
816 /// Design Note: The reason we define this class instead of simply using
817 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
818 /// the `<P,I>` template parameters from MLIR client code (to simplify the
819 /// type parameters used for direct sparse-to-sparse conversion).  And the
820 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
821 /// than simply using this class, is to avoid the cost of virtual-method
822 /// dispatch within the loop-nest.
823 template <typename V>
824 class SparseTensorEnumeratorBase {
825 public:
826   /// Constructs an enumerator with the given permutation for mapping
827   /// the semantic-ordering of dimensions to the desired target-ordering.
828   ///
829   /// Preconditions:
830   /// * the `tensor` must have the same `V` value type.
831   /// * `perm` must be valid for `rank`.
832   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
833                              uint64_t rank, const uint64_t *perm)
834       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
835         cursor(getRank()) {
836     assert(perm && "Received nullptr for permutation");
837     assert(rank == getRank() && "Permutation rank mismatch");
838     const auto &rev = src.getRev();           // source-order -> semantic-order
839     const auto &dimSizes = src.getDimSizes(); // in source storage-order
840     for (uint64_t s = 0; s < rank; s++) {     // `s` source storage-order
841       uint64_t t = perm[rev[s]];              // `t` target-order
842       reord[s] = t;
843       permsz[t] = dimSizes[s];
844     }
845   }
846 
847   virtual ~SparseTensorEnumeratorBase() = default;
848 
849   // We disallow copying to help avoid leaking the `src` reference.
850   // (In addition to avoiding the problem of slicing.)
851   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
852   SparseTensorEnumeratorBase &
853   operator=(const SparseTensorEnumeratorBase &) = delete;
854 
855   /// Returns the source/target tensor's rank.  (The source-rank and
856   /// target-rank are always equal since we only support permutations.
857   /// Though once we add support for other dimension mappings, this
858   /// method will have to be split in two.)
859   uint64_t getRank() const { return permsz.size(); }
860 
861   /// Returns the target tensor's dimension sizes.
862   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
863 
864   /// Enumerates all elements of the source tensor, permutes their
865   /// indices, and passes the permuted element to the callback.
866   /// The callback must not store the cursor reference directly,
867   /// since this function reuses the storage.  Instead, the callback
868   /// must copy it if they want to keep it.
869   virtual void forallElements(ElementConsumer<V> yield) = 0;
870 
871 protected:
872   const SparseTensorStorageBase &src;
873   std::vector<uint64_t> permsz; // in target order.
874   std::vector<uint64_t> reord;  // source storage-order -> target order.
875   std::vector<uint64_t> cursor; // in target order.
876 };
877 
878 template <typename P, typename I, typename V>
879 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
880   using Base = SparseTensorEnumeratorBase<V>;
881 
882 public:
883   /// Constructs an enumerator with the given permutation for mapping
884   /// the semantic-ordering of dimensions to the desired target-ordering.
885   ///
886   /// Precondition: `perm` must be valid for `rank`.
887   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
888                          uint64_t rank, const uint64_t *perm)
889       : Base(tensor, rank, perm) {}
890 
891   ~SparseTensorEnumerator() final override = default;
892 
893   void forallElements(ElementConsumer<V> yield) final override {
894     forallElements(yield, 0, 0);
895   }
896 
897 private:
898   /// The recursive component of the public `forallElements`.
899   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
900                       uint64_t d) {
901     // Recover the `<P,I,V>` type parameters of `src`.
902     const auto &src =
903         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
904     if (d == Base::getRank()) {
905       assert(parentPos < src.values.size() &&
906              "Value position is out of bounds");
907       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
908       yield(this->cursor, src.values[parentPos]);
909     } else if (src.isCompressedDim(d)) {
910       // Look up the bounds of the `d`-level segment determined by the
911       // `d-1`-level position `parentPos`.
912       const std::vector<P> &pointers_d = src.pointers[d];
913       assert(parentPos + 1 < pointers_d.size() &&
914              "Parent pointer position is out of bounds");
915       const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
916       const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
917       // Loop-invariant code for looking up the `d`-level coordinates/indices.
918       const std::vector<I> &indices_d = src.indices[d];
919       assert(pstop <= indices_d.size() && "Index position is out of bounds");
920       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
921       for (uint64_t pos = pstart; pos < pstop; pos++) {
922         cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
923         forallElements(yield, pos, d + 1);
924       }
925     } else { // Dense dimension.
926       const uint64_t sz = src.getDimSizes()[d];
927       const uint64_t pstart = parentPos * sz;
928       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
929       for (uint64_t i = 0; i < sz; i++) {
930         cursor_reord_d = i;
931         forallElements(yield, pstart + i, d + 1);
932       }
933     }
934   }
935 };
936 
937 /// Statistics regarding the number of nonzero subtensors in
938 /// a source tensor, for direct sparse=>sparse conversion a la
939 /// <https://arxiv.org/abs/2001.02609>.
940 ///
941 /// N.B., this class stores references to the parameters passed to
942 /// the constructor; thus, objects of this class must not outlive
943 /// those parameters.
944 class SparseTensorNNZ final {
945 public:
946   /// Allocate the statistics structure for the desired sizes and
947   /// sparsity (in the target tensor's storage-order).  This constructor
948   /// does not actually populate the statistics, however; for that see
949   /// `initialize`.
950   ///
951   /// Precondition: `dimSizes` must not contain zeros.
952   SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
953                   const std::vector<DimLevelType> &sparsity)
954       : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
955     assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
956     bool uncompressed = true;
957     uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
958     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
959       switch (dimTypes[r]) {
960       case DimLevelType::kCompressed:
961         assert(uncompressed &&
962                "Multiple compressed layers not currently supported");
963         uncompressed = false;
964         nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
965         break;
966       case DimLevelType::kDense:
967         assert(uncompressed &&
968                "Dense after compressed not currently supported");
969         break;
970       case DimLevelType::kSingleton:
971         // Singleton after Compressed causes no problems for allocating
972         // `nnz` nor for the yieldPos loop.  This remains true even
973         // when adding support for multiple compressed dimensions or
974         // for dense-after-compressed.
975         break;
976       }
977       sz = checkedMul(sz, dimSizes[r]);
978     }
979   }
980 
981   // We disallow copying to help avoid leaking the stored references.
982   SparseTensorNNZ(const SparseTensorNNZ &) = delete;
983   SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
984 
985   /// Returns the rank of the target tensor.
986   uint64_t getRank() const { return dimSizes.size(); }
987 
988   /// Enumerate the source tensor to fill in the statistics.  The
989   /// enumerator should already incorporate the permutation (from
990   /// semantic-order to the target storage-order).
991   template <typename V>
992   void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
993     assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
994     assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
995     enumerator.forallElements(
996         [this](const std::vector<uint64_t> &ind, V) { add(ind); });
997   }
998 
999   /// The type of callback functions which receive an nnz-statistic.
1000   using NNZConsumer = const std::function<void(uint64_t)> &;
1001 
1002   /// Lexicographically enumerates all indicies for dimensions strictly
1003   /// less than `stopDim`, and passes their nnz statistic to the callback.
1004   /// Since our use-case only requires the statistic not the coordinates
1005   /// themselves, we do not bother to construct those coordinates.
1006   void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
1007     assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
1008     assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
1009            "Cannot look up non-compressed dimensions");
1010     forallIndices(yield, stopDim, 0, 0);
1011   }
1012 
1013 private:
1014   /// Adds a new element (i.e., increment its statistics).  We use
1015   /// a method rather than inlining into the lambda in `initialize`,
1016   /// to avoid spurious templating over `V`.  And this method is private
1017   /// to avoid needing to re-assert validity of `ind` (which is guaranteed
1018   /// by `forallElements`).
1019   void add(const std::vector<uint64_t> &ind) {
1020     uint64_t parentPos = 0;
1021     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1022       if (dimTypes[r] == DimLevelType::kCompressed)
1023         nnz[r][parentPos]++;
1024       parentPos = parentPos * dimSizes[r] + ind[r];
1025     }
1026   }
1027 
1028   /// Recursive component of the public `forallIndices`.
1029   void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
1030                      uint64_t d) const {
1031     assert(d <= stopDim);
1032     if (d == stopDim) {
1033       assert(parentPos < nnz[d].size() && "Cursor is out of range");
1034       yield(nnz[d][parentPos]);
1035     } else {
1036       const uint64_t sz = dimSizes[d];
1037       const uint64_t pstart = parentPos * sz;
1038       for (uint64_t i = 0; i < sz; i++)
1039         forallIndices(yield, stopDim, pstart + i, d + 1);
1040     }
1041   }
1042 
1043   // All of these are in the target storage-order.
1044   const std::vector<uint64_t> &dimSizes;
1045   const std::vector<DimLevelType> &dimTypes;
1046   std::vector<std::vector<uint64_t>> nnz;
1047 };
1048 
1049 template <typename P, typename I, typename V>
1050 SparseTensorStorage<P, I, V>::SparseTensorStorage(
1051     const std::vector<uint64_t> &dimSizes, const uint64_t *perm,
1052     const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1053     : SparseTensorStorage(dimSizes, perm, sparsity) {
1054   SparseTensorEnumeratorBase<V> *enumerator;
1055   tensor.newEnumerator(&enumerator, getRank(), perm);
1056   {
1057     // Initialize the statistics structure.
1058     SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
1059     nnz.initialize(*enumerator);
1060     // Initialize "pointers" overhead (and allocate "indices", "values").
1061     uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
1062     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1063       if (isCompressedDim(r)) {
1064         pointers[r].reserve(parentSz + 1);
1065         pointers[r].push_back(0);
1066         uint64_t currentPos = 0;
1067         nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
1068           currentPos += n;
1069           appendPointer(r, currentPos);
1070         });
1071         assert(pointers[r].size() == parentSz + 1 &&
1072                "Final pointers size doesn't match allocated size");
1073         // That assertion entails `assembledSize(parentSz, r)`
1074         // is now in a valid state.  That is, `pointers[r][parentSz]`
1075         // equals the present value of `currentPos`, which is the
1076         // correct assembled-size for `indices[r]`.
1077       }
1078       // Update assembled-size for the next iteration.
1079       parentSz = assembledSize(parentSz, r);
1080       // Ideally we need only `indices[r].reserve(parentSz)`, however
1081       // the `std::vector` implementation forces us to initialize it too.
1082       // That is, in the yieldPos loop we need random-access assignment
1083       // to `indices[r]`; however, `std::vector`'s subscript-assignment
1084       // only allows assigning to already-initialized positions.
1085       if (isCompressedDim(r))
1086         indices[r].resize(parentSz, 0);
1087     }
1088     values.resize(parentSz, 0); // Both allocate and zero-initialize.
1089   }
1090   // The yieldPos loop
1091   enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
1092     uint64_t parentSz = 1, parentPos = 0;
1093     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1094       if (isCompressedDim(r)) {
1095         // If `parentPos == parentSz` then it's valid as an array-lookup;
1096         // however, it's semantically invalid here since that entry
1097         // does not represent a segment of `indices[r]`.  Moreover, that
1098         // entry must be immutable for `assembledSize` to remain valid.
1099         assert(parentPos < parentSz && "Pointers position is out of bounds");
1100         const uint64_t currentPos = pointers[r][parentPos];
1101         // This increment won't overflow the `P` type, since it can't
1102         // exceed the original value of `pointers[r][parentPos+1]`
1103         // which was already verified to be within bounds for `P`
1104         // when it was written to the array.
1105         pointers[r][parentPos]++;
1106         writeIndex(r, currentPos, ind[r]);
1107         parentPos = currentPos;
1108       } else { // Dense dimension.
1109         parentPos = parentPos * getDimSizes()[r] + ind[r];
1110       }
1111       parentSz = assembledSize(parentSz, r);
1112     }
1113     assert(parentPos < values.size() && "Value position is out of bounds");
1114     values[parentPos] = val;
1115   });
1116   // No longer need the enumerator, so we'll delete it ASAP.
1117   delete enumerator;
1118   // The finalizeYieldPos loop
1119   for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
1120     if (isCompressedDim(r)) {
1121       assert(parentSz == pointers[r].size() - 1 &&
1122              "Actual pointers size doesn't match the expected size");
1123       // Can't check all of them, but at least we can check the last one.
1124       assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
1125              "Pointers got corrupted");
1126       // TODO: optimize this by using `memmove` or similar.
1127       for (uint64_t n = 0; n < parentSz; n++) {
1128         const uint64_t parentPos = parentSz - n;
1129         pointers[r][parentPos] = pointers[r][parentPos - 1];
1130       }
1131       pointers[r][0] = 0;
1132     }
1133     parentSz = assembledSize(parentSz, r);
1134   }
1135 }
1136 
1137 /// Helper to convert string to lower case.
1138 static char *toLower(char *token) {
1139   for (char *c = token; *c; c++)
1140     *c = tolower(*c);
1141   return token;
1142 }
1143 
1144 /// Read the MME header of a general sparse matrix of type real.
1145 static void readMMEHeader(FILE *file, char *filename, char *line,
1146                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
1147   char header[64];
1148   char object[64];
1149   char format[64];
1150   char field[64];
1151   char symmetry[64];
1152   // Read header line.
1153   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
1154              symmetry) != 5)
1155     FATAL("Corrupt header in %s\n", filename);
1156   // Set properties
1157   *isPattern = (strcmp(toLower(field), "pattern") == 0);
1158   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
1159   // Make sure this is a general sparse matrix.
1160   if (strcmp(toLower(header), "%%matrixmarket") ||
1161       strcmp(toLower(object), "matrix") ||
1162       strcmp(toLower(format), "coordinate") ||
1163       (strcmp(toLower(field), "real") && !(*isPattern)) ||
1164       (strcmp(toLower(symmetry), "general") && !(*isSymmetric)))
1165     FATAL("Cannot find a general sparse matrix in %s\n", filename);
1166   // Skip comments.
1167   while (true) {
1168     if (!fgets(line, kColWidth, file))
1169       FATAL("Cannot find data in %s\n", filename);
1170     if (line[0] != '%')
1171       break;
1172   }
1173   // Next line contains M N NNZ.
1174   idata[0] = 2; // rank
1175   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
1176              idata + 1) != 3)
1177     FATAL("Cannot find size in %s\n", filename);
1178 }
1179 
1180 /// Read the "extended" FROSTT header. Although not part of the documented
1181 /// format, we assume that the file starts with optional comments followed
1182 /// by two lines that define the rank, the number of nonzeros, and the
1183 /// dimensions sizes (one per rank) of the sparse tensor.
1184 static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
1185                                 uint64_t *idata) {
1186   // Skip comments.
1187   while (true) {
1188     if (!fgets(line, kColWidth, file))
1189       FATAL("Cannot find data in %s\n", filename);
1190     if (line[0] != '#')
1191       break;
1192   }
1193   // Next line contains RANK and NNZ.
1194   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
1195     FATAL("Cannot find metadata in %s\n", filename);
1196   // Followed by a line with the dimension sizes (one per rank).
1197   for (uint64_t r = 0; r < idata[0]; r++)
1198     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
1199       FATAL("Cannot find dimension size %s\n", filename);
1200   fgets(line, kColWidth, file); // end of line
1201 }
1202 
1203 /// Reads a sparse tensor with the given filename into a memory-resident
1204 /// sparse tensor in coordinate scheme.
1205 template <typename V>
1206 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
1207                                                const uint64_t *shape,
1208                                                const uint64_t *perm) {
1209   // Open the file.
1210   assert(filename && "Received nullptr for filename");
1211   FILE *file = fopen(filename, "r");
1212   if (!file)
1213     FATAL("Cannot find file %s\n", filename);
1214   // Perform some file format dependent set up.
1215   char line[kColWidth];
1216   uint64_t idata[512];
1217   bool isPattern = false;
1218   bool isSymmetric = false;
1219   if (strstr(filename, ".mtx")) {
1220     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
1221   } else if (strstr(filename, ".tns")) {
1222     readExtFROSTTHeader(file, filename, line, idata);
1223   } else {
1224     FATAL("Unknown format %s\n", filename);
1225   }
1226   // Prepare sparse tensor object with per-dimension sizes
1227   // and the number of nonzeros as initial capacity.
1228   assert(rank == idata[0] && "rank mismatch");
1229   uint64_t nnz = idata[1];
1230   for (uint64_t r = 0; r < rank; r++)
1231     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
1232            "dimension size mismatch");
1233   SparseTensorCOO<V> *tensor =
1234       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
1235   // Read all nonzero elements.
1236   std::vector<uint64_t> indices(rank);
1237   for (uint64_t k = 0; k < nnz; k++) {
1238     if (!fgets(line, kColWidth, file))
1239       FATAL("Cannot find next line of data in %s\n", filename);
1240     char *linePtr = line;
1241     for (uint64_t r = 0; r < rank; r++) {
1242       uint64_t idx = strtoul(linePtr, &linePtr, 10);
1243       // Add 0-based index.
1244       indices[perm[r]] = idx - 1;
1245     }
1246     // The external formats always store the numerical values with the type
1247     // double, but we cast these values to the sparse tensor object type.
1248     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
1249     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
1250     tensor->add(indices, value);
1251     // We currently chose to deal with symmetric matrices by fully constructing
1252     // them. In the future, we may want to make symmetry implicit for storage
1253     // reasons.
1254     if (isSymmetric && indices[0] != indices[1])
1255       tensor->add({indices[1], indices[0]}, value);
1256   }
1257   // Close the file and return tensor.
1258   fclose(file);
1259   return tensor;
1260 }
1261 
1262 /// Writes the sparse tensor to extended FROSTT format.
1263 template <typename V>
1264 static void outSparseTensor(void *tensor, void *dest, bool sort) {
1265   assert(tensor && dest);
1266   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
1267   if (sort)
1268     coo->sort();
1269   char *filename = static_cast<char *>(dest);
1270   auto &dimSizes = coo->getDimSizes();
1271   auto &elements = coo->getElements();
1272   uint64_t rank = coo->getRank();
1273   uint64_t nnz = elements.size();
1274   std::fstream file;
1275   file.open(filename, std::ios_base::out | std::ios_base::trunc);
1276   assert(file.is_open());
1277   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1278   for (uint64_t r = 0; r < rank - 1; r++)
1279     file << dimSizes[r] << " ";
1280   file << dimSizes[rank - 1] << std::endl;
1281   for (uint64_t i = 0; i < nnz; i++) {
1282     auto &idx = elements[i].indices;
1283     for (uint64_t r = 0; r < rank; r++)
1284       file << (idx[r] + 1) << " ";
1285     file << elements[i].value << std::endl;
1286   }
1287   file.flush();
1288   file.close();
1289   assert(file.good());
1290 }
1291 
1292 /// Initializes sparse tensor from an external COO-flavored format.
1293 template <typename V>
1294 static SparseTensorStorage<uint64_t, uint64_t, V> *
1295 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
1296                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
1297   const DimLevelType *sparsity = (DimLevelType *)(sparse);
1298 #ifndef NDEBUG
1299   // Verify that perm is a permutation of 0..(rank-1).
1300   std::vector<uint64_t> order(perm, perm + rank);
1301   std::sort(order.begin(), order.end());
1302   for (uint64_t i = 0; i < rank; ++i)
1303     if (i != order[i])
1304       FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
1305 
1306   // Verify that the sparsity values are supported.
1307   for (uint64_t i = 0; i < rank; ++i)
1308     if (sparsity[i] != DimLevelType::kDense &&
1309         sparsity[i] != DimLevelType::kCompressed)
1310       FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
1311 #endif
1312 
1313   // Convert external format to internal COO.
1314   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
1315   std::vector<uint64_t> idx(rank);
1316   for (uint64_t i = 0, base = 0; i < nse; i++) {
1317     for (uint64_t r = 0; r < rank; r++)
1318       idx[perm[r]] = indices[base + r];
1319     coo->add(idx, values[i]);
1320     base += rank;
1321   }
1322   // Return sparse tensor storage format as opaque pointer.
1323   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
1324       rank, shape, perm, sparsity, coo);
1325   delete coo;
1326   return tensor;
1327 }
1328 
1329 /// Converts a sparse tensor to an external COO-flavored format.
1330 template <typename V>
1331 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1332                                  uint64_t **pShape, V **pValues,
1333                                  uint64_t **pIndices) {
1334   assert(tensor);
1335   auto sparseTensor =
1336       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1337   uint64_t rank = sparseTensor->getRank();
1338   std::vector<uint64_t> perm(rank);
1339   std::iota(perm.begin(), perm.end(), 0);
1340   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1341 
1342   const std::vector<Element<V>> &elements = coo->getElements();
1343   uint64_t nse = elements.size();
1344 
1345   uint64_t *shape = new uint64_t[rank];
1346   for (uint64_t i = 0; i < rank; i++)
1347     shape[i] = coo->getDimSizes()[i];
1348 
1349   V *values = new V[nse];
1350   uint64_t *indices = new uint64_t[rank * nse];
1351 
1352   for (uint64_t i = 0, base = 0; i < nse; i++) {
1353     values[i] = elements[i].value;
1354     for (uint64_t j = 0; j < rank; j++)
1355       indices[base + j] = elements[i].indices[j];
1356     base += rank;
1357   }
1358 
1359   delete coo;
1360   *pRank = rank;
1361   *pNse = nse;
1362   *pShape = shape;
1363   *pValues = values;
1364   *pIndices = indices;
1365 }
1366 
1367 } // namespace
1368 
1369 extern "C" {
1370 
1371 //===----------------------------------------------------------------------===//
1372 //
1373 // Public API with methods that operate on MLIR buffers (memrefs) to interact
1374 // with sparse tensors, which are only visible as opaque pointers externally.
1375 // These methods should be used exclusively by MLIR compiler-generated code.
1376 //
1377 // Some macro magic is used to generate implementations for all required type
1378 // combinations that can be called from MLIR compiler-generated code.
1379 //
1380 //===----------------------------------------------------------------------===//
1381 
1382 #define CASE(p, i, v, P, I, V)                                                 \
1383   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
1384     SparseTensorCOO<V> *coo = nullptr;                                         \
1385     if (action <= Action::kFromCOO) {                                          \
1386       if (action == Action::kFromFile) {                                       \
1387         char *filename = static_cast<char *>(ptr);                             \
1388         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
1389       } else if (action == Action::kFromCOO) {                                 \
1390         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
1391       } else {                                                                 \
1392         assert(action == Action::kEmpty);                                      \
1393       }                                                                        \
1394       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
1395           rank, shape, perm, sparsity, coo);                                   \
1396       if (action == Action::kFromFile)                                         \
1397         delete coo;                                                            \
1398       return tensor;                                                           \
1399     }                                                                          \
1400     if (action == Action::kSparseToSparse) {                                   \
1401       auto *tensor = static_cast<SparseTensorStorageBase *>(ptr);              \
1402       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
1403                                                            sparsity, tensor);  \
1404     }                                                                          \
1405     if (action == Action::kEmptyCOO)                                           \
1406       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
1407     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1408     if (action == Action::kToIterator) {                                       \
1409       coo->startIterator();                                                    \
1410     } else {                                                                   \
1411       assert(action == Action::kToCOO);                                        \
1412     }                                                                          \
1413     return coo;                                                                \
1414   }
1415 
1416 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1417 
1418 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1419 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1420 // that this file cannot get out of sync with its header.
1421 static_assert(std::is_same<index_type, uint64_t>::value,
1422               "Expected index_type == uint64_t");
1423 
1424 /// Constructs a new sparse tensor. This is the "swiss army knife"
1425 /// method for materializing sparse tensors into the computation.
1426 ///
1427 /// Action:
1428 /// kEmpty = returns empty storage to fill later
1429 /// kFromFile = returns storage, where ptr contains filename to read
1430 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
1431 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
1432 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
1433 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
1434 void *
1435 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1436                              StridedMemRefType<index_type, 1> *sref,
1437                              StridedMemRefType<index_type, 1> *pref,
1438                              OverheadType ptrTp, OverheadType indTp,
1439                              PrimaryType valTp, Action action, void *ptr) {
1440   assert(aref && sref && pref);
1441   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1442          pref->strides[0] == 1);
1443   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1444   const DimLevelType *sparsity = aref->data + aref->offset;
1445   const index_type *shape = sref->data + sref->offset;
1446   const index_type *perm = pref->data + pref->offset;
1447   uint64_t rank = aref->sizes[0];
1448 
1449   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1450   // This is safe because of the static_assert above.
1451   if (ptrTp == OverheadType::kIndex)
1452     ptrTp = OverheadType::kU64;
1453   if (indTp == OverheadType::kIndex)
1454     indTp = OverheadType::kU64;
1455 
1456   // Double matrices with all combinations of overhead storage.
1457   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1458        uint64_t, double);
1459   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1460        uint32_t, double);
1461   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1462        uint16_t, double);
1463   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1464        uint8_t, double);
1465   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1466        uint64_t, double);
1467   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1468        uint32_t, double);
1469   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1470        uint16_t, double);
1471   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1472        uint8_t, double);
1473   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1474        uint64_t, double);
1475   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1476        uint32_t, double);
1477   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1478        uint16_t, double);
1479   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1480        uint8_t, double);
1481   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1482        uint64_t, double);
1483   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1484        uint32_t, double);
1485   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1486        uint16_t, double);
1487   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1488        uint8_t, double);
1489 
1490   // Float matrices with all combinations of overhead storage.
1491   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1492        uint64_t, float);
1493   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1494        uint32_t, float);
1495   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1496        uint16_t, float);
1497   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1498        uint8_t, float);
1499   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1500        uint64_t, float);
1501   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1502        uint32_t, float);
1503   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1504        uint16_t, float);
1505   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1506        uint8_t, float);
1507   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1508        uint64_t, float);
1509   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1510        uint32_t, float);
1511   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1512        uint16_t, float);
1513   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1514        uint8_t, float);
1515   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1516        uint64_t, float);
1517   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1518        uint32_t, float);
1519   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1520        uint16_t, float);
1521   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1522        uint8_t, float);
1523 
1524   // Integral matrices with both overheads of the same type.
1525   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1526   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1527   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1528   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1529   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1530   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1531   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1532   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1533   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1534   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1535   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1536   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1537   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1538 
1539   // Complex matrices with wide overhead.
1540   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1541   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1542 
1543   // Unsupported case (add above if needed).
1544   // TODO: better pretty-printing of enum values!
1545   FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
1546         static_cast<int>(ptrTp), static_cast<int>(indTp),
1547         static_cast<int>(valTp));
1548 }
1549 #undef CASE
1550 #undef CASE_SECSAME
1551 
1552 /// Methods that provide direct access to values.
1553 #define IMPL_SPARSEVALUES(VNAME, V)                                            \
1554   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
1555                                         void *tensor) {                        \
1556     assert(ref &&tensor);                                                      \
1557     std::vector<V> *v;                                                         \
1558     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
1559     ref->basePtr = ref->data = v->data();                                      \
1560     ref->offset = 0;                                                           \
1561     ref->sizes[0] = v->size();                                                 \
1562     ref->strides[0] = 1;                                                       \
1563   }
1564 FOREVERY_V(IMPL_SPARSEVALUES)
1565 #undef IMPL_SPARSEVALUES
1566 
1567 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1568   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1569                            index_type d) {                                     \
1570     assert(ref &&tensor);                                                      \
1571     std::vector<TYPE> *v;                                                      \
1572     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1573     ref->basePtr = ref->data = v->data();                                      \
1574     ref->offset = 0;                                                           \
1575     ref->sizes[0] = v->size();                                                 \
1576     ref->strides[0] = 1;                                                       \
1577   }
1578 /// Methods that provide direct access to pointers.
1579 IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
1580 IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
1581 IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
1582 IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
1583 IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
1584 
1585 /// Methods that provide direct access to indices.
1586 IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
1587 IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
1588 IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
1589 IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
1590 IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
1591 #undef IMPL_GETOVERHEAD
1592 
1593 /// Helper to add value to coordinate scheme, one per value type.
1594 #define IMPL_ADDELT(VNAME, V)                                                  \
1595   void *_mlir_ciface_addElt##VNAME(void *coo, V value,                         \
1596                                    StridedMemRefType<index_type, 1> *iref,     \
1597                                    StridedMemRefType<index_type, 1> *pref) {   \
1598     assert(coo &&iref &&pref);                                                 \
1599     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1600     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1601     const index_type *indx = iref->data + iref->offset;                        \
1602     const index_type *perm = pref->data + pref->offset;                        \
1603     uint64_t isize = iref->sizes[0];                                           \
1604     std::vector<index_type> indices(isize);                                    \
1605     for (uint64_t r = 0; r < isize; r++)                                       \
1606       indices[perm[r]] = indx[r];                                              \
1607     static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value);               \
1608     return coo;                                                                \
1609   }
1610 FOREVERY_SIMPLEX_V(IMPL_ADDELT)
1611 // `complex64` apparently doesn't encounter any ABI issues (yet).
1612 IMPL_ADDELT(C64, complex64)
1613 // TODO: cleaner way to avoid ABI padding problem?
1614 IMPL_ADDELT(C32ABI, complex32)
1615 void *_mlir_ciface_addEltC32(void *coo, float r, float i,
1616                              StridedMemRefType<index_type, 1> *iref,
1617                              StridedMemRefType<index_type, 1> *pref) {
1618   return _mlir_ciface_addEltC32ABI(coo, complex32(r, i), iref, pref);
1619 }
1620 #undef IMPL_ADDELT
1621 
1622 /// Helper to enumerate elements of coordinate scheme, one per value type.
1623 #define IMPL_GETNEXT(VNAME, V)                                                 \
1624   bool _mlir_ciface_getNext##VNAME(void *coo,                                  \
1625                                    StridedMemRefType<index_type, 1> *iref,     \
1626                                    StridedMemRefType<V, 0> *vref) {            \
1627     assert(coo &&iref &&vref);                                                 \
1628     assert(iref->strides[0] == 1);                                             \
1629     index_type *indx = iref->data + iref->offset;                              \
1630     V *value = vref->data + vref->offset;                                      \
1631     const uint64_t isize = iref->sizes[0];                                     \
1632     const Element<V> *elem =                                                   \
1633         static_cast<SparseTensorCOO<V> *>(coo)->getNext();                     \
1634     if (elem == nullptr)                                                       \
1635       return false;                                                            \
1636     for (uint64_t r = 0; r < isize; r++)                                       \
1637       indx[r] = elem->indices[r];                                              \
1638     *value = elem->value;                                                      \
1639     return true;                                                               \
1640   }
1641 FOREVERY_V(IMPL_GETNEXT)
1642 #undef IMPL_GETNEXT
1643 
1644 /// Insert elements in lexicographical index order, one per value type.
1645 #define IMPL_LEXINSERT(VNAME, V)                                               \
1646   void _mlir_ciface_lexInsert##VNAME(                                          \
1647       void *tensor, StridedMemRefType<index_type, 1> *cref, V val) {           \
1648     assert(tensor &&cref);                                                     \
1649     assert(cref->strides[0] == 1);                                             \
1650     index_type *cursor = cref->data + cref->offset;                            \
1651     assert(cursor);                                                            \
1652     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1653   }
1654 FOREVERY_SIMPLEX_V(IMPL_LEXINSERT)
1655 // `complex64` apparently doesn't encounter any ABI issues (yet).
1656 IMPL_LEXINSERT(C64, complex64)
1657 // TODO: cleaner way to avoid ABI padding problem?
1658 IMPL_LEXINSERT(C32ABI, complex32)
1659 void _mlir_ciface_lexInsertC32(void *tensor,
1660                                StridedMemRefType<index_type, 1> *cref, float r,
1661                                float i) {
1662   _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
1663 }
1664 #undef IMPL_LEXINSERT
1665 
1666 /// Insert using expansion, one per value type.
1667 #define IMPL_EXPINSERT(VNAME, V)                                               \
1668   void _mlir_ciface_expInsert##VNAME(                                          \
1669       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1670       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1671       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1672     assert(tensor &&cref &&vref &&fref &&aref);                                \
1673     assert(cref->strides[0] == 1);                                             \
1674     assert(vref->strides[0] == 1);                                             \
1675     assert(fref->strides[0] == 1);                                             \
1676     assert(aref->strides[0] == 1);                                             \
1677     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1678     index_type *cursor = cref->data + cref->offset;                            \
1679     V *values = vref->data + vref->offset;                                     \
1680     bool *filled = fref->data + fref->offset;                                  \
1681     index_type *added = aref->data + aref->offset;                             \
1682     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1683         cursor, values, filled, added, count);                                 \
1684   }
1685 FOREVERY_V(IMPL_EXPINSERT)
1686 #undef IMPL_EXPINSERT
1687 
1688 /// Output a sparse tensor, one per value type.
1689 #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \
1690   void outSparseTensor##VNAME(void *coo, void *dest, bool sort) {              \
1691     return outSparseTensor<V>(coo, dest, sort);                                \
1692   }
1693 FOREVERY_V(IMPL_OUTSPARSETENSOR)
1694 #undef IMPL_OUTSPARSETENSOR
1695 
1696 //===----------------------------------------------------------------------===//
1697 //
1698 // Public API with methods that accept C-style data structures to interact
1699 // with sparse tensors, which are only visible as opaque pointers externally.
1700 // These methods can be used both by MLIR compiler-generated code as well as by
1701 // an external runtime that wants to interact with MLIR compiler-generated code.
1702 //
1703 //===----------------------------------------------------------------------===//
1704 
1705 /// Helper method to read a sparse tensor filename from the environment,
1706 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1707 char *getTensorFilename(index_type id) {
1708   char var[80];
1709   sprintf(var, "TENSOR%" PRIu64, id);
1710   char *env = getenv(var);
1711   if (!env)
1712     FATAL("Environment variable %s is not set\n", var);
1713   return env;
1714 }
1715 
1716 /// Returns size of sparse tensor in given dimension.
1717 index_type sparseDimSize(void *tensor, index_type d) {
1718   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1719 }
1720 
1721 /// Finalizes lexicographic insertions.
1722 void endInsert(void *tensor) {
1723   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1724 }
1725 
1726 /// Releases sparse tensor storage.
1727 void delSparseTensor(void *tensor) {
1728   delete static_cast<SparseTensorStorageBase *>(tensor);
1729 }
1730 
1731 /// Releases sparse tensor coordinate scheme.
1732 #define IMPL_DELCOO(VNAME, V)                                                  \
1733   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1734     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1735   }
1736 FOREVERY_V(IMPL_DELCOO)
1737 #undef IMPL_DELCOO
1738 
1739 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
1740 /// data structures. The expected parameters are:
1741 ///
1742 ///   rank:    rank of tensor
1743 ///   nse:     number of specified elements (usually the nonzeros)
1744 ///   shape:   array with dimension size for each rank
1745 ///   values:  a "nse" array with values for all specified elements
1746 ///   indices: a flat "nse x rank" array with indices for all specified elements
1747 ///   perm:    the permutation of the dimensions in the storage
1748 ///   sparse:  the sparsity for the dimensions
1749 ///
1750 /// For example, the sparse matrix
1751 ///     | 1.0 0.0 0.0 |
1752 ///     | 0.0 5.0 3.0 |
1753 /// can be passed as
1754 ///      rank    = 2
1755 ///      nse     = 3
1756 ///      shape   = [2, 3]
1757 ///      values  = [1.0, 5.0, 3.0]
1758 ///      indices = [ 0, 0,  1, 1,  1, 2]
1759 //
1760 // TODO: generalize beyond 64-bit indices.
1761 //
1762 #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V)                               \
1763   void *convertToMLIRSparseTensor##VNAME(                                      \
1764       uint64_t rank, uint64_t nse, uint64_t *shape, V *values,                 \
1765       uint64_t *indices, uint64_t *perm, uint8_t *sparse) {                    \
1766     return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm,      \
1767                                  sparse);                                      \
1768   }
1769 FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
1770 #undef IMPL_CONVERTTOMLIRSPARSETENSOR
1771 
1772 /// Converts a sparse tensor to COO-flavored format expressed using C-style
1773 /// data structures. The expected output parameters are pointers for these
1774 /// values:
1775 ///
1776 ///   rank:    rank of tensor
1777 ///   nse:     number of specified elements (usually the nonzeros)
1778 ///   shape:   array with dimension size for each rank
1779 ///   values:  a "nse" array with values for all specified elements
1780 ///   indices: a flat "nse x rank" array with indices for all specified elements
1781 ///
1782 /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
1783 /// from convertToMLIRSparseTensor.
1784 ///
1785 //  TODO: Currently, values are copied from SparseTensorStorage to
1786 //  SparseTensorCOO, then to the output. We may want to reduce the number of
1787 //  copies.
1788 //
1789 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1790 // compressed
1791 //
1792 #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V)                             \
1793   void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank,       \
1794                                           uint64_t *pNse, uint64_t **pShape,   \
1795                                           V **pValues, uint64_t **pIndices) {  \
1796     fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices);   \
1797   }
1798 FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
1799 #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
1800 
1801 } // extern "C"
1802 
1803 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1804