1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cctype>
25 #include <cinttypes>
26 #include <cstdio>
27 #include <cstdlib>
28 #include <cstring>
29 #include <numeric>
30 #include <vector>
31 
32 //===----------------------------------------------------------------------===//
33 //
34 // Internal support for storing and reading sparse tensors.
35 //
36 // The following memory-resident sparse storage schemes are supported:
37 //
38 // (a) A coordinate scheme for temporarily storing and lexicographically
39 //     sorting a sparse tensor by index (SparseTensorCOO).
40 //
41 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
42 //     per-dimension sparse/dense annnotations together with a dimension
43 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
44 //
45 // The following external formats are supported:
46 //
47 // (1) Matrix Market Exchange (MME): *.mtx
48 //     https://math.nist.gov/MatrixMarket/formats.html
49 //
50 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
51 //     http://frostt.io/tensors/file-formats.html
52 //
53 // Two public APIs are supported:
54 //
55 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
56 //     tensors. These methods should be used exclusively by MLIR
57 //     compiler-generated code.
58 //
59 // (II) Methods that accept C-style data structures to interact with sparse
60 //      tensors. These methods can be used by any external runtime that wants
61 //      to interact with MLIR compiler-generated code.
62 //
63 // In both cases (I) and (II), the SparseTensorStorage format is externally
64 // only visible as an opaque pointer.
65 //
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 
70 static constexpr int kColWidth = 1025;
71 
72 /// A sparse tensor element in coordinate scheme (value and indices).
73 /// For example, a rank-1 vector element would look like
74 ///   ({i}, a[i])
75 /// and a rank-5 tensor element like
76 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
77 template <typename V>
78 struct Element {
79   Element(const std::vector<uint64_t> &ind, V val) : indices(ind), value(val){};
80   std::vector<uint64_t> indices;
81   V value;
82 };
83 
84 /// A memory-resident sparse tensor in coordinate scheme (collection of
85 /// elements). This data structure is used to read a sparse tensor from
86 /// any external format into memory and sort the elements lexicographically
87 /// by indices before passing it back to the client (most packed storage
88 /// formats require the elements to appear in lexicographic index order).
89 template <typename V>
90 struct SparseTensorCOO {
91 public:
92   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
93       : sizes(szs), iteratorLocked(false), iteratorPos(0) {
94     if (capacity)
95       elements.reserve(capacity);
96   }
97   /// Adds element as indices and value.
98   void add(const std::vector<uint64_t> &ind, V val) {
99     assert(!iteratorLocked && "Attempt to add() after startIterator()");
100     uint64_t rank = getRank();
101     assert(rank == ind.size());
102     for (uint64_t r = 0; r < rank; r++)
103       assert(ind[r] < sizes[r]); // within bounds
104     elements.emplace_back(ind, val);
105   }
106   /// Sorts elements lexicographically by index.
107   void sort() {
108     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
109     // TODO: we may want to cache an `isSorted` bit, to avoid
110     // unnecessary/redundant sorting.
111     std::sort(elements.begin(), elements.end(), lexOrder);
112   }
113   /// Returns rank.
114   uint64_t getRank() const { return sizes.size(); }
115   /// Getter for sizes array.
116   const std::vector<uint64_t> &getSizes() const { return sizes; }
117   /// Getter for elements array.
118   const std::vector<Element<V>> &getElements() const { return elements; }
119 
120   /// Switch into iterator mode.
121   void startIterator() {
122     iteratorLocked = true;
123     iteratorPos = 0;
124   }
125   /// Get the next element.
126   const Element<V> *getNext() {
127     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
128     if (iteratorPos < elements.size())
129       return &(elements[iteratorPos++]);
130     iteratorLocked = false;
131     return nullptr;
132   }
133 
134   /// Factory method. Permutes the original dimensions according to
135   /// the given ordering and expects subsequent add() calls to honor
136   /// that same ordering for the given indices. The result is a
137   /// fully permuted coordinate scheme.
138   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
139                                                 const uint64_t *sizes,
140                                                 const uint64_t *perm,
141                                                 uint64_t capacity = 0) {
142     std::vector<uint64_t> permsz(rank);
143     for (uint64_t r = 0; r < rank; r++)
144       permsz[perm[r]] = sizes[r];
145     return new SparseTensorCOO<V>(permsz, capacity);
146   }
147 
148 private:
149   /// Returns true if indices of e1 < indices of e2.
150   static bool lexOrder(const Element<V> &e1, const Element<V> &e2) {
151     uint64_t rank = e1.indices.size();
152     assert(rank == e2.indices.size());
153     for (uint64_t r = 0; r < rank; r++) {
154       if (e1.indices[r] == e2.indices[r])
155         continue;
156       return e1.indices[r] < e2.indices[r];
157     }
158     return false;
159   }
160   const std::vector<uint64_t> sizes; // per-dimension sizes
161   std::vector<Element<V>> elements;
162   bool iteratorLocked;
163   unsigned iteratorPos;
164 };
165 
166 /// Abstract base class of sparse tensor storage. Note that we use
167 /// function overloading to implement "partial" method specialization.
168 class SparseTensorStorageBase {
169 public:
170   /// Dimension size query.
171   virtual uint64_t getDimSize(uint64_t) = 0;
172 
173   /// Overhead storage.
174   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
175   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
176   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
177   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
178   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
179   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
180   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
181   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
182 
183   /// Primary storage.
184   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
185   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
186   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
187   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
188   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
189   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
190 
191   /// Element-wise insertion in lexicographic index order.
192   virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
193   virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); }
194   virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); }
195   virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
196   virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
197   virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
198 
199   /// Expanded insertion.
200   virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
201     fatal("expf64");
202   }
203   virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
204     fatal("expf32");
205   }
206   virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
207     fatal("expi64");
208   }
209   virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
210     fatal("expi32");
211   }
212   virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
213     fatal("expi16");
214   }
215   virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
216     fatal("expi8");
217   }
218 
219   /// Finishes insertion.
220   virtual void endInsert() = 0;
221 
222   virtual ~SparseTensorStorageBase() = default;
223 
224 private:
225   void fatal(const char *tp) {
226     fprintf(stderr, "unsupported %s\n", tp);
227     exit(1);
228   }
229 };
230 
231 /// A memory-resident sparse tensor using a storage scheme based on
232 /// per-dimension sparse/dense annotations. This data structure provides a
233 /// bufferized form of a sparse tensor type. In contrast to generating setup
234 /// methods for each differently annotated sparse tensor, this method provides
235 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
236 /// and annotations to implement all required setup in a general manner.
237 template <typename P, typename I, typename V>
238 class SparseTensorStorage : public SparseTensorStorageBase {
239 public:
240   /// Constructs a sparse tensor storage scheme with the given dimensions,
241   /// permutation, and per-dimension dense/sparse annotations, using
242   /// the coordinate scheme tensor for the initial contents if provided.
243   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
244                       const DimLevelType *sparsity,
245                       SparseTensorCOO<V> *tensor = nullptr)
246       : sizes(szs), rev(getRank()), idx(getRank()), pointers(getRank()),
247         indices(getRank()) {
248     uint64_t rank = getRank();
249     // Store "reverse" permutation.
250     for (uint64_t r = 0; r < rank; r++)
251       rev[perm[r]] = r;
252     // Provide hints on capacity of pointers and indices.
253     // TODO: needs fine-tuning based on sparsity
254     bool allDense = true;
255     uint64_t sz = 1;
256     for (uint64_t r = 0; r < rank; r++) {
257       sz *= sizes[r];
258       if (sparsity[r] == DimLevelType::kCompressed) {
259         pointers[r].reserve(sz + 1);
260         indices[r].reserve(sz);
261         sz = 1;
262         allDense = false;
263       } else {
264         assert(sparsity[r] == DimLevelType::kDense &&
265                "singleton not yet supported");
266       }
267     }
268     // Prepare sparse pointer structures for all dimensions.
269     for (uint64_t r = 0; r < rank; r++)
270       if (sparsity[r] == DimLevelType::kCompressed)
271         pointers[r].push_back(0);
272     // Then assign contents from coordinate scheme tensor if provided.
273     if (tensor) {
274       // Lexicographically sort the tensor, to ensure precondition of `fromCOO`.
275       tensor->sort();
276       const std::vector<Element<V>> &elements = tensor->getElements();
277       uint64_t nnz = elements.size();
278       values.reserve(nnz);
279       fromCOO(elements, 0, nnz, 0);
280     } else if (allDense) {
281       values.resize(sz, 0);
282     }
283   }
284 
285   ~SparseTensorStorage() override = default;
286 
287   /// Get the rank of the tensor.
288   uint64_t getRank() const { return sizes.size(); }
289 
290   /// Get the size in the given dimension of the tensor.
291   uint64_t getDimSize(uint64_t d) override {
292     assert(d < getRank());
293     return sizes[d];
294   }
295 
296   /// Partially specialize these getter methods based on template types.
297   void getPointers(std::vector<P> **out, uint64_t d) override {
298     assert(d < getRank());
299     *out = &pointers[d];
300   }
301   void getIndices(std::vector<I> **out, uint64_t d) override {
302     assert(d < getRank());
303     *out = &indices[d];
304   }
305   void getValues(std::vector<V> **out) override { *out = &values; }
306 
307   /// Partially specialize lexicographical insertions based on template types.
308   void lexInsert(const uint64_t *cursor, V val) override {
309     // First, wrap up pending insertion path.
310     uint64_t diff = 0;
311     uint64_t top = 0;
312     if (!values.empty()) {
313       diff = lexDiff(cursor);
314       endPath(diff + 1);
315       top = idx[diff] + 1;
316     }
317     // Then continue with insertion path.
318     insPath(cursor, diff, top, val);
319   }
320 
321   /// Partially specialize expanded insertions based on template types.
322   /// Note that this method resets the values/filled-switch array back
323   /// to all-zero/false while only iterating over the nonzero elements.
324   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
325                  uint64_t count) override {
326     if (count == 0)
327       return;
328     // Sort.
329     std::sort(added, added + count);
330     // Restore insertion path for first insert.
331     uint64_t rank = getRank();
332     uint64_t index = added[0];
333     cursor[rank - 1] = index;
334     lexInsert(cursor, values[index]);
335     assert(filled[index]);
336     values[index] = 0;
337     filled[index] = false;
338     // Subsequent insertions are quick.
339     for (uint64_t i = 1; i < count; i++) {
340       assert(index < added[i] && "non-lexicographic insertion");
341       index = added[i];
342       cursor[rank - 1] = index;
343       insPath(cursor, rank - 1, added[i - 1] + 1, values[index]);
344       assert(filled[index]);
345       values[index] = 0.0;
346       filled[index] = false;
347     }
348   }
349 
350   /// Finalizes lexicographic insertions.
351   void endInsert() override {
352     if (values.empty())
353       endDim(0);
354     else
355       endPath(0);
356   }
357 
358   /// Returns this sparse tensor storage scheme as a new memory-resident
359   /// sparse tensor in coordinate scheme with the given dimension order.
360   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
361     // Restore original order of the dimension sizes and allocate coordinate
362     // scheme with desired new ordering specified in perm.
363     uint64_t rank = getRank();
364     std::vector<uint64_t> orgsz(rank);
365     for (uint64_t r = 0; r < rank; r++)
366       orgsz[rev[r]] = sizes[r];
367     SparseTensorCOO<V> *tensor = SparseTensorCOO<V>::newSparseTensorCOO(
368         rank, orgsz.data(), perm, values.size());
369     // Populate coordinate scheme restored from old ordering and changed with
370     // new ordering. Rather than applying both reorderings during the recursion,
371     // we compute the combine permutation in advance.
372     std::vector<uint64_t> reord(rank);
373     for (uint64_t r = 0; r < rank; r++)
374       reord[r] = perm[rev[r]];
375     toCOO(*tensor, reord, 0, 0);
376     assert(tensor->getElements().size() == values.size());
377     return tensor;
378   }
379 
380   /// Factory method. Constructs a sparse tensor storage scheme with the given
381   /// dimensions, permutation, and per-dimension dense/sparse annotations,
382   /// using the coordinate scheme tensor for the initial contents if provided.
383   /// In the latter case, the coordinate scheme must respect the same
384   /// permutation as is desired for the new sparse tensor storage.
385   static SparseTensorStorage<P, I, V> *
386   newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm,
387                   const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) {
388     SparseTensorStorage<P, I, V> *n = nullptr;
389     if (tensor) {
390       assert(tensor->getRank() == rank);
391       for (uint64_t r = 0; r < rank; r++)
392         assert(sizes[r] == 0 || tensor->getSizes()[perm[r]] == sizes[r]);
393       n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity,
394                                            tensor);
395       delete tensor;
396     } else {
397       std::vector<uint64_t> permsz(rank);
398       for (uint64_t r = 0; r < rank; r++)
399         permsz[perm[r]] = sizes[r];
400       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
401     }
402     return n;
403   }
404 
405 private:
406   /// Initializes sparse tensor storage scheme from a memory-resident sparse
407   /// tensor in coordinate scheme. This method prepares the pointers and
408   /// indices arrays under the given per-dimension dense/sparse annotations.
409   /// Precondition: the `elements` must be lexicographically sorted.
410   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
411                uint64_t hi, uint64_t d) {
412     // Once dimensions are exhausted, insert the numerical values.
413     assert(d <= getRank() && hi <= elements.size());
414     if (d == getRank()) {
415       assert(lo < hi);
416       values.push_back(elements[lo].value);
417       return;
418     }
419     // Visit all elements in this interval.
420     uint64_t full = 0;
421     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
422       // Find segment in interval with same index elements in this dimension.
423       uint64_t i = elements[lo].indices[d];
424       uint64_t seg = lo + 1;
425       while (seg < hi && elements[seg].indices[d] == i)
426         seg++;
427       // Handle segment in interval for sparse or dense dimension.
428       if (isCompressedDim(d)) {
429         indices[d].push_back(i);
430       } else {
431         // For dense storage we must fill in all the zero values between
432         // the previous element (when last we ran this for-loop) and the
433         // current element.
434         for (; full < i; full++)
435           endDim(d + 1);
436         full++;
437       }
438       fromCOO(elements, lo, seg, d + 1);
439       // And move on to next segment in interval.
440       lo = seg;
441     }
442     // Finalize the sparse pointer structure at this dimension.
443     if (isCompressedDim(d)) {
444       pointers[d].push_back(indices[d].size());
445     } else {
446       // For dense storage we must fill in all the zero values after
447       // the last element.
448       for (uint64_t sz = sizes[d]; full < sz; full++)
449         endDim(d + 1);
450     }
451   }
452 
453   /// Stores the sparse tensor storage scheme into a memory-resident sparse
454   /// tensor in coordinate scheme.
455   void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord,
456              uint64_t pos, uint64_t d) {
457     assert(d <= getRank());
458     if (d == getRank()) {
459       assert(pos < values.size());
460       tensor.add(idx, values[pos]);
461     } else if (isCompressedDim(d)) {
462       // Sparse dimension.
463       for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
464         idx[reord[d]] = indices[d][ii];
465         toCOO(tensor, reord, ii, d + 1);
466       }
467     } else {
468       // Dense dimension.
469       for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) {
470         idx[reord[d]] = i;
471         toCOO(tensor, reord, off + i, d + 1);
472       }
473     }
474   }
475 
476   /// Ends a deeper, never seen before dimension.
477   void endDim(uint64_t d) {
478     assert(d <= getRank());
479     if (d == getRank()) {
480       values.push_back(0);
481     } else if (isCompressedDim(d)) {
482       pointers[d].push_back(indices[d].size());
483     } else {
484       for (uint64_t full = 0, sz = sizes[d]; full < sz; full++)
485         endDim(d + 1);
486     }
487   }
488 
489   /// Wraps up a single insertion path, inner to outer.
490   void endPath(uint64_t diff) {
491     uint64_t rank = getRank();
492     assert(diff <= rank);
493     for (uint64_t i = 0; i < rank - diff; i++) {
494       uint64_t d = rank - i - 1;
495       if (isCompressedDim(d)) {
496         pointers[d].push_back(indices[d].size());
497       } else {
498         for (uint64_t full = idx[d] + 1, sz = sizes[d]; full < sz; full++)
499           endDim(d + 1);
500       }
501     }
502   }
503 
504   /// Continues a single insertion path, outer to inner.
505   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
506     uint64_t rank = getRank();
507     assert(diff < rank);
508     for (uint64_t d = diff; d < rank; d++) {
509       uint64_t i = cursor[d];
510       if (isCompressedDim(d)) {
511         indices[d].push_back(i);
512       } else {
513         for (uint64_t full = top; full < i; full++)
514           endDim(d + 1);
515       }
516       top = 0;
517       idx[d] = i;
518     }
519     values.push_back(val);
520   }
521 
522   /// Finds the lexicographic differing dimension.
523   uint64_t lexDiff(const uint64_t *cursor) {
524     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
525       if (cursor[r] > idx[r])
526         return r;
527       else
528         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
529     assert(0 && "duplication insertion");
530     return -1u;
531   }
532 
533   /// Returns true if dimension is compressed.
534   inline bool isCompressedDim(uint64_t d) const {
535     return (!pointers[d].empty());
536   }
537 
538 private:
539   std::vector<uint64_t> sizes; // per-dimension sizes
540   std::vector<uint64_t> rev;   // "reverse" permutation
541   std::vector<uint64_t> idx;   // index cursor
542   std::vector<std::vector<P>> pointers;
543   std::vector<std::vector<I>> indices;
544   std::vector<V> values;
545 };
546 
547 /// Helper to convert string to lower case.
548 static char *toLower(char *token) {
549   for (char *c = token; *c; c++)
550     *c = tolower(*c);
551   return token;
552 }
553 
554 /// Read the MME header of a general sparse matrix of type real.
555 static void readMMEHeader(FILE *file, char *filename, char *line,
556                           uint64_t *idata, bool *isSymmetric) {
557   char header[64];
558   char object[64];
559   char format[64];
560   char field[64];
561   char symmetry[64];
562   // Read header line.
563   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
564              symmetry) != 5) {
565     fprintf(stderr, "Corrupt header in %s\n", filename);
566     exit(1);
567   }
568   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
569   // Make sure this is a general sparse matrix.
570   if (strcmp(toLower(header), "%%matrixmarket") ||
571       strcmp(toLower(object), "matrix") ||
572       strcmp(toLower(format), "coordinate") || strcmp(toLower(field), "real") ||
573       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
574     fprintf(stderr,
575             "Cannot find a general sparse matrix with type real in %s\n",
576             filename);
577     exit(1);
578   }
579   // Skip comments.
580   while (true) {
581     if (!fgets(line, kColWidth, file)) {
582       fprintf(stderr, "Cannot find data in %s\n", filename);
583       exit(1);
584     }
585     if (line[0] != '%')
586       break;
587   }
588   // Next line contains M N NNZ.
589   idata[0] = 2; // rank
590   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
591              idata + 1) != 3) {
592     fprintf(stderr, "Cannot find size in %s\n", filename);
593     exit(1);
594   }
595 }
596 
597 /// Read the "extended" FROSTT header. Although not part of the documented
598 /// format, we assume that the file starts with optional comments followed
599 /// by two lines that define the rank, the number of nonzeros, and the
600 /// dimensions sizes (one per rank) of the sparse tensor.
601 static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
602                                 uint64_t *idata) {
603   // Skip comments.
604   while (true) {
605     if (!fgets(line, kColWidth, file)) {
606       fprintf(stderr, "Cannot find data in %s\n", filename);
607       exit(1);
608     }
609     if (line[0] != '#')
610       break;
611   }
612   // Next line contains RANK and NNZ.
613   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
614     fprintf(stderr, "Cannot find metadata in %s\n", filename);
615     exit(1);
616   }
617   // Followed by a line with the dimension sizes (one per rank).
618   for (uint64_t r = 0; r < idata[0]; r++) {
619     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
620       fprintf(stderr, "Cannot find dimension size %s\n", filename);
621       exit(1);
622     }
623   }
624   fgets(line, kColWidth, file); // end of line
625 }
626 
627 /// Reads a sparse tensor with the given filename into a memory-resident
628 /// sparse tensor in coordinate scheme.
629 template <typename V>
630 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
631                                                const uint64_t *sizes,
632                                                const uint64_t *perm) {
633   // Open the file.
634   FILE *file = fopen(filename, "r");
635   if (!file) {
636     fprintf(stderr, "Cannot find %s\n", filename);
637     exit(1);
638   }
639   // Perform some file format dependent set up.
640   char line[kColWidth];
641   uint64_t idata[512];
642   bool isSymmetric = false;
643   if (strstr(filename, ".mtx")) {
644     readMMEHeader(file, filename, line, idata, &isSymmetric);
645   } else if (strstr(filename, ".tns")) {
646     readExtFROSTTHeader(file, filename, line, idata);
647   } else {
648     fprintf(stderr, "Unknown format %s\n", filename);
649     exit(1);
650   }
651   // Prepare sparse tensor object with per-dimension sizes
652   // and the number of nonzeros as initial capacity.
653   assert(rank == idata[0] && "rank mismatch");
654   uint64_t nnz = idata[1];
655   for (uint64_t r = 0; r < rank; r++)
656     assert((sizes[r] == 0 || sizes[r] == idata[2 + r]) &&
657            "dimension size mismatch");
658   SparseTensorCOO<V> *tensor =
659       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
660   //  Read all nonzero elements.
661   std::vector<uint64_t> indices(rank);
662   for (uint64_t k = 0; k < nnz; k++) {
663     if (!fgets(line, kColWidth, file)) {
664       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
665       exit(1);
666     }
667     char *linePtr = line;
668     for (uint64_t r = 0; r < rank; r++) {
669       uint64_t idx = strtoul(linePtr, &linePtr, 10);
670       // Add 0-based index.
671       indices[perm[r]] = idx - 1;
672     }
673     // The external formats always store the numerical values with the type
674     // double, but we cast these values to the sparse tensor object type.
675     double value = strtod(linePtr, &linePtr);
676     tensor->add(indices, value);
677     // We currently chose to deal with symmetric matrices by fully constructing
678     // them. In the future, we may want to make symmetry implicit for storage
679     // reasons.
680     if (isSymmetric && indices[0] != indices[1])
681       tensor->add({indices[1], indices[0]}, value);
682   }
683   // Close the file and return tensor.
684   fclose(file);
685   return tensor;
686 }
687 
688 } // namespace
689 
690 extern "C" {
691 
692 //===----------------------------------------------------------------------===//
693 //
694 // Public API with methods that operate on MLIR buffers (memrefs) to interact
695 // with sparse tensors, which are only visible as opaque pointers externally.
696 // These methods should be used exclusively by MLIR compiler-generated code.
697 //
698 // Some macro magic is used to generate implementations for all required type
699 // combinations that can be called from MLIR compiler-generated code.
700 //
701 //===----------------------------------------------------------------------===//
702 
703 #define CASE(p, i, v, P, I, V)                                                 \
704   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
705     SparseTensorCOO<V> *tensor = nullptr;                                      \
706     if (action <= Action::kFromCOO) {                                          \
707       if (action == Action::kFromFile) {                                       \
708         char *filename = static_cast<char *>(ptr);                             \
709         tensor = openSparseTensorCOO<V>(filename, rank, sizes, perm);          \
710       } else if (action == Action::kFromCOO) {                                 \
711         tensor = static_cast<SparseTensorCOO<V> *>(ptr);                       \
712       } else {                                                                 \
713         assert(action == Action::kEmpty);                                      \
714       }                                                                        \
715       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm,  \
716                                                            sparsity, tensor);  \
717     }                                                                          \
718     if (action == Action::kEmptyCOO)                                           \
719       return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm);        \
720     tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);    \
721     if (action == Action::kToIterator) {                                       \
722       tensor->startIterator();                                                 \
723     } else {                                                                   \
724       assert(action == Action::kToCOO);                                        \
725     }                                                                          \
726     return tensor;                                                             \
727   }
728 
729 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
730 
731 #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
732   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
733     assert(ref &&tensor);                                                      \
734     std::vector<TYPE> *v;                                                      \
735     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
736     ref->basePtr = ref->data = v->data();                                      \
737     ref->offset = 0;                                                           \
738     ref->sizes[0] = v->size();                                                 \
739     ref->strides[0] = 1;                                                       \
740   }
741 
742 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
743   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
744                            index_t d) {                                        \
745     assert(ref &&tensor);                                                      \
746     std::vector<TYPE> *v;                                                      \
747     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
748     ref->basePtr = ref->data = v->data();                                      \
749     ref->offset = 0;                                                           \
750     ref->sizes[0] = v->size();                                                 \
751     ref->strides[0] = 1;                                                       \
752   }
753 
754 #define IMPL_ADDELT(NAME, TYPE)                                                \
755   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
756                             StridedMemRefType<index_t, 1> *iref,               \
757                             StridedMemRefType<index_t, 1> *pref) {             \
758     assert(tensor &&iref &&pref);                                              \
759     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
760     assert(iref->sizes[0] == pref->sizes[0]);                                  \
761     const index_t *indx = iref->data + iref->offset;                           \
762     const index_t *perm = pref->data + pref->offset;                           \
763     uint64_t isize = iref->sizes[0];                                           \
764     std::vector<index_t> indices(isize);                                       \
765     for (uint64_t r = 0; r < isize; r++)                                       \
766       indices[perm[r]] = indx[r];                                              \
767     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
768     return tensor;                                                             \
769   }
770 
771 #define IMPL_GETNEXT(NAME, V)                                                  \
772   bool _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *iref,  \
773                            StridedMemRefType<V, 0> *vref) {                    \
774     assert(tensor &&iref &&vref);                                              \
775     assert(iref->strides[0] == 1);                                             \
776     index_t *indx = iref->data + iref->offset;                                 \
777     V *value = vref->data + vref->offset;                                      \
778     const uint64_t isize = iref->sizes[0];                                     \
779     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
780     const Element<V> *elem = iter->getNext();                                  \
781     if (elem == nullptr) {                                                     \
782       delete iter;                                                             \
783       return false;                                                            \
784     }                                                                          \
785     for (uint64_t r = 0; r < isize; r++)                                       \
786       indx[r] = elem->indices[r];                                              \
787     *value = elem->value;                                                      \
788     return true;                                                               \
789   }
790 
791 #define IMPL_LEXINSERT(NAME, V)                                                \
792   void _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *cref,  \
793                            V val) {                                            \
794     assert(tensor &&cref);                                                     \
795     assert(cref->strides[0] == 1);                                             \
796     index_t *cursor = cref->data + cref->offset;                               \
797     assert(cursor);                                                            \
798     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
799   }
800 
801 #define IMPL_EXPINSERT(NAME, V)                                                \
802   void _mlir_ciface_##NAME(                                                    \
803       void *tensor, StridedMemRefType<index_t, 1> *cref,                       \
804       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
805       StridedMemRefType<index_t, 1> *aref, index_t count) {                    \
806     assert(tensor &&cref &&vref &&fref &&aref);                                \
807     assert(cref->strides[0] == 1);                                             \
808     assert(vref->strides[0] == 1);                                             \
809     assert(fref->strides[0] == 1);                                             \
810     assert(aref->strides[0] == 1);                                             \
811     assert(vref->sizes[0] == fref->sizes[0]);                                  \
812     index_t *cursor = cref->data + cref->offset;                               \
813     V *values = vref->data + vref->offset;                                     \
814     bool *filled = fref->data + fref->offset;                                  \
815     index_t *added = aref->data + aref->offset;                                \
816     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
817         cursor, values, filled, added, count);                                 \
818   }
819 
820 // Assume index_t is in fact uint64_t, so that _mlir_ciface_newSparseTensor
821 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
822 // that this file cannot get out of sync with its header.
823 static_assert(std::is_same<index_t, uint64_t>::value,
824               "Expected index_t == uint64_t");
825 
826 /// Constructs a new sparse tensor. This is the "swiss army knife"
827 /// method for materializing sparse tensors into the computation.
828 ///
829 /// Action:
830 /// kEmpty = returns empty storage to fill later
831 /// kFromFile = returns storage, where ptr contains filename to read
832 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
833 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
834 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
835 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
836 void *
837 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
838                              StridedMemRefType<index_t, 1> *sref,
839                              StridedMemRefType<index_t, 1> *pref,
840                              OverheadType ptrTp, OverheadType indTp,
841                              PrimaryType valTp, Action action, void *ptr) {
842   assert(aref && sref && pref);
843   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
844          pref->strides[0] == 1);
845   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
846   const DimLevelType *sparsity = aref->data + aref->offset;
847   const index_t *sizes = sref->data + sref->offset;
848   const index_t *perm = pref->data + pref->offset;
849   uint64_t rank = aref->sizes[0];
850 
851   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
852   // This is safe because of the static_assert above.
853   if (ptrTp == OverheadType::kIndex)
854     ptrTp = OverheadType::kU64;
855   if (indTp == OverheadType::kIndex)
856     indTp = OverheadType::kU64;
857 
858   // Double matrices with all combinations of overhead storage.
859   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
860        uint64_t, double);
861   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
862        uint32_t, double);
863   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
864        uint16_t, double);
865   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
866        uint8_t, double);
867   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
868        uint64_t, double);
869   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
870        uint32_t, double);
871   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
872        uint16_t, double);
873   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
874        uint8_t, double);
875   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
876        uint64_t, double);
877   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
878        uint32_t, double);
879   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
880        uint16_t, double);
881   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
882        uint8_t, double);
883   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
884        uint64_t, double);
885   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
886        uint32_t, double);
887   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
888        uint16_t, double);
889   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
890        uint8_t, double);
891 
892   // Float matrices with all combinations of overhead storage.
893   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
894        uint64_t, float);
895   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
896        uint32_t, float);
897   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
898        uint16_t, float);
899   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
900        uint8_t, float);
901   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
902        uint64_t, float);
903   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
904        uint32_t, float);
905   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
906        uint16_t, float);
907   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
908        uint8_t, float);
909   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
910        uint64_t, float);
911   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
912        uint32_t, float);
913   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
914        uint16_t, float);
915   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
916        uint8_t, float);
917   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
918        uint64_t, float);
919   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
920        uint32_t, float);
921   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
922        uint16_t, float);
923   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
924        uint8_t, float);
925 
926   // Integral matrices with both overheads of the same type.
927   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
928   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
929   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
930   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
931   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
932   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
933   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
934   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
935   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
936   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
937   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
938   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
939   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
940 
941   // Unsupported case (add above if needed).
942   fputs("unsupported combination of types\n", stderr);
943   exit(1);
944 }
945 
946 /// Methods that provide direct access to pointers.
947 IMPL_GETOVERHEAD(sparsePointers, index_t, getPointers)
948 IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
949 IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
950 IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
951 IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
952 
953 /// Methods that provide direct access to indices.
954 IMPL_GETOVERHEAD(sparseIndices, index_t, getIndices)
955 IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
956 IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
957 IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
958 IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
959 
960 /// Methods that provide direct access to values.
961 IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
962 IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
963 IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
964 IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
965 IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
966 IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
967 
968 /// Helper to add value to coordinate scheme, one per value type.
969 IMPL_ADDELT(addEltF64, double)
970 IMPL_ADDELT(addEltF32, float)
971 IMPL_ADDELT(addEltI64, int64_t)
972 IMPL_ADDELT(addEltI32, int32_t)
973 IMPL_ADDELT(addEltI16, int16_t)
974 IMPL_ADDELT(addEltI8, int8_t)
975 
976 /// Helper to enumerate elements of coordinate scheme, one per value type.
977 IMPL_GETNEXT(getNextF64, double)
978 IMPL_GETNEXT(getNextF32, float)
979 IMPL_GETNEXT(getNextI64, int64_t)
980 IMPL_GETNEXT(getNextI32, int32_t)
981 IMPL_GETNEXT(getNextI16, int16_t)
982 IMPL_GETNEXT(getNextI8, int8_t)
983 
984 /// Helper to insert elements in lexicographical index order, one per value
985 /// type.
986 IMPL_LEXINSERT(lexInsertF64, double)
987 IMPL_LEXINSERT(lexInsertF32, float)
988 IMPL_LEXINSERT(lexInsertI64, int64_t)
989 IMPL_LEXINSERT(lexInsertI32, int32_t)
990 IMPL_LEXINSERT(lexInsertI16, int16_t)
991 IMPL_LEXINSERT(lexInsertI8, int8_t)
992 
993 /// Helper to insert using expansion, one per value type.
994 IMPL_EXPINSERT(expInsertF64, double)
995 IMPL_EXPINSERT(expInsertF32, float)
996 IMPL_EXPINSERT(expInsertI64, int64_t)
997 IMPL_EXPINSERT(expInsertI32, int32_t)
998 IMPL_EXPINSERT(expInsertI16, int16_t)
999 IMPL_EXPINSERT(expInsertI8, int8_t)
1000 
1001 #undef CASE
1002 #undef IMPL_SPARSEVALUES
1003 #undef IMPL_GETOVERHEAD
1004 #undef IMPL_ADDELT
1005 #undef IMPL_GETNEXT
1006 #undef IMPL_LEXINSERT
1007 #undef IMPL_EXPINSERT
1008 
1009 //===----------------------------------------------------------------------===//
1010 //
1011 // Public API with methods that accept C-style data structures to interact
1012 // with sparse tensors, which are only visible as opaque pointers externally.
1013 // These methods can be used both by MLIR compiler-generated code as well as by
1014 // an external runtime that wants to interact with MLIR compiler-generated code.
1015 //
1016 //===----------------------------------------------------------------------===//
1017 
1018 /// Helper method to read a sparse tensor filename from the environment,
1019 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1020 char *getTensorFilename(index_t id) {
1021   char var[80];
1022   sprintf(var, "TENSOR%" PRIu64, id);
1023   char *env = getenv(var);
1024   return env;
1025 }
1026 
1027 /// Returns size of sparse tensor in given dimension.
1028 index_t sparseDimSize(void *tensor, index_t d) {
1029   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1030 }
1031 
1032 /// Finalizes lexicographic insertions.
1033 void endInsert(void *tensor) {
1034   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1035 }
1036 
1037 /// Releases sparse tensor storage.
1038 void delSparseTensor(void *tensor) {
1039   delete static_cast<SparseTensorStorageBase *>(tensor);
1040 }
1041 
1042 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
1043 /// data structures. The expected parameters are:
1044 ///
1045 ///   rank:    rank of tensor
1046 ///   nse:     number of specified elements (usually the nonzeros)
1047 ///   shape:   array with dimension size for each rank
1048 ///   values:  a "nse" array with values for all specified elements
1049 ///   indices: a flat "nse x rank" array with indices for all specified elements
1050 ///
1051 /// For example, the sparse matrix
1052 ///     | 1.0 0.0 0.0 |
1053 ///     | 0.0 5.0 3.0 |
1054 /// can be passed as
1055 ///      rank    = 2
1056 ///      nse     = 3
1057 ///      shape   = [2, 3]
1058 ///      values  = [1.0, 5.0, 3.0]
1059 ///      indices = [ 0, 0,  1, 1,  1, 2]
1060 //
1061 // TODO: for now f64 tensors only, no dim ordering, all dimensions compressed
1062 //
1063 void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
1064                                 double *values, uint64_t *indices) {
1065   // Setup all-dims compressed and default ordering.
1066   std::vector<DimLevelType> sparse(rank, DimLevelType::kCompressed);
1067   std::vector<uint64_t> perm(rank);
1068   std::iota(perm.begin(), perm.end(), 0);
1069   // Convert external format to internal COO.
1070   SparseTensorCOO<double> *tensor = SparseTensorCOO<double>::newSparseTensorCOO(
1071       rank, shape, perm.data(), nse);
1072   std::vector<uint64_t> idx(rank);
1073   for (uint64_t i = 0, base = 0; i < nse; i++) {
1074     for (uint64_t r = 0; r < rank; r++)
1075       idx[r] = indices[base + r];
1076     tensor->add(idx, values[i]);
1077     base += rank;
1078   }
1079   // Return sparse tensor storage format as opaque pointer.
1080   return SparseTensorStorage<uint64_t, uint64_t, double>::newSparseTensor(
1081       rank, shape, perm.data(), sparse.data(), tensor);
1082 }
1083 
1084 /// Converts a sparse tensor to COO-flavored format expressed using C-style
1085 /// data structures. The expected output parameters are pointers for these
1086 /// values:
1087 ///
1088 ///   rank:    rank of tensor
1089 ///   nse:     number of specified elements (usually the nonzeros)
1090 ///   shape:   array with dimension size for each rank
1091 ///   values:  a "nse" array with values for all specified elements
1092 ///   indices: a flat "nse x rank" array with indices for all specified elements
1093 ///
1094 /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
1095 /// from convertToMLIRSparseTensor.
1096 ///
1097 //  TODO: Currently, values are copied from SparseTensorStorage to
1098 //  SparseTensorCOO, then to the output. We may want to reduce the number of
1099 //  copies.
1100 //
1101 //  TODO: for now f64 tensors only, no dim ordering, all dimensions compressed
1102 //
1103 void convertFromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1104                                  uint64_t **pShape, double **pValues,
1105                                  uint64_t **pIndices) {
1106   SparseTensorStorage<uint64_t, uint64_t, double> *sparseTensor =
1107       static_cast<SparseTensorStorage<uint64_t, uint64_t, double> *>(tensor);
1108   uint64_t rank = sparseTensor->getRank();
1109   std::vector<uint64_t> perm(rank);
1110   std::iota(perm.begin(), perm.end(), 0);
1111   SparseTensorCOO<double> *coo = sparseTensor->toCOO(perm.data());
1112 
1113   const std::vector<Element<double>> &elements = coo->getElements();
1114   uint64_t nse = elements.size();
1115 
1116   uint64_t *shape = new uint64_t[rank];
1117   for (uint64_t i = 0; i < rank; i++)
1118     shape[i] = coo->getSizes()[i];
1119 
1120   double *values = new double[nse];
1121   uint64_t *indices = new uint64_t[rank * nse];
1122 
1123   for (uint64_t i = 0, base = 0; i < nse; i++) {
1124     values[i] = elements[i].value;
1125     for (uint64_t j = 0; j < rank; j++)
1126       indices[base + j] = elements[i].indices[j];
1127     base += rank;
1128   }
1129 
1130   delete coo;
1131   *pRank = rank;
1132   *pNse = nse;
1133   *pShape = shape;
1134   *pValues = values;
1135   *pIndices = indices;
1136 }
1137 } // extern "C"
1138 
1139 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1140