1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cctype>
25 #include <cinttypes>
26 #include <cstdio>
27 #include <cstdlib>
28 #include <cstring>
29 #include <numeric>
30 #include <vector>
31 
32 //===----------------------------------------------------------------------===//
33 //
34 // Internal support for storing and reading sparse tensors.
35 //
36 // The following memory-resident sparse storage schemes are supported:
37 //
38 // (a) A coordinate scheme for temporarily storing and lexicographically
39 //     sorting a sparse tensor by index (SparseTensorCOO).
40 //
41 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
42 //     per-dimension sparse/dense annnotations together with a dimension
43 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
44 //
45 // The following external formats are supported:
46 //
47 // (1) Matrix Market Exchange (MME): *.mtx
48 //     https://math.nist.gov/MatrixMarket/formats.html
49 //
50 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
51 //     http://frostt.io/tensors/file-formats.html
52 //
53 // Two public APIs are supported:
54 //
55 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
56 //     tensors. These methods should be used exclusively by MLIR
57 //     compiler-generated code.
58 //
59 // (II) Methods that accept C-style data structures to interact with sparse
60 //      tensors. These methods can be used by any external runtime that wants
61 //      to interact with MLIR compiler-generated code.
62 //
63 // In both cases (I) and (II), the SparseTensorStorage format is externally
64 // only visible as an opaque pointer.
65 //
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 
70 /// A sparse tensor element in coordinate scheme (value and indices).
71 /// For example, a rank-1 vector element would look like
72 ///   ({i}, a[i])
73 /// and a rank-5 tensor element like
74 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
75 template <typename V>
76 struct Element {
77   Element(const std::vector<uint64_t> &ind, V val) : indices(ind), value(val){};
78   std::vector<uint64_t> indices;
79   V value;
80 };
81 
82 /// A memory-resident sparse tensor in coordinate scheme (collection of
83 /// elements). This data structure is used to read a sparse tensor from
84 /// any external format into memory and sort the elements lexicographically
85 /// by indices before passing it back to the client (most packed storage
86 /// formats require the elements to appear in lexicographic index order).
87 template <typename V>
88 struct SparseTensorCOO {
89 public:
90   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
91       : sizes(szs), iteratorLocked(false), iteratorPos(0) {
92     if (capacity)
93       elements.reserve(capacity);
94   }
95   /// Adds element as indices and value.
96   void add(const std::vector<uint64_t> &ind, V val) {
97     assert(!iteratorLocked && "Attempt to add() after startIterator()");
98     uint64_t rank = getRank();
99     assert(rank == ind.size());
100     for (uint64_t r = 0; r < rank; r++)
101       assert(ind[r] < sizes[r]); // within bounds
102     elements.emplace_back(ind, val);
103   }
104   /// Sorts elements lexicographically by index.
105   void sort() {
106     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
107     std::sort(elements.begin(), elements.end(), lexOrder);
108   }
109   /// Returns rank.
110   uint64_t getRank() const { return sizes.size(); }
111   /// Getter for sizes array.
112   const std::vector<uint64_t> &getSizes() const { return sizes; }
113   /// Getter for elements array.
114   const std::vector<Element<V>> &getElements() const { return elements; }
115 
116   /// Switch into iterator mode.
117   void startIterator() {
118     iteratorLocked = true;
119     iteratorPos = 0;
120   }
121   /// Get the next element.
122   const Element<V> *getNext() {
123     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
124     if (iteratorPos < elements.size())
125       return &(elements[iteratorPos++]);
126     iteratorLocked = false;
127     return nullptr;
128   }
129 
130   /// Factory method. Permutes the original dimensions according to
131   /// the given ordering and expects subsequent add() calls to honor
132   /// that same ordering for the given indices. The result is a
133   /// fully permuted coordinate scheme.
134   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
135                                                 const uint64_t *sizes,
136                                                 const uint64_t *perm,
137                                                 uint64_t capacity = 0) {
138     std::vector<uint64_t> permsz(rank);
139     for (uint64_t r = 0; r < rank; r++)
140       permsz[perm[r]] = sizes[r];
141     return new SparseTensorCOO<V>(permsz, capacity);
142   }
143 
144 private:
145   /// Returns true if indices of e1 < indices of e2.
146   static bool lexOrder(const Element<V> &e1, const Element<V> &e2) {
147     uint64_t rank = e1.indices.size();
148     assert(rank == e2.indices.size());
149     for (uint64_t r = 0; r < rank; r++) {
150       if (e1.indices[r] == e2.indices[r])
151         continue;
152       return e1.indices[r] < e2.indices[r];
153     }
154     return false;
155   }
156   const std::vector<uint64_t> sizes; // per-dimension sizes
157   std::vector<Element<V>> elements;
158   bool iteratorLocked;
159   unsigned iteratorPos;
160 };
161 
162 /// Abstract base class of sparse tensor storage. Note that we use
163 /// function overloading to implement "partial" method specialization.
164 class SparseTensorStorageBase {
165 public:
166   /// Dimension size query.
167   virtual uint64_t getDimSize(uint64_t) = 0;
168 
169   /// Overhead storage.
170   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
171   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
172   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
173   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
174   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
175   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
176   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
177   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
178 
179   /// Primary storage.
180   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
181   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
182   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
183   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
184   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
185   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
186 
187   /// Element-wise insertion in lexicographic index order.
188   virtual void lexInsert(uint64_t *, double) { fatal("insf64"); }
189   virtual void lexInsert(uint64_t *, float) { fatal("insf32"); }
190   virtual void lexInsert(uint64_t *, int64_t) { fatal("insi64"); }
191   virtual void lexInsert(uint64_t *, int32_t) { fatal("insi32"); }
192   virtual void lexInsert(uint64_t *, int16_t) { fatal("ins16"); }
193   virtual void lexInsert(uint64_t *, int8_t) { fatal("insi8"); }
194 
195   /// Expanded insertion.
196   virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
197     fatal("expf64");
198   }
199   virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
200     fatal("expf32");
201   }
202   virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
203     fatal("expi64");
204   }
205   virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
206     fatal("expi32");
207   }
208   virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
209     fatal("expi16");
210   }
211   virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
212     fatal("expi8");
213   }
214 
215   /// Finishes insertion.
216   virtual void endInsert() = 0;
217 
218   virtual ~SparseTensorStorageBase() {}
219 
220 private:
221   void fatal(const char *tp) {
222     fprintf(stderr, "unsupported %s\n", tp);
223     exit(1);
224   }
225 };
226 
227 /// A memory-resident sparse tensor using a storage scheme based on
228 /// per-dimension sparse/dense annotations. This data structure provides a
229 /// bufferized form of a sparse tensor type. In contrast to generating setup
230 /// methods for each differently annotated sparse tensor, this method provides
231 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
232 /// and annotations to implement all required setup in a general manner.
233 template <typename P, typename I, typename V>
234 class SparseTensorStorage : public SparseTensorStorageBase {
235 public:
236   /// Constructs a sparse tensor storage scheme with the given dimensions,
237   /// permutation, and per-dimension dense/sparse annotations, using
238   /// the coordinate scheme tensor for the initial contents if provided.
239   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
240                       const DimLevelType *sparsity,
241                       SparseTensorCOO<V> *tensor = nullptr)
242       : sizes(szs), rev(getRank()), idx(getRank()), pointers(getRank()),
243         indices(getRank()) {
244     uint64_t rank = getRank();
245     // Store "reverse" permutation.
246     for (uint64_t r = 0; r < rank; r++)
247       rev[perm[r]] = r;
248     // Provide hints on capacity of pointers and indices.
249     // TODO: needs fine-tuning based on sparsity
250     bool allDense = true;
251     uint64_t sz = 1;
252     for (uint64_t r = 0; r < rank; r++) {
253       sz *= sizes[r];
254       if (sparsity[r] == DimLevelType::kCompressed) {
255         pointers[r].reserve(sz + 1);
256         indices[r].reserve(sz);
257         sz = 1;
258         allDense = false;
259       } else {
260         assert(sparsity[r] == DimLevelType::kDense &&
261                "singleton not yet supported");
262       }
263     }
264     // Prepare sparse pointer structures for all dimensions.
265     for (uint64_t r = 0; r < rank; r++)
266       if (sparsity[r] == DimLevelType::kCompressed)
267         pointers[r].push_back(0);
268     // Then assign contents from coordinate scheme tensor if provided.
269     if (tensor) {
270       uint64_t nnz = tensor->getElements().size();
271       values.reserve(nnz);
272       fromCOO(tensor, 0, nnz, 0);
273     } else if (allDense) {
274       values.resize(sz, 0);
275     }
276   }
277 
278   virtual ~SparseTensorStorage() {}
279 
280   /// Get the rank of the tensor.
281   uint64_t getRank() const { return sizes.size(); }
282 
283   /// Get the size in the given dimension of the tensor.
284   uint64_t getDimSize(uint64_t d) override {
285     assert(d < getRank());
286     return sizes[d];
287   }
288 
289   /// Partially specialize these getter methods based on template types.
290   void getPointers(std::vector<P> **out, uint64_t d) override {
291     assert(d < getRank());
292     *out = &pointers[d];
293   }
294   void getIndices(std::vector<I> **out, uint64_t d) override {
295     assert(d < getRank());
296     *out = &indices[d];
297   }
298   void getValues(std::vector<V> **out) override { *out = &values; }
299 
300   /// Partially specialize lexicographic insertions based on template types.
301   void lexInsert(uint64_t *cursor, V val) override {
302     // First, wrap up pending insertion path.
303     uint64_t diff = 0;
304     uint64_t top = 0;
305     if (!values.empty()) {
306       diff = lexDiff(cursor);
307       endPath(diff + 1);
308       top = idx[diff] + 1;
309     }
310     // Then continue with insertion path.
311     insPath(cursor, diff, top, val);
312   }
313 
314   /// Partially specialize expanded insertions based on template types.
315   /// Note that this method resets the values/filled-switch array back
316   /// to all-zero/false while only iterating over the nonzero elements.
317   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
318                  uint64_t count) override {
319     if (count == 0)
320       return;
321     // Sort.
322     std::sort(added, added + count);
323     // Restore insertion path for first insert.
324     uint64_t rank = getRank();
325     uint64_t index = added[0];
326     cursor[rank - 1] = index;
327     lexInsert(cursor, values[index]);
328     assert(filled[index]);
329     values[index] = 0;
330     filled[index] = false;
331     // Subsequent insertions are quick.
332     for (uint64_t i = 1; i < count; i++) {
333       assert(index < added[i] && "non-lexicographic insertion");
334       index = added[i];
335       cursor[rank - 1] = index;
336       insPath(cursor, rank - 1, added[i - 1] + 1, values[index]);
337       assert(filled[index]);
338       values[index] = 0.0;
339       filled[index] = false;
340     }
341   }
342 
343   /// Finalizes lexicographic insertions.
344   void endInsert() override {
345     if (values.empty())
346       endDim(0);
347     else
348       endPath(0);
349   }
350 
351   /// Returns this sparse tensor storage scheme as a new memory-resident
352   /// sparse tensor in coordinate scheme with the given dimension order.
353   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
354     // Restore original order of the dimension sizes and allocate coordinate
355     // scheme with desired new ordering specified in perm.
356     uint64_t rank = getRank();
357     std::vector<uint64_t> orgsz(rank);
358     for (uint64_t r = 0; r < rank; r++)
359       orgsz[rev[r]] = sizes[r];
360     SparseTensorCOO<V> *tensor = SparseTensorCOO<V>::newSparseTensorCOO(
361         rank, orgsz.data(), perm, values.size());
362     // Populate coordinate scheme restored from old ordering and changed with
363     // new ordering. Rather than applying both reorderings during the recursion,
364     // we compute the combine permutation in advance.
365     std::vector<uint64_t> reord(rank);
366     for (uint64_t r = 0; r < rank; r++)
367       reord[r] = perm[rev[r]];
368     toCOO(tensor, reord, 0, 0);
369     assert(tensor->getElements().size() == values.size());
370     return tensor;
371   }
372 
373   /// Factory method. Constructs a sparse tensor storage scheme with the given
374   /// dimensions, permutation, and per-dimension dense/sparse annotations,
375   /// using the coordinate scheme tensor for the initial contents if provided.
376   /// In the latter case, the coordinate scheme must respect the same
377   /// permutation as is desired for the new sparse tensor storage.
378   static SparseTensorStorage<P, I, V> *
379   newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm,
380                   const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) {
381     SparseTensorStorage<P, I, V> *n = nullptr;
382     if (tensor) {
383       assert(tensor->getRank() == rank);
384       for (uint64_t r = 0; r < rank; r++)
385         assert(sizes[r] == 0 || tensor->getSizes()[perm[r]] == sizes[r]);
386       tensor->sort(); // sort lexicographically
387       n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity,
388                                            tensor);
389       delete tensor;
390     } else {
391       std::vector<uint64_t> permsz(rank);
392       for (uint64_t r = 0; r < rank; r++)
393         permsz[perm[r]] = sizes[r];
394       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
395     }
396     return n;
397   }
398 
399 private:
400   /// Initializes sparse tensor storage scheme from a memory-resident sparse
401   /// tensor in coordinate scheme. This method prepares the pointers and
402   /// indices arrays under the given per-dimension dense/sparse annotations.
403   void fromCOO(SparseTensorCOO<V> *tensor, uint64_t lo, uint64_t hi,
404                uint64_t d) {
405     const std::vector<Element<V>> &elements = tensor->getElements();
406     // Once dimensions are exhausted, insert the numerical values.
407     assert(d <= getRank());
408     if (d == getRank()) {
409       assert(lo < hi && hi <= elements.size());
410       values.push_back(elements[lo].value);
411       return;
412     }
413     // Visit all elements in this interval.
414     uint64_t full = 0;
415     while (lo < hi) {
416       assert(lo < elements.size() && hi <= elements.size());
417       // Find segment in interval with same index elements in this dimension.
418       uint64_t i = elements[lo].indices[d];
419       uint64_t seg = lo + 1;
420       while (seg < hi && elements[seg].indices[d] == i)
421         seg++;
422       // Handle segment in interval for sparse or dense dimension.
423       if (isCompressedDim(d)) {
424         indices[d].push_back(i);
425       } else {
426         // For dense storage we must fill in all the zero values between
427         // the previous element (when last we ran this for-loop) and the
428         // current element.
429         for (; full < i; full++)
430           endDim(d + 1);
431         full++;
432       }
433       fromCOO(tensor, lo, seg, d + 1);
434       // And move on to next segment in interval.
435       lo = seg;
436     }
437     // Finalize the sparse pointer structure at this dimension.
438     if (isCompressedDim(d)) {
439       pointers[d].push_back(indices[d].size());
440     } else {
441       // For dense storage we must fill in all the zero values after
442       // the last element.
443       for (uint64_t sz = sizes[d]; full < sz; full++)
444         endDim(d + 1);
445     }
446   }
447 
448   /// Stores the sparse tensor storage scheme into a memory-resident sparse
449   /// tensor in coordinate scheme.
450   void toCOO(SparseTensorCOO<V> *tensor, std::vector<uint64_t> &reord,
451              uint64_t pos, uint64_t d) {
452     assert(d <= getRank());
453     if (d == getRank()) {
454       assert(pos < values.size());
455       tensor->add(idx, values[pos]);
456     } else if (isCompressedDim(d)) {
457       // Sparse dimension.
458       for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
459         idx[reord[d]] = indices[d][ii];
460         toCOO(tensor, reord, ii, d + 1);
461       }
462     } else {
463       // Dense dimension.
464       for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) {
465         idx[reord[d]] = i;
466         toCOO(tensor, reord, off + i, d + 1);
467       }
468     }
469   }
470 
471   /// Ends a deeper, never seen before dimension.
472   void endDim(uint64_t d) {
473     assert(d <= getRank());
474     if (d == getRank()) {
475       values.push_back(0);
476     } else if (isCompressedDim(d)) {
477       pointers[d].push_back(indices[d].size());
478     } else {
479       for (uint64_t full = 0, sz = sizes[d]; full < sz; full++)
480         endDim(d + 1);
481     }
482   }
483 
484   /// Wraps up a single insertion path, inner to outer.
485   void endPath(uint64_t diff) {
486     uint64_t rank = getRank();
487     assert(diff <= rank);
488     for (uint64_t i = 0; i < rank - diff; i++) {
489       uint64_t d = rank - i - 1;
490       if (isCompressedDim(d)) {
491         pointers[d].push_back(indices[d].size());
492       } else {
493         for (uint64_t full = idx[d] + 1, sz = sizes[d]; full < sz; full++)
494           endDim(d + 1);
495       }
496     }
497   }
498 
499   /// Continues a single insertion path, outer to inner.
500   void insPath(uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
501     uint64_t rank = getRank();
502     assert(diff < rank);
503     for (uint64_t d = diff; d < rank; d++) {
504       uint64_t i = cursor[d];
505       if (isCompressedDim(d)) {
506         indices[d].push_back(i);
507       } else {
508         for (uint64_t full = top; full < i; full++)
509           endDim(d + 1);
510       }
511       top = 0;
512       idx[d] = i;
513     }
514     values.push_back(val);
515   }
516 
517   /// Finds the lexicographic differing dimension.
518   uint64_t lexDiff(uint64_t *cursor) {
519     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
520       if (cursor[r] > idx[r])
521         return r;
522       else
523         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
524     assert(0 && "duplication insertion");
525     return -1u;
526   }
527 
528   /// Returns true if dimension is compressed.
529   inline bool isCompressedDim(uint64_t d) const {
530     return (!pointers[d].empty());
531   }
532 
533 private:
534   std::vector<uint64_t> sizes; // per-dimension sizes
535   std::vector<uint64_t> rev;   // "reverse" permutation
536   std::vector<uint64_t> idx;   // index cursor
537   std::vector<std::vector<P>> pointers;
538   std::vector<std::vector<I>> indices;
539   std::vector<V> values;
540 };
541 
542 /// Helper to convert string to lower case.
543 static char *toLower(char *token) {
544   for (char *c = token; *c; c++)
545     *c = tolower(*c);
546   return token;
547 }
548 
549 /// Read the MME header of a general sparse matrix of type real.
550 static void readMMEHeader(FILE *file, char *name, uint64_t *idata,
551                           bool *is_symmetric) {
552   char line[1025];
553   char header[64];
554   char object[64];
555   char format[64];
556   char field[64];
557   char symmetry[64];
558   // Read header line.
559   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
560              symmetry) != 5) {
561     fprintf(stderr, "Corrupt header in %s\n", name);
562     exit(1);
563   }
564   *is_symmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
565   // Make sure this is a general sparse matrix.
566   if (strcmp(toLower(header), "%%matrixmarket") ||
567       strcmp(toLower(object), "matrix") ||
568       strcmp(toLower(format), "coordinate") || strcmp(toLower(field), "real") ||
569       (strcmp(toLower(symmetry), "general") && !(*is_symmetric))) {
570     fprintf(stderr,
571             "Cannot find a general sparse matrix with type real in %s\n", name);
572     exit(1);
573   }
574   // Skip comments.
575   while (1) {
576     if (!fgets(line, 1025, file)) {
577       fprintf(stderr, "Cannot find data in %s\n", name);
578       exit(1);
579     }
580     if (line[0] != '%')
581       break;
582   }
583   // Next line contains M N NNZ.
584   idata[0] = 2; // rank
585   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
586              idata + 1) != 3) {
587     fprintf(stderr, "Cannot find size in %s\n", name);
588     exit(1);
589   }
590 }
591 
592 /// Read the "extended" FROSTT header. Although not part of the documented
593 /// format, we assume that the file starts with optional comments followed
594 /// by two lines that define the rank, the number of nonzeros, and the
595 /// dimensions sizes (one per rank) of the sparse tensor.
596 static void readExtFROSTTHeader(FILE *file, char *name, uint64_t *idata) {
597   char line[1025];
598   // Skip comments.
599   while (1) {
600     if (!fgets(line, 1025, file)) {
601       fprintf(stderr, "Cannot find data in %s\n", name);
602       exit(1);
603     }
604     if (line[0] != '#')
605       break;
606   }
607   // Next line contains RANK and NNZ.
608   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
609     fprintf(stderr, "Cannot find metadata in %s\n", name);
610     exit(1);
611   }
612   // Followed by a line with the dimension sizes (one per rank).
613   for (uint64_t r = 0; r < idata[0]; r++) {
614     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
615       fprintf(stderr, "Cannot find dimension size %s\n", name);
616       exit(1);
617     }
618   }
619 }
620 
621 /// Reads a sparse tensor with the given filename into a memory-resident
622 /// sparse tensor in coordinate scheme.
623 template <typename V>
624 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
625                                                const uint64_t *sizes,
626                                                const uint64_t *perm) {
627   // Open the file.
628   FILE *file = fopen(filename, "r");
629   if (!file) {
630     fprintf(stderr, "Cannot find %s\n", filename);
631     exit(1);
632   }
633   // Perform some file format dependent set up.
634   uint64_t idata[512];
635   bool is_symmetric = false;
636   if (strstr(filename, ".mtx")) {
637     readMMEHeader(file, filename, idata, &is_symmetric);
638   } else if (strstr(filename, ".tns")) {
639     readExtFROSTTHeader(file, filename, idata);
640   } else {
641     fprintf(stderr, "Unknown format %s\n", filename);
642     exit(1);
643   }
644   // Prepare sparse tensor object with per-dimension sizes
645   // and the number of nonzeros as initial capacity.
646   assert(rank == idata[0] && "rank mismatch");
647   uint64_t nnz = idata[1];
648   for (uint64_t r = 0; r < rank; r++)
649     assert((sizes[r] == 0 || sizes[r] == idata[2 + r]) &&
650            "dimension size mismatch");
651   SparseTensorCOO<V> *tensor =
652       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
653   //  Read all nonzero elements.
654   std::vector<uint64_t> indices(rank);
655   for (uint64_t k = 0; k < nnz; k++) {
656     uint64_t idx = -1u;
657     for (uint64_t r = 0; r < rank; r++) {
658       if (fscanf(file, "%" PRIu64, &idx) != 1) {
659         fprintf(stderr, "Cannot find next index in %s\n", filename);
660         exit(1);
661       }
662       // Add 0-based index.
663       indices[perm[r]] = idx - 1;
664     }
665     // The external formats always store the numerical values with the type
666     // double, but we cast these values to the sparse tensor object type.
667     double value;
668     if (fscanf(file, "%lg\n", &value) != 1) {
669       fprintf(stderr, "Cannot find next value in %s\n", filename);
670       exit(1);
671     }
672     tensor->add(indices, value);
673     // We currently chose to deal with symmetric matrices by fully constructing
674     // them. In the future, we may want to make symmetry implicit for storage
675     // reasons.
676     if (is_symmetric && indices[0] != indices[1])
677       tensor->add({indices[1], indices[0]}, value);
678   }
679   // Close the file and return tensor.
680   fclose(file);
681   return tensor;
682 }
683 
684 } // namespace
685 
686 extern "C" {
687 
688 /// This type is used in the public API at all places where MLIR expects
689 /// values with the built-in type "index". For now, we simply assume that
690 /// type is 64-bit, but targets with different "index" bit widths should link
691 /// with an alternatively built runtime support library.
692 // TODO: support such targets?
693 typedef uint64_t index_t;
694 
695 //===----------------------------------------------------------------------===//
696 //
697 // Public API with methods that operate on MLIR buffers (memrefs) to interact
698 // with sparse tensors, which are only visible as opaque pointers externally.
699 // These methods should be used exclusively by MLIR compiler-generated code.
700 //
701 // Some macro magic is used to generate implementations for all required type
702 // combinations that can be called from MLIR compiler-generated code.
703 //
704 //===----------------------------------------------------------------------===//
705 
706 #define CASE(p, i, v, P, I, V)                                                 \
707   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
708     SparseTensorCOO<V> *tensor = nullptr;                                      \
709     if (action <= Action::kFromCOO) {                                          \
710       if (action == Action::kFromFile) {                                       \
711         char *filename = static_cast<char *>(ptr);                             \
712         tensor = openSparseTensorCOO<V>(filename, rank, sizes, perm);          \
713       } else if (action == Action::kFromCOO) {                                 \
714         tensor = static_cast<SparseTensorCOO<V> *>(ptr);                       \
715       } else {                                                                 \
716         assert(action == Action::kEmpty);                                      \
717       }                                                                        \
718       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm,  \
719                                                            sparsity, tensor);  \
720     } else if (action == Action::kEmptyCOO) {                                  \
721       return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm);        \
722     } else {                                                                   \
723       tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);  \
724       if (action == Action::kToIterator) {                                     \
725         tensor->startIterator();                                               \
726       } else {                                                                 \
727         assert(action == Action::kToCOO);                                      \
728       }                                                                        \
729       return tensor;                                                           \
730     }                                                                          \
731   }
732 
733 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
734 
735 #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
736   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
737     assert(ref &&tensor);                                                      \
738     std::vector<TYPE> *v;                                                      \
739     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
740     ref->basePtr = ref->data = v->data();                                      \
741     ref->offset = 0;                                                           \
742     ref->sizes[0] = v->size();                                                 \
743     ref->strides[0] = 1;                                                       \
744   }
745 
746 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
747   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
748                            index_t d) {                                        \
749     assert(ref &&tensor);                                                      \
750     std::vector<TYPE> *v;                                                      \
751     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
752     ref->basePtr = ref->data = v->data();                                      \
753     ref->offset = 0;                                                           \
754     ref->sizes[0] = v->size();                                                 \
755     ref->strides[0] = 1;                                                       \
756   }
757 
758 #define IMPL_ADDELT(NAME, TYPE)                                                \
759   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
760                             StridedMemRefType<index_t, 1> *iref,               \
761                             StridedMemRefType<index_t, 1> *pref) {             \
762     assert(tensor &&iref &&pref);                                              \
763     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
764     assert(iref->sizes[0] == pref->sizes[0]);                                  \
765     const index_t *indx = iref->data + iref->offset;                           \
766     const index_t *perm = pref->data + pref->offset;                           \
767     uint64_t isize = iref->sizes[0];                                           \
768     std::vector<index_t> indices(isize);                                       \
769     for (uint64_t r = 0; r < isize; r++)                                       \
770       indices[perm[r]] = indx[r];                                              \
771     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
772     return tensor;                                                             \
773   }
774 
775 #define IMPL_GETNEXT(NAME, V)                                                  \
776   bool _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *iref,  \
777                            StridedMemRefType<V, 0> *vref) {                    \
778     assert(tensor &&iref &&vref);                                              \
779     assert(iref->strides[0] == 1);                                             \
780     index_t *indx = iref->data + iref->offset;                                 \
781     V *value = vref->data + vref->offset;                                      \
782     const uint64_t isize = iref->sizes[0];                                     \
783     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
784     const Element<V> *elem = iter->getNext();                                  \
785     if (elem == nullptr) {                                                     \
786       delete iter;                                                             \
787       return false;                                                            \
788     }                                                                          \
789     for (uint64_t r = 0; r < isize; r++)                                       \
790       indx[r] = elem->indices[r];                                              \
791     *value = elem->value;                                                      \
792     return true;                                                               \
793   }
794 
795 #define IMPL_LEXINSERT(NAME, V)                                                \
796   void _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *cref,  \
797                            V val) {                                            \
798     assert(tensor &&cref);                                                     \
799     assert(cref->strides[0] == 1);                                             \
800     index_t *cursor = cref->data + cref->offset;                               \
801     assert(cursor);                                                            \
802     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
803   }
804 
805 #define IMPL_EXPINSERT(NAME, V)                                                \
806   void _mlir_ciface_##NAME(                                                    \
807       void *tensor, StridedMemRefType<index_t, 1> *cref,                       \
808       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
809       StridedMemRefType<index_t, 1> *aref, index_t count) {                    \
810     assert(tensor &&cref &&vref &&fref &&aref);                                \
811     assert(cref->strides[0] == 1);                                             \
812     assert(vref->strides[0] == 1);                                             \
813     assert(fref->strides[0] == 1);                                             \
814     assert(aref->strides[0] == 1);                                             \
815     assert(vref->sizes[0] == fref->sizes[0]);                                  \
816     index_t *cursor = cref->data + cref->offset;                               \
817     V *values = vref->data + vref->offset;                                     \
818     bool *filled = fref->data + fref->offset;                                  \
819     index_t *added = aref->data + aref->offset;                                \
820     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
821         cursor, values, filled, added, count);                                 \
822   }
823 
824 /// Constructs a new sparse tensor. This is the "swiss army knife"
825 /// method for materializing sparse tensors into the computation.
826 ///
827 /// Action:
828 /// kEmpty = returns empty storage to fill later
829 /// kFromFile = returns storage, where ptr contains filename to read
830 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
831 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
832 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
833 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
834 void *
835 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
836                              StridedMemRefType<index_t, 1> *sref,
837                              StridedMemRefType<index_t, 1> *pref,
838                              OverheadType ptrTp, OverheadType indTp,
839                              PrimaryType valTp, Action action, void *ptr) {
840   assert(aref && sref && pref);
841   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
842          pref->strides[0] == 1);
843   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
844   const DimLevelType *sparsity = aref->data + aref->offset;
845   const index_t *sizes = sref->data + sref->offset;
846   const index_t *perm = pref->data + pref->offset;
847   uint64_t rank = aref->sizes[0];
848 
849   // Double matrices with all combinations of overhead storage.
850   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
851        uint64_t, double);
852   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
853        uint32_t, double);
854   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
855        uint16_t, double);
856   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
857        uint8_t, double);
858   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
859        uint64_t, double);
860   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
861        uint32_t, double);
862   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
863        uint16_t, double);
864   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
865        uint8_t, double);
866   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
867        uint64_t, double);
868   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
869        uint32_t, double);
870   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
871        uint16_t, double);
872   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
873        uint8_t, double);
874   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
875        uint64_t, double);
876   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
877        uint32_t, double);
878   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
879        uint16_t, double);
880   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
881        uint8_t, double);
882 
883   // Float matrices with all combinations of overhead storage.
884   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
885        uint64_t, float);
886   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
887        uint32_t, float);
888   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
889        uint16_t, float);
890   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
891        uint8_t, float);
892   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
893        uint64_t, float);
894   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
895        uint32_t, float);
896   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
897        uint16_t, float);
898   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
899        uint8_t, float);
900   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
901        uint64_t, float);
902   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
903        uint32_t, float);
904   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
905        uint16_t, float);
906   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
907        uint8_t, float);
908   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
909        uint64_t, float);
910   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
911        uint32_t, float);
912   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
913        uint16_t, float);
914   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
915        uint8_t, float);
916 
917   // Integral matrices with both overheads of the same type.
918   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
919   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
920   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
921   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
922   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
923   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
924   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
925   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
926   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
927   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
928   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
929   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
930   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
931 
932   // Unsupported case (add above if needed).
933   fputs("unsupported combination of types\n", stderr);
934   exit(1);
935 }
936 
937 /// Methods that provide direct access to pointers.
938 IMPL_GETOVERHEAD(sparsePointers, index_t, getPointers)
939 IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
940 IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
941 IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
942 IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
943 
944 /// Methods that provide direct access to indices.
945 IMPL_GETOVERHEAD(sparseIndices, index_t, getIndices)
946 IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
947 IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
948 IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
949 IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
950 
951 /// Methods that provide direct access to values.
952 IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
953 IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
954 IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
955 IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
956 IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
957 IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
958 
959 /// Helper to add value to coordinate scheme, one per value type.
960 IMPL_ADDELT(addEltF64, double)
961 IMPL_ADDELT(addEltF32, float)
962 IMPL_ADDELT(addEltI64, int64_t)
963 IMPL_ADDELT(addEltI32, int32_t)
964 IMPL_ADDELT(addEltI16, int16_t)
965 IMPL_ADDELT(addEltI8, int8_t)
966 
967 /// Helper to enumerate elements of coordinate scheme, one per value type.
968 IMPL_GETNEXT(getNextF64, double)
969 IMPL_GETNEXT(getNextF32, float)
970 IMPL_GETNEXT(getNextI64, int64_t)
971 IMPL_GETNEXT(getNextI32, int32_t)
972 IMPL_GETNEXT(getNextI16, int16_t)
973 IMPL_GETNEXT(getNextI8, int8_t)
974 
975 /// Helper to insert elements in lexicograph index order, one per value type.
976 IMPL_LEXINSERT(lexInsertF64, double)
977 IMPL_LEXINSERT(lexInsertF32, float)
978 IMPL_LEXINSERT(lexInsertI64, int64_t)
979 IMPL_LEXINSERT(lexInsertI32, int32_t)
980 IMPL_LEXINSERT(lexInsertI16, int16_t)
981 IMPL_LEXINSERT(lexInsertI8, int8_t)
982 
983 /// Helper to insert using expansion, one per value type.
984 IMPL_EXPINSERT(expInsertF64, double)
985 IMPL_EXPINSERT(expInsertF32, float)
986 IMPL_EXPINSERT(expInsertI64, int64_t)
987 IMPL_EXPINSERT(expInsertI32, int32_t)
988 IMPL_EXPINSERT(expInsertI16, int16_t)
989 IMPL_EXPINSERT(expInsertI8, int8_t)
990 
991 #undef CASE
992 #undef IMPL_SPARSEVALUES
993 #undef IMPL_GETOVERHEAD
994 #undef IMPL_ADDELT
995 #undef IMPL_GETNEXT
996 #undef IMPL_LEXINSERT
997 #undef IMPL_EXPINSERT
998 
999 //===----------------------------------------------------------------------===//
1000 //
1001 // Public API with methods that accept C-style data structures to interact
1002 // with sparse tensors, which are only visible as opaque pointers externally.
1003 // These methods can be used both by MLIR compiler-generated code as well as by
1004 // an external runtime that wants to interact with MLIR compiler-generated code.
1005 //
1006 //===----------------------------------------------------------------------===//
1007 
1008 /// Helper method to read a sparse tensor filename from the environment,
1009 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1010 char *getTensorFilename(index_t id) {
1011   char var[80];
1012   sprintf(var, "TENSOR%" PRIu64, id);
1013   char *env = getenv(var);
1014   return env;
1015 }
1016 
1017 /// Returns size of sparse tensor in given dimension.
1018 index_t sparseDimSize(void *tensor, index_t d) {
1019   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1020 }
1021 
1022 /// Finalizes lexicographic insertions.
1023 void endInsert(void *tensor) {
1024   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1025 }
1026 
1027 /// Releases sparse tensor storage.
1028 void delSparseTensor(void *tensor) {
1029   delete static_cast<SparseTensorStorageBase *>(tensor);
1030 }
1031 
1032 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
1033 /// data structures. The expected parameters are:
1034 ///
1035 ///   rank:    rank of tensor
1036 ///   nse:     number of specified elements (usually the nonzeros)
1037 ///   shape:   array with dimension size for each rank
1038 ///   values:  a "nse" array with values for all specified elements
1039 ///   indices: a flat "nse x rank" array with indices for all specified elements
1040 ///
1041 /// For example, the sparse matrix
1042 ///     | 1.0 0.0 0.0 |
1043 ///     | 0.0 5.0 3.0 |
1044 /// can be passed as
1045 ///      rank    = 2
1046 ///      nse     = 3
1047 ///      shape   = [2, 3]
1048 ///      values  = [1.0, 5.0, 3.0]
1049 ///      indices = [ 0, 0,  1, 1,  1, 2]
1050 //
1051 // TODO: for now f64 tensors only, no dim ordering, all dimensions compressed
1052 //
1053 void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
1054                                 double *values, uint64_t *indices) {
1055   // Setup all-dims compressed and default ordering.
1056   std::vector<DimLevelType> sparse(rank, DimLevelType::kCompressed);
1057   std::vector<uint64_t> perm(rank);
1058   std::iota(perm.begin(), perm.end(), 0);
1059   // Convert external format to internal COO.
1060   SparseTensorCOO<double> *tensor = SparseTensorCOO<double>::newSparseTensorCOO(
1061       rank, shape, perm.data(), nse);
1062   std::vector<uint64_t> idx(rank);
1063   for (uint64_t i = 0, base = 0; i < nse; i++) {
1064     for (uint64_t r = 0; r < rank; r++)
1065       idx[r] = indices[base + r];
1066     tensor->add(idx, values[i]);
1067     base += rank;
1068   }
1069   // Return sparse tensor storage format as opaque pointer.
1070   return SparseTensorStorage<uint64_t, uint64_t, double>::newSparseTensor(
1071       rank, shape, perm.data(), sparse.data(), tensor);
1072 }
1073 
1074 /// Converts a sparse tensor to COO-flavored format expressed using C-style
1075 /// data structures. The expected output parameters are pointers for these
1076 /// values:
1077 ///
1078 ///   rank:    rank of tensor
1079 ///   nse:     number of specified elements (usually the nonzeros)
1080 ///   shape:   array with dimension size for each rank
1081 ///   values:  a "nse" array with values for all specified elements
1082 ///   indices: a flat "nse x rank" array with indices for all specified elements
1083 ///
1084 /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
1085 /// from convertToMLIRSparseTensor.
1086 ///
1087 //  TODO: Currently, values are copied from SparseTensorStorage to
1088 //  SparseTensorCOO, then to the output. We may want to reduce the number of
1089 //  copies.
1090 //
1091 //  TODO: for now f64 tensors only, no dim ordering, all dimensions compressed
1092 //
1093 void convertFromMLIRSparseTensor(void *tensor, uint64_t *p_rank,
1094                                  uint64_t *p_nse, uint64_t **p_shape,
1095                                  double **p_values, uint64_t **p_indices) {
1096   SparseTensorStorage<uint64_t, uint64_t, double> *sparse_tensor =
1097       static_cast<SparseTensorStorage<uint64_t, uint64_t, double> *>(tensor);
1098   uint64_t rank = sparse_tensor->getRank();
1099   std::vector<uint64_t> perm(rank);
1100   std::iota(perm.begin(), perm.end(), 0);
1101   SparseTensorCOO<double> *coo = sparse_tensor->toCOO(perm.data());
1102 
1103   const std::vector<Element<double>> &elements = coo->getElements();
1104   uint64_t nse = elements.size();
1105 
1106   uint64_t *shape = new uint64_t[rank];
1107   for (uint64_t i = 0; i < rank; i++)
1108     shape[i] = coo->getSizes()[i];
1109 
1110   double *values = new double[nse];
1111   uint64_t *indices = new uint64_t[rank * nse];
1112 
1113   for (uint64_t i = 0, base = 0; i < nse; i++) {
1114     values[i] = elements[i].value;
1115     for (uint64_t j = 0; j < rank; j++)
1116       indices[base + j] = elements[i].indices[j];
1117     base += rank;
1118   }
1119 
1120   delete coo;
1121   *p_rank = rank;
1122   *p_nse = nse;
1123   *p_shape = shape;
1124   *p_values = values;
1125   *p_indices = indices;
1126 }
1127 } // extern "C"
1128 
1129 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1130