1 //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
18
19 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
20
21 #include <algorithm>
22 #include <cassert>
23 #include <cctype>
24 #include <cstdio>
25 #include <cstdlib>
26 #include <cstring>
27 #include <fstream>
28 #include <functional>
29 #include <iostream>
30 #include <limits>
31 #include <numeric>
32
33 //===----------------------------------------------------------------------===//
34 //
35 // Internal support for storing and reading sparse tensors.
36 //
37 // The following memory-resident sparse storage schemes are supported:
38 //
39 // (a) A coordinate scheme for temporarily storing and lexicographically
40 // sorting a sparse tensor by index (SparseTensorCOO).
41 //
42 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
43 // per-dimension sparse/dense annnotations together with a dimension
44 // ordering used by MLIR compiler-generated code (SparseTensorStorage).
45 //
46 // The following external formats are supported:
47 //
48 // (1) Matrix Market Exchange (MME): *.mtx
49 // https://math.nist.gov/MatrixMarket/formats.html
50 //
51 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
52 // http://frostt.io/tensors/file-formats.html
53 //
54 // Two public APIs are supported:
55 //
56 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
57 // tensors. These methods should be used exclusively by MLIR
58 // compiler-generated code.
59 //
60 // (II) Methods that accept C-style data structures to interact with sparse
61 // tensors. These methods can be used by any external runtime that wants
62 // to interact with MLIR compiler-generated code.
63 //
64 // In both cases (I) and (II), the SparseTensorStorage format is externally
65 // only visible as an opaque pointer.
66 //
67 //===----------------------------------------------------------------------===//
68
69 namespace {
70
71 static constexpr int kColWidth = 1025;
72
73 /// A version of `operator*` on `uint64_t` which checks for overflows.
checkedMul(uint64_t lhs,uint64_t rhs)74 static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
75 assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
76 "Integer overflow");
77 return lhs * rhs;
78 }
79
80 // This macro helps minimize repetition of this idiom, as well as ensuring
81 // we have some additional output indicating where the error is coming from.
82 // (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
83 // to track down whether an error is coming from our code vs somewhere else
84 // in MLIR.)
85 #define FATAL(...) \
86 do { \
87 fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__); \
88 exit(1); \
89 } while (0)
90
91 // TODO: try to unify this with `SparseTensorFile::assertMatchesShape`
92 // which is used by `openSparseTensorCOO`. It's easy enough to resolve
93 // the `std::vector` vs pointer mismatch for `dimSizes`; but it's trickier
94 // to resolve the presence/absence of `perm` (without introducing extra
95 // overhead), so perhaps the code duplication is unavoidable.
96 //
97 /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
98 /// semantic-order to target-order) are a refinement of the desired `shape`
99 /// (in semantic-order).
100 ///
101 /// Precondition: `perm` and `shape` must be valid for `rank`.
102 static inline void
assertPermutedSizesMatchShape(const std::vector<uint64_t> & dimSizes,uint64_t rank,const uint64_t * perm,const uint64_t * shape)103 assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes,
104 uint64_t rank, const uint64_t *perm,
105 const uint64_t *shape) {
106 assert(perm && shape);
107 assert(rank == dimSizes.size() && "Rank mismatch");
108 for (uint64_t r = 0; r < rank; r++)
109 assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) &&
110 "Dimension size mismatch");
111 }
112
113 /// A sparse tensor element in coordinate scheme (value and indices).
114 /// For example, a rank-1 vector element would look like
115 /// ({i}, a[i])
116 /// and a rank-5 tensor element like
117 /// ({i,j,k,l,m}, a[i,j,k,l,m])
118 /// We use pointer to a shared index pool rather than e.g. a direct
119 /// vector since that (1) reduces the per-element memory footprint, and
120 /// (2) centralizes the memory reservation and (re)allocation to one place.
121 template <typename V>
122 struct Element final {
Element__anon2f09016c0111::Element123 Element(uint64_t *ind, V val) : indices(ind), value(val){};
124 uint64_t *indices; // pointer into shared index pool
125 V value;
126 };
127
128 /// The type of callback functions which receive an element. We avoid
129 /// packaging the coordinates and value together as an `Element` object
130 /// because this helps keep code somewhat cleaner.
131 template <typename V>
132 using ElementConsumer =
133 const std::function<void(const std::vector<uint64_t> &, V)> &;
134
135 /// A memory-resident sparse tensor in coordinate scheme (collection of
136 /// elements). This data structure is used to read a sparse tensor from
137 /// any external format into memory and sort the elements lexicographically
138 /// by indices before passing it back to the client (most packed storage
139 /// formats require the elements to appear in lexicographic index order).
140 template <typename V>
141 struct SparseTensorCOO final {
142 public:
SparseTensorCOO__anon2f09016c0111::SparseTensorCOO143 SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
144 : dimSizes(dimSizes) {
145 if (capacity) {
146 elements.reserve(capacity);
147 indices.reserve(capacity * getRank());
148 }
149 }
150
151 /// Adds element as indices and value.
add__anon2f09016c0111::SparseTensorCOO152 void add(const std::vector<uint64_t> &ind, V val) {
153 assert(!iteratorLocked && "Attempt to add() after startIterator()");
154 uint64_t *base = indices.data();
155 uint64_t size = indices.size();
156 uint64_t rank = getRank();
157 assert(ind.size() == rank && "Element rank mismatch");
158 for (uint64_t r = 0; r < rank; r++) {
159 assert(ind[r] < dimSizes[r] && "Index is too large for the dimension");
160 indices.push_back(ind[r]);
161 }
162 // This base only changes if indices were reallocated. In that case, we
163 // need to correct all previous pointers into the vector. Note that this
164 // only happens if we did not set the initial capacity right, and then only
165 // for every internal vector reallocation (which with the doubling rule
166 // should only incur an amortized linear overhead).
167 uint64_t *newBase = indices.data();
168 if (newBase != base) {
169 for (uint64_t i = 0, n = elements.size(); i < n; i++)
170 elements[i].indices = newBase + (elements[i].indices - base);
171 base = newBase;
172 }
173 // Add element as (pointer into shared index pool, value) pair.
174 elements.emplace_back(base + size, val);
175 }
176
177 /// Sorts elements lexicographically by index.
sort__anon2f09016c0111::SparseTensorCOO178 void sort() {
179 assert(!iteratorLocked && "Attempt to sort() after startIterator()");
180 // TODO: we may want to cache an `isSorted` bit, to avoid
181 // unnecessary/redundant sorting.
182 uint64_t rank = getRank();
183 std::sort(elements.begin(), elements.end(),
184 [rank](const Element<V> &e1, const Element<V> &e2) {
185 for (uint64_t r = 0; r < rank; r++) {
186 if (e1.indices[r] == e2.indices[r])
187 continue;
188 return e1.indices[r] < e2.indices[r];
189 }
190 return false;
191 });
192 }
193
194 /// Get the rank of the tensor.
getRank__anon2f09016c0111::SparseTensorCOO195 uint64_t getRank() const { return dimSizes.size(); }
196
197 /// Getter for the dimension-sizes array.
getDimSizes__anon2f09016c0111::SparseTensorCOO198 const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
199
200 /// Getter for the elements array.
getElements__anon2f09016c0111::SparseTensorCOO201 const std::vector<Element<V>> &getElements() const { return elements; }
202
203 /// Switch into iterator mode.
startIterator__anon2f09016c0111::SparseTensorCOO204 void startIterator() {
205 iteratorLocked = true;
206 iteratorPos = 0;
207 }
208
209 /// Get the next element.
getNext__anon2f09016c0111::SparseTensorCOO210 const Element<V> *getNext() {
211 assert(iteratorLocked && "Attempt to getNext() before startIterator()");
212 if (iteratorPos < elements.size())
213 return &(elements[iteratorPos++]);
214 iteratorLocked = false;
215 return nullptr;
216 }
217
218 /// Factory method. Permutes the original dimensions according to
219 /// the given ordering and expects subsequent add() calls to honor
220 /// that same ordering for the given indices. The result is a
221 /// fully permuted coordinate scheme.
222 ///
223 /// Precondition: `dimSizes` and `perm` must be valid for `rank`.
newSparseTensorCOO__anon2f09016c0111::SparseTensorCOO224 static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
225 const uint64_t *dimSizes,
226 const uint64_t *perm,
227 uint64_t capacity = 0) {
228 std::vector<uint64_t> permsz(rank);
229 for (uint64_t r = 0; r < rank; r++) {
230 assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
231 permsz[perm[r]] = dimSizes[r];
232 }
233 return new SparseTensorCOO<V>(permsz, capacity);
234 }
235
236 private:
237 const std::vector<uint64_t> dimSizes; // per-dimension sizes
238 std::vector<Element<V>> elements; // all COO elements
239 std::vector<uint64_t> indices; // shared index pool
240 bool iteratorLocked = false;
241 unsigned iteratorPos = 0;
242 };
243
244 // Forward.
245 template <typename V>
246 class SparseTensorEnumeratorBase;
247
248 // Helper macro for generating error messages when some
249 // `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
250 // and then the wrong "partial method specialization" is called.
251 #define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
252
253 /// Abstract base class for `SparseTensorStorage<P,I,V>`. This class
254 /// takes responsibility for all the `<P,I,V>`-independent aspects
255 /// of the tensor (e.g., shape, sparsity, permutation). In addition,
256 /// we use function overloading to implement "partial" method
257 /// specialization, which the C-API relies on to catch type errors
258 /// arising from our use of opaque pointers.
259 class SparseTensorStorageBase {
260 public:
261 /// Constructs a new storage object. The `perm` maps the tensor's
262 /// semantic-ordering of dimensions to this object's storage-order.
263 /// The `dimSizes` and `sparsity` arrays are already in storage-order.
264 ///
265 /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
SparseTensorStorageBase(const std::vector<uint64_t> & dimSizes,const uint64_t * perm,const DimLevelType * sparsity)266 SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes,
267 const uint64_t *perm, const DimLevelType *sparsity)
268 : dimSizes(dimSizes), rev(getRank()),
269 dimTypes(sparsity, sparsity + getRank()) {
270 assert(perm && sparsity);
271 const uint64_t rank = getRank();
272 // Validate parameters.
273 assert(rank > 0 && "Trivial shape is unsupported");
274 for (uint64_t r = 0; r < rank; r++) {
275 assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
276 assert((dimTypes[r] == DimLevelType::kDense ||
277 dimTypes[r] == DimLevelType::kCompressed) &&
278 "Unsupported DimLevelType");
279 }
280 // Construct the "reverse" (i.e., inverse) permutation.
281 for (uint64_t r = 0; r < rank; r++)
282 rev[perm[r]] = r;
283 }
284
285 virtual ~SparseTensorStorageBase() = default;
286
287 /// Get the rank of the tensor.
getRank() const288 uint64_t getRank() const { return dimSizes.size(); }
289
290 /// Getter for the dimension-sizes array, in storage-order.
getDimSizes() const291 const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
292
293 /// Safely lookup the size of the given (storage-order) dimension.
getDimSize(uint64_t d) const294 uint64_t getDimSize(uint64_t d) const {
295 assert(d < getRank());
296 return dimSizes[d];
297 }
298
299 /// Getter for the "reverse" permutation, which maps this object's
300 /// storage-order to the tensor's semantic-order.
getRev() const301 const std::vector<uint64_t> &getRev() const { return rev; }
302
303 /// Getter for the dimension-types array, in storage-order.
getDimTypes() const304 const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
305
306 /// Safely check if the (storage-order) dimension uses compressed storage.
isCompressedDim(uint64_t d) const307 bool isCompressedDim(uint64_t d) const {
308 assert(d < getRank());
309 return (dimTypes[d] == DimLevelType::kCompressed);
310 }
311
312 /// Allocate a new enumerator.
313 #define DECL_NEWENUMERATOR(VNAME, V) \
314 virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t, \
315 const uint64_t *) const { \
316 FATAL_PIV("newEnumerator" #VNAME); \
317 }
318 FOREVERY_V(DECL_NEWENUMERATOR)
319 #undef DECL_NEWENUMERATOR
320
321 /// Overhead storage.
322 #define DECL_GETPOINTERS(PNAME, P) \
323 virtual void getPointers(std::vector<P> **, uint64_t) { \
324 FATAL_PIV("getPointers" #PNAME); \
325 }
326 FOREVERY_FIXED_O(DECL_GETPOINTERS)
327 #undef DECL_GETPOINTERS
328 #define DECL_GETINDICES(INAME, I) \
329 virtual void getIndices(std::vector<I> **, uint64_t) { \
330 FATAL_PIV("getIndices" #INAME); \
331 }
332 FOREVERY_FIXED_O(DECL_GETINDICES)
333 #undef DECL_GETINDICES
334
335 /// Primary storage.
336 #define DECL_GETVALUES(VNAME, V) \
337 virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
338 FOREVERY_V(DECL_GETVALUES)
339 #undef DECL_GETVALUES
340
341 /// Element-wise insertion in lexicographic index order.
342 #define DECL_LEXINSERT(VNAME, V) \
343 virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
344 FOREVERY_V(DECL_LEXINSERT)
345 #undef DECL_LEXINSERT
346
347 /// Expanded insertion.
348 #define DECL_EXPINSERT(VNAME, V) \
349 virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \
350 FATAL_PIV("expInsert" #VNAME); \
351 }
352 FOREVERY_V(DECL_EXPINSERT)
353 #undef DECL_EXPINSERT
354
355 /// Finishes insertion.
356 virtual void endInsert() = 0;
357
358 protected:
359 // Since this class is virtual, we must disallow public copying in
360 // order to avoid "slicing". Since this class has data members,
361 // that means making copying protected.
362 // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
363 SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
364 // Copy-assignment would be implicitly deleted (because `dimSizes`
365 // is const), so we explicitly delete it for clarity.
366 SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
367
368 private:
369 const std::vector<uint64_t> dimSizes;
370 std::vector<uint64_t> rev;
371 const std::vector<DimLevelType> dimTypes;
372 };
373
374 #undef FATAL_PIV
375
376 // Forward.
377 template <typename P, typename I, typename V>
378 class SparseTensorEnumerator;
379
380 /// A memory-resident sparse tensor using a storage scheme based on
381 /// per-dimension sparse/dense annotations. This data structure provides a
382 /// bufferized form of a sparse tensor type. In contrast to generating setup
383 /// methods for each differently annotated sparse tensor, this method provides
384 /// a convenient "one-size-fits-all" solution that simply takes an input tensor
385 /// and annotations to implement all required setup in a general manner.
386 template <typename P, typename I, typename V>
387 class SparseTensorStorage final : public SparseTensorStorageBase {
388 /// Private constructor to share code between the other constructors.
389 /// Beware that the object is not necessarily guaranteed to be in a
390 /// valid state after this constructor alone; e.g., `isCompressedDim(d)`
391 /// doesn't entail `!(pointers[d].empty())`.
392 ///
393 /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
SparseTensorStorage(const std::vector<uint64_t> & dimSizes,const uint64_t * perm,const DimLevelType * sparsity)394 SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
395 const uint64_t *perm, const DimLevelType *sparsity)
396 : SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()),
397 indices(getRank()), idx(getRank()) {}
398
399 public:
400 /// Constructs a sparse tensor storage scheme with the given dimensions,
401 /// permutation, and per-dimension dense/sparse annotations, using
402 /// the coordinate scheme tensor for the initial contents if provided.
403 ///
404 /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`.
SparseTensorStorage(const std::vector<uint64_t> & dimSizes,const uint64_t * perm,const DimLevelType * sparsity,SparseTensorCOO<V> * coo)405 SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
406 const uint64_t *perm, const DimLevelType *sparsity,
407 SparseTensorCOO<V> *coo)
408 : SparseTensorStorage(dimSizes, perm, sparsity) {
409 // Provide hints on capacity of pointers and indices.
410 // TODO: needs much fine-tuning based on actual sparsity; currently
411 // we reserve pointer/index space based on all previous dense
412 // dimensions, which works well up to first sparse dim; but
413 // we should really use nnz and dense/sparse distribution.
414 bool allDense = true;
415 uint64_t sz = 1;
416 for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
417 if (isCompressedDim(r)) {
418 // TODO: Take a parameter between 1 and `dimSizes[r]`, and multiply
419 // `sz` by that before reserving. (For now we just use 1.)
420 pointers[r].reserve(sz + 1);
421 pointers[r].push_back(0);
422 indices[r].reserve(sz);
423 sz = 1;
424 allDense = false;
425 } else { // Dense dimension.
426 sz = checkedMul(sz, getDimSizes()[r]);
427 }
428 }
429 // Then assign contents from coordinate scheme tensor if provided.
430 if (coo) {
431 // Ensure both preconditions of `fromCOO`.
432 assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch");
433 coo->sort();
434 // Now actually insert the `elements`.
435 const std::vector<Element<V>> &elements = coo->getElements();
436 uint64_t nnz = elements.size();
437 values.reserve(nnz);
438 fromCOO(elements, 0, nnz, 0);
439 } else if (allDense) {
440 values.resize(sz, 0);
441 }
442 }
443
444 /// Constructs a sparse tensor storage scheme with the given dimensions,
445 /// permutation, and per-dimension dense/sparse annotations, using
446 /// the given sparse tensor for the initial contents.
447 ///
448 /// Preconditions:
449 /// * `perm` and `sparsity` must be valid for `dimSizes.size()`.
450 /// * The `tensor` must have the same value type `V`.
451 SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
452 const uint64_t *perm, const DimLevelType *sparsity,
453 const SparseTensorStorageBase &tensor);
454
455 ~SparseTensorStorage() final = default;
456
457 /// Partially specialize these getter methods based on template types.
getPointers(std::vector<P> ** out,uint64_t d)458 void getPointers(std::vector<P> **out, uint64_t d) final {
459 assert(d < getRank());
460 *out = &pointers[d];
461 }
getIndices(std::vector<I> ** out,uint64_t d)462 void getIndices(std::vector<I> **out, uint64_t d) final {
463 assert(d < getRank());
464 *out = &indices[d];
465 }
getValues(std::vector<V> ** out)466 void getValues(std::vector<V> **out) final { *out = &values; }
467
468 /// Partially specialize lexicographical insertions based on template types.
lexInsert(const uint64_t * cursor,V val)469 void lexInsert(const uint64_t *cursor, V val) final {
470 // First, wrap up pending insertion path.
471 uint64_t diff = 0;
472 uint64_t top = 0;
473 if (!values.empty()) {
474 diff = lexDiff(cursor);
475 endPath(diff + 1);
476 top = idx[diff] + 1;
477 }
478 // Then continue with insertion path.
479 insPath(cursor, diff, top, val);
480 }
481
482 /// Partially specialize expanded insertions based on template types.
483 /// Note that this method resets the values/filled-switch array back
484 /// to all-zero/false while only iterating over the nonzero elements.
expInsert(uint64_t * cursor,V * values,bool * filled,uint64_t * added,uint64_t count)485 void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
486 uint64_t count) final {
487 if (count == 0)
488 return;
489 // Sort.
490 std::sort(added, added + count);
491 // Restore insertion path for first insert.
492 const uint64_t lastDim = getRank() - 1;
493 uint64_t index = added[0];
494 cursor[lastDim] = index;
495 lexInsert(cursor, values[index]);
496 assert(filled[index]);
497 values[index] = 0;
498 filled[index] = false;
499 // Subsequent insertions are quick.
500 for (uint64_t i = 1; i < count; i++) {
501 assert(index < added[i] && "non-lexicographic insertion");
502 index = added[i];
503 cursor[lastDim] = index;
504 insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
505 assert(filled[index]);
506 values[index] = 0;
507 filled[index] = false;
508 }
509 }
510
511 /// Finalizes lexicographic insertions.
endInsert()512 void endInsert() final {
513 if (values.empty())
514 finalizeSegment(0);
515 else
516 endPath(0);
517 }
518
newEnumerator(SparseTensorEnumeratorBase<V> ** out,uint64_t rank,const uint64_t * perm) const519 void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
520 const uint64_t *perm) const final {
521 *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
522 }
523
524 /// Returns this sparse tensor storage scheme as a new memory-resident
525 /// sparse tensor in coordinate scheme with the given dimension order.
526 ///
527 /// Precondition: `perm` must be valid for `getRank()`.
toCOO(const uint64_t * perm) const528 SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
529 SparseTensorEnumeratorBase<V> *enumerator;
530 newEnumerator(&enumerator, getRank(), perm);
531 SparseTensorCOO<V> *coo =
532 new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
533 enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
534 coo->add(ind, val);
535 });
536 // TODO: This assertion assumes there are no stored zeros,
537 // or if there are then that we don't filter them out.
538 // Cf., <https://github.com/llvm/llvm-project/issues/54179>
539 assert(coo->getElements().size() == values.size());
540 delete enumerator;
541 return coo;
542 }
543
544 /// Factory method. Constructs a sparse tensor storage scheme with the given
545 /// dimensions, permutation, and per-dimension dense/sparse annotations,
546 /// using the coordinate scheme tensor for the initial contents if provided.
547 /// In the latter case, the coordinate scheme must respect the same
548 /// permutation as is desired for the new sparse tensor storage.
549 ///
550 /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
551 static SparseTensorStorage<P, I, V> *
newSparseTensor(uint64_t rank,const uint64_t * shape,const uint64_t * perm,const DimLevelType * sparsity,SparseTensorCOO<V> * coo)552 newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
553 const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
554 SparseTensorStorage<P, I, V> *n = nullptr;
555 if (coo) {
556 const auto &coosz = coo->getDimSizes();
557 assertPermutedSizesMatchShape(coosz, rank, perm, shape);
558 n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
559 } else {
560 std::vector<uint64_t> permsz(rank);
561 for (uint64_t r = 0; r < rank; r++) {
562 assert(shape[r] > 0 && "Dimension size zero has trivial storage");
563 permsz[perm[r]] = shape[r];
564 }
565 // We pass the null `coo` to ensure we select the intended constructor.
566 n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
567 }
568 return n;
569 }
570
571 /// Factory method. Constructs a sparse tensor storage scheme with
572 /// the given dimensions, permutation, and per-dimension dense/sparse
573 /// annotations, using the sparse tensor for the initial contents.
574 ///
575 /// Preconditions:
576 /// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
577 /// * The `tensor` must have the same value type `V`.
578 static SparseTensorStorage<P, I, V> *
newSparseTensor(uint64_t rank,const uint64_t * shape,const uint64_t * perm,const DimLevelType * sparsity,const SparseTensorStorageBase * source)579 newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
580 const DimLevelType *sparsity,
581 const SparseTensorStorageBase *source) {
582 assert(source && "Got nullptr for source");
583 SparseTensorEnumeratorBase<V> *enumerator;
584 source->newEnumerator(&enumerator, rank, perm);
585 const auto &permsz = enumerator->permutedSizes();
586 assertPermutedSizesMatchShape(permsz, rank, perm, shape);
587 auto *tensor =
588 new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
589 delete enumerator;
590 return tensor;
591 }
592
593 private:
594 /// Appends an arbitrary new position to `pointers[d]`. This method
595 /// checks that `pos` is representable in the `P` type; however, it
596 /// does not check that `pos` is semantically valid (i.e., larger than
597 /// the previous position and smaller than `indices[d].capacity()`).
appendPointer(uint64_t d,uint64_t pos,uint64_t count=1)598 void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
599 assert(isCompressedDim(d));
600 assert(pos <= std::numeric_limits<P>::max() &&
601 "Pointer value is too large for the P-type");
602 pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
603 }
604
605 /// Appends index `i` to dimension `d`, in the semantically general
606 /// sense. For non-dense dimensions, that means appending to the
607 /// `indices[d]` array, checking that `i` is representable in the `I`
608 /// type; however, we do not verify other semantic requirements (e.g.,
609 /// that `i` is in bounds for `dimSizes[d]`, and not previously occurring
610 /// in the same segment). For dense dimensions, this method instead
611 /// appends the appropriate number of zeros to the `values` array,
612 /// where `full` is the number of "entries" already written to `values`
613 /// for this segment (aka one after the highest index previously appended).
appendIndex(uint64_t d,uint64_t full,uint64_t i)614 void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
615 if (isCompressedDim(d)) {
616 assert(i <= std::numeric_limits<I>::max() &&
617 "Index value is too large for the I-type");
618 indices[d].push_back(static_cast<I>(i));
619 } else { // Dense dimension.
620 assert(i >= full && "Index was already filled");
621 if (i == full)
622 return; // Short-circuit, since it'll be a nop.
623 if (d + 1 == getRank())
624 values.insert(values.end(), i - full, 0);
625 else
626 finalizeSegment(d + 1, 0, i - full);
627 }
628 }
629
630 /// Writes the given coordinate to `indices[d][pos]`. This method
631 /// checks that `i` is representable in the `I` type; however, it
632 /// does not check that `i` is semantically valid (i.e., in bounds
633 /// for `dimSizes[d]` and not elsewhere occurring in the same segment).
writeIndex(uint64_t d,uint64_t pos,uint64_t i)634 void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
635 assert(isCompressedDim(d));
636 // Subscript assignment to `std::vector` requires that the `pos`-th
637 // entry has been initialized; thus we must be sure to check `size()`
638 // here, instead of `capacity()` as would be ideal.
639 assert(pos < indices[d].size() && "Index position is out of bounds");
640 assert(i <= std::numeric_limits<I>::max() &&
641 "Index value is too large for the I-type");
642 indices[d][pos] = static_cast<I>(i);
643 }
644
645 /// Computes the assembled-size associated with the `d`-th dimension,
646 /// given the assembled-size associated with the `(d-1)`-th dimension.
647 /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
648 /// storage, as opposed to "dimension-sizes" which are the cardinality
649 /// of coordinates for that dimension.
650 ///
651 /// Precondition: the `pointers[d]` array must be fully initialized
652 /// before calling this method.
assembledSize(uint64_t parentSz,uint64_t d) const653 uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
654 if (isCompressedDim(d))
655 return pointers[d][parentSz];
656 // else if dense:
657 return parentSz * getDimSizes()[d];
658 }
659
660 /// Initializes sparse tensor storage scheme from a memory-resident sparse
661 /// tensor in coordinate scheme. This method prepares the pointers and
662 /// indices arrays under the given per-dimension dense/sparse annotations.
663 ///
664 /// Preconditions:
665 /// (1) the `elements` must be lexicographically sorted.
666 /// (2) the indices of every element are valid for `dimSizes` (equal rank
667 /// and pointwise less-than).
fromCOO(const std::vector<Element<V>> & elements,uint64_t lo,uint64_t hi,uint64_t d)668 void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
669 uint64_t hi, uint64_t d) {
670 uint64_t rank = getRank();
671 assert(d <= rank && hi <= elements.size());
672 // Once dimensions are exhausted, insert the numerical values.
673 if (d == rank) {
674 assert(lo < hi);
675 values.push_back(elements[lo].value);
676 return;
677 }
678 // Visit all elements in this interval.
679 uint64_t full = 0;
680 while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
681 // Find segment in interval with same index elements in this dimension.
682 uint64_t i = elements[lo].indices[d];
683 uint64_t seg = lo + 1;
684 while (seg < hi && elements[seg].indices[d] == i)
685 seg++;
686 // Handle segment in interval for sparse or dense dimension.
687 appendIndex(d, full, i);
688 full = i + 1;
689 fromCOO(elements, lo, seg, d + 1);
690 // And move on to next segment in interval.
691 lo = seg;
692 }
693 // Finalize the sparse pointer structure at this dimension.
694 finalizeSegment(d, full);
695 }
696
697 /// Finalize the sparse pointer structure at this dimension.
finalizeSegment(uint64_t d,uint64_t full=0,uint64_t count=1)698 void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
699 if (count == 0)
700 return; // Short-circuit, since it'll be a nop.
701 if (isCompressedDim(d)) {
702 appendPointer(d, indices[d].size(), count);
703 } else { // Dense dimension.
704 const uint64_t sz = getDimSizes()[d];
705 assert(sz >= full && "Segment is overfull");
706 count = checkedMul(count, sz - full);
707 // For dense storage we must enumerate all the remaining coordinates
708 // in this dimension (i.e., coordinates after the last non-zero
709 // element), and either fill in their zero values or else recurse
710 // to finalize some deeper dimension.
711 if (d + 1 == getRank())
712 values.insert(values.end(), count, 0);
713 else
714 finalizeSegment(d + 1, 0, count);
715 }
716 }
717
718 /// Wraps up a single insertion path, inner to outer.
endPath(uint64_t diff)719 void endPath(uint64_t diff) {
720 uint64_t rank = getRank();
721 assert(diff <= rank);
722 for (uint64_t i = 0; i < rank - diff; i++) {
723 const uint64_t d = rank - i - 1;
724 finalizeSegment(d, idx[d] + 1);
725 }
726 }
727
728 /// Continues a single insertion path, outer to inner.
insPath(const uint64_t * cursor,uint64_t diff,uint64_t top,V val)729 void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
730 uint64_t rank = getRank();
731 assert(diff < rank);
732 for (uint64_t d = diff; d < rank; d++) {
733 uint64_t i = cursor[d];
734 appendIndex(d, top, i);
735 top = 0;
736 idx[d] = i;
737 }
738 values.push_back(val);
739 }
740
741 /// Finds the lexicographic differing dimension.
lexDiff(const uint64_t * cursor) const742 uint64_t lexDiff(const uint64_t *cursor) const {
743 for (uint64_t r = 0, rank = getRank(); r < rank; r++)
744 if (cursor[r] > idx[r])
745 return r;
746 else
747 assert(cursor[r] == idx[r] && "non-lexicographic insertion");
748 assert(0 && "duplication insertion");
749 return -1u;
750 }
751
752 // Allow `SparseTensorEnumerator` to access the data-members (to avoid
753 // the cost of virtual-function dispatch in inner loops), without
754 // making them public to other client code.
755 friend class SparseTensorEnumerator<P, I, V>;
756
757 std::vector<std::vector<P>> pointers;
758 std::vector<std::vector<I>> indices;
759 std::vector<V> values;
760 std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
761 };
762
763 /// A (higher-order) function object for enumerating the elements of some
764 /// `SparseTensorStorage` under a permutation. That is, the `forallElements`
765 /// method encapsulates the loop-nest for enumerating the elements of
766 /// the source tensor (in whatever order is best for the source tensor),
767 /// and applies a permutation to the coordinates/indices before handing
768 /// each element to the callback. A single enumerator object can be
769 /// freely reused for several calls to `forallElements`, just so long
770 /// as each call is sequential with respect to one another.
771 ///
772 /// N.B., this class stores a reference to the `SparseTensorStorageBase`
773 /// passed to the constructor; thus, objects of this class must not
774 /// outlive the sparse tensor they depend on.
775 ///
776 /// Design Note: The reason we define this class instead of simply using
777 /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
778 /// the `<P,I>` template parameters from MLIR client code (to simplify the
779 /// type parameters used for direct sparse-to-sparse conversion). And the
780 /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
781 /// than simply using this class, is to avoid the cost of virtual-method
782 /// dispatch within the loop-nest.
783 template <typename V>
784 class SparseTensorEnumeratorBase {
785 public:
786 /// Constructs an enumerator with the given permutation for mapping
787 /// the semantic-ordering of dimensions to the desired target-ordering.
788 ///
789 /// Preconditions:
790 /// * the `tensor` must have the same `V` value type.
791 /// * `perm` must be valid for `rank`.
SparseTensorEnumeratorBase(const SparseTensorStorageBase & tensor,uint64_t rank,const uint64_t * perm)792 SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
793 uint64_t rank, const uint64_t *perm)
794 : src(tensor), permsz(src.getRev().size()), reord(getRank()),
795 cursor(getRank()) {
796 assert(perm && "Received nullptr for permutation");
797 assert(rank == getRank() && "Permutation rank mismatch");
798 const auto &rev = src.getRev(); // source-order -> semantic-order
799 const auto &dimSizes = src.getDimSizes(); // in source storage-order
800 for (uint64_t s = 0; s < rank; s++) { // `s` source storage-order
801 uint64_t t = perm[rev[s]]; // `t` target-order
802 reord[s] = t;
803 permsz[t] = dimSizes[s];
804 }
805 }
806
807 virtual ~SparseTensorEnumeratorBase() = default;
808
809 // We disallow copying to help avoid leaking the `src` reference.
810 // (In addition to avoiding the problem of slicing.)
811 SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
812 SparseTensorEnumeratorBase &
813 operator=(const SparseTensorEnumeratorBase &) = delete;
814
815 /// Returns the source/target tensor's rank. (The source-rank and
816 /// target-rank are always equal since we only support permutations.
817 /// Though once we add support for other dimension mappings, this
818 /// method will have to be split in two.)
getRank() const819 uint64_t getRank() const { return permsz.size(); }
820
821 /// Returns the target tensor's dimension sizes.
permutedSizes() const822 const std::vector<uint64_t> &permutedSizes() const { return permsz; }
823
824 /// Enumerates all elements of the source tensor, permutes their
825 /// indices, and passes the permuted element to the callback.
826 /// The callback must not store the cursor reference directly,
827 /// since this function reuses the storage. Instead, the callback
828 /// must copy it if they want to keep it.
829 virtual void forallElements(ElementConsumer<V> yield) = 0;
830
831 protected:
832 const SparseTensorStorageBase &src;
833 std::vector<uint64_t> permsz; // in target order.
834 std::vector<uint64_t> reord; // source storage-order -> target order.
835 std::vector<uint64_t> cursor; // in target order.
836 };
837
838 template <typename P, typename I, typename V>
839 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
840 using Base = SparseTensorEnumeratorBase<V>;
841
842 public:
843 /// Constructs an enumerator with the given permutation for mapping
844 /// the semantic-ordering of dimensions to the desired target-ordering.
845 ///
846 /// Precondition: `perm` must be valid for `rank`.
SparseTensorEnumerator(const SparseTensorStorage<P,I,V> & tensor,uint64_t rank,const uint64_t * perm)847 SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
848 uint64_t rank, const uint64_t *perm)
849 : Base(tensor, rank, perm) {}
850
851 ~SparseTensorEnumerator() final = default;
852
forallElements(ElementConsumer<V> yield)853 void forallElements(ElementConsumer<V> yield) final {
854 forallElements(yield, 0, 0);
855 }
856
857 private:
858 /// The recursive component of the public `forallElements`.
forallElements(ElementConsumer<V> yield,uint64_t parentPos,uint64_t d)859 void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
860 uint64_t d) {
861 // Recover the `<P,I,V>` type parameters of `src`.
862 const auto &src =
863 static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
864 if (d == Base::getRank()) {
865 assert(parentPos < src.values.size() &&
866 "Value position is out of bounds");
867 // TODO: <https://github.com/llvm/llvm-project/issues/54179>
868 yield(this->cursor, src.values[parentPos]);
869 } else if (src.isCompressedDim(d)) {
870 // Look up the bounds of the `d`-level segment determined by the
871 // `d-1`-level position `parentPos`.
872 const std::vector<P> &pointersD = src.pointers[d];
873 assert(parentPos + 1 < pointersD.size() &&
874 "Parent pointer position is out of bounds");
875 const uint64_t pstart = static_cast<uint64_t>(pointersD[parentPos]);
876 const uint64_t pstop = static_cast<uint64_t>(pointersD[parentPos + 1]);
877 // Loop-invariant code for looking up the `d`-level coordinates/indices.
878 const std::vector<I> &indicesD = src.indices[d];
879 assert(pstop <= indicesD.size() && "Index position is out of bounds");
880 uint64_t &cursorReordD = this->cursor[this->reord[d]];
881 for (uint64_t pos = pstart; pos < pstop; pos++) {
882 cursorReordD = static_cast<uint64_t>(indicesD[pos]);
883 forallElements(yield, pos, d + 1);
884 }
885 } else { // Dense dimension.
886 const uint64_t sz = src.getDimSizes()[d];
887 const uint64_t pstart = parentPos * sz;
888 uint64_t &cursorReordD = this->cursor[this->reord[d]];
889 for (uint64_t i = 0; i < sz; i++) {
890 cursorReordD = i;
891 forallElements(yield, pstart + i, d + 1);
892 }
893 }
894 }
895 };
896
897 /// Statistics regarding the number of nonzero subtensors in
898 /// a source tensor, for direct sparse=>sparse conversion a la
899 /// <https://arxiv.org/abs/2001.02609>.
900 ///
901 /// N.B., this class stores references to the parameters passed to
902 /// the constructor; thus, objects of this class must not outlive
903 /// those parameters.
904 class SparseTensorNNZ final {
905 public:
906 /// Allocate the statistics structure for the desired sizes and
907 /// sparsity (in the target tensor's storage-order). This constructor
908 /// does not actually populate the statistics, however; for that see
909 /// `initialize`.
910 ///
911 /// Precondition: `dimSizes` must not contain zeros.
SparseTensorNNZ(const std::vector<uint64_t> & dimSizes,const std::vector<DimLevelType> & sparsity)912 SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
913 const std::vector<DimLevelType> &sparsity)
914 : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
915 assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
916 bool uncompressed = true;
917 (void)uncompressed;
918 uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
919 for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
920 switch (dimTypes[r]) {
921 case DimLevelType::kCompressed:
922 assert(uncompressed &&
923 "Multiple compressed layers not currently supported");
924 uncompressed = false;
925 nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
926 break;
927 case DimLevelType::kDense:
928 assert(uncompressed &&
929 "Dense after compressed not currently supported");
930 break;
931 case DimLevelType::kSingleton:
932 // Singleton after Compressed causes no problems for allocating
933 // `nnz` nor for the yieldPos loop. This remains true even
934 // when adding support for multiple compressed dimensions or
935 // for dense-after-compressed.
936 break;
937 }
938 sz = checkedMul(sz, dimSizes[r]);
939 }
940 }
941
942 // We disallow copying to help avoid leaking the stored references.
943 SparseTensorNNZ(const SparseTensorNNZ &) = delete;
944 SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
945
946 /// Returns the rank of the target tensor.
getRank() const947 uint64_t getRank() const { return dimSizes.size(); }
948
949 /// Enumerate the source tensor to fill in the statistics. The
950 /// enumerator should already incorporate the permutation (from
951 /// semantic-order to the target storage-order).
952 template <typename V>
initialize(SparseTensorEnumeratorBase<V> & enumerator)953 void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
954 assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
955 assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
956 enumerator.forallElements(
957 [this](const std::vector<uint64_t> &ind, V) { add(ind); });
958 }
959
960 /// The type of callback functions which receive an nnz-statistic.
961 using NNZConsumer = const std::function<void(uint64_t)> &;
962
963 /// Lexicographically enumerates all indicies for dimensions strictly
964 /// less than `stopDim`, and passes their nnz statistic to the callback.
965 /// Since our use-case only requires the statistic not the coordinates
966 /// themselves, we do not bother to construct those coordinates.
forallIndices(uint64_t stopDim,NNZConsumer yield) const967 void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
968 assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
969 assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
970 "Cannot look up non-compressed dimensions");
971 forallIndices(yield, stopDim, 0, 0);
972 }
973
974 private:
975 /// Adds a new element (i.e., increment its statistics). We use
976 /// a method rather than inlining into the lambda in `initialize`,
977 /// to avoid spurious templating over `V`. And this method is private
978 /// to avoid needing to re-assert validity of `ind` (which is guaranteed
979 /// by `forallElements`).
add(const std::vector<uint64_t> & ind)980 void add(const std::vector<uint64_t> &ind) {
981 uint64_t parentPos = 0;
982 for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
983 if (dimTypes[r] == DimLevelType::kCompressed)
984 nnz[r][parentPos]++;
985 parentPos = parentPos * dimSizes[r] + ind[r];
986 }
987 }
988
989 /// Recursive component of the public `forallIndices`.
forallIndices(NNZConsumer yield,uint64_t stopDim,uint64_t parentPos,uint64_t d) const990 void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
991 uint64_t d) const {
992 assert(d <= stopDim);
993 if (d == stopDim) {
994 assert(parentPos < nnz[d].size() && "Cursor is out of range");
995 yield(nnz[d][parentPos]);
996 } else {
997 const uint64_t sz = dimSizes[d];
998 const uint64_t pstart = parentPos * sz;
999 for (uint64_t i = 0; i < sz; i++)
1000 forallIndices(yield, stopDim, pstart + i, d + 1);
1001 }
1002 }
1003
1004 // All of these are in the target storage-order.
1005 const std::vector<uint64_t> &dimSizes;
1006 const std::vector<DimLevelType> &dimTypes;
1007 std::vector<std::vector<uint64_t>> nnz;
1008 };
1009
1010 template <typename P, typename I, typename V>
SparseTensorStorage(const std::vector<uint64_t> & dimSizes,const uint64_t * perm,const DimLevelType * sparsity,const SparseTensorStorageBase & tensor)1011 SparseTensorStorage<P, I, V>::SparseTensorStorage(
1012 const std::vector<uint64_t> &dimSizes, const uint64_t *perm,
1013 const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
1014 : SparseTensorStorage(dimSizes, perm, sparsity) {
1015 SparseTensorEnumeratorBase<V> *enumerator;
1016 tensor.newEnumerator(&enumerator, getRank(), perm);
1017 {
1018 // Initialize the statistics structure.
1019 SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
1020 nnz.initialize(*enumerator);
1021 // Initialize "pointers" overhead (and allocate "indices", "values").
1022 uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
1023 for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1024 if (isCompressedDim(r)) {
1025 pointers[r].reserve(parentSz + 1);
1026 pointers[r].push_back(0);
1027 uint64_t currentPos = 0;
1028 nnz.forallIndices(r, [this, ¤tPos, r](uint64_t n) {
1029 currentPos += n;
1030 appendPointer(r, currentPos);
1031 });
1032 assert(pointers[r].size() == parentSz + 1 &&
1033 "Final pointers size doesn't match allocated size");
1034 // That assertion entails `assembledSize(parentSz, r)`
1035 // is now in a valid state. That is, `pointers[r][parentSz]`
1036 // equals the present value of `currentPos`, which is the
1037 // correct assembled-size for `indices[r]`.
1038 }
1039 // Update assembled-size for the next iteration.
1040 parentSz = assembledSize(parentSz, r);
1041 // Ideally we need only `indices[r].reserve(parentSz)`, however
1042 // the `std::vector` implementation forces us to initialize it too.
1043 // That is, in the yieldPos loop we need random-access assignment
1044 // to `indices[r]`; however, `std::vector`'s subscript-assignment
1045 // only allows assigning to already-initialized positions.
1046 if (isCompressedDim(r))
1047 indices[r].resize(parentSz, 0);
1048 }
1049 values.resize(parentSz, 0); // Both allocate and zero-initialize.
1050 }
1051 // The yieldPos loop
1052 enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
1053 uint64_t parentSz = 1, parentPos = 0;
1054 for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
1055 if (isCompressedDim(r)) {
1056 // If `parentPos == parentSz` then it's valid as an array-lookup;
1057 // however, it's semantically invalid here since that entry
1058 // does not represent a segment of `indices[r]`. Moreover, that
1059 // entry must be immutable for `assembledSize` to remain valid.
1060 assert(parentPos < parentSz && "Pointers position is out of bounds");
1061 const uint64_t currentPos = pointers[r][parentPos];
1062 // This increment won't overflow the `P` type, since it can't
1063 // exceed the original value of `pointers[r][parentPos+1]`
1064 // which was already verified to be within bounds for `P`
1065 // when it was written to the array.
1066 pointers[r][parentPos]++;
1067 writeIndex(r, currentPos, ind[r]);
1068 parentPos = currentPos;
1069 } else { // Dense dimension.
1070 parentPos = parentPos * getDimSizes()[r] + ind[r];
1071 }
1072 parentSz = assembledSize(parentSz, r);
1073 }
1074 assert(parentPos < values.size() && "Value position is out of bounds");
1075 values[parentPos] = val;
1076 });
1077 // No longer need the enumerator, so we'll delete it ASAP.
1078 delete enumerator;
1079 // The finalizeYieldPos loop
1080 for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
1081 if (isCompressedDim(r)) {
1082 assert(parentSz == pointers[r].size() - 1 &&
1083 "Actual pointers size doesn't match the expected size");
1084 // Can't check all of them, but at least we can check the last one.
1085 assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
1086 "Pointers got corrupted");
1087 // TODO: optimize this by using `memmove` or similar.
1088 for (uint64_t n = 0; n < parentSz; n++) {
1089 const uint64_t parentPos = parentSz - n;
1090 pointers[r][parentPos] = pointers[r][parentPos - 1];
1091 }
1092 pointers[r][0] = 0;
1093 }
1094 parentSz = assembledSize(parentSz, r);
1095 }
1096 }
1097
1098 /// Helper to convert string to lower case.
toLower(char * token)1099 static char *toLower(char *token) {
1100 for (char *c = token; *c; c++)
1101 *c = tolower(*c);
1102 return token;
1103 }
1104
1105 /// This class abstracts over the information stored in file headers,
1106 /// as well as providing the buffers and methods for parsing those headers.
1107 class SparseTensorFile final {
1108 public:
1109 enum class ValueKind {
1110 kInvalid = 0,
1111 kPattern = 1,
1112 kReal = 2,
1113 kInteger = 3,
1114 kComplex = 4,
1115 kUndefined = 5
1116 };
1117
SparseTensorFile(char * filename)1118 explicit SparseTensorFile(char *filename) : filename(filename) {
1119 assert(filename && "Received nullptr for filename");
1120 }
1121
1122 // Disallows copying, to avoid duplicating the `file` pointer.
1123 SparseTensorFile(const SparseTensorFile &) = delete;
1124 SparseTensorFile &operator=(const SparseTensorFile &) = delete;
1125
1126 // This dtor tries to avoid leaking the `file`. (Though it's better
1127 // to call `closeFile` explicitly when possible, since there are
1128 // circumstances where dtors are not called reliably.)
~SparseTensorFile()1129 ~SparseTensorFile() { closeFile(); }
1130
1131 /// Opens the file for reading.
openFile()1132 void openFile() {
1133 if (file)
1134 FATAL("Already opened file %s\n", filename);
1135 file = fopen(filename, "r");
1136 if (!file)
1137 FATAL("Cannot find file %s\n", filename);
1138 }
1139
1140 /// Closes the file.
closeFile()1141 void closeFile() {
1142 if (file) {
1143 fclose(file);
1144 file = nullptr;
1145 }
1146 }
1147
1148 // TODO(wrengr/bixia): figure out how to reorganize the element-parsing
1149 // loop of `openSparseTensorCOO` into methods of this class, so we can
1150 // avoid leaking access to the `line` pointer (both for general hygiene
1151 // and because we can't mark it const due to the second argument of
1152 // `strtoul`/`strtoud` being `char * *restrict` rather than
1153 // `char const* *restrict`).
1154 //
1155 /// Attempts to read a line from the file.
readLine()1156 char *readLine() {
1157 if (fgets(line, kColWidth, file))
1158 return line;
1159 FATAL("Cannot read next line of %s\n", filename);
1160 }
1161
1162 /// Reads and parses the file's header.
readHeader()1163 void readHeader() {
1164 assert(file && "Attempt to readHeader() before openFile()");
1165 if (strstr(filename, ".mtx"))
1166 readMMEHeader();
1167 else if (strstr(filename, ".tns"))
1168 readExtFROSTTHeader();
1169 else
1170 FATAL("Unknown format %s\n", filename);
1171 assert(isValid() && "Failed to read the header");
1172 }
1173
getValueKind() const1174 ValueKind getValueKind() const { return valueKind_; }
1175
isValid() const1176 bool isValid() const { return valueKind_ != ValueKind::kInvalid; }
1177
1178 /// Gets the MME "pattern" property setting. Is only valid after
1179 /// parsing the header.
isPattern() const1180 bool isPattern() const {
1181 assert(isValid() && "Attempt to isPattern() before readHeader()");
1182 return valueKind_ == ValueKind::kPattern;
1183 }
1184
1185 /// Gets the MME "symmetric" property setting. Is only valid after
1186 /// parsing the header.
isSymmetric() const1187 bool isSymmetric() const {
1188 assert(isValid() && "Attempt to isSymmetric() before readHeader()");
1189 return isSymmetric_;
1190 }
1191
1192 /// Gets the rank of the tensor. Is only valid after parsing the header.
getRank() const1193 uint64_t getRank() const {
1194 assert(isValid() && "Attempt to getRank() before readHeader()");
1195 return idata[0];
1196 }
1197
1198 /// Gets the number of non-zeros. Is only valid after parsing the header.
getNNZ() const1199 uint64_t getNNZ() const {
1200 assert(isValid() && "Attempt to getNNZ() before readHeader()");
1201 return idata[1];
1202 }
1203
1204 /// Gets the dimension-sizes array. The pointer itself is always
1205 /// valid; however, the values stored therein are only valid after
1206 /// parsing the header.
getDimSizes() const1207 const uint64_t *getDimSizes() const { return idata + 2; }
1208
1209 /// Safely gets the size of the given dimension. Is only valid
1210 /// after parsing the header.
getDimSize(uint64_t d) const1211 uint64_t getDimSize(uint64_t d) const {
1212 assert(d < getRank());
1213 return idata[2 + d];
1214 }
1215
1216 /// Asserts the shape subsumes the actual dimension sizes. Is only
1217 /// valid after parsing the header.
assertMatchesShape(uint64_t rank,const uint64_t * shape) const1218 void assertMatchesShape(uint64_t rank, const uint64_t *shape) const {
1219 assert(rank == getRank() && "Rank mismatch");
1220 for (uint64_t r = 0; r < rank; r++)
1221 assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
1222 "Dimension size mismatch");
1223 }
1224
1225 private:
1226 void readMMEHeader();
1227 void readExtFROSTTHeader();
1228
1229 const char *filename;
1230 FILE *file = nullptr;
1231 ValueKind valueKind_ = ValueKind::kInvalid;
1232 bool isSymmetric_ = false;
1233 uint64_t idata[512];
1234 char line[kColWidth];
1235 };
1236
1237 /// Read the MME header of a general sparse matrix of type real.
readMMEHeader()1238 void SparseTensorFile::readMMEHeader() {
1239 char header[64];
1240 char object[64];
1241 char format[64];
1242 char field[64];
1243 char symmetry[64];
1244 // Read header line.
1245 if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
1246 symmetry) != 5)
1247 FATAL("Corrupt header in %s\n", filename);
1248 // Process `field`, which specify pattern or the data type of the values.
1249 if (strcmp(toLower(field), "pattern") == 0)
1250 valueKind_ = ValueKind::kPattern;
1251 else if (strcmp(toLower(field), "real") == 0)
1252 valueKind_ = ValueKind::kReal;
1253 else if (strcmp(toLower(field), "integer") == 0)
1254 valueKind_ = ValueKind::kInteger;
1255 else if (strcmp(toLower(field), "complex") == 0)
1256 valueKind_ = ValueKind::kComplex;
1257 else
1258 FATAL("Unexpected header field value in %s\n", filename);
1259
1260 // Set properties.
1261 isSymmetric_ = (strcmp(toLower(symmetry), "symmetric") == 0);
1262 // Make sure this is a general sparse matrix.
1263 if (strcmp(toLower(header), "%%matrixmarket") ||
1264 strcmp(toLower(object), "matrix") ||
1265 strcmp(toLower(format), "coordinate") ||
1266 (strcmp(toLower(symmetry), "general") && !isSymmetric_))
1267 FATAL("Cannot find a general sparse matrix in %s\n", filename);
1268 // Skip comments.
1269 while (true) {
1270 readLine();
1271 if (line[0] != '%')
1272 break;
1273 }
1274 // Next line contains M N NNZ.
1275 idata[0] = 2; // rank
1276 if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
1277 idata + 1) != 3)
1278 FATAL("Cannot find size in %s\n", filename);
1279 }
1280
1281 /// Read the "extended" FROSTT header. Although not part of the documented
1282 /// format, we assume that the file starts with optional comments followed
1283 /// by two lines that define the rank, the number of nonzeros, and the
1284 /// dimensions sizes (one per rank) of the sparse tensor.
readExtFROSTTHeader()1285 void SparseTensorFile::readExtFROSTTHeader() {
1286 // Skip comments.
1287 while (true) {
1288 readLine();
1289 if (line[0] != '#')
1290 break;
1291 }
1292 // Next line contains RANK and NNZ.
1293 if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
1294 FATAL("Cannot find metadata in %s\n", filename);
1295 // Followed by a line with the dimension sizes (one per rank).
1296 for (uint64_t r = 0; r < idata[0]; r++)
1297 if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
1298 FATAL("Cannot find dimension size %s\n", filename);
1299 readLine(); // end of line
1300 // The FROSTT format does not define the data type of the nonzero elements.
1301 valueKind_ = ValueKind::kUndefined;
1302 }
1303
1304 // Adds a value to a tensor in coordinate scheme. If is_symmetric_value is true,
1305 // also adds the value to its symmetric location.
1306 template <typename T, typename V>
addValue(T * coo,V value,const std::vector<uint64_t> indices,bool is_symmetric_value)1307 static inline void addValue(T *coo, V value,
1308 const std::vector<uint64_t> indices,
1309 bool is_symmetric_value) {
1310 // TODO: <https://github.com/llvm/llvm-project/issues/54179>
1311 coo->add(indices, value);
1312 // We currently chose to deal with symmetric matrices by fully constructing
1313 // them. In the future, we may want to make symmetry implicit for storage
1314 // reasons.
1315 if (is_symmetric_value)
1316 coo->add({indices[1], indices[0]}, value);
1317 }
1318
1319 // Reads an element of a complex type for the current indices in coordinate
1320 // scheme.
1321 template <typename V>
readCOOValue(SparseTensorCOO<std::complex<V>> * coo,const std::vector<uint64_t> indices,char ** linePtr,bool is_pattern,bool add_symmetric_value)1322 static inline void readCOOValue(SparseTensorCOO<std::complex<V>> *coo,
1323 const std::vector<uint64_t> indices,
1324 char **linePtr, bool is_pattern,
1325 bool add_symmetric_value) {
1326 // Read two values to make a complex. The external formats always store
1327 // numerical values with the type double, but we cast these values to the
1328 // sparse tensor object type. For a pattern tensor, we arbitrarily pick the
1329 // value 1 for all entries.
1330 V re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
1331 V im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
1332 std::complex<V> value = {re, im};
1333 addValue(coo, value, indices, add_symmetric_value);
1334 }
1335
1336 // Reads an element of a non-complex type for the current indices in coordinate
1337 // scheme.
1338 template <typename V,
1339 typename std::enable_if<
1340 !std::is_same<std::complex<float>, V>::value &&
1341 !std::is_same<std::complex<double>, V>::value>::type * = nullptr>
readCOOValue(SparseTensorCOO<V> * coo,const std::vector<uint64_t> indices,char ** linePtr,bool is_pattern,bool is_symmetric_value)1342 static void inline readCOOValue(SparseTensorCOO<V> *coo,
1343 const std::vector<uint64_t> indices,
1344 char **linePtr, bool is_pattern,
1345 bool is_symmetric_value) {
1346 // The external formats always store these numerical values with the type
1347 // double, but we cast these values to the sparse tensor object type.
1348 // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
1349 double value = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
1350 addValue(coo, value, indices, is_symmetric_value);
1351 }
1352
1353 /// Reads a sparse tensor with the given filename into a memory-resident
1354 /// sparse tensor in coordinate scheme.
1355 template <typename V>
1356 static SparseTensorCOO<V> *
openSparseTensorCOO(char * filename,uint64_t rank,const uint64_t * shape,const uint64_t * perm,PrimaryType valTp)1357 openSparseTensorCOO(char *filename, uint64_t rank, const uint64_t *shape,
1358 const uint64_t *perm, PrimaryType valTp) {
1359 SparseTensorFile stfile(filename);
1360 stfile.openFile();
1361 stfile.readHeader();
1362 // Check tensor element type against the value type in the input file.
1363 SparseTensorFile::ValueKind valueKind = stfile.getValueKind();
1364 bool tensorIsInteger =
1365 (valTp >= PrimaryType::kI64 && valTp <= PrimaryType::kI8);
1366 bool tensorIsReal = (valTp >= PrimaryType::kF64 && valTp <= PrimaryType::kI8);
1367 if ((valueKind == SparseTensorFile::ValueKind::kReal && tensorIsInteger) ||
1368 (valueKind == SparseTensorFile::ValueKind::kComplex && tensorIsReal)) {
1369 FATAL("Tensor element type %d not compatible with values in file %s\n",
1370 static_cast<int>(valTp), filename);
1371 }
1372 stfile.assertMatchesShape(rank, shape);
1373 // Prepare sparse tensor object with per-dimension sizes
1374 // and the number of nonzeros as initial capacity.
1375 uint64_t nnz = stfile.getNNZ();
1376 auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, stfile.getDimSizes(),
1377 perm, nnz);
1378 // Read all nonzero elements.
1379 std::vector<uint64_t> indices(rank);
1380 for (uint64_t k = 0; k < nnz; k++) {
1381 char *linePtr = stfile.readLine();
1382 for (uint64_t r = 0; r < rank; r++) {
1383 uint64_t idx = strtoul(linePtr, &linePtr, 10);
1384 // Add 0-based index.
1385 indices[perm[r]] = idx - 1;
1386 }
1387 readCOOValue(coo, indices, &linePtr, stfile.isPattern(),
1388 stfile.isSymmetric() && indices[0] != indices[1]);
1389 }
1390 // Close the file and return tensor.
1391 stfile.closeFile();
1392 return coo;
1393 }
1394
1395 /// Writes the sparse tensor to `dest` in extended FROSTT format.
1396 template <typename V>
outSparseTensor(void * tensor,void * dest,bool sort)1397 static void outSparseTensor(void *tensor, void *dest, bool sort) {
1398 assert(tensor && dest);
1399 auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
1400 if (sort)
1401 coo->sort();
1402 char *filename = static_cast<char *>(dest);
1403 auto &dimSizes = coo->getDimSizes();
1404 auto &elements = coo->getElements();
1405 uint64_t rank = coo->getRank();
1406 uint64_t nnz = elements.size();
1407 std::fstream file;
1408 file.open(filename, std::ios_base::out | std::ios_base::trunc);
1409 assert(file.is_open());
1410 file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
1411 for (uint64_t r = 0; r < rank - 1; r++)
1412 file << dimSizes[r] << " ";
1413 file << dimSizes[rank - 1] << std::endl;
1414 for (uint64_t i = 0; i < nnz; i++) {
1415 auto &idx = elements[i].indices;
1416 for (uint64_t r = 0; r < rank; r++)
1417 file << (idx[r] + 1) << " ";
1418 file << elements[i].value << std::endl;
1419 }
1420 file.flush();
1421 file.close();
1422 assert(file.good());
1423 }
1424
1425 /// Initializes sparse tensor from an external COO-flavored format.
1426 template <typename V>
1427 static SparseTensorStorage<uint64_t, uint64_t, V> *
toMLIRSparseTensor(uint64_t rank,uint64_t nse,uint64_t * shape,V * values,uint64_t * indices,uint64_t * perm,uint8_t * sparse)1428 toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
1429 uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
1430 const DimLevelType *sparsity = (DimLevelType *)(sparse);
1431 #ifndef NDEBUG
1432 // Verify that perm is a permutation of 0..(rank-1).
1433 std::vector<uint64_t> order(perm, perm + rank);
1434 std::sort(order.begin(), order.end());
1435 for (uint64_t i = 0; i < rank; ++i)
1436 if (i != order[i])
1437 FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
1438
1439 // Verify that the sparsity values are supported.
1440 for (uint64_t i = 0; i < rank; ++i)
1441 if (sparsity[i] != DimLevelType::kDense &&
1442 sparsity[i] != DimLevelType::kCompressed)
1443 FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
1444 #endif
1445
1446 // Convert external format to internal COO.
1447 auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
1448 std::vector<uint64_t> idx(rank);
1449 for (uint64_t i = 0, base = 0; i < nse; i++) {
1450 for (uint64_t r = 0; r < rank; r++)
1451 idx[perm[r]] = indices[base + r];
1452 coo->add(idx, values[i]);
1453 base += rank;
1454 }
1455 // Return sparse tensor storage format as opaque pointer.
1456 auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
1457 rank, shape, perm, sparsity, coo);
1458 delete coo;
1459 return tensor;
1460 }
1461
1462 /// Converts a sparse tensor to an external COO-flavored format.
1463 template <typename V>
fromMLIRSparseTensor(void * tensor,uint64_t * pRank,uint64_t * pNse,uint64_t ** pShape,V ** pValues,uint64_t ** pIndices)1464 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1465 uint64_t **pShape, V **pValues,
1466 uint64_t **pIndices) {
1467 assert(tensor);
1468 auto sparseTensor =
1469 static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
1470 uint64_t rank = sparseTensor->getRank();
1471 std::vector<uint64_t> perm(rank);
1472 std::iota(perm.begin(), perm.end(), 0);
1473 SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
1474
1475 const std::vector<Element<V>> &elements = coo->getElements();
1476 uint64_t nse = elements.size();
1477
1478 uint64_t *shape = new uint64_t[rank];
1479 for (uint64_t i = 0; i < rank; i++)
1480 shape[i] = coo->getDimSizes()[i];
1481
1482 V *values = new V[nse];
1483 uint64_t *indices = new uint64_t[rank * nse];
1484
1485 for (uint64_t i = 0, base = 0; i < nse; i++) {
1486 values[i] = elements[i].value;
1487 for (uint64_t j = 0; j < rank; j++)
1488 indices[base + j] = elements[i].indices[j];
1489 base += rank;
1490 }
1491
1492 delete coo;
1493 *pRank = rank;
1494 *pNse = nse;
1495 *pShape = shape;
1496 *pValues = values;
1497 *pIndices = indices;
1498 }
1499
1500 } // anonymous namespace
1501
1502 extern "C" {
1503
1504 //===----------------------------------------------------------------------===//
1505 //
1506 // Public functions which operate on MLIR buffers (memrefs) to interact
1507 // with sparse tensors (which are only visible as opaque pointers externally).
1508 //
1509 //===----------------------------------------------------------------------===//
1510
1511 #define CASE(p, i, v, P, I, V) \
1512 if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \
1513 SparseTensorCOO<V> *coo = nullptr; \
1514 if (action <= Action::kFromCOO) { \
1515 if (action == Action::kFromFile) { \
1516 char *filename = static_cast<char *>(ptr); \
1517 coo = openSparseTensorCOO<V>(filename, rank, shape, perm, v); \
1518 } else if (action == Action::kFromCOO) { \
1519 coo = static_cast<SparseTensorCOO<V> *>(ptr); \
1520 } else { \
1521 assert(action == Action::kEmpty); \
1522 } \
1523 auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor( \
1524 rank, shape, perm, sparsity, coo); \
1525 if (action == Action::kFromFile) \
1526 delete coo; \
1527 return tensor; \
1528 } \
1529 if (action == Action::kSparseToSparse) { \
1530 auto *tensor = static_cast<SparseTensorStorageBase *>(ptr); \
1531 return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm, \
1532 sparsity, tensor); \
1533 } \
1534 if (action == Action::kEmptyCOO) \
1535 return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm); \
1536 coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \
1537 if (action == Action::kToIterator) { \
1538 coo->startIterator(); \
1539 } else { \
1540 assert(action == Action::kToCOO); \
1541 } \
1542 return coo; \
1543 }
1544
1545 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
1546
1547 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1548 // can safely rewrite kIndex to kU64. We make this assertion to guarantee
1549 // that this file cannot get out of sync with its header.
1550 static_assert(std::is_same<index_type, uint64_t>::value,
1551 "Expected index_type == uint64_t");
1552
1553 void *
_mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType,1> * aref,StridedMemRefType<index_type,1> * sref,StridedMemRefType<index_type,1> * pref,OverheadType ptrTp,OverheadType indTp,PrimaryType valTp,Action action,void * ptr)1554 _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1555 StridedMemRefType<index_type, 1> *sref,
1556 StridedMemRefType<index_type, 1> *pref,
1557 OverheadType ptrTp, OverheadType indTp,
1558 PrimaryType valTp, Action action, void *ptr) {
1559 assert(aref && sref && pref);
1560 assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
1561 pref->strides[0] == 1);
1562 assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1563 const DimLevelType *sparsity = aref->data + aref->offset;
1564 const index_type *shape = sref->data + sref->offset;
1565 const index_type *perm = pref->data + pref->offset;
1566 uint64_t rank = aref->sizes[0];
1567
1568 // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1569 // This is safe because of the static_assert above.
1570 if (ptrTp == OverheadType::kIndex)
1571 ptrTp = OverheadType::kU64;
1572 if (indTp == OverheadType::kIndex)
1573 indTp = OverheadType::kU64;
1574
1575 // Double matrices with all combinations of overhead storage.
1576 CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1577 uint64_t, double);
1578 CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1579 uint32_t, double);
1580 CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1581 uint16_t, double);
1582 CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1583 uint8_t, double);
1584 CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1585 uint64_t, double);
1586 CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1587 uint32_t, double);
1588 CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1589 uint16_t, double);
1590 CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1591 uint8_t, double);
1592 CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1593 uint64_t, double);
1594 CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1595 uint32_t, double);
1596 CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1597 uint16_t, double);
1598 CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1599 uint8_t, double);
1600 CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1601 uint64_t, double);
1602 CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1603 uint32_t, double);
1604 CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1605 uint16_t, double);
1606 CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1607 uint8_t, double);
1608
1609 // Float matrices with all combinations of overhead storage.
1610 CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1611 uint64_t, float);
1612 CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1613 uint32_t, float);
1614 CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1615 uint16_t, float);
1616 CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1617 uint8_t, float);
1618 CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1619 uint64_t, float);
1620 CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1621 uint32_t, float);
1622 CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1623 uint16_t, float);
1624 CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1625 uint8_t, float);
1626 CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1627 uint64_t, float);
1628 CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1629 uint32_t, float);
1630 CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1631 uint16_t, float);
1632 CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1633 uint8_t, float);
1634 CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1635 uint64_t, float);
1636 CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1637 uint32_t, float);
1638 CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1639 uint16_t, float);
1640 CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1641 uint8_t, float);
1642
1643 // Two-byte floats with both overheads of the same type.
1644 CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16);
1645 CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16);
1646 CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16);
1647 CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16);
1648 CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16);
1649 CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16);
1650 CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16);
1651 CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16);
1652
1653 // Integral matrices with both overheads of the same type.
1654 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1655 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1656 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1657 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1658 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t);
1659 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1660 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1661 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1662 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t);
1663 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1664 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1665 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1666 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t);
1667 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1668 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1669 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
1670
1671 // Complex matrices with wide overhead.
1672 CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
1673 CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
1674
1675 // Unsupported case (add above if needed).
1676 // TODO: better pretty-printing of enum values!
1677 FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
1678 static_cast<int>(ptrTp), static_cast<int>(indTp),
1679 static_cast<int>(valTp));
1680 }
1681 #undef CASE
1682 #undef CASE_SECSAME
1683
1684 #define IMPL_SPARSEVALUES(VNAME, V) \
1685 void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref, \
1686 void *tensor) { \
1687 assert(ref &&tensor); \
1688 std::vector<V> *v; \
1689 static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \
1690 ref->basePtr = ref->data = v->data(); \
1691 ref->offset = 0; \
1692 ref->sizes[0] = v->size(); \
1693 ref->strides[0] = 1; \
1694 }
1695 FOREVERY_V(IMPL_SPARSEVALUES)
1696 #undef IMPL_SPARSEVALUES
1697
1698 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \
1699 void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \
1700 index_type d) { \
1701 assert(ref &&tensor); \
1702 std::vector<TYPE> *v; \
1703 static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d); \
1704 ref->basePtr = ref->data = v->data(); \
1705 ref->offset = 0; \
1706 ref->sizes[0] = v->size(); \
1707 ref->strides[0] = 1; \
1708 }
1709 #define IMPL_SPARSEPOINTERS(PNAME, P) \
1710 IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
FOREVERY_O(IMPL_SPARSEPOINTERS)1711 FOREVERY_O(IMPL_SPARSEPOINTERS)
1712 #undef IMPL_SPARSEPOINTERS
1713
1714 #define IMPL_SPARSEINDICES(INAME, I) \
1715 IMPL_GETOVERHEAD(sparseIndices##INAME, I, getIndices)
1716 FOREVERY_O(IMPL_SPARSEINDICES)
1717 #undef IMPL_SPARSEINDICES
1718 #undef IMPL_GETOVERHEAD
1719
1720 #define IMPL_ADDELT(VNAME, V) \
1721 void *_mlir_ciface_addElt##VNAME(void *coo, StridedMemRefType<V, 0> *vref, \
1722 StridedMemRefType<index_type, 1> *iref, \
1723 StridedMemRefType<index_type, 1> *pref) { \
1724 assert(coo &&vref &&iref &&pref); \
1725 assert(iref->strides[0] == 1 && pref->strides[0] == 1); \
1726 assert(iref->sizes[0] == pref->sizes[0]); \
1727 const index_type *indx = iref->data + iref->offset; \
1728 const index_type *perm = pref->data + pref->offset; \
1729 uint64_t isize = iref->sizes[0]; \
1730 std::vector<index_type> indices(isize); \
1731 for (uint64_t r = 0; r < isize; r++) \
1732 indices[perm[r]] = indx[r]; \
1733 V *value = vref->data + vref->offset; \
1734 static_cast<SparseTensorCOO<V> *>(coo)->add(indices, *value); \
1735 return coo; \
1736 }
1737 FOREVERY_V(IMPL_ADDELT)
1738 #undef IMPL_ADDELT
1739
1740 #define IMPL_GETNEXT(VNAME, V) \
1741 bool _mlir_ciface_getNext##VNAME(void *coo, \
1742 StridedMemRefType<index_type, 1> *iref, \
1743 StridedMemRefType<V, 0> *vref) { \
1744 assert(coo &&iref &&vref); \
1745 assert(iref->strides[0] == 1); \
1746 index_type *indx = iref->data + iref->offset; \
1747 V *value = vref->data + vref->offset; \
1748 const uint64_t isize = iref->sizes[0]; \
1749 const Element<V> *elem = \
1750 static_cast<SparseTensorCOO<V> *>(coo)->getNext(); \
1751 if (elem == nullptr) \
1752 return false; \
1753 for (uint64_t r = 0; r < isize; r++) \
1754 indx[r] = elem->indices[r]; \
1755 *value = elem->value; \
1756 return true; \
1757 }
1758 FOREVERY_V(IMPL_GETNEXT)
1759 #undef IMPL_GETNEXT
1760
1761 #define IMPL_LEXINSERT(VNAME, V) \
1762 void _mlir_ciface_lexInsert##VNAME(void *tensor, \
1763 StridedMemRefType<index_type, 1> *cref, \
1764 StridedMemRefType<V, 0> *vref) { \
1765 assert(tensor &&cref &&vref); \
1766 assert(cref->strides[0] == 1); \
1767 index_type *cursor = cref->data + cref->offset; \
1768 assert(cursor); \
1769 V *value = vref->data + vref->offset; \
1770 static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, *value); \
1771 }
1772 FOREVERY_V(IMPL_LEXINSERT)
1773 #undef IMPL_LEXINSERT
1774
1775 #define IMPL_EXPINSERT(VNAME, V) \
1776 void _mlir_ciface_expInsert##VNAME( \
1777 void *tensor, StridedMemRefType<index_type, 1> *cref, \
1778 StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \
1779 StridedMemRefType<index_type, 1> *aref, index_type count) { \
1780 assert(tensor &&cref &&vref &&fref &&aref); \
1781 assert(cref->strides[0] == 1); \
1782 assert(vref->strides[0] == 1); \
1783 assert(fref->strides[0] == 1); \
1784 assert(aref->strides[0] == 1); \
1785 assert(vref->sizes[0] == fref->sizes[0]); \
1786 index_type *cursor = cref->data + cref->offset; \
1787 V *values = vref->data + vref->offset; \
1788 bool *filled = fref->data + fref->offset; \
1789 index_type *added = aref->data + aref->offset; \
1790 static_cast<SparseTensorStorageBase *>(tensor)->expInsert( \
1791 cursor, values, filled, added, count); \
1792 }
1793 FOREVERY_V(IMPL_EXPINSERT)
1794 #undef IMPL_EXPINSERT
1795
1796 //===----------------------------------------------------------------------===//
1797 //
1798 // Public functions which accept only C-style data structures to interact
1799 // with sparse tensors (which are only visible as opaque pointers externally).
1800 //
1801 //===----------------------------------------------------------------------===//
1802
1803 index_type sparseDimSize(void *tensor, index_type d) {
1804 return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
1805 }
1806
endInsert(void * tensor)1807 void endInsert(void *tensor) {
1808 return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1809 }
1810
1811 #define IMPL_OUTSPARSETENSOR(VNAME, V) \
1812 void outSparseTensor##VNAME(void *coo, void *dest, bool sort) { \
1813 return outSparseTensor<V>(coo, dest, sort); \
1814 }
FOREVERY_V(IMPL_OUTSPARSETENSOR)1815 FOREVERY_V(IMPL_OUTSPARSETENSOR)
1816 #undef IMPL_OUTSPARSETENSOR
1817
1818 void delSparseTensor(void *tensor) {
1819 delete static_cast<SparseTensorStorageBase *>(tensor);
1820 }
1821
1822 #define IMPL_DELCOO(VNAME, V) \
1823 void delSparseTensorCOO##VNAME(void *coo) { \
1824 delete static_cast<SparseTensorCOO<V> *>(coo); \
1825 }
FOREVERY_V(IMPL_DELCOO)1826 FOREVERY_V(IMPL_DELCOO)
1827 #undef IMPL_DELCOO
1828
1829 char *getTensorFilename(index_type id) {
1830 char var[80];
1831 sprintf(var, "TENSOR%" PRIu64, id);
1832 char *env = getenv(var);
1833 if (!env)
1834 FATAL("Environment variable %s is not set\n", var);
1835 return env;
1836 }
1837
readSparseTensorShape(char * filename,std::vector<uint64_t> * out)1838 void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
1839 assert(out && "Received nullptr for out-parameter");
1840 SparseTensorFile stfile(filename);
1841 stfile.openFile();
1842 stfile.readHeader();
1843 stfile.closeFile();
1844 const uint64_t rank = stfile.getRank();
1845 const uint64_t *dimSizes = stfile.getDimSizes();
1846 out->reserve(rank);
1847 out->assign(dimSizes, dimSizes + rank);
1848 }
1849
1850 // TODO: generalize beyond 64-bit indices.
1851 #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V) \
1852 void *convertToMLIRSparseTensor##VNAME( \
1853 uint64_t rank, uint64_t nse, uint64_t *shape, V *values, \
1854 uint64_t *indices, uint64_t *perm, uint8_t *sparse) { \
1855 return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm, \
1856 sparse); \
1857 }
1858 FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
1859 #undef IMPL_CONVERTTOMLIRSPARSETENSOR
1860
1861 // TODO: Currently, values are copied from SparseTensorStorage to
1862 // SparseTensorCOO, then to the output. We may want to reduce the number
1863 // of copies.
1864 //
1865 // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
1866 // compressed
1867 #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V) \
1868 void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank, \
1869 uint64_t *pNse, uint64_t **pShape, \
1870 V **pValues, uint64_t **pIndices) { \
1871 fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices); \
1872 }
1873 FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
1874 #undef IMPL_CONVERTFROMMLIRSPARSETENSOR
1875
1876 } // extern "C"
1877
1878 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1879