1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <complex>
25 #include <cctype>
26 #include <cinttypes>
27 #include <cstdio>
28 #include <cstdlib>
29 #include <cstring>
30 #include <fstream>
31 #include <functional>
32 #include <iostream>
33 #include <limits>
34 #include <numeric>
35 #include <vector>
36 
37 using complex64 = std::complex<double>;
38 using complex32 = std::complex<float>;
39 
40 //===----------------------------------------------------------------------===//
41 //
42 // Internal support for storing and reading sparse tensors.
43 //
44 // The following memory-resident sparse storage schemes are supported:
45 //
46 // (a) A coordinate scheme for temporarily storing and lexicographically
47 //     sorting a sparse tensor by index (SparseTensorCOO).
48 //
49 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
50 //     per-dimension sparse/dense annnotations together with a dimension
51 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
52 //
53 // The following external formats are supported:
54 //
55 // (1) Matrix Market Exchange (MME): *.mtx
56 //     https://math.nist.gov/MatrixMarket/formats.html
57 //
58 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
59 //     http://frostt.io/tensors/file-formats.html
60 //
61 // Two public APIs are supported:
62 //
63 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
64 //     tensors. These methods should be used exclusively by MLIR
65 //     compiler-generated code.
66 //
67 // (II) Methods that accept C-style data structures to interact with sparse
68 //      tensors. These methods can be used by any external runtime that wants
69 //      to interact with MLIR compiler-generated code.
70 //
71 // In both cases (I) and (II), the SparseTensorStorage format is externally
72 // only visible as an opaque pointer.
73 //
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 
78 static constexpr int kColWidth = 1025;
79 
80 /// A version of `operator*` on `uint64_t` which checks for overflows.
81 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
82   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
83          "Integer overflow");
84   return lhs * rhs;
85 }
86 
87 // TODO: adjust this so it can be used by `openSparseTensorCOO` too.
88 // That version doesn't have the permutation, and the `sizes` are
89 // a pointer/C-array rather than `std::vector`.
90 //
91 /// Asserts that the `sizes` (in target-order) under the `perm` (mapping
92 /// semantic-order to target-order) are a refinement of the desired `shape`
93 /// (in semantic-order).
94 ///
95 /// Precondition: `perm` and `shape` must be valid for `rank`.
96 static inline void
97 assertPermutedSizesMatchShape(const std::vector<uint64_t> &sizes, uint64_t rank,
98                               const uint64_t *perm, const uint64_t *shape) {
99   assert(perm && shape);
100   assert(rank == sizes.size() && "Rank mismatch");
101   for (uint64_t r = 0; r < rank; r++)
102     assert((shape[r] == 0 || shape[r] == sizes[perm[r]]) &&
103            "Dimension size mismatch");
104 }
105 
106 /// A sparse tensor element in coordinate scheme (value and indices).
107 /// For example, a rank-1 vector element would look like
108 ///   ({i}, a[i])
109 /// and a rank-5 tensor element like
110 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
111 /// We use pointer to a shared index pool rather than e.g. a direct
112 /// vector since that (1) reduces the per-element memory footprint, and
113 /// (2) centralizes the memory reservation and (re)allocation to one place.
114 template <typename V>
115 struct Element {
116   Element(uint64_t *ind, V val) : indices(ind), value(val){};
117   uint64_t *indices; // pointer into shared index pool
118   V value;
119 };
120 
121 /// The type of callback functions which receive an element.  We avoid
122 /// packaging the coordinates and value together as an `Element` object
123 /// because this helps keep code somewhat cleaner.
124 template <typename V>
125 using ElementConsumer =
126     const std::function<void(const std::vector<uint64_t> &, V)> &;
127 
128 /// A memory-resident sparse tensor in coordinate scheme (collection of
129 /// elements). This data structure is used to read a sparse tensor from
130 /// any external format into memory and sort the elements lexicographically
131 /// by indices before passing it back to the client (most packed storage
132 /// formats require the elements to appear in lexicographic index order).
133 template <typename V>
134 struct SparseTensorCOO {
135 public:
136   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
137       : sizes(szs) {
138     if (capacity) {
139       elements.reserve(capacity);
140       indices.reserve(capacity * getRank());
141     }
142   }
143 
144   /// Adds element as indices and value.
145   void add(const std::vector<uint64_t> &ind, V val) {
146     assert(!iteratorLocked && "Attempt to add() after startIterator()");
147     uint64_t *base = indices.data();
148     uint64_t size = indices.size();
149     uint64_t rank = getRank();
150     assert(rank == ind.size());
151     for (uint64_t r = 0; r < rank; r++) {
152       assert(ind[r] < sizes[r]); // within bounds
153       indices.push_back(ind[r]);
154     }
155     // This base only changes if indices were reallocated. In that case, we
156     // need to correct all previous pointers into the vector. Note that this
157     // only happens if we did not set the initial capacity right, and then only
158     // for every internal vector reallocation (which with the doubling rule
159     // should only incur an amortized linear overhead).
160     uint64_t *newBase = indices.data();
161     if (newBase != base) {
162       for (uint64_t i = 0, n = elements.size(); i < n; i++)
163         elements[i].indices = newBase + (elements[i].indices - base);
164       base = newBase;
165     }
166     // Add element as (pointer into shared index pool, value) pair.
167     elements.emplace_back(base + size, val);
168   }
169 
170   /// Sorts elements lexicographically by index.
171   void sort() {
172     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
173     // TODO: we may want to cache an `isSorted` bit, to avoid
174     // unnecessary/redundant sorting.
175     std::sort(elements.begin(), elements.end(),
176               [this](const Element<V> &e1, const Element<V> &e2) {
177                 uint64_t rank = getRank();
178                 for (uint64_t r = 0; r < rank; r++) {
179                   if (e1.indices[r] == e2.indices[r])
180                     continue;
181                   return e1.indices[r] < e2.indices[r];
182                 }
183                 return false;
184               });
185   }
186 
187   /// Returns rank.
188   uint64_t getRank() const { return sizes.size(); }
189 
190   /// Getter for sizes array.
191   const std::vector<uint64_t> &getSizes() const { return sizes; }
192 
193   /// Getter for elements array.
194   const std::vector<Element<V>> &getElements() const { return elements; }
195 
196   /// Switch into iterator mode.
197   void startIterator() {
198     iteratorLocked = true;
199     iteratorPos = 0;
200   }
201 
202   /// Get the next element.
203   const Element<V> *getNext() {
204     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
205     if (iteratorPos < elements.size())
206       return &(elements[iteratorPos++]);
207     iteratorLocked = false;
208     return nullptr;
209   }
210 
211   /// Factory method. Permutes the original dimensions according to
212   /// the given ordering and expects subsequent add() calls to honor
213   /// that same ordering for the given indices. The result is a
214   /// fully permuted coordinate scheme.
215   ///
216   /// Precondition: `sizes` and `perm` must be valid for `rank`.
217   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
218                                                 const uint64_t *sizes,
219                                                 const uint64_t *perm,
220                                                 uint64_t capacity = 0) {
221     std::vector<uint64_t> permsz(rank);
222     for (uint64_t r = 0; r < rank; r++) {
223       assert(sizes[r] > 0 && "Dimension size zero has trivial storage");
224       permsz[perm[r]] = sizes[r];
225     }
226     return new SparseTensorCOO<V>(permsz, capacity);
227   }
228 
229 private:
230   const std::vector<uint64_t> sizes; // per-dimension sizes
231   std::vector<Element<V>> elements;  // all COO elements
232   std::vector<uint64_t> indices;     // shared index pool
233   bool iteratorLocked = false;
234   unsigned iteratorPos = 0;
235 };
236 
237 // Forward.
238 template <typename V>
239 class SparseTensorEnumeratorBase;
240 
241 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
242 /// takes responsibility for all the `<P,I,V>`-independent aspects
243 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
244 /// we use function overloading to implement "partial" method
245 /// specialization, which the C-API relies on to catch type errors
246 /// arising from our use of opaque pointers.
247 class SparseTensorStorageBase {
248 public:
249   /// Constructs a new storage object.  The `perm` maps the tensor's
250   /// semantic-ordering of dimensions to this object's storage-order.
251   /// The `szs` and `sparsity` arrays are already in storage-order.
252   ///
253   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
254   SparseTensorStorageBase(const std::vector<uint64_t> &szs,
255                           const uint64_t *perm, const DimLevelType *sparsity)
256       : dimSizes(szs), rev(getRank()),
257         dimTypes(sparsity, sparsity + getRank()) {
258     assert(perm && sparsity);
259     const uint64_t rank = getRank();
260     // Validate parameters.
261     assert(rank > 0 && "Trivial shape is unsupported");
262     for (uint64_t r = 0; r < rank; r++) {
263       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
264       assert((dimTypes[r] == DimLevelType::kDense ||
265               dimTypes[r] == DimLevelType::kCompressed) &&
266              "Unsupported DimLevelType");
267     }
268     // Construct the "reverse" (i.e., inverse) permutation.
269     for (uint64_t r = 0; r < rank; r++)
270       rev[perm[r]] = r;
271   }
272 
273   virtual ~SparseTensorStorageBase() = default;
274 
275   /// Get the rank of the tensor.
276   uint64_t getRank() const { return dimSizes.size(); }
277 
278   /// Getter for the dimension-sizes array, in storage-order.
279   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
280 
281   /// Safely lookup the size of the given (storage-order) dimension.
282   uint64_t getDimSize(uint64_t d) const {
283     assert(d < getRank());
284     return dimSizes[d];
285   }
286 
287   /// Getter for the "reverse" permutation, which maps this object's
288   /// storage-order to the tensor's semantic-order.
289   const std::vector<uint64_t> &getRev() const { return rev; }
290 
291   /// Getter for the dimension-types array, in storage-order.
292   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
293 
294   /// Safely check if the (storage-order) dimension uses compressed storage.
295   bool isCompressedDim(uint64_t d) const {
296     assert(d < getRank());
297     return (dimTypes[d] == DimLevelType::kCompressed);
298   }
299 
300   /// Allocate a new enumerator.
301   virtual void newEnumerator(SparseTensorEnumeratorBase<double> **, uint64_t,
302                              const uint64_t *) const {
303     fatal("enumf64");
304   }
305   virtual void newEnumerator(SparseTensorEnumeratorBase<float> **, uint64_t,
306                              const uint64_t *) const {
307     fatal("enumf32");
308   }
309   virtual void newEnumerator(SparseTensorEnumeratorBase<int64_t> **, uint64_t,
310                              const uint64_t *) const {
311     fatal("enumi64");
312   }
313   virtual void newEnumerator(SparseTensorEnumeratorBase<int32_t> **, uint64_t,
314                              const uint64_t *) const {
315     fatal("enumi32");
316   }
317   virtual void newEnumerator(SparseTensorEnumeratorBase<int16_t> **, uint64_t,
318                              const uint64_t *) const {
319     fatal("enumi16");
320   }
321   virtual void newEnumerator(SparseTensorEnumeratorBase<int8_t> **, uint64_t,
322                              const uint64_t *) const {
323     fatal("enumi8");
324   }
325   virtual void newEnumerator(SparseTensorEnumeratorBase<complex64> **, uint64_t,
326                              const uint64_t *) const {
327     fatal("enumc64");
328   }
329   virtual void newEnumerator(SparseTensorEnumeratorBase<complex32> **, uint64_t,
330                              const uint64_t *) const {
331     fatal("enumc32");
332   }
333 
334   /// Overhead storage.
335   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
336   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
337   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
338   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
339   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
340   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
341   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
342   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
343 
344   /// Primary storage.
345   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
346   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
347   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
348   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
349   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
350   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
351   virtual void getValues(std::vector<complex64> **) { fatal("valc64"); }
352   virtual void getValues(std::vector<complex32> **) { fatal("valc32"); }
353 
354   /// Element-wise insertion in lexicographic index order.
355   virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
356   virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); }
357   virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); }
358   virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
359   virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
360   virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
361   virtual void lexInsert(const uint64_t *, complex64) { fatal("insc64"); }
362   virtual void lexInsert(const uint64_t *, complex32) { fatal("insc32"); }
363 
364   /// Expanded insertion.
365   virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
366     fatal("expf64");
367   }
368   virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
369     fatal("expf32");
370   }
371   virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
372     fatal("expi64");
373   }
374   virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
375     fatal("expi32");
376   }
377   virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
378     fatal("expi16");
379   }
380   virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
381     fatal("expi8");
382   }
383   virtual void expInsert(uint64_t *, complex64 *, bool *, uint64_t *,
384                          uint64_t) {
385     fatal("expc64");
386   }
387   virtual void expInsert(uint64_t *, complex32 *, bool *, uint64_t *,
388                          uint64_t) {
389     fatal("expc32");
390   }
391 
392   /// Finishes insertion.
393   virtual void endInsert() = 0;
394 
395 protected:
396   // Since this class is virtual, we must disallow public copying in
397   // order to avoid "slicing".  Since this class has data members,
398   // that means making copying protected.
399   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
400   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
401   // Copy-assignment would be implicitly deleted (because `dimSizes`
402   // is const), so we explicitly delete it for clarity.
403   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
404 
405 private:
406   static void fatal(const char *tp) {
407     fprintf(stderr, "unsupported %s\n", tp);
408     exit(1);
409   }
410 
411   const std::vector<uint64_t> dimSizes;
412   std::vector<uint64_t> rev;
413   const std::vector<DimLevelType> dimTypes;
414 };
415 
416 // Forward.
417 template <typename P, typename I, typename V>
418 class SparseTensorEnumerator;
419 
420 /// A memory-resident sparse tensor using a storage scheme based on
421 /// per-dimension sparse/dense annotations. This data structure provides a
422 /// bufferized form of a sparse tensor type. In contrast to generating setup
423 /// methods for each differently annotated sparse tensor, this method provides
424 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
425 /// and annotations to implement all required setup in a general manner.
426 template <typename P, typename I, typename V>
427 class SparseTensorStorage : public SparseTensorStorageBase {
428   /// Private constructor to share code between the other constructors.
429   /// Beware that the object is not necessarily guaranteed to be in a
430   /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
431   /// doesn't entail `!(pointers[d].empty())`.
432   ///
433   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
434   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
435                       const DimLevelType *sparsity)
436       : SparseTensorStorageBase(szs, perm, sparsity), pointers(getRank()),
437         indices(getRank()), idx(getRank()) {}
438 
439 public:
440   /// Constructs a sparse tensor storage scheme with the given dimensions,
441   /// permutation, and per-dimension dense/sparse annotations, using
442   /// the coordinate scheme tensor for the initial contents if provided.
443   ///
444   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
445   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
446                       const DimLevelType *sparsity, SparseTensorCOO<V> *coo)
447       : SparseTensorStorage(szs, perm, sparsity) {
448     // Provide hints on capacity of pointers and indices.
449     // TODO: needs much fine-tuning based on actual sparsity; currently
450     //       we reserve pointer/index space based on all previous dense
451     //       dimensions, which works well up to first sparse dim; but
452     //       we should really use nnz and dense/sparse distribution.
453     bool allDense = true;
454     uint64_t sz = 1;
455     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
456       if (isCompressedDim(r)) {
457         // TODO: Take a parameter between 1 and `sizes[r]`, and multiply
458         // `sz` by that before reserving. (For now we just use 1.)
459         pointers[r].reserve(sz + 1);
460         pointers[r].push_back(0);
461         indices[r].reserve(sz);
462         sz = 1;
463         allDense = false;
464       } else { // Dense dimension.
465         sz = checkedMul(sz, getDimSizes()[r]);
466       }
467     }
468     // Then assign contents from coordinate scheme tensor if provided.
469     if (coo) {
470       // Ensure both preconditions of `fromCOO`.
471       assert(coo->getSizes() == getDimSizes() && "Tensor size mismatch");
472       coo->sort();
473       // Now actually insert the `elements`.
474       const std::vector<Element<V>> &elements = coo->getElements();
475       uint64_t nnz = elements.size();
476       values.reserve(nnz);
477       fromCOO(elements, 0, nnz, 0);
478     } else if (allDense) {
479       values.resize(sz, 0);
480     }
481   }
482 
483   /// Constructs a sparse tensor storage scheme with the given dimensions,
484   /// permutation, and per-dimension dense/sparse annotations, using
485   /// the given sparse tensor for the initial contents.
486   ///
487   /// Preconditions:
488   /// * `perm` and `sparsity` must be valid for `szs.size()`.
489   /// * The `tensor` must have the same value type `V`.
490   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
491                       const DimLevelType *sparsity,
492                       const SparseTensorStorageBase &tensor);
493 
494   ~SparseTensorStorage() override = default;
495 
496   /// Partially specialize these getter methods based on template types.
497   void getPointers(std::vector<P> **out, uint64_t d) override {
498     assert(d < getRank());
499     *out = &pointers[d];
500   }
501   void getIndices(std::vector<I> **out, uint64_t d) override {
502     assert(d < getRank());
503     *out = &indices[d];
504   }
505   void getValues(std::vector<V> **out) override { *out = &values; }
506 
507   /// Partially specialize lexicographical insertions based on template types.
508   void lexInsert(const uint64_t *cursor, V val) override {
509     // First, wrap up pending insertion path.
510     uint64_t diff = 0;
511     uint64_t top = 0;
512     if (!values.empty()) {
513       diff = lexDiff(cursor);
514       endPath(diff + 1);
515       top = idx[diff] + 1;
516     }
517     // Then continue with insertion path.
518     insPath(cursor, diff, top, val);
519   }
520 
521   /// Partially specialize expanded insertions based on template types.
522   /// Note that this method resets the values/filled-switch array back
523   /// to all-zero/false while only iterating over the nonzero elements.
524   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
525                  uint64_t count) override {
526     if (count == 0)
527       return;
528     // Sort.
529     std::sort(added, added + count);
530     // Restore insertion path for first insert.
531     const uint64_t lastDim = getRank() - 1;
532     uint64_t index = added[0];
533     cursor[lastDim] = index;
534     lexInsert(cursor, values[index]);
535     assert(filled[index]);
536     values[index] = 0;
537     filled[index] = false;
538     // Subsequent insertions are quick.
539     for (uint64_t i = 1; i < count; i++) {
540       assert(index < added[i] && "non-lexicographic insertion");
541       index = added[i];
542       cursor[lastDim] = index;
543       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
544       assert(filled[index]);
545       values[index] = 0;
546       filled[index] = false;
547     }
548   }
549 
550   /// Finalizes lexicographic insertions.
551   void endInsert() override {
552     if (values.empty())
553       finalizeSegment(0);
554     else
555       endPath(0);
556   }
557 
558   void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
559                      const uint64_t *perm) const override {
560     *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
561   }
562 
563   /// Returns this sparse tensor storage scheme as a new memory-resident
564   /// sparse tensor in coordinate scheme with the given dimension order.
565   ///
566   /// Precondition: `perm` must be valid for `getRank()`.
567   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
568     SparseTensorEnumeratorBase<V> *enumerator;
569     newEnumerator(&enumerator, getRank(), perm);
570     SparseTensorCOO<V> *coo =
571         new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
572     enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
573       coo->add(ind, val);
574     });
575     // TODO: This assertion assumes there are no stored zeros,
576     // or if there are then that we don't filter them out.
577     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
578     assert(coo->getElements().size() == values.size());
579     delete enumerator;
580     return coo;
581   }
582 
583   /// Factory method. Constructs a sparse tensor storage scheme with the given
584   /// dimensions, permutation, and per-dimension dense/sparse annotations,
585   /// using the coordinate scheme tensor for the initial contents if provided.
586   /// In the latter case, the coordinate scheme must respect the same
587   /// permutation as is desired for the new sparse tensor storage.
588   ///
589   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
590   static SparseTensorStorage<P, I, V> *
591   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
592                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
593     SparseTensorStorage<P, I, V> *n = nullptr;
594     if (coo) {
595       const auto &coosz = coo->getSizes();
596       assertPermutedSizesMatchShape(coosz, rank, perm, shape);
597       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
598     } else {
599       std::vector<uint64_t> permsz(rank);
600       for (uint64_t r = 0; r < rank; r++) {
601         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
602         permsz[perm[r]] = shape[r];
603       }
604       // We pass the null `coo` to ensure we select the intended constructor.
605       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
606     }
607     return n;
608   }
609 
610   /// Factory method. Constructs a sparse tensor storage scheme with
611   /// the given dimensions, permutation, and per-dimension dense/sparse
612   /// annotations, using the sparse tensor for the initial contents.
613   ///
614   /// Preconditions:
615   /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
616   /// * The `tensor` must have the same value type `V`.
617   static SparseTensorStorage<P, I, V> *
618   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
619                   const DimLevelType *sparsity,
620                   const SparseTensorStorageBase *source) {
621     assert(source && "Got nullptr for source");
622     SparseTensorEnumeratorBase<V> *enumerator;
623     source->newEnumerator(&enumerator, rank, perm);
624     const auto &permsz = enumerator->permutedSizes();
625     assertPermutedSizesMatchShape(permsz, rank, perm, shape);
626     auto *tensor =
627         new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
628     delete enumerator;
629     return tensor;
630   }
631 
632 private:
633   /// Appends an arbitrary new position to `pointers[d]`.  This method
634   /// checks that `pos` is representable in the `P` type; however, it
635   /// does not check that `pos` is semantically valid (i.e., larger than
636   /// the previous position and smaller than `indices[d].capacity()`).
637   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
638     assert(isCompressedDim(d));
639     assert(pos <= std::numeric_limits<P>::max() &&
640            "Pointer value is too large for the P-type");
641     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
642   }
643 
644   /// Appends index `i` to dimension `d`, in the semantically general
645   /// sense.  For non-dense dimensions, that means appending to the
646   /// `indices[d]` array, checking that `i` is representable in the `I`
647   /// type; however, we do not verify other semantic requirements (e.g.,
648   /// that `i` is in bounds for `sizes[d]`, and not previously occurring
649   /// in the same segment).  For dense dimensions, this method instead
650   /// appends the appropriate number of zeros to the `values` array,
651   /// where `full` is the number of "entries" already written to `values`
652   /// for this segment (aka one after the highest index previously appended).
653   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
654     if (isCompressedDim(d)) {
655       assert(i <= std::numeric_limits<I>::max() &&
656              "Index value is too large for the I-type");
657       indices[d].push_back(static_cast<I>(i));
658     } else { // Dense dimension.
659       assert(i >= full && "Index was already filled");
660       if (i == full)
661         return; // Short-circuit, since it'll be a nop.
662       if (d + 1 == getRank())
663         values.insert(values.end(), i - full, 0);
664       else
665         finalizeSegment(d + 1, 0, i - full);
666     }
667   }
668 
669   /// Writes the given coordinate to `indices[d][pos]`.  This method
670   /// checks that `i` is representable in the `I` type; however, it
671   /// does not check that `i` is semantically valid (i.e., in bounds
672   /// for `sizes[d]` and not elsewhere occurring in the same segment).
673   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
674     assert(isCompressedDim(d));
675     // Subscript assignment to `std::vector` requires that the `pos`-th
676     // entry has been initialized; thus we must be sure to check `size()`
677     // here, instead of `capacity()` as would be ideal.
678     assert(pos < indices[d].size() && "Index position is out of bounds");
679     assert(i <= std::numeric_limits<I>::max() &&
680            "Index value is too large for the I-type");
681     indices[d][pos] = static_cast<I>(i);
682   }
683 
684   /// Computes the assembled-size associated with the `d`-th dimension,
685   /// given the assembled-size associated with the `(d-1)`-th dimension.
686   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
687   /// storage, as opposed to "dimension-sizes" which are the cardinality
688   /// of coordinates for that dimension.
689   ///
690   /// Precondition: the `pointers[d]` array must be fully initialized
691   /// before calling this method.
692   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
693     if (isCompressedDim(d))
694       return pointers[d][parentSz];
695     // else if dense:
696     return parentSz * getDimSizes()[d];
697   }
698 
699   /// Initializes sparse tensor storage scheme from a memory-resident sparse
700   /// tensor in coordinate scheme. This method prepares the pointers and
701   /// indices arrays under the given per-dimension dense/sparse annotations.
702   ///
703   /// Preconditions:
704   /// (1) the `elements` must be lexicographically sorted.
705   /// (2) the indices of every element are valid for `sizes` (equal rank
706   ///     and pointwise less-than).
707   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
708                uint64_t hi, uint64_t d) {
709     uint64_t rank = getRank();
710     assert(d <= rank && hi <= elements.size());
711     // Once dimensions are exhausted, insert the numerical values.
712     if (d == rank) {
713       assert(lo < hi);
714       values.push_back(elements[lo].value);
715       return;
716     }
717     // Visit all elements in this interval.
718     uint64_t full = 0;
719     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
720       // Find segment in interval with same index elements in this dimension.
721       uint64_t i = elements[lo].indices[d];
722       uint64_t seg = lo + 1;
723       while (seg < hi && elements[seg].indices[d] == i)
724         seg++;
725       // Handle segment in interval for sparse or dense dimension.
726       appendIndex(d, full, i);
727       full = i + 1;
728       fromCOO(elements, lo, seg, d + 1);
729       // And move on to next segment in interval.
730       lo = seg;
731     }
732     // Finalize the sparse pointer structure at this dimension.
733     finalizeSegment(d, full);
734   }
735 
736   /// Finalize the sparse pointer structure at this dimension.
737   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
738     if (count == 0)
739       return; // Short-circuit, since it'll be a nop.
740     if (isCompressedDim(d)) {
741       appendPointer(d, indices[d].size(), count);
742     } else { // Dense dimension.
743       const uint64_t sz = getDimSizes()[d];
744       assert(sz >= full && "Segment is overfull");
745       count = checkedMul(count, sz - full);
746       // For dense storage we must enumerate all the remaining coordinates
747       // in this dimension (i.e., coordinates after the last non-zero
748       // element), and either fill in their zero values or else recurse
749       // to finalize some deeper dimension.
750       if (d + 1 == getRank())
751         values.insert(values.end(), count, 0);
752       else
753         finalizeSegment(d + 1, 0, count);
754     }
755   }
756 
757   /// Wraps up a single insertion path, inner to outer.
758   void endPath(uint64_t diff) {
759     uint64_t rank = getRank();
760     assert(diff <= rank);
761     for (uint64_t i = 0; i < rank - diff; i++) {
762       const uint64_t d = rank - i - 1;
763       finalizeSegment(d, idx[d] + 1);
764     }
765   }
766 
767   /// Continues a single insertion path, outer to inner.
768   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
769     uint64_t rank = getRank();
770     assert(diff < rank);
771     for (uint64_t d = diff; d < rank; d++) {
772       uint64_t i = cursor[d];
773       appendIndex(d, top, i);
774       top = 0;
775       idx[d] = i;
776     }
777     values.push_back(val);
778   }
779 
780   /// Finds the lexicographic differing dimension.
781   uint64_t lexDiff(const uint64_t *cursor) const {
782     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
783       if (cursor[r] > idx[r])
784         return r;
785       else
786         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
787     assert(0 && "duplication insertion");
788     return -1u;
789   }
790 
791   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
792   // the cost of virtual-function dispatch in inner loops), without
793   // making them public to other client code.
794   friend class SparseTensorEnumerator<P, I, V>;
795 
796   std::vector<std::vector<P>> pointers;
797   std::vector<std::vector<I>> indices;
798   std::vector<V> values;
799   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
800 };
801 
802 /// A (higher-order) function object for enumerating the elements of some
803 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
804 /// method encapsulates the loop-nest for enumerating the elements of
805 /// the source tensor (in whatever order is best for the source tensor),
806 /// and applies a permutation to the coordinates/indices before handing
807 /// each element to the callback.  A single enumerator object can be
808 /// freely reused for several calls to `forallElements`, just so long
809 /// as each call is sequential with respect to one another.
810 ///
811 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
812 /// passed to the constructor; thus, objects of this class must not
813 /// outlive the sparse tensor they depend on.
814 ///
815 /// Design Note: The reason we define this class instead of simply using
816 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
817 /// the `<P,I>` template parameters from MLIR client code (to simplify the
818 /// type parameters used for direct sparse-to-sparse conversion).  And the
819 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
820 /// than simply using this class, is to avoid the cost of virtual-method
821 /// dispatch within the loop-nest.
822 template <typename V>
823 class SparseTensorEnumeratorBase {
824 public:
825   /// Constructs an enumerator with the given permutation for mapping
826   /// the semantic-ordering of dimensions to the desired target-ordering.
827   ///
828   /// Preconditions:
829   /// * the `tensor` must have the same `V` value type.
830   /// * `perm` must be valid for `rank`.
831   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
832                              uint64_t rank, const uint64_t *perm)
833       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
834         cursor(getRank()) {
835     assert(perm && "Received nullptr for permutation");
836     assert(rank == getRank() && "Permutation rank mismatch");
837     const auto &rev = src.getRev();        // source stg-order -> semantic-order
838     const auto &sizes = src.getDimSizes(); // in source storage-order
839     for (uint64_t s = 0; s < rank; s++) {  // `s` source storage-order
840       uint64_t t = perm[rev[s]];           // `t` target-order
841       reord[s] = t;
842       permsz[t] = sizes[s];
843     }
844   }
845 
846   virtual ~SparseTensorEnumeratorBase() = default;
847 
848   // We disallow copying to help avoid leaking the `src` reference.
849   // (In addition to avoiding the problem of slicing.)
850   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
851   SparseTensorEnumeratorBase &
852   operator=(const SparseTensorEnumeratorBase &) = delete;
853 
854   /// Returns the source/target tensor's rank.  (The source-rank and
855   /// target-rank are always equal since we only support permutations.
856   /// Though once we add support for other dimension mappings, this
857   /// method will have to be split in two.)
858   uint64_t getRank() const { return permsz.size(); }
859 
860   /// Returns the target tensor's dimension sizes.
861   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
862 
863   /// Enumerates all elements of the source tensor, permutes their
864   /// indices, and passes the permuted element to the callback.
865   /// The callback must not store the cursor reference directly,
866   /// since this function reuses the storage.  Instead, the callback
867   /// must copy it if they want to keep it.
868   virtual void forallElements(ElementConsumer<V> yield) = 0;
869 
870 protected:
871   const SparseTensorStorageBase &src;
872   std::vector<uint64_t> permsz; // in target order.
873   std::vector<uint64_t> reord;  // source storage-order -> target order.
874   std::vector<uint64_t> cursor; // in target order.
875 };
876 
877 template <typename P, typename I, typename V>
878 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
879   using Base = SparseTensorEnumeratorBase<V>;
880 
881 public:
882   /// Constructs an enumerator with the given permutation for mapping
883   /// the semantic-ordering of dimensions to the desired target-ordering.
884   ///
885   /// Precondition: `perm` must be valid for `rank`.
886   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
887                          uint64_t rank, const uint64_t *perm)
888       : Base(tensor, rank, perm) {}
889 
890   ~SparseTensorEnumerator() final override = default;
891 
892   void forallElements(ElementConsumer<V> yield) final override {
893     forallElements(yield, 0, 0);
894   }
895 
896 private:
897   /// The recursive component of the public `forallElements`.
898   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
899                       uint64_t d) {
900     // Recover the `<P,I,V>` type parameters of `src`.
901     const auto &src =
902         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
903     if (d == Base::getRank()) {
904       assert(parentPos < src.values.size() &&
905              "Value position is out of bounds");
906       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
907       yield(this->cursor, src.values[parentPos]);
908     } else if (src.isCompressedDim(d)) {
909       // Look up the bounds of the `d`-level segment determined by the
910       // `d-1`-level position `parentPos`.
911       const std::vector<P> &pointers_d = src.pointers[d];
912       assert(parentPos + 1 < pointers_d.size() &&
913              "Parent pointer position is out of bounds");
914       const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
915       const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
916       // Loop-invariant code for looking up the `d`-level coordinates/indices.
917       const std::vector<I> &indices_d = src.indices[d];
918       assert(pstop - 1 < indices_d.size() && "Index position is out of bounds");
919       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
920       for (uint64_t pos = pstart; pos < pstop; pos++) {
921         cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
922         forallElements(yield, pos, d + 1);
923       }
924     } else { // Dense dimension.
925       const uint64_t sz = src.getDimSizes()[d];
926       const uint64_t pstart = parentPos * sz;
927       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
928       for (uint64_t i = 0; i < sz; i++) {
929         cursor_reord_d = i;
930         forallElements(yield, pstart + i, d + 1);
931       }
932     }
933   }
934 };
935 
936 /// Statistics regarding the number of nonzero subtensors in
937 /// a source tensor, for direct sparse=>sparse conversion a la
938 /// <https://arxiv.org/abs/2001.02609>.
939 ///
940 /// N.B., this class stores references to the parameters passed to
941 /// the constructor; thus, objects of this class must not outlive
942 /// those parameters.
943 class SparseTensorNNZ {
944 public:
945   /// Allocate the statistics structure for the desired sizes and
946   /// sparsity (in the target tensor's storage-order).  This constructor
947   /// does not actually populate the statistics, however; for that see
948   /// `initialize`.
949   ///
950   /// Precondition: `szs` must not contain zeros.
951   SparseTensorNNZ(const std::vector<uint64_t> &szs,
952                   const std::vector<DimLevelType> &sparsity)
953       : dimSizes(szs), dimTypes(sparsity), nnz(getRank()) {
954     assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
955     bool uncompressed = true;
956     uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
957     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
958       switch (dimTypes[r]) {
959       case DimLevelType::kCompressed:
960         assert(uncompressed &&
961                "Multiple compressed layers not currently supported");
962         uncompressed = false;
963         nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
964         break;
965       case DimLevelType::kDense:
966         assert(uncompressed &&
967                "Dense after compressed not currently supported");
968         break;
969       case DimLevelType::kSingleton:
970         // Singleton after Compressed causes no problems for allocating
971         // `nnz` nor for the yieldPos loop.  This remains true even
972         // when adding support for multiple compressed dimensions or
973         // for dense-after-compressed.
974         break;
975       }
976       sz = checkedMul(sz, dimSizes[r]);
977     }
978   }
979 
980   // We disallow copying to help avoid leaking the stored references.
981   SparseTensorNNZ(const SparseTensorNNZ &) = delete;
982   SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
983 
984   /// Returns the rank of the target tensor.
985   uint64_t getRank() const { return dimSizes.size(); }
986 
987   /// Enumerate the source tensor to fill in the statistics.  The
988   /// enumerator should already incorporate the permutation (from
989   /// semantic-order to the target storage-order).
990   template <typename V>
991   void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
992     assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
993     assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
994     enumerator.forallElements(
995         [this](const std::vector<uint64_t> &ind, V) { add(ind); });
996   }
997 
998   /// The type of callback functions which receive an nnz-statistic.
999   using NNZConsumer = const std::function<void(uint64_t)> &;
1000 
1001   /// Lexicographically enumerates all indicies for dimensions strictly
1002   /// less than `stopDim`, and passes their nnz statistic to the callback.
1003   /// Since our use-case only requires the statistic not the coordinates
1004   /// themselves, we do not bother to construct those coordinates.
1005   void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
1006     assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
1007     assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
1008            "Cannot look up non-compressed dimensions");
1009     forallIndices(yield, stopDim, 0, 0);
1010   }
1011 
1012 private:
1013   /// Adds a new element (i.e., increment its statistics).  We use
1014   /// a method rather than inlining into the lambda in `initialize`,
1015   /// to avoid spurious templating over `V`.  And this method is private
1016   /// to avoid needing to re-assert validity of `ind` (which is guaranteed
1017   /// by `forallElements`).
1018   void add(const std::vector<uint64_t> &ind) {
1019     uint64_t parentPos = 0;
1020     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1021       if (dimTypes[r] == DimLevelType::kCompressed)
1022         nnz[r][parentPos]++;
1023       parentPos = parentPos * dimSizes[r] + ind[r];
1024     }
1025   }
1026 
1027   /// Recursive component of the public `forallIndices`.
1028   void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
1029                      uint64_t d) const {
1030     assert(d <= stopDim);
1031     if (d == stopDim) {
1032       assert(parentPos < nnz[d].size() && "Cursor is out of range");
1033       yield(nnz[d][parentPos]);
1034     } else {
1035       const uint64_t sz = dimSizes[d];
1036       const uint64_t pstart = parentPos * sz;
1037       for (uint64_t i = 0; i < sz; i++)
1038         forallIndices(yield, stopDim, pstart + i, d + 1);
1039     }
1040   }
1041 
1042   // All of these are in the target storage-order.
1043   const std::vector<uint64_t> &dimSizes;
1044   const std::vector<DimLevelType> &dimTypes;
1045   std::vector<std::vector<uint64_t>> nnz;
1046 };
1047 
1048 template <typename P, typename I, typename V>
1049 SparseTensorStorage<P, I, V>::SparseTensorStorage(
1050     const std::vector<uint64_t> &szs, const uint64_t *perm,
1051     const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1052     : SparseTensorStorage(szs, perm, sparsity) {
1053   SparseTensorEnumeratorBase<V> *enumerator;
1054   tensor.newEnumerator(&enumerator, getRank(), perm);
1055   {
1056     // Initialize the statistics structure.
1057     SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
1058     nnz.initialize(*enumerator);
1059     // Initialize "pointers" overhead (and allocate "indices", "values").
1060     uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
1061     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1062       if (isCompressedDim(r)) {
1063         pointers[r].reserve(parentSz + 1);
1064         pointers[r].push_back(0);
1065         uint64_t currentPos = 0;
1066         nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
1067           currentPos += n;
1068           appendPointer(r, currentPos);
1069         });
1070         assert(pointers[r].size() == parentSz + 1 &&
1071                "Final pointers size doesn't match allocated size");
1072         // That assertion entails `assembledSize(parentSz, r)`
1073         // is now in a valid state.  That is, `pointers[r][parentSz]`
1074         // equals the present value of `currentPos`, which is the
1075         // correct assembled-size for `indices[r]`.
1076       }
1077       // Update assembled-size for the next iteration.
1078       parentSz = assembledSize(parentSz, r);
1079       // Ideally we need only `indices[r].reserve(parentSz)`, however
1080       // the `std::vector` implementation forces us to initialize it too.
1081       // That is, in the yieldPos loop we need random-access assignment
1082       // to `indices[r]`; however, `std::vector`'s subscript-assignment
1083       // only allows assigning to already-initialized positions.
1084       if (isCompressedDim(r))
1085         indices[r].resize(parentSz, 0);
1086     }
1087     values.resize(parentSz, 0); // Both allocate and zero-initialize.
1088   }
1089   // The yieldPos loop
1090   enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
1091     uint64_t parentSz = 1, parentPos = 0;
1092     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1093       if (isCompressedDim(r)) {
1094         // If `parentPos == parentSz` then it's valid as an array-lookup;
1095         // however, it's semantically invalid here since that entry
1096         // does not represent a segment of `indices[r]`.  Moreover, that
1097         // entry must be immutable for `assembledSize` to remain valid.
1098         assert(parentPos < parentSz && "Pointers position is out of bounds");
1099         const uint64_t currentPos = pointers[r][parentPos];
1100         // This increment won't overflow the `P` type, since it can't
1101         // exceed the original value of `pointers[r][parentPos+1]`
1102         // which was already verified to be within bounds for `P`
1103         // when it was written to the array.
1104         pointers[r][parentPos]++;
1105         writeIndex(r, currentPos, ind[r]);
1106         parentPos = currentPos;
1107       } else { // Dense dimension.
1108         parentPos = parentPos * getDimSizes()[r] + ind[r];
1109       }
1110       parentSz = assembledSize(parentSz, r);
1111     }
1112     assert(parentPos < values.size() && "Value position is out of bounds");
1113     values[parentPos] = val;
1114   });
1115   // No longer need the enumerator, so we'll delete it ASAP.
1116   delete enumerator;
1117   // The finalizeYieldPos loop
1118   for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
1119     if (isCompressedDim(r)) {
1120       assert(parentSz == pointers[r].size() - 1 &&
1121              "Actual pointers size doesn't match the expected size");
1122       // Can't check all of them, but at least we can check the last one.
1123       assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
1124              "Pointers got corrupted");
1125       // TODO: optimize this by using `memmove` or similar.
1126       for (uint64_t n = 0; n < parentSz; n++) {
1127         const uint64_t parentPos = parentSz - n;
1128         pointers[r][parentPos] = pointers[r][parentPos - 1];
1129       }
1130       pointers[r][0] = 0;
1131     }
1132     parentSz = assembledSize(parentSz, r);
1133   }
1134 }
1135 
1136 /// Helper to convert string to lower case.
1137 static char *toLower(char *token) {
1138   for (char *c = token; *c; c++)
1139     *c = tolower(*c);
1140   return token;
1141 }
1142 
1143 /// Read the MME header of a general sparse matrix of type real.
1144 static void readMMEHeader(FILE *file, char *filename, char *line,
1145                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
1146   char header[64];
1147   char object[64];
1148   char format[64];
1149   char field[64];
1150   char symmetry[64];
1151   // Read header line.
1152   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
1153              symmetry) != 5) {
1154     fprintf(stderr, "Corrupt header in %s\n", filename);
1155     exit(1);
1156   }
1157   // Set properties
1158   *isPattern = (strcmp(toLower(field), "pattern") == 0);
1159   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
1160   // Make sure this is a general sparse matrix.
1161   if (strcmp(toLower(header), "%%matrixmarket") ||
1162       strcmp(toLower(object), "matrix") ||
1163       strcmp(toLower(format), "coordinate") ||
1164       (strcmp(toLower(field), "real") && !(*isPattern)) ||
1165       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
1166     fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
1167     exit(1);
1168   }
1169   // Skip comments.
1170   while (true) {
1171     if (!fgets(line, kColWidth, file)) {
1172       fprintf(stderr, "Cannot find data in %s\n", filename);
1173       exit(1);
1174     }
1175     if (line[0] != '%')
1176       break;
1177   }
1178   // Next line contains M N NNZ.
1179   idata[0] = 2; // rank
1180   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
1181              idata + 1) != 3) {
1182     fprintf(stderr, "Cannot find size in %s\n", filename);
1183     exit(1);
1184   }
1185 }
1186 
1187 /// Read the "extended" FROSTT header. Although not part of the documented
1188 /// format, we assume that the file starts with optional comments followed
1189 /// by two lines that define the rank, the number of nonzeros, and the
1190 /// dimensions sizes (one per rank) of the sparse tensor.
1191 static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
1192                                 uint64_t *idata) {
1193   // Skip comments.
1194   while (true) {
1195     if (!fgets(line, kColWidth, file)) {
1196       fprintf(stderr, "Cannot find data in %s\n", filename);
1197       exit(1);
1198     }
1199     if (line[0] != '#')
1200       break;
1201   }
1202   // Next line contains RANK and NNZ.
1203   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
1204     fprintf(stderr, "Cannot find metadata in %s\n", filename);
1205     exit(1);
1206   }
1207   // Followed by a line with the dimension sizes (one per rank).
1208   for (uint64_t r = 0; r < idata[0]; r++) {
1209     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
1210       fprintf(stderr, "Cannot find dimension size %s\n", filename);
1211       exit(1);
1212     }
1213   }
1214   fgets(line, kColWidth, file); // end of line
1215 }
1216 
1217 /// Reads a sparse tensor with the given filename into a memory-resident
1218 /// sparse tensor in coordinate scheme.
1219 template <typename V>
1220 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
1221                                                const uint64_t *shape,
1222                                                const uint64_t *perm) {
1223   // Open the file.
1224   FILE *file = fopen(filename, "r");
1225   if (!file) {
1226     assert(filename && "Received nullptr for filename");
1227     fprintf(stderr, "Cannot find file %s\n", filename);
1228     exit(1);
1229   }
1230   // Perform some file format dependent set up.
1231   char line[kColWidth];
1232   uint64_t idata[512];
1233   bool isPattern = false;
1234   bool isSymmetric = false;
1235   if (strstr(filename, ".mtx")) {
1236     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
1237   } else if (strstr(filename, ".tns")) {
1238     readExtFROSTTHeader(file, filename, line, idata);
1239   } else {
1240     fprintf(stderr, "Unknown format %s\n", filename);
1241     exit(1);
1242   }
1243   // Prepare sparse tensor object with per-dimension sizes
1244   // and the number of nonzeros as initial capacity.
1245   assert(rank == idata[0] && "rank mismatch");
1246   uint64_t nnz = idata[1];
1247   for (uint64_t r = 0; r < rank; r++)
1248     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
1249            "dimension size mismatch");
1250   SparseTensorCOO<V> *tensor =
1251       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
1252   // Read all nonzero elements.
1253   std::vector<uint64_t> indices(rank);
1254   for (uint64_t k = 0; k < nnz; k++) {
1255     if (!fgets(line, kColWidth, file)) {
1256       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
1257       exit(1);
1258     }
1259     char *linePtr = line;
1260     for (uint64_t r = 0; r < rank; r++) {
1261       uint64_t idx = strtoul(linePtr, &linePtr, 10);
1262       // Add 0-based index.
1263       indices[perm[r]] = idx - 1;
1264     }
1265     // The external formats always store the numerical values with the type
1266     // double, but we cast these values to the sparse tensor object type.
1267     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
1268     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
1269     tensor->add(indices, value);
1270     // We currently chose to deal with symmetric matrices by fully constructing
1271     // them. In the future, we may want to make symmetry implicit for storage
1272     // reasons.
1273     if (isSymmetric && indices[0] != indices[1])
1274       tensor->add({indices[1], indices[0]}, value);
1275   }
1276   // Close the file and return tensor.
1277   fclose(file);
1278   return tensor;
1279 }
1280 
1281 /// Writes the sparse tensor to extended FROSTT format.
1282 template <typename V>
1283 static void outSparseTensor(void *tensor, void *dest, bool sort) {
1284   assert(tensor && dest);
1285   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
1286   if (sort)
1287     coo->sort();
1288   char *filename = static_cast<char *>(dest);
1289   auto &sizes = coo->getSizes();
1290   auto &elements = coo->getElements();
1291   uint64_t rank = coo->getRank();
1292   uint64_t nnz = elements.size();
1293   std::fstream file;
1294   file.open(filename, std::ios_base::out | std::ios_base::trunc);
1295   assert(file.is_open());
1296   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1297   for (uint64_t r = 0; r < rank - 1; r++)
1298     file << sizes[r] << " ";
1299   file << sizes[rank - 1] << std::endl;
1300   for (uint64_t i = 0; i < nnz; i++) {
1301     auto &idx = elements[i].indices;
1302     for (uint64_t r = 0; r < rank; r++)
1303       file << (idx[r] + 1) << " ";
1304     file << elements[i].value << std::endl;
1305   }
1306   file.flush();
1307   file.close();
1308   assert(file.good());
1309 }
1310 
1311 /// Initializes sparse tensor from an external COO-flavored format.
1312 template <typename V>
1313 static SparseTensorStorage<uint64_t, uint64_t, V> *
1314 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
1315                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
1316   const DimLevelType *sparsity = (DimLevelType *)(sparse);
1317 #ifndef NDEBUG
1318   // Verify that perm is a permutation of 0..(rank-1).
1319   std::vector<uint64_t> order(perm, perm + rank);
1320   std::sort(order.begin(), order.end());
1321   for (uint64_t i = 0; i < rank; ++i) {
1322     if (i != order[i]) {
1323       fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
1324       exit(1);
1325     }
1326   }
1327 
1328   // Verify that the sparsity values are supported.
1329   for (uint64_t i = 0; i < rank; ++i) {
1330     if (sparsity[i] != DimLevelType::kDense &&
1331         sparsity[i] != DimLevelType::kCompressed) {
1332       fprintf(stderr, "Unsupported sparsity value %d\n",
1333               static_cast<int>(sparsity[i]));
1334       exit(1);
1335     }
1336   }
1337 #endif
1338 
1339   // Convert external format to internal COO.
1340   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
1341   std::vector<uint64_t> idx(rank);
1342   for (uint64_t i = 0, base = 0; i < nse; i++) {
1343     for (uint64_t r = 0; r < rank; r++)
1344       idx[perm[r]] = indices[base + r];
1345     coo->add(idx, values[i]);
1346     base += rank;
1347   }
1348   // Return sparse tensor storage format as opaque pointer.
1349   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
1350       rank, shape, perm, sparsity, coo);
1351   delete coo;
1352   return tensor;
1353 }
1354 
1355 /// Converts a sparse tensor to an external COO-flavored format.
1356 template <typename V>
1357 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1358                                  uint64_t **pShape, V **pValues,
1359                                  uint64_t **pIndices) {
1360   assert(tensor);
1361   auto sparseTensor =
1362       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1363   uint64_t rank = sparseTensor->getRank();
1364   std::vector<uint64_t> perm(rank);
1365   std::iota(perm.begin(), perm.end(), 0);
1366   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1367 
1368   const std::vector<Element<V>> &elements = coo->getElements();
1369   uint64_t nse = elements.size();
1370 
1371   uint64_t *shape = new uint64_t[rank];
1372   for (uint64_t i = 0; i < rank; i++)
1373     shape[i] = coo->getSizes()[i];
1374 
1375   V *values = new V[nse];
1376   uint64_t *indices = new uint64_t[rank * nse];
1377 
1378   for (uint64_t i = 0, base = 0; i < nse; i++) {
1379     values[i] = elements[i].value;
1380     for (uint64_t j = 0; j < rank; j++)
1381       indices[base + j] = elements[i].indices[j];
1382     base += rank;
1383   }
1384 
1385   delete coo;
1386   *pRank = rank;
1387   *pNse = nse;
1388   *pShape = shape;
1389   *pValues = values;
1390   *pIndices = indices;
1391 }
1392 
1393 } // namespace
1394 
1395 extern "C" {
1396 
1397 //===----------------------------------------------------------------------===//
1398 //
1399 // Public API with methods that operate on MLIR buffers (memrefs) to interact
1400 // with sparse tensors, which are only visible as opaque pointers externally.
1401 // These methods should be used exclusively by MLIR compiler-generated code.
1402 //
1403 // Some macro magic is used to generate implementations for all required type
1404 // combinations that can be called from MLIR compiler-generated code.
1405 //
1406 //===----------------------------------------------------------------------===//
1407 
1408 #define CASE(p, i, v, P, I, V)                                                 \
1409   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
1410     SparseTensorCOO<V> *coo = nullptr;                                         \
1411     if (action <= Action::kFromCOO) {                                          \
1412       if (action == Action::kFromFile) {                                       \
1413         char *filename = static_cast<char *>(ptr);                             \
1414         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
1415       } else if (action == Action::kFromCOO) {                                 \
1416         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
1417       } else {                                                                 \
1418         assert(action == Action::kEmpty);                                      \
1419       }                                                                        \
1420       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
1421           rank, shape, perm, sparsity, coo);                                   \
1422       if (action == Action::kFromFile)                                         \
1423         delete coo;                                                            \
1424       return tensor;                                                           \
1425     }                                                                          \
1426     if (action == Action::kSparseToSparse) {                                   \
1427       auto *tensor = static_cast<SparseTensorStorageBase *>(ptr);              \
1428       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
1429                                                            sparsity, tensor);  \
1430     }                                                                          \
1431     if (action == Action::kEmptyCOO)                                           \
1432       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
1433     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1434     if (action == Action::kToIterator) {                                       \
1435       coo->startIterator();                                                    \
1436     } else {                                                                   \
1437       assert(action == Action::kToCOO);                                        \
1438     }                                                                          \
1439     return coo;                                                                \
1440   }
1441 
1442 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1443 
1444 #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
1445   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
1446     assert(ref &&tensor);                                                      \
1447     std::vector<TYPE> *v;                                                      \
1448     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
1449     ref->basePtr = ref->data = v->data();                                      \
1450     ref->offset = 0;                                                           \
1451     ref->sizes[0] = v->size();                                                 \
1452     ref->strides[0] = 1;                                                       \
1453   }
1454 
1455 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1456   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1457                            index_type d) {                                     \
1458     assert(ref &&tensor);                                                      \
1459     std::vector<TYPE> *v;                                                      \
1460     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1461     ref->basePtr = ref->data = v->data();                                      \
1462     ref->offset = 0;                                                           \
1463     ref->sizes[0] = v->size();                                                 \
1464     ref->strides[0] = 1;                                                       \
1465   }
1466 
1467 #define IMPL_ADDELT(NAME, TYPE)                                                \
1468   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
1469                             StridedMemRefType<index_type, 1> *iref,            \
1470                             StridedMemRefType<index_type, 1> *pref) {          \
1471     assert(tensor &&iref &&pref);                                              \
1472     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1473     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1474     const index_type *indx = iref->data + iref->offset;                        \
1475     const index_type *perm = pref->data + pref->offset;                        \
1476     uint64_t isize = iref->sizes[0];                                           \
1477     std::vector<index_type> indices(isize);                                    \
1478     for (uint64_t r = 0; r < isize; r++)                                       \
1479       indices[perm[r]] = indx[r];                                              \
1480     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
1481     return tensor;                                                             \
1482   }
1483 
1484 #define IMPL_GETNEXT(NAME, V)                                                  \
1485   bool _mlir_ciface_##NAME(void *tensor,                                       \
1486                            StridedMemRefType<index_type, 1> *iref,             \
1487                            StridedMemRefType<V, 0> *vref) {                    \
1488     assert(tensor &&iref &&vref);                                              \
1489     assert(iref->strides[0] == 1);                                             \
1490     index_type *indx = iref->data + iref->offset;                              \
1491     V *value = vref->data + vref->offset;                                      \
1492     const uint64_t isize = iref->sizes[0];                                     \
1493     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
1494     const Element<V> *elem = iter->getNext();                                  \
1495     if (elem == nullptr)                                                       \
1496       return false;                                                            \
1497     for (uint64_t r = 0; r < isize; r++)                                       \
1498       indx[r] = elem->indices[r];                                              \
1499     *value = elem->value;                                                      \
1500     return true;                                                               \
1501   }
1502 
1503 #define IMPL_LEXINSERT(NAME, V)                                                \
1504   void _mlir_ciface_##NAME(void *tensor,                                       \
1505                            StridedMemRefType<index_type, 1> *cref, V val) {    \
1506     assert(tensor &&cref);                                                     \
1507     assert(cref->strides[0] == 1);                                             \
1508     index_type *cursor = cref->data + cref->offset;                            \
1509     assert(cursor);                                                            \
1510     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1511   }
1512 
1513 #define IMPL_EXPINSERT(NAME, V)                                                \
1514   void _mlir_ciface_##NAME(                                                    \
1515       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1516       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1517       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1518     assert(tensor &&cref &&vref &&fref &&aref);                                \
1519     assert(cref->strides[0] == 1);                                             \
1520     assert(vref->strides[0] == 1);                                             \
1521     assert(fref->strides[0] == 1);                                             \
1522     assert(aref->strides[0] == 1);                                             \
1523     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1524     index_type *cursor = cref->data + cref->offset;                            \
1525     V *values = vref->data + vref->offset;                                     \
1526     bool *filled = fref->data + fref->offset;                                  \
1527     index_type *added = aref->data + aref->offset;                             \
1528     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1529         cursor, values, filled, added, count);                                 \
1530   }
1531 
1532 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1533 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1534 // that this file cannot get out of sync with its header.
1535 static_assert(std::is_same<index_type, uint64_t>::value,
1536               "Expected index_type == uint64_t");
1537 
1538 /// Constructs a new sparse tensor. This is the "swiss army knife"
1539 /// method for materializing sparse tensors into the computation.
1540 ///
1541 /// Action:
1542 /// kEmpty = returns empty storage to fill later
1543 /// kFromFile = returns storage, where ptr contains filename to read
1544 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
1545 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
1546 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
1547 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
1548 void *
1549 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1550                              StridedMemRefType<index_type, 1> *sref,
1551                              StridedMemRefType<index_type, 1> *pref,
1552                              OverheadType ptrTp, OverheadType indTp,
1553                              PrimaryType valTp, Action action, void *ptr) {
1554   assert(aref && sref && pref);
1555   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1556          pref->strides[0] == 1);
1557   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1558   const DimLevelType *sparsity = aref->data + aref->offset;
1559   const index_type *shape = sref->data + sref->offset;
1560   const index_type *perm = pref->data + pref->offset;
1561   uint64_t rank = aref->sizes[0];
1562 
1563   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1564   // This is safe because of the static_assert above.
1565   if (ptrTp == OverheadType::kIndex)
1566     ptrTp = OverheadType::kU64;
1567   if (indTp == OverheadType::kIndex)
1568     indTp = OverheadType::kU64;
1569 
1570   // Double matrices with all combinations of overhead storage.
1571   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1572        uint64_t, double);
1573   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1574        uint32_t, double);
1575   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1576        uint16_t, double);
1577   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1578        uint8_t, double);
1579   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1580        uint64_t, double);
1581   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1582        uint32_t, double);
1583   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1584        uint16_t, double);
1585   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1586        uint8_t, double);
1587   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1588        uint64_t, double);
1589   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1590        uint32_t, double);
1591   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1592        uint16_t, double);
1593   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1594        uint8_t, double);
1595   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1596        uint64_t, double);
1597   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1598        uint32_t, double);
1599   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1600        uint16_t, double);
1601   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1602        uint8_t, double);
1603 
1604   // Float matrices with all combinations of overhead storage.
1605   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1606        uint64_t, float);
1607   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1608        uint32_t, float);
1609   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1610        uint16_t, float);
1611   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1612        uint8_t, float);
1613   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1614        uint64_t, float);
1615   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1616        uint32_t, float);
1617   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1618        uint16_t, float);
1619   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1620        uint8_t, float);
1621   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1622        uint64_t, float);
1623   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1624        uint32_t, float);
1625   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1626        uint16_t, float);
1627   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1628        uint8_t, float);
1629   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1630        uint64_t, float);
1631   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1632        uint32_t, float);
1633   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1634        uint16_t, float);
1635   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1636        uint8_t, float);
1637 
1638   // Integral matrices with both overheads of the same type.
1639   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1640   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1641   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1642   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1643   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1644   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1645   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1646   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1647   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1648   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1649   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1650   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1651   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1652 
1653   // Complex matrices with wide overhead.
1654   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1655   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1656 
1657   // Unsupported case (add above if needed).
1658   fputs("unsupported combination of types\n", stderr);
1659   exit(1);
1660 }
1661 
1662 /// Methods that provide direct access to pointers.
1663 IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
1664 IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
1665 IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
1666 IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
1667 IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
1668 
1669 /// Methods that provide direct access to indices.
1670 IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
1671 IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
1672 IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
1673 IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
1674 IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
1675 
1676 /// Methods that provide direct access to values.
1677 IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
1678 IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
1679 IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
1680 IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
1681 IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
1682 IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
1683 IMPL_SPARSEVALUES(sparseValuesC64, complex64, getValues)
1684 IMPL_SPARSEVALUES(sparseValuesC32, complex32, getValues)
1685 
1686 /// Helper to add value to coordinate scheme, one per value type.
1687 IMPL_ADDELT(addEltF64, double)
1688 IMPL_ADDELT(addEltF32, float)
1689 IMPL_ADDELT(addEltI64, int64_t)
1690 IMPL_ADDELT(addEltI32, int32_t)
1691 IMPL_ADDELT(addEltI16, int16_t)
1692 IMPL_ADDELT(addEltI8, int8_t)
1693 IMPL_ADDELT(addEltC64, complex64)
1694 IMPL_ADDELT(addEltC32ABI, complex32)
1695 // Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
1696 // any padding (which seem to happen for complex32 when passed as scalar;
1697 // all other cases, e.g. pointer to array, work as expected).
1698 // TODO: cleaner way to avoid ABI padding problem?
1699 void *_mlir_ciface_addEltC32(void *tensor, float r, float i,
1700                              StridedMemRefType<index_type, 1> *iref,
1701                              StridedMemRefType<index_type, 1> *pref) {
1702   return _mlir_ciface_addEltC32ABI(tensor, complex32(r, i), iref, pref);
1703 }
1704 
1705 /// Helper to enumerate elements of coordinate scheme, one per value type.
1706 IMPL_GETNEXT(getNextF64, double)
1707 IMPL_GETNEXT(getNextF32, float)
1708 IMPL_GETNEXT(getNextI64, int64_t)
1709 IMPL_GETNEXT(getNextI32, int32_t)
1710 IMPL_GETNEXT(getNextI16, int16_t)
1711 IMPL_GETNEXT(getNextI8, int8_t)
1712 IMPL_GETNEXT(getNextC64, complex64)
1713 IMPL_GETNEXT(getNextC32, complex32)
1714 
1715 /// Insert elements in lexicographical index order, one per value type.
1716 IMPL_LEXINSERT(lexInsertF64, double)
1717 IMPL_LEXINSERT(lexInsertF32, float)
1718 IMPL_LEXINSERT(lexInsertI64, int64_t)
1719 IMPL_LEXINSERT(lexInsertI32, int32_t)
1720 IMPL_LEXINSERT(lexInsertI16, int16_t)
1721 IMPL_LEXINSERT(lexInsertI8, int8_t)
1722 IMPL_LEXINSERT(lexInsertC64, complex64)
1723 IMPL_LEXINSERT(lexInsertC32ABI, complex32)
1724 // Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
1725 // any padding (which seem to happen for complex32 when passed as scalar;
1726 // all other cases, e.g. pointer to array, work as expected).
1727 // TODO: cleaner way to avoid ABI padding problem?
1728 void _mlir_ciface_lexInsertC32(void *tensor,
1729                                StridedMemRefType<index_type, 1> *cref, float r,
1730                                float i) {
1731   _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
1732 }
1733 
1734 /// Insert using expansion, one per value type.
1735 IMPL_EXPINSERT(expInsertF64, double)
1736 IMPL_EXPINSERT(expInsertF32, float)
1737 IMPL_EXPINSERT(expInsertI64, int64_t)
1738 IMPL_EXPINSERT(expInsertI32, int32_t)
1739 IMPL_EXPINSERT(expInsertI16, int16_t)
1740 IMPL_EXPINSERT(expInsertI8, int8_t)
1741 IMPL_EXPINSERT(expInsertC64, complex64)
1742 IMPL_EXPINSERT(expInsertC32, complex32)
1743 
1744 #undef CASE
1745 #undef IMPL_SPARSEVALUES
1746 #undef IMPL_GETOVERHEAD
1747 #undef IMPL_ADDELT
1748 #undef IMPL_GETNEXT
1749 #undef IMPL_LEXINSERT
1750 #undef IMPL_EXPINSERT
1751 
1752 /// Output a sparse tensor, one per value type.
1753 void outSparseTensorF64(void *tensor, void *dest, bool sort) {
1754   return outSparseTensor<double>(tensor, dest, sort);
1755 }
1756 void outSparseTensorF32(void *tensor, void *dest, bool sort) {
1757   return outSparseTensor<float>(tensor, dest, sort);
1758 }
1759 void outSparseTensorI64(void *tensor, void *dest, bool sort) {
1760   return outSparseTensor<int64_t>(tensor, dest, sort);
1761 }
1762 void outSparseTensorI32(void *tensor, void *dest, bool sort) {
1763   return outSparseTensor<int32_t>(tensor, dest, sort);
1764 }
1765 void outSparseTensorI16(void *tensor, void *dest, bool sort) {
1766   return outSparseTensor<int16_t>(tensor, dest, sort);
1767 }
1768 void outSparseTensorI8(void *tensor, void *dest, bool sort) {
1769   return outSparseTensor<int8_t>(tensor, dest, sort);
1770 }
1771 void outSparseTensorC64(void *tensor, void *dest, bool sort) {
1772   return outSparseTensor<complex64>(tensor, dest, sort);
1773 }
1774 void outSparseTensorC32(void *tensor, void *dest, bool sort) {
1775   return outSparseTensor<complex32>(tensor, dest, sort);
1776 }
1777 
1778 //===----------------------------------------------------------------------===//
1779 //
1780 // Public API with methods that accept C-style data structures to interact
1781 // with sparse tensors, which are only visible as opaque pointers externally.
1782 // These methods can be used both by MLIR compiler-generated code as well as by
1783 // an external runtime that wants to interact with MLIR compiler-generated code.
1784 //
1785 //===----------------------------------------------------------------------===//
1786 
1787 /// Helper method to read a sparse tensor filename from the environment,
1788 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1789 char *getTensorFilename(index_type id) {
1790   char var[80];
1791   sprintf(var, "TENSOR%" PRIu64, id);
1792   char *env = getenv(var);
1793   if (!env) {
1794     fprintf(stderr, "Environment variable %s is not set\n", var);
1795     exit(1);
1796   }
1797   return env;
1798 }
1799 
1800 /// Returns size of sparse tensor in given dimension.
1801 index_type sparseDimSize(void *tensor, index_type d) {
1802   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1803 }
1804 
1805 /// Finalizes lexicographic insertions.
1806 void endInsert(void *tensor) {
1807   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1808 }
1809 
1810 /// Releases sparse tensor storage.
1811 void delSparseTensor(void *tensor) {
1812   delete static_cast<SparseTensorStorageBase *>(tensor);
1813 }
1814 
1815 /// Releases sparse tensor coordinate scheme.
1816 #define IMPL_DELCOO(VNAME, V)                                                  \
1817   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1818     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1819   }
1820 IMPL_DELCOO(F64, double)
1821 IMPL_DELCOO(F32, float)
1822 IMPL_DELCOO(I64, int64_t)
1823 IMPL_DELCOO(I32, int32_t)
1824 IMPL_DELCOO(I16, int16_t)
1825 IMPL_DELCOO(I8, int8_t)
1826 IMPL_DELCOO(C64, complex64)
1827 IMPL_DELCOO(C32, complex32)
1828 #undef IMPL_DELCOO
1829 
1830 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
1831 /// data structures. The expected parameters are:
1832 ///
1833 ///   rank:    rank of tensor
1834 ///   nse:     number of specified elements (usually the nonzeros)
1835 ///   shape:   array with dimension size for each rank
1836 ///   values:  a "nse" array with values for all specified elements
1837 ///   indices: a flat "nse x rank" array with indices for all specified elements
1838 ///   perm:    the permutation of the dimensions in the storage
1839 ///   sparse:  the sparsity for the dimensions
1840 ///
1841 /// For example, the sparse matrix
1842 ///     | 1.0 0.0 0.0 |
1843 ///     | 0.0 5.0 3.0 |
1844 /// can be passed as
1845 ///      rank    = 2
1846 ///      nse     = 3
1847 ///      shape   = [2, 3]
1848 ///      values  = [1.0, 5.0, 3.0]
1849 ///      indices = [ 0, 0,  1, 1,  1, 2]
1850 //
1851 // TODO: generalize beyond 64-bit indices.
1852 //
1853 void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape,
1854                                    double *values, uint64_t *indices,
1855                                    uint64_t *perm, uint8_t *sparse) {
1856   return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm,
1857                                     sparse);
1858 }
1859 void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
1860                                    float *values, uint64_t *indices,
1861                                    uint64_t *perm, uint8_t *sparse) {
1862   return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm,
1863                                    sparse);
1864 }
1865 void *convertToMLIRSparseTensorI64(uint64_t rank, uint64_t nse, uint64_t *shape,
1866                                    int64_t *values, uint64_t *indices,
1867                                    uint64_t *perm, uint8_t *sparse) {
1868   return toMLIRSparseTensor<int64_t>(rank, nse, shape, values, indices, perm,
1869                                      sparse);
1870 }
1871 void *convertToMLIRSparseTensorI32(uint64_t rank, uint64_t nse, uint64_t *shape,
1872                                    int32_t *values, uint64_t *indices,
1873                                    uint64_t *perm, uint8_t *sparse) {
1874   return toMLIRSparseTensor<int32_t>(rank, nse, shape, values, indices, perm,
1875                                      sparse);
1876 }
1877 void *convertToMLIRSparseTensorI16(uint64_t rank, uint64_t nse, uint64_t *shape,
1878                                    int16_t *values, uint64_t *indices,
1879                                    uint64_t *perm, uint8_t *sparse) {
1880   return toMLIRSparseTensor<int16_t>(rank, nse, shape, values, indices, perm,
1881                                      sparse);
1882 }
1883 void *convertToMLIRSparseTensorI8(uint64_t rank, uint64_t nse, uint64_t *shape,
1884                                   int8_t *values, uint64_t *indices,
1885                                   uint64_t *perm, uint8_t *sparse) {
1886   return toMLIRSparseTensor<int8_t>(rank, nse, shape, values, indices, perm,
1887                                     sparse);
1888 }
1889 void *convertToMLIRSparseTensorC64(uint64_t rank, uint64_t nse, uint64_t *shape,
1890                                    complex64 *values, uint64_t *indices,
1891                                    uint64_t *perm, uint8_t *sparse) {
1892   return toMLIRSparseTensor<complex64>(rank, nse, shape, values, indices, perm,
1893                                        sparse);
1894 }
1895 void *convertToMLIRSparseTensorC32(uint64_t rank, uint64_t nse, uint64_t *shape,
1896                                    complex32 *values, uint64_t *indices,
1897                                    uint64_t *perm, uint8_t *sparse) {
1898   return toMLIRSparseTensor<complex32>(rank, nse, shape, values, indices, perm,
1899                                        sparse);
1900 }
1901 
1902 /// Converts a sparse tensor to COO-flavored format expressed using C-style
1903 /// data structures. The expected output parameters are pointers for these
1904 /// values:
1905 ///
1906 ///   rank:    rank of tensor
1907 ///   nse:     number of specified elements (usually the nonzeros)
1908 ///   shape:   array with dimension size for each rank
1909 ///   values:  a "nse" array with values for all specified elements
1910 ///   indices: a flat "nse x rank" array with indices for all specified elements
1911 ///
1912 /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
1913 /// from convertToMLIRSparseTensor.
1914 ///
1915 //  TODO: Currently, values are copied from SparseTensorStorage to
1916 //  SparseTensorCOO, then to the output. We may want to reduce the number of
1917 //  copies.
1918 //
1919 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1920 // compressed
1921 //
1922 void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank,
1923                                     uint64_t *pNse, uint64_t **pShape,
1924                                     double **pValues, uint64_t **pIndices) {
1925   fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices);
1926 }
1927 void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
1928                                     uint64_t *pNse, uint64_t **pShape,
1929                                     float **pValues, uint64_t **pIndices) {
1930   fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
1931 }
1932 void convertFromMLIRSparseTensorI64(void *tensor, uint64_t *pRank,
1933                                     uint64_t *pNse, uint64_t **pShape,
1934                                     int64_t **pValues, uint64_t **pIndices) {
1935   fromMLIRSparseTensor<int64_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1936 }
1937 void convertFromMLIRSparseTensorI32(void *tensor, uint64_t *pRank,
1938                                     uint64_t *pNse, uint64_t **pShape,
1939                                     int32_t **pValues, uint64_t **pIndices) {
1940   fromMLIRSparseTensor<int32_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1941 }
1942 void convertFromMLIRSparseTensorI16(void *tensor, uint64_t *pRank,
1943                                     uint64_t *pNse, uint64_t **pShape,
1944                                     int16_t **pValues, uint64_t **pIndices) {
1945   fromMLIRSparseTensor<int16_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1946 }
1947 void convertFromMLIRSparseTensorI8(void *tensor, uint64_t *pRank,
1948                                    uint64_t *pNse, uint64_t **pShape,
1949                                    int8_t **pValues, uint64_t **pIndices) {
1950   fromMLIRSparseTensor<int8_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1951 }
1952 void convertFromMLIRSparseTensorC64(void *tensor, uint64_t *pRank,
1953                                     uint64_t *pNse, uint64_t **pShape,
1954                                     complex64 **pValues, uint64_t **pIndices) {
1955   fromMLIRSparseTensor<complex64>(tensor, pRank, pNse, pShape, pValues,
1956                                   pIndices);
1957 }
1958 void convertFromMLIRSparseTensorC32(void *tensor, uint64_t *pRank,
1959                                     uint64_t *pNse, uint64_t **pShape,
1960                                     complex32 **pValues, uint64_t **pIndices) {
1961   fromMLIRSparseTensor<complex32>(tensor, pRank, pNse, pShape, pValues,
1962                                   pIndices);
1963 }
1964 
1965 } // extern "C"
1966 
1967 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1968