1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <complex>
25 #include <cctype>
26 #include <cinttypes>
27 #include <cstdio>
28 #include <cstdlib>
29 #include <cstring>
30 #include <fstream>
31 #include <functional>
32 #include <iostream>
33 #include <limits>
34 #include <numeric>
35 #include <vector>
36 
37 using complex64 = std::complex<double>;
38 using complex32 = std::complex<float>;
39 
40 //===----------------------------------------------------------------------===//
41 //
42 // Internal support for storing and reading sparse tensors.
43 //
44 // The following memory-resident sparse storage schemes are supported:
45 //
46 // (a) A coordinate scheme for temporarily storing and lexicographically
47 //     sorting a sparse tensor by index (SparseTensorCOO).
48 //
49 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
50 //     per-dimension sparse/dense annnotations together with a dimension
51 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
52 //
53 // The following external formats are supported:
54 //
55 // (1) Matrix Market Exchange (MME): *.mtx
56 //     https://math.nist.gov/MatrixMarket/formats.html
57 //
58 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
59 //     http://frostt.io/tensors/file-formats.html
60 //
61 // Two public APIs are supported:
62 //
63 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
64 //     tensors. These methods should be used exclusively by MLIR
65 //     compiler-generated code.
66 //
67 // (II) Methods that accept C-style data structures to interact with sparse
68 //      tensors. These methods can be used by any external runtime that wants
69 //      to interact with MLIR compiler-generated code.
70 //
71 // In both cases (I) and (II), the SparseTensorStorage format is externally
72 // only visible as an opaque pointer.
73 //
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 
78 static constexpr int kColWidth = 1025;
79 
80 /// A version of `operator*` on `uint64_t` which checks for overflows.
81 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
82   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
83          "Integer overflow");
84   return lhs * rhs;
85 }
86 
87 // TODO: adjust this so it can be used by `openSparseTensorCOO` too.
88 // That version doesn't have the permutation, and the `sizes` are
89 // a pointer/C-array rather than `std::vector`.
90 //
91 /// Asserts that the `sizes` (in target-order) under the `perm` (mapping
92 /// semantic-order to target-order) are a refinement of the desired `shape`
93 /// (in semantic-order).
94 ///
95 /// Precondition: `perm` and `shape` must be valid for `rank`.
96 static inline void
97 assertPermutedSizesMatchShape(const std::vector<uint64_t> &sizes, uint64_t rank,
98                               const uint64_t *perm, const uint64_t *shape) {
99   assert(perm && shape);
100   assert(rank == sizes.size() && "Rank mismatch");
101   for (uint64_t r = 0; r < rank; r++)
102     assert((shape[r] == 0 || shape[r] == sizes[perm[r]]) &&
103            "Dimension size mismatch");
104 }
105 
106 /// A sparse tensor element in coordinate scheme (value and indices).
107 /// For example, a rank-1 vector element would look like
108 ///   ({i}, a[i])
109 /// and a rank-5 tensor element like
110 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
111 /// We use pointer to a shared index pool rather than e.g. a direct
112 /// vector since that (1) reduces the per-element memory footprint, and
113 /// (2) centralizes the memory reservation and (re)allocation to one place.
114 template <typename V>
115 struct Element final {
116   Element(uint64_t *ind, V val) : indices(ind), value(val){};
117   uint64_t *indices; // pointer into shared index pool
118   V value;
119 };
120 
121 /// The type of callback functions which receive an element.  We avoid
122 /// packaging the coordinates and value together as an `Element` object
123 /// because this helps keep code somewhat cleaner.
124 template <typename V>
125 using ElementConsumer =
126     const std::function<void(const std::vector<uint64_t> &, V)> &;
127 
128 /// A memory-resident sparse tensor in coordinate scheme (collection of
129 /// elements). This data structure is used to read a sparse tensor from
130 /// any external format into memory and sort the elements lexicographically
131 /// by indices before passing it back to the client (most packed storage
132 /// formats require the elements to appear in lexicographic index order).
133 template <typename V>
134 struct SparseTensorCOO final {
135 public:
136   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
137       : sizes(szs) {
138     if (capacity) {
139       elements.reserve(capacity);
140       indices.reserve(capacity * getRank());
141     }
142   }
143 
144   /// Adds element as indices and value.
145   void add(const std::vector<uint64_t> &ind, V val) {
146     assert(!iteratorLocked && "Attempt to add() after startIterator()");
147     uint64_t *base = indices.data();
148     uint64_t size = indices.size();
149     uint64_t rank = getRank();
150     assert(rank == ind.size());
151     for (uint64_t r = 0; r < rank; r++) {
152       assert(ind[r] < sizes[r]); // within bounds
153       indices.push_back(ind[r]);
154     }
155     // This base only changes if indices were reallocated. In that case, we
156     // need to correct all previous pointers into the vector. Note that this
157     // only happens if we did not set the initial capacity right, and then only
158     // for every internal vector reallocation (which with the doubling rule
159     // should only incur an amortized linear overhead).
160     uint64_t *newBase = indices.data();
161     if (newBase != base) {
162       for (uint64_t i = 0, n = elements.size(); i < n; i++)
163         elements[i].indices = newBase + (elements[i].indices - base);
164       base = newBase;
165     }
166     // Add element as (pointer into shared index pool, value) pair.
167     elements.emplace_back(base + size, val);
168   }
169 
170   /// Sorts elements lexicographically by index.
171   void sort() {
172     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
173     // TODO: we may want to cache an `isSorted` bit, to avoid
174     // unnecessary/redundant sorting.
175     std::sort(elements.begin(), elements.end(),
176               [this](const Element<V> &e1, const Element<V> &e2) {
177                 uint64_t rank = getRank();
178                 for (uint64_t r = 0; r < rank; r++) {
179                   if (e1.indices[r] == e2.indices[r])
180                     continue;
181                   return e1.indices[r] < e2.indices[r];
182                 }
183                 return false;
184               });
185   }
186 
187   /// Returns rank.
188   uint64_t getRank() const { return sizes.size(); }
189 
190   /// Getter for sizes array.
191   const std::vector<uint64_t> &getSizes() const { return sizes; }
192 
193   /// Getter for elements array.
194   const std::vector<Element<V>> &getElements() const { return elements; }
195 
196   /// Switch into iterator mode.
197   void startIterator() {
198     iteratorLocked = true;
199     iteratorPos = 0;
200   }
201 
202   /// Get the next element.
203   const Element<V> *getNext() {
204     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
205     if (iteratorPos < elements.size())
206       return &(elements[iteratorPos++]);
207     iteratorLocked = false;
208     return nullptr;
209   }
210 
211   /// Factory method. Permutes the original dimensions according to
212   /// the given ordering and expects subsequent add() calls to honor
213   /// that same ordering for the given indices. The result is a
214   /// fully permuted coordinate scheme.
215   ///
216   /// Precondition: `sizes` and `perm` must be valid for `rank`.
217   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
218                                                 const uint64_t *sizes,
219                                                 const uint64_t *perm,
220                                                 uint64_t capacity = 0) {
221     std::vector<uint64_t> permsz(rank);
222     for (uint64_t r = 0; r < rank; r++) {
223       assert(sizes[r] > 0 && "Dimension size zero has trivial storage");
224       permsz[perm[r]] = sizes[r];
225     }
226     return new SparseTensorCOO<V>(permsz, capacity);
227   }
228 
229 private:
230   const std::vector<uint64_t> sizes; // per-dimension sizes
231   std::vector<Element<V>> elements;  // all COO elements
232   std::vector<uint64_t> indices;     // shared index pool
233   bool iteratorLocked = false;
234   unsigned iteratorPos = 0;
235 };
236 
237 // See <https://en.wikipedia.org/wiki/X_Macro>
238 //
239 // `FOREVERY_SIMPLEX_V` only specifies the non-complex `V` types, because
240 // the ABI for complex types has compiler/architecture dependent complexities
241 // we need to work around.  Namely, when a function takes a parameter of
242 // C/C++ type `complex32` (per se), then there is additional padding that
243 // causes it not to match the LLVM type `!llvm.struct<(f32, f32)>`.  This
244 // only happens with the `complex32` type itself, not with pointers/arrays
245 // of complex values.  So far `complex64` doesn't exhibit this ABI
246 // incompatibility, but we exclude it anyways just to be safe.
247 #define FOREVERY_SIMPLEX_V(DO)                                                 \
248   DO(F64, double)                                                              \
249   DO(F32, float)                                                               \
250   DO(I64, int64_t)                                                             \
251   DO(I32, int32_t)                                                             \
252   DO(I16, int16_t)                                                             \
253   DO(I8, int8_t)
254 
255 #define FOREVERY_V(DO)                                                         \
256   FOREVERY_SIMPLEX_V(DO)                                                       \
257   DO(C64, complex64)                                                           \
258   DO(C32, complex32)
259 
260 // Forward.
261 template <typename V>
262 class SparseTensorEnumeratorBase;
263 
264 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
265 /// takes responsibility for all the `<P,I,V>`-independent aspects
266 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
267 /// we use function overloading to implement "partial" method
268 /// specialization, which the C-API relies on to catch type errors
269 /// arising from our use of opaque pointers.
270 class SparseTensorStorageBase {
271 public:
272   /// Constructs a new storage object.  The `perm` maps the tensor's
273   /// semantic-ordering of dimensions to this object's storage-order.
274   /// The `szs` and `sparsity` arrays are already in storage-order.
275   ///
276   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
277   SparseTensorStorageBase(const std::vector<uint64_t> &szs,
278                           const uint64_t *perm, const DimLevelType *sparsity)
279       : dimSizes(szs), rev(getRank()),
280         dimTypes(sparsity, sparsity + getRank()) {
281     assert(perm && sparsity);
282     const uint64_t rank = getRank();
283     // Validate parameters.
284     assert(rank > 0 && "Trivial shape is unsupported");
285     for (uint64_t r = 0; r < rank; r++) {
286       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
287       assert((dimTypes[r] == DimLevelType::kDense ||
288               dimTypes[r] == DimLevelType::kCompressed) &&
289              "Unsupported DimLevelType");
290     }
291     // Construct the "reverse" (i.e., inverse) permutation.
292     for (uint64_t r = 0; r < rank; r++)
293       rev[perm[r]] = r;
294   }
295 
296   virtual ~SparseTensorStorageBase() = default;
297 
298   /// Get the rank of the tensor.
299   uint64_t getRank() const { return dimSizes.size(); }
300 
301   /// Getter for the dimension-sizes array, in storage-order.
302   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
303 
304   /// Safely lookup the size of the given (storage-order) dimension.
305   uint64_t getDimSize(uint64_t d) const {
306     assert(d < getRank());
307     return dimSizes[d];
308   }
309 
310   /// Getter for the "reverse" permutation, which maps this object's
311   /// storage-order to the tensor's semantic-order.
312   const std::vector<uint64_t> &getRev() const { return rev; }
313 
314   /// Getter for the dimension-types array, in storage-order.
315   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
316 
317   /// Safely check if the (storage-order) dimension uses compressed storage.
318   bool isCompressedDim(uint64_t d) const {
319     assert(d < getRank());
320     return (dimTypes[d] == DimLevelType::kCompressed);
321   }
322 
323   /// Allocate a new enumerator.
324 #define DECL_NEWENUMERATOR(VNAME, V)                                           \
325   virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t,       \
326                              const uint64_t *) const {                         \
327     fatal("newEnumerator" #VNAME);                                             \
328   }
329   FOREVERY_V(DECL_NEWENUMERATOR)
330 #undef DECL_NEWENUMERATOR
331 
332   /// Overhead storage.
333   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
334   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
335   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
336   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
337   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
338   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
339   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
340   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
341 
342   /// Primary storage.
343 #define DECL_GETVALUES(VNAME, V)                                               \
344   virtual void getValues(std::vector<V> **) { fatal("getValues" #VNAME); }
345   FOREVERY_V(DECL_GETVALUES)
346 #undef DECL_GETVALUES
347 
348   /// Element-wise insertion in lexicographic index order.
349 #define DECL_LEXINSERT(VNAME, V)                                               \
350   virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); }
351   FOREVERY_V(DECL_LEXINSERT)
352 #undef DECL_LEXINSERT
353 
354   /// Expanded insertion.
355 #define DECL_EXPINSERT(VNAME, V)                                               \
356   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) {      \
357     fatal("expInsert" #VNAME);                                                 \
358   }
359   FOREVERY_V(DECL_EXPINSERT)
360 #undef DECL_EXPINSERT
361 
362   /// Finishes insertion.
363   virtual void endInsert() = 0;
364 
365 protected:
366   // Since this class is virtual, we must disallow public copying in
367   // order to avoid "slicing".  Since this class has data members,
368   // that means making copying protected.
369   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
370   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
371   // Copy-assignment would be implicitly deleted (because `dimSizes`
372   // is const), so we explicitly delete it for clarity.
373   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
374 
375 private:
376   static void fatal(const char *tp) {
377     fprintf(stderr, "unsupported %s\n", tp);
378     exit(1);
379   }
380 
381   const std::vector<uint64_t> dimSizes;
382   std::vector<uint64_t> rev;
383   const std::vector<DimLevelType> dimTypes;
384 };
385 
386 // Forward.
387 template <typename P, typename I, typename V>
388 class SparseTensorEnumerator;
389 
390 /// A memory-resident sparse tensor using a storage scheme based on
391 /// per-dimension sparse/dense annotations. This data structure provides a
392 /// bufferized form of a sparse tensor type. In contrast to generating setup
393 /// methods for each differently annotated sparse tensor, this method provides
394 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
395 /// and annotations to implement all required setup in a general manner.
396 template <typename P, typename I, typename V>
397 class SparseTensorStorage final : public SparseTensorStorageBase {
398   /// Private constructor to share code between the other constructors.
399   /// Beware that the object is not necessarily guaranteed to be in a
400   /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
401   /// doesn't entail `!(pointers[d].empty())`.
402   ///
403   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
404   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
405                       const DimLevelType *sparsity)
406       : SparseTensorStorageBase(szs, perm, sparsity), pointers(getRank()),
407         indices(getRank()), idx(getRank()) {}
408 
409 public:
410   /// Constructs a sparse tensor storage scheme with the given dimensions,
411   /// permutation, and per-dimension dense/sparse annotations, using
412   /// the coordinate scheme tensor for the initial contents if provided.
413   ///
414   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
415   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
416                       const DimLevelType *sparsity, SparseTensorCOO<V> *coo)
417       : SparseTensorStorage(szs, perm, sparsity) {
418     // Provide hints on capacity of pointers and indices.
419     // TODO: needs much fine-tuning based on actual sparsity; currently
420     //       we reserve pointer/index space based on all previous dense
421     //       dimensions, which works well up to first sparse dim; but
422     //       we should really use nnz and dense/sparse distribution.
423     bool allDense = true;
424     uint64_t sz = 1;
425     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
426       if (isCompressedDim(r)) {
427         // TODO: Take a parameter between 1 and `sizes[r]`, and multiply
428         // `sz` by that before reserving. (For now we just use 1.)
429         pointers[r].reserve(sz + 1);
430         pointers[r].push_back(0);
431         indices[r].reserve(sz);
432         sz = 1;
433         allDense = false;
434       } else { // Dense dimension.
435         sz = checkedMul(sz, getDimSizes()[r]);
436       }
437     }
438     // Then assign contents from coordinate scheme tensor if provided.
439     if (coo) {
440       // Ensure both preconditions of `fromCOO`.
441       assert(coo->getSizes() == getDimSizes() && "Tensor size mismatch");
442       coo->sort();
443       // Now actually insert the `elements`.
444       const std::vector<Element<V>> &elements = coo->getElements();
445       uint64_t nnz = elements.size();
446       values.reserve(nnz);
447       fromCOO(elements, 0, nnz, 0);
448     } else if (allDense) {
449       values.resize(sz, 0);
450     }
451   }
452 
453   /// Constructs a sparse tensor storage scheme with the given dimensions,
454   /// permutation, and per-dimension dense/sparse annotations, using
455   /// the given sparse tensor for the initial contents.
456   ///
457   /// Preconditions:
458   /// * `perm` and `sparsity` must be valid for `szs.size()`.
459   /// * The `tensor` must have the same value type `V`.
460   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
461                       const DimLevelType *sparsity,
462                       const SparseTensorStorageBase &tensor);
463 
464   ~SparseTensorStorage() final override = default;
465 
466   /// Partially specialize these getter methods based on template types.
467   void getPointers(std::vector<P> **out, uint64_t d) final override {
468     assert(d < getRank());
469     *out = &pointers[d];
470   }
471   void getIndices(std::vector<I> **out, uint64_t d) final override {
472     assert(d < getRank());
473     *out = &indices[d];
474   }
475   void getValues(std::vector<V> **out) final override { *out = &values; }
476 
477   /// Partially specialize lexicographical insertions based on template types.
478   void lexInsert(const uint64_t *cursor, V val) final override {
479     // First, wrap up pending insertion path.
480     uint64_t diff = 0;
481     uint64_t top = 0;
482     if (!values.empty()) {
483       diff = lexDiff(cursor);
484       endPath(diff + 1);
485       top = idx[diff] + 1;
486     }
487     // Then continue with insertion path.
488     insPath(cursor, diff, top, val);
489   }
490 
491   /// Partially specialize expanded insertions based on template types.
492   /// Note that this method resets the values/filled-switch array back
493   /// to all-zero/false while only iterating over the nonzero elements.
494   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
495                  uint64_t count) final override {
496     if (count == 0)
497       return;
498     // Sort.
499     std::sort(added, added + count);
500     // Restore insertion path for first insert.
501     const uint64_t lastDim = getRank() - 1;
502     uint64_t index = added[0];
503     cursor[lastDim] = index;
504     lexInsert(cursor, values[index]);
505     assert(filled[index]);
506     values[index] = 0;
507     filled[index] = false;
508     // Subsequent insertions are quick.
509     for (uint64_t i = 1; i < count; i++) {
510       assert(index < added[i] && "non-lexicographic insertion");
511       index = added[i];
512       cursor[lastDim] = index;
513       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
514       assert(filled[index]);
515       values[index] = 0;
516       filled[index] = false;
517     }
518   }
519 
520   /// Finalizes lexicographic insertions.
521   void endInsert() final override {
522     if (values.empty())
523       finalizeSegment(0);
524     else
525       endPath(0);
526   }
527 
528   void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
529                      const uint64_t *perm) const final override {
530     *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
531   }
532 
533   /// Returns this sparse tensor storage scheme as a new memory-resident
534   /// sparse tensor in coordinate scheme with the given dimension order.
535   ///
536   /// Precondition: `perm` must be valid for `getRank()`.
537   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
538     SparseTensorEnumeratorBase<V> *enumerator;
539     newEnumerator(&enumerator, getRank(), perm);
540     SparseTensorCOO<V> *coo =
541         new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
542     enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
543       coo->add(ind, val);
544     });
545     // TODO: This assertion assumes there are no stored zeros,
546     // or if there are then that we don't filter them out.
547     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
548     assert(coo->getElements().size() == values.size());
549     delete enumerator;
550     return coo;
551   }
552 
553   /// Factory method. Constructs a sparse tensor storage scheme with the given
554   /// dimensions, permutation, and per-dimension dense/sparse annotations,
555   /// using the coordinate scheme tensor for the initial contents if provided.
556   /// In the latter case, the coordinate scheme must respect the same
557   /// permutation as is desired for the new sparse tensor storage.
558   ///
559   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
560   static SparseTensorStorage<P, I, V> *
561   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
562                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
563     SparseTensorStorage<P, I, V> *n = nullptr;
564     if (coo) {
565       const auto &coosz = coo->getSizes();
566       assertPermutedSizesMatchShape(coosz, rank, perm, shape);
567       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
568     } else {
569       std::vector<uint64_t> permsz(rank);
570       for (uint64_t r = 0; r < rank; r++) {
571         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
572         permsz[perm[r]] = shape[r];
573       }
574       // We pass the null `coo` to ensure we select the intended constructor.
575       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
576     }
577     return n;
578   }
579 
580   /// Factory method. Constructs a sparse tensor storage scheme with
581   /// the given dimensions, permutation, and per-dimension dense/sparse
582   /// annotations, using the sparse tensor for the initial contents.
583   ///
584   /// Preconditions:
585   /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
586   /// * The `tensor` must have the same value type `V`.
587   static SparseTensorStorage<P, I, V> *
588   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
589                   const DimLevelType *sparsity,
590                   const SparseTensorStorageBase *source) {
591     assert(source && "Got nullptr for source");
592     SparseTensorEnumeratorBase<V> *enumerator;
593     source->newEnumerator(&enumerator, rank, perm);
594     const auto &permsz = enumerator->permutedSizes();
595     assertPermutedSizesMatchShape(permsz, rank, perm, shape);
596     auto *tensor =
597         new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
598     delete enumerator;
599     return tensor;
600   }
601 
602 private:
603   /// Appends an arbitrary new position to `pointers[d]`.  This method
604   /// checks that `pos` is representable in the `P` type; however, it
605   /// does not check that `pos` is semantically valid (i.e., larger than
606   /// the previous position and smaller than `indices[d].capacity()`).
607   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
608     assert(isCompressedDim(d));
609     assert(pos <= std::numeric_limits<P>::max() &&
610            "Pointer value is too large for the P-type");
611     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
612   }
613 
614   /// Appends index `i` to dimension `d`, in the semantically general
615   /// sense.  For non-dense dimensions, that means appending to the
616   /// `indices[d]` array, checking that `i` is representable in the `I`
617   /// type; however, we do not verify other semantic requirements (e.g.,
618   /// that `i` is in bounds for `sizes[d]`, and not previously occurring
619   /// in the same segment).  For dense dimensions, this method instead
620   /// appends the appropriate number of zeros to the `values` array,
621   /// where `full` is the number of "entries" already written to `values`
622   /// for this segment (aka one after the highest index previously appended).
623   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
624     if (isCompressedDim(d)) {
625       assert(i <= std::numeric_limits<I>::max() &&
626              "Index value is too large for the I-type");
627       indices[d].push_back(static_cast<I>(i));
628     } else { // Dense dimension.
629       assert(i >= full && "Index was already filled");
630       if (i == full)
631         return; // Short-circuit, since it'll be a nop.
632       if (d + 1 == getRank())
633         values.insert(values.end(), i - full, 0);
634       else
635         finalizeSegment(d + 1, 0, i - full);
636     }
637   }
638 
639   /// Writes the given coordinate to `indices[d][pos]`.  This method
640   /// checks that `i` is representable in the `I` type; however, it
641   /// does not check that `i` is semantically valid (i.e., in bounds
642   /// for `sizes[d]` and not elsewhere occurring in the same segment).
643   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
644     assert(isCompressedDim(d));
645     // Subscript assignment to `std::vector` requires that the `pos`-th
646     // entry has been initialized; thus we must be sure to check `size()`
647     // here, instead of `capacity()` as would be ideal.
648     assert(pos < indices[d].size() && "Index position is out of bounds");
649     assert(i <= std::numeric_limits<I>::max() &&
650            "Index value is too large for the I-type");
651     indices[d][pos] = static_cast<I>(i);
652   }
653 
654   /// Computes the assembled-size associated with the `d`-th dimension,
655   /// given the assembled-size associated with the `(d-1)`-th dimension.
656   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
657   /// storage, as opposed to "dimension-sizes" which are the cardinality
658   /// of coordinates for that dimension.
659   ///
660   /// Precondition: the `pointers[d]` array must be fully initialized
661   /// before calling this method.
662   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
663     if (isCompressedDim(d))
664       return pointers[d][parentSz];
665     // else if dense:
666     return parentSz * getDimSizes()[d];
667   }
668 
669   /// Initializes sparse tensor storage scheme from a memory-resident sparse
670   /// tensor in coordinate scheme. This method prepares the pointers and
671   /// indices arrays under the given per-dimension dense/sparse annotations.
672   ///
673   /// Preconditions:
674   /// (1) the `elements` must be lexicographically sorted.
675   /// (2) the indices of every element are valid for `sizes` (equal rank
676   ///     and pointwise less-than).
677   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
678                uint64_t hi, uint64_t d) {
679     uint64_t rank = getRank();
680     assert(d <= rank && hi <= elements.size());
681     // Once dimensions are exhausted, insert the numerical values.
682     if (d == rank) {
683       assert(lo < hi);
684       values.push_back(elements[lo].value);
685       return;
686     }
687     // Visit all elements in this interval.
688     uint64_t full = 0;
689     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
690       // Find segment in interval with same index elements in this dimension.
691       uint64_t i = elements[lo].indices[d];
692       uint64_t seg = lo + 1;
693       while (seg < hi && elements[seg].indices[d] == i)
694         seg++;
695       // Handle segment in interval for sparse or dense dimension.
696       appendIndex(d, full, i);
697       full = i + 1;
698       fromCOO(elements, lo, seg, d + 1);
699       // And move on to next segment in interval.
700       lo = seg;
701     }
702     // Finalize the sparse pointer structure at this dimension.
703     finalizeSegment(d, full);
704   }
705 
706   /// Finalize the sparse pointer structure at this dimension.
707   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
708     if (count == 0)
709       return; // Short-circuit, since it'll be a nop.
710     if (isCompressedDim(d)) {
711       appendPointer(d, indices[d].size(), count);
712     } else { // Dense dimension.
713       const uint64_t sz = getDimSizes()[d];
714       assert(sz >= full && "Segment is overfull");
715       count = checkedMul(count, sz - full);
716       // For dense storage we must enumerate all the remaining coordinates
717       // in this dimension (i.e., coordinates after the last non-zero
718       // element), and either fill in their zero values or else recurse
719       // to finalize some deeper dimension.
720       if (d + 1 == getRank())
721         values.insert(values.end(), count, 0);
722       else
723         finalizeSegment(d + 1, 0, count);
724     }
725   }
726 
727   /// Wraps up a single insertion path, inner to outer.
728   void endPath(uint64_t diff) {
729     uint64_t rank = getRank();
730     assert(diff <= rank);
731     for (uint64_t i = 0; i < rank - diff; i++) {
732       const uint64_t d = rank - i - 1;
733       finalizeSegment(d, idx[d] + 1);
734     }
735   }
736 
737   /// Continues a single insertion path, outer to inner.
738   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
739     uint64_t rank = getRank();
740     assert(diff < rank);
741     for (uint64_t d = diff; d < rank; d++) {
742       uint64_t i = cursor[d];
743       appendIndex(d, top, i);
744       top = 0;
745       idx[d] = i;
746     }
747     values.push_back(val);
748   }
749 
750   /// Finds the lexicographic differing dimension.
751   uint64_t lexDiff(const uint64_t *cursor) const {
752     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
753       if (cursor[r] > idx[r])
754         return r;
755       else
756         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
757     assert(0 && "duplication insertion");
758     return -1u;
759   }
760 
761   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
762   // the cost of virtual-function dispatch in inner loops), without
763   // making them public to other client code.
764   friend class SparseTensorEnumerator<P, I, V>;
765 
766   std::vector<std::vector<P>> pointers;
767   std::vector<std::vector<I>> indices;
768   std::vector<V> values;
769   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
770 };
771 
772 /// A (higher-order) function object for enumerating the elements of some
773 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
774 /// method encapsulates the loop-nest for enumerating the elements of
775 /// the source tensor (in whatever order is best for the source tensor),
776 /// and applies a permutation to the coordinates/indices before handing
777 /// each element to the callback.  A single enumerator object can be
778 /// freely reused for several calls to `forallElements`, just so long
779 /// as each call is sequential with respect to one another.
780 ///
781 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
782 /// passed to the constructor; thus, objects of this class must not
783 /// outlive the sparse tensor they depend on.
784 ///
785 /// Design Note: The reason we define this class instead of simply using
786 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
787 /// the `<P,I>` template parameters from MLIR client code (to simplify the
788 /// type parameters used for direct sparse-to-sparse conversion).  And the
789 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
790 /// than simply using this class, is to avoid the cost of virtual-method
791 /// dispatch within the loop-nest.
792 template <typename V>
793 class SparseTensorEnumeratorBase {
794 public:
795   /// Constructs an enumerator with the given permutation for mapping
796   /// the semantic-ordering of dimensions to the desired target-ordering.
797   ///
798   /// Preconditions:
799   /// * the `tensor` must have the same `V` value type.
800   /// * `perm` must be valid for `rank`.
801   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
802                              uint64_t rank, const uint64_t *perm)
803       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
804         cursor(getRank()) {
805     assert(perm && "Received nullptr for permutation");
806     assert(rank == getRank() && "Permutation rank mismatch");
807     const auto &rev = src.getRev();        // source stg-order -> semantic-order
808     const auto &sizes = src.getDimSizes(); // in source storage-order
809     for (uint64_t s = 0; s < rank; s++) {  // `s` source storage-order
810       uint64_t t = perm[rev[s]];           // `t` target-order
811       reord[s] = t;
812       permsz[t] = sizes[s];
813     }
814   }
815 
816   virtual ~SparseTensorEnumeratorBase() = default;
817 
818   // We disallow copying to help avoid leaking the `src` reference.
819   // (In addition to avoiding the problem of slicing.)
820   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
821   SparseTensorEnumeratorBase &
822   operator=(const SparseTensorEnumeratorBase &) = delete;
823 
824   /// Returns the source/target tensor's rank.  (The source-rank and
825   /// target-rank are always equal since we only support permutations.
826   /// Though once we add support for other dimension mappings, this
827   /// method will have to be split in two.)
828   uint64_t getRank() const { return permsz.size(); }
829 
830   /// Returns the target tensor's dimension sizes.
831   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
832 
833   /// Enumerates all elements of the source tensor, permutes their
834   /// indices, and passes the permuted element to the callback.
835   /// The callback must not store the cursor reference directly,
836   /// since this function reuses the storage.  Instead, the callback
837   /// must copy it if they want to keep it.
838   virtual void forallElements(ElementConsumer<V> yield) = 0;
839 
840 protected:
841   const SparseTensorStorageBase &src;
842   std::vector<uint64_t> permsz; // in target order.
843   std::vector<uint64_t> reord;  // source storage-order -> target order.
844   std::vector<uint64_t> cursor; // in target order.
845 };
846 
847 template <typename P, typename I, typename V>
848 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
849   using Base = SparseTensorEnumeratorBase<V>;
850 
851 public:
852   /// Constructs an enumerator with the given permutation for mapping
853   /// the semantic-ordering of dimensions to the desired target-ordering.
854   ///
855   /// Precondition: `perm` must be valid for `rank`.
856   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
857                          uint64_t rank, const uint64_t *perm)
858       : Base(tensor, rank, perm) {}
859 
860   ~SparseTensorEnumerator() final override = default;
861 
862   void forallElements(ElementConsumer<V> yield) final override {
863     forallElements(yield, 0, 0);
864   }
865 
866 private:
867   /// The recursive component of the public `forallElements`.
868   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
869                       uint64_t d) {
870     // Recover the `<P,I,V>` type parameters of `src`.
871     const auto &src =
872         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
873     if (d == Base::getRank()) {
874       assert(parentPos < src.values.size() &&
875              "Value position is out of bounds");
876       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
877       yield(this->cursor, src.values[parentPos]);
878     } else if (src.isCompressedDim(d)) {
879       // Look up the bounds of the `d`-level segment determined by the
880       // `d-1`-level position `parentPos`.
881       const std::vector<P> &pointers_d = src.pointers[d];
882       assert(parentPos + 1 < pointers_d.size() &&
883              "Parent pointer position is out of bounds");
884       const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
885       const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
886       // Loop-invariant code for looking up the `d`-level coordinates/indices.
887       const std::vector<I> &indices_d = src.indices[d];
888       assert(pstop - 1 < indices_d.size() && "Index position is out of bounds");
889       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
890       for (uint64_t pos = pstart; pos < pstop; pos++) {
891         cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
892         forallElements(yield, pos, d + 1);
893       }
894     } else { // Dense dimension.
895       const uint64_t sz = src.getDimSizes()[d];
896       const uint64_t pstart = parentPos * sz;
897       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
898       for (uint64_t i = 0; i < sz; i++) {
899         cursor_reord_d = i;
900         forallElements(yield, pstart + i, d + 1);
901       }
902     }
903   }
904 };
905 
906 /// Statistics regarding the number of nonzero subtensors in
907 /// a source tensor, for direct sparse=>sparse conversion a la
908 /// <https://arxiv.org/abs/2001.02609>.
909 ///
910 /// N.B., this class stores references to the parameters passed to
911 /// the constructor; thus, objects of this class must not outlive
912 /// those parameters.
913 class SparseTensorNNZ final {
914 public:
915   /// Allocate the statistics structure for the desired sizes and
916   /// sparsity (in the target tensor's storage-order).  This constructor
917   /// does not actually populate the statistics, however; for that see
918   /// `initialize`.
919   ///
920   /// Precondition: `szs` must not contain zeros.
921   SparseTensorNNZ(const std::vector<uint64_t> &szs,
922                   const std::vector<DimLevelType> &sparsity)
923       : dimSizes(szs), dimTypes(sparsity), nnz(getRank()) {
924     assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
925     bool uncompressed = true;
926     uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
927     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
928       switch (dimTypes[r]) {
929       case DimLevelType::kCompressed:
930         assert(uncompressed &&
931                "Multiple compressed layers not currently supported");
932         uncompressed = false;
933         nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
934         break;
935       case DimLevelType::kDense:
936         assert(uncompressed &&
937                "Dense after compressed not currently supported");
938         break;
939       case DimLevelType::kSingleton:
940         // Singleton after Compressed causes no problems for allocating
941         // `nnz` nor for the yieldPos loop.  This remains true even
942         // when adding support for multiple compressed dimensions or
943         // for dense-after-compressed.
944         break;
945       }
946       sz = checkedMul(sz, dimSizes[r]);
947     }
948   }
949 
950   // We disallow copying to help avoid leaking the stored references.
951   SparseTensorNNZ(const SparseTensorNNZ &) = delete;
952   SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
953 
954   /// Returns the rank of the target tensor.
955   uint64_t getRank() const { return dimSizes.size(); }
956 
957   /// Enumerate the source tensor to fill in the statistics.  The
958   /// enumerator should already incorporate the permutation (from
959   /// semantic-order to the target storage-order).
960   template <typename V>
961   void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
962     assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
963     assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
964     enumerator.forallElements(
965         [this](const std::vector<uint64_t> &ind, V) { add(ind); });
966   }
967 
968   /// The type of callback functions which receive an nnz-statistic.
969   using NNZConsumer = const std::function<void(uint64_t)> &;
970 
971   /// Lexicographically enumerates all indicies for dimensions strictly
972   /// less than `stopDim`, and passes their nnz statistic to the callback.
973   /// Since our use-case only requires the statistic not the coordinates
974   /// themselves, we do not bother to construct those coordinates.
975   void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
976     assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
977     assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
978            "Cannot look up non-compressed dimensions");
979     forallIndices(yield, stopDim, 0, 0);
980   }
981 
982 private:
983   /// Adds a new element (i.e., increment its statistics).  We use
984   /// a method rather than inlining into the lambda in `initialize`,
985   /// to avoid spurious templating over `V`.  And this method is private
986   /// to avoid needing to re-assert validity of `ind` (which is guaranteed
987   /// by `forallElements`).
988   void add(const std::vector<uint64_t> &ind) {
989     uint64_t parentPos = 0;
990     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
991       if (dimTypes[r] == DimLevelType::kCompressed)
992         nnz[r][parentPos]++;
993       parentPos = parentPos * dimSizes[r] + ind[r];
994     }
995   }
996 
997   /// Recursive component of the public `forallIndices`.
998   void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
999                      uint64_t d) const {
1000     assert(d <= stopDim);
1001     if (d == stopDim) {
1002       assert(parentPos < nnz[d].size() && "Cursor is out of range");
1003       yield(nnz[d][parentPos]);
1004     } else {
1005       const uint64_t sz = dimSizes[d];
1006       const uint64_t pstart = parentPos * sz;
1007       for (uint64_t i = 0; i < sz; i++)
1008         forallIndices(yield, stopDim, pstart + i, d + 1);
1009     }
1010   }
1011 
1012   // All of these are in the target storage-order.
1013   const std::vector<uint64_t> &dimSizes;
1014   const std::vector<DimLevelType> &dimTypes;
1015   std::vector<std::vector<uint64_t>> nnz;
1016 };
1017 
1018 template <typename P, typename I, typename V>
1019 SparseTensorStorage<P, I, V>::SparseTensorStorage(
1020     const std::vector<uint64_t> &szs, const uint64_t *perm,
1021     const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1022     : SparseTensorStorage(szs, perm, sparsity) {
1023   SparseTensorEnumeratorBase<V> *enumerator;
1024   tensor.newEnumerator(&enumerator, getRank(), perm);
1025   {
1026     // Initialize the statistics structure.
1027     SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
1028     nnz.initialize(*enumerator);
1029     // Initialize "pointers" overhead (and allocate "indices", "values").
1030     uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
1031     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1032       if (isCompressedDim(r)) {
1033         pointers[r].reserve(parentSz + 1);
1034         pointers[r].push_back(0);
1035         uint64_t currentPos = 0;
1036         nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
1037           currentPos += n;
1038           appendPointer(r, currentPos);
1039         });
1040         assert(pointers[r].size() == parentSz + 1 &&
1041                "Final pointers size doesn't match allocated size");
1042         // That assertion entails `assembledSize(parentSz, r)`
1043         // is now in a valid state.  That is, `pointers[r][parentSz]`
1044         // equals the present value of `currentPos`, which is the
1045         // correct assembled-size for `indices[r]`.
1046       }
1047       // Update assembled-size for the next iteration.
1048       parentSz = assembledSize(parentSz, r);
1049       // Ideally we need only `indices[r].reserve(parentSz)`, however
1050       // the `std::vector` implementation forces us to initialize it too.
1051       // That is, in the yieldPos loop we need random-access assignment
1052       // to `indices[r]`; however, `std::vector`'s subscript-assignment
1053       // only allows assigning to already-initialized positions.
1054       if (isCompressedDim(r))
1055         indices[r].resize(parentSz, 0);
1056     }
1057     values.resize(parentSz, 0); // Both allocate and zero-initialize.
1058   }
1059   // The yieldPos loop
1060   enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
1061     uint64_t parentSz = 1, parentPos = 0;
1062     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1063       if (isCompressedDim(r)) {
1064         // If `parentPos == parentSz` then it's valid as an array-lookup;
1065         // however, it's semantically invalid here since that entry
1066         // does not represent a segment of `indices[r]`.  Moreover, that
1067         // entry must be immutable for `assembledSize` to remain valid.
1068         assert(parentPos < parentSz && "Pointers position is out of bounds");
1069         const uint64_t currentPos = pointers[r][parentPos];
1070         // This increment won't overflow the `P` type, since it can't
1071         // exceed the original value of `pointers[r][parentPos+1]`
1072         // which was already verified to be within bounds for `P`
1073         // when it was written to the array.
1074         pointers[r][parentPos]++;
1075         writeIndex(r, currentPos, ind[r]);
1076         parentPos = currentPos;
1077       } else { // Dense dimension.
1078         parentPos = parentPos * getDimSizes()[r] + ind[r];
1079       }
1080       parentSz = assembledSize(parentSz, r);
1081     }
1082     assert(parentPos < values.size() && "Value position is out of bounds");
1083     values[parentPos] = val;
1084   });
1085   // No longer need the enumerator, so we'll delete it ASAP.
1086   delete enumerator;
1087   // The finalizeYieldPos loop
1088   for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
1089     if (isCompressedDim(r)) {
1090       assert(parentSz == pointers[r].size() - 1 &&
1091              "Actual pointers size doesn't match the expected size");
1092       // Can't check all of them, but at least we can check the last one.
1093       assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
1094              "Pointers got corrupted");
1095       // TODO: optimize this by using `memmove` or similar.
1096       for (uint64_t n = 0; n < parentSz; n++) {
1097         const uint64_t parentPos = parentSz - n;
1098         pointers[r][parentPos] = pointers[r][parentPos - 1];
1099       }
1100       pointers[r][0] = 0;
1101     }
1102     parentSz = assembledSize(parentSz, r);
1103   }
1104 }
1105 
1106 /// Helper to convert string to lower case.
1107 static char *toLower(char *token) {
1108   for (char *c = token; *c; c++)
1109     *c = tolower(*c);
1110   return token;
1111 }
1112 
1113 /// Read the MME header of a general sparse matrix of type real.
1114 static void readMMEHeader(FILE *file, char *filename, char *line,
1115                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
1116   char header[64];
1117   char object[64];
1118   char format[64];
1119   char field[64];
1120   char symmetry[64];
1121   // Read header line.
1122   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
1123              symmetry) != 5) {
1124     fprintf(stderr, "Corrupt header in %s\n", filename);
1125     exit(1);
1126   }
1127   // Set properties
1128   *isPattern = (strcmp(toLower(field), "pattern") == 0);
1129   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
1130   // Make sure this is a general sparse matrix.
1131   if (strcmp(toLower(header), "%%matrixmarket") ||
1132       strcmp(toLower(object), "matrix") ||
1133       strcmp(toLower(format), "coordinate") ||
1134       (strcmp(toLower(field), "real") && !(*isPattern)) ||
1135       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
1136     fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
1137     exit(1);
1138   }
1139   // Skip comments.
1140   while (true) {
1141     if (!fgets(line, kColWidth, file)) {
1142       fprintf(stderr, "Cannot find data in %s\n", filename);
1143       exit(1);
1144     }
1145     if (line[0] != '%')
1146       break;
1147   }
1148   // Next line contains M N NNZ.
1149   idata[0] = 2; // rank
1150   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
1151              idata + 1) != 3) {
1152     fprintf(stderr, "Cannot find size in %s\n", filename);
1153     exit(1);
1154   }
1155 }
1156 
1157 /// Read the "extended" FROSTT header. Although not part of the documented
1158 /// format, we assume that the file starts with optional comments followed
1159 /// by two lines that define the rank, the number of nonzeros, and the
1160 /// dimensions sizes (one per rank) of the sparse tensor.
1161 static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
1162                                 uint64_t *idata) {
1163   // Skip comments.
1164   while (true) {
1165     if (!fgets(line, kColWidth, file)) {
1166       fprintf(stderr, "Cannot find data in %s\n", filename);
1167       exit(1);
1168     }
1169     if (line[0] != '#')
1170       break;
1171   }
1172   // Next line contains RANK and NNZ.
1173   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
1174     fprintf(stderr, "Cannot find metadata in %s\n", filename);
1175     exit(1);
1176   }
1177   // Followed by a line with the dimension sizes (one per rank).
1178   for (uint64_t r = 0; r < idata[0]; r++) {
1179     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
1180       fprintf(stderr, "Cannot find dimension size %s\n", filename);
1181       exit(1);
1182     }
1183   }
1184   fgets(line, kColWidth, file); // end of line
1185 }
1186 
1187 /// Reads a sparse tensor with the given filename into a memory-resident
1188 /// sparse tensor in coordinate scheme.
1189 template <typename V>
1190 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
1191                                                const uint64_t *shape,
1192                                                const uint64_t *perm) {
1193   // Open the file.
1194   FILE *file = fopen(filename, "r");
1195   if (!file) {
1196     assert(filename && "Received nullptr for filename");
1197     fprintf(stderr, "Cannot find file %s\n", filename);
1198     exit(1);
1199   }
1200   // Perform some file format dependent set up.
1201   char line[kColWidth];
1202   uint64_t idata[512];
1203   bool isPattern = false;
1204   bool isSymmetric = false;
1205   if (strstr(filename, ".mtx")) {
1206     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
1207   } else if (strstr(filename, ".tns")) {
1208     readExtFROSTTHeader(file, filename, line, idata);
1209   } else {
1210     fprintf(stderr, "Unknown format %s\n", filename);
1211     exit(1);
1212   }
1213   // Prepare sparse tensor object with per-dimension sizes
1214   // and the number of nonzeros as initial capacity.
1215   assert(rank == idata[0] && "rank mismatch");
1216   uint64_t nnz = idata[1];
1217   for (uint64_t r = 0; r < rank; r++)
1218     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
1219            "dimension size mismatch");
1220   SparseTensorCOO<V> *tensor =
1221       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
1222   // Read all nonzero elements.
1223   std::vector<uint64_t> indices(rank);
1224   for (uint64_t k = 0; k < nnz; k++) {
1225     if (!fgets(line, kColWidth, file)) {
1226       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
1227       exit(1);
1228     }
1229     char *linePtr = line;
1230     for (uint64_t r = 0; r < rank; r++) {
1231       uint64_t idx = strtoul(linePtr, &linePtr, 10);
1232       // Add 0-based index.
1233       indices[perm[r]] = idx - 1;
1234     }
1235     // The external formats always store the numerical values with the type
1236     // double, but we cast these values to the sparse tensor object type.
1237     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
1238     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
1239     tensor->add(indices, value);
1240     // We currently chose to deal with symmetric matrices by fully constructing
1241     // them. In the future, we may want to make symmetry implicit for storage
1242     // reasons.
1243     if (isSymmetric && indices[0] != indices[1])
1244       tensor->add({indices[1], indices[0]}, value);
1245   }
1246   // Close the file and return tensor.
1247   fclose(file);
1248   return tensor;
1249 }
1250 
1251 /// Writes the sparse tensor to extended FROSTT format.
1252 template <typename V>
1253 static void outSparseTensor(void *tensor, void *dest, bool sort) {
1254   assert(tensor && dest);
1255   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
1256   if (sort)
1257     coo->sort();
1258   char *filename = static_cast<char *>(dest);
1259   auto &sizes = coo->getSizes();
1260   auto &elements = coo->getElements();
1261   uint64_t rank = coo->getRank();
1262   uint64_t nnz = elements.size();
1263   std::fstream file;
1264   file.open(filename, std::ios_base::out | std::ios_base::trunc);
1265   assert(file.is_open());
1266   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1267   for (uint64_t r = 0; r < rank - 1; r++)
1268     file << sizes[r] << " ";
1269   file << sizes[rank - 1] << std::endl;
1270   for (uint64_t i = 0; i < nnz; i++) {
1271     auto &idx = elements[i].indices;
1272     for (uint64_t r = 0; r < rank; r++)
1273       file << (idx[r] + 1) << " ";
1274     file << elements[i].value << std::endl;
1275   }
1276   file.flush();
1277   file.close();
1278   assert(file.good());
1279 }
1280 
1281 /// Initializes sparse tensor from an external COO-flavored format.
1282 template <typename V>
1283 static SparseTensorStorage<uint64_t, uint64_t, V> *
1284 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
1285                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
1286   const DimLevelType *sparsity = (DimLevelType *)(sparse);
1287 #ifndef NDEBUG
1288   // Verify that perm is a permutation of 0..(rank-1).
1289   std::vector<uint64_t> order(perm, perm + rank);
1290   std::sort(order.begin(), order.end());
1291   for (uint64_t i = 0; i < rank; ++i) {
1292     if (i != order[i]) {
1293       fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
1294       exit(1);
1295     }
1296   }
1297 
1298   // Verify that the sparsity values are supported.
1299   for (uint64_t i = 0; i < rank; ++i) {
1300     if (sparsity[i] != DimLevelType::kDense &&
1301         sparsity[i] != DimLevelType::kCompressed) {
1302       fprintf(stderr, "Unsupported sparsity value %d\n",
1303               static_cast<int>(sparsity[i]));
1304       exit(1);
1305     }
1306   }
1307 #endif
1308 
1309   // Convert external format to internal COO.
1310   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
1311   std::vector<uint64_t> idx(rank);
1312   for (uint64_t i = 0, base = 0; i < nse; i++) {
1313     for (uint64_t r = 0; r < rank; r++)
1314       idx[perm[r]] = indices[base + r];
1315     coo->add(idx, values[i]);
1316     base += rank;
1317   }
1318   // Return sparse tensor storage format as opaque pointer.
1319   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
1320       rank, shape, perm, sparsity, coo);
1321   delete coo;
1322   return tensor;
1323 }
1324 
1325 /// Converts a sparse tensor to an external COO-flavored format.
1326 template <typename V>
1327 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1328                                  uint64_t **pShape, V **pValues,
1329                                  uint64_t **pIndices) {
1330   assert(tensor);
1331   auto sparseTensor =
1332       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1333   uint64_t rank = sparseTensor->getRank();
1334   std::vector<uint64_t> perm(rank);
1335   std::iota(perm.begin(), perm.end(), 0);
1336   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1337 
1338   const std::vector<Element<V>> &elements = coo->getElements();
1339   uint64_t nse = elements.size();
1340 
1341   uint64_t *shape = new uint64_t[rank];
1342   for (uint64_t i = 0; i < rank; i++)
1343     shape[i] = coo->getSizes()[i];
1344 
1345   V *values = new V[nse];
1346   uint64_t *indices = new uint64_t[rank * nse];
1347 
1348   for (uint64_t i = 0, base = 0; i < nse; i++) {
1349     values[i] = elements[i].value;
1350     for (uint64_t j = 0; j < rank; j++)
1351       indices[base + j] = elements[i].indices[j];
1352     base += rank;
1353   }
1354 
1355   delete coo;
1356   *pRank = rank;
1357   *pNse = nse;
1358   *pShape = shape;
1359   *pValues = values;
1360   *pIndices = indices;
1361 }
1362 
1363 } // namespace
1364 
1365 extern "C" {
1366 
1367 //===----------------------------------------------------------------------===//
1368 //
1369 // Public API with methods that operate on MLIR buffers (memrefs) to interact
1370 // with sparse tensors, which are only visible as opaque pointers externally.
1371 // These methods should be used exclusively by MLIR compiler-generated code.
1372 //
1373 // Some macro magic is used to generate implementations for all required type
1374 // combinations that can be called from MLIR compiler-generated code.
1375 //
1376 //===----------------------------------------------------------------------===//
1377 
1378 #define CASE(p, i, v, P, I, V)                                                 \
1379   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
1380     SparseTensorCOO<V> *coo = nullptr;                                         \
1381     if (action <= Action::kFromCOO) {                                          \
1382       if (action == Action::kFromFile) {                                       \
1383         char *filename = static_cast<char *>(ptr);                             \
1384         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
1385       } else if (action == Action::kFromCOO) {                                 \
1386         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
1387       } else {                                                                 \
1388         assert(action == Action::kEmpty);                                      \
1389       }                                                                        \
1390       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
1391           rank, shape, perm, sparsity, coo);                                   \
1392       if (action == Action::kFromFile)                                         \
1393         delete coo;                                                            \
1394       return tensor;                                                           \
1395     }                                                                          \
1396     if (action == Action::kSparseToSparse) {                                   \
1397       auto *tensor = static_cast<SparseTensorStorageBase *>(ptr);              \
1398       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
1399                                                            sparsity, tensor);  \
1400     }                                                                          \
1401     if (action == Action::kEmptyCOO)                                           \
1402       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
1403     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1404     if (action == Action::kToIterator) {                                       \
1405       coo->startIterator();                                                    \
1406     } else {                                                                   \
1407       assert(action == Action::kToCOO);                                        \
1408     }                                                                          \
1409     return coo;                                                                \
1410   }
1411 
1412 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1413 
1414 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1415 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1416 // that this file cannot get out of sync with its header.
1417 static_assert(std::is_same<index_type, uint64_t>::value,
1418               "Expected index_type == uint64_t");
1419 
1420 /// Constructs a new sparse tensor. This is the "swiss army knife"
1421 /// method for materializing sparse tensors into the computation.
1422 ///
1423 /// Action:
1424 /// kEmpty = returns empty storage to fill later
1425 /// kFromFile = returns storage, where ptr contains filename to read
1426 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
1427 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
1428 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
1429 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
1430 void *
1431 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1432                              StridedMemRefType<index_type, 1> *sref,
1433                              StridedMemRefType<index_type, 1> *pref,
1434                              OverheadType ptrTp, OverheadType indTp,
1435                              PrimaryType valTp, Action action, void *ptr) {
1436   assert(aref && sref && pref);
1437   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1438          pref->strides[0] == 1);
1439   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1440   const DimLevelType *sparsity = aref->data + aref->offset;
1441   const index_type *shape = sref->data + sref->offset;
1442   const index_type *perm = pref->data + pref->offset;
1443   uint64_t rank = aref->sizes[0];
1444 
1445   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1446   // This is safe because of the static_assert above.
1447   if (ptrTp == OverheadType::kIndex)
1448     ptrTp = OverheadType::kU64;
1449   if (indTp == OverheadType::kIndex)
1450     indTp = OverheadType::kU64;
1451 
1452   // Double matrices with all combinations of overhead storage.
1453   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1454        uint64_t, double);
1455   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1456        uint32_t, double);
1457   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1458        uint16_t, double);
1459   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1460        uint8_t, double);
1461   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1462        uint64_t, double);
1463   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1464        uint32_t, double);
1465   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1466        uint16_t, double);
1467   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1468        uint8_t, double);
1469   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1470        uint64_t, double);
1471   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1472        uint32_t, double);
1473   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1474        uint16_t, double);
1475   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1476        uint8_t, double);
1477   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1478        uint64_t, double);
1479   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1480        uint32_t, double);
1481   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1482        uint16_t, double);
1483   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1484        uint8_t, double);
1485 
1486   // Float matrices with all combinations of overhead storage.
1487   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1488        uint64_t, float);
1489   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1490        uint32_t, float);
1491   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1492        uint16_t, float);
1493   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1494        uint8_t, float);
1495   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1496        uint64_t, float);
1497   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1498        uint32_t, float);
1499   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1500        uint16_t, float);
1501   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1502        uint8_t, float);
1503   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1504        uint64_t, float);
1505   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1506        uint32_t, float);
1507   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1508        uint16_t, float);
1509   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1510        uint8_t, float);
1511   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1512        uint64_t, float);
1513   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1514        uint32_t, float);
1515   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1516        uint16_t, float);
1517   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1518        uint8_t, float);
1519 
1520   // Integral matrices with both overheads of the same type.
1521   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1522   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1523   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1524   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1525   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1526   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1527   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1528   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1529   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1530   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1531   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1532   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1533   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1534 
1535   // Complex matrices with wide overhead.
1536   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1537   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1538 
1539   // Unsupported case (add above if needed).
1540   fputs("unsupported combination of types\n", stderr);
1541   exit(1);
1542 }
1543 #undef CASE
1544 #undef CASE_SECSAME
1545 
1546 /// Methods that provide direct access to values.
1547 #define IMPL_SPARSEVALUES(VNAME, V)                                            \
1548   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
1549                                         void *tensor) {                        \
1550     assert(ref &&tensor);                                                      \
1551     std::vector<V> *v;                                                         \
1552     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
1553     ref->basePtr = ref->data = v->data();                                      \
1554     ref->offset = 0;                                                           \
1555     ref->sizes[0] = v->size();                                                 \
1556     ref->strides[0] = 1;                                                       \
1557   }
1558 FOREVERY_V(IMPL_SPARSEVALUES)
1559 #undef IMPL_SPARSEVALUES
1560 
1561 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1562   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1563                            index_type d) {                                     \
1564     assert(ref &&tensor);                                                      \
1565     std::vector<TYPE> *v;                                                      \
1566     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1567     ref->basePtr = ref->data = v->data();                                      \
1568     ref->offset = 0;                                                           \
1569     ref->sizes[0] = v->size();                                                 \
1570     ref->strides[0] = 1;                                                       \
1571   }
1572 /// Methods that provide direct access to pointers.
1573 IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
1574 IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
1575 IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
1576 IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
1577 IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
1578 
1579 /// Methods that provide direct access to indices.
1580 IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
1581 IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
1582 IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
1583 IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
1584 IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
1585 #undef IMPL_GETOVERHEAD
1586 
1587 /// Helper to add value to coordinate scheme, one per value type.
1588 #define IMPL_ADDELT(VNAME, V)                                                  \
1589   void *_mlir_ciface_addElt##VNAME(void *coo, V value,                         \
1590                                    StridedMemRefType<index_type, 1> *iref,     \
1591                                    StridedMemRefType<index_type, 1> *pref) {   \
1592     assert(coo &&iref &&pref);                                                 \
1593     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1594     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1595     const index_type *indx = iref->data + iref->offset;                        \
1596     const index_type *perm = pref->data + pref->offset;                        \
1597     uint64_t isize = iref->sizes[0];                                           \
1598     std::vector<index_type> indices(isize);                                    \
1599     for (uint64_t r = 0; r < isize; r++)                                       \
1600       indices[perm[r]] = indx[r];                                              \
1601     static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value);               \
1602     return coo;                                                                \
1603   }
1604 FOREVERY_SIMPLEX_V(IMPL_ADDELT)
1605 // `complex64` apparently doesn't encounter any ABI issues (yet).
1606 IMPL_ADDELT(C64, complex64)
1607 // TODO: cleaner way to avoid ABI padding problem?
1608 IMPL_ADDELT(C32ABI, complex32)
1609 void *_mlir_ciface_addEltC32(void *coo, float r, float i,
1610                              StridedMemRefType<index_type, 1> *iref,
1611                              StridedMemRefType<index_type, 1> *pref) {
1612   return _mlir_ciface_addEltC32ABI(coo, complex32(r, i), iref, pref);
1613 }
1614 #undef IMPL_ADDELT
1615 
1616 /// Helper to enumerate elements of coordinate scheme, one per value type.
1617 #define IMPL_GETNEXT(VNAME, V)                                                 \
1618   bool _mlir_ciface_getNext##VNAME(void *coo,                                  \
1619                                    StridedMemRefType<index_type, 1> *iref,     \
1620                                    StridedMemRefType<V, 0> *vref) {            \
1621     assert(coo &&iref &&vref);                                                 \
1622     assert(iref->strides[0] == 1);                                             \
1623     index_type *indx = iref->data + iref->offset;                              \
1624     V *value = vref->data + vref->offset;                                      \
1625     const uint64_t isize = iref->sizes[0];                                     \
1626     const Element<V> *elem =                                                   \
1627         static_cast<SparseTensorCOO<V> *>(coo)->getNext();                     \
1628     if (elem == nullptr)                                                       \
1629       return false;                                                            \
1630     for (uint64_t r = 0; r < isize; r++)                                       \
1631       indx[r] = elem->indices[r];                                              \
1632     *value = elem->value;                                                      \
1633     return true;                                                               \
1634   }
1635 FOREVERY_V(IMPL_GETNEXT)
1636 #undef IMPL_GETNEXT
1637 
1638 /// Insert elements in lexicographical index order, one per value type.
1639 #define IMPL_LEXINSERT(VNAME, V)                                               \
1640   void _mlir_ciface_lexInsert##VNAME(                                          \
1641       void *tensor, StridedMemRefType<index_type, 1> *cref, V val) {           \
1642     assert(tensor &&cref);                                                     \
1643     assert(cref->strides[0] == 1);                                             \
1644     index_type *cursor = cref->data + cref->offset;                            \
1645     assert(cursor);                                                            \
1646     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1647   }
1648 FOREVERY_SIMPLEX_V(IMPL_LEXINSERT)
1649 // `complex64` apparently doesn't encounter any ABI issues (yet).
1650 IMPL_LEXINSERT(C64, complex64)
1651 // TODO: cleaner way to avoid ABI padding problem?
1652 IMPL_LEXINSERT(C32ABI, complex32)
1653 void _mlir_ciface_lexInsertC32(void *tensor,
1654                                StridedMemRefType<index_type, 1> *cref, float r,
1655                                float i) {
1656   _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
1657 }
1658 #undef IMPL_LEXINSERT
1659 
1660 /// Insert using expansion, one per value type.
1661 #define IMPL_EXPINSERT(VNAME, V)                                               \
1662   void _mlir_ciface_expInsert##VNAME(                                          \
1663       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1664       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1665       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1666     assert(tensor &&cref &&vref &&fref &&aref);                                \
1667     assert(cref->strides[0] == 1);                                             \
1668     assert(vref->strides[0] == 1);                                             \
1669     assert(fref->strides[0] == 1);                                             \
1670     assert(aref->strides[0] == 1);                                             \
1671     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1672     index_type *cursor = cref->data + cref->offset;                            \
1673     V *values = vref->data + vref->offset;                                     \
1674     bool *filled = fref->data + fref->offset;                                  \
1675     index_type *added = aref->data + aref->offset;                             \
1676     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1677         cursor, values, filled, added, count);                                 \
1678   }
1679 FOREVERY_V(IMPL_EXPINSERT)
1680 #undef IMPL_EXPINSERT
1681 
1682 /// Output a sparse tensor, one per value type.
1683 #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \
1684   void outSparseTensor##VNAME(void *coo, void *dest, bool sort) {              \
1685     return outSparseTensor<V>(coo, dest, sort);                                \
1686   }
1687 FOREVERY_V(IMPL_OUTSPARSETENSOR)
1688 #undef IMPL_OUTSPARSETENSOR
1689 
1690 //===----------------------------------------------------------------------===//
1691 //
1692 // Public API with methods that accept C-style data structures to interact
1693 // with sparse tensors, which are only visible as opaque pointers externally.
1694 // These methods can be used both by MLIR compiler-generated code as well as by
1695 // an external runtime that wants to interact with MLIR compiler-generated code.
1696 //
1697 //===----------------------------------------------------------------------===//
1698 
1699 /// Helper method to read a sparse tensor filename from the environment,
1700 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1701 char *getTensorFilename(index_type id) {
1702   char var[80];
1703   sprintf(var, "TENSOR%" PRIu64, id);
1704   char *env = getenv(var);
1705   if (!env) {
1706     fprintf(stderr, "Environment variable %s is not set\n", var);
1707     exit(1);
1708   }
1709   return env;
1710 }
1711 
1712 /// Returns size of sparse tensor in given dimension.
1713 index_type sparseDimSize(void *tensor, index_type d) {
1714   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1715 }
1716 
1717 /// Finalizes lexicographic insertions.
1718 void endInsert(void *tensor) {
1719   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1720 }
1721 
1722 /// Releases sparse tensor storage.
1723 void delSparseTensor(void *tensor) {
1724   delete static_cast<SparseTensorStorageBase *>(tensor);
1725 }
1726 
1727 /// Releases sparse tensor coordinate scheme.
1728 #define IMPL_DELCOO(VNAME, V)                                                  \
1729   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1730     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1731   }
1732 FOREVERY_V(IMPL_DELCOO)
1733 #undef IMPL_DELCOO
1734 
1735 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
1736 /// data structures. The expected parameters are:
1737 ///
1738 ///   rank:    rank of tensor
1739 ///   nse:     number of specified elements (usually the nonzeros)
1740 ///   shape:   array with dimension size for each rank
1741 ///   values:  a "nse" array with values for all specified elements
1742 ///   indices: a flat "nse x rank" array with indices for all specified elements
1743 ///   perm:    the permutation of the dimensions in the storage
1744 ///   sparse:  the sparsity for the dimensions
1745 ///
1746 /// For example, the sparse matrix
1747 ///     | 1.0 0.0 0.0 |
1748 ///     | 0.0 5.0 3.0 |
1749 /// can be passed as
1750 ///      rank    = 2
1751 ///      nse     = 3
1752 ///      shape   = [2, 3]
1753 ///      values  = [1.0, 5.0, 3.0]
1754 ///      indices = [ 0, 0,  1, 1,  1, 2]
1755 //
1756 // TODO: generalize beyond 64-bit indices.
1757 //
1758 #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V)                               \
1759   void *convertToMLIRSparseTensor##VNAME(                                      \
1760       uint64_t rank, uint64_t nse, uint64_t *shape, V *values,                 \
1761       uint64_t *indices, uint64_t *perm, uint8_t *sparse) {                    \
1762     return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm,      \
1763                                  sparse);                                      \
1764   }
1765 FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
1766 #undef IMPL_CONVERTTOMLIRSPARSETENSOR
1767 
1768 /// Converts a sparse tensor to COO-flavored format expressed using C-style
1769 /// data structures. The expected output parameters are pointers for these
1770 /// values:
1771 ///
1772 ///   rank:    rank of tensor
1773 ///   nse:     number of specified elements (usually the nonzeros)
1774 ///   shape:   array with dimension size for each rank
1775 ///   values:  a "nse" array with values for all specified elements
1776 ///   indices: a flat "nse x rank" array with indices for all specified elements
1777 ///
1778 /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
1779 /// from convertToMLIRSparseTensor.
1780 ///
1781 //  TODO: Currently, values are copied from SparseTensorStorage to
1782 //  SparseTensorCOO, then to the output. We may want to reduce the number of
1783 //  copies.
1784 //
1785 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1786 // compressed
1787 //
1788 #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V)                             \
1789   void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank,       \
1790                                           uint64_t *pNse, uint64_t **pShape,   \
1791                                           V **pValues, uint64_t **pIndices) {  \
1792     fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices);   \
1793   }
1794 FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
1795 #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
1796 
1797 } // extern "C"
1798 
1799 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1800