1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <complex>
25 #include <cctype>
26 #include <cinttypes>
27 #include <cstdio>
28 #include <cstdlib>
29 #include <cstring>
30 #include <fstream>
31 #include <functional>
32 #include <iostream>
33 #include <limits>
34 #include <numeric>
35 #include <vector>
36 
37 using complex64 = std::complex<double>;
38 using complex32 = std::complex<float>;
39 
40 //===----------------------------------------------------------------------===//
41 //
42 // Internal support for storing and reading sparse tensors.
43 //
44 // The following memory-resident sparse storage schemes are supported:
45 //
46 // (a) A coordinate scheme for temporarily storing and lexicographically
47 //     sorting a sparse tensor by index (SparseTensorCOO).
48 //
49 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
50 //     per-dimension sparse/dense annnotations together with a dimension
51 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
52 //
53 // The following external formats are supported:
54 //
55 // (1) Matrix Market Exchange (MME): *.mtx
56 //     https://math.nist.gov/MatrixMarket/formats.html
57 //
58 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
59 //     http://frostt.io/tensors/file-formats.html
60 //
61 // Two public APIs are supported:
62 //
63 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
64 //     tensors. These methods should be used exclusively by MLIR
65 //     compiler-generated code.
66 //
67 // (II) Methods that accept C-style data structures to interact with sparse
68 //      tensors. These methods can be used by any external runtime that wants
69 //      to interact with MLIR compiler-generated code.
70 //
71 // In both cases (I) and (II), the SparseTensorStorage format is externally
72 // only visible as an opaque pointer.
73 //
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 
78 static constexpr int kColWidth = 1025;
79 
80 /// A version of `operator*` on `uint64_t` which checks for overflows.
81 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
82   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
83          "Integer overflow");
84   return lhs * rhs;
85 }
86 
87 /// A sparse tensor element in coordinate scheme (value and indices).
88 /// For example, a rank-1 vector element would look like
89 ///   ({i}, a[i])
90 /// and a rank-5 tensor element like
91 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
92 /// We use pointer to a shared index pool rather than e.g. a direct
93 /// vector since that (1) reduces the per-element memory footprint, and
94 /// (2) centralizes the memory reservation and (re)allocation to one place.
95 template <typename V>
96 struct Element {
97   Element(uint64_t *ind, V val) : indices(ind), value(val){};
98   uint64_t *indices; // pointer into shared index pool
99   V value;
100 };
101 
102 /// The type of callback functions which receive an element.  We avoid
103 /// packaging the coordinates and value together as an `Element` object
104 /// because this helps keep code somewhat cleaner.
105 template <typename V>
106 using ElementConsumer =
107     const std::function<void(const std::vector<uint64_t> &, V)> &;
108 
109 /// A memory-resident sparse tensor in coordinate scheme (collection of
110 /// elements). This data structure is used to read a sparse tensor from
111 /// any external format into memory and sort the elements lexicographically
112 /// by indices before passing it back to the client (most packed storage
113 /// formats require the elements to appear in lexicographic index order).
114 template <typename V>
115 struct SparseTensorCOO {
116 public:
117   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
118       : sizes(szs) {
119     if (capacity) {
120       elements.reserve(capacity);
121       indices.reserve(capacity * getRank());
122     }
123   }
124 
125   /// Adds element as indices and value.
126   void add(const std::vector<uint64_t> &ind, V val) {
127     assert(!iteratorLocked && "Attempt to add() after startIterator()");
128     uint64_t *base = indices.data();
129     uint64_t size = indices.size();
130     uint64_t rank = getRank();
131     assert(rank == ind.size());
132     for (uint64_t r = 0; r < rank; r++) {
133       assert(ind[r] < sizes[r]); // within bounds
134       indices.push_back(ind[r]);
135     }
136     // This base only changes if indices were reallocated. In that case, we
137     // need to correct all previous pointers into the vector. Note that this
138     // only happens if we did not set the initial capacity right, and then only
139     // for every internal vector reallocation (which with the doubling rule
140     // should only incur an amortized linear overhead).
141     uint64_t *newBase = indices.data();
142     if (newBase != base) {
143       for (uint64_t i = 0, n = elements.size(); i < n; i++)
144         elements[i].indices = newBase + (elements[i].indices - base);
145       base = newBase;
146     }
147     // Add element as (pointer into shared index pool, value) pair.
148     elements.emplace_back(base + size, val);
149   }
150 
151   /// Sorts elements lexicographically by index.
152   void sort() {
153     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
154     // TODO: we may want to cache an `isSorted` bit, to avoid
155     // unnecessary/redundant sorting.
156     std::sort(elements.begin(), elements.end(),
157               [this](const Element<V> &e1, const Element<V> &e2) {
158                 uint64_t rank = getRank();
159                 for (uint64_t r = 0; r < rank; r++) {
160                   if (e1.indices[r] == e2.indices[r])
161                     continue;
162                   return e1.indices[r] < e2.indices[r];
163                 }
164                 return false;
165               });
166   }
167 
168   /// Returns rank.
169   uint64_t getRank() const { return sizes.size(); }
170 
171   /// Getter for sizes array.
172   const std::vector<uint64_t> &getSizes() const { return sizes; }
173 
174   /// Getter for elements array.
175   const std::vector<Element<V>> &getElements() const { return elements; }
176 
177   /// Switch into iterator mode.
178   void startIterator() {
179     iteratorLocked = true;
180     iteratorPos = 0;
181   }
182 
183   /// Get the next element.
184   const Element<V> *getNext() {
185     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
186     if (iteratorPos < elements.size())
187       return &(elements[iteratorPos++]);
188     iteratorLocked = false;
189     return nullptr;
190   }
191 
192   /// Factory method. Permutes the original dimensions according to
193   /// the given ordering and expects subsequent add() calls to honor
194   /// that same ordering for the given indices. The result is a
195   /// fully permuted coordinate scheme.
196   ///
197   /// Precondition: `sizes` and `perm` must be valid for `rank`.
198   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
199                                                 const uint64_t *sizes,
200                                                 const uint64_t *perm,
201                                                 uint64_t capacity = 0) {
202     std::vector<uint64_t> permsz(rank);
203     for (uint64_t r = 0; r < rank; r++) {
204       assert(sizes[r] > 0 && "Dimension size zero has trivial storage");
205       permsz[perm[r]] = sizes[r];
206     }
207     return new SparseTensorCOO<V>(permsz, capacity);
208   }
209 
210 private:
211   const std::vector<uint64_t> sizes; // per-dimension sizes
212   std::vector<Element<V>> elements;  // all COO elements
213   std::vector<uint64_t> indices;     // shared index pool
214   bool iteratorLocked = false;
215   unsigned iteratorPos = 0;
216 };
217 
218 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
219 /// takes responsibility for all the `<P,I,V>`-independent aspects
220 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
221 /// we use function overloading to implement "partial" method
222 /// specialization, which the C-API relies on to catch type errors
223 /// arising from our use of opaque pointers.
224 class SparseTensorStorageBase {
225 public:
226   /// Constructs a new storage object.  The `perm` maps the tensor's
227   /// semantic-ordering of dimensions to this object's storage-order.
228   /// The `szs` and `sparsity` arrays are already in storage-order.
229   ///
230   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
231   SparseTensorStorageBase(const std::vector<uint64_t> &szs,
232                           const uint64_t *perm, const DimLevelType *sparsity)
233       : dimSizes(szs), rev(getRank()),
234         dimTypes(sparsity, sparsity + getRank()) {
235     assert(perm && sparsity);
236     const uint64_t rank = getRank();
237     // Validate parameters.
238     assert(rank > 0 && "Trivial shape is unsupported");
239     for (uint64_t r = 0; r < rank; r++) {
240       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
241       assert((dimTypes[r] == DimLevelType::kDense ||
242               dimTypes[r] == DimLevelType::kCompressed) &&
243              "Unsupported DimLevelType");
244     }
245     // Construct the "reverse" (i.e., inverse) permutation.
246     for (uint64_t r = 0; r < rank; r++)
247       rev[perm[r]] = r;
248   }
249 
250   virtual ~SparseTensorStorageBase() = default;
251 
252   /// Get the rank of the tensor.
253   uint64_t getRank() const { return dimSizes.size(); }
254 
255   /// Getter for the dimension-sizes array, in storage-order.
256   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
257 
258   /// Safely lookup the size of the given (storage-order) dimension.
259   uint64_t getDimSize(uint64_t d) const {
260     assert(d < getRank());
261     return dimSizes[d];
262   }
263 
264   /// Getter for the "reverse" permutation, which maps this object's
265   /// storage-order to the tensor's semantic-order.
266   const std::vector<uint64_t> &getRev() const { return rev; }
267 
268   /// Getter for the dimension-types array, in storage-order.
269   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
270 
271   /// Safely check if the (storage-order) dimension uses compressed storage.
272   bool isCompressedDim(uint64_t d) const {
273     assert(d < getRank());
274     return (dimTypes[d] == DimLevelType::kCompressed);
275   }
276 
277   /// Overhead storage.
278   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
279   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
280   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
281   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
282   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
283   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
284   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
285   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
286 
287   /// Primary storage.
288   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
289   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
290   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
291   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
292   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
293   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
294   virtual void getValues(std::vector<complex64> **) { fatal("valc64"); }
295   virtual void getValues(std::vector<complex32> **) { fatal("valc32"); }
296 
297   /// Element-wise insertion in lexicographic index order.
298   virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
299   virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); }
300   virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); }
301   virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
302   virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
303   virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
304   virtual void lexInsert(const uint64_t *, complex64) { fatal("insc64"); }
305   virtual void lexInsert(const uint64_t *, complex32) { fatal("insc32"); }
306 
307   /// Expanded insertion.
308   virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
309     fatal("expf64");
310   }
311   virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
312     fatal("expf32");
313   }
314   virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
315     fatal("expi64");
316   }
317   virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
318     fatal("expi32");
319   }
320   virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
321     fatal("expi16");
322   }
323   virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
324     fatal("expi8");
325   }
326   virtual void expInsert(uint64_t *, complex64 *, bool *, uint64_t *,
327                          uint64_t) {
328     fatal("expc64");
329   }
330   virtual void expInsert(uint64_t *, complex32 *, bool *, uint64_t *,
331                          uint64_t) {
332     fatal("expc32");
333   }
334 
335   /// Finishes insertion.
336   virtual void endInsert() = 0;
337 
338 protected:
339   // Since this class is virtual, we must disallow public copying in
340   // order to avoid "slicing".  Since this class has data members,
341   // that means making copying protected.
342   // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
343   SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
344   // Copy-assignment would be implicitly deleted (because `dimSizes`
345   // is const), so we explicitly delete it for clarity.
346   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
347 
348 private:
349   static void fatal(const char *tp) {
350     fprintf(stderr, "unsupported %s\n", tp);
351     exit(1);
352   }
353 
354   const std::vector<uint64_t> dimSizes;
355   std::vector<uint64_t> rev;
356   const std::vector<DimLevelType> dimTypes;
357 };
358 
359 // Forward.
360 template <typename P, typename I, typename V>
361 class SparseTensorEnumerator;
362 
363 /// A memory-resident sparse tensor using a storage scheme based on
364 /// per-dimension sparse/dense annotations. This data structure provides a
365 /// bufferized form of a sparse tensor type. In contrast to generating setup
366 /// methods for each differently annotated sparse tensor, this method provides
367 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
368 /// and annotations to implement all required setup in a general manner.
369 template <typename P, typename I, typename V>
370 class SparseTensorStorage : public SparseTensorStorageBase {
371 public:
372   /// Constructs a sparse tensor storage scheme with the given dimensions,
373   /// permutation, and per-dimension dense/sparse annotations, using
374   /// the coordinate scheme tensor for the initial contents if provided.
375   ///
376   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
377   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
378                       const DimLevelType *sparsity,
379                       SparseTensorCOO<V> *coo = nullptr)
380       : SparseTensorStorageBase(szs, perm, sparsity), pointers(getRank()),
381         indices(getRank()), idx(getRank()) {
382     // Provide hints on capacity of pointers and indices.
383     // TODO: needs much fine-tuning based on actual sparsity; currently
384     //       we reserve pointer/index space based on all previous dense
385     //       dimensions, which works well up to first sparse dim; but
386     //       we should really use nnz and dense/sparse distribution.
387     bool allDense = true;
388     uint64_t sz = 1;
389     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
390       if (isCompressedDim(r)) {
391         // TODO: Take a parameter between 1 and `sizes[r]`, and multiply
392         // `sz` by that before reserving. (For now we just use 1.)
393         pointers[r].reserve(sz + 1);
394         pointers[r].push_back(0);
395         indices[r].reserve(sz);
396         sz = 1;
397         allDense = false;
398       } else { // Dense dimension.
399         sz = checkedMul(sz, getDimSizes()[r]);
400       }
401     }
402     // Then assign contents from coordinate scheme tensor if provided.
403     if (coo) {
404       // Ensure both preconditions of `fromCOO`.
405       assert(coo->getSizes() == getDimSizes() && "Tensor size mismatch");
406       coo->sort();
407       // Now actually insert the `elements`.
408       const std::vector<Element<V>> &elements = coo->getElements();
409       uint64_t nnz = elements.size();
410       values.reserve(nnz);
411       fromCOO(elements, 0, nnz, 0);
412     } else if (allDense) {
413       values.resize(sz, 0);
414     }
415   }
416 
417   ~SparseTensorStorage() override = default;
418 
419   /// Partially specialize these getter methods based on template types.
420   void getPointers(std::vector<P> **out, uint64_t d) override {
421     assert(d < getRank());
422     *out = &pointers[d];
423   }
424   void getIndices(std::vector<I> **out, uint64_t d) override {
425     assert(d < getRank());
426     *out = &indices[d];
427   }
428   void getValues(std::vector<V> **out) override { *out = &values; }
429 
430   /// Partially specialize lexicographical insertions based on template types.
431   void lexInsert(const uint64_t *cursor, V val) override {
432     // First, wrap up pending insertion path.
433     uint64_t diff = 0;
434     uint64_t top = 0;
435     if (!values.empty()) {
436       diff = lexDiff(cursor);
437       endPath(diff + 1);
438       top = idx[diff] + 1;
439     }
440     // Then continue with insertion path.
441     insPath(cursor, diff, top, val);
442   }
443 
444   /// Partially specialize expanded insertions based on template types.
445   /// Note that this method resets the values/filled-switch array back
446   /// to all-zero/false while only iterating over the nonzero elements.
447   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
448                  uint64_t count) override {
449     if (count == 0)
450       return;
451     // Sort.
452     std::sort(added, added + count);
453     // Restore insertion path for first insert.
454     const uint64_t lastDim = getRank() - 1;
455     uint64_t index = added[0];
456     cursor[lastDim] = index;
457     lexInsert(cursor, values[index]);
458     assert(filled[index]);
459     values[index] = 0;
460     filled[index] = false;
461     // Subsequent insertions are quick.
462     for (uint64_t i = 1; i < count; i++) {
463       assert(index < added[i] && "non-lexicographic insertion");
464       index = added[i];
465       cursor[lastDim] = index;
466       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
467       assert(filled[index]);
468       values[index] = 0;
469       filled[index] = false;
470     }
471   }
472 
473   /// Finalizes lexicographic insertions.
474   void endInsert() override {
475     if (values.empty())
476       finalizeSegment(0);
477     else
478       endPath(0);
479   }
480 
481   /// Returns this sparse tensor storage scheme as a new memory-resident
482   /// sparse tensor in coordinate scheme with the given dimension order.
483   ///
484   /// Precondition: `perm` must be valid for `getRank()`.
485   SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
486     SparseTensorEnumerator<P, I, V> enumerator(*this, getRank(), perm);
487     SparseTensorCOO<V> *coo =
488         new SparseTensorCOO<V>(enumerator.permutedSizes(), values.size());
489     enumerator.forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
490       coo->add(ind, val);
491     });
492     // TODO: This assertion assumes there are no stored zeros,
493     // or if there are then that we don't filter them out.
494     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
495     assert(coo->getElements().size() == values.size());
496     return coo;
497   }
498 
499   /// Factory method. Constructs a sparse tensor storage scheme with the given
500   /// dimensions, permutation, and per-dimension dense/sparse annotations,
501   /// using the coordinate scheme tensor for the initial contents if provided.
502   /// In the latter case, the coordinate scheme must respect the same
503   /// permutation as is desired for the new sparse tensor storage.
504   ///
505   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
506   static SparseTensorStorage<P, I, V> *
507   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
508                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
509     SparseTensorStorage<P, I, V> *n = nullptr;
510     if (coo) {
511       assert(coo->getRank() == rank && "Tensor rank mismatch");
512       const auto &coosz = coo->getSizes();
513       for (uint64_t r = 0; r < rank; r++)
514         assert(shape[r] == 0 || shape[r] == coosz[perm[r]]);
515       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
516     } else {
517       std::vector<uint64_t> permsz(rank);
518       for (uint64_t r = 0; r < rank; r++) {
519         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
520         permsz[perm[r]] = shape[r];
521       }
522       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
523     }
524     return n;
525   }
526 
527 private:
528   /// Appends an arbitrary new position to `pointers[d]`.  This method
529   /// checks that `pos` is representable in the `P` type; however, it
530   /// does not check that `pos` is semantically valid (i.e., larger than
531   /// the previous position and smaller than `indices[d].capacity()`).
532   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
533     assert(isCompressedDim(d));
534     assert(pos <= std::numeric_limits<P>::max() &&
535            "Pointer value is too large for the P-type");
536     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
537   }
538 
539   /// Appends index `i` to dimension `d`, in the semantically general
540   /// sense.  For non-dense dimensions, that means appending to the
541   /// `indices[d]` array, checking that `i` is representable in the `I`
542   /// type; however, we do not verify other semantic requirements (e.g.,
543   /// that `i` is in bounds for `sizes[d]`, and not previously occurring
544   /// in the same segment).  For dense dimensions, this method instead
545   /// appends the appropriate number of zeros to the `values` array,
546   /// where `full` is the number of "entries" already written to `values`
547   /// for this segment (aka one after the highest index previously appended).
548   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
549     if (isCompressedDim(d)) {
550       assert(i <= std::numeric_limits<I>::max() &&
551              "Index value is too large for the I-type");
552       indices[d].push_back(static_cast<I>(i));
553     } else { // Dense dimension.
554       assert(i >= full && "Index was already filled");
555       if (i == full)
556         return; // Short-circuit, since it'll be a nop.
557       if (d + 1 == getRank())
558         values.insert(values.end(), i - full, 0);
559       else
560         finalizeSegment(d + 1, 0, i - full);
561     }
562   }
563 
564   /// Initializes sparse tensor storage scheme from a memory-resident sparse
565   /// tensor in coordinate scheme. This method prepares the pointers and
566   /// indices arrays under the given per-dimension dense/sparse annotations.
567   ///
568   /// Preconditions:
569   /// (1) the `elements` must be lexicographically sorted.
570   /// (2) the indices of every element are valid for `sizes` (equal rank
571   ///     and pointwise less-than).
572   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
573                uint64_t hi, uint64_t d) {
574     uint64_t rank = getRank();
575     assert(d <= rank && hi <= elements.size());
576     // Once dimensions are exhausted, insert the numerical values.
577     if (d == rank) {
578       assert(lo < hi);
579       values.push_back(elements[lo].value);
580       return;
581     }
582     // Visit all elements in this interval.
583     uint64_t full = 0;
584     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
585       // Find segment in interval with same index elements in this dimension.
586       uint64_t i = elements[lo].indices[d];
587       uint64_t seg = lo + 1;
588       while (seg < hi && elements[seg].indices[d] == i)
589         seg++;
590       // Handle segment in interval for sparse or dense dimension.
591       appendIndex(d, full, i);
592       full = i + 1;
593       fromCOO(elements, lo, seg, d + 1);
594       // And move on to next segment in interval.
595       lo = seg;
596     }
597     // Finalize the sparse pointer structure at this dimension.
598     finalizeSegment(d, full);
599   }
600 
601   /// Finalize the sparse pointer structure at this dimension.
602   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
603     if (count == 0)
604       return; // Short-circuit, since it'll be a nop.
605     if (isCompressedDim(d)) {
606       appendPointer(d, indices[d].size(), count);
607     } else { // Dense dimension.
608       const uint64_t sz = getDimSizes()[d];
609       assert(sz >= full && "Segment is overfull");
610       count = checkedMul(count, sz - full);
611       // For dense storage we must enumerate all the remaining coordinates
612       // in this dimension (i.e., coordinates after the last non-zero
613       // element), and either fill in their zero values or else recurse
614       // to finalize some deeper dimension.
615       if (d + 1 == getRank())
616         values.insert(values.end(), count, 0);
617       else
618         finalizeSegment(d + 1, 0, count);
619     }
620   }
621 
622   /// Wraps up a single insertion path, inner to outer.
623   void endPath(uint64_t diff) {
624     uint64_t rank = getRank();
625     assert(diff <= rank);
626     for (uint64_t i = 0; i < rank - diff; i++) {
627       const uint64_t d = rank - i - 1;
628       finalizeSegment(d, idx[d] + 1);
629     }
630   }
631 
632   /// Continues a single insertion path, outer to inner.
633   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
634     uint64_t rank = getRank();
635     assert(diff < rank);
636     for (uint64_t d = diff; d < rank; d++) {
637       uint64_t i = cursor[d];
638       appendIndex(d, top, i);
639       top = 0;
640       idx[d] = i;
641     }
642     values.push_back(val);
643   }
644 
645   /// Finds the lexicographic differing dimension.
646   uint64_t lexDiff(const uint64_t *cursor) const {
647     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
648       if (cursor[r] > idx[r])
649         return r;
650       else
651         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
652     assert(0 && "duplication insertion");
653     return -1u;
654   }
655 
656   // Allow `SparseTensorEnumerator` to access the data-members (to avoid
657   // the cost of virtual-function dispatch in inner loops), without
658   // making them public to other client code.
659   friend class SparseTensorEnumerator<P, I, V>;
660 
661   std::vector<std::vector<P>> pointers;
662   std::vector<std::vector<I>> indices;
663   std::vector<V> values;
664   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
665 };
666 
667 /// A (higher-order) function object for enumerating the elements of some
668 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
669 /// method encapsulates the loop-nest for enumerating the elements of
670 /// the source tensor (in whatever order is best for the source tensor),
671 /// and applies a permutation to the coordinates/indices before handing
672 /// each element to the callback.  A single enumerator object can be
673 /// freely reused for several calls to `forallElements`, just so long
674 /// as each call is sequential with respect to one another.
675 ///
676 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
677 /// passed to the constructor; thus, objects of this class must not
678 /// outlive the sparse tensor they depend on.
679 ///
680 /// Design Note: The reason we define this class instead of simply using
681 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
682 /// the `<P,I>` template parameters from MLIR client code (to simplify the
683 /// type parameters used for direct sparse-to-sparse conversion).  And the
684 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
685 /// than simply using this class, is to avoid the cost of virtual-method
686 /// dispatch within the loop-nest.
687 template <typename V>
688 class SparseTensorEnumeratorBase {
689 public:
690   /// Constructs an enumerator with the given permutation for mapping
691   /// the semantic-ordering of dimensions to the desired target-ordering.
692   ///
693   /// Preconditions:
694   /// * the `tensor` must have the same `V` value type.
695   /// * `perm` must be valid for `rank`.
696   SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
697                              uint64_t rank, const uint64_t *perm)
698       : src(tensor), permsz(src.getRev().size()), reord(getRank()),
699         cursor(getRank()) {
700     assert(perm && "Received nullptr for permutation");
701     assert(rank == getRank() && "Permutation rank mismatch");
702     const auto &rev = src.getRev();        // source stg-order -> semantic-order
703     const auto &sizes = src.getDimSizes(); // in source storage-order
704     for (uint64_t s = 0; s < rank; s++) {  // `s` source storage-order
705       uint64_t t = perm[rev[s]];           // `t` target-order
706       reord[s] = t;
707       permsz[t] = sizes[s];
708     }
709   }
710 
711   virtual ~SparseTensorEnumeratorBase() = default;
712 
713   // We disallow copying to help avoid leaking the `src` reference.
714   // (In addition to avoiding the problem of slicing.)
715   SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
716   SparseTensorEnumeratorBase &
717   operator=(const SparseTensorEnumeratorBase &) = delete;
718 
719   /// Returns the source/target tensor's rank.  (The source-rank and
720   /// target-rank are always equal since we only support permutations.
721   /// Though once we add support for other dimension mappings, this
722   /// method will have to be split in two.)
723   uint64_t getRank() const { return permsz.size(); }
724 
725   /// Returns the target tensor's dimension sizes.
726   const std::vector<uint64_t> &permutedSizes() const { return permsz; }
727 
728   /// Enumerates all elements of the source tensor, permutes their
729   /// indices, and passes the permuted element to the callback.
730   /// The callback must not store the cursor reference directly,
731   /// since this function reuses the storage.  Instead, the callback
732   /// must copy it if they want to keep it.
733   virtual void forallElements(ElementConsumer<V> yield) = 0;
734 
735 protected:
736   const SparseTensorStorageBase &src;
737   std::vector<uint64_t> permsz; // in target order.
738   std::vector<uint64_t> reord;  // source storage-order -> target order.
739   std::vector<uint64_t> cursor; // in target order.
740 };
741 
742 template <typename P, typename I, typename V>
743 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
744   using Base = SparseTensorEnumeratorBase<V>;
745 
746 public:
747   /// Constructs an enumerator with the given permutation for mapping
748   /// the semantic-ordering of dimensions to the desired target-ordering.
749   ///
750   /// Precondition: `perm` must be valid for `rank`.
751   SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
752                          uint64_t rank, const uint64_t *perm)
753       : Base(tensor, rank, perm) {}
754 
755   ~SparseTensorEnumerator() final override = default;
756 
757   void forallElements(ElementConsumer<V> yield) final override {
758     forallElements(yield, 0, 0);
759   }
760 
761 private:
762   /// The recursive component of the public `forallElements`.
763   void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
764                       uint64_t d) {
765     // Recover the `<P,I,V>` type parameters of `src`.
766     const auto &src =
767         static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
768     if (d == Base::getRank()) {
769       assert(parentPos < src.values.size() &&
770              "Value position is out of bounds");
771       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
772       yield(this->cursor, src.values[parentPos]);
773     } else if (src.isCompressedDim(d)) {
774       // Look up the bounds of the `d`-level segment determined by the
775       // `d-1`-level position `parentPos`.
776       const std::vector<P> &pointers_d = src.pointers[d];
777       assert(parentPos + 1 < pointers_d.size() &&
778              "Parent pointer position is out of bounds");
779       const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
780       const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
781       // Loop-invariant code for looking up the `d`-level coordinates/indices.
782       const std::vector<I> &indices_d = src.indices[d];
783       assert(pstop - 1 < indices_d.size() && "Index position is out of bounds");
784       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
785       for (uint64_t pos = pstart; pos < pstop; pos++) {
786         cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
787         forallElements(yield, pos, d + 1);
788       }
789     } else { // Dense dimension.
790       const uint64_t sz = src.getDimSizes()[d];
791       const uint64_t pstart = parentPos * sz;
792       uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
793       for (uint64_t i = 0; i < sz; i++) {
794         cursor_reord_d = i;
795         forallElements(yield, pstart + i, d + 1);
796       }
797     }
798   }
799 };
800 
801 /// Helper to convert string to lower case.
802 static char *toLower(char *token) {
803   for (char *c = token; *c; c++)
804     *c = tolower(*c);
805   return token;
806 }
807 
808 /// Read the MME header of a general sparse matrix of type real.
809 static void readMMEHeader(FILE *file, char *filename, char *line,
810                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
811   char header[64];
812   char object[64];
813   char format[64];
814   char field[64];
815   char symmetry[64];
816   // Read header line.
817   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
818              symmetry) != 5) {
819     fprintf(stderr, "Corrupt header in %s\n", filename);
820     exit(1);
821   }
822   // Set properties
823   *isPattern = (strcmp(toLower(field), "pattern") == 0);
824   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
825   // Make sure this is a general sparse matrix.
826   if (strcmp(toLower(header), "%%matrixmarket") ||
827       strcmp(toLower(object), "matrix") ||
828       strcmp(toLower(format), "coordinate") ||
829       (strcmp(toLower(field), "real") && !(*isPattern)) ||
830       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
831     fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
832     exit(1);
833   }
834   // Skip comments.
835   while (true) {
836     if (!fgets(line, kColWidth, file)) {
837       fprintf(stderr, "Cannot find data in %s\n", filename);
838       exit(1);
839     }
840     if (line[0] != '%')
841       break;
842   }
843   // Next line contains M N NNZ.
844   idata[0] = 2; // rank
845   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
846              idata + 1) != 3) {
847     fprintf(stderr, "Cannot find size in %s\n", filename);
848     exit(1);
849   }
850 }
851 
852 /// Read the "extended" FROSTT header. Although not part of the documented
853 /// format, we assume that the file starts with optional comments followed
854 /// by two lines that define the rank, the number of nonzeros, and the
855 /// dimensions sizes (one per rank) of the sparse tensor.
856 static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
857                                 uint64_t *idata) {
858   // Skip comments.
859   while (true) {
860     if (!fgets(line, kColWidth, file)) {
861       fprintf(stderr, "Cannot find data in %s\n", filename);
862       exit(1);
863     }
864     if (line[0] != '#')
865       break;
866   }
867   // Next line contains RANK and NNZ.
868   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
869     fprintf(stderr, "Cannot find metadata in %s\n", filename);
870     exit(1);
871   }
872   // Followed by a line with the dimension sizes (one per rank).
873   for (uint64_t r = 0; r < idata[0]; r++) {
874     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
875       fprintf(stderr, "Cannot find dimension size %s\n", filename);
876       exit(1);
877     }
878   }
879   fgets(line, kColWidth, file); // end of line
880 }
881 
882 /// Reads a sparse tensor with the given filename into a memory-resident
883 /// sparse tensor in coordinate scheme.
884 template <typename V>
885 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
886                                                const uint64_t *shape,
887                                                const uint64_t *perm) {
888   // Open the file.
889   FILE *file = fopen(filename, "r");
890   if (!file) {
891     assert(filename && "Received nullptr for filename");
892     fprintf(stderr, "Cannot find file %s\n", filename);
893     exit(1);
894   }
895   // Perform some file format dependent set up.
896   char line[kColWidth];
897   uint64_t idata[512];
898   bool isPattern = false;
899   bool isSymmetric = false;
900   if (strstr(filename, ".mtx")) {
901     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
902   } else if (strstr(filename, ".tns")) {
903     readExtFROSTTHeader(file, filename, line, idata);
904   } else {
905     fprintf(stderr, "Unknown format %s\n", filename);
906     exit(1);
907   }
908   // Prepare sparse tensor object with per-dimension sizes
909   // and the number of nonzeros as initial capacity.
910   assert(rank == idata[0] && "rank mismatch");
911   uint64_t nnz = idata[1];
912   for (uint64_t r = 0; r < rank; r++)
913     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
914            "dimension size mismatch");
915   SparseTensorCOO<V> *tensor =
916       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
917   // Read all nonzero elements.
918   std::vector<uint64_t> indices(rank);
919   for (uint64_t k = 0; k < nnz; k++) {
920     if (!fgets(line, kColWidth, file)) {
921       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
922       exit(1);
923     }
924     char *linePtr = line;
925     for (uint64_t r = 0; r < rank; r++) {
926       uint64_t idx = strtoul(linePtr, &linePtr, 10);
927       // Add 0-based index.
928       indices[perm[r]] = idx - 1;
929     }
930     // The external formats always store the numerical values with the type
931     // double, but we cast these values to the sparse tensor object type.
932     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
933     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
934     tensor->add(indices, value);
935     // We currently chose to deal with symmetric matrices by fully constructing
936     // them. In the future, we may want to make symmetry implicit for storage
937     // reasons.
938     if (isSymmetric && indices[0] != indices[1])
939       tensor->add({indices[1], indices[0]}, value);
940   }
941   // Close the file and return tensor.
942   fclose(file);
943   return tensor;
944 }
945 
946 /// Writes the sparse tensor to extended FROSTT format.
947 template <typename V>
948 static void outSparseTensor(void *tensor, void *dest, bool sort) {
949   assert(tensor && dest);
950   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
951   if (sort)
952     coo->sort();
953   char *filename = static_cast<char *>(dest);
954   auto &sizes = coo->getSizes();
955   auto &elements = coo->getElements();
956   uint64_t rank = coo->getRank();
957   uint64_t nnz = elements.size();
958   std::fstream file;
959   file.open(filename, std::ios_base::out | std::ios_base::trunc);
960   assert(file.is_open());
961   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
962   for (uint64_t r = 0; r < rank - 1; r++)
963     file << sizes[r] << " ";
964   file << sizes[rank - 1] << std::endl;
965   for (uint64_t i = 0; i < nnz; i++) {
966     auto &idx = elements[i].indices;
967     for (uint64_t r = 0; r < rank; r++)
968       file << (idx[r] + 1) << " ";
969     file << elements[i].value << std::endl;
970   }
971   file.flush();
972   file.close();
973   assert(file.good());
974 }
975 
976 /// Initializes sparse tensor from an external COO-flavored format.
977 template <typename V>
978 static SparseTensorStorage<uint64_t, uint64_t, V> *
979 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
980                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
981   const DimLevelType *sparsity = (DimLevelType *)(sparse);
982 #ifndef NDEBUG
983   // Verify that perm is a permutation of 0..(rank-1).
984   std::vector<uint64_t> order(perm, perm + rank);
985   std::sort(order.begin(), order.end());
986   for (uint64_t i = 0; i < rank; ++i) {
987     if (i != order[i]) {
988       fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
989       exit(1);
990     }
991   }
992 
993   // Verify that the sparsity values are supported.
994   for (uint64_t i = 0; i < rank; ++i) {
995     if (sparsity[i] != DimLevelType::kDense &&
996         sparsity[i] != DimLevelType::kCompressed) {
997       fprintf(stderr, "Unsupported sparsity value %d\n",
998               static_cast<int>(sparsity[i]));
999       exit(1);
1000     }
1001   }
1002 #endif
1003 
1004   // Convert external format to internal COO.
1005   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
1006   std::vector<uint64_t> idx(rank);
1007   for (uint64_t i = 0, base = 0; i < nse; i++) {
1008     for (uint64_t r = 0; r < rank; r++)
1009       idx[perm[r]] = indices[base + r];
1010     coo->add(idx, values[i]);
1011     base += rank;
1012   }
1013   // Return sparse tensor storage format as opaque pointer.
1014   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
1015       rank, shape, perm, sparsity, coo);
1016   delete coo;
1017   return tensor;
1018 }
1019 
1020 /// Converts a sparse tensor to an external COO-flavored format.
1021 template <typename V>
1022 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1023                                  uint64_t **pShape, V **pValues,
1024                                  uint64_t **pIndices) {
1025   assert(tensor);
1026   auto sparseTensor =
1027       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1028   uint64_t rank = sparseTensor->getRank();
1029   std::vector<uint64_t> perm(rank);
1030   std::iota(perm.begin(), perm.end(), 0);
1031   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1032 
1033   const std::vector<Element<V>> &elements = coo->getElements();
1034   uint64_t nse = elements.size();
1035 
1036   uint64_t *shape = new uint64_t[rank];
1037   for (uint64_t i = 0; i < rank; i++)
1038     shape[i] = coo->getSizes()[i];
1039 
1040   V *values = new V[nse];
1041   uint64_t *indices = new uint64_t[rank * nse];
1042 
1043   for (uint64_t i = 0, base = 0; i < nse; i++) {
1044     values[i] = elements[i].value;
1045     for (uint64_t j = 0; j < rank; j++)
1046       indices[base + j] = elements[i].indices[j];
1047     base += rank;
1048   }
1049 
1050   delete coo;
1051   *pRank = rank;
1052   *pNse = nse;
1053   *pShape = shape;
1054   *pValues = values;
1055   *pIndices = indices;
1056 }
1057 
1058 } // namespace
1059 
1060 extern "C" {
1061 
1062 //===----------------------------------------------------------------------===//
1063 //
1064 // Public API with methods that operate on MLIR buffers (memrefs) to interact
1065 // with sparse tensors, which are only visible as opaque pointers externally.
1066 // These methods should be used exclusively by MLIR compiler-generated code.
1067 //
1068 // Some macro magic is used to generate implementations for all required type
1069 // combinations that can be called from MLIR compiler-generated code.
1070 //
1071 //===----------------------------------------------------------------------===//
1072 
1073 #define CASE(p, i, v, P, I, V)                                                 \
1074   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
1075     SparseTensorCOO<V> *coo = nullptr;                                         \
1076     if (action <= Action::kFromCOO) {                                          \
1077       if (action == Action::kFromFile) {                                       \
1078         char *filename = static_cast<char *>(ptr);                             \
1079         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
1080       } else if (action == Action::kFromCOO) {                                 \
1081         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
1082       } else {                                                                 \
1083         assert(action == Action::kEmpty);                                      \
1084       }                                                                        \
1085       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
1086           rank, shape, perm, sparsity, coo);                                   \
1087       if (action == Action::kFromFile)                                         \
1088         delete coo;                                                            \
1089       return tensor;                                                           \
1090     }                                                                          \
1091     if (action == Action::kEmptyCOO)                                           \
1092       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
1093     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
1094     if (action == Action::kToIterator) {                                       \
1095       coo->startIterator();                                                    \
1096     } else {                                                                   \
1097       assert(action == Action::kToCOO);                                        \
1098     }                                                                          \
1099     return coo;                                                                \
1100   }
1101 
1102 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1103 
1104 #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
1105   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
1106     assert(ref &&tensor);                                                      \
1107     std::vector<TYPE> *v;                                                      \
1108     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
1109     ref->basePtr = ref->data = v->data();                                      \
1110     ref->offset = 0;                                                           \
1111     ref->sizes[0] = v->size();                                                 \
1112     ref->strides[0] = 1;                                                       \
1113   }
1114 
1115 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
1116   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
1117                            index_type d) {                                     \
1118     assert(ref &&tensor);                                                      \
1119     std::vector<TYPE> *v;                                                      \
1120     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
1121     ref->basePtr = ref->data = v->data();                                      \
1122     ref->offset = 0;                                                           \
1123     ref->sizes[0] = v->size();                                                 \
1124     ref->strides[0] = 1;                                                       \
1125   }
1126 
1127 #define IMPL_ADDELT(NAME, TYPE)                                                \
1128   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
1129                             StridedMemRefType<index_type, 1> *iref,            \
1130                             StridedMemRefType<index_type, 1> *pref) {          \
1131     assert(tensor &&iref &&pref);                                              \
1132     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
1133     assert(iref->sizes[0] == pref->sizes[0]);                                  \
1134     const index_type *indx = iref->data + iref->offset;                        \
1135     const index_type *perm = pref->data + pref->offset;                        \
1136     uint64_t isize = iref->sizes[0];                                           \
1137     std::vector<index_type> indices(isize);                                    \
1138     for (uint64_t r = 0; r < isize; r++)                                       \
1139       indices[perm[r]] = indx[r];                                              \
1140     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
1141     return tensor;                                                             \
1142   }
1143 
1144 #define IMPL_GETNEXT(NAME, V)                                                  \
1145   bool _mlir_ciface_##NAME(void *tensor,                                       \
1146                            StridedMemRefType<index_type, 1> *iref,             \
1147                            StridedMemRefType<V, 0> *vref) {                    \
1148     assert(tensor &&iref &&vref);                                              \
1149     assert(iref->strides[0] == 1);                                             \
1150     index_type *indx = iref->data + iref->offset;                              \
1151     V *value = vref->data + vref->offset;                                      \
1152     const uint64_t isize = iref->sizes[0];                                     \
1153     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
1154     const Element<V> *elem = iter->getNext();                                  \
1155     if (elem == nullptr)                                                       \
1156       return false;                                                            \
1157     for (uint64_t r = 0; r < isize; r++)                                       \
1158       indx[r] = elem->indices[r];                                              \
1159     *value = elem->value;                                                      \
1160     return true;                                                               \
1161   }
1162 
1163 #define IMPL_LEXINSERT(NAME, V)                                                \
1164   void _mlir_ciface_##NAME(void *tensor,                                       \
1165                            StridedMemRefType<index_type, 1> *cref, V val) {    \
1166     assert(tensor &&cref);                                                     \
1167     assert(cref->strides[0] == 1);                                             \
1168     index_type *cursor = cref->data + cref->offset;                            \
1169     assert(cursor);                                                            \
1170     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1171   }
1172 
1173 #define IMPL_EXPINSERT(NAME, V)                                                \
1174   void _mlir_ciface_##NAME(                                                    \
1175       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
1176       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1177       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
1178     assert(tensor &&cref &&vref &&fref &&aref);                                \
1179     assert(cref->strides[0] == 1);                                             \
1180     assert(vref->strides[0] == 1);                                             \
1181     assert(fref->strides[0] == 1);                                             \
1182     assert(aref->strides[0] == 1);                                             \
1183     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1184     index_type *cursor = cref->data + cref->offset;                            \
1185     V *values = vref->data + vref->offset;                                     \
1186     bool *filled = fref->data + fref->offset;                                  \
1187     index_type *added = aref->data + aref->offset;                             \
1188     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
1189         cursor, values, filled, added, count);                                 \
1190   }
1191 
1192 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1193 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1194 // that this file cannot get out of sync with its header.
1195 static_assert(std::is_same<index_type, uint64_t>::value,
1196               "Expected index_type == uint64_t");
1197 
1198 /// Constructs a new sparse tensor. This is the "swiss army knife"
1199 /// method for materializing sparse tensors into the computation.
1200 ///
1201 /// Action:
1202 /// kEmpty = returns empty storage to fill later
1203 /// kFromFile = returns storage, where ptr contains filename to read
1204 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
1205 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
1206 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
1207 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
1208 void *
1209 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1210                              StridedMemRefType<index_type, 1> *sref,
1211                              StridedMemRefType<index_type, 1> *pref,
1212                              OverheadType ptrTp, OverheadType indTp,
1213                              PrimaryType valTp, Action action, void *ptr) {
1214   assert(aref && sref && pref);
1215   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1216          pref->strides[0] == 1);
1217   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1218   const DimLevelType *sparsity = aref->data + aref->offset;
1219   const index_type *shape = sref->data + sref->offset;
1220   const index_type *perm = pref->data + pref->offset;
1221   uint64_t rank = aref->sizes[0];
1222 
1223   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1224   // This is safe because of the static_assert above.
1225   if (ptrTp == OverheadType::kIndex)
1226     ptrTp = OverheadType::kU64;
1227   if (indTp == OverheadType::kIndex)
1228     indTp = OverheadType::kU64;
1229 
1230   // Double matrices with all combinations of overhead storage.
1231   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1232        uint64_t, double);
1233   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1234        uint32_t, double);
1235   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1236        uint16_t, double);
1237   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1238        uint8_t, double);
1239   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1240        uint64_t, double);
1241   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1242        uint32_t, double);
1243   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1244        uint16_t, double);
1245   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1246        uint8_t, double);
1247   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1248        uint64_t, double);
1249   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1250        uint32_t, double);
1251   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1252        uint16_t, double);
1253   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1254        uint8_t, double);
1255   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1256        uint64_t, double);
1257   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1258        uint32_t, double);
1259   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1260        uint16_t, double);
1261   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1262        uint8_t, double);
1263 
1264   // Float matrices with all combinations of overhead storage.
1265   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1266        uint64_t, float);
1267   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1268        uint32_t, float);
1269   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1270        uint16_t, float);
1271   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1272        uint8_t, float);
1273   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1274        uint64_t, float);
1275   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1276        uint32_t, float);
1277   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1278        uint16_t, float);
1279   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1280        uint8_t, float);
1281   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1282        uint64_t, float);
1283   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1284        uint32_t, float);
1285   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1286        uint16_t, float);
1287   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1288        uint8_t, float);
1289   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1290        uint64_t, float);
1291   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1292        uint32_t, float);
1293   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1294        uint16_t, float);
1295   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1296        uint8_t, float);
1297 
1298   // Integral matrices with both overheads of the same type.
1299   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1300   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1301   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1302   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1303   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1304   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1305   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1306   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1307   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1308   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1309   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1310   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1311   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1312 
1313   // Complex matrices with wide overhead.
1314   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1315   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1316 
1317   // Unsupported case (add above if needed).
1318   fputs("unsupported combination of types\n", stderr);
1319   exit(1);
1320 }
1321 
1322 /// Methods that provide direct access to pointers.
1323 IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
1324 IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
1325 IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
1326 IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
1327 IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
1328 
1329 /// Methods that provide direct access to indices.
1330 IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
1331 IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
1332 IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
1333 IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
1334 IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
1335 
1336 /// Methods that provide direct access to values.
1337 IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
1338 IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
1339 IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
1340 IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
1341 IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
1342 IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
1343 IMPL_SPARSEVALUES(sparseValuesC64, complex64, getValues)
1344 IMPL_SPARSEVALUES(sparseValuesC32, complex32, getValues)
1345 
1346 /// Helper to add value to coordinate scheme, one per value type.
1347 IMPL_ADDELT(addEltF64, double)
1348 IMPL_ADDELT(addEltF32, float)
1349 IMPL_ADDELT(addEltI64, int64_t)
1350 IMPL_ADDELT(addEltI32, int32_t)
1351 IMPL_ADDELT(addEltI16, int16_t)
1352 IMPL_ADDELT(addEltI8, int8_t)
1353 IMPL_ADDELT(addEltC64, complex64)
1354 IMPL_ADDELT(addEltC32ABI, complex32)
1355 // Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
1356 // any padding (which seem to happen for complex32 when passed as scalar;
1357 // all other cases, e.g. pointer to array, work as expected).
1358 // TODO: cleaner way to avoid ABI padding problem?
1359 void *_mlir_ciface_addEltC32(void *tensor, float r, float i,
1360                              StridedMemRefType<index_type, 1> *iref,
1361                              StridedMemRefType<index_type, 1> *pref) {
1362   return _mlir_ciface_addEltC32ABI(tensor, complex32(r, i), iref, pref);
1363 }
1364 
1365 /// Helper to enumerate elements of coordinate scheme, one per value type.
1366 IMPL_GETNEXT(getNextF64, double)
1367 IMPL_GETNEXT(getNextF32, float)
1368 IMPL_GETNEXT(getNextI64, int64_t)
1369 IMPL_GETNEXT(getNextI32, int32_t)
1370 IMPL_GETNEXT(getNextI16, int16_t)
1371 IMPL_GETNEXT(getNextI8, int8_t)
1372 IMPL_GETNEXT(getNextC64, complex64)
1373 IMPL_GETNEXT(getNextC32, complex32)
1374 
1375 /// Insert elements in lexicographical index order, one per value type.
1376 IMPL_LEXINSERT(lexInsertF64, double)
1377 IMPL_LEXINSERT(lexInsertF32, float)
1378 IMPL_LEXINSERT(lexInsertI64, int64_t)
1379 IMPL_LEXINSERT(lexInsertI32, int32_t)
1380 IMPL_LEXINSERT(lexInsertI16, int16_t)
1381 IMPL_LEXINSERT(lexInsertI8, int8_t)
1382 IMPL_LEXINSERT(lexInsertC64, complex64)
1383 IMPL_LEXINSERT(lexInsertC32ABI, complex32)
1384 // Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
1385 // any padding (which seem to happen for complex32 when passed as scalar;
1386 // all other cases, e.g. pointer to array, work as expected).
1387 // TODO: cleaner way to avoid ABI padding problem?
1388 void _mlir_ciface_lexInsertC32(void *tensor,
1389                                StridedMemRefType<index_type, 1> *cref, float r,
1390                                float i) {
1391   _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
1392 }
1393 
1394 /// Insert using expansion, one per value type.
1395 IMPL_EXPINSERT(expInsertF64, double)
1396 IMPL_EXPINSERT(expInsertF32, float)
1397 IMPL_EXPINSERT(expInsertI64, int64_t)
1398 IMPL_EXPINSERT(expInsertI32, int32_t)
1399 IMPL_EXPINSERT(expInsertI16, int16_t)
1400 IMPL_EXPINSERT(expInsertI8, int8_t)
1401 IMPL_EXPINSERT(expInsertC64, complex64)
1402 IMPL_EXPINSERT(expInsertC32, complex32)
1403 
1404 #undef CASE
1405 #undef IMPL_SPARSEVALUES
1406 #undef IMPL_GETOVERHEAD
1407 #undef IMPL_ADDELT
1408 #undef IMPL_GETNEXT
1409 #undef IMPL_LEXINSERT
1410 #undef IMPL_EXPINSERT
1411 
1412 /// Output a sparse tensor, one per value type.
1413 void outSparseTensorF64(void *tensor, void *dest, bool sort) {
1414   return outSparseTensor<double>(tensor, dest, sort);
1415 }
1416 void outSparseTensorF32(void *tensor, void *dest, bool sort) {
1417   return outSparseTensor<float>(tensor, dest, sort);
1418 }
1419 void outSparseTensorI64(void *tensor, void *dest, bool sort) {
1420   return outSparseTensor<int64_t>(tensor, dest, sort);
1421 }
1422 void outSparseTensorI32(void *tensor, void *dest, bool sort) {
1423   return outSparseTensor<int32_t>(tensor, dest, sort);
1424 }
1425 void outSparseTensorI16(void *tensor, void *dest, bool sort) {
1426   return outSparseTensor<int16_t>(tensor, dest, sort);
1427 }
1428 void outSparseTensorI8(void *tensor, void *dest, bool sort) {
1429   return outSparseTensor<int8_t>(tensor, dest, sort);
1430 }
1431 void outSparseTensorC64(void *tensor, void *dest, bool sort) {
1432   return outSparseTensor<complex64>(tensor, dest, sort);
1433 }
1434 void outSparseTensorC32(void *tensor, void *dest, bool sort) {
1435   return outSparseTensor<complex32>(tensor, dest, sort);
1436 }
1437 
1438 //===----------------------------------------------------------------------===//
1439 //
1440 // Public API with methods that accept C-style data structures to interact
1441 // with sparse tensors, which are only visible as opaque pointers externally.
1442 // These methods can be used both by MLIR compiler-generated code as well as by
1443 // an external runtime that wants to interact with MLIR compiler-generated code.
1444 //
1445 //===----------------------------------------------------------------------===//
1446 
1447 /// Helper method to read a sparse tensor filename from the environment,
1448 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1449 char *getTensorFilename(index_type id) {
1450   char var[80];
1451   sprintf(var, "TENSOR%" PRIu64, id);
1452   char *env = getenv(var);
1453   if (!env) {
1454     fprintf(stderr, "Environment variable %s is not set\n", var);
1455     exit(1);
1456   }
1457   return env;
1458 }
1459 
1460 /// Returns size of sparse tensor in given dimension.
1461 index_type sparseDimSize(void *tensor, index_type d) {
1462   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1463 }
1464 
1465 /// Finalizes lexicographic insertions.
1466 void endInsert(void *tensor) {
1467   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1468 }
1469 
1470 /// Releases sparse tensor storage.
1471 void delSparseTensor(void *tensor) {
1472   delete static_cast<SparseTensorStorageBase *>(tensor);
1473 }
1474 
1475 /// Releases sparse tensor coordinate scheme.
1476 #define IMPL_DELCOO(VNAME, V)                                                  \
1477   void delSparseTensorCOO##VNAME(void *coo) {                                  \
1478     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
1479   }
1480 IMPL_DELCOO(F64, double)
1481 IMPL_DELCOO(F32, float)
1482 IMPL_DELCOO(I64, int64_t)
1483 IMPL_DELCOO(I32, int32_t)
1484 IMPL_DELCOO(I16, int16_t)
1485 IMPL_DELCOO(I8, int8_t)
1486 IMPL_DELCOO(C64, complex64)
1487 IMPL_DELCOO(C32, complex32)
1488 #undef IMPL_DELCOO
1489 
1490 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
1491 /// data structures. The expected parameters are:
1492 ///
1493 ///   rank:    rank of tensor
1494 ///   nse:     number of specified elements (usually the nonzeros)
1495 ///   shape:   array with dimension size for each rank
1496 ///   values:  a "nse" array with values for all specified elements
1497 ///   indices: a flat "nse x rank" array with indices for all specified elements
1498 ///   perm:    the permutation of the dimensions in the storage
1499 ///   sparse:  the sparsity for the dimensions
1500 ///
1501 /// For example, the sparse matrix
1502 ///     | 1.0 0.0 0.0 |
1503 ///     | 0.0 5.0 3.0 |
1504 /// can be passed as
1505 ///      rank    = 2
1506 ///      nse     = 3
1507 ///      shape   = [2, 3]
1508 ///      values  = [1.0, 5.0, 3.0]
1509 ///      indices = [ 0, 0,  1, 1,  1, 2]
1510 //
1511 // TODO: generalize beyond 64-bit indices.
1512 //
1513 void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape,
1514                                    double *values, uint64_t *indices,
1515                                    uint64_t *perm, uint8_t *sparse) {
1516   return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm,
1517                                     sparse);
1518 }
1519 void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
1520                                    float *values, uint64_t *indices,
1521                                    uint64_t *perm, uint8_t *sparse) {
1522   return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm,
1523                                    sparse);
1524 }
1525 void *convertToMLIRSparseTensorI64(uint64_t rank, uint64_t nse, uint64_t *shape,
1526                                    int64_t *values, uint64_t *indices,
1527                                    uint64_t *perm, uint8_t *sparse) {
1528   return toMLIRSparseTensor<int64_t>(rank, nse, shape, values, indices, perm,
1529                                      sparse);
1530 }
1531 void *convertToMLIRSparseTensorI32(uint64_t rank, uint64_t nse, uint64_t *shape,
1532                                    int32_t *values, uint64_t *indices,
1533                                    uint64_t *perm, uint8_t *sparse) {
1534   return toMLIRSparseTensor<int32_t>(rank, nse, shape, values, indices, perm,
1535                                      sparse);
1536 }
1537 void *convertToMLIRSparseTensorI16(uint64_t rank, uint64_t nse, uint64_t *shape,
1538                                    int16_t *values, uint64_t *indices,
1539                                    uint64_t *perm, uint8_t *sparse) {
1540   return toMLIRSparseTensor<int16_t>(rank, nse, shape, values, indices, perm,
1541                                      sparse);
1542 }
1543 void *convertToMLIRSparseTensorI8(uint64_t rank, uint64_t nse, uint64_t *shape,
1544                                   int8_t *values, uint64_t *indices,
1545                                   uint64_t *perm, uint8_t *sparse) {
1546   return toMLIRSparseTensor<int8_t>(rank, nse, shape, values, indices, perm,
1547                                     sparse);
1548 }
1549 void *convertToMLIRSparseTensorC64(uint64_t rank, uint64_t nse, uint64_t *shape,
1550                                    complex64 *values, uint64_t *indices,
1551                                    uint64_t *perm, uint8_t *sparse) {
1552   return toMLIRSparseTensor<complex64>(rank, nse, shape, values, indices, perm,
1553                                        sparse);
1554 }
1555 void *convertToMLIRSparseTensorC32(uint64_t rank, uint64_t nse, uint64_t *shape,
1556                                    complex32 *values, uint64_t *indices,
1557                                    uint64_t *perm, uint8_t *sparse) {
1558   return toMLIRSparseTensor<complex32>(rank, nse, shape, values, indices, perm,
1559                                        sparse);
1560 }
1561 
1562 /// Converts a sparse tensor to COO-flavored format expressed using C-style
1563 /// data structures. The expected output parameters are pointers for these
1564 /// values:
1565 ///
1566 ///   rank:    rank of tensor
1567 ///   nse:     number of specified elements (usually the nonzeros)
1568 ///   shape:   array with dimension size for each rank
1569 ///   values:  a "nse" array with values for all specified elements
1570 ///   indices: a flat "nse x rank" array with indices for all specified elements
1571 ///
1572 /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
1573 /// from convertToMLIRSparseTensor.
1574 ///
1575 //  TODO: Currently, values are copied from SparseTensorStorage to
1576 //  SparseTensorCOO, then to the output. We may want to reduce the number of
1577 //  copies.
1578 //
1579 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1580 // compressed
1581 //
1582 void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank,
1583                                     uint64_t *pNse, uint64_t **pShape,
1584                                     double **pValues, uint64_t **pIndices) {
1585   fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices);
1586 }
1587 void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
1588                                     uint64_t *pNse, uint64_t **pShape,
1589                                     float **pValues, uint64_t **pIndices) {
1590   fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
1591 }
1592 void convertFromMLIRSparseTensorI64(void *tensor, uint64_t *pRank,
1593                                     uint64_t *pNse, uint64_t **pShape,
1594                                     int64_t **pValues, uint64_t **pIndices) {
1595   fromMLIRSparseTensor<int64_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1596 }
1597 void convertFromMLIRSparseTensorI32(void *tensor, uint64_t *pRank,
1598                                     uint64_t *pNse, uint64_t **pShape,
1599                                     int32_t **pValues, uint64_t **pIndices) {
1600   fromMLIRSparseTensor<int32_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1601 }
1602 void convertFromMLIRSparseTensorI16(void *tensor, uint64_t *pRank,
1603                                     uint64_t *pNse, uint64_t **pShape,
1604                                     int16_t **pValues, uint64_t **pIndices) {
1605   fromMLIRSparseTensor<int16_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1606 }
1607 void convertFromMLIRSparseTensorI8(void *tensor, uint64_t *pRank,
1608                                    uint64_t *pNse, uint64_t **pShape,
1609                                    int8_t **pValues, uint64_t **pIndices) {
1610   fromMLIRSparseTensor<int8_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
1611 }
1612 void convertFromMLIRSparseTensorC64(void *tensor, uint64_t *pRank,
1613                                     uint64_t *pNse, uint64_t **pShape,
1614                                     complex64 **pValues, uint64_t **pIndices) {
1615   fromMLIRSparseTensor<complex64>(tensor, pRank, pNse, pShape, pValues,
1616                                   pIndices);
1617 }
1618 void convertFromMLIRSparseTensorC32(void *tensor, uint64_t *pRank,
1619                                     uint64_t *pNse, uint64_t **pShape,
1620                                     complex32 **pValues, uint64_t **pIndices) {
1621   fromMLIRSparseTensor<complex32>(tensor, pRank, pNse, pShape, pValues,
1622                                   pIndices);
1623 }
1624 
1625 } // extern "C"
1626 
1627 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1628