1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 
20 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cctype>
25 #include <cinttypes>
26 #include <cstdio>
27 #include <cstdlib>
28 #include <cstring>
29 #include <numeric>
30 #include <vector>
31 
32 //===----------------------------------------------------------------------===//
33 //
34 // Internal support for storing and reading sparse tensors.
35 //
36 // The following memory-resident sparse storage schemes are supported:
37 //
38 // (a) A coordinate scheme for temporarily storing and lexicographically
39 //     sorting a sparse tensor by index (SparseTensorCOO).
40 //
41 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
42 //     per-dimension sparse/dense annnotations together with a dimension
43 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
44 //
45 // The following external formats are supported:
46 //
47 // (1) Matrix Market Exchange (MME): *.mtx
48 //     https://math.nist.gov/MatrixMarket/formats.html
49 //
50 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
51 //     http://frostt.io/tensors/file-formats.html
52 //
53 // Two public APIs are supported:
54 //
55 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
56 //     tensors. These methods should be used exclusively by MLIR
57 //     compiler-generated code.
58 //
59 // (II) Methods that accept C-style data structures to interact with sparse
60 //      tensors. These methods can be used by any external runtime that wants
61 //      to interact with MLIR compiler-generated code.
62 //
63 // In both cases (I) and (II), the SparseTensorStorage format is externally
64 // only visible as an opaque pointer.
65 //
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 
70 /// A sparse tensor element in coordinate scheme (value and indices).
71 /// For example, a rank-1 vector element would look like
72 ///   ({i}, a[i])
73 /// and a rank-5 tensor element like
74 ///   ({i,j,k,l,m}, a[i,j,k,l,m])
75 template <typename V>
76 struct Element {
77   Element(const std::vector<uint64_t> &ind, V val) : indices(ind), value(val){};
78   std::vector<uint64_t> indices;
79   V value;
80 };
81 
82 /// A memory-resident sparse tensor in coordinate scheme (collection of
83 /// elements). This data structure is used to read a sparse tensor from
84 /// any external format into memory and sort the elements lexicographically
85 /// by indices before passing it back to the client (most packed storage
86 /// formats require the elements to appear in lexicographic index order).
87 template <typename V>
88 struct SparseTensorCOO {
89 public:
90   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
91       : sizes(szs), iteratorLocked(false), iteratorPos(0) {
92     if (capacity)
93       elements.reserve(capacity);
94   }
95   /// Adds element as indices and value.
96   void add(const std::vector<uint64_t> &ind, V val) {
97     assert(!iteratorLocked && "Attempt to add() after startIterator()");
98     uint64_t rank = getRank();
99     assert(rank == ind.size());
100     for (uint64_t r = 0; r < rank; r++)
101       assert(ind[r] < sizes[r]); // within bounds
102     elements.emplace_back(ind, val);
103   }
104   /// Sorts elements lexicographically by index.
105   void sort() {
106     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
107     std::sort(elements.begin(), elements.end(), lexOrder);
108   }
109   /// Returns rank.
110   uint64_t getRank() const { return sizes.size(); }
111   /// Getter for sizes array.
112   const std::vector<uint64_t> &getSizes() const { return sizes; }
113   /// Getter for elements array.
114   const std::vector<Element<V>> &getElements() const { return elements; }
115 
116   /// Switch into iterator mode.
117   void startIterator() {
118     iteratorLocked = true;
119     iteratorPos = 0;
120   }
121   /// Get the next element.
122   const Element<V> *getNext() {
123     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
124     if (iteratorPos < elements.size())
125       return &(elements[iteratorPos++]);
126     iteratorLocked = false;
127     return nullptr;
128   }
129 
130   /// Factory method. Permutes the original dimensions according to
131   /// the given ordering and expects subsequent add() calls to honor
132   /// that same ordering for the given indices. The result is a
133   /// fully permuted coordinate scheme.
134   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
135                                                 const uint64_t *sizes,
136                                                 const uint64_t *perm,
137                                                 uint64_t capacity = 0) {
138     std::vector<uint64_t> permsz(rank);
139     for (uint64_t r = 0; r < rank; r++)
140       permsz[perm[r]] = sizes[r];
141     return new SparseTensorCOO<V>(permsz, capacity);
142   }
143 
144 private:
145   /// Returns true if indices of e1 < indices of e2.
146   static bool lexOrder(const Element<V> &e1, const Element<V> &e2) {
147     uint64_t rank = e1.indices.size();
148     assert(rank == e2.indices.size());
149     for (uint64_t r = 0; r < rank; r++) {
150       if (e1.indices[r] == e2.indices[r])
151         continue;
152       return e1.indices[r] < e2.indices[r];
153     }
154     return false;
155   }
156   const std::vector<uint64_t> sizes; // per-dimension sizes
157   std::vector<Element<V>> elements;
158   bool iteratorLocked;
159   unsigned iteratorPos;
160 };
161 
162 /// Abstract base class of sparse tensor storage. Note that we use
163 /// function overloading to implement "partial" method specialization.
164 class SparseTensorStorageBase {
165 public:
166   // Dimension size query.
167   virtual uint64_t getDimSize(uint64_t) = 0;
168 
169   // Overhead storage.
170   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
171   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
172   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
173   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
174   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
175   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
176   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
177   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
178 
179   // Primary storage.
180   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
181   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
182   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
183   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
184   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
185   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
186 
187   // Element-wise insertion in lexicographic index order.
188   virtual void lexInsert(uint64_t *, double) { fatal("insf64"); }
189   virtual void lexInsert(uint64_t *, float) { fatal("insf32"); }
190   virtual void lexInsert(uint64_t *, int64_t) { fatal("insi64"); }
191   virtual void lexInsert(uint64_t *, int32_t) { fatal("insi32"); }
192   virtual void lexInsert(uint64_t *, int16_t) { fatal("ins16"); }
193   virtual void lexInsert(uint64_t *, int8_t) { fatal("insi8"); }
194   virtual void endInsert() = 0;
195 
196   virtual ~SparseTensorStorageBase() {}
197 
198 private:
199   void fatal(const char *tp) {
200     fprintf(stderr, "unsupported %s\n", tp);
201     exit(1);
202   }
203 };
204 
205 /// A memory-resident sparse tensor using a storage scheme based on
206 /// per-dimension sparse/dense annotations. This data structure provides a
207 /// bufferized form of a sparse tensor type. In contrast to generating setup
208 /// methods for each differently annotated sparse tensor, this method provides
209 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
210 /// and annotations to implement all required setup in a general manner.
211 template <typename P, typename I, typename V>
212 class SparseTensorStorage : public SparseTensorStorageBase {
213 public:
214   /// Constructs a sparse tensor storage scheme with the given dimensions,
215   /// permutation, and per-dimension dense/sparse annotations, using
216   /// the coordinate scheme tensor for the initial contents if provided.
217   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
218                       const DimLevelType *sparsity,
219                       SparseTensorCOO<V> *tensor = nullptr)
220       : sizes(szs), rev(getRank()), idx(getRank()), pointers(getRank()),
221         indices(getRank()) {
222     uint64_t rank = getRank();
223     // Store "reverse" permutation.
224     for (uint64_t r = 0; r < rank; r++)
225       rev[perm[r]] = r;
226     // Provide hints on capacity of pointers and indices.
227     // TODO: needs fine-tuning based on sparsity
228     bool allDense = true;
229     uint64_t sz = 1;
230     for (uint64_t r = 0; r < rank; r++) {
231       sz *= sizes[r];
232       if (sparsity[r] == DimLevelType::kCompressed) {
233         pointers[r].reserve(sz + 1);
234         indices[r].reserve(sz);
235         sz = 1;
236         allDense = false;
237       } else {
238         assert(sparsity[r] == DimLevelType::kDense &&
239                "singleton not yet supported");
240       }
241     }
242     // Prepare sparse pointer structures for all dimensions.
243     for (uint64_t r = 0; r < rank; r++)
244       if (sparsity[r] == DimLevelType::kCompressed)
245         pointers[r].push_back(0);
246     // Then assign contents from coordinate scheme tensor if provided.
247     if (tensor) {
248       uint64_t nnz = tensor->getElements().size();
249       values.reserve(nnz);
250       fromCOO(tensor, sparsity, 0, nnz, 0);
251     } else {
252       if (allDense)
253         values.resize(sz, 0);
254       for (uint64_t r = 0; r < rank; r++)
255         idx[r] = -1u;
256     }
257   }
258 
259   virtual ~SparseTensorStorage() {}
260 
261   /// Get the rank of the tensor.
262   uint64_t getRank() const { return sizes.size(); }
263 
264   /// Get the size in the given dimension of the tensor.
265   uint64_t getDimSize(uint64_t d) override {
266     assert(d < getRank());
267     return sizes[d];
268   }
269 
270   /// Partially specialize these getter methods based on template types.
271   void getPointers(std::vector<P> **out, uint64_t d) override {
272     assert(d < getRank());
273     *out = &pointers[d];
274   }
275   void getIndices(std::vector<I> **out, uint64_t d) override {
276     assert(d < getRank());
277     *out = &indices[d];
278   }
279   void getValues(std::vector<V> **out) override { *out = &values; }
280 
281   /// Partially specialize lexicographic insertions based on template types.
282   // TODO: 1-dim tensors only for now, generalize soon
283   void lexInsert(uint64_t *cursor, V val) override {
284     assert((idx[0] == -1u || idx[0] < cursor[0]) && "not lexicographic");
285     indices[0].push_back(cursor[0]);
286     values.push_back(val);
287     idx[0] = cursor[0];
288   }
289 
290   /// Finalizes lexicographic insertions.
291   void endInsert() override { pointers[0].push_back(indices[0].size()); }
292 
293   /// Returns this sparse tensor storage scheme as a new memory-resident
294   /// sparse tensor in coordinate scheme with the given dimension order.
295   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
296     // Restore original order of the dimension sizes and allocate coordinate
297     // scheme with desired new ordering specified in perm.
298     uint64_t rank = getRank();
299     std::vector<uint64_t> orgsz(rank);
300     for (uint64_t r = 0; r < rank; r++)
301       orgsz[rev[r]] = sizes[r];
302     SparseTensorCOO<V> *tensor = SparseTensorCOO<V>::newSparseTensorCOO(
303         rank, orgsz.data(), perm, values.size());
304     // Populate coordinate scheme restored from old ordering and changed with
305     // new ordering. Rather than applying both reorderings during the recursion,
306     // we compute the combine permutation in advance.
307     std::vector<uint64_t> reord(rank);
308     for (uint64_t r = 0; r < rank; r++)
309       reord[r] = perm[rev[r]];
310     toCOO(tensor, reord, 0, 0);
311     assert(tensor->getElements().size() == values.size());
312     return tensor;
313   }
314 
315   /// Factory method. Constructs a sparse tensor storage scheme with the given
316   /// dimensions, permutation, and per-dimension dense/sparse annotations,
317   /// using the coordinate scheme tensor for the initial contents if provided.
318   /// In the latter case, the coordinate scheme must respect the same
319   /// permutation as is desired for the new sparse tensor storage.
320   static SparseTensorStorage<P, I, V> *
321   newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm,
322                   const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) {
323     SparseTensorStorage<P, I, V> *n = nullptr;
324     if (tensor) {
325       assert(tensor->getRank() == rank);
326       for (uint64_t r = 0; r < rank; r++)
327         assert(sizes[r] == 0 || tensor->getSizes()[perm[r]] == sizes[r]);
328       tensor->sort(); // sort lexicographically
329       n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity,
330                                            tensor);
331       delete tensor;
332     } else {
333       std::vector<uint64_t> permsz(rank);
334       for (uint64_t r = 0; r < rank; r++)
335         permsz[perm[r]] = sizes[r];
336       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
337     }
338     return n;
339   }
340 
341 private:
342   /// Initializes sparse tensor storage scheme from a memory-resident sparse
343   /// tensor in coordinate scheme. This method prepares the pointers and
344   /// indices arrays under the given per-dimension dense/sparse annotations.
345   void fromCOO(SparseTensorCOO<V> *tensor, const DimLevelType *sparsity,
346                uint64_t lo, uint64_t hi, uint64_t d) {
347     const std::vector<Element<V>> &elements = tensor->getElements();
348     // Once dimensions are exhausted, insert the numerical values.
349     assert(d <= getRank());
350     if (d == getRank()) {
351       assert(lo >= hi || lo < elements.size());
352       values.push_back(lo < hi ? elements[lo].value : 0);
353       return;
354     }
355     // Visit all elements in this interval.
356     uint64_t full = 0;
357     while (lo < hi) {
358       assert(lo < elements.size() && hi <= elements.size());
359       // Find segment in interval with same index elements in this dimension.
360       uint64_t i = elements[lo].indices[d];
361       uint64_t seg = lo + 1;
362       while (seg < hi && elements[seg].indices[d] == i)
363         seg++;
364       // Handle segment in interval for sparse or dense dimension.
365       if (sparsity[d] == DimLevelType::kCompressed) {
366         indices[d].push_back(i);
367       } else {
368         // For dense storage we must fill in all the zero values between
369         // the previous element (when last we ran this for-loop) and the
370         // current element.
371         for (; full < i; full++)
372           fromCOO(tensor, sparsity, 0, 0, d + 1); // pass empty
373         full++;
374       }
375       fromCOO(tensor, sparsity, lo, seg, d + 1);
376       // And move on to next segment in interval.
377       lo = seg;
378     }
379     // Finalize the sparse pointer structure at this dimension.
380     if (sparsity[d] == DimLevelType::kCompressed) {
381       pointers[d].push_back(indices[d].size());
382     } else {
383       // For dense storage we must fill in all the zero values after
384       // the last element.
385       for (uint64_t sz = sizes[d]; full < sz; full++)
386         fromCOO(tensor, sparsity, 0, 0, d + 1); // pass empty
387     }
388   }
389 
390   /// Stores the sparse tensor storage scheme into a memory-resident sparse
391   /// tensor in coordinate scheme.
392   void toCOO(SparseTensorCOO<V> *tensor, std::vector<uint64_t> &reord,
393              uint64_t pos, uint64_t d) {
394     assert(d <= getRank());
395     if (d == getRank()) {
396       assert(pos < values.size());
397       tensor->add(idx, values[pos]);
398     } else if (pointers[d].empty()) {
399       // Dense dimension.
400       for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) {
401         idx[reord[d]] = i;
402         toCOO(tensor, reord, off + i, d + 1);
403       }
404     } else {
405       // Sparse dimension.
406       for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
407         idx[reord[d]] = indices[d][ii];
408         toCOO(tensor, reord, ii, d + 1);
409       }
410     }
411   }
412 
413 private:
414   std::vector<uint64_t> sizes; // per-dimension sizes
415   std::vector<uint64_t> rev;   // "reverse" permutation
416   std::vector<uint64_t> idx;   // index cursor
417   std::vector<std::vector<P>> pointers;
418   std::vector<std::vector<I>> indices;
419   std::vector<V> values;
420 };
421 
422 /// Helper to convert string to lower case.
423 static char *toLower(char *token) {
424   for (char *c = token; *c; c++)
425     *c = tolower(*c);
426   return token;
427 }
428 
429 /// Read the MME header of a general sparse matrix of type real.
430 static void readMMEHeader(FILE *file, char *name, uint64_t *idata) {
431   char line[1025];
432   char header[64];
433   char object[64];
434   char format[64];
435   char field[64];
436   char symmetry[64];
437   // Read header line.
438   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
439              symmetry) != 5) {
440     fprintf(stderr, "Corrupt header in %s\n", name);
441     exit(1);
442   }
443   // Make sure this is a general sparse matrix.
444   if (strcmp(toLower(header), "%%matrixmarket") ||
445       strcmp(toLower(object), "matrix") ||
446       strcmp(toLower(format), "coordinate") || strcmp(toLower(field), "real") ||
447       strcmp(toLower(symmetry), "general")) {
448     fprintf(stderr,
449             "Cannot find a general sparse matrix with type real in %s\n", name);
450     exit(1);
451   }
452   // Skip comments.
453   while (1) {
454     if (!fgets(line, 1025, file)) {
455       fprintf(stderr, "Cannot find data in %s\n", name);
456       exit(1);
457     }
458     if (line[0] != '%')
459       break;
460   }
461   // Next line contains M N NNZ.
462   idata[0] = 2; // rank
463   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
464              idata + 1) != 3) {
465     fprintf(stderr, "Cannot find size in %s\n", name);
466     exit(1);
467   }
468 }
469 
470 /// Read the "extended" FROSTT header. Although not part of the documented
471 /// format, we assume that the file starts with optional comments followed
472 /// by two lines that define the rank, the number of nonzeros, and the
473 /// dimensions sizes (one per rank) of the sparse tensor.
474 static void readExtFROSTTHeader(FILE *file, char *name, uint64_t *idata) {
475   char line[1025];
476   // Skip comments.
477   while (1) {
478     if (!fgets(line, 1025, file)) {
479       fprintf(stderr, "Cannot find data in %s\n", name);
480       exit(1);
481     }
482     if (line[0] != '#')
483       break;
484   }
485   // Next line contains RANK and NNZ.
486   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
487     fprintf(stderr, "Cannot find metadata in %s\n", name);
488     exit(1);
489   }
490   // Followed by a line with the dimension sizes (one per rank).
491   for (uint64_t r = 0; r < idata[0]; r++) {
492     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
493       fprintf(stderr, "Cannot find dimension size %s\n", name);
494       exit(1);
495     }
496   }
497 }
498 
499 /// Reads a sparse tensor with the given filename into a memory-resident
500 /// sparse tensor in coordinate scheme.
501 template <typename V>
502 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
503                                                const uint64_t *sizes,
504                                                const uint64_t *perm) {
505   // Open the file.
506   FILE *file = fopen(filename, "r");
507   if (!file) {
508     fprintf(stderr, "Cannot find %s\n", filename);
509     exit(1);
510   }
511   // Perform some file format dependent set up.
512   uint64_t idata[512];
513   if (strstr(filename, ".mtx")) {
514     readMMEHeader(file, filename, idata);
515   } else if (strstr(filename, ".tns")) {
516     readExtFROSTTHeader(file, filename, idata);
517   } else {
518     fprintf(stderr, "Unknown format %s\n", filename);
519     exit(1);
520   }
521   // Prepare sparse tensor object with per-dimension sizes
522   // and the number of nonzeros as initial capacity.
523   assert(rank == idata[0] && "rank mismatch");
524   uint64_t nnz = idata[1];
525   for (uint64_t r = 0; r < rank; r++)
526     assert((sizes[r] == 0 || sizes[r] == idata[2 + r]) &&
527            "dimension size mismatch");
528   SparseTensorCOO<V> *tensor =
529       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
530   //  Read all nonzero elements.
531   std::vector<uint64_t> indices(rank);
532   for (uint64_t k = 0; k < nnz; k++) {
533     uint64_t idx = -1u;
534     for (uint64_t r = 0; r < rank; r++) {
535       if (fscanf(file, "%" PRIu64, &idx) != 1) {
536         fprintf(stderr, "Cannot find next index in %s\n", filename);
537         exit(1);
538       }
539       // Add 0-based index.
540       indices[perm[r]] = idx - 1;
541     }
542     // The external formats always store the numerical values with the type
543     // double, but we cast these values to the sparse tensor object type.
544     double value;
545     if (fscanf(file, "%lg\n", &value) != 1) {
546       fprintf(stderr, "Cannot find next value in %s\n", filename);
547       exit(1);
548     }
549     tensor->add(indices, value);
550   }
551   // Close the file and return tensor.
552   fclose(file);
553   return tensor;
554 }
555 
556 } // anonymous namespace
557 
558 extern "C" {
559 
560 /// This type is used in the public API at all places where MLIR expects
561 /// values with the built-in type "index". For now, we simply assume that
562 /// type is 64-bit, but targets with different "index" bit widths should link
563 /// with an alternatively built runtime support library.
564 // TODO: support such targets?
565 typedef uint64_t index_t;
566 
567 //===----------------------------------------------------------------------===//
568 //
569 // Public API with methods that operate on MLIR buffers (memrefs) to interact
570 // with sparse tensors, which are only visible as opaque pointers externally.
571 // These methods should be used exclusively by MLIR compiler-generated code.
572 //
573 // Some macro magic is used to generate implementations for all required type
574 // combinations that can be called from MLIR compiler-generated code.
575 //
576 //===----------------------------------------------------------------------===//
577 
578 #define CASE(p, i, v, P, I, V)                                                 \
579   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
580     SparseTensorCOO<V> *tensor = nullptr;                                      \
581     if (action <= Action::kFromCOO) {                                          \
582       if (action == Action::kFromFile) {                                       \
583         char *filename = static_cast<char *>(ptr);                             \
584         tensor = openSparseTensorCOO<V>(filename, rank, sizes, perm);          \
585       } else if (action == Action::kFromCOO) {                                 \
586         tensor = static_cast<SparseTensorCOO<V> *>(ptr);                       \
587       } else {                                                                 \
588         assert(action == Action::kEmpty);                                      \
589       }                                                                        \
590       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm,  \
591                                                            sparsity, tensor);  \
592     } else if (action == Action::kEmptyCOO) {                                  \
593       return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm);        \
594     } else {                                                                   \
595       tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);  \
596       if (action == Action::kToIterator) {                                     \
597         tensor->startIterator();                                               \
598       } else {                                                                 \
599         assert(action == Action::kToCOO);                                      \
600       }                                                                        \
601       return tensor;                                                           \
602     }                                                                          \
603   }
604 
605 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
606 
607 #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
608   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
609     assert(ref);                                                               \
610     assert(tensor);                                                            \
611     std::vector<TYPE> *v;                                                      \
612     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
613     ref->basePtr = ref->data = v->data();                                      \
614     ref->offset = 0;                                                           \
615     ref->sizes[0] = v->size();                                                 \
616     ref->strides[0] = 1;                                                       \
617   }
618 
619 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
620   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
621                            index_t d) {                                        \
622     assert(ref);                                                               \
623     assert(tensor);                                                            \
624     std::vector<TYPE> *v;                                                      \
625     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
626     ref->basePtr = ref->data = v->data();                                      \
627     ref->offset = 0;                                                           \
628     ref->sizes[0] = v->size();                                                 \
629     ref->strides[0] = 1;                                                       \
630   }
631 
632 #define IMPL_ADDELT(NAME, TYPE)                                                \
633   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
634                             StridedMemRefType<index_t, 1> *iref,               \
635                             StridedMemRefType<index_t, 1> *pref) {             \
636     assert(tensor);                                                            \
637     assert(iref);                                                              \
638     assert(pref);                                                              \
639     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
640     assert(iref->sizes[0] == pref->sizes[0]);                                  \
641     const index_t *indx = iref->data + iref->offset;                           \
642     const index_t *perm = pref->data + pref->offset;                           \
643     uint64_t isize = iref->sizes[0];                                           \
644     std::vector<index_t> indices(isize);                                       \
645     for (uint64_t r = 0; r < isize; r++)                                       \
646       indices[perm[r]] = indx[r];                                              \
647     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
648     return tensor;                                                             \
649   }
650 
651 #define IMPL_GETNEXT(NAME, V)                                                  \
652   bool _mlir_ciface_##NAME(void *tensor, StridedMemRefType<uint64_t, 1> *iref, \
653                            StridedMemRefType<V, 0> *vref) {                    \
654     assert(iref->strides[0] == 1);                                             \
655     uint64_t *indx = iref->data + iref->offset;                                \
656     V *value = vref->data + vref->offset;                                      \
657     const uint64_t isize = iref->sizes[0];                                     \
658     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
659     const Element<V> *elem = iter->getNext();                                  \
660     if (elem == nullptr) {                                                     \
661       delete iter;                                                             \
662       return false;                                                            \
663     }                                                                          \
664     for (uint64_t r = 0; r < isize; r++)                                       \
665       indx[r] = elem->indices[r];                                              \
666     *value = elem->value;                                                      \
667     return true;                                                               \
668   }
669 
670 #define IMPL_LEXINSERT(NAME, V)                                                \
671   void _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *cref,  \
672                            V val) {                                            \
673     assert(cref->strides[0] == 1);                                             \
674     uint64_t *cursor = cref->data + cref->offset;                              \
675     assert(cursor);                                                            \
676     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
677   }
678 
679 /// Constructs a new sparse tensor. This is the "swiss army knife"
680 /// method for materializing sparse tensors into the computation.
681 ///
682 /// Action:
683 /// kEmpty = returns empty storage to fill later
684 /// kFromFile = returns storage, where ptr contains filename to read
685 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
686 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
687 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
688 /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
689 void *
690 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
691                              StridedMemRefType<index_t, 1> *sref,
692                              StridedMemRefType<index_t, 1> *pref,
693                              OverheadType ptrTp, OverheadType indTp,
694                              PrimaryType valTp, Action action, void *ptr) {
695   assert(aref && sref && pref);
696   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
697          pref->strides[0] == 1);
698   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
699   const DimLevelType *sparsity = aref->data + aref->offset;
700   const index_t *sizes = sref->data + sref->offset;
701   const index_t *perm = pref->data + pref->offset;
702   uint64_t rank = aref->sizes[0];
703 
704   // Double matrices with all combinations of overhead storage.
705   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
706        uint64_t, double);
707   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
708        uint32_t, double);
709   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
710        uint16_t, double);
711   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
712        uint8_t, double);
713   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
714        uint64_t, double);
715   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
716        uint32_t, double);
717   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
718        uint16_t, double);
719   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
720        uint8_t, double);
721   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
722        uint64_t, double);
723   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
724        uint32_t, double);
725   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
726        uint16_t, double);
727   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
728        uint8_t, double);
729   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
730        uint64_t, double);
731   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
732        uint32_t, double);
733   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
734        uint16_t, double);
735   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
736        uint8_t, double);
737 
738   // Float matrices with all combinations of overhead storage.
739   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
740        uint64_t, float);
741   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
742        uint32_t, float);
743   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
744        uint16_t, float);
745   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
746        uint8_t, float);
747   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
748        uint64_t, float);
749   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
750        uint32_t, float);
751   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
752        uint16_t, float);
753   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
754        uint8_t, float);
755   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
756        uint64_t, float);
757   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
758        uint32_t, float);
759   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
760        uint16_t, float);
761   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
762        uint8_t, float);
763   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
764        uint64_t, float);
765   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
766        uint32_t, float);
767   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
768        uint16_t, float);
769   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
770        uint8_t, float);
771 
772   // Integral matrices with both overheads of the same type.
773   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
774   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
775   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
776   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
777   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
778   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
779   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
780   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
781   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
782   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
783   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
784   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
785   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
786 
787   // Unsupported case (add above if needed).
788   fputs("unsupported combination of types\n", stderr);
789   exit(1);
790 }
791 
792 /// Methods that provide direct access to pointers.
793 IMPL_GETOVERHEAD(sparsePointers, index_t, getPointers)
794 IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
795 IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
796 IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
797 IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
798 
799 /// Methods that provide direct access to indices.
800 IMPL_GETOVERHEAD(sparseIndices, index_t, getIndices)
801 IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
802 IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
803 IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
804 IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
805 
806 /// Methods that provide direct access to values.
807 IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
808 IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
809 IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
810 IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
811 IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
812 IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
813 
814 /// Helper to add value to coordinate scheme, one per value type.
815 IMPL_ADDELT(addEltF64, double)
816 IMPL_ADDELT(addEltF32, float)
817 IMPL_ADDELT(addEltI64, int64_t)
818 IMPL_ADDELT(addEltI32, int32_t)
819 IMPL_ADDELT(addEltI16, int16_t)
820 IMPL_ADDELT(addEltI8, int8_t)
821 
822 /// Helper to enumerate elements of coordinate scheme, one per value type.
823 IMPL_GETNEXT(getNextF64, double)
824 IMPL_GETNEXT(getNextF32, float)
825 IMPL_GETNEXT(getNextI64, int64_t)
826 IMPL_GETNEXT(getNextI32, int32_t)
827 IMPL_GETNEXT(getNextI16, int16_t)
828 IMPL_GETNEXT(getNextI8, int8_t)
829 
830 /// Helper to insert elements in lexicograph index order, one per value type.
831 IMPL_LEXINSERT(lexInsertF64, double)
832 IMPL_LEXINSERT(lexInsertF32, float)
833 IMPL_LEXINSERT(lexInsertI64, int64_t)
834 IMPL_LEXINSERT(lexInsertI32, int32_t)
835 IMPL_LEXINSERT(lexInsertI16, int16_t)
836 IMPL_LEXINSERT(lexInsertI8, int8_t)
837 
838 #undef CASE
839 #undef IMPL_SPARSEVALUES
840 #undef IMPL_GETOVERHEAD
841 #undef IMPL_ADDELT
842 #undef IMPL_GETNEXT
843 #undef IMPL_INSERTLEX
844 
845 //===----------------------------------------------------------------------===//
846 //
847 // Public API with methods that accept C-style data structures to interact
848 // with sparse tensors, which are only visible as opaque pointers externally.
849 // These methods can be used both by MLIR compiler-generated code as well as by
850 // an external runtime that wants to interact with MLIR compiler-generated code.
851 //
852 //===----------------------------------------------------------------------===//
853 
854 /// Helper method to read a sparse tensor filename from the environment,
855 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
856 char *getTensorFilename(index_t id) {
857   char var[80];
858   sprintf(var, "TENSOR%" PRIu64, id);
859   char *env = getenv(var);
860   return env;
861 }
862 
863 /// Returns size of sparse tensor in given dimension.
864 index_t sparseDimSize(void *tensor, index_t d) {
865   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
866 }
867 
868 /// Finalizes lexicographic insertions.
869 void endInsert(void *tensor) {
870   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
871 }
872 
873 /// Releases sparse tensor storage.
874 void delSparseTensor(void *tensor) {
875   delete static_cast<SparseTensorStorageBase *>(tensor);
876 }
877 
878 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
879 /// data structures. The expected parameters are:
880 ///
881 ///   rank:    rank of tensor
882 ///   nse:     number of specified elements (usually the nonzeros)
883 ///   shape:   array with dimension size for each rank
884 ///   values:  a "nse" array with values for all specified elements
885 ///   indices: a flat "nse x rank" array with indices for all specified elements
886 ///
887 /// For example, the sparse matrix
888 ///     | 1.0 0.0 0.0 |
889 ///     | 0.0 5.0 3.0 |
890 /// can be passed as
891 ///      rank    = 2
892 ///      nse     = 3
893 ///      shape   = [2, 3]
894 ///      values  = [1.0, 5.0, 3.0]
895 ///      indices = [ 0, 0,  1, 1,  1, 2]
896 //
897 // TODO: for now f64 tensors only, no dim ordering, all dimensions compressed
898 //
899 void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
900                                 double *values, uint64_t *indices) {
901   // Setup all-dims compressed and default ordering.
902   std::vector<DimLevelType> sparse(rank, DimLevelType::kCompressed);
903   std::vector<uint64_t> perm(rank);
904   std::iota(perm.begin(), perm.end(), 0);
905   // Convert external format to internal COO.
906   SparseTensorCOO<double> *tensor = SparseTensorCOO<double>::newSparseTensorCOO(
907       rank, shape, perm.data(), nse);
908   std::vector<uint64_t> idx(rank);
909   for (uint64_t i = 0, base = 0; i < nse; i++) {
910     for (uint64_t r = 0; r < rank; r++)
911       idx[r] = indices[base + r];
912     tensor->add(idx, values[i]);
913     base += rank;
914   }
915   // Return sparse tensor storage format as opaque pointer.
916   return SparseTensorStorage<uint64_t, uint64_t, double>::newSparseTensor(
917       rank, shape, perm.data(), sparse.data(), tensor);
918 }
919 
920 } // extern "C"
921 
922 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
923