1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cctype>
25 #include <cinttypes>
26 #include <cstdio>
27 #include <cstdlib>
28 #include <cstring>
29 #include <fstream>
30 #include <functional>
31 #include <iostream>
32 #include <limits>
33 #include <numeric>
34 #include <vector>
35 
36 //===----------------------------------------------------------------------===//
37 //
38 // Internal support for storing and reading sparse tensors.
39 //
40 // The following memory-resident sparse storage schemes are supported:
41 //
42 // (a) A coordinate scheme for temporarily storing and lexicographically
43 //     sorting a sparse tensor by index (SparseTensorCOO).
44 //
45 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
46 //     per-dimension sparse/dense annnotations together with a dimension
47 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
48 //
49 // The following external formats are supported:
50 //
51 // (1) Matrix Market Exchange (MME): *.mtx
52 //     https://math.nist.gov/MatrixMarket/formats.html
53 //
54 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
55 //     http://frostt.io/tensors/file-formats.html
56 //
57 // Two public APIs are supported:
58 //
59 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
60 //     tensors. These methods should be used exclusively by MLIR
61 //     compiler-generated code.
62 //
63 // (II) Methods that accept C-style data structures to interact with sparse
64 //      tensors. These methods can be used by any external runtime that wants
65 //      to interact with MLIR compiler-generated code.
66 //
67 // In both cases (I) and (II), the SparseTensorStorage format is externally
68 // only visible as an opaque pointer.
69 //
70 //===----------------------------------------------------------------------===//
71 
72 namespace {
73 
74 static constexpr int kColWidth = 1025;
75 
76 /// A version of `operator*` on `uint64_t` which checks for overflows.
77 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
78   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
79          "Integer overflow");
80   return lhs * rhs;
81 }
82 
83 /// A sparse tensor element in coordinate scheme (value and indices).
84 /// For example, a rank-1 vector element would look like
85 ///   ({i}, a[i])
86 /// and a rank-5 tensor element like
87 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
88 /// We use pointer to a shared index pool rather than e.g. a direct
89 /// vector since that (1) reduces the per-element memory footprint, and
90 /// (2) centralizes the memory reservation and (re)allocation to one place.
91 template <typename V>
92 struct Element {
93   Element(uint64_t *ind, V val) : indices(ind), value(val){};
94   uint64_t *indices; // pointer into shared index pool
95   V value;
96 };
97 
98 /// The type of callback functions which receive an element.  We avoid
99 /// packaging the coordinates and value together as an `Element` object
100 /// because this helps keep code somewhat cleaner.
101 template <typename V>
102 using ElementConsumer =
103     const std::function<void(const std::vector<uint64_t> &, V)> &;
104 
105 /// A memory-resident sparse tensor in coordinate scheme (collection of
106 /// elements). This data structure is used to read a sparse tensor from
107 /// any external format into memory and sort the elements lexicographically
108 /// by indices before passing it back to the client (most packed storage
109 /// formats require the elements to appear in lexicographic index order).
110 template <typename V>
111 struct SparseTensorCOO {
112 public:
113   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
114       : sizes(szs) {
115     if (capacity) {
116       elements.reserve(capacity);
117       indices.reserve(capacity * getRank());
118     }
119   }
120 
121   /// Adds element as indices and value.
122   void add(const std::vector<uint64_t> &ind, V val) {
123     assert(!iteratorLocked && "Attempt to add() after startIterator()");
124     uint64_t *base = indices.data();
125     uint64_t size = indices.size();
126     uint64_t rank = getRank();
127     assert(rank == ind.size());
128     for (uint64_t r = 0; r < rank; r++) {
129       assert(ind[r] < sizes[r]); // within bounds
130       indices.push_back(ind[r]);
131     }
132     // This base only changes if indices were reallocated. In that case, we
133     // need to correct all previous pointers into the vector. Note that this
134     // only happens if we did not set the initial capacity right, and then only
135     // for every internal vector reallocation (which with the doubling rule
136     // should only incur an amortized linear overhead).
137     uint64_t *newBase = indices.data();
138     if (newBase != base) {
139       for (uint64_t i = 0, n = elements.size(); i < n; i++)
140         elements[i].indices = newBase + (elements[i].indices - base);
141       base = newBase;
142     }
143     // Add element as (pointer into shared index pool, value) pair.
144     elements.emplace_back(base + size, val);
145   }
146 
147   /// Sorts elements lexicographically by index.
148   void sort() {
149     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
150     // TODO: we may want to cache an `isSorted` bit, to avoid
151     // unnecessary/redundant sorting.
152     std::sort(elements.begin(), elements.end(),
153               [this](const Element<V> &e1, const Element<V> &e2) {
154                 uint64_t rank = getRank();
155                 for (uint64_t r = 0; r < rank; r++) {
156                   if (e1.indices[r] == e2.indices[r])
157                     continue;
158                   return e1.indices[r] < e2.indices[r];
159                 }
160                 return false;
161               });
162   }
163 
164   /// Returns rank.
165   uint64_t getRank() const { return sizes.size(); }
166 
167   /// Getter for sizes array.
168   const std::vector<uint64_t> &getSizes() const { return sizes; }
169 
170   /// Getter for elements array.
171   const std::vector<Element<V>> &getElements() const { return elements; }
172 
173   /// Switch into iterator mode.
174   void startIterator() {
175     iteratorLocked = true;
176     iteratorPos = 0;
177   }
178 
179   /// Get the next element.
180   const Element<V> *getNext() {
181     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
182     if (iteratorPos < elements.size())
183       return &(elements[iteratorPos++]);
184     iteratorLocked = false;
185     return nullptr;
186   }
187 
188   /// Factory method. Permutes the original dimensions according to
189   /// the given ordering and expects subsequent add() calls to honor
190   /// that same ordering for the given indices. The result is a
191   /// fully permuted coordinate scheme.
192   ///
193   /// Precondition: `sizes` and `perm` must be valid for `rank`.
194   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
195                                                 const uint64_t *sizes,
196                                                 const uint64_t *perm,
197                                                 uint64_t capacity = 0) {
198     std::vector<uint64_t> permsz(rank);
199     for (uint64_t r = 0; r < rank; r++) {
200       assert(sizes[r] > 0 && "Dimension size zero has trivial storage");
201       permsz[perm[r]] = sizes[r];
202     }
203     return new SparseTensorCOO<V>(permsz, capacity);
204   }
205 
206 private:
207   const std::vector<uint64_t> sizes; // per-dimension sizes
208   std::vector<Element<V>> elements;  // all COO elements
209   std::vector<uint64_t> indices;     // shared index pool
210   bool iteratorLocked = false;
211   unsigned iteratorPos = 0;
212 };
213 
214 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
215 /// takes responsibility for all the `<P,I,V>`-independent aspects
216 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
217 /// we use function overloading to implement "partial" method
218 /// specialization, which the C-API relies on to catch type errors
219 /// arising from our use of opaque pointers.
220 class SparseTensorStorageBase {
221 public:
222   /// Constructs a new storage object.  The `perm` maps the tensor's
223   /// semantic-ordering of dimensions to this object's storage-order.
224   /// The `szs` and `sparsity` arrays are already in storage-order.
225   ///
226   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
227   SparseTensorStorageBase(const std::vector<uint64_t> &szs,
228                           const uint64_t *perm, const DimLevelType *sparsity)
229       : dimSizes(szs), rev(getRank()),
230         dimTypes(sparsity, sparsity + getRank()) {
231     assert(perm && sparsity);
232     const uint64_t rank = getRank();
233     // Validate parameters.
234     assert(rank > 0 && "Trivial shape is unsupported");
235     for (uint64_t r = 0; r < rank; r++) {
236       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
237       assert((dimTypes[r] == DimLevelType::kDense ||
238               dimTypes[r] == DimLevelType::kCompressed) &&
239              "Unsupported DimLevelType");
240     }
241     // Construct the "reverse" (i.e., inverse) permutation.
242     for (uint64_t r = 0; r < rank; r++)
243       rev[perm[r]] = r;
244   }
245 
246   virtual ~SparseTensorStorageBase() = default;
247 
248   /// Get the rank of the tensor.
249   uint64_t getRank() const { return dimSizes.size(); }
250 
251   /// Getter for the dimension-sizes array, in storage-order.
252   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
253 
254   /// Safely lookup the size of the given (storage-order) dimension.
255   uint64_t getDimSize(uint64_t d) const {
256     assert(d < getRank());
257     return dimSizes[d];
258   }
259 
260   /// Getter for the "reverse" permutation, which maps this object's
261   /// storage-order to the tensor's semantic-order.
262   const std::vector<uint64_t> &getRev() const { return rev; }
263 
264   /// Getter for the dimension-types array, in storage-order.
265   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
266 
267   /// Safely check if the (storage-order) dimension uses compressed storage.
268   bool isCompressedDim(uint64_t d) const {
269     assert(d < getRank());
270     return (dimTypes[d] == DimLevelType::kCompressed);
271   }
272 
273   /// Overhead storage.
274   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
275   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
276   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
277   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
278   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
279   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
280   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
281   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
282 
283   /// Primary storage.
284   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
285   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
286   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
287   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
288   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
289   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
290 
291   /// Element-wise insertion in lexicographic index order.
292   virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
293   virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); }
294   virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); }
295   virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
296   virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
297   virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
298 
299   /// Expanded insertion.
300   virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
301     fatal("expf64");
302   }
303   virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
304     fatal("expf32");
305   }
306   virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
307     fatal("expi64");
308   }
309   virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
310     fatal("expi32");
311   }
312   virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
313     fatal("expi16");
314   }
315   virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
316     fatal("expi8");
317   }
318 
319   /// Finishes insertion.
320   virtual void endInsert() = 0;
321 
322 protected:
323   // Since this class is virtual, we must disallow public copying in
324   // order to avoid "slicing".  Since this class has data members,
325   // that means making copying protected.
326   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
327   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
328   // Copy-assignment would be implicitly deleted (because `dimSizes`
329   // is const), so we explicitly delete it for clarity.
330   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
331 
332 private:
333   static void fatal(const char *tp) {
334     fprintf(stderr, "unsupported %s\n", tp);
335     exit(1);
336   }
337 
338   const std::vector<uint64_t> dimSizes;
339   std::vector<uint64_t> rev;
340   const std::vector<DimLevelType> dimTypes;
341 };
342 
343 // Forward.
344 template <typename P, typename I, typename V>
345 class SparseTensorEnumerator;
346 
347 /// A memory-resident sparse tensor using a storage scheme based on
348 /// per-dimension sparse/dense annotations. This data structure provides a
349 /// bufferized form of a sparse tensor type. In contrast to generating setup
350 /// methods for each differently annotated sparse tensor, this method provides
351 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
352 /// and annotations to implement all required setup in a general manner.
353 template <typename P, typename I, typename V>
354 class SparseTensorStorage : public SparseTensorStorageBase {
355 public:
356   /// Constructs a sparse tensor storage scheme with the given dimensions,
357   /// permutation, and per-dimension dense/sparse annotations, using
358   /// the coordinate scheme tensor for the initial contents if provided.
359   ///
360   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
361   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
362                       const DimLevelType *sparsity,
363                       SparseTensorCOO<V> *coo = nullptr)
364       : SparseTensorStorageBase(szs, perm, sparsity), pointers(getRank()),
365         indices(getRank()), idx(getRank()) {
366     // Provide hints on capacity of pointers and indices.
367     // TODO: needs much fine-tuning based on actual sparsity; currently
368     //       we reserve pointer/index space based on all previous dense
369     //       dimensions, which works well up to first sparse dim; but
370     //       we should really use nnz and dense/sparse distribution.
371     bool allDense = true;
372     uint64_t sz = 1;
373     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
374       if (isCompressedDim(r)) {
375         // TODO: Take a parameter between 1 and `sizes[r]`, and multiply
376         // `sz` by that before reserving. (For now we just use 1.)
377         pointers[r].reserve(sz + 1);
378         pointers[r].push_back(0);
379         indices[r].reserve(sz);
380         sz = 1;
381         allDense = false;
382       } else { // Dense dimension.
383         sz = checkedMul(sz, getDimSizes()[r]);
384       }
385     }
386     // Then assign contents from coordinate scheme tensor if provided.
387     if (coo) {
388       // Ensure both preconditions of `fromCOO`.
389       assert(coo->getSizes() == getDimSizes() && "Tensor size mismatch");
390       coo->sort();
391       // Now actually insert the `elements`.
392       const std::vector<Element<V>> &elements = coo->getElements();
393       uint64_t nnz = elements.size();
394       values.reserve(nnz);
395       fromCOO(elements, 0, nnz, 0);
396     } else if (allDense) {
397       values.resize(sz, 0);
398     }
399   }
400 
401   ~SparseTensorStorage() override = default;
402 
403   /// Partially specialize these getter methods based on template types.
404   void getPointers(std::vector<P> **out, uint64_t d) override {
405     assert(d < getRank());
406     *out = &pointers[d];
407   }
408   void getIndices(std::vector<I> **out, uint64_t d) override {
409     assert(d < getRank());
410     *out = &indices[d];
411   }
412   void getValues(std::vector<V> **out) override { *out = &values; }
413 
414   /// Partially specialize lexicographical insertions based on template types.
415   void lexInsert(const uint64_t *cursor, V val) override {
416     // First, wrap up pending insertion path.
417     uint64_t diff = 0;
418     uint64_t top = 0;
419     if (!values.empty()) {
420       diff = lexDiff(cursor);
421       endPath(diff + 1);
422       top = idx[diff] + 1;
423     }
424     // Then continue with insertion path.
425     insPath(cursor, diff, top, val);
426   }
427 
428   /// Partially specialize expanded insertions based on template types.
429   /// Note that this method resets the values/filled-switch array back
430   /// to all-zero/false while only iterating over the nonzero elements.
431   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
432                  uint64_t count) override {
433     if (count == 0)
434       return;
435     // Sort.
436     std::sort(added, added + count);
437     // Restore insertion path for first insert.
438     const uint64_t lastDim = getRank() - 1;
439     uint64_t index = added[0];
440     cursor[lastDim] = index;
441     lexInsert(cursor, values[index]);
442     assert(filled[index]);
443     values[index] = 0;
444     filled[index] = false;
445     // Subsequent insertions are quick.
446     for (uint64_t i = 1; i < count; i++) {
447       assert(index < added[i] && "non-lexicographic insertion");
448       index = added[i];
449       cursor[lastDim] = index;
450       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
451       assert(filled[index]);
452       values[index] = 0;
453       filled[index] = false;
454     }
455   }
456 
457   /// Finalizes lexicographic insertions.
458   void endInsert() override {
459     if (values.empty())
460       finalizeSegment(0);
461     else
462       endPath(0);
463   }
464 
465   /// Returns this sparse tensor storage scheme as a new memory-resident
466   /// sparse tensor in coordinate scheme with the given dimension order.
467   ///
468   /// Precondition: `perm` must be valid for `getRank()`.
469   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
470     SparseTensorEnumerator<P, I, V> enumerator(*this, getRank(), perm);
471     SparseTensorCOO<V> *coo =
472         new SparseTensorCOO<V>(enumerator.permutedSizes(), values.size());
473     enumerator.forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
474       coo->add(ind, val);
475     });
476     // TODO: This assertion assumes there are no stored zeros,
477     // or if there are then that we don't filter them out.
478     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
479     assert(coo->getElements().size() == values.size());
480     return coo;
481   }
482 
483   /// Factory method. Constructs a sparse tensor storage scheme with the given
484   /// dimensions, permutation, and per-dimension dense/sparse annotations,
485   /// using the coordinate scheme tensor for the initial contents if provided.
486   /// In the latter case, the coordinate scheme must respect the same
487   /// permutation as is desired for the new sparse tensor storage.
488   ///
489   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
490   static SparseTensorStorage<P, I, V> *
491   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
492                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
493     SparseTensorStorage<P, I, V> *n = nullptr;
494     if (coo) {
495       assert(coo->getRank() == rank && "Tensor rank mismatch");
496       const auto &coosz = coo->getSizes();
497       for (uint64_t r = 0; r < rank; r++)
498         assert(shape[r] == 0 || shape[r] == coosz[perm[r]]);
499       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
500     } else {
501       std::vector<uint64_t> permsz(rank);
502       for (uint64_t r = 0; r < rank; r++) {
503         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
504         permsz[perm[r]] = shape[r];
505       }
506       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
507     }
508     return n;
509   }
510 
511 private:
512   /// Appends an arbitrary new position to `pointers[d]`.  This method
513   /// checks that `pos` is representable in the `P` type; however, it
514   /// does not check that `pos` is semantically valid (i.e., larger than
515   /// the previous position and smaller than `indices[d].capacity()`).
516   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
517     assert(isCompressedDim(d));
518     assert(pos <= std::numeric_limits<P>::max() &&
519            "Pointer value is too large for the P-type");
520     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
521   }
522 
523   /// Appends index `i` to dimension `d`, in the semantically general
524   /// sense.  For non-dense dimensions, that means appending to the
525   /// `indices[d]` array, checking that `i` is representable in the `I`
526   /// type; however, we do not verify other semantic requirements (e.g.,
527   /// that `i` is in bounds for `sizes[d]`, and not previously occurring
528   /// in the same segment).  For dense dimensions, this method instead
529   /// appends the appropriate number of zeros to the `values` array,
530   /// where `full` is the number of "entries" already written to `values`
531   /// for this segment (aka one after the highest index previously appended).
532   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
533     if (isCompressedDim(d)) {
534       assert(i <= std::numeric_limits<I>::max() &&
535              "Index value is too large for the I-type");
536       indices[d].push_back(static_cast<I>(i));
537     } else { // Dense dimension.
538       assert(i >= full && "Index was already filled");
539       if (i == full)
540         return; // Short-circuit, since it'll be a nop.
541       if (d + 1 == getRank())
542         values.insert(values.end(), i - full, 0);
543       else
544         finalizeSegment(d + 1, 0, i - full);
545     }
546   }
547 
548   /// Initializes sparse tensor storage scheme from a memory-resident sparse
549   /// tensor in coordinate scheme. This method prepares the pointers and
550   /// indices arrays under the given per-dimension dense/sparse annotations.
551   ///
552   /// Preconditions:
553   /// (1) the `elements` must be lexicographically sorted.
554   /// (2) the indices of every element are valid for `sizes` (equal rank
555   ///     and pointwise less-than).
556   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
557                uint64_t hi, uint64_t d) {
558     uint64_t rank = getRank();
559     assert(d <= rank && hi <= elements.size());
560     // Once dimensions are exhausted, insert the numerical values.
561     if (d == rank) {
562       assert(lo < hi);
563       values.push_back(elements[lo].value);
564       return;
565     }
566     // Visit all elements in this interval.
567     uint64_t full = 0;
568     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
569       // Find segment in interval with same index elements in this dimension.
570       uint64_t i = elements[lo].indices[d];
571       uint64_t seg = lo + 1;
572       while (seg < hi && elements[seg].indices[d] == i)
573         seg++;
574       // Handle segment in interval for sparse or dense dimension.
575       appendIndex(d, full, i);
576       full = i + 1;
577       fromCOO(elements, lo, seg, d + 1);
578       // And move on to next segment in interval.
579       lo = seg;
580     }
581     // Finalize the sparse pointer structure at this dimension.
582     finalizeSegment(d, full);
583   }
584 
585   /// Finalize the sparse pointer structure at this dimension.
586   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
587     if (count == 0)
588       return; // Short-circuit, since it'll be a nop.
589     if (isCompressedDim(d)) {
590       appendPointer(d, indices[d].size(), count);
591     } else { // Dense dimension.
592       const uint64_t sz = getDimSizes()[d];
593       assert(sz >= full && "Segment is overfull");
594       count = checkedMul(count, sz - full);
595       // For dense storage we must enumerate all the remaining coordinates
596       // in this dimension (i.e., coordinates after the last non-zero
597       // element), and either fill in their zero values or else recurse
598       // to finalize some deeper dimension.
599       if (d + 1 == getRank())
600         values.insert(values.end(), count, 0);
601       else
602         finalizeSegment(d + 1, 0, count);
603     }
604   }
605 
606   /// Wraps up a single insertion path, inner to outer.
607   void endPath(uint64_t diff) {
608     uint64_t rank = getRank();
609     assert(diff <= rank);
610     for (uint64_t i = 0; i < rank - diff; i++) {
611       const uint64_t d = rank - i - 1;
612       finalizeSegment(d, idx[d] + 1);
613     }
614   }
615 
616   /// Continues a single insertion path, outer to inner.
617   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
618     uint64_t rank = getRank();
619     assert(diff < rank);
620     for (uint64_t d = diff; d < rank; d++) {
621       uint64_t i = cursor[d];
622       appendIndex(d, top, i);
623       top = 0;
624       idx[d] = i;
625     }
626     values.push_back(val);
627   }
628 
629   /// Finds the lexicographic differing dimension.
630   uint64_t lexDiff(const uint64_t *cursor) const {
631     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
632       if (cursor[r] > idx[r])
633         return r;
634       else
635         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
636     assert(0 && "duplication insertion");
637     return -1u;
638   }
639 
640   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
641   // the cost of virtual-function dispatch in inner loops), without
642   // making them public to other client code.
643   friend class SparseTensorEnumerator<P, I, V>;
644 
645   std::vector<std::vector<P>> pointers;
646   std::vector<std::vector<I>> indices;
647   std::vector<V> values;
648   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
649 };
650 
651 /// A (higher-order) function object for enumerating the elements of some
652 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
653 /// method encapsulates the loop-nest for enumerating the elements of
654 /// the source tensor (in whatever order is best for the source tensor),
655 /// and applies a permutation to the coordinates/indices before handing
656 /// each element to the callback.  A single enumerator object can be
657 /// freely reused for several calls to `forallElements`, just so long
658 /// as each call is sequential with respect to one another.
659 ///
660 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
661 /// passed to the constructor; thus, objects of this class must not
662 /// outlive the sparse tensor they depend on.
663 ///
664 /// Design Note: The reason we define this class instead of simply using
665 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
666 /// the `<P,I>` template parameters from MLIR client code (to simplify the
667 /// type parameters used for direct sparse-to-sparse conversion).  And the
668 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
669 /// than simply using this class, is to avoid the cost of virtual-method
670 /// dispatch within the loop-nest.
671 template <typename V>
672 class SparseTensorEnumeratorBase {
673 public:
674   /// Constructs an enumerator with the given permutation for mapping
675   /// the semantic-ordering of dimensions to the desired target-ordering.
676   ///
677   /// Preconditions:
678   /// * the `tensor` must have the same `V` value type.
679   /// * `perm` must be valid for `rank`.
680   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
681                              uint64_t rank, const uint64_t *perm)
682       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
683         cursor(getRank()) {
684     assert(perm && "Received nullptr for permutation");
685     assert(rank == getRank() && "Permutation rank mismatch");
686     const auto &rev = src.getRev();        // source stg-order -> semantic-order
687     const auto &sizes = src.getDimSizes(); // in source storage-order
688     for (uint64_t s = 0; s < rank; s++) {  // `s` source storage-order
689       uint64_t t = perm[rev[s]];           // `t` target-order
690       reord[s] = t;
691       permsz[t] = sizes[s];
692     }
693   }
694 
695   virtual ~SparseTensorEnumeratorBase() = default;
696 
697   // We disallow copying to help avoid leaking the `src` reference.
698   // (In addition to avoiding the problem of slicing.)
699   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
700   SparseTensorEnumeratorBase &
701   operator=(const SparseTensorEnumeratorBase &) = delete;
702 
703   /// Returns the source/target tensor's rank.  (The source-rank and
704   /// target-rank are always equal since we only support permutations.
705   /// Though once we add support for other dimension mappings, this
706   /// method will have to be split in two.)
707   uint64_t getRank() const { return permsz.size(); }
708 
709   /// Returns the target tensor's dimension sizes.
710   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
711 
712   /// Enumerates all elements of the source tensor, permutes their
713   /// indices, and passes the permuted element to the callback.
714   /// The callback must not store the cursor reference directly,
715   /// since this function reuses the storage.  Instead, the callback
716   /// must copy it if they want to keep it.
717   virtual void forallElements(ElementConsumer<V> yield) = 0;
718 
719 protected:
720   const SparseTensorStorageBase &src;
721   std::vector<uint64_t> permsz; // in target order.
722   std::vector<uint64_t> reord;  // source storage-order -> target order.
723   std::vector<uint64_t> cursor; // in target order.
724 };
725 
726 template <typename P, typename I, typename V>
727 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
728   using Base = SparseTensorEnumeratorBase<V>;
729 
730 public:
731   /// Constructs an enumerator with the given permutation for mapping
732   /// the semantic-ordering of dimensions to the desired target-ordering.
733   ///
734   /// Precondition: `perm` must be valid for `rank`.
735   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
736                          uint64_t rank, const uint64_t *perm)
737       : Base(tensor, rank, perm) {}
738 
739   ~SparseTensorEnumerator() final override = default;
740 
741   void forallElements(ElementConsumer<V> yield) final override {
742     forallElements(yield, 0, 0);
743   }
744 
745 private:
746   /// The recursive component of the public `forallElements`.
747   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
748                       uint64_t d) {
749     // Recover the `<P,I,V>` type parameters of `src`.
750     const auto &src =
751         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
752     if (d == Base::getRank()) {
753       assert(parentPos < src.values.size() &&
754              "Value position is out of bounds");
755       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
756       yield(this->cursor, src.values[parentPos]);
757     } else if (src.isCompressedDim(d)) {
758       // Look up the bounds of the `d`-level segment determined by the
759       // `d-1`-level position `parentPos`.
760       const std::vector<P> &pointers_d = src.pointers[d];
761       assert(parentPos + 1 < pointers_d.size() &&
762              "Parent pointer position is out of bounds");
763       const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
764       const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
765       // Loop-invariant code for looking up the `d`-level coordinates/indices.
766       const std::vector<I> &indices_d = src.indices[d];
767       assert(pstop - 1 < indices_d.size() && "Index position is out of bounds");
768       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
769       for (uint64_t pos = pstart; pos < pstop; pos++) {
770         cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
771         forallElements(yield, pos, d + 1);
772       }
773     } else { // Dense dimension.
774       const uint64_t sz = src.getDimSizes()[d];
775       const uint64_t pstart = parentPos * sz;
776       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
777       for (uint64_t i = 0; i < sz; i++) {
778         cursor_reord_d = i;
779         forallElements(yield, pstart + i, d + 1);
780       }
781     }
782   }
783 };
784 
785 /// Helper to convert string to lower case.
786 static char *toLower(char *token) {
787   for (char *c = token; *c; c++)
788     *c = tolower(*c);
789   return token;
790 }
791 
792 /// Read the MME header of a general sparse matrix of type real.
793 static void readMMEHeader(FILE *file, char *filename, char *line,
794                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
795   char header[64];
796   char object[64];
797   char format[64];
798   char field[64];
799   char symmetry[64];
800   // Read header line.
801   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
802              symmetry) != 5) {
803     fprintf(stderr, "Corrupt header in %s\n", filename);
804     exit(1);
805   }
806   // Set properties
807   *isPattern = (strcmp(toLower(field), "pattern") == 0);
808   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
809   // Make sure this is a general sparse matrix.
810   if (strcmp(toLower(header), "%%matrixmarket") ||
811       strcmp(toLower(object), "matrix") ||
812       strcmp(toLower(format), "coordinate") ||
813       (strcmp(toLower(field), "real") && !(*isPattern)) ||
814       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
815     fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
816     exit(1);
817   }
818   // Skip comments.
819   while (true) {
820     if (!fgets(line, kColWidth, file)) {
821       fprintf(stderr, "Cannot find data in %s\n", filename);
822       exit(1);
823     }
824     if (line[0] != '%')
825       break;
826   }
827   // Next line contains M N NNZ.
828   idata[0] = 2; // rank
829   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
830              idata + 1) != 3) {
831     fprintf(stderr, "Cannot find size in %s\n", filename);
832     exit(1);
833   }
834 }
835 
836 /// Read the "extended" FROSTT header. Although not part of the documented
837 /// format, we assume that the file starts with optional comments followed
838 /// by two lines that define the rank, the number of nonzeros, and the
839 /// dimensions sizes (one per rank) of the sparse tensor.
840 static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
841                                 uint64_t *idata) {
842   // Skip comments.
843   while (true) {
844     if (!fgets(line, kColWidth, file)) {
845       fprintf(stderr, "Cannot find data in %s\n", filename);
846       exit(1);
847     }
848     if (line[0] != '#')
849       break;
850   }
851   // Next line contains RANK and NNZ.
852   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
853     fprintf(stderr, "Cannot find metadata in %s\n", filename);
854     exit(1);
855   }
856   // Followed by a line with the dimension sizes (one per rank).
857   for (uint64_t r = 0; r < idata[0]; r++) {
858     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
859       fprintf(stderr, "Cannot find dimension size %s\n", filename);
860       exit(1);
861     }
862   }
863   fgets(line, kColWidth, file); // end of line
864 }
865 
866 /// Reads a sparse tensor with the given filename into a memory-resident
867 /// sparse tensor in coordinate scheme.
868 template <typename V>
869 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
870                                                const uint64_t *shape,
871                                                const uint64_t *perm) {
872   // Open the file.
873   FILE *file = fopen(filename, "r");
874   if (!file) {
875     assert(filename && "Received nullptr for filename");
876     fprintf(stderr, "Cannot find file %s\n", filename);
877     exit(1);
878   }
879   // Perform some file format dependent set up.
880   char line[kColWidth];
881   uint64_t idata[512];
882   bool isPattern = false;
883   bool isSymmetric = false;
884   if (strstr(filename, ".mtx")) {
885     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
886   } else if (strstr(filename, ".tns")) {
887     readExtFROSTTHeader(file, filename, line, idata);
888   } else {
889     fprintf(stderr, "Unknown format %s\n", filename);
890     exit(1);
891   }
892   // Prepare sparse tensor object with per-dimension sizes
893   // and the number of nonzeros as initial capacity.
894   assert(rank == idata[0] && "rank mismatch");
895   uint64_t nnz = idata[1];
896   for (uint64_t r = 0; r < rank; r++)
897     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
898            "dimension size mismatch");
899   SparseTensorCOO<V> *tensor =
900       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
901   //  Read all nonzero elements.
902   std::vector<uint64_t> indices(rank);
903   for (uint64_t k = 0; k < nnz; k++) {
904     if (!fgets(line, kColWidth, file)) {
905       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
906       exit(1);
907     }
908     char *linePtr = line;
909     for (uint64_t r = 0; r < rank; r++) {
910       uint64_t idx = strtoul(linePtr, &linePtr, 10);
911       // Add 0-based index.
912       indices[perm[r]] = idx - 1;
913     }
914     // The external formats always store the numerical values with the type
915     // double, but we cast these values to the sparse tensor object type.
916     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
917     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
918     tensor->add(indices, value);
919     // We currently chose to deal with symmetric matrices by fully constructing
920     // them. In the future, we may want to make symmetry implicit for storage
921     // reasons.
922     if (isSymmetric && indices[0] != indices[1])
923       tensor->add({indices[1], indices[0]}, value);
924   }
925   // Close the file and return tensor.
926   fclose(file);
927   return tensor;
928 }
929 
930 /// Writes the sparse tensor to extended FROSTT format.
931 template <typename V>
932 static void outSparseTensor(void *tensor, void *dest, bool sort) {
933   assert(tensor && dest);
934   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
935   if (sort)
936     coo->sort();
937   char *filename = static_cast<char *>(dest);
938   auto &sizes = coo->getSizes();
939   auto &elements = coo->getElements();
940   uint64_t rank = coo->getRank();
941   uint64_t nnz = elements.size();
942   std::fstream file;
943   file.open(filename, std::ios_base::out | std::ios_base::trunc);
944   assert(file.is_open());
945   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
946   for (uint64_t r = 0; r < rank - 1; r++)
947     file << sizes[r] << " ";
948   file << sizes[rank - 1] << std::endl;
949   for (uint64_t i = 0; i < nnz; i++) {
950     auto &idx = elements[i].indices;
951     for (uint64_t r = 0; r < rank; r++)
952       file << (idx[r] + 1) << " ";
953     file << elements[i].value << std::endl;
954   }
955   file.flush();
956   file.close();
957   assert(file.good());
958 }
959 
960 /// Initializes sparse tensor from an external COO-flavored format.
961 template <typename V>
962 static SparseTensorStorage<uint64_t, uint64_t, V> *
963 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
964                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
965   const DimLevelType *sparsity = (DimLevelType *)(sparse);
966 #ifndef NDEBUG
967   // Verify that perm is a permutation of 0..(rank-1).
968   std::vector<uint64_t> order(perm, perm + rank);
969   std::sort(order.begin(), order.end());
970   for (uint64_t i = 0; i < rank; ++i) {
971     if (i != order[i]) {
972       fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
973       exit(1);
974     }
975   }
976 
977   // Verify that the sparsity values are supported.
978   for (uint64_t i = 0; i < rank; ++i) {
979     if (sparsity[i] != DimLevelType::kDense &&
980         sparsity[i] != DimLevelType::kCompressed) {
981       fprintf(stderr, "Unsupported sparsity value %d\n",
982               static_cast<int>(sparsity[i]));
983       exit(1);
984     }
985   }
986 #endif
987 
988   // Convert external format to internal COO.
989   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
990   std::vector<uint64_t> idx(rank);
991   for (uint64_t i = 0, base = 0; i < nse; i++) {
992     for (uint64_t r = 0; r < rank; r++)
993       idx[perm[r]] = indices[base + r];
994     coo->add(idx, values[i]);
995     base += rank;
996   }
997   // Return sparse tensor storage format as opaque pointer.
998   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
999       rank, shape, perm, sparsity, coo);
1000   delete coo;
1001   return tensor;
1002 }
1003 
1004 /// Converts a sparse tensor to an external COO-flavored format.
1005 template <typename V>
1006 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1007                                  uint64_t **pShape, V **pValues,
1008                                  uint64_t **pIndices) {
1009   auto sparseTensor =
1010       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1011   uint64_t rank = sparseTensor->getRank();
1012   std::vector<uint64_t> perm(rank);
1013   std::iota(perm.begin(), perm.end(), 0);
1014   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1015 
1016   const std::vector<Element<V>> &elements = coo->getElements();
1017   uint64_t nse = elements.size();
1018 
1019   uint64_t *shape = new uint64_t[rank];
1020   for (uint64_t i = 0; i < rank; i++)
1021     shape[i] = coo->getSizes()[i];
1022 
1023   V *values = new V[nse];
1024   uint64_t *indices = new uint64_t[rank * nse];
1025 
1026   for (uint64_t i = 0, base = 0; i < nse; i++) {
1027     values[i] = elements[i].value;
1028     for (uint64_t j = 0; j < rank; j++)
1029       indices[base + j] = elements[i].indices[j];
1030     base += rank;
1031   }
1032 
1033   delete coo;
1034   *pRank = rank;
1035   *pNse = nse;
1036   *pShape = shape;
1037   *pValues = values;
1038   *pIndices = indices;
1039 }
1040 
1041 } // namespace
1042 
1043 extern "C" {
1044 
1045 //===----------------------------------------------------------------------===//
1046 //
1047 // Public API with methods that operate on MLIR buffers (memrefs) to interact
1048 // with sparse tensors, which are only visible as opaque pointers externally.
1049 // These methods should be used exclusively by MLIR compiler-generated code.
1050 //
1051 // Some macro magic is used to generate implementations for all required type
1052 // combinations that can be called from MLIR compiler-generated code.
1053 //
1054 //===----------------------------------------------------------------------===//
1055 
1056 #define CASE(p, i, v, P, I, V)                                                 \
1057   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
1058     SparseTensorCOO<V> *coo = nullptr;                                         \
1059     if (action <= Action::kFromCOO) {                                          \
1060       if (action == Action::kFromFile) {                                       \
1061         char *filename = static_cast<char *>(ptr);                             \
1062         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
1063       } else if (action == Action::kFromCOO) {                                 \
1064         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
1065       } else {                                                                 \
1066         assert(action == Action::kEmpty);                                      \
1067       }                                                                        \
1068       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
1069           rank, shape, perm, sparsity, coo);                                   \
1070       if (action == Action::kFromFile)                                         \
1071         delete coo;                                                            \
1072       return tensor;                                                           \
1073     }                                                                          \
1074     if (action == Action::kEmptyCOO)                                           \
1075       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
1076     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1077     if (action == Action::kToIterator) {                                       \
1078       coo->startIterator();                                                    \
1079     } else {                                                                   \
1080       assert(action == Action::kToCOO);                                        \
1081     }                                                                          \
1082     return coo;                                                                \
1083   }
1084 
1085 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1086 
1087 #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
1088   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
1089     assert(ref &&tensor);                                                      \
1090     std::vector<TYPE> *v;                                                      \
1091     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
1092     ref->basePtr = ref->data = v->data();                                      \
1093     ref->offset = 0;                                                           \
1094     ref->sizes[0] = v->size();                                                 \
1095     ref->strides[0] = 1;                                                       \
1096   }
1097 
1098 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1099   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1100                            index_type d) {                                     \
1101     assert(ref &&tensor);                                                      \
1102     std::vector<TYPE> *v;                                                      \
1103     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1104     ref->basePtr = ref->data = v->data();                                      \
1105     ref->offset = 0;                                                           \
1106     ref->sizes[0] = v->size();                                                 \
1107     ref->strides[0] = 1;                                                       \
1108   }
1109 
1110 #define IMPL_ADDELT(NAME, TYPE)                                                \
1111   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
1112                             StridedMemRefType<index_type, 1> *iref,            \
1113                             StridedMemRefType<index_type, 1> *pref) {          \
1114     assert(tensor &&iref &&pref);                                              \
1115     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1116     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1117     const index_type *indx = iref->data + iref->offset;                        \
1118     const index_type *perm = pref->data + pref->offset;                        \
1119     uint64_t isize = iref->sizes[0];                                           \
1120     std::vector<index_type> indices(isize);                                    \
1121     for (uint64_t r = 0; r < isize; r++)                                       \
1122       indices[perm[r]] = indx[r];                                              \
1123     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
1124     return tensor;                                                             \
1125   }
1126 
1127 #define IMPL_GETNEXT(NAME, V)                                                  \
1128   bool _mlir_ciface_##NAME(void *tensor,                                       \
1129                            StridedMemRefType<index_type, 1> *iref,             \
1130                            StridedMemRefType<V, 0> *vref) {                    \
1131     assert(tensor &&iref &&vref);                                              \
1132     assert(iref->strides[0] == 1);                                             \
1133     index_type *indx = iref->data + iref->offset;                              \
1134     V *value = vref->data + vref->offset;                                      \
1135     const uint64_t isize = iref->sizes[0];                                     \
1136     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
1137     const Element<V> *elem = iter->getNext();                                  \
1138     if (elem == nullptr)                                                       \
1139       return false;                                                            \
1140     for (uint64_t r = 0; r < isize; r++)                                       \
1141       indx[r] = elem->indices[r];                                              \
1142     *value = elem->value;                                                      \
1143     return true;                                                               \
1144   }
1145 
1146 #define IMPL_LEXINSERT(NAME, V)                                                \
1147   void _mlir_ciface_##NAME(void *tensor,                                       \
1148                            StridedMemRefType<index_type, 1> *cref, V val) {    \
1149     assert(tensor &&cref);                                                     \
1150     assert(cref->strides[0] == 1);                                             \
1151     index_type *cursor = cref->data + cref->offset;                            \
1152     assert(cursor);                                                            \
1153     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1154   }
1155 
1156 #define IMPL_EXPINSERT(NAME, V)                                                \
1157   void _mlir_ciface_##NAME(                                                    \
1158       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1159       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1160       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1161     assert(tensor &&cref &&vref &&fref &&aref);                                \
1162     assert(cref->strides[0] == 1);                                             \
1163     assert(vref->strides[0] == 1);                                             \
1164     assert(fref->strides[0] == 1);                                             \
1165     assert(aref->strides[0] == 1);                                             \
1166     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1167     index_type *cursor = cref->data + cref->offset;                            \
1168     V *values = vref->data + vref->offset;                                     \
1169     bool *filled = fref->data + fref->offset;                                  \
1170     index_type *added = aref->data + aref->offset;                             \
1171     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1172         cursor, values, filled, added, count);                                 \
1173   }
1174 
1175 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1176 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1177 // that this file cannot get out of sync with its header.
1178 static_assert(std::is_same<index_type, uint64_t>::value,
1179               "Expected index_type == uint64_t");
1180 
1181 /// Constructs a new sparse tensor. This is the "swiss army knife"
1182 /// method for materializing sparse tensors into the computation.
1183 ///
1184 /// Action:
1185 /// kEmpty = returns empty storage to fill later
1186 /// kFromFile = returns storage, where ptr contains filename to read
1187 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
1188 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
1189 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
1190 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
1191 void *
1192 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1193                              StridedMemRefType<index_type, 1> *sref,
1194                              StridedMemRefType<index_type, 1> *pref,
1195                              OverheadType ptrTp, OverheadType indTp,
1196                              PrimaryType valTp, Action action, void *ptr) {
1197   assert(aref && sref && pref);
1198   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1199          pref->strides[0] == 1);
1200   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1201   const DimLevelType *sparsity = aref->data + aref->offset;
1202   const index_type *shape = sref->data + sref->offset;
1203   const index_type *perm = pref->data + pref->offset;
1204   uint64_t rank = aref->sizes[0];
1205 
1206   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1207   // This is safe because of the static_assert above.
1208   if (ptrTp == OverheadType::kIndex)
1209     ptrTp = OverheadType::kU64;
1210   if (indTp == OverheadType::kIndex)
1211     indTp = OverheadType::kU64;
1212 
1213   // Double matrices with all combinations of overhead storage.
1214   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1215        uint64_t, double);
1216   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1217        uint32_t, double);
1218   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1219        uint16_t, double);
1220   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1221        uint8_t, double);
1222   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1223        uint64_t, double);
1224   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1225        uint32_t, double);
1226   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1227        uint16_t, double);
1228   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1229        uint8_t, double);
1230   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1231        uint64_t, double);
1232   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1233        uint32_t, double);
1234   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1235        uint16_t, double);
1236   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1237        uint8_t, double);
1238   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1239        uint64_t, double);
1240   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1241        uint32_t, double);
1242   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1243        uint16_t, double);
1244   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1245        uint8_t, double);
1246 
1247   // Float matrices with all combinations of overhead storage.
1248   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1249        uint64_t, float);
1250   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1251        uint32_t, float);
1252   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1253        uint16_t, float);
1254   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1255        uint8_t, float);
1256   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1257        uint64_t, float);
1258   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1259        uint32_t, float);
1260   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1261        uint16_t, float);
1262   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1263        uint8_t, float);
1264   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1265        uint64_t, float);
1266   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1267        uint32_t, float);
1268   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1269        uint16_t, float);
1270   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1271        uint8_t, float);
1272   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1273        uint64_t, float);
1274   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1275        uint32_t, float);
1276   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1277        uint16_t, float);
1278   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1279        uint8_t, float);
1280 
1281   // Integral matrices with both overheads of the same type.
1282   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1283   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1284   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1285   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1286   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1287   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1288   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1289   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1290   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1291   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1292   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1293   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1294   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1295 
1296   // Unsupported case (add above if needed).
1297   fputs("unsupported combination of types\n", stderr);
1298   exit(1);
1299 }
1300 
1301 /// Methods that provide direct access to pointers.
1302 IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
1303 IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
1304 IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
1305 IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
1306 IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
1307 
1308 /// Methods that provide direct access to indices.
1309 IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
1310 IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
1311 IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
1312 IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
1313 IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
1314 
1315 /// Methods that provide direct access to values.
1316 IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
1317 IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
1318 IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
1319 IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
1320 IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
1321 IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
1322 
1323 /// Helper to add value to coordinate scheme, one per value type.
1324 IMPL_ADDELT(addEltF64, double)
1325 IMPL_ADDELT(addEltF32, float)
1326 IMPL_ADDELT(addEltI64, int64_t)
1327 IMPL_ADDELT(addEltI32, int32_t)
1328 IMPL_ADDELT(addEltI16, int16_t)
1329 IMPL_ADDELT(addEltI8, int8_t)
1330 
1331 /// Helper to enumerate elements of coordinate scheme, one per value type.
1332 IMPL_GETNEXT(getNextF64, double)
1333 IMPL_GETNEXT(getNextF32, float)
1334 IMPL_GETNEXT(getNextI64, int64_t)
1335 IMPL_GETNEXT(getNextI32, int32_t)
1336 IMPL_GETNEXT(getNextI16, int16_t)
1337 IMPL_GETNEXT(getNextI8, int8_t)
1338 
1339 /// Insert elements in lexicographical index order, one per value type.
1340 IMPL_LEXINSERT(lexInsertF64, double)
1341 IMPL_LEXINSERT(lexInsertF32, float)
1342 IMPL_LEXINSERT(lexInsertI64, int64_t)
1343 IMPL_LEXINSERT(lexInsertI32, int32_t)
1344 IMPL_LEXINSERT(lexInsertI16, int16_t)
1345 IMPL_LEXINSERT(lexInsertI8, int8_t)
1346 
1347 /// Insert using expansion, one per value type.
1348 IMPL_EXPINSERT(expInsertF64, double)
1349 IMPL_EXPINSERT(expInsertF32, float)
1350 IMPL_EXPINSERT(expInsertI64, int64_t)
1351 IMPL_EXPINSERT(expInsertI32, int32_t)
1352 IMPL_EXPINSERT(expInsertI16, int16_t)
1353 IMPL_EXPINSERT(expInsertI8, int8_t)
1354 
1355 #undef CASE
1356 #undef IMPL_SPARSEVALUES
1357 #undef IMPL_GETOVERHEAD
1358 #undef IMPL_ADDELT
1359 #undef IMPL_GETNEXT
1360 #undef IMPL_LEXINSERT
1361 #undef IMPL_EXPINSERT
1362 
1363 /// Output a sparse tensor, one per value type.
1364 void outSparseTensorF64(void *tensor, void *dest, bool sort) {
1365   return outSparseTensor<double>(tensor, dest, sort);
1366 }
1367 void outSparseTensorF32(void *tensor, void *dest, bool sort) {
1368   return outSparseTensor<float>(tensor, dest, sort);
1369 }
1370 void outSparseTensorI64(void *tensor, void *dest, bool sort) {
1371   return outSparseTensor<int64_t>(tensor, dest, sort);
1372 }
1373 void outSparseTensorI32(void *tensor, void *dest, bool sort) {
1374   return outSparseTensor<int32_t>(tensor, dest, sort);
1375 }
1376 void outSparseTensorI16(void *tensor, void *dest, bool sort) {
1377   return outSparseTensor<int16_t>(tensor, dest, sort);
1378 }
1379 void outSparseTensorI8(void *tensor, void *dest, bool sort) {
1380   return outSparseTensor<int8_t>(tensor, dest, sort);
1381 }
1382 
1383 //===----------------------------------------------------------------------===//
1384 //
1385 // Public API with methods that accept C-style data structures to interact
1386 // with sparse tensors, which are only visible as opaque pointers externally.
1387 // These methods can be used both by MLIR compiler-generated code as well as by
1388 // an external runtime that wants to interact with MLIR compiler-generated code.
1389 //
1390 //===----------------------------------------------------------------------===//
1391 
1392 /// Helper method to read a sparse tensor filename from the environment,
1393 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1394 char *getTensorFilename(index_type id) {
1395   char var[80];
1396   sprintf(var, "TENSOR%" PRIu64, id);
1397   char *env = getenv(var);
1398   if (!env) {
1399     fprintf(stderr, "Environment variable %s is not set\n", var);
1400     exit(1);
1401   }
1402   return env;
1403 }
1404 
1405 /// Returns size of sparse tensor in given dimension.
1406 index_type sparseDimSize(void *tensor, index_type d) {
1407   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1408 }
1409 
1410 /// Finalizes lexicographic insertions.
1411 void endInsert(void *tensor) {
1412   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1413 }
1414 
1415 /// Releases sparse tensor storage.
1416 void delSparseTensor(void *tensor) {
1417   delete static_cast<SparseTensorStorageBase *>(tensor);
1418 }
1419 
1420 /// Releases sparse tensor coordinate scheme.
1421 #define IMPL_DELCOO(VNAME, V)                                                  \
1422   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1423     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1424   }
1425 IMPL_DELCOO(F64, double)
1426 IMPL_DELCOO(F32, float)
1427 IMPL_DELCOO(I64, int64_t)
1428 IMPL_DELCOO(I32, int32_t)
1429 IMPL_DELCOO(I16, int16_t)
1430 IMPL_DELCOO(I8, int8_t)
1431 #undef IMPL_DELCOO
1432 
1433 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
1434 /// data structures. The expected parameters are:
1435 ///
1436 ///   rank:    rank of tensor
1437 ///   nse:     number of specified elements (usually the nonzeros)
1438 ///   shape:   array with dimension size for each rank
1439 ///   values:  a "nse" array with values for all specified elements
1440 ///   indices: a flat "nse x rank" array with indices for all specified elements
1441 ///   perm:    the permutation of the dimensions in the storage
1442 ///   sparse:  the sparsity for the dimensions
1443 ///
1444 /// For example, the sparse matrix
1445 ///     | 1.0 0.0 0.0 |
1446 ///     | 0.0 5.0 3.0 |
1447 /// can be passed as
1448 ///      rank    = 2
1449 ///      nse     = 3
1450 ///      shape   = [2, 3]
1451 ///      values  = [1.0, 5.0, 3.0]
1452 ///      indices = [ 0, 0,  1, 1,  1, 2]
1453 //
1454 // TODO: generalize beyond 64-bit indices.
1455 //
1456 void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape,
1457                                    double *values, uint64_t *indices,
1458                                    uint64_t *perm, uint8_t *sparse) {
1459   return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm,
1460                                     sparse);
1461 }
1462 void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
1463                                    float *values, uint64_t *indices,
1464                                    uint64_t *perm, uint8_t *sparse) {
1465   return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm,
1466                                    sparse);
1467 }
1468 void *convertToMLIRSparseTensorI64(uint64_t rank, uint64_t nse, uint64_t *shape,
1469                                    int64_t *values, uint64_t *indices,
1470                                    uint64_t *perm, uint8_t *sparse) {
1471   return toMLIRSparseTensor<int64_t>(rank, nse, shape, values, indices, perm,
1472                                      sparse);
1473 }
1474 void *convertToMLIRSparseTensorI32(uint64_t rank, uint64_t nse, uint64_t *shape,
1475                                    int32_t *values, uint64_t *indices,
1476                                    uint64_t *perm, uint8_t *sparse) {
1477   return toMLIRSparseTensor<int32_t>(rank, nse, shape, values, indices, perm,
1478                                      sparse);
1479 }
1480 void *convertToMLIRSparseTensorI16(uint64_t rank, uint64_t nse, uint64_t *shape,
1481                                    int16_t *values, uint64_t *indices,
1482                                    uint64_t *perm, uint8_t *sparse) {
1483   return toMLIRSparseTensor<int16_t>(rank, nse, shape, values, indices, perm,
1484                                      sparse);
1485 }
1486 void *convertToMLIRSparseTensorI8(uint64_t rank, uint64_t nse, uint64_t *shape,
1487                                   int8_t *values, uint64_t *indices,
1488                                   uint64_t *perm, uint8_t *sparse) {
1489   return toMLIRSparseTensor<int8_t>(rank, nse, shape, values, indices, perm,
1490                                     sparse);
1491 }
1492 
1493 /// Converts a sparse tensor to COO-flavored format expressed using C-style
1494 /// data structures. The expected output parameters are pointers for these
1495 /// values:
1496 ///
1497 ///   rank:    rank of tensor
1498 ///   nse:     number of specified elements (usually the nonzeros)
1499 ///   shape:   array with dimension size for each rank
1500 ///   values:  a "nse" array with values for all specified elements
1501 ///   indices: a flat "nse x rank" array with indices for all specified elements
1502 ///
1503 /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
1504 /// from convertToMLIRSparseTensor.
1505 ///
1506 //  TODO: Currently, values are copied from SparseTensorStorage to
1507 //  SparseTensorCOO, then to the output. We may want to reduce the number of
1508 //  copies.
1509 //
1510 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1511 // compressed
1512 //
1513 void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank,
1514                                     uint64_t *pNse, uint64_t **pShape,
1515                                     double **pValues, uint64_t **pIndices) {
1516   fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices);
1517 }
1518 void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
1519                                     uint64_t *pNse, uint64_t **pShape,
1520                                     float **pValues, uint64_t **pIndices) {
1521   fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
1522 }
1523 void convertFromMLIRSparseTensorI64(void *tensor, uint64_t *pRank,
1524                                     uint64_t *pNse, uint64_t **pShape,
1525                                     int64_t **pValues, uint64_t **pIndices) {
1526   fromMLIRSparseTensor<int64_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1527 }
1528 void convertFromMLIRSparseTensorI32(void *tensor, uint64_t *pRank,
1529                                     uint64_t *pNse, uint64_t **pShape,
1530                                     int32_t **pValues, uint64_t **pIndices) {
1531   fromMLIRSparseTensor<int32_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1532 }
1533 void convertFromMLIRSparseTensorI16(void *tensor, uint64_t *pRank,
1534                                     uint64_t *pNse, uint64_t **pShape,
1535                                     int16_t **pValues, uint64_t **pIndices) {
1536   fromMLIRSparseTensor<int16_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1537 }
1538 void convertFromMLIRSparseTensorI8(void *tensor, uint64_t *pRank,
1539                                    uint64_t *pNse, uint64_t **pShape,
1540                                    int8_t **pValues, uint64_t **pIndices) {
1541   fromMLIRSparseTensor<int8_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1542 }
1543 
1544 } // extern "C"
1545 
1546 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1547