1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cctype>
25 #include <cinttypes>
26 #include <cstdio>
27 #include <cstdlib>
28 #include <cstring>
29 #include <fstream>
30 #include <iostream>
31 #include <limits>
32 #include <numeric>
33 #include <vector>
34 
35 //===----------------------------------------------------------------------===//
36 //
37 // Internal support for storing and reading sparse tensors.
38 //
39 // The following memory-resident sparse storage schemes are supported:
40 //
41 // (a) A coordinate scheme for temporarily storing and lexicographically
42 //     sorting a sparse tensor by index (SparseTensorCOO).
43 //
44 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
45 //     per-dimension sparse/dense annnotations together with a dimension
46 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
47 //
48 // The following external formats are supported:
49 //
50 // (1) Matrix Market Exchange (MME): *.mtx
51 //     https://math.nist.gov/MatrixMarket/formats.html
52 //
53 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
54 //     http://frostt.io/tensors/file-formats.html
55 //
56 // Two public APIs are supported:
57 //
58 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
59 //     tensors. These methods should be used exclusively by MLIR
60 //     compiler-generated code.
61 //
62 // (II) Methods that accept C-style data structures to interact with sparse
63 //      tensors. These methods can be used by any external runtime that wants
64 //      to interact with MLIR compiler-generated code.
65 //
66 // In both cases (I) and (II), the SparseTensorStorage format is externally
67 // only visible as an opaque pointer.
68 //
69 //===----------------------------------------------------------------------===//
70 
71 namespace {
72 
73 static constexpr int kColWidth = 1025;
74 
75 /// A version of `operator*` on `uint64_t` which checks for overflows.
76 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
77   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
78          "Integer overflow");
79   return lhs * rhs;
80 }
81 
82 /// A sparse tensor element in coordinate scheme (value and indices).
83 /// For example, a rank-1 vector element would look like
84 ///   ({i}, a[i])
85 /// and a rank-5 tensor element like
86 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
87 template <typename V>
88 struct Element {
89   Element(const std::vector<uint64_t> &ind, V val) : indices(ind), value(val){};
90   std::vector<uint64_t> indices;
91   V value;
92   /// Returns true if indices of e1 < indices of e2.
93   static bool lexOrder(const Element<V> &e1, const Element<V> &e2) {
94     uint64_t rank = e1.indices.size();
95     assert(rank == e2.indices.size());
96     for (uint64_t r = 0; r < rank; r++) {
97       if (e1.indices[r] == e2.indices[r])
98         continue;
99       return e1.indices[r] < e2.indices[r];
100     }
101     return false;
102   }
103 };
104 
105 /// A memory-resident sparse tensor in coordinate scheme (collection of
106 /// elements). This data structure is used to read a sparse tensor from
107 /// any external format into memory and sort the elements lexicographically
108 /// by indices before passing it back to the client (most packed storage
109 /// formats require the elements to appear in lexicographic index order).
110 template <typename V>
111 struct SparseTensorCOO {
112 public:
113   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
114       : sizes(szs) {
115     if (capacity)
116       elements.reserve(capacity);
117   }
118   /// Adds element as indices and value.
119   void add(const std::vector<uint64_t> &ind, V val) {
120     assert(!iteratorLocked && "Attempt to add() after startIterator()");
121     uint64_t rank = getRank();
122     assert(rank == ind.size());
123     for (uint64_t r = 0; r < rank; r++)
124       assert(ind[r] < sizes[r]); // within bounds
125     elements.emplace_back(ind, val);
126   }
127   /// Sorts elements lexicographically by index.
128   void sort() {
129     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
130     // TODO: we may want to cache an `isSorted` bit, to avoid
131     // unnecessary/redundant sorting.
132     std::sort(elements.begin(), elements.end(), Element<V>::lexOrder);
133   }
134   /// Returns rank.
135   uint64_t getRank() const { return sizes.size(); }
136   /// Getter for sizes array.
137   const std::vector<uint64_t> &getSizes() const { return sizes; }
138   /// Getter for elements array.
139   const std::vector<Element<V>> &getElements() const { return elements; }
140 
141   /// Switch into iterator mode.
142   void startIterator() {
143     iteratorLocked = true;
144     iteratorPos = 0;
145   }
146   /// Get the next element.
147   const Element<V> *getNext() {
148     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
149     if (iteratorPos < elements.size())
150       return &(elements[iteratorPos++]);
151     iteratorLocked = false;
152     return nullptr;
153   }
154 
155   /// Factory method. Permutes the original dimensions according to
156   /// the given ordering and expects subsequent add() calls to honor
157   /// that same ordering for the given indices. The result is a
158   /// fully permuted coordinate scheme.
159   ///
160   /// Precondition: `sizes` and `perm` must be valid for `rank`.
161   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
162                                                 const uint64_t *sizes,
163                                                 const uint64_t *perm,
164                                                 uint64_t capacity = 0) {
165     std::vector<uint64_t> permsz(rank);
166     for (uint64_t r = 0; r < rank; r++) {
167       assert(sizes[r] > 0 && "Dimension size zero has trivial storage");
168       permsz[perm[r]] = sizes[r];
169     }
170     return new SparseTensorCOO<V>(permsz, capacity);
171   }
172 
173 private:
174   const std::vector<uint64_t> sizes; // per-dimension sizes
175   std::vector<Element<V>> elements;
176   bool iteratorLocked = false;
177   unsigned iteratorPos = 0;
178 };
179 
180 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
181 /// takes responsibility for all the `<P,I,V>`-independent aspects
182 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
183 /// we use function overloading to implement "partial" method
184 /// specialization, which the C-API relies on to catch type errors
185 /// arising from our use of opaque pointers.
186 class SparseTensorStorageBase {
187 public:
188   /// Constructs a new storage object.  The `perm` maps the tensor's
189   /// semantic-ordering of dimensions to this object's storage-order.
190   /// The `szs` and `sparsity` arrays are already in storage-order.
191   ///
192   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
193   SparseTensorStorageBase(const std::vector<uint64_t> &szs,
194                           const uint64_t *perm, const DimLevelType *sparsity)
195       : dimSizes(szs), rev(getRank()),
196         dimTypes(sparsity, sparsity + getRank()) {
197     const uint64_t rank = getRank();
198     // Validate parameters.
199     assert(rank > 0 && "Trivial shape is unsupported");
200     for (uint64_t r = 0; r < rank; r++) {
201       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
202       assert((dimTypes[r] == DimLevelType::kDense ||
203               dimTypes[r] == DimLevelType::kCompressed) &&
204              "Unsupported DimLevelType");
205     }
206     // Construct the "reverse" (i.e., inverse) permutation.
207     for (uint64_t r = 0; r < rank; r++)
208       rev[perm[r]] = r;
209   }
210 
211   virtual ~SparseTensorStorageBase() = default;
212 
213   /// Get the rank of the tensor.
214   uint64_t getRank() const { return dimSizes.size(); }
215 
216   /// Getter for the dimension-sizes array, in storage-order.
217   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
218 
219   /// Safely lookup the size of the given (storage-order) dimension.
220   uint64_t getDimSize(uint64_t d) const {
221     assert(d < getRank());
222     return dimSizes[d];
223   }
224 
225   /// Getter for the "reverse" permutation, which maps this object's
226   /// storage-order to the tensor's semantic-order.
227   const std::vector<uint64_t> &getRev() const { return rev; }
228 
229   /// Getter for the dimension-types array, in storage-order.
230   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
231 
232   /// Safely check if the (storage-order) dimension uses compressed storage.
233   bool isCompressedDim(uint64_t d) const {
234     assert(d < getRank());
235     return (dimTypes[d] == DimLevelType::kCompressed);
236   }
237 
238   /// Overhead storage.
239   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
240   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
241   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
242   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
243   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
244   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
245   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
246   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
247 
248   /// Primary storage.
249   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
250   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
251   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
252   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
253   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
254   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
255 
256   /// Element-wise insertion in lexicographic index order.
257   virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
258   virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); }
259   virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); }
260   virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
261   virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
262   virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
263 
264   /// Expanded insertion.
265   virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
266     fatal("expf64");
267   }
268   virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
269     fatal("expf32");
270   }
271   virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
272     fatal("expi64");
273   }
274   virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
275     fatal("expi32");
276   }
277   virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
278     fatal("expi16");
279   }
280   virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
281     fatal("expi8");
282   }
283 
284   /// Finishes insertion.
285   virtual void endInsert() = 0;
286 
287 private:
288   static void fatal(const char *tp) {
289     fprintf(stderr, "unsupported %s\n", tp);
290     exit(1);
291   }
292 
293   const std::vector<uint64_t> dimSizes;
294   std::vector<uint64_t> rev;
295   const std::vector<DimLevelType> dimTypes;
296 };
297 
298 /// A memory-resident sparse tensor using a storage scheme based on
299 /// per-dimension sparse/dense annotations. This data structure provides a
300 /// bufferized form of a sparse tensor type. In contrast to generating setup
301 /// methods for each differently annotated sparse tensor, this method provides
302 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
303 /// and annotations to implement all required setup in a general manner.
304 template <typename P, typename I, typename V>
305 class SparseTensorStorage : public SparseTensorStorageBase {
306 public:
307   /// Constructs a sparse tensor storage scheme with the given dimensions,
308   /// permutation, and per-dimension dense/sparse annotations, using
309   /// the coordinate scheme tensor for the initial contents if provided.
310   ///
311   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
312   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
313                       const DimLevelType *sparsity,
314                       SparseTensorCOO<V> *coo = nullptr)
315       : SparseTensorStorageBase(szs, perm, sparsity), pointers(getRank()),
316         indices(getRank()), idx(getRank()) {
317     // Provide hints on capacity of pointers and indices.
318     // TODO: needs much fine-tuning based on actual sparsity; currently
319     //       we reserve pointer/index space based on all previous dense
320     //       dimensions, which works well up to first sparse dim; but
321     //       we should really use nnz and dense/sparse distribution.
322     bool allDense = true;
323     uint64_t sz = 1;
324     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
325       if (isCompressedDim(r)) {
326         // TODO: Take a parameter between 1 and `sizes[r]`, and multiply
327         // `sz` by that before reserving. (For now we just use 1.)
328         pointers[r].reserve(sz + 1);
329         pointers[r].push_back(0);
330         indices[r].reserve(sz);
331         sz = 1;
332         allDense = false;
333       } else { // Dense dimension.
334         sz = checkedMul(sz, getDimSizes()[r]);
335       }
336     }
337     // Then assign contents from coordinate scheme tensor if provided.
338     if (coo) {
339       // Ensure both preconditions of `fromCOO`.
340       assert(coo->getSizes() == getDimSizes() && "Tensor size mismatch");
341       coo->sort();
342       // Now actually insert the `elements`.
343       const std::vector<Element<V>> &elements = coo->getElements();
344       uint64_t nnz = elements.size();
345       values.reserve(nnz);
346       fromCOO(elements, 0, nnz, 0);
347     } else if (allDense) {
348       values.resize(sz, 0);
349     }
350   }
351 
352   ~SparseTensorStorage() override = default;
353 
354   /// Partially specialize these getter methods based on template types.
355   void getPointers(std::vector<P> **out, uint64_t d) override {
356     assert(d < getRank());
357     *out = &pointers[d];
358   }
359   void getIndices(std::vector<I> **out, uint64_t d) override {
360     assert(d < getRank());
361     *out = &indices[d];
362   }
363   void getValues(std::vector<V> **out) override { *out = &values; }
364 
365   /// Partially specialize lexicographical insertions based on template types.
366   void lexInsert(const uint64_t *cursor, V val) override {
367     // First, wrap up pending insertion path.
368     uint64_t diff = 0;
369     uint64_t top = 0;
370     if (!values.empty()) {
371       diff = lexDiff(cursor);
372       endPath(diff + 1);
373       top = idx[diff] + 1;
374     }
375     // Then continue with insertion path.
376     insPath(cursor, diff, top, val);
377   }
378 
379   /// Partially specialize expanded insertions based on template types.
380   /// Note that this method resets the values/filled-switch array back
381   /// to all-zero/false while only iterating over the nonzero elements.
382   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
383                  uint64_t count) override {
384     if (count == 0)
385       return;
386     // Sort.
387     std::sort(added, added + count);
388     // Restore insertion path for first insert.
389     const uint64_t lastDim = getRank() - 1;
390     uint64_t index = added[0];
391     cursor[lastDim] = index;
392     lexInsert(cursor, values[index]);
393     assert(filled[index]);
394     values[index] = 0;
395     filled[index] = false;
396     // Subsequent insertions are quick.
397     for (uint64_t i = 1; i < count; i++) {
398       assert(index < added[i] && "non-lexicographic insertion");
399       index = added[i];
400       cursor[lastDim] = index;
401       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
402       assert(filled[index]);
403       values[index] = 0;
404       filled[index] = false;
405     }
406   }
407 
408   /// Finalizes lexicographic insertions.
409   void endInsert() override {
410     if (values.empty())
411       finalizeSegment(0);
412     else
413       endPath(0);
414   }
415 
416   /// Returns this sparse tensor storage scheme as a new memory-resident
417   /// sparse tensor in coordinate scheme with the given dimension order.
418   ///
419   /// Precondition: `perm` must be valid for `getRank()`.
420   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
421     // Restore original order of the dimension sizes and allocate coordinate
422     // scheme with desired new ordering specified in perm.
423     const uint64_t rank = getRank();
424     const auto &rev = getRev();
425     const auto &sizes = getDimSizes();
426     std::vector<uint64_t> orgsz(rank);
427     for (uint64_t r = 0; r < rank; r++)
428       orgsz[rev[r]] = sizes[r];
429     SparseTensorCOO<V> *coo = SparseTensorCOO<V>::newSparseTensorCOO(
430         rank, orgsz.data(), perm, values.size());
431     // Populate coordinate scheme restored from old ordering and changed with
432     // new ordering. Rather than applying both reorderings during the recursion,
433     // we compute the combine permutation in advance.
434     std::vector<uint64_t> reord(rank);
435     for (uint64_t r = 0; r < rank; r++)
436       reord[r] = perm[rev[r]];
437     toCOO(*coo, reord, 0, 0);
438     // TODO: This assertion assumes there are no stored zeros,
439     // or if there are then that we don't filter them out.
440     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
441     assert(coo->getElements().size() == values.size());
442     return coo;
443   }
444 
445   /// Factory method. Constructs a sparse tensor storage scheme with the given
446   /// dimensions, permutation, and per-dimension dense/sparse annotations,
447   /// using the coordinate scheme tensor for the initial contents if provided.
448   /// In the latter case, the coordinate scheme must respect the same
449   /// permutation as is desired for the new sparse tensor storage.
450   ///
451   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
452   static SparseTensorStorage<P, I, V> *
453   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
454                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
455     SparseTensorStorage<P, I, V> *n = nullptr;
456     if (coo) {
457       assert(coo->getRank() == rank && "Tensor rank mismatch");
458       const auto &coosz = coo->getSizes();
459       for (uint64_t r = 0; r < rank; r++)
460         assert(shape[r] == 0 || shape[r] == coosz[perm[r]]);
461       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
462     } else {
463       std::vector<uint64_t> permsz(rank);
464       for (uint64_t r = 0; r < rank; r++) {
465         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
466         permsz[perm[r]] = shape[r];
467       }
468       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
469     }
470     return n;
471   }
472 
473 private:
474   /// Appends an arbitrary new position to `pointers[d]`.  This method
475   /// checks that `pos` is representable in the `P` type; however, it
476   /// does not check that `pos` is semantically valid (i.e., larger than
477   /// the previous position and smaller than `indices[d].capacity()`).
478   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
479     assert(isCompressedDim(d));
480     assert(pos <= std::numeric_limits<P>::max() &&
481            "Pointer value is too large for the P-type");
482     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
483   }
484 
485   /// Appends index `i` to dimension `d`, in the semantically general
486   /// sense.  For non-dense dimensions, that means appending to the
487   /// `indices[d]` array, checking that `i` is representable in the `I`
488   /// type; however, we do not verify other semantic requirements (e.g.,
489   /// that `i` is in bounds for `sizes[d]`, and not previously occurring
490   /// in the same segment).  For dense dimensions, this method instead
491   /// appends the appropriate number of zeros to the `values` array,
492   /// where `full` is the number of "entries" already written to `values`
493   /// for this segment (aka one after the highest index previously appended).
494   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
495     if (isCompressedDim(d)) {
496       assert(i <= std::numeric_limits<I>::max() &&
497              "Index value is too large for the I-type");
498       indices[d].push_back(static_cast<I>(i));
499     } else { // Dense dimension.
500       assert(i >= full && "Index was already filled");
501       if (i == full)
502         return; // Short-circuit, since it'll be a nop.
503       if (d + 1 == getRank())
504         values.insert(values.end(), i - full, 0);
505       else
506         finalizeSegment(d + 1, 0, i - full);
507     }
508   }
509 
510   /// Initializes sparse tensor storage scheme from a memory-resident sparse
511   /// tensor in coordinate scheme. This method prepares the pointers and
512   /// indices arrays under the given per-dimension dense/sparse annotations.
513   ///
514   /// Preconditions:
515   /// (1) the `elements` must be lexicographically sorted.
516   /// (2) the indices of every element are valid for `sizes` (equal rank
517   ///     and pointwise less-than).
518   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
519                uint64_t hi, uint64_t d) {
520     // Once dimensions are exhausted, insert the numerical values.
521     assert(d <= getRank() && hi <= elements.size());
522     if (d == getRank()) {
523       assert(lo < hi);
524       values.push_back(elements[lo].value);
525       return;
526     }
527     // Visit all elements in this interval.
528     uint64_t full = 0;
529     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
530       // Find segment in interval with same index elements in this dimension.
531       uint64_t i = elements[lo].indices[d];
532       uint64_t seg = lo + 1;
533       while (seg < hi && elements[seg].indices[d] == i)
534         seg++;
535       // Handle segment in interval for sparse or dense dimension.
536       appendIndex(d, full, i);
537       full = i + 1;
538       fromCOO(elements, lo, seg, d + 1);
539       // And move on to next segment in interval.
540       lo = seg;
541     }
542     // Finalize the sparse pointer structure at this dimension.
543     finalizeSegment(d, full);
544   }
545 
546   /// Stores the sparse tensor storage scheme into a memory-resident sparse
547   /// tensor in coordinate scheme.
548   void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord,
549              uint64_t pos, uint64_t d) {
550     assert(d <= getRank());
551     if (d == getRank()) {
552       assert(pos < values.size());
553       tensor.add(idx, values[pos]);
554     } else if (isCompressedDim(d)) {
555       // Sparse dimension.
556       for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
557         idx[reord[d]] = indices[d][ii];
558         toCOO(tensor, reord, ii, d + 1);
559       }
560     } else {
561       // Dense dimension.
562       const uint64_t sz = getDimSizes()[d];
563       const uint64_t off = pos * sz;
564       for (uint64_t i = 0; i < sz; i++) {
565         idx[reord[d]] = i;
566         toCOO(tensor, reord, off + i, d + 1);
567       }
568     }
569   }
570 
571   /// Finalize the sparse pointer structure at this dimension.
572   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
573     if (count == 0)
574       return; // Short-circuit, since it'll be a nop.
575     if (isCompressedDim(d)) {
576       appendPointer(d, indices[d].size(), count);
577     } else { // Dense dimension.
578       const uint64_t sz = getDimSizes()[d];
579       assert(sz >= full && "Segment is overfull");
580       count = checkedMul(count, sz - full);
581       // For dense storage we must enumerate all the remaining coordinates
582       // in this dimension (i.e., coordinates after the last non-zero
583       // element), and either fill in their zero values or else recurse
584       // to finalize some deeper dimension.
585       if (d + 1 == getRank())
586         values.insert(values.end(), count, 0);
587       else
588         finalizeSegment(d + 1, 0, count);
589     }
590   }
591 
592   /// Wraps up a single insertion path, inner to outer.
593   void endPath(uint64_t diff) {
594     uint64_t rank = getRank();
595     assert(diff <= rank);
596     for (uint64_t i = 0; i < rank - diff; i++) {
597       const uint64_t d = rank - i - 1;
598       finalizeSegment(d, idx[d] + 1);
599     }
600   }
601 
602   /// Continues a single insertion path, outer to inner.
603   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
604     uint64_t rank = getRank();
605     assert(diff < rank);
606     for (uint64_t d = diff; d < rank; d++) {
607       uint64_t i = cursor[d];
608       appendIndex(d, top, i);
609       top = 0;
610       idx[d] = i;
611     }
612     values.push_back(val);
613   }
614 
615   /// Finds the lexicographic differing dimension.
616   uint64_t lexDiff(const uint64_t *cursor) const {
617     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
618       if (cursor[r] > idx[r])
619         return r;
620       else
621         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
622     assert(0 && "duplication insertion");
623     return -1u;
624   }
625 
626 private:
627   std::vector<std::vector<P>> pointers;
628   std::vector<std::vector<I>> indices;
629   std::vector<V> values;
630   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
631 };
632 
633 /// Helper to convert string to lower case.
634 static char *toLower(char *token) {
635   for (char *c = token; *c; c++)
636     *c = tolower(*c);
637   return token;
638 }
639 
640 /// Read the MME header of a general sparse matrix of type real.
641 static void readMMEHeader(FILE *file, char *filename, char *line,
642                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
643   char header[64];
644   char object[64];
645   char format[64];
646   char field[64];
647   char symmetry[64];
648   // Read header line.
649   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
650              symmetry) != 5) {
651     fprintf(stderr, "Corrupt header in %s\n", filename);
652     exit(1);
653   }
654   // Set properties
655   *isPattern = (strcmp(toLower(field), "pattern") == 0);
656   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
657   // Make sure this is a general sparse matrix.
658   if (strcmp(toLower(header), "%%matrixmarket") ||
659       strcmp(toLower(object), "matrix") ||
660       strcmp(toLower(format), "coordinate") ||
661       (strcmp(toLower(field), "real") && !(*isPattern)) ||
662       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
663     fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
664     exit(1);
665   }
666   // Skip comments.
667   while (true) {
668     if (!fgets(line, kColWidth, file)) {
669       fprintf(stderr, "Cannot find data in %s\n", filename);
670       exit(1);
671     }
672     if (line[0] != '%')
673       break;
674   }
675   // Next line contains M N NNZ.
676   idata[0] = 2; // rank
677   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
678              idata + 1) != 3) {
679     fprintf(stderr, "Cannot find size in %s\n", filename);
680     exit(1);
681   }
682 }
683 
684 /// Read the "extended" FROSTT header. Although not part of the documented
685 /// format, we assume that the file starts with optional comments followed
686 /// by two lines that define the rank, the number of nonzeros, and the
687 /// dimensions sizes (one per rank) of the sparse tensor.
688 static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
689                                 uint64_t *idata) {
690   // Skip comments.
691   while (true) {
692     if (!fgets(line, kColWidth, file)) {
693       fprintf(stderr, "Cannot find data in %s\n", filename);
694       exit(1);
695     }
696     if (line[0] != '#')
697       break;
698   }
699   // Next line contains RANK and NNZ.
700   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
701     fprintf(stderr, "Cannot find metadata in %s\n", filename);
702     exit(1);
703   }
704   // Followed by a line with the dimension sizes (one per rank).
705   for (uint64_t r = 0; r < idata[0]; r++) {
706     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
707       fprintf(stderr, "Cannot find dimension size %s\n", filename);
708       exit(1);
709     }
710   }
711   fgets(line, kColWidth, file); // end of line
712 }
713 
714 /// Reads a sparse tensor with the given filename into a memory-resident
715 /// sparse tensor in coordinate scheme.
716 template <typename V>
717 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
718                                                const uint64_t *shape,
719                                                const uint64_t *perm) {
720   // Open the file.
721   FILE *file = fopen(filename, "r");
722   if (!file) {
723     assert(filename && "Received nullptr for filename");
724     fprintf(stderr, "Cannot find file %s\n", filename);
725     exit(1);
726   }
727   // Perform some file format dependent set up.
728   char line[kColWidth];
729   uint64_t idata[512];
730   bool isPattern = false;
731   bool isSymmetric = false;
732   if (strstr(filename, ".mtx")) {
733     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
734   } else if (strstr(filename, ".tns")) {
735     readExtFROSTTHeader(file, filename, line, idata);
736   } else {
737     fprintf(stderr, "Unknown format %s\n", filename);
738     exit(1);
739   }
740   // Prepare sparse tensor object with per-dimension sizes
741   // and the number of nonzeros as initial capacity.
742   assert(rank == idata[0] && "rank mismatch");
743   uint64_t nnz = idata[1];
744   for (uint64_t r = 0; r < rank; r++)
745     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
746            "dimension size mismatch");
747   SparseTensorCOO<V> *tensor =
748       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
749   //  Read all nonzero elements.
750   std::vector<uint64_t> indices(rank);
751   for (uint64_t k = 0; k < nnz; k++) {
752     if (!fgets(line, kColWidth, file)) {
753       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
754       exit(1);
755     }
756     char *linePtr = line;
757     for (uint64_t r = 0; r < rank; r++) {
758       uint64_t idx = strtoul(linePtr, &linePtr, 10);
759       // Add 0-based index.
760       indices[perm[r]] = idx - 1;
761     }
762     // The external formats always store the numerical values with the type
763     // double, but we cast these values to the sparse tensor object type.
764     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
765     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
766     tensor->add(indices, value);
767     // We currently chose to deal with symmetric matrices by fully constructing
768     // them. In the future, we may want to make symmetry implicit for storage
769     // reasons.
770     if (isSymmetric && indices[0] != indices[1])
771       tensor->add({indices[1], indices[0]}, value);
772   }
773   // Close the file and return tensor.
774   fclose(file);
775   return tensor;
776 }
777 
778 /// Writes the sparse tensor to extended FROSTT format.
779 template <typename V>
780 static void outSparseTensor(void *tensor, void *dest, bool sort) {
781   assert(tensor && dest);
782   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
783   if (sort)
784     coo->sort();
785   char *filename = static_cast<char *>(dest);
786   auto &sizes = coo->getSizes();
787   auto &elements = coo->getElements();
788   uint64_t rank = coo->getRank();
789   uint64_t nnz = elements.size();
790   std::fstream file;
791   file.open(filename, std::ios_base::out | std::ios_base::trunc);
792   assert(file.is_open());
793   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
794   for (uint64_t r = 0; r < rank - 1; r++)
795     file << sizes[r] << " ";
796   file << sizes[rank - 1] << std::endl;
797   for (uint64_t i = 0; i < nnz; i++) {
798     auto &idx = elements[i].indices;
799     for (uint64_t r = 0; r < rank; r++)
800       file << (idx[r] + 1) << " ";
801     file << elements[i].value << std::endl;
802   }
803   file.flush();
804   file.close();
805   assert(file.good());
806 }
807 
808 /// Initializes sparse tensor from an external COO-flavored format.
809 template <typename V>
810 static SparseTensorStorage<uint64_t, uint64_t, V> *
811 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
812                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
813   const DimLevelType *sparsity = (DimLevelType *)(sparse);
814 #ifndef NDEBUG
815   // Verify that perm is a permutation of 0..(rank-1).
816   std::vector<uint64_t> order(perm, perm + rank);
817   std::sort(order.begin(), order.end());
818   for (uint64_t i = 0; i < rank; ++i) {
819     if (i != order[i]) {
820       fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
821       exit(1);
822     }
823   }
824 
825   // Verify that the sparsity values are supported.
826   for (uint64_t i = 0; i < rank; ++i) {
827     if (sparsity[i] != DimLevelType::kDense &&
828         sparsity[i] != DimLevelType::kCompressed) {
829       fprintf(stderr, "Unsupported sparsity value %d\n",
830               static_cast<int>(sparsity[i]));
831       exit(1);
832     }
833   }
834 #endif
835 
836   // Convert external format to internal COO.
837   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
838   std::vector<uint64_t> idx(rank);
839   for (uint64_t i = 0, base = 0; i < nse; i++) {
840     for (uint64_t r = 0; r < rank; r++)
841       idx[perm[r]] = indices[base + r];
842     coo->add(idx, values[i]);
843     base += rank;
844   }
845   // Return sparse tensor storage format as opaque pointer.
846   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
847       rank, shape, perm, sparsity, coo);
848   delete coo;
849   return tensor;
850 }
851 
852 /// Converts a sparse tensor to an external COO-flavored format.
853 template <typename V>
854 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
855                                  uint64_t **pShape, V **pValues,
856                                  uint64_t **pIndices) {
857   auto sparseTensor =
858       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
859   uint64_t rank = sparseTensor->getRank();
860   std::vector<uint64_t> perm(rank);
861   std::iota(perm.begin(), perm.end(), 0);
862   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
863 
864   const std::vector<Element<V>> &elements = coo->getElements();
865   uint64_t nse = elements.size();
866 
867   uint64_t *shape = new uint64_t[rank];
868   for (uint64_t i = 0; i < rank; i++)
869     shape[i] = coo->getSizes()[i];
870 
871   V *values = new V[nse];
872   uint64_t *indices = new uint64_t[rank * nse];
873 
874   for (uint64_t i = 0, base = 0; i < nse; i++) {
875     values[i] = elements[i].value;
876     for (uint64_t j = 0; j < rank; j++)
877       indices[base + j] = elements[i].indices[j];
878     base += rank;
879   }
880 
881   delete coo;
882   *pRank = rank;
883   *pNse = nse;
884   *pShape = shape;
885   *pValues = values;
886   *pIndices = indices;
887 }
888 
889 } // namespace
890 
891 extern "C" {
892 
893 //===----------------------------------------------------------------------===//
894 //
895 // Public API with methods that operate on MLIR buffers (memrefs) to interact
896 // with sparse tensors, which are only visible as opaque pointers externally.
897 // These methods should be used exclusively by MLIR compiler-generated code.
898 //
899 // Some macro magic is used to generate implementations for all required type
900 // combinations that can be called from MLIR compiler-generated code.
901 //
902 //===----------------------------------------------------------------------===//
903 
904 #define CASE(p, i, v, P, I, V)                                                 \
905   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
906     SparseTensorCOO<V> *coo = nullptr;                                         \
907     if (action <= Action::kFromCOO) {                                          \
908       if (action == Action::kFromFile) {                                       \
909         char *filename = static_cast<char *>(ptr);                             \
910         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
911       } else if (action == Action::kFromCOO) {                                 \
912         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
913       } else {                                                                 \
914         assert(action == Action::kEmpty);                                      \
915       }                                                                        \
916       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
917           rank, shape, perm, sparsity, coo);                                   \
918       if (action == Action::kFromFile)                                         \
919         delete coo;                                                            \
920       return tensor;                                                           \
921     }                                                                          \
922     if (action == Action::kEmptyCOO)                                           \
923       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
924     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
925     if (action == Action::kToIterator) {                                       \
926       coo->startIterator();                                                    \
927     } else {                                                                   \
928       assert(action == Action::kToCOO);                                        \
929     }                                                                          \
930     return coo;                                                                \
931   }
932 
933 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
934 
935 #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
936   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
937     assert(ref &&tensor);                                                      \
938     std::vector<TYPE> *v;                                                      \
939     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
940     ref->basePtr = ref->data = v->data();                                      \
941     ref->offset = 0;                                                           \
942     ref->sizes[0] = v->size();                                                 \
943     ref->strides[0] = 1;                                                       \
944   }
945 
946 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
947   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
948                            index_type d) {                                     \
949     assert(ref &&tensor);                                                      \
950     std::vector<TYPE> *v;                                                      \
951     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
952     ref->basePtr = ref->data = v->data();                                      \
953     ref->offset = 0;                                                           \
954     ref->sizes[0] = v->size();                                                 \
955     ref->strides[0] = 1;                                                       \
956   }
957 
958 #define IMPL_ADDELT(NAME, TYPE)                                                \
959   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
960                             StridedMemRefType<index_type, 1> *iref,            \
961                             StridedMemRefType<index_type, 1> *pref) {          \
962     assert(tensor &&iref &&pref);                                              \
963     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
964     assert(iref->sizes[0] == pref->sizes[0]);                                  \
965     const index_type *indx = iref->data + iref->offset;                        \
966     const index_type *perm = pref->data + pref->offset;                        \
967     uint64_t isize = iref->sizes[0];                                           \
968     std::vector<index_type> indices(isize);                                    \
969     for (uint64_t r = 0; r < isize; r++)                                       \
970       indices[perm[r]] = indx[r];                                              \
971     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
972     return tensor;                                                             \
973   }
974 
975 #define IMPL_GETNEXT(NAME, V)                                                  \
976   bool _mlir_ciface_##NAME(void *tensor,                                       \
977                            StridedMemRefType<index_type, 1> *iref,             \
978                            StridedMemRefType<V, 0> *vref) {                    \
979     assert(tensor &&iref &&vref);                                              \
980     assert(iref->strides[0] == 1);                                             \
981     index_type *indx = iref->data + iref->offset;                              \
982     V *value = vref->data + vref->offset;                                      \
983     const uint64_t isize = iref->sizes[0];                                     \
984     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
985     const Element<V> *elem = iter->getNext();                                  \
986     if (elem == nullptr)                                                       \
987       return false;                                                            \
988     for (uint64_t r = 0; r < isize; r++)                                       \
989       indx[r] = elem->indices[r];                                              \
990     *value = elem->value;                                                      \
991     return true;                                                               \
992   }
993 
994 #define IMPL_LEXINSERT(NAME, V)                                                \
995   void _mlir_ciface_##NAME(void *tensor,                                       \
996                            StridedMemRefType<index_type, 1> *cref, V val) {    \
997     assert(tensor &&cref);                                                     \
998     assert(cref->strides[0] == 1);                                             \
999     index_type *cursor = cref->data + cref->offset;                            \
1000     assert(cursor);                                                            \
1001     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1002   }
1003 
1004 #define IMPL_EXPINSERT(NAME, V)                                                \
1005   void _mlir_ciface_##NAME(                                                    \
1006       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1007       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1008       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1009     assert(tensor &&cref &&vref &&fref &&aref);                                \
1010     assert(cref->strides[0] == 1);                                             \
1011     assert(vref->strides[0] == 1);                                             \
1012     assert(fref->strides[0] == 1);                                             \
1013     assert(aref->strides[0] == 1);                                             \
1014     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1015     index_type *cursor = cref->data + cref->offset;                            \
1016     V *values = vref->data + vref->offset;                                     \
1017     bool *filled = fref->data + fref->offset;                                  \
1018     index_type *added = aref->data + aref->offset;                             \
1019     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1020         cursor, values, filled, added, count);                                 \
1021   }
1022 
1023 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1024 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1025 // that this file cannot get out of sync with its header.
1026 static_assert(std::is_same<index_type, uint64_t>::value,
1027               "Expected index_type == uint64_t");
1028 
1029 /// Constructs a new sparse tensor. This is the "swiss army knife"
1030 /// method for materializing sparse tensors into the computation.
1031 ///
1032 /// Action:
1033 /// kEmpty = returns empty storage to fill later
1034 /// kFromFile = returns storage, where ptr contains filename to read
1035 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
1036 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
1037 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
1038 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
1039 void *
1040 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1041                              StridedMemRefType<index_type, 1> *sref,
1042                              StridedMemRefType<index_type, 1> *pref,
1043                              OverheadType ptrTp, OverheadType indTp,
1044                              PrimaryType valTp, Action action, void *ptr) {
1045   assert(aref && sref && pref);
1046   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1047          pref->strides[0] == 1);
1048   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1049   const DimLevelType *sparsity = aref->data + aref->offset;
1050   const index_type *shape = sref->data + sref->offset;
1051   const index_type *perm = pref->data + pref->offset;
1052   uint64_t rank = aref->sizes[0];
1053 
1054   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1055   // This is safe because of the static_assert above.
1056   if (ptrTp == OverheadType::kIndex)
1057     ptrTp = OverheadType::kU64;
1058   if (indTp == OverheadType::kIndex)
1059     indTp = OverheadType::kU64;
1060 
1061   // Double matrices with all combinations of overhead storage.
1062   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1063        uint64_t, double);
1064   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1065        uint32_t, double);
1066   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1067        uint16_t, double);
1068   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1069        uint8_t, double);
1070   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1071        uint64_t, double);
1072   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1073        uint32_t, double);
1074   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1075        uint16_t, double);
1076   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1077        uint8_t, double);
1078   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1079        uint64_t, double);
1080   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1081        uint32_t, double);
1082   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1083        uint16_t, double);
1084   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1085        uint8_t, double);
1086   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1087        uint64_t, double);
1088   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1089        uint32_t, double);
1090   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1091        uint16_t, double);
1092   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1093        uint8_t, double);
1094 
1095   // Float matrices with all combinations of overhead storage.
1096   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1097        uint64_t, float);
1098   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1099        uint32_t, float);
1100   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1101        uint16_t, float);
1102   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1103        uint8_t, float);
1104   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1105        uint64_t, float);
1106   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1107        uint32_t, float);
1108   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1109        uint16_t, float);
1110   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1111        uint8_t, float);
1112   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1113        uint64_t, float);
1114   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1115        uint32_t, float);
1116   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1117        uint16_t, float);
1118   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1119        uint8_t, float);
1120   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1121        uint64_t, float);
1122   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1123        uint32_t, float);
1124   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1125        uint16_t, float);
1126   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1127        uint8_t, float);
1128 
1129   // Integral matrices with both overheads of the same type.
1130   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1131   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1132   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1133   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1134   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1135   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1136   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1137   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1138   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1139   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1140   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1141   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1142   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1143 
1144   // Unsupported case (add above if needed).
1145   fputs("unsupported combination of types\n", stderr);
1146   exit(1);
1147 }
1148 
1149 /// Methods that provide direct access to pointers.
1150 IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
1151 IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
1152 IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
1153 IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
1154 IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
1155 
1156 /// Methods that provide direct access to indices.
1157 IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
1158 IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
1159 IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
1160 IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
1161 IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
1162 
1163 /// Methods that provide direct access to values.
1164 IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
1165 IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
1166 IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
1167 IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
1168 IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
1169 IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
1170 
1171 /// Helper to add value to coordinate scheme, one per value type.
1172 IMPL_ADDELT(addEltF64, double)
1173 IMPL_ADDELT(addEltF32, float)
1174 IMPL_ADDELT(addEltI64, int64_t)
1175 IMPL_ADDELT(addEltI32, int32_t)
1176 IMPL_ADDELT(addEltI16, int16_t)
1177 IMPL_ADDELT(addEltI8, int8_t)
1178 
1179 /// Helper to enumerate elements of coordinate scheme, one per value type.
1180 IMPL_GETNEXT(getNextF64, double)
1181 IMPL_GETNEXT(getNextF32, float)
1182 IMPL_GETNEXT(getNextI64, int64_t)
1183 IMPL_GETNEXT(getNextI32, int32_t)
1184 IMPL_GETNEXT(getNextI16, int16_t)
1185 IMPL_GETNEXT(getNextI8, int8_t)
1186 
1187 /// Insert elements in lexicographical index order, one per value type.
1188 IMPL_LEXINSERT(lexInsertF64, double)
1189 IMPL_LEXINSERT(lexInsertF32, float)
1190 IMPL_LEXINSERT(lexInsertI64, int64_t)
1191 IMPL_LEXINSERT(lexInsertI32, int32_t)
1192 IMPL_LEXINSERT(lexInsertI16, int16_t)
1193 IMPL_LEXINSERT(lexInsertI8, int8_t)
1194 
1195 /// Insert using expansion, one per value type.
1196 IMPL_EXPINSERT(expInsertF64, double)
1197 IMPL_EXPINSERT(expInsertF32, float)
1198 IMPL_EXPINSERT(expInsertI64, int64_t)
1199 IMPL_EXPINSERT(expInsertI32, int32_t)
1200 IMPL_EXPINSERT(expInsertI16, int16_t)
1201 IMPL_EXPINSERT(expInsertI8, int8_t)
1202 
1203 #undef CASE
1204 #undef IMPL_SPARSEVALUES
1205 #undef IMPL_GETOVERHEAD
1206 #undef IMPL_ADDELT
1207 #undef IMPL_GETNEXT
1208 #undef IMPL_LEXINSERT
1209 #undef IMPL_EXPINSERT
1210 
1211 /// Output a sparse tensor, one per value type.
1212 void outSparseTensorF64(void *tensor, void *dest, bool sort) {
1213   return outSparseTensor<double>(tensor, dest, sort);
1214 }
1215 void outSparseTensorF32(void *tensor, void *dest, bool sort) {
1216   return outSparseTensor<float>(tensor, dest, sort);
1217 }
1218 void outSparseTensorI64(void *tensor, void *dest, bool sort) {
1219   return outSparseTensor<int64_t>(tensor, dest, sort);
1220 }
1221 void outSparseTensorI32(void *tensor, void *dest, bool sort) {
1222   return outSparseTensor<int32_t>(tensor, dest, sort);
1223 }
1224 void outSparseTensorI16(void *tensor, void *dest, bool sort) {
1225   return outSparseTensor<int16_t>(tensor, dest, sort);
1226 }
1227 void outSparseTensorI8(void *tensor, void *dest, bool sort) {
1228   return outSparseTensor<int8_t>(tensor, dest, sort);
1229 }
1230 
1231 //===----------------------------------------------------------------------===//
1232 //
1233 // Public API with methods that accept C-style data structures to interact
1234 // with sparse tensors, which are only visible as opaque pointers externally.
1235 // These methods can be used both by MLIR compiler-generated code as well as by
1236 // an external runtime that wants to interact with MLIR compiler-generated code.
1237 //
1238 //===----------------------------------------------------------------------===//
1239 
1240 /// Helper method to read a sparse tensor filename from the environment,
1241 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1242 char *getTensorFilename(index_type id) {
1243   char var[80];
1244   sprintf(var, "TENSOR%" PRIu64, id);
1245   char *env = getenv(var);
1246   if (!env) {
1247     fprintf(stderr, "Environment variable %s is not set\n", var);
1248     exit(1);
1249   }
1250   return env;
1251 }
1252 
1253 /// Returns size of sparse tensor in given dimension.
1254 index_type sparseDimSize(void *tensor, index_type d) {
1255   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1256 }
1257 
1258 /// Finalizes lexicographic insertions.
1259 void endInsert(void *tensor) {
1260   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1261 }
1262 
1263 /// Releases sparse tensor storage.
1264 void delSparseTensor(void *tensor) {
1265   delete static_cast<SparseTensorStorageBase *>(tensor);
1266 }
1267 
1268 /// Releases sparse tensor coordinate scheme.
1269 #define IMPL_DELCOO(VNAME, V)                                                  \
1270   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1271     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1272   }
1273 IMPL_DELCOO(F64, double)
1274 IMPL_DELCOO(F32, float)
1275 IMPL_DELCOO(I64, int64_t)
1276 IMPL_DELCOO(I32, int32_t)
1277 IMPL_DELCOO(I16, int16_t)
1278 IMPL_DELCOO(I8, int8_t)
1279 #undef IMPL_DELCOO
1280 
1281 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
1282 /// data structures. The expected parameters are:
1283 ///
1284 ///   rank:    rank of tensor
1285 ///   nse:     number of specified elements (usually the nonzeros)
1286 ///   shape:   array with dimension size for each rank
1287 ///   values:  a "nse" array with values for all specified elements
1288 ///   indices: a flat "nse x rank" array with indices for all specified elements
1289 ///   perm:    the permutation of the dimensions in the storage
1290 ///   sparse:  the sparsity for the dimensions
1291 ///
1292 /// For example, the sparse matrix
1293 ///     | 1.0 0.0 0.0 |
1294 ///     | 0.0 5.0 3.0 |
1295 /// can be passed as
1296 ///      rank    = 2
1297 ///      nse     = 3
1298 ///      shape   = [2, 3]
1299 ///      values  = [1.0, 5.0, 3.0]
1300 ///      indices = [ 0, 0,  1, 1,  1, 2]
1301 //
1302 // TODO: generalize beyond 64-bit indices.
1303 //
1304 void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape,
1305                                    double *values, uint64_t *indices,
1306                                    uint64_t *perm, uint8_t *sparse) {
1307   return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm,
1308                                     sparse);
1309 }
1310 void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
1311                                    float *values, uint64_t *indices,
1312                                    uint64_t *perm, uint8_t *sparse) {
1313   return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm,
1314                                    sparse);
1315 }
1316 
1317 /// Converts a sparse tensor to COO-flavored format expressed using C-style
1318 /// data structures. The expected output parameters are pointers for these
1319 /// values:
1320 ///
1321 ///   rank:    rank of tensor
1322 ///   nse:     number of specified elements (usually the nonzeros)
1323 ///   shape:   array with dimension size for each rank
1324 ///   values:  a "nse" array with values for all specified elements
1325 ///   indices: a flat "nse x rank" array with indices for all specified elements
1326 ///
1327 /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
1328 /// from convertToMLIRSparseTensor.
1329 ///
1330 //  TODO: Currently, values are copied from SparseTensorStorage to
1331 //  SparseTensorCOO, then to the output. We may want to reduce the number of
1332 //  copies.
1333 //
1334 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1335 // compressed
1336 //
1337 void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank,
1338                                     uint64_t *pNse, uint64_t **pShape,
1339                                     double **pValues, uint64_t **pIndices) {
1340   fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices);
1341 }
1342 void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
1343                                     uint64_t *pNse, uint64_t **pShape,
1344                                     float **pValues, uint64_t **pIndices) {
1345   fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
1346 }
1347 
1348 } // extern "C"
1349 
1350 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1351