1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cctype>
25 #include <cinttypes>
26 #include <cstdio>
27 #include <cstdlib>
28 #include <cstring>
29 #include <fstream>
30 #include <iostream>
31 #include <limits>
32 #include <numeric>
33 #include <vector>
34 
35 //===----------------------------------------------------------------------===//
36 //
37 // Internal support for storing and reading sparse tensors.
38 //
39 // The following memory-resident sparse storage schemes are supported:
40 //
41 // (a) A coordinate scheme for temporarily storing and lexicographically
42 //     sorting a sparse tensor by index (SparseTensorCOO).
43 //
44 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
45 //     per-dimension sparse/dense annnotations together with a dimension
46 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
47 //
48 // The following external formats are supported:
49 //
50 // (1) Matrix Market Exchange (MME): *.mtx
51 //     https://math.nist.gov/MatrixMarket/formats.html
52 //
53 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
54 //     http://frostt.io/tensors/file-formats.html
55 //
56 // Two public APIs are supported:
57 //
58 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
59 //     tensors. These methods should be used exclusively by MLIR
60 //     compiler-generated code.
61 //
62 // (II) Methods that accept C-style data structures to interact with sparse
63 //      tensors. These methods can be used by any external runtime that wants
64 //      to interact with MLIR compiler-generated code.
65 //
66 // In both cases (I) and (II), the SparseTensorStorage format is externally
67 // only visible as an opaque pointer.
68 //
69 //===----------------------------------------------------------------------===//
70 
71 namespace {
72 
73 static constexpr int kColWidth = 1025;
74 
75 /// A version of `operator*` on `uint64_t` which checks for overflows.
76 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
77   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
78          "Integer overflow");
79   return lhs * rhs;
80 }
81 
82 /// A sparse tensor element in coordinate scheme (value and indices).
83 /// For example, a rank-1 vector element would look like
84 ///   ({i}, a[i])
85 /// and a rank-5 tensor element like
86 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
87 /// We use pointer to a shared index pool rather than e.g. a direct
88 /// vector since that (1) reduces the per-element memory footprint, and
89 /// (2) centralizes the memory reservation and (re)allocation to one place.
90 template <typename V>
91 struct Element {
92   Element(uint64_t *ind, V val) : indices(ind), value(val){};
93   uint64_t *indices; // pointer into shared index pool
94   V value;
95 };
96 
97 /// A memory-resident sparse tensor in coordinate scheme (collection of
98 /// elements). This data structure is used to read a sparse tensor from
99 /// any external format into memory and sort the elements lexicographically
100 /// by indices before passing it back to the client (most packed storage
101 /// formats require the elements to appear in lexicographic index order).
102 template <typename V>
103 struct SparseTensorCOO {
104 public:
105   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
106       : sizes(szs) {
107     if (capacity) {
108       elements.reserve(capacity);
109       indices.reserve(capacity * getRank());
110     }
111   }
112 
113   /// Adds element as indices and value.
114   void add(const std::vector<uint64_t> &ind, V val) {
115     assert(!iteratorLocked && "Attempt to add() after startIterator()");
116     uint64_t *base = indices.data();
117     uint64_t size = indices.size();
118     uint64_t rank = getRank();
119     assert(rank == ind.size());
120     for (uint64_t r = 0; r < rank; r++) {
121       assert(ind[r] < sizes[r]); // within bounds
122       indices.push_back(ind[r]);
123     }
124     // This base only changes if indices were reallocated. In that case, we
125     // need to correct all previous pointers into the vector. Note that this
126     // only happens if we did not set the initial capacity right, and then only
127     // for every internal vector reallocation (which with the doubling rule
128     // should only incur an amortized linear overhead).
129     uint64_t *new_base = indices.data();
130     if (new_base != base) {
131       for (uint64_t i = 0, n = elements.size(); i < n; i++)
132         elements[i].indices = new_base + (elements[i].indices - base);
133       base = new_base;
134     }
135     // Add element as (pointer into shared index pool, value) pair.
136     elements.emplace_back(base + size, val);
137   }
138 
139   /// Sorts elements lexicographically by index.
140   void sort() {
141     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
142     // TODO: we may want to cache an `isSorted` bit, to avoid
143     // unnecessary/redundant sorting.
144     std::sort(elements.begin(), elements.end(),
145               [this](const Element<V> &e1, const Element<V> &e2) {
146                 uint64_t rank = getRank();
147                 for (uint64_t r = 0; r < rank; r++) {
148                   if (e1.indices[r] == e2.indices[r])
149                     continue;
150                   return e1.indices[r] < e2.indices[r];
151                 }
152                 return false;
153               });
154   }
155 
156   /// Returns rank.
157   uint64_t getRank() const { return sizes.size(); }
158 
159   /// Getter for sizes array.
160   const std::vector<uint64_t> &getSizes() const { return sizes; }
161 
162   /// Getter for elements array.
163   const std::vector<Element<V>> &getElements() const { return elements; }
164 
165   /// Switch into iterator mode.
166   void startIterator() {
167     iteratorLocked = true;
168     iteratorPos = 0;
169   }
170 
171   /// Get the next element.
172   const Element<V> *getNext() {
173     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
174     if (iteratorPos < elements.size())
175       return &(elements[iteratorPos++]);
176     iteratorLocked = false;
177     return nullptr;
178   }
179 
180   /// Factory method. Permutes the original dimensions according to
181   /// the given ordering and expects subsequent add() calls to honor
182   /// that same ordering for the given indices. The result is a
183   /// fully permuted coordinate scheme.
184   ///
185   /// Precondition: `sizes` and `perm` must be valid for `rank`.
186   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
187                                                 const uint64_t *sizes,
188                                                 const uint64_t *perm,
189                                                 uint64_t capacity = 0) {
190     std::vector<uint64_t> permsz(rank);
191     for (uint64_t r = 0; r < rank; r++) {
192       assert(sizes[r] > 0 && "Dimension size zero has trivial storage");
193       permsz[perm[r]] = sizes[r];
194     }
195     return new SparseTensorCOO<V>(permsz, capacity);
196   }
197 
198 private:
199   const std::vector<uint64_t> sizes; // per-dimension sizes
200   std::vector<Element<V>> elements;  // all COO elements
201   std::vector<uint64_t> indices;     // shared index pool
202   bool iteratorLocked = false;
203   unsigned iteratorPos = 0;
204 };
205 
206 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
207 /// takes responsibility for all the `<P,I,V>`-independent aspects
208 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
209 /// we use function overloading to implement "partial" method
210 /// specialization, which the C-API relies on to catch type errors
211 /// arising from our use of opaque pointers.
212 class SparseTensorStorageBase {
213 public:
214   /// Constructs a new storage object.  The `perm` maps the tensor's
215   /// semantic-ordering of dimensions to this object's storage-order.
216   /// The `szs` and `sparsity` arrays are already in storage-order.
217   ///
218   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
219   SparseTensorStorageBase(const std::vector<uint64_t> &szs,
220                           const uint64_t *perm, const DimLevelType *sparsity)
221       : dimSizes(szs), rev(getRank()),
222         dimTypes(sparsity, sparsity + getRank()) {
223     const uint64_t rank = getRank();
224     // Validate parameters.
225     assert(rank > 0 && "Trivial shape is unsupported");
226     for (uint64_t r = 0; r < rank; r++) {
227       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
228       assert((dimTypes[r] == DimLevelType::kDense ||
229               dimTypes[r] == DimLevelType::kCompressed) &&
230              "Unsupported DimLevelType");
231     }
232     // Construct the "reverse" (i.e., inverse) permutation.
233     for (uint64_t r = 0; r < rank; r++)
234       rev[perm[r]] = r;
235   }
236 
237   virtual ~SparseTensorStorageBase() = default;
238 
239   /// Get the rank of the tensor.
240   uint64_t getRank() const { return dimSizes.size(); }
241 
242   /// Getter for the dimension-sizes array, in storage-order.
243   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
244 
245   /// Safely lookup the size of the given (storage-order) dimension.
246   uint64_t getDimSize(uint64_t d) const {
247     assert(d < getRank());
248     return dimSizes[d];
249   }
250 
251   /// Getter for the "reverse" permutation, which maps this object's
252   /// storage-order to the tensor's semantic-order.
253   const std::vector<uint64_t> &getRev() const { return rev; }
254 
255   /// Getter for the dimension-types array, in storage-order.
256   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
257 
258   /// Safely check if the (storage-order) dimension uses compressed storage.
259   bool isCompressedDim(uint64_t d) const {
260     assert(d < getRank());
261     return (dimTypes[d] == DimLevelType::kCompressed);
262   }
263 
264   /// Overhead storage.
265   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
266   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
267   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
268   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
269   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
270   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
271   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
272   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
273 
274   /// Primary storage.
275   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
276   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
277   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
278   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
279   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
280   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
281 
282   /// Element-wise insertion in lexicographic index order.
283   virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
284   virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); }
285   virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); }
286   virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
287   virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
288   virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
289 
290   /// Expanded insertion.
291   virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
292     fatal("expf64");
293   }
294   virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
295     fatal("expf32");
296   }
297   virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
298     fatal("expi64");
299   }
300   virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
301     fatal("expi32");
302   }
303   virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
304     fatal("expi16");
305   }
306   virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
307     fatal("expi8");
308   }
309 
310   /// Finishes insertion.
311   virtual void endInsert() = 0;
312 
313 private:
314   static void fatal(const char *tp) {
315     fprintf(stderr, "unsupported %s\n", tp);
316     exit(1);
317   }
318 
319   const std::vector<uint64_t> dimSizes;
320   std::vector<uint64_t> rev;
321   const std::vector<DimLevelType> dimTypes;
322 };
323 
324 /// A memory-resident sparse tensor using a storage scheme based on
325 /// per-dimension sparse/dense annotations. This data structure provides a
326 /// bufferized form of a sparse tensor type. In contrast to generating setup
327 /// methods for each differently annotated sparse tensor, this method provides
328 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
329 /// and annotations to implement all required setup in a general manner.
330 template <typename P, typename I, typename V>
331 class SparseTensorStorage : public SparseTensorStorageBase {
332 public:
333   /// Constructs a sparse tensor storage scheme with the given dimensions,
334   /// permutation, and per-dimension dense/sparse annotations, using
335   /// the coordinate scheme tensor for the initial contents if provided.
336   ///
337   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
338   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
339                       const DimLevelType *sparsity,
340                       SparseTensorCOO<V> *coo = nullptr)
341       : SparseTensorStorageBase(szs, perm, sparsity), pointers(getRank()),
342         indices(getRank()), idx(getRank()) {
343     // Provide hints on capacity of pointers and indices.
344     // TODO: needs much fine-tuning based on actual sparsity; currently
345     //       we reserve pointer/index space based on all previous dense
346     //       dimensions, which works well up to first sparse dim; but
347     //       we should really use nnz and dense/sparse distribution.
348     bool allDense = true;
349     uint64_t sz = 1;
350     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
351       if (isCompressedDim(r)) {
352         // TODO: Take a parameter between 1 and `sizes[r]`, and multiply
353         // `sz` by that before reserving. (For now we just use 1.)
354         pointers[r].reserve(sz + 1);
355         pointers[r].push_back(0);
356         indices[r].reserve(sz);
357         sz = 1;
358         allDense = false;
359       } else { // Dense dimension.
360         sz = checkedMul(sz, getDimSizes()[r]);
361       }
362     }
363     // Then assign contents from coordinate scheme tensor if provided.
364     if (coo) {
365       // Ensure both preconditions of `fromCOO`.
366       assert(coo->getSizes() == getDimSizes() && "Tensor size mismatch");
367       coo->sort();
368       // Now actually insert the `elements`.
369       const std::vector<Element<V>> &elements = coo->getElements();
370       uint64_t nnz = elements.size();
371       values.reserve(nnz);
372       fromCOO(elements, 0, nnz, 0);
373     } else if (allDense) {
374       values.resize(sz, 0);
375     }
376   }
377 
378   ~SparseTensorStorage() override = default;
379 
380   /// Partially specialize these getter methods based on template types.
381   void getPointers(std::vector<P> **out, uint64_t d) override {
382     assert(d < getRank());
383     *out = &pointers[d];
384   }
385   void getIndices(std::vector<I> **out, uint64_t d) override {
386     assert(d < getRank());
387     *out = &indices[d];
388   }
389   void getValues(std::vector<V> **out) override { *out = &values; }
390 
391   /// Partially specialize lexicographical insertions based on template types.
392   void lexInsert(const uint64_t *cursor, V val) override {
393     // First, wrap up pending insertion path.
394     uint64_t diff = 0;
395     uint64_t top = 0;
396     if (!values.empty()) {
397       diff = lexDiff(cursor);
398       endPath(diff + 1);
399       top = idx[diff] + 1;
400     }
401     // Then continue with insertion path.
402     insPath(cursor, diff, top, val);
403   }
404 
405   /// Partially specialize expanded insertions based on template types.
406   /// Note that this method resets the values/filled-switch array back
407   /// to all-zero/false while only iterating over the nonzero elements.
408   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
409                  uint64_t count) override {
410     if (count == 0)
411       return;
412     // Sort.
413     std::sort(added, added + count);
414     // Restore insertion path for first insert.
415     const uint64_t lastDim = getRank() - 1;
416     uint64_t index = added[0];
417     cursor[lastDim] = index;
418     lexInsert(cursor, values[index]);
419     assert(filled[index]);
420     values[index] = 0;
421     filled[index] = false;
422     // Subsequent insertions are quick.
423     for (uint64_t i = 1; i < count; i++) {
424       assert(index < added[i] && "non-lexicographic insertion");
425       index = added[i];
426       cursor[lastDim] = index;
427       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
428       assert(filled[index]);
429       values[index] = 0;
430       filled[index] = false;
431     }
432   }
433 
434   /// Finalizes lexicographic insertions.
435   void endInsert() override {
436     if (values.empty())
437       finalizeSegment(0);
438     else
439       endPath(0);
440   }
441 
442   /// Returns this sparse tensor storage scheme as a new memory-resident
443   /// sparse tensor in coordinate scheme with the given dimension order.
444   ///
445   /// Precondition: `perm` must be valid for `getRank()`.
446   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
447     // Restore original order of the dimension sizes and allocate coordinate
448     // scheme with desired new ordering specified in perm.
449     const uint64_t rank = getRank();
450     const auto &rev = getRev();
451     const auto &sizes = getDimSizes();
452     std::vector<uint64_t> orgsz(rank);
453     for (uint64_t r = 0; r < rank; r++)
454       orgsz[rev[r]] = sizes[r];
455     SparseTensorCOO<V> *coo = SparseTensorCOO<V>::newSparseTensorCOO(
456         rank, orgsz.data(), perm, values.size());
457     // Populate coordinate scheme restored from old ordering and changed with
458     // new ordering. Rather than applying both reorderings during the recursion,
459     // we compute the combine permutation in advance.
460     std::vector<uint64_t> reord(rank);
461     for (uint64_t r = 0; r < rank; r++)
462       reord[r] = perm[rev[r]];
463     toCOO(*coo, reord, 0, 0);
464     // TODO: This assertion assumes there are no stored zeros,
465     // or if there are then that we don't filter them out.
466     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
467     assert(coo->getElements().size() == values.size());
468     return coo;
469   }
470 
471   /// Factory method. Constructs a sparse tensor storage scheme with the given
472   /// dimensions, permutation, and per-dimension dense/sparse annotations,
473   /// using the coordinate scheme tensor for the initial contents if provided.
474   /// In the latter case, the coordinate scheme must respect the same
475   /// permutation as is desired for the new sparse tensor storage.
476   ///
477   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
478   static SparseTensorStorage<P, I, V> *
479   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
480                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
481     SparseTensorStorage<P, I, V> *n = nullptr;
482     if (coo) {
483       assert(coo->getRank() == rank && "Tensor rank mismatch");
484       const auto &coosz = coo->getSizes();
485       for (uint64_t r = 0; r < rank; r++)
486         assert(shape[r] == 0 || shape[r] == coosz[perm[r]]);
487       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
488     } else {
489       std::vector<uint64_t> permsz(rank);
490       for (uint64_t r = 0; r < rank; r++) {
491         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
492         permsz[perm[r]] = shape[r];
493       }
494       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
495     }
496     return n;
497   }
498 
499 private:
500   /// Appends an arbitrary new position to `pointers[d]`.  This method
501   /// checks that `pos` is representable in the `P` type; however, it
502   /// does not check that `pos` is semantically valid (i.e., larger than
503   /// the previous position and smaller than `indices[d].capacity()`).
504   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
505     assert(isCompressedDim(d));
506     assert(pos <= std::numeric_limits<P>::max() &&
507            "Pointer value is too large for the P-type");
508     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
509   }
510 
511   /// Appends index `i` to dimension `d`, in the semantically general
512   /// sense.  For non-dense dimensions, that means appending to the
513   /// `indices[d]` array, checking that `i` is representable in the `I`
514   /// type; however, we do not verify other semantic requirements (e.g.,
515   /// that `i` is in bounds for `sizes[d]`, and not previously occurring
516   /// in the same segment).  For dense dimensions, this method instead
517   /// appends the appropriate number of zeros to the `values` array,
518   /// where `full` is the number of "entries" already written to `values`
519   /// for this segment (aka one after the highest index previously appended).
520   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
521     if (isCompressedDim(d)) {
522       assert(i <= std::numeric_limits<I>::max() &&
523              "Index value is too large for the I-type");
524       indices[d].push_back(static_cast<I>(i));
525     } else { // Dense dimension.
526       assert(i >= full && "Index was already filled");
527       if (i == full)
528         return; // Short-circuit, since it'll be a nop.
529       if (d + 1 == getRank())
530         values.insert(values.end(), i - full, 0);
531       else
532         finalizeSegment(d + 1, 0, i - full);
533     }
534   }
535 
536   /// Initializes sparse tensor storage scheme from a memory-resident sparse
537   /// tensor in coordinate scheme. This method prepares the pointers and
538   /// indices arrays under the given per-dimension dense/sparse annotations.
539   ///
540   /// Preconditions:
541   /// (1) the `elements` must be lexicographically sorted.
542   /// (2) the indices of every element are valid for `sizes` (equal rank
543   ///     and pointwise less-than).
544   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
545                uint64_t hi, uint64_t d) {
546     // Once dimensions are exhausted, insert the numerical values.
547     assert(d <= getRank() && hi <= elements.size());
548     if (d == getRank()) {
549       assert(lo < hi);
550       values.push_back(elements[lo].value);
551       return;
552     }
553     // Visit all elements in this interval.
554     uint64_t full = 0;
555     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
556       // Find segment in interval with same index elements in this dimension.
557       uint64_t i = elements[lo].indices[d];
558       uint64_t seg = lo + 1;
559       while (seg < hi && elements[seg].indices[d] == i)
560         seg++;
561       // Handle segment in interval for sparse or dense dimension.
562       appendIndex(d, full, i);
563       full = i + 1;
564       fromCOO(elements, lo, seg, d + 1);
565       // And move on to next segment in interval.
566       lo = seg;
567     }
568     // Finalize the sparse pointer structure at this dimension.
569     finalizeSegment(d, full);
570   }
571 
572   /// Stores the sparse tensor storage scheme into a memory-resident sparse
573   /// tensor in coordinate scheme.
574   void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord,
575              uint64_t pos, uint64_t d) {
576     assert(d <= getRank());
577     if (d == getRank()) {
578       assert(pos < values.size());
579       tensor.add(idx, values[pos]);
580     } else if (isCompressedDim(d)) {
581       // Sparse dimension.
582       for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
583         idx[reord[d]] = indices[d][ii];
584         toCOO(tensor, reord, ii, d + 1);
585       }
586     } else {
587       // Dense dimension.
588       const uint64_t sz = getDimSizes()[d];
589       const uint64_t off = pos * sz;
590       for (uint64_t i = 0; i < sz; i++) {
591         idx[reord[d]] = i;
592         toCOO(tensor, reord, off + i, d + 1);
593       }
594     }
595   }
596 
597   /// Finalize the sparse pointer structure at this dimension.
598   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
599     if (count == 0)
600       return; // Short-circuit, since it'll be a nop.
601     if (isCompressedDim(d)) {
602       appendPointer(d, indices[d].size(), count);
603     } else { // Dense dimension.
604       const uint64_t sz = getDimSizes()[d];
605       assert(sz >= full && "Segment is overfull");
606       count = checkedMul(count, sz - full);
607       // For dense storage we must enumerate all the remaining coordinates
608       // in this dimension (i.e., coordinates after the last non-zero
609       // element), and either fill in their zero values or else recurse
610       // to finalize some deeper dimension.
611       if (d + 1 == getRank())
612         values.insert(values.end(), count, 0);
613       else
614         finalizeSegment(d + 1, 0, count);
615     }
616   }
617 
618   /// Wraps up a single insertion path, inner to outer.
619   void endPath(uint64_t diff) {
620     uint64_t rank = getRank();
621     assert(diff <= rank);
622     for (uint64_t i = 0; i < rank - diff; i++) {
623       const uint64_t d = rank - i - 1;
624       finalizeSegment(d, idx[d] + 1);
625     }
626   }
627 
628   /// Continues a single insertion path, outer to inner.
629   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
630     uint64_t rank = getRank();
631     assert(diff < rank);
632     for (uint64_t d = diff; d < rank; d++) {
633       uint64_t i = cursor[d];
634       appendIndex(d, top, i);
635       top = 0;
636       idx[d] = i;
637     }
638     values.push_back(val);
639   }
640 
641   /// Finds the lexicographic differing dimension.
642   uint64_t lexDiff(const uint64_t *cursor) const {
643     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
644       if (cursor[r] > idx[r])
645         return r;
646       else
647         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
648     assert(0 && "duplication insertion");
649     return -1u;
650   }
651 
652 private:
653   std::vector<std::vector<P>> pointers;
654   std::vector<std::vector<I>> indices;
655   std::vector<V> values;
656   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
657 };
658 
659 /// Helper to convert string to lower case.
660 static char *toLower(char *token) {
661   for (char *c = token; *c; c++)
662     *c = tolower(*c);
663   return token;
664 }
665 
666 /// Read the MME header of a general sparse matrix of type real.
667 static void readMMEHeader(FILE *file, char *filename, char *line,
668                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
669   char header[64];
670   char object[64];
671   char format[64];
672   char field[64];
673   char symmetry[64];
674   // Read header line.
675   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
676              symmetry) != 5) {
677     fprintf(stderr, "Corrupt header in %s\n", filename);
678     exit(1);
679   }
680   // Set properties
681   *isPattern = (strcmp(toLower(field), "pattern") == 0);
682   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
683   // Make sure this is a general sparse matrix.
684   if (strcmp(toLower(header), "%%matrixmarket") ||
685       strcmp(toLower(object), "matrix") ||
686       strcmp(toLower(format), "coordinate") ||
687       (strcmp(toLower(field), "real") && !(*isPattern)) ||
688       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
689     fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
690     exit(1);
691   }
692   // Skip comments.
693   while (true) {
694     if (!fgets(line, kColWidth, file)) {
695       fprintf(stderr, "Cannot find data in %s\n", filename);
696       exit(1);
697     }
698     if (line[0] != '%')
699       break;
700   }
701   // Next line contains M N NNZ.
702   idata[0] = 2; // rank
703   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
704              idata + 1) != 3) {
705     fprintf(stderr, "Cannot find size in %s\n", filename);
706     exit(1);
707   }
708 }
709 
710 /// Read the "extended" FROSTT header. Although not part of the documented
711 /// format, we assume that the file starts with optional comments followed
712 /// by two lines that define the rank, the number of nonzeros, and the
713 /// dimensions sizes (one per rank) of the sparse tensor.
714 static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
715                                 uint64_t *idata) {
716   // Skip comments.
717   while (true) {
718     if (!fgets(line, kColWidth, file)) {
719       fprintf(stderr, "Cannot find data in %s\n", filename);
720       exit(1);
721     }
722     if (line[0] != '#')
723       break;
724   }
725   // Next line contains RANK and NNZ.
726   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
727     fprintf(stderr, "Cannot find metadata in %s\n", filename);
728     exit(1);
729   }
730   // Followed by a line with the dimension sizes (one per rank).
731   for (uint64_t r = 0; r < idata[0]; r++) {
732     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
733       fprintf(stderr, "Cannot find dimension size %s\n", filename);
734       exit(1);
735     }
736   }
737   fgets(line, kColWidth, file); // end of line
738 }
739 
740 /// Reads a sparse tensor with the given filename into a memory-resident
741 /// sparse tensor in coordinate scheme.
742 template <typename V>
743 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
744                                                const uint64_t *shape,
745                                                const uint64_t *perm) {
746   // Open the file.
747   FILE *file = fopen(filename, "r");
748   if (!file) {
749     assert(filename && "Received nullptr for filename");
750     fprintf(stderr, "Cannot find file %s\n", filename);
751     exit(1);
752   }
753   // Perform some file format dependent set up.
754   char line[kColWidth];
755   uint64_t idata[512];
756   bool isPattern = false;
757   bool isSymmetric = false;
758   if (strstr(filename, ".mtx")) {
759     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
760   } else if (strstr(filename, ".tns")) {
761     readExtFROSTTHeader(file, filename, line, idata);
762   } else {
763     fprintf(stderr, "Unknown format %s\n", filename);
764     exit(1);
765   }
766   // Prepare sparse tensor object with per-dimension sizes
767   // and the number of nonzeros as initial capacity.
768   assert(rank == idata[0] && "rank mismatch");
769   uint64_t nnz = idata[1];
770   for (uint64_t r = 0; r < rank; r++)
771     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
772            "dimension size mismatch");
773   SparseTensorCOO<V> *tensor =
774       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
775   //  Read all nonzero elements.
776   std::vector<uint64_t> indices(rank);
777   for (uint64_t k = 0; k < nnz; k++) {
778     if (!fgets(line, kColWidth, file)) {
779       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
780       exit(1);
781     }
782     char *linePtr = line;
783     for (uint64_t r = 0; r < rank; r++) {
784       uint64_t idx = strtoul(linePtr, &linePtr, 10);
785       // Add 0-based index.
786       indices[perm[r]] = idx - 1;
787     }
788     // The external formats always store the numerical values with the type
789     // double, but we cast these values to the sparse tensor object type.
790     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
791     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
792     tensor->add(indices, value);
793     // We currently chose to deal with symmetric matrices by fully constructing
794     // them. In the future, we may want to make symmetry implicit for storage
795     // reasons.
796     if (isSymmetric && indices[0] != indices[1])
797       tensor->add({indices[1], indices[0]}, value);
798   }
799   // Close the file and return tensor.
800   fclose(file);
801   return tensor;
802 }
803 
804 /// Writes the sparse tensor to extended FROSTT format.
805 template <typename V>
806 static void outSparseTensor(void *tensor, void *dest, bool sort) {
807   assert(tensor && dest);
808   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
809   if (sort)
810     coo->sort();
811   char *filename = static_cast<char *>(dest);
812   auto &sizes = coo->getSizes();
813   auto &elements = coo->getElements();
814   uint64_t rank = coo->getRank();
815   uint64_t nnz = elements.size();
816   std::fstream file;
817   file.open(filename, std::ios_base::out | std::ios_base::trunc);
818   assert(file.is_open());
819   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
820   for (uint64_t r = 0; r < rank - 1; r++)
821     file << sizes[r] << " ";
822   file << sizes[rank - 1] << std::endl;
823   for (uint64_t i = 0; i < nnz; i++) {
824     auto &idx = elements[i].indices;
825     for (uint64_t r = 0; r < rank; r++)
826       file << (idx[r] + 1) << " ";
827     file << elements[i].value << std::endl;
828   }
829   file.flush();
830   file.close();
831   assert(file.good());
832 }
833 
834 /// Initializes sparse tensor from an external COO-flavored format.
835 template <typename V>
836 static SparseTensorStorage<uint64_t, uint64_t, V> *
837 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
838                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
839   const DimLevelType *sparsity = (DimLevelType *)(sparse);
840 #ifndef NDEBUG
841   // Verify that perm is a permutation of 0..(rank-1).
842   std::vector<uint64_t> order(perm, perm + rank);
843   std::sort(order.begin(), order.end());
844   for (uint64_t i = 0; i < rank; ++i) {
845     if (i != order[i]) {
846       fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
847       exit(1);
848     }
849   }
850 
851   // Verify that the sparsity values are supported.
852   for (uint64_t i = 0; i < rank; ++i) {
853     if (sparsity[i] != DimLevelType::kDense &&
854         sparsity[i] != DimLevelType::kCompressed) {
855       fprintf(stderr, "Unsupported sparsity value %d\n",
856               static_cast<int>(sparsity[i]));
857       exit(1);
858     }
859   }
860 #endif
861 
862   // Convert external format to internal COO.
863   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
864   std::vector<uint64_t> idx(rank);
865   for (uint64_t i = 0, base = 0; i < nse; i++) {
866     for (uint64_t r = 0; r < rank; r++)
867       idx[perm[r]] = indices[base + r];
868     coo->add(idx, values[i]);
869     base += rank;
870   }
871   // Return sparse tensor storage format as opaque pointer.
872   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
873       rank, shape, perm, sparsity, coo);
874   delete coo;
875   return tensor;
876 }
877 
878 /// Converts a sparse tensor to an external COO-flavored format.
879 template <typename V>
880 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
881                                  uint64_t **pShape, V **pValues,
882                                  uint64_t **pIndices) {
883   auto sparseTensor =
884       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
885   uint64_t rank = sparseTensor->getRank();
886   std::vector<uint64_t> perm(rank);
887   std::iota(perm.begin(), perm.end(), 0);
888   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
889 
890   const std::vector<Element<V>> &elements = coo->getElements();
891   uint64_t nse = elements.size();
892 
893   uint64_t *shape = new uint64_t[rank];
894   for (uint64_t i = 0; i < rank; i++)
895     shape[i] = coo->getSizes()[i];
896 
897   V *values = new V[nse];
898   uint64_t *indices = new uint64_t[rank * nse];
899 
900   for (uint64_t i = 0, base = 0; i < nse; i++) {
901     values[i] = elements[i].value;
902     for (uint64_t j = 0; j < rank; j++)
903       indices[base + j] = elements[i].indices[j];
904     base += rank;
905   }
906 
907   delete coo;
908   *pRank = rank;
909   *pNse = nse;
910   *pShape = shape;
911   *pValues = values;
912   *pIndices = indices;
913 }
914 
915 } // namespace
916 
917 extern "C" {
918 
919 //===----------------------------------------------------------------------===//
920 //
921 // Public API with methods that operate on MLIR buffers (memrefs) to interact
922 // with sparse tensors, which are only visible as opaque pointers externally.
923 // These methods should be used exclusively by MLIR compiler-generated code.
924 //
925 // Some macro magic is used to generate implementations for all required type
926 // combinations that can be called from MLIR compiler-generated code.
927 //
928 //===----------------------------------------------------------------------===//
929 
930 #define CASE(p, i, v, P, I, V)                                                 \
931   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
932     SparseTensorCOO<V> *coo = nullptr;                                         \
933     if (action <= Action::kFromCOO) {                                          \
934       if (action == Action::kFromFile) {                                       \
935         char *filename = static_cast<char *>(ptr);                             \
936         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
937       } else if (action == Action::kFromCOO) {                                 \
938         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
939       } else {                                                                 \
940         assert(action == Action::kEmpty);                                      \
941       }                                                                        \
942       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
943           rank, shape, perm, sparsity, coo);                                   \
944       if (action == Action::kFromFile)                                         \
945         delete coo;                                                            \
946       return tensor;                                                           \
947     }                                                                          \
948     if (action == Action::kEmptyCOO)                                           \
949       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
950     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
951     if (action == Action::kToIterator) {                                       \
952       coo->startIterator();                                                    \
953     } else {                                                                   \
954       assert(action == Action::kToCOO);                                        \
955     }                                                                          \
956     return coo;                                                                \
957   }
958 
959 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
960 
961 #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
962   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
963     assert(ref &&tensor);                                                      \
964     std::vector<TYPE> *v;                                                      \
965     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
966     ref->basePtr = ref->data = v->data();                                      \
967     ref->offset = 0;                                                           \
968     ref->sizes[0] = v->size();                                                 \
969     ref->strides[0] = 1;                                                       \
970   }
971 
972 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
973   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
974                            index_type d) {                                     \
975     assert(ref &&tensor);                                                      \
976     std::vector<TYPE> *v;                                                      \
977     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
978     ref->basePtr = ref->data = v->data();                                      \
979     ref->offset = 0;                                                           \
980     ref->sizes[0] = v->size();                                                 \
981     ref->strides[0] = 1;                                                       \
982   }
983 
984 #define IMPL_ADDELT(NAME, TYPE)                                                \
985   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
986                             StridedMemRefType<index_type, 1> *iref,            \
987                             StridedMemRefType<index_type, 1> *pref) {          \
988     assert(tensor &&iref &&pref);                                              \
989     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
990     assert(iref->sizes[0] == pref->sizes[0]);                                  \
991     const index_type *indx = iref->data + iref->offset;                        \
992     const index_type *perm = pref->data + pref->offset;                        \
993     uint64_t isize = iref->sizes[0];                                           \
994     std::vector<index_type> indices(isize);                                    \
995     for (uint64_t r = 0; r < isize; r++)                                       \
996       indices[perm[r]] = indx[r];                                              \
997     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
998     return tensor;                                                             \
999   }
1000 
1001 #define IMPL_GETNEXT(NAME, V)                                                  \
1002   bool _mlir_ciface_##NAME(void *tensor,                                       \
1003                            StridedMemRefType<index_type, 1> *iref,             \
1004                            StridedMemRefType<V, 0> *vref) {                    \
1005     assert(tensor &&iref &&vref);                                              \
1006     assert(iref->strides[0] == 1);                                             \
1007     index_type *indx = iref->data + iref->offset;                              \
1008     V *value = vref->data + vref->offset;                                      \
1009     const uint64_t isize = iref->sizes[0];                                     \
1010     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
1011     const Element<V> *elem = iter->getNext();                                  \
1012     if (elem == nullptr)                                                       \
1013       return false;                                                            \
1014     for (uint64_t r = 0; r < isize; r++)                                       \
1015       indx[r] = elem->indices[r];                                              \
1016     *value = elem->value;                                                      \
1017     return true;                                                               \
1018   }
1019 
1020 #define IMPL_LEXINSERT(NAME, V)                                                \
1021   void _mlir_ciface_##NAME(void *tensor,                                       \
1022                            StridedMemRefType<index_type, 1> *cref, V val) {    \
1023     assert(tensor &&cref);                                                     \
1024     assert(cref->strides[0] == 1);                                             \
1025     index_type *cursor = cref->data + cref->offset;                            \
1026     assert(cursor);                                                            \
1027     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1028   }
1029 
1030 #define IMPL_EXPINSERT(NAME, V)                                                \
1031   void _mlir_ciface_##NAME(                                                    \
1032       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1033       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1034       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1035     assert(tensor &&cref &&vref &&fref &&aref);                                \
1036     assert(cref->strides[0] == 1);                                             \
1037     assert(vref->strides[0] == 1);                                             \
1038     assert(fref->strides[0] == 1);                                             \
1039     assert(aref->strides[0] == 1);                                             \
1040     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1041     index_type *cursor = cref->data + cref->offset;                            \
1042     V *values = vref->data + vref->offset;                                     \
1043     bool *filled = fref->data + fref->offset;                                  \
1044     index_type *added = aref->data + aref->offset;                             \
1045     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1046         cursor, values, filled, added, count);                                 \
1047   }
1048 
1049 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1050 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1051 // that this file cannot get out of sync with its header.
1052 static_assert(std::is_same<index_type, uint64_t>::value,
1053               "Expected index_type == uint64_t");
1054 
1055 /// Constructs a new sparse tensor. This is the "swiss army knife"
1056 /// method for materializing sparse tensors into the computation.
1057 ///
1058 /// Action:
1059 /// kEmpty = returns empty storage to fill later
1060 /// kFromFile = returns storage, where ptr contains filename to read
1061 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
1062 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
1063 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
1064 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
1065 void *
1066 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1067                              StridedMemRefType<index_type, 1> *sref,
1068                              StridedMemRefType<index_type, 1> *pref,
1069                              OverheadType ptrTp, OverheadType indTp,
1070                              PrimaryType valTp, Action action, void *ptr) {
1071   assert(aref && sref && pref);
1072   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1073          pref->strides[0] == 1);
1074   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1075   const DimLevelType *sparsity = aref->data + aref->offset;
1076   const index_type *shape = sref->data + sref->offset;
1077   const index_type *perm = pref->data + pref->offset;
1078   uint64_t rank = aref->sizes[0];
1079 
1080   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1081   // This is safe because of the static_assert above.
1082   if (ptrTp == OverheadType::kIndex)
1083     ptrTp = OverheadType::kU64;
1084   if (indTp == OverheadType::kIndex)
1085     indTp = OverheadType::kU64;
1086 
1087   // Double matrices with all combinations of overhead storage.
1088   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1089        uint64_t, double);
1090   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1091        uint32_t, double);
1092   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1093        uint16_t, double);
1094   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1095        uint8_t, double);
1096   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1097        uint64_t, double);
1098   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1099        uint32_t, double);
1100   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1101        uint16_t, double);
1102   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1103        uint8_t, double);
1104   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1105        uint64_t, double);
1106   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1107        uint32_t, double);
1108   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1109        uint16_t, double);
1110   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1111        uint8_t, double);
1112   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1113        uint64_t, double);
1114   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1115        uint32_t, double);
1116   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1117        uint16_t, double);
1118   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1119        uint8_t, double);
1120 
1121   // Float matrices with all combinations of overhead storage.
1122   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1123        uint64_t, float);
1124   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1125        uint32_t, float);
1126   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1127        uint16_t, float);
1128   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1129        uint8_t, float);
1130   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1131        uint64_t, float);
1132   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1133        uint32_t, float);
1134   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1135        uint16_t, float);
1136   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1137        uint8_t, float);
1138   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1139        uint64_t, float);
1140   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1141        uint32_t, float);
1142   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1143        uint16_t, float);
1144   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1145        uint8_t, float);
1146   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1147        uint64_t, float);
1148   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1149        uint32_t, float);
1150   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1151        uint16_t, float);
1152   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1153        uint8_t, float);
1154 
1155   // Integral matrices with both overheads of the same type.
1156   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1157   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1158   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1159   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1160   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1161   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1162   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1163   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1164   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1165   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1166   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1167   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1168   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1169 
1170   // Unsupported case (add above if needed).
1171   fputs("unsupported combination of types\n", stderr);
1172   exit(1);
1173 }
1174 
1175 /// Methods that provide direct access to pointers.
1176 IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
1177 IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
1178 IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
1179 IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
1180 IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
1181 
1182 /// Methods that provide direct access to indices.
1183 IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
1184 IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
1185 IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
1186 IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
1187 IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
1188 
1189 /// Methods that provide direct access to values.
1190 IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
1191 IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
1192 IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
1193 IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
1194 IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
1195 IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
1196 
1197 /// Helper to add value to coordinate scheme, one per value type.
1198 IMPL_ADDELT(addEltF64, double)
1199 IMPL_ADDELT(addEltF32, float)
1200 IMPL_ADDELT(addEltI64, int64_t)
1201 IMPL_ADDELT(addEltI32, int32_t)
1202 IMPL_ADDELT(addEltI16, int16_t)
1203 IMPL_ADDELT(addEltI8, int8_t)
1204 
1205 /// Helper to enumerate elements of coordinate scheme, one per value type.
1206 IMPL_GETNEXT(getNextF64, double)
1207 IMPL_GETNEXT(getNextF32, float)
1208 IMPL_GETNEXT(getNextI64, int64_t)
1209 IMPL_GETNEXT(getNextI32, int32_t)
1210 IMPL_GETNEXT(getNextI16, int16_t)
1211 IMPL_GETNEXT(getNextI8, int8_t)
1212 
1213 /// Insert elements in lexicographical index order, one per value type.
1214 IMPL_LEXINSERT(lexInsertF64, double)
1215 IMPL_LEXINSERT(lexInsertF32, float)
1216 IMPL_LEXINSERT(lexInsertI64, int64_t)
1217 IMPL_LEXINSERT(lexInsertI32, int32_t)
1218 IMPL_LEXINSERT(lexInsertI16, int16_t)
1219 IMPL_LEXINSERT(lexInsertI8, int8_t)
1220 
1221 /// Insert using expansion, one per value type.
1222 IMPL_EXPINSERT(expInsertF64, double)
1223 IMPL_EXPINSERT(expInsertF32, float)
1224 IMPL_EXPINSERT(expInsertI64, int64_t)
1225 IMPL_EXPINSERT(expInsertI32, int32_t)
1226 IMPL_EXPINSERT(expInsertI16, int16_t)
1227 IMPL_EXPINSERT(expInsertI8, int8_t)
1228 
1229 #undef CASE
1230 #undef IMPL_SPARSEVALUES
1231 #undef IMPL_GETOVERHEAD
1232 #undef IMPL_ADDELT
1233 #undef IMPL_GETNEXT
1234 #undef IMPL_LEXINSERT
1235 #undef IMPL_EXPINSERT
1236 
1237 /// Output a sparse tensor, one per value type.
1238 void outSparseTensorF64(void *tensor, void *dest, bool sort) {
1239   return outSparseTensor<double>(tensor, dest, sort);
1240 }
1241 void outSparseTensorF32(void *tensor, void *dest, bool sort) {
1242   return outSparseTensor<float>(tensor, dest, sort);
1243 }
1244 void outSparseTensorI64(void *tensor, void *dest, bool sort) {
1245   return outSparseTensor<int64_t>(tensor, dest, sort);
1246 }
1247 void outSparseTensorI32(void *tensor, void *dest, bool sort) {
1248   return outSparseTensor<int32_t>(tensor, dest, sort);
1249 }
1250 void outSparseTensorI16(void *tensor, void *dest, bool sort) {
1251   return outSparseTensor<int16_t>(tensor, dest, sort);
1252 }
1253 void outSparseTensorI8(void *tensor, void *dest, bool sort) {
1254   return outSparseTensor<int8_t>(tensor, dest, sort);
1255 }
1256 
1257 //===----------------------------------------------------------------------===//
1258 //
1259 // Public API with methods that accept C-style data structures to interact
1260 // with sparse tensors, which are only visible as opaque pointers externally.
1261 // These methods can be used both by MLIR compiler-generated code as well as by
1262 // an external runtime that wants to interact with MLIR compiler-generated code.
1263 //
1264 //===----------------------------------------------------------------------===//
1265 
1266 /// Helper method to read a sparse tensor filename from the environment,
1267 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1268 char *getTensorFilename(index_type id) {
1269   char var[80];
1270   sprintf(var, "TENSOR%" PRIu64, id);
1271   char *env = getenv(var);
1272   if (!env) {
1273     fprintf(stderr, "Environment variable %s is not set\n", var);
1274     exit(1);
1275   }
1276   return env;
1277 }
1278 
1279 /// Returns size of sparse tensor in given dimension.
1280 index_type sparseDimSize(void *tensor, index_type d) {
1281   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1282 }
1283 
1284 /// Finalizes lexicographic insertions.
1285 void endInsert(void *tensor) {
1286   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1287 }
1288 
1289 /// Releases sparse tensor storage.
1290 void delSparseTensor(void *tensor) {
1291   delete static_cast<SparseTensorStorageBase *>(tensor);
1292 }
1293 
1294 /// Releases sparse tensor coordinate scheme.
1295 #define IMPL_DELCOO(VNAME, V)                                                  \
1296   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1297     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1298   }
1299 IMPL_DELCOO(F64, double)
1300 IMPL_DELCOO(F32, float)
1301 IMPL_DELCOO(I64, int64_t)
1302 IMPL_DELCOO(I32, int32_t)
1303 IMPL_DELCOO(I16, int16_t)
1304 IMPL_DELCOO(I8, int8_t)
1305 #undef IMPL_DELCOO
1306 
1307 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
1308 /// data structures. The expected parameters are:
1309 ///
1310 ///   rank:    rank of tensor
1311 ///   nse:     number of specified elements (usually the nonzeros)
1312 ///   shape:   array with dimension size for each rank
1313 ///   values:  a "nse" array with values for all specified elements
1314 ///   indices: a flat "nse x rank" array with indices for all specified elements
1315 ///   perm:    the permutation of the dimensions in the storage
1316 ///   sparse:  the sparsity for the dimensions
1317 ///
1318 /// For example, the sparse matrix
1319 ///     | 1.0 0.0 0.0 |
1320 ///     | 0.0 5.0 3.0 |
1321 /// can be passed as
1322 ///      rank    = 2
1323 ///      nse     = 3
1324 ///      shape   = [2, 3]
1325 ///      values  = [1.0, 5.0, 3.0]
1326 ///      indices = [ 0, 0,  1, 1,  1, 2]
1327 //
1328 // TODO: generalize beyond 64-bit indices.
1329 //
1330 void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape,
1331                                    double *values, uint64_t *indices,
1332                                    uint64_t *perm, uint8_t *sparse) {
1333   return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm,
1334                                     sparse);
1335 }
1336 void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
1337                                    float *values, uint64_t *indices,
1338                                    uint64_t *perm, uint8_t *sparse) {
1339   return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm,
1340                                    sparse);
1341 }
1342 void *convertToMLIRSparseTensorI64(uint64_t rank, uint64_t nse, uint64_t *shape,
1343                                    int64_t *values, uint64_t *indices,
1344                                    uint64_t *perm, uint8_t *sparse) {
1345   return toMLIRSparseTensor<int64_t>(rank, nse, shape, values, indices, perm,
1346                                      sparse);
1347 }
1348 void *convertToMLIRSparseTensorI32(uint64_t rank, uint64_t nse, uint64_t *shape,
1349                                    int32_t *values, uint64_t *indices,
1350                                    uint64_t *perm, uint8_t *sparse) {
1351   return toMLIRSparseTensor<int32_t>(rank, nse, shape, values, indices, perm,
1352                                      sparse);
1353 }
1354 void *convertToMLIRSparseTensorI16(uint64_t rank, uint64_t nse, uint64_t *shape,
1355                                    int16_t *values, uint64_t *indices,
1356                                    uint64_t *perm, uint8_t *sparse) {
1357   return toMLIRSparseTensor<int16_t>(rank, nse, shape, values, indices, perm,
1358                                      sparse);
1359 }
1360 void *convertToMLIRSparseTensorI8(uint64_t rank, uint64_t nse, uint64_t *shape,
1361                                   int8_t *values, uint64_t *indices,
1362                                   uint64_t *perm, uint8_t *sparse) {
1363   return toMLIRSparseTensor<int8_t>(rank, nse, shape, values, indices, perm,
1364                                     sparse);
1365 }
1366 
1367 /// Converts a sparse tensor to COO-flavored format expressed using C-style
1368 /// data structures. The expected output parameters are pointers for these
1369 /// values:
1370 ///
1371 ///   rank:    rank of tensor
1372 ///   nse:     number of specified elements (usually the nonzeros)
1373 ///   shape:   array with dimension size for each rank
1374 ///   values:  a "nse" array with values for all specified elements
1375 ///   indices: a flat "nse x rank" array with indices for all specified elements
1376 ///
1377 /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
1378 /// from convertToMLIRSparseTensor.
1379 ///
1380 //  TODO: Currently, values are copied from SparseTensorStorage to
1381 //  SparseTensorCOO, then to the output. We may want to reduce the number of
1382 //  copies.
1383 //
1384 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1385 // compressed
1386 //
1387 void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank,
1388                                     uint64_t *pNse, uint64_t **pShape,
1389                                     double **pValues, uint64_t **pIndices) {
1390   fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices);
1391 }
1392 void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
1393                                     uint64_t *pNse, uint64_t **pShape,
1394                                     float **pValues, uint64_t **pIndices) {
1395   fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
1396 }
1397 void convertFromMLIRSparseTensorI64(void *tensor, uint64_t *pRank,
1398                                     uint64_t *pNse, uint64_t **pShape,
1399                                     int64_t **pValues, uint64_t **pIndices) {
1400   fromMLIRSparseTensor<int64_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1401 }
1402 void convertFromMLIRSparseTensorI32(void *tensor, uint64_t *pRank,
1403                                     uint64_t *pNse, uint64_t **pShape,
1404                                     int32_t **pValues, uint64_t **pIndices) {
1405   fromMLIRSparseTensor<int32_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1406 }
1407 void convertFromMLIRSparseTensorI16(void *tensor, uint64_t *pRank,
1408                                     uint64_t *pNse, uint64_t **pShape,
1409                                     int16_t **pValues, uint64_t **pIndices) {
1410   fromMLIRSparseTensor<int16_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1411 }
1412 void convertFromMLIRSparseTensorI8(void *tensor, uint64_t *pRank,
1413                                    uint64_t *pNse, uint64_t **pShape,
1414                                    int8_t **pValues, uint64_t **pIndices) {
1415   fromMLIRSparseTensor<int8_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1416 }
1417 
1418 } // extern "C"
1419 
1420 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1421