1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <complex>
25 #include <cctype>
26 #include <cinttypes>
27 #include <cstdio>
28 #include <cstdlib>
29 #include <cstring>
30 #include <fstream>
31 #include <functional>
32 #include <iostream>
33 #include <limits>
34 #include <numeric>
35 #include <vector>
36 
37 using complex64 = std::complex<double>;
38 using complex32 = std::complex<float>;
39 
40 //===----------------------------------------------------------------------===//
41 //
42 // Internal support for storing and reading sparse tensors.
43 //
44 // The following memory-resident sparse storage schemes are supported:
45 //
46 // (a) A coordinate scheme for temporarily storing and lexicographically
47 //     sorting a sparse tensor by index (SparseTensorCOO).
48 //
49 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
50 //     per-dimension sparse/dense annnotations together with a dimension
51 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
52 //
53 // The following external formats are supported:
54 //
55 // (1) Matrix Market Exchange (MME): *.mtx
56 //     https://math.nist.gov/MatrixMarket/formats.html
57 //
58 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
59 //     http://frostt.io/tensors/file-formats.html
60 //
61 // Two public APIs are supported:
62 //
63 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
64 //     tensors. These methods should be used exclusively by MLIR
65 //     compiler-generated code.
66 //
67 // (II) Methods that accept C-style data structures to interact with sparse
68 //      tensors. These methods can be used by any external runtime that wants
69 //      to interact with MLIR compiler-generated code.
70 //
71 // In both cases (I) and (II), the SparseTensorStorage format is externally
72 // only visible as an opaque pointer.
73 //
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 
78 static constexpr int kColWidth = 1025;
79 
80 /// A version of `operator*` on `uint64_t` which checks for overflows.
81 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
82   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
83          "Integer overflow");
84   return lhs * rhs;
85 }
86 
87 // TODO: adjust this so it can be used by `openSparseTensorCOO` too.
88 // That version doesn't have the permutation, and the `dimSizes` are
89 // a pointer/C-array rather than `std::vector`.
90 //
91 /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
92 /// semantic-order to target-order) are a refinement of the desired `shape`
93 /// (in semantic-order).
94 ///
95 /// Precondition: `perm` and `shape` must be valid for `rank`.
96 static inline void
97 assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes,
98                               uint64_t rank, const uint64_t *perm,
99                               const uint64_t *shape) {
100   assert(perm && shape);
101   assert(rank == dimSizes.size() && "Rank mismatch");
102   for (uint64_t r = 0; r < rank; r++)
103     assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) &&
104            "Dimension size mismatch");
105 }
106 
107 /// A sparse tensor element in coordinate scheme (value and indices).
108 /// For example, a rank-1 vector element would look like
109 ///   ({i}, a[i])
110 /// and a rank-5 tensor element like
111 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
112 /// We use pointer to a shared index pool rather than e.g. a direct
113 /// vector since that (1) reduces the per-element memory footprint, and
114 /// (2) centralizes the memory reservation and (re)allocation to one place.
115 template <typename V>
116 struct Element final {
117   Element(uint64_t *ind, V val) : indices(ind), value(val){};
118   uint64_t *indices; // pointer into shared index pool
119   V value;
120 };
121 
122 /// The type of callback functions which receive an element.  We avoid
123 /// packaging the coordinates and value together as an `Element` object
124 /// because this helps keep code somewhat cleaner.
125 template <typename V>
126 using ElementConsumer =
127     const std::function<void(const std::vector<uint64_t> &, V)> &;
128 
129 /// A memory-resident sparse tensor in coordinate scheme (collection of
130 /// elements). This data structure is used to read a sparse tensor from
131 /// any external format into memory and sort the elements lexicographically
132 /// by indices before passing it back to the client (most packed storage
133 /// formats require the elements to appear in lexicographic index order).
134 template <typename V>
135 struct SparseTensorCOO final {
136 public:
137   SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
138       : dimSizes(dimSizes) {
139     if (capacity) {
140       elements.reserve(capacity);
141       indices.reserve(capacity * getRank());
142     }
143   }
144 
145   /// Adds element as indices and value.
146   void add(const std::vector<uint64_t> &ind, V val) {
147     assert(!iteratorLocked && "Attempt to add() after startIterator()");
148     uint64_t *base = indices.data();
149     uint64_t size = indices.size();
150     uint64_t rank = getRank();
151     assert(ind.size() == rank && "Element rank mismatch");
152     for (uint64_t r = 0; r < rank; r++) {
153       assert(ind[r] < dimSizes[r] && "Index is too large for the dimension");
154       indices.push_back(ind[r]);
155     }
156     // This base only changes if indices were reallocated. In that case, we
157     // need to correct all previous pointers into the vector. Note that this
158     // only happens if we did not set the initial capacity right, and then only
159     // for every internal vector reallocation (which with the doubling rule
160     // should only incur an amortized linear overhead).
161     uint64_t *newBase = indices.data();
162     if (newBase != base) {
163       for (uint64_t i = 0, n = elements.size(); i < n; i++)
164         elements[i].indices = newBase + (elements[i].indices - base);
165       base = newBase;
166     }
167     // Add element as (pointer into shared index pool, value) pair.
168     elements.emplace_back(base + size, val);
169   }
170 
171   /// Sorts elements lexicographically by index.
172   void sort() {
173     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
174     // TODO: we may want to cache an `isSorted` bit, to avoid
175     // unnecessary/redundant sorting.
176     std::sort(elements.begin(), elements.end(),
177               [this](const Element<V> &e1, const Element<V> &e2) {
178                 uint64_t rank = getRank();
179                 for (uint64_t r = 0; r < rank; r++) {
180                   if (e1.indices[r] == e2.indices[r])
181                     continue;
182                   return e1.indices[r] < e2.indices[r];
183                 }
184                 return false;
185               });
186   }
187 
188   /// Get the rank of the tensor.
189   uint64_t getRank() const { return dimSizes.size(); }
190 
191   /// Getter for the dimension-sizes array.
192   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
193 
194   /// Getter for the elements array.
195   const std::vector<Element<V>> &getElements() const { return elements; }
196 
197   /// Switch into iterator mode.
198   void startIterator() {
199     iteratorLocked = true;
200     iteratorPos = 0;
201   }
202 
203   /// Get the next element.
204   const Element<V> *getNext() {
205     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
206     if (iteratorPos < elements.size())
207       return &(elements[iteratorPos++]);
208     iteratorLocked = false;
209     return nullptr;
210   }
211 
212   /// Factory method. Permutes the original dimensions according to
213   /// the given ordering and expects subsequent add() calls to honor
214   /// that same ordering for the given indices. The result is a
215   /// fully permuted coordinate scheme.
216   ///
217   /// Precondition: `dimSizes` and `perm` must be valid for `rank`.
218   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
219                                                 const uint64_t *dimSizes,
220                                                 const uint64_t *perm,
221                                                 uint64_t capacity = 0) {
222     std::vector<uint64_t> permsz(rank);
223     for (uint64_t r = 0; r < rank; r++) {
224       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
225       permsz[perm[r]] = dimSizes[r];
226     }
227     return new SparseTensorCOO<V>(permsz, capacity);
228   }
229 
230 private:
231   const std::vector<uint64_t> dimSizes; // per-dimension sizes
232   std::vector<Element<V>> elements;     // all COO elements
233   std::vector<uint64_t> indices;        // shared index pool
234   bool iteratorLocked = false;
235   unsigned iteratorPos = 0;
236 };
237 
238 // See <https://en.wikipedia.org/wiki/X_Macro>
239 //
240 // `FOREVERY_SIMPLEX_V` only specifies the non-complex `V` types, because
241 // the ABI for complex types has compiler/architecture dependent complexities
242 // we need to work around.  Namely, when a function takes a parameter of
243 // C/C++ type `complex32` (per se), then there is additional padding that
244 // causes it not to match the LLVM type `!llvm.struct<(f32, f32)>`.  This
245 // only happens with the `complex32` type itself, not with pointers/arrays
246 // of complex values.  So far `complex64` doesn't exhibit this ABI
247 // incompatibility, but we exclude it anyways just to be safe.
248 #define FOREVERY_SIMPLEX_V(DO)                                                 \
249   DO(F64, double)                                                              \
250   DO(F32, float)                                                               \
251   DO(I64, int64_t)                                                             \
252   DO(I32, int32_t)                                                             \
253   DO(I16, int16_t)                                                             \
254   DO(I8, int8_t)
255 
256 #define FOREVERY_V(DO)                                                         \
257   FOREVERY_SIMPLEX_V(DO)                                                       \
258   DO(C64, complex64)                                                           \
259   DO(C32, complex32)
260 
261 // Forward.
262 template <typename V>
263 class SparseTensorEnumeratorBase;
264 
265 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
266 /// takes responsibility for all the `<P,I,V>`-independent aspects
267 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
268 /// we use function overloading to implement "partial" method
269 /// specialization, which the C-API relies on to catch type errors
270 /// arising from our use of opaque pointers.
271 class SparseTensorStorageBase {
272 public:
273   /// Constructs a new storage object.  The `perm` maps the tensor's
274   /// semantic-ordering of dimensions to this object's storage-order.
275   /// The `dimSizes` and `sparsity` arrays are already in storage-order.
276   ///
277   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
278   SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes,
279                           const uint64_t *perm, const DimLevelType *sparsity)
280       : dimSizes(dimSizes), rev(getRank()),
281         dimTypes(sparsity, sparsity + getRank()) {
282     assert(perm && sparsity);
283     const uint64_t rank = getRank();
284     // Validate parameters.
285     assert(rank > 0 && "Trivial shape is unsupported");
286     for (uint64_t r = 0; r < rank; r++) {
287       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
288       assert((dimTypes[r] == DimLevelType::kDense ||
289               dimTypes[r] == DimLevelType::kCompressed) &&
290              "Unsupported DimLevelType");
291     }
292     // Construct the "reverse" (i.e., inverse) permutation.
293     for (uint64_t r = 0; r < rank; r++)
294       rev[perm[r]] = r;
295   }
296 
297   virtual ~SparseTensorStorageBase() = default;
298 
299   /// Get the rank of the tensor.
300   uint64_t getRank() const { return dimSizes.size(); }
301 
302   /// Getter for the dimension-sizes array, in storage-order.
303   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
304 
305   /// Safely lookup the size of the given (storage-order) dimension.
306   uint64_t getDimSize(uint64_t d) const {
307     assert(d < getRank());
308     return dimSizes[d];
309   }
310 
311   /// Getter for the "reverse" permutation, which maps this object's
312   /// storage-order to the tensor's semantic-order.
313   const std::vector<uint64_t> &getRev() const { return rev; }
314 
315   /// Getter for the dimension-types array, in storage-order.
316   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
317 
318   /// Safely check if the (storage-order) dimension uses compressed storage.
319   bool isCompressedDim(uint64_t d) const {
320     assert(d < getRank());
321     return (dimTypes[d] == DimLevelType::kCompressed);
322   }
323 
324   /// Allocate a new enumerator.
325 #define DECL_NEWENUMERATOR(VNAME, V)                                           \
326   virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t,       \
327                              const uint64_t *) const {                         \
328     fatal("newEnumerator" #VNAME);                                             \
329   }
330   FOREVERY_V(DECL_NEWENUMERATOR)
331 #undef DECL_NEWENUMERATOR
332 
333   /// Overhead storage.
334   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
335   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
336   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
337   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
338   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
339   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
340   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
341   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
342 
343   /// Primary storage.
344 #define DECL_GETVALUES(VNAME, V)                                               \
345   virtual void getValues(std::vector<V> **) { fatal("getValues" #VNAME); }
346   FOREVERY_V(DECL_GETVALUES)
347 #undef DECL_GETVALUES
348 
349   /// Element-wise insertion in lexicographic index order.
350 #define DECL_LEXINSERT(VNAME, V)                                               \
351   virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); }
352   FOREVERY_V(DECL_LEXINSERT)
353 #undef DECL_LEXINSERT
354 
355   /// Expanded insertion.
356 #define DECL_EXPINSERT(VNAME, V)                                               \
357   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) {      \
358     fatal("expInsert" #VNAME);                                                 \
359   }
360   FOREVERY_V(DECL_EXPINSERT)
361 #undef DECL_EXPINSERT
362 
363   /// Finishes insertion.
364   virtual void endInsert() = 0;
365 
366 protected:
367   // Since this class is virtual, we must disallow public copying in
368   // order to avoid "slicing".  Since this class has data members,
369   // that means making copying protected.
370   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
371   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
372   // Copy-assignment would be implicitly deleted (because `dimSizes`
373   // is const), so we explicitly delete it for clarity.
374   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
375 
376 private:
377   static void fatal(const char *tp) {
378     fprintf(stderr, "unsupported %s\n", tp);
379     exit(1);
380   }
381 
382   const std::vector<uint64_t> dimSizes;
383   std::vector<uint64_t> rev;
384   const std::vector<DimLevelType> dimTypes;
385 };
386 
387 // Forward.
388 template <typename P, typename I, typename V>
389 class SparseTensorEnumerator;
390 
391 /// A memory-resident sparse tensor using a storage scheme based on
392 /// per-dimension sparse/dense annotations. This data structure provides a
393 /// bufferized form of a sparse tensor type. In contrast to generating setup
394 /// methods for each differently annotated sparse tensor, this method provides
395 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
396 /// and annotations to implement all required setup in a general manner.
397 template <typename P, typename I, typename V>
398 class SparseTensorStorage final : public SparseTensorStorageBase {
399   /// Private constructor to share code between the other constructors.
400   /// Beware that the object is not necessarily guaranteed to be in a
401   /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
402   /// doesn't entail `!(pointers[d].empty())`.
403   ///
404   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
405   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
406                       const uint64_t *perm, const DimLevelType *sparsity)
407       : SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()),
408         indices(getRank()), idx(getRank()) {}
409 
410 public:
411   /// Constructs a sparse tensor storage scheme with the given dimensions,
412   /// permutation, and per-dimension dense/sparse annotations, using
413   /// the coordinate scheme tensor for the initial contents if provided.
414   ///
415   /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
416   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
417                       const uint64_t *perm, const DimLevelType *sparsity,
418                       SparseTensorCOO<V> *coo)
419       : SparseTensorStorage(dimSizes, perm, sparsity) {
420     // Provide hints on capacity of pointers and indices.
421     // TODO: needs much fine-tuning based on actual sparsity; currently
422     //       we reserve pointer/index space based on all previous dense
423     //       dimensions, which works well up to first sparse dim; but
424     //       we should really use nnz and dense/sparse distribution.
425     bool allDense = true;
426     uint64_t sz = 1;
427     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
428       if (isCompressedDim(r)) {
429         // TODO: Take a parameter between 1 and `dimSizes[r]`, and multiply
430         // `sz` by that before reserving. (For now we just use 1.)
431         pointers[r].reserve(sz + 1);
432         pointers[r].push_back(0);
433         indices[r].reserve(sz);
434         sz = 1;
435         allDense = false;
436       } else { // Dense dimension.
437         sz = checkedMul(sz, getDimSizes()[r]);
438       }
439     }
440     // Then assign contents from coordinate scheme tensor if provided.
441     if (coo) {
442       // Ensure both preconditions of `fromCOO`.
443       assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch");
444       coo->sort();
445       // Now actually insert the `elements`.
446       const std::vector<Element<V>> &elements = coo->getElements();
447       uint64_t nnz = elements.size();
448       values.reserve(nnz);
449       fromCOO(elements, 0, nnz, 0);
450     } else if (allDense) {
451       values.resize(sz, 0);
452     }
453   }
454 
455   /// Constructs a sparse tensor storage scheme with the given dimensions,
456   /// permutation, and per-dimension dense/sparse annotations, using
457   /// the given sparse tensor for the initial contents.
458   ///
459   /// Preconditions:
460   /// * `perm` and `sparsity` must be valid for `dimSizes.size()`.
461   /// * The `tensor` must have the same value type `V`.
462   SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
463                       const uint64_t *perm, const DimLevelType *sparsity,
464                       const SparseTensorStorageBase &tensor);
465 
466   ~SparseTensorStorage() final override = default;
467 
468   /// Partially specialize these getter methods based on template types.
469   void getPointers(std::vector<P> **out, uint64_t d) final override {
470     assert(d < getRank());
471     *out = &pointers[d];
472   }
473   void getIndices(std::vector<I> **out, uint64_t d) final override {
474     assert(d < getRank());
475     *out = &indices[d];
476   }
477   void getValues(std::vector<V> **out) final override { *out = &values; }
478 
479   /// Partially specialize lexicographical insertions based on template types.
480   void lexInsert(const uint64_t *cursor, V val) final override {
481     // First, wrap up pending insertion path.
482     uint64_t diff = 0;
483     uint64_t top = 0;
484     if (!values.empty()) {
485       diff = lexDiff(cursor);
486       endPath(diff + 1);
487       top = idx[diff] + 1;
488     }
489     // Then continue with insertion path.
490     insPath(cursor, diff, top, val);
491   }
492 
493   /// Partially specialize expanded insertions based on template types.
494   /// Note that this method resets the values/filled-switch array back
495   /// to all-zero/false while only iterating over the nonzero elements.
496   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
497                  uint64_t count) final override {
498     if (count == 0)
499       return;
500     // Sort.
501     std::sort(added, added + count);
502     // Restore insertion path for first insert.
503     const uint64_t lastDim = getRank() - 1;
504     uint64_t index = added[0];
505     cursor[lastDim] = index;
506     lexInsert(cursor, values[index]);
507     assert(filled[index]);
508     values[index] = 0;
509     filled[index] = false;
510     // Subsequent insertions are quick.
511     for (uint64_t i = 1; i < count; i++) {
512       assert(index < added[i] && "non-lexicographic insertion");
513       index = added[i];
514       cursor[lastDim] = index;
515       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
516       assert(filled[index]);
517       values[index] = 0;
518       filled[index] = false;
519     }
520   }
521 
522   /// Finalizes lexicographic insertions.
523   void endInsert() final override {
524     if (values.empty())
525       finalizeSegment(0);
526     else
527       endPath(0);
528   }
529 
530   void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
531                      const uint64_t *perm) const final override {
532     *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
533   }
534 
535   /// Returns this sparse tensor storage scheme as a new memory-resident
536   /// sparse tensor in coordinate scheme with the given dimension order.
537   ///
538   /// Precondition: `perm` must be valid for `getRank()`.
539   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
540     SparseTensorEnumeratorBase<V> *enumerator;
541     newEnumerator(&enumerator, getRank(), perm);
542     SparseTensorCOO<V> *coo =
543         new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
544     enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
545       coo->add(ind, val);
546     });
547     // TODO: This assertion assumes there are no stored zeros,
548     // or if there are then that we don't filter them out.
549     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
550     assert(coo->getElements().size() == values.size());
551     delete enumerator;
552     return coo;
553   }
554 
555   /// Factory method. Constructs a sparse tensor storage scheme with the given
556   /// dimensions, permutation, and per-dimension dense/sparse annotations,
557   /// using the coordinate scheme tensor for the initial contents if provided.
558   /// In the latter case, the coordinate scheme must respect the same
559   /// permutation as is desired for the new sparse tensor storage.
560   ///
561   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
562   static SparseTensorStorage<P, I, V> *
563   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
564                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
565     SparseTensorStorage<P, I, V> *n = nullptr;
566     if (coo) {
567       const auto &coosz = coo->getDimSizes();
568       assertPermutedSizesMatchShape(coosz, rank, perm, shape);
569       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
570     } else {
571       std::vector<uint64_t> permsz(rank);
572       for (uint64_t r = 0; r < rank; r++) {
573         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
574         permsz[perm[r]] = shape[r];
575       }
576       // We pass the null `coo` to ensure we select the intended constructor.
577       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
578     }
579     return n;
580   }
581 
582   /// Factory method. Constructs a sparse tensor storage scheme with
583   /// the given dimensions, permutation, and per-dimension dense/sparse
584   /// annotations, using the sparse tensor for the initial contents.
585   ///
586   /// Preconditions:
587   /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
588   /// * The `tensor` must have the same value type `V`.
589   static SparseTensorStorage<P, I, V> *
590   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
591                   const DimLevelType *sparsity,
592                   const SparseTensorStorageBase *source) {
593     assert(source && "Got nullptr for source");
594     SparseTensorEnumeratorBase<V> *enumerator;
595     source->newEnumerator(&enumerator, rank, perm);
596     const auto &permsz = enumerator->permutedSizes();
597     assertPermutedSizesMatchShape(permsz, rank, perm, shape);
598     auto *tensor =
599         new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
600     delete enumerator;
601     return tensor;
602   }
603 
604 private:
605   /// Appends an arbitrary new position to `pointers[d]`.  This method
606   /// checks that `pos` is representable in the `P` type; however, it
607   /// does not check that `pos` is semantically valid (i.e., larger than
608   /// the previous position and smaller than `indices[d].capacity()`).
609   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
610     assert(isCompressedDim(d));
611     assert(pos <= std::numeric_limits<P>::max() &&
612            "Pointer value is too large for the P-type");
613     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
614   }
615 
616   /// Appends index `i` to dimension `d`, in the semantically general
617   /// sense.  For non-dense dimensions, that means appending to the
618   /// `indices[d]` array, checking that `i` is representable in the `I`
619   /// type; however, we do not verify other semantic requirements (e.g.,
620   /// that `i` is in bounds for `dimSizes[d]`, and not previously occurring
621   /// in the same segment).  For dense dimensions, this method instead
622   /// appends the appropriate number of zeros to the `values` array,
623   /// where `full` is the number of "entries" already written to `values`
624   /// for this segment (aka one after the highest index previously appended).
625   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
626     if (isCompressedDim(d)) {
627       assert(i <= std::numeric_limits<I>::max() &&
628              "Index value is too large for the I-type");
629       indices[d].push_back(static_cast<I>(i));
630     } else { // Dense dimension.
631       assert(i >= full && "Index was already filled");
632       if (i == full)
633         return; // Short-circuit, since it'll be a nop.
634       if (d + 1 == getRank())
635         values.insert(values.end(), i - full, 0);
636       else
637         finalizeSegment(d + 1, 0, i - full);
638     }
639   }
640 
641   /// Writes the given coordinate to `indices[d][pos]`.  This method
642   /// checks that `i` is representable in the `I` type; however, it
643   /// does not check that `i` is semantically valid (i.e., in bounds
644   /// for `dimSizes[d]` and not elsewhere occurring in the same segment).
645   void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
646     assert(isCompressedDim(d));
647     // Subscript assignment to `std::vector` requires that the `pos`-th
648     // entry has been initialized; thus we must be sure to check `size()`
649     // here, instead of `capacity()` as would be ideal.
650     assert(pos < indices[d].size() && "Index position is out of bounds");
651     assert(i <= std::numeric_limits<I>::max() &&
652            "Index value is too large for the I-type");
653     indices[d][pos] = static_cast<I>(i);
654   }
655 
656   /// Computes the assembled-size associated with the `d`-th dimension,
657   /// given the assembled-size associated with the `(d-1)`-th dimension.
658   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
659   /// storage, as opposed to "dimension-sizes" which are the cardinality
660   /// of coordinates for that dimension.
661   ///
662   /// Precondition: the `pointers[d]` array must be fully initialized
663   /// before calling this method.
664   uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
665     if (isCompressedDim(d))
666       return pointers[d][parentSz];
667     // else if dense:
668     return parentSz * getDimSizes()[d];
669   }
670 
671   /// Initializes sparse tensor storage scheme from a memory-resident sparse
672   /// tensor in coordinate scheme. This method prepares the pointers and
673   /// indices arrays under the given per-dimension dense/sparse annotations.
674   ///
675   /// Preconditions:
676   /// (1) the `elements` must be lexicographically sorted.
677   /// (2) the indices of every element are valid for `dimSizes` (equal rank
678   ///     and pointwise less-than).
679   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
680                uint64_t hi, uint64_t d) {
681     uint64_t rank = getRank();
682     assert(d <= rank && hi <= elements.size());
683     // Once dimensions are exhausted, insert the numerical values.
684     if (d == rank) {
685       assert(lo < hi);
686       values.push_back(elements[lo].value);
687       return;
688     }
689     // Visit all elements in this interval.
690     uint64_t full = 0;
691     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
692       // Find segment in interval with same index elements in this dimension.
693       uint64_t i = elements[lo].indices[d];
694       uint64_t seg = lo + 1;
695       while (seg < hi && elements[seg].indices[d] == i)
696         seg++;
697       // Handle segment in interval for sparse or dense dimension.
698       appendIndex(d, full, i);
699       full = i + 1;
700       fromCOO(elements, lo, seg, d + 1);
701       // And move on to next segment in interval.
702       lo = seg;
703     }
704     // Finalize the sparse pointer structure at this dimension.
705     finalizeSegment(d, full);
706   }
707 
708   /// Finalize the sparse pointer structure at this dimension.
709   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
710     if (count == 0)
711       return; // Short-circuit, since it'll be a nop.
712     if (isCompressedDim(d)) {
713       appendPointer(d, indices[d].size(), count);
714     } else { // Dense dimension.
715       const uint64_t sz = getDimSizes()[d];
716       assert(sz >= full && "Segment is overfull");
717       count = checkedMul(count, sz - full);
718       // For dense storage we must enumerate all the remaining coordinates
719       // in this dimension (i.e., coordinates after the last non-zero
720       // element), and either fill in their zero values or else recurse
721       // to finalize some deeper dimension.
722       if (d + 1 == getRank())
723         values.insert(values.end(), count, 0);
724       else
725         finalizeSegment(d + 1, 0, count);
726     }
727   }
728 
729   /// Wraps up a single insertion path, inner to outer.
730   void endPath(uint64_t diff) {
731     uint64_t rank = getRank();
732     assert(diff <= rank);
733     for (uint64_t i = 0; i < rank - diff; i++) {
734       const uint64_t d = rank - i - 1;
735       finalizeSegment(d, idx[d] + 1);
736     }
737   }
738 
739   /// Continues a single insertion path, outer to inner.
740   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
741     uint64_t rank = getRank();
742     assert(diff < rank);
743     for (uint64_t d = diff; d < rank; d++) {
744       uint64_t i = cursor[d];
745       appendIndex(d, top, i);
746       top = 0;
747       idx[d] = i;
748     }
749     values.push_back(val);
750   }
751 
752   /// Finds the lexicographic differing dimension.
753   uint64_t lexDiff(const uint64_t *cursor) const {
754     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
755       if (cursor[r] > idx[r])
756         return r;
757       else
758         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
759     assert(0 && "duplication insertion");
760     return -1u;
761   }
762 
763   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
764   // the cost of virtual-function dispatch in inner loops), without
765   // making them public to other client code.
766   friend class SparseTensorEnumerator<P, I, V>;
767 
768   std::vector<std::vector<P>> pointers;
769   std::vector<std::vector<I>> indices;
770   std::vector<V> values;
771   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
772 };
773 
774 /// A (higher-order) function object for enumerating the elements of some
775 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
776 /// method encapsulates the loop-nest for enumerating the elements of
777 /// the source tensor (in whatever order is best for the source tensor),
778 /// and applies a permutation to the coordinates/indices before handing
779 /// each element to the callback.  A single enumerator object can be
780 /// freely reused for several calls to `forallElements`, just so long
781 /// as each call is sequential with respect to one another.
782 ///
783 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
784 /// passed to the constructor; thus, objects of this class must not
785 /// outlive the sparse tensor they depend on.
786 ///
787 /// Design Note: The reason we define this class instead of simply using
788 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
789 /// the `<P,I>` template parameters from MLIR client code (to simplify the
790 /// type parameters used for direct sparse-to-sparse conversion).  And the
791 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
792 /// than simply using this class, is to avoid the cost of virtual-method
793 /// dispatch within the loop-nest.
794 template <typename V>
795 class SparseTensorEnumeratorBase {
796 public:
797   /// Constructs an enumerator with the given permutation for mapping
798   /// the semantic-ordering of dimensions to the desired target-ordering.
799   ///
800   /// Preconditions:
801   /// * the `tensor` must have the same `V` value type.
802   /// * `perm` must be valid for `rank`.
803   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
804                              uint64_t rank, const uint64_t *perm)
805       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
806         cursor(getRank()) {
807     assert(perm && "Received nullptr for permutation");
808     assert(rank == getRank() && "Permutation rank mismatch");
809     const auto &rev = src.getRev();           // source-order -> semantic-order
810     const auto &dimSizes = src.getDimSizes(); // in source storage-order
811     for (uint64_t s = 0; s < rank; s++) {     // `s` source storage-order
812       uint64_t t = perm[rev[s]];              // `t` target-order
813       reord[s] = t;
814       permsz[t] = dimSizes[s];
815     }
816   }
817 
818   virtual ~SparseTensorEnumeratorBase() = default;
819 
820   // We disallow copying to help avoid leaking the `src` reference.
821   // (In addition to avoiding the problem of slicing.)
822   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
823   SparseTensorEnumeratorBase &
824   operator=(const SparseTensorEnumeratorBase &) = delete;
825 
826   /// Returns the source/target tensor's rank.  (The source-rank and
827   /// target-rank are always equal since we only support permutations.
828   /// Though once we add support for other dimension mappings, this
829   /// method will have to be split in two.)
830   uint64_t getRank() const { return permsz.size(); }
831 
832   /// Returns the target tensor's dimension sizes.
833   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
834 
835   /// Enumerates all elements of the source tensor, permutes their
836   /// indices, and passes the permuted element to the callback.
837   /// The callback must not store the cursor reference directly,
838   /// since this function reuses the storage.  Instead, the callback
839   /// must copy it if they want to keep it.
840   virtual void forallElements(ElementConsumer<V> yield) = 0;
841 
842 protected:
843   const SparseTensorStorageBase &src;
844   std::vector<uint64_t> permsz; // in target order.
845   std::vector<uint64_t> reord;  // source storage-order -> target order.
846   std::vector<uint64_t> cursor; // in target order.
847 };
848 
849 template <typename P, typename I, typename V>
850 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
851   using Base = SparseTensorEnumeratorBase<V>;
852 
853 public:
854   /// Constructs an enumerator with the given permutation for mapping
855   /// the semantic-ordering of dimensions to the desired target-ordering.
856   ///
857   /// Precondition: `perm` must be valid for `rank`.
858   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
859                          uint64_t rank, const uint64_t *perm)
860       : Base(tensor, rank, perm) {}
861 
862   ~SparseTensorEnumerator() final override = default;
863 
864   void forallElements(ElementConsumer<V> yield) final override {
865     forallElements(yield, 0, 0);
866   }
867 
868 private:
869   /// The recursive component of the public `forallElements`.
870   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
871                       uint64_t d) {
872     // Recover the `<P,I,V>` type parameters of `src`.
873     const auto &src =
874         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
875     if (d == Base::getRank()) {
876       assert(parentPos < src.values.size() &&
877              "Value position is out of bounds");
878       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
879       yield(this->cursor, src.values[parentPos]);
880     } else if (src.isCompressedDim(d)) {
881       // Look up the bounds of the `d`-level segment determined by the
882       // `d-1`-level position `parentPos`.
883       const std::vector<P> &pointers_d = src.pointers[d];
884       assert(parentPos + 1 < pointers_d.size() &&
885              "Parent pointer position is out of bounds");
886       const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
887       const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
888       // Loop-invariant code for looking up the `d`-level coordinates/indices.
889       const std::vector<I> &indices_d = src.indices[d];
890       assert(pstop <= indices_d.size() && "Index position is out of bounds");
891       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
892       for (uint64_t pos = pstart; pos < pstop; pos++) {
893         cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
894         forallElements(yield, pos, d + 1);
895       }
896     } else { // Dense dimension.
897       const uint64_t sz = src.getDimSizes()[d];
898       const uint64_t pstart = parentPos * sz;
899       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
900       for (uint64_t i = 0; i < sz; i++) {
901         cursor_reord_d = i;
902         forallElements(yield, pstart + i, d + 1);
903       }
904     }
905   }
906 };
907 
908 /// Statistics regarding the number of nonzero subtensors in
909 /// a source tensor, for direct sparse=>sparse conversion a la
910 /// <https://arxiv.org/abs/2001.02609>.
911 ///
912 /// N.B., this class stores references to the parameters passed to
913 /// the constructor; thus, objects of this class must not outlive
914 /// those parameters.
915 class SparseTensorNNZ final {
916 public:
917   /// Allocate the statistics structure for the desired sizes and
918   /// sparsity (in the target tensor's storage-order).  This constructor
919   /// does not actually populate the statistics, however; for that see
920   /// `initialize`.
921   ///
922   /// Precondition: `dimSizes` must not contain zeros.
923   SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
924                   const std::vector<DimLevelType> &sparsity)
925       : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
926     assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
927     bool uncompressed = true;
928     uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
929     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
930       switch (dimTypes[r]) {
931       case DimLevelType::kCompressed:
932         assert(uncompressed &&
933                "Multiple compressed layers not currently supported");
934         uncompressed = false;
935         nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
936         break;
937       case DimLevelType::kDense:
938         assert(uncompressed &&
939                "Dense after compressed not currently supported");
940         break;
941       case DimLevelType::kSingleton:
942         // Singleton after Compressed causes no problems for allocating
943         // `nnz` nor for the yieldPos loop.  This remains true even
944         // when adding support for multiple compressed dimensions or
945         // for dense-after-compressed.
946         break;
947       }
948       sz = checkedMul(sz, dimSizes[r]);
949     }
950   }
951 
952   // We disallow copying to help avoid leaking the stored references.
953   SparseTensorNNZ(const SparseTensorNNZ &) = delete;
954   SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
955 
956   /// Returns the rank of the target tensor.
957   uint64_t getRank() const { return dimSizes.size(); }
958 
959   /// Enumerate the source tensor to fill in the statistics.  The
960   /// enumerator should already incorporate the permutation (from
961   /// semantic-order to the target storage-order).
962   template <typename V>
963   void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
964     assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
965     assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
966     enumerator.forallElements(
967         [this](const std::vector<uint64_t> &ind, V) { add(ind); });
968   }
969 
970   /// The type of callback functions which receive an nnz-statistic.
971   using NNZConsumer = const std::function<void(uint64_t)> &;
972 
973   /// Lexicographically enumerates all indicies for dimensions strictly
974   /// less than `stopDim`, and passes their nnz statistic to the callback.
975   /// Since our use-case only requires the statistic not the coordinates
976   /// themselves, we do not bother to construct those coordinates.
977   void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
978     assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
979     assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
980            "Cannot look up non-compressed dimensions");
981     forallIndices(yield, stopDim, 0, 0);
982   }
983 
984 private:
985   /// Adds a new element (i.e., increment its statistics).  We use
986   /// a method rather than inlining into the lambda in `initialize`,
987   /// to avoid spurious templating over `V`.  And this method is private
988   /// to avoid needing to re-assert validity of `ind` (which is guaranteed
989   /// by `forallElements`).
990   void add(const std::vector<uint64_t> &ind) {
991     uint64_t parentPos = 0;
992     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
993       if (dimTypes[r] == DimLevelType::kCompressed)
994         nnz[r][parentPos]++;
995       parentPos = parentPos * dimSizes[r] + ind[r];
996     }
997   }
998 
999   /// Recursive component of the public `forallIndices`.
1000   void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
1001                      uint64_t d) const {
1002     assert(d <= stopDim);
1003     if (d == stopDim) {
1004       assert(parentPos < nnz[d].size() && "Cursor is out of range");
1005       yield(nnz[d][parentPos]);
1006     } else {
1007       const uint64_t sz = dimSizes[d];
1008       const uint64_t pstart = parentPos * sz;
1009       for (uint64_t i = 0; i < sz; i++)
1010         forallIndices(yield, stopDim, pstart + i, d + 1);
1011     }
1012   }
1013 
1014   // All of these are in the target storage-order.
1015   const std::vector<uint64_t> &dimSizes;
1016   const std::vector<DimLevelType> &dimTypes;
1017   std::vector<std::vector<uint64_t>> nnz;
1018 };
1019 
1020 template <typename P, typename I, typename V>
1021 SparseTensorStorage<P, I, V>::SparseTensorStorage(
1022     const std::vector<uint64_t> &dimSizes, const uint64_t *perm,
1023     const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1024     : SparseTensorStorage(dimSizes, perm, sparsity) {
1025   SparseTensorEnumeratorBase<V> *enumerator;
1026   tensor.newEnumerator(&enumerator, getRank(), perm);
1027   {
1028     // Initialize the statistics structure.
1029     SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
1030     nnz.initialize(*enumerator);
1031     // Initialize "pointers" overhead (and allocate "indices", "values").
1032     uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
1033     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1034       if (isCompressedDim(r)) {
1035         pointers[r].reserve(parentSz + 1);
1036         pointers[r].push_back(0);
1037         uint64_t currentPos = 0;
1038         nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
1039           currentPos += n;
1040           appendPointer(r, currentPos);
1041         });
1042         assert(pointers[r].size() == parentSz + 1 &&
1043                "Final pointers size doesn't match allocated size");
1044         // That assertion entails `assembledSize(parentSz, r)`
1045         // is now in a valid state.  That is, `pointers[r][parentSz]`
1046         // equals the present value of `currentPos`, which is the
1047         // correct assembled-size for `indices[r]`.
1048       }
1049       // Update assembled-size for the next iteration.
1050       parentSz = assembledSize(parentSz, r);
1051       // Ideally we need only `indices[r].reserve(parentSz)`, however
1052       // the `std::vector` implementation forces us to initialize it too.
1053       // That is, in the yieldPos loop we need random-access assignment
1054       // to `indices[r]`; however, `std::vector`'s subscript-assignment
1055       // only allows assigning to already-initialized positions.
1056       if (isCompressedDim(r))
1057         indices[r].resize(parentSz, 0);
1058     }
1059     values.resize(parentSz, 0); // Both allocate and zero-initialize.
1060   }
1061   // The yieldPos loop
1062   enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
1063     uint64_t parentSz = 1, parentPos = 0;
1064     for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1065       if (isCompressedDim(r)) {
1066         // If `parentPos == parentSz` then it's valid as an array-lookup;
1067         // however, it's semantically invalid here since that entry
1068         // does not represent a segment of `indices[r]`.  Moreover, that
1069         // entry must be immutable for `assembledSize` to remain valid.
1070         assert(parentPos < parentSz && "Pointers position is out of bounds");
1071         const uint64_t currentPos = pointers[r][parentPos];
1072         // This increment won't overflow the `P` type, since it can't
1073         // exceed the original value of `pointers[r][parentPos+1]`
1074         // which was already verified to be within bounds for `P`
1075         // when it was written to the array.
1076         pointers[r][parentPos]++;
1077         writeIndex(r, currentPos, ind[r]);
1078         parentPos = currentPos;
1079       } else { // Dense dimension.
1080         parentPos = parentPos * getDimSizes()[r] + ind[r];
1081       }
1082       parentSz = assembledSize(parentSz, r);
1083     }
1084     assert(parentPos < values.size() && "Value position is out of bounds");
1085     values[parentPos] = val;
1086   });
1087   // No longer need the enumerator, so we'll delete it ASAP.
1088   delete enumerator;
1089   // The finalizeYieldPos loop
1090   for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
1091     if (isCompressedDim(r)) {
1092       assert(parentSz == pointers[r].size() - 1 &&
1093              "Actual pointers size doesn't match the expected size");
1094       // Can't check all of them, but at least we can check the last one.
1095       assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
1096              "Pointers got corrupted");
1097       // TODO: optimize this by using `memmove` or similar.
1098       for (uint64_t n = 0; n < parentSz; n++) {
1099         const uint64_t parentPos = parentSz - n;
1100         pointers[r][parentPos] = pointers[r][parentPos - 1];
1101       }
1102       pointers[r][0] = 0;
1103     }
1104     parentSz = assembledSize(parentSz, r);
1105   }
1106 }
1107 
1108 /// Helper to convert string to lower case.
1109 static char *toLower(char *token) {
1110   for (char *c = token; *c; c++)
1111     *c = tolower(*c);
1112   return token;
1113 }
1114 
1115 /// Read the MME header of a general sparse matrix of type real.
1116 static void readMMEHeader(FILE *file, char *filename, char *line,
1117                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
1118   char header[64];
1119   char object[64];
1120   char format[64];
1121   char field[64];
1122   char symmetry[64];
1123   // Read header line.
1124   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
1125              symmetry) != 5) {
1126     fprintf(stderr, "Corrupt header in %s\n", filename);
1127     exit(1);
1128   }
1129   // Set properties
1130   *isPattern = (strcmp(toLower(field), "pattern") == 0);
1131   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
1132   // Make sure this is a general sparse matrix.
1133   if (strcmp(toLower(header), "%%matrixmarket") ||
1134       strcmp(toLower(object), "matrix") ||
1135       strcmp(toLower(format), "coordinate") ||
1136       (strcmp(toLower(field), "real") && !(*isPattern)) ||
1137       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
1138     fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
1139     exit(1);
1140   }
1141   // Skip comments.
1142   while (true) {
1143     if (!fgets(line, kColWidth, file)) {
1144       fprintf(stderr, "Cannot find data in %s\n", filename);
1145       exit(1);
1146     }
1147     if (line[0] != '%')
1148       break;
1149   }
1150   // Next line contains M N NNZ.
1151   idata[0] = 2; // rank
1152   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
1153              idata + 1) != 3) {
1154     fprintf(stderr, "Cannot find size in %s\n", filename);
1155     exit(1);
1156   }
1157 }
1158 
1159 /// Read the "extended" FROSTT header. Although not part of the documented
1160 /// format, we assume that the file starts with optional comments followed
1161 /// by two lines that define the rank, the number of nonzeros, and the
1162 /// dimensions sizes (one per rank) of the sparse tensor.
1163 static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
1164                                 uint64_t *idata) {
1165   // Skip comments.
1166   while (true) {
1167     if (!fgets(line, kColWidth, file)) {
1168       fprintf(stderr, "Cannot find data in %s\n", filename);
1169       exit(1);
1170     }
1171     if (line[0] != '#')
1172       break;
1173   }
1174   // Next line contains RANK and NNZ.
1175   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
1176     fprintf(stderr, "Cannot find metadata in %s\n", filename);
1177     exit(1);
1178   }
1179   // Followed by a line with the dimension sizes (one per rank).
1180   for (uint64_t r = 0; r < idata[0]; r++) {
1181     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
1182       fprintf(stderr, "Cannot find dimension size %s\n", filename);
1183       exit(1);
1184     }
1185   }
1186   fgets(line, kColWidth, file); // end of line
1187 }
1188 
1189 /// Reads a sparse tensor with the given filename into a memory-resident
1190 /// sparse tensor in coordinate scheme.
1191 template <typename V>
1192 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
1193                                                const uint64_t *shape,
1194                                                const uint64_t *perm) {
1195   // Open the file.
1196   FILE *file = fopen(filename, "r");
1197   if (!file) {
1198     assert(filename && "Received nullptr for filename");
1199     fprintf(stderr, "Cannot find file %s\n", filename);
1200     exit(1);
1201   }
1202   // Perform some file format dependent set up.
1203   char line[kColWidth];
1204   uint64_t idata[512];
1205   bool isPattern = false;
1206   bool isSymmetric = false;
1207   if (strstr(filename, ".mtx")) {
1208     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
1209   } else if (strstr(filename, ".tns")) {
1210     readExtFROSTTHeader(file, filename, line, idata);
1211   } else {
1212     fprintf(stderr, "Unknown format %s\n", filename);
1213     exit(1);
1214   }
1215   // Prepare sparse tensor object with per-dimension sizes
1216   // and the number of nonzeros as initial capacity.
1217   assert(rank == idata[0] && "rank mismatch");
1218   uint64_t nnz = idata[1];
1219   for (uint64_t r = 0; r < rank; r++)
1220     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
1221            "dimension size mismatch");
1222   SparseTensorCOO<V> *tensor =
1223       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
1224   // Read all nonzero elements.
1225   std::vector<uint64_t> indices(rank);
1226   for (uint64_t k = 0; k < nnz; k++) {
1227     if (!fgets(line, kColWidth, file)) {
1228       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
1229       exit(1);
1230     }
1231     char *linePtr = line;
1232     for (uint64_t r = 0; r < rank; r++) {
1233       uint64_t idx = strtoul(linePtr, &linePtr, 10);
1234       // Add 0-based index.
1235       indices[perm[r]] = idx - 1;
1236     }
1237     // The external formats always store the numerical values with the type
1238     // double, but we cast these values to the sparse tensor object type.
1239     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
1240     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
1241     tensor->add(indices, value);
1242     // We currently chose to deal with symmetric matrices by fully constructing
1243     // them. In the future, we may want to make symmetry implicit for storage
1244     // reasons.
1245     if (isSymmetric && indices[0] != indices[1])
1246       tensor->add({indices[1], indices[0]}, value);
1247   }
1248   // Close the file and return tensor.
1249   fclose(file);
1250   return tensor;
1251 }
1252 
1253 /// Writes the sparse tensor to extended FROSTT format.
1254 template <typename V>
1255 static void outSparseTensor(void *tensor, void *dest, bool sort) {
1256   assert(tensor && dest);
1257   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
1258   if (sort)
1259     coo->sort();
1260   char *filename = static_cast<char *>(dest);
1261   auto &dimSizes = coo->getDimSizes();
1262   auto &elements = coo->getElements();
1263   uint64_t rank = coo->getRank();
1264   uint64_t nnz = elements.size();
1265   std::fstream file;
1266   file.open(filename, std::ios_base::out | std::ios_base::trunc);
1267   assert(file.is_open());
1268   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1269   for (uint64_t r = 0; r < rank - 1; r++)
1270     file << dimSizes[r] << " ";
1271   file << dimSizes[rank - 1] << std::endl;
1272   for (uint64_t i = 0; i < nnz; i++) {
1273     auto &idx = elements[i].indices;
1274     for (uint64_t r = 0; r < rank; r++)
1275       file << (idx[r] + 1) << " ";
1276     file << elements[i].value << std::endl;
1277   }
1278   file.flush();
1279   file.close();
1280   assert(file.good());
1281 }
1282 
1283 /// Initializes sparse tensor from an external COO-flavored format.
1284 template <typename V>
1285 static SparseTensorStorage<uint64_t, uint64_t, V> *
1286 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
1287                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
1288   const DimLevelType *sparsity = (DimLevelType *)(sparse);
1289 #ifndef NDEBUG
1290   // Verify that perm is a permutation of 0..(rank-1).
1291   std::vector<uint64_t> order(perm, perm + rank);
1292   std::sort(order.begin(), order.end());
1293   for (uint64_t i = 0; i < rank; ++i) {
1294     if (i != order[i]) {
1295       fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
1296       exit(1);
1297     }
1298   }
1299 
1300   // Verify that the sparsity values are supported.
1301   for (uint64_t i = 0; i < rank; ++i) {
1302     if (sparsity[i] != DimLevelType::kDense &&
1303         sparsity[i] != DimLevelType::kCompressed) {
1304       fprintf(stderr, "Unsupported sparsity value %d\n",
1305               static_cast<int>(sparsity[i]));
1306       exit(1);
1307     }
1308   }
1309 #endif
1310 
1311   // Convert external format to internal COO.
1312   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
1313   std::vector<uint64_t> idx(rank);
1314   for (uint64_t i = 0, base = 0; i < nse; i++) {
1315     for (uint64_t r = 0; r < rank; r++)
1316       idx[perm[r]] = indices[base + r];
1317     coo->add(idx, values[i]);
1318     base += rank;
1319   }
1320   // Return sparse tensor storage format as opaque pointer.
1321   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
1322       rank, shape, perm, sparsity, coo);
1323   delete coo;
1324   return tensor;
1325 }
1326 
1327 /// Converts a sparse tensor to an external COO-flavored format.
1328 template <typename V>
1329 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1330                                  uint64_t **pShape, V **pValues,
1331                                  uint64_t **pIndices) {
1332   assert(tensor);
1333   auto sparseTensor =
1334       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1335   uint64_t rank = sparseTensor->getRank();
1336   std::vector<uint64_t> perm(rank);
1337   std::iota(perm.begin(), perm.end(), 0);
1338   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1339 
1340   const std::vector<Element<V>> &elements = coo->getElements();
1341   uint64_t nse = elements.size();
1342 
1343   uint64_t *shape = new uint64_t[rank];
1344   for (uint64_t i = 0; i < rank; i++)
1345     shape[i] = coo->getDimSizes()[i];
1346 
1347   V *values = new V[nse];
1348   uint64_t *indices = new uint64_t[rank * nse];
1349 
1350   for (uint64_t i = 0, base = 0; i < nse; i++) {
1351     values[i] = elements[i].value;
1352     for (uint64_t j = 0; j < rank; j++)
1353       indices[base + j] = elements[i].indices[j];
1354     base += rank;
1355   }
1356 
1357   delete coo;
1358   *pRank = rank;
1359   *pNse = nse;
1360   *pShape = shape;
1361   *pValues = values;
1362   *pIndices = indices;
1363 }
1364 
1365 } // namespace
1366 
1367 extern "C" {
1368 
1369 //===----------------------------------------------------------------------===//
1370 //
1371 // Public API with methods that operate on MLIR buffers (memrefs) to interact
1372 // with sparse tensors, which are only visible as opaque pointers externally.
1373 // These methods should be used exclusively by MLIR compiler-generated code.
1374 //
1375 // Some macro magic is used to generate implementations for all required type
1376 // combinations that can be called from MLIR compiler-generated code.
1377 //
1378 //===----------------------------------------------------------------------===//
1379 
1380 #define CASE(p, i, v, P, I, V)                                                 \
1381   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
1382     SparseTensorCOO<V> *coo = nullptr;                                         \
1383     if (action <= Action::kFromCOO) {                                          \
1384       if (action == Action::kFromFile) {                                       \
1385         char *filename = static_cast<char *>(ptr);                             \
1386         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
1387       } else if (action == Action::kFromCOO) {                                 \
1388         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
1389       } else {                                                                 \
1390         assert(action == Action::kEmpty);                                      \
1391       }                                                                        \
1392       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
1393           rank, shape, perm, sparsity, coo);                                   \
1394       if (action == Action::kFromFile)                                         \
1395         delete coo;                                                            \
1396       return tensor;                                                           \
1397     }                                                                          \
1398     if (action == Action::kSparseToSparse) {                                   \
1399       auto *tensor = static_cast<SparseTensorStorageBase *>(ptr);              \
1400       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm,  \
1401                                                            sparsity, tensor);  \
1402     }                                                                          \
1403     if (action == Action::kEmptyCOO)                                           \
1404       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
1405     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1406     if (action == Action::kToIterator) {                                       \
1407       coo->startIterator();                                                    \
1408     } else {                                                                   \
1409       assert(action == Action::kToCOO);                                        \
1410     }                                                                          \
1411     return coo;                                                                \
1412   }
1413 
1414 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1415 
1416 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1417 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1418 // that this file cannot get out of sync with its header.
1419 static_assert(std::is_same<index_type, uint64_t>::value,
1420               "Expected index_type == uint64_t");
1421 
1422 /// Constructs a new sparse tensor. This is the "swiss army knife"
1423 /// method for materializing sparse tensors into the computation.
1424 ///
1425 /// Action:
1426 /// kEmpty = returns empty storage to fill later
1427 /// kFromFile = returns storage, where ptr contains filename to read
1428 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
1429 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
1430 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
1431 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
1432 void *
1433 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1434                              StridedMemRefType<index_type, 1> *sref,
1435                              StridedMemRefType<index_type, 1> *pref,
1436                              OverheadType ptrTp, OverheadType indTp,
1437                              PrimaryType valTp, Action action, void *ptr) {
1438   assert(aref && sref && pref);
1439   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1440          pref->strides[0] == 1);
1441   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1442   const DimLevelType *sparsity = aref->data + aref->offset;
1443   const index_type *shape = sref->data + sref->offset;
1444   const index_type *perm = pref->data + pref->offset;
1445   uint64_t rank = aref->sizes[0];
1446 
1447   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1448   // This is safe because of the static_assert above.
1449   if (ptrTp == OverheadType::kIndex)
1450     ptrTp = OverheadType::kU64;
1451   if (indTp == OverheadType::kIndex)
1452     indTp = OverheadType::kU64;
1453 
1454   // Double matrices with all combinations of overhead storage.
1455   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1456        uint64_t, double);
1457   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1458        uint32_t, double);
1459   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1460        uint16_t, double);
1461   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1462        uint8_t, double);
1463   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1464        uint64_t, double);
1465   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1466        uint32_t, double);
1467   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1468        uint16_t, double);
1469   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1470        uint8_t, double);
1471   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1472        uint64_t, double);
1473   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1474        uint32_t, double);
1475   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1476        uint16_t, double);
1477   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1478        uint8_t, double);
1479   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1480        uint64_t, double);
1481   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1482        uint32_t, double);
1483   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1484        uint16_t, double);
1485   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1486        uint8_t, double);
1487 
1488   // Float matrices with all combinations of overhead storage.
1489   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1490        uint64_t, float);
1491   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1492        uint32_t, float);
1493   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1494        uint16_t, float);
1495   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1496        uint8_t, float);
1497   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1498        uint64_t, float);
1499   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1500        uint32_t, float);
1501   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1502        uint16_t, float);
1503   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1504        uint8_t, float);
1505   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1506        uint64_t, float);
1507   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1508        uint32_t, float);
1509   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1510        uint16_t, float);
1511   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1512        uint8_t, float);
1513   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1514        uint64_t, float);
1515   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1516        uint32_t, float);
1517   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1518        uint16_t, float);
1519   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1520        uint8_t, float);
1521 
1522   // Integral matrices with both overheads of the same type.
1523   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1524   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1525   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1526   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1527   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1528   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1529   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1530   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1531   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1532   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1533   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1534   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1535   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1536 
1537   // Complex matrices with wide overhead.
1538   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1539   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1540 
1541   // Unsupported case (add above if needed).
1542   fputs("unsupported combination of types\n", stderr);
1543   exit(1);
1544 }
1545 #undef CASE
1546 #undef CASE_SECSAME
1547 
1548 /// Methods that provide direct access to values.
1549 #define IMPL_SPARSEVALUES(VNAME, V)                                            \
1550   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
1551                                         void *tensor) {                        \
1552     assert(ref &&tensor);                                                      \
1553     std::vector<V> *v;                                                         \
1554     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
1555     ref->basePtr = ref->data = v->data();                                      \
1556     ref->offset = 0;                                                           \
1557     ref->sizes[0] = v->size();                                                 \
1558     ref->strides[0] = 1;                                                       \
1559   }
1560 FOREVERY_V(IMPL_SPARSEVALUES)
1561 #undef IMPL_SPARSEVALUES
1562 
1563 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1564   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1565                            index_type d) {                                     \
1566     assert(ref &&tensor);                                                      \
1567     std::vector<TYPE> *v;                                                      \
1568     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1569     ref->basePtr = ref->data = v->data();                                      \
1570     ref->offset = 0;                                                           \
1571     ref->sizes[0] = v->size();                                                 \
1572     ref->strides[0] = 1;                                                       \
1573   }
1574 /// Methods that provide direct access to pointers.
1575 IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
1576 IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
1577 IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
1578 IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
1579 IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
1580 
1581 /// Methods that provide direct access to indices.
1582 IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
1583 IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
1584 IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
1585 IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
1586 IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
1587 #undef IMPL_GETOVERHEAD
1588 
1589 /// Helper to add value to coordinate scheme, one per value type.
1590 #define IMPL_ADDELT(VNAME, V)                                                  \
1591   void *_mlir_ciface_addElt##VNAME(void *coo, V value,                         \
1592                                    StridedMemRefType<index_type, 1> *iref,     \
1593                                    StridedMemRefType<index_type, 1> *pref) {   \
1594     assert(coo &&iref &&pref);                                                 \
1595     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1596     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1597     const index_type *indx = iref->data + iref->offset;                        \
1598     const index_type *perm = pref->data + pref->offset;                        \
1599     uint64_t isize = iref->sizes[0];                                           \
1600     std::vector<index_type> indices(isize);                                    \
1601     for (uint64_t r = 0; r < isize; r++)                                       \
1602       indices[perm[r]] = indx[r];                                              \
1603     static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value);               \
1604     return coo;                                                                \
1605   }
1606 FOREVERY_SIMPLEX_V(IMPL_ADDELT)
1607 // `complex64` apparently doesn't encounter any ABI issues (yet).
1608 IMPL_ADDELT(C64, complex64)
1609 // TODO: cleaner way to avoid ABI padding problem?
1610 IMPL_ADDELT(C32ABI, complex32)
1611 void *_mlir_ciface_addEltC32(void *coo, float r, float i,
1612                              StridedMemRefType<index_type, 1> *iref,
1613                              StridedMemRefType<index_type, 1> *pref) {
1614   return _mlir_ciface_addEltC32ABI(coo, complex32(r, i), iref, pref);
1615 }
1616 #undef IMPL_ADDELT
1617 
1618 /// Helper to enumerate elements of coordinate scheme, one per value type.
1619 #define IMPL_GETNEXT(VNAME, V)                                                 \
1620   bool _mlir_ciface_getNext##VNAME(void *coo,                                  \
1621                                    StridedMemRefType<index_type, 1> *iref,     \
1622                                    StridedMemRefType<V, 0> *vref) {            \
1623     assert(coo &&iref &&vref);                                                 \
1624     assert(iref->strides[0] == 1);                                             \
1625     index_type *indx = iref->data + iref->offset;                              \
1626     V *value = vref->data + vref->offset;                                      \
1627     const uint64_t isize = iref->sizes[0];                                     \
1628     const Element<V> *elem =                                                   \
1629         static_cast<SparseTensorCOO<V> *>(coo)->getNext();                     \
1630     if (elem == nullptr)                                                       \
1631       return false;                                                            \
1632     for (uint64_t r = 0; r < isize; r++)                                       \
1633       indx[r] = elem->indices[r];                                              \
1634     *value = elem->value;                                                      \
1635     return true;                                                               \
1636   }
1637 FOREVERY_V(IMPL_GETNEXT)
1638 #undef IMPL_GETNEXT
1639 
1640 /// Insert elements in lexicographical index order, one per value type.
1641 #define IMPL_LEXINSERT(VNAME, V)                                               \
1642   void _mlir_ciface_lexInsert##VNAME(                                          \
1643       void *tensor, StridedMemRefType<index_type, 1> *cref, V val) {           \
1644     assert(tensor &&cref);                                                     \
1645     assert(cref->strides[0] == 1);                                             \
1646     index_type *cursor = cref->data + cref->offset;                            \
1647     assert(cursor);                                                            \
1648     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1649   }
1650 FOREVERY_SIMPLEX_V(IMPL_LEXINSERT)
1651 // `complex64` apparently doesn't encounter any ABI issues (yet).
1652 IMPL_LEXINSERT(C64, complex64)
1653 // TODO: cleaner way to avoid ABI padding problem?
1654 IMPL_LEXINSERT(C32ABI, complex32)
1655 void _mlir_ciface_lexInsertC32(void *tensor,
1656                                StridedMemRefType<index_type, 1> *cref, float r,
1657                                float i) {
1658   _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
1659 }
1660 #undef IMPL_LEXINSERT
1661 
1662 /// Insert using expansion, one per value type.
1663 #define IMPL_EXPINSERT(VNAME, V)                                               \
1664   void _mlir_ciface_expInsert##VNAME(                                          \
1665       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1666       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1667       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1668     assert(tensor &&cref &&vref &&fref &&aref);                                \
1669     assert(cref->strides[0] == 1);                                             \
1670     assert(vref->strides[0] == 1);                                             \
1671     assert(fref->strides[0] == 1);                                             \
1672     assert(aref->strides[0] == 1);                                             \
1673     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1674     index_type *cursor = cref->data + cref->offset;                            \
1675     V *values = vref->data + vref->offset;                                     \
1676     bool *filled = fref->data + fref->offset;                                  \
1677     index_type *added = aref->data + aref->offset;                             \
1678     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1679         cursor, values, filled, added, count);                                 \
1680   }
1681 FOREVERY_V(IMPL_EXPINSERT)
1682 #undef IMPL_EXPINSERT
1683 
1684 /// Output a sparse tensor, one per value type.
1685 #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \
1686   void outSparseTensor##VNAME(void *coo, void *dest, bool sort) {              \
1687     return outSparseTensor<V>(coo, dest, sort);                                \
1688   }
1689 FOREVERY_V(IMPL_OUTSPARSETENSOR)
1690 #undef IMPL_OUTSPARSETENSOR
1691 
1692 //===----------------------------------------------------------------------===//
1693 //
1694 // Public API with methods that accept C-style data structures to interact
1695 // with sparse tensors, which are only visible as opaque pointers externally.
1696 // These methods can be used both by MLIR compiler-generated code as well as by
1697 // an external runtime that wants to interact with MLIR compiler-generated code.
1698 //
1699 //===----------------------------------------------------------------------===//
1700 
1701 /// Helper method to read a sparse tensor filename from the environment,
1702 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1703 char *getTensorFilename(index_type id) {
1704   char var[80];
1705   sprintf(var, "TENSOR%" PRIu64, id);
1706   char *env = getenv(var);
1707   if (!env) {
1708     fprintf(stderr, "Environment variable %s is not set\n", var);
1709     exit(1);
1710   }
1711   return env;
1712 }
1713 
1714 /// Returns size of sparse tensor in given dimension.
1715 index_type sparseDimSize(void *tensor, index_type d) {
1716   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1717 }
1718 
1719 /// Finalizes lexicographic insertions.
1720 void endInsert(void *tensor) {
1721   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1722 }
1723 
1724 /// Releases sparse tensor storage.
1725 void delSparseTensor(void *tensor) {
1726   delete static_cast<SparseTensorStorageBase *>(tensor);
1727 }
1728 
1729 /// Releases sparse tensor coordinate scheme.
1730 #define IMPL_DELCOO(VNAME, V)                                                  \
1731   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1732     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1733   }
1734 FOREVERY_V(IMPL_DELCOO)
1735 #undef IMPL_DELCOO
1736 
1737 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
1738 /// data structures. The expected parameters are:
1739 ///
1740 ///   rank:    rank of tensor
1741 ///   nse:     number of specified elements (usually the nonzeros)
1742 ///   shape:   array with dimension size for each rank
1743 ///   values:  a "nse" array with values for all specified elements
1744 ///   indices: a flat "nse x rank" array with indices for all specified elements
1745 ///   perm:    the permutation of the dimensions in the storage
1746 ///   sparse:  the sparsity for the dimensions
1747 ///
1748 /// For example, the sparse matrix
1749 ///     | 1.0 0.0 0.0 |
1750 ///     | 0.0 5.0 3.0 |
1751 /// can be passed as
1752 ///      rank    = 2
1753 ///      nse     = 3
1754 ///      shape   = [2, 3]
1755 ///      values  = [1.0, 5.0, 3.0]
1756 ///      indices = [ 0, 0,  1, 1,  1, 2]
1757 //
1758 // TODO: generalize beyond 64-bit indices.
1759 //
1760 #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V)                               \
1761   void *convertToMLIRSparseTensor##VNAME(                                      \
1762       uint64_t rank, uint64_t nse, uint64_t *shape, V *values,                 \
1763       uint64_t *indices, uint64_t *perm, uint8_t *sparse) {                    \
1764     return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm,      \
1765                                  sparse);                                      \
1766   }
1767 FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
1768 #undef IMPL_CONVERTTOMLIRSPARSETENSOR
1769 
1770 /// Converts a sparse tensor to COO-flavored format expressed using C-style
1771 /// data structures. The expected output parameters are pointers for these
1772 /// values:
1773 ///
1774 ///   rank:    rank of tensor
1775 ///   nse:     number of specified elements (usually the nonzeros)
1776 ///   shape:   array with dimension size for each rank
1777 ///   values:  a "nse" array with values for all specified elements
1778 ///   indices: a flat "nse x rank" array with indices for all specified elements
1779 ///
1780 /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
1781 /// from convertToMLIRSparseTensor.
1782 ///
1783 //  TODO: Currently, values are copied from SparseTensorStorage to
1784 //  SparseTensorCOO, then to the output. We may want to reduce the number of
1785 //  copies.
1786 //
1787 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1788 // compressed
1789 //
1790 #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V)                             \
1791   void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank,       \
1792                                           uint64_t *pNse, uint64_t **pShape,   \
1793                                           V **pValues, uint64_t **pIndices) {  \
1794     fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices);   \
1795   }
1796 FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
1797 #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
1798 
1799 } // extern "C"
1800 
1801 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1802