18a91bc7bSHarrietAkot //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===// 28a91bc7bSHarrietAkot // 38a91bc7bSHarrietAkot // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48a91bc7bSHarrietAkot // See https://llvm.org/LICENSE.txt for license information. 58a91bc7bSHarrietAkot // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68a91bc7bSHarrietAkot // 78a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 88a91bc7bSHarrietAkot // 98a91bc7bSHarrietAkot // This file implements a light-weight runtime support library that is useful 108a91bc7bSHarrietAkot // for sparse tensor manipulations. The functionality provided in this library 118a91bc7bSHarrietAkot // is meant to simplify benchmarking, testing, and debugging MLIR code that 128a91bc7bSHarrietAkot // operates on sparse tensors. The provided functionality is **not** part 138a91bc7bSHarrietAkot // of core MLIR, however. 148a91bc7bSHarrietAkot // 158a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 168a91bc7bSHarrietAkot 17845561ecSwren romano #include "mlir/ExecutionEngine/SparseTensorUtils.h" 188a91bc7bSHarrietAkot 198a91bc7bSHarrietAkot #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 208a91bc7bSHarrietAkot 218a91bc7bSHarrietAkot #include <algorithm> 228a91bc7bSHarrietAkot #include <cassert> 238a91bc7bSHarrietAkot #include <cctype> 248a91bc7bSHarrietAkot #include <cstdio> 258a91bc7bSHarrietAkot #include <cstdlib> 268a91bc7bSHarrietAkot #include <cstring> 27efa15f41SAart Bik #include <fstream> 28753fe330Swren romano #include <functional> 29efa15f41SAart Bik #include <iostream> 304d0a18d0Swren romano #include <limits> 318a91bc7bSHarrietAkot #include <numeric> 32736c1b66SAart Bik 338a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 348a91bc7bSHarrietAkot // 358a91bc7bSHarrietAkot // Internal support for storing and reading sparse tensors. 368a91bc7bSHarrietAkot // 378a91bc7bSHarrietAkot // The following memory-resident sparse storage schemes are supported: 388a91bc7bSHarrietAkot // 398a91bc7bSHarrietAkot // (a) A coordinate scheme for temporarily storing and lexicographically 408a91bc7bSHarrietAkot // sorting a sparse tensor by index (SparseTensorCOO). 418a91bc7bSHarrietAkot // 428a91bc7bSHarrietAkot // (b) A "one-size-fits-all" sparse tensor storage scheme defined by 438a91bc7bSHarrietAkot // per-dimension sparse/dense annnotations together with a dimension 448a91bc7bSHarrietAkot // ordering used by MLIR compiler-generated code (SparseTensorStorage). 458a91bc7bSHarrietAkot // 468a91bc7bSHarrietAkot // The following external formats are supported: 478a91bc7bSHarrietAkot // 488a91bc7bSHarrietAkot // (1) Matrix Market Exchange (MME): *.mtx 498a91bc7bSHarrietAkot // https://math.nist.gov/MatrixMarket/formats.html 508a91bc7bSHarrietAkot // 518a91bc7bSHarrietAkot // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns 528a91bc7bSHarrietAkot // http://frostt.io/tensors/file-formats.html 538a91bc7bSHarrietAkot // 548a91bc7bSHarrietAkot // Two public APIs are supported: 558a91bc7bSHarrietAkot // 568a91bc7bSHarrietAkot // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse 578a91bc7bSHarrietAkot // tensors. These methods should be used exclusively by MLIR 588a91bc7bSHarrietAkot // compiler-generated code. 598a91bc7bSHarrietAkot // 608a91bc7bSHarrietAkot // (II) Methods that accept C-style data structures to interact with sparse 618a91bc7bSHarrietAkot // tensors. These methods can be used by any external runtime that wants 628a91bc7bSHarrietAkot // to interact with MLIR compiler-generated code. 638a91bc7bSHarrietAkot // 648a91bc7bSHarrietAkot // In both cases (I) and (II), the SparseTensorStorage format is externally 658a91bc7bSHarrietAkot // only visible as an opaque pointer. 668a91bc7bSHarrietAkot // 678a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 688a91bc7bSHarrietAkot 698a91bc7bSHarrietAkot namespace { 708a91bc7bSHarrietAkot 7103fe15ceSAart Bik static constexpr int kColWidth = 1025; 7203fe15ceSAart Bik 7372ec2f76Swren romano /// A version of `operator*` on `uint64_t` which checks for overflows. 7472ec2f76Swren romano static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) { 7572ec2f76Swren romano assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) && 7672ec2f76Swren romano "Integer overflow"); 7772ec2f76Swren romano return lhs * rhs; 7872ec2f76Swren romano } 7972ec2f76Swren romano 80774674ceSwren romano // This macro helps minimize repetition of this idiom, as well as ensuring 81774674ceSwren romano // we have some additional output indicating where the error is coming from. 82774674ceSwren romano // (Since `fprintf` doesn't provide a stacktrace, this helps make it easier 83774674ceSwren romano // to track down whether an error is coming from our code vs somewhere else 84774674ceSwren romano // in MLIR.) 85774674ceSwren romano #define FATAL(...) \ 86c63d4facSwren romano do { \ 87774674ceSwren romano fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__); \ 88774674ceSwren romano exit(1); \ 89c63d4facSwren romano } while (0) 90774674ceSwren romano 91a4c53f8cSwren romano // TODO: try to unify this with `SparseTensorFile::assertMatchesShape` 92a4c53f8cSwren romano // which is used by `openSparseTensorCOO`. It's easy enough to resolve 93a4c53f8cSwren romano // the `std::vector` vs pointer mismatch for `dimSizes`; but it's trickier 94a4c53f8cSwren romano // to resolve the presence/absence of `perm` (without introducing extra 95a4c53f8cSwren romano // overhead), so perhaps the code duplication is unavoidable. 968cb33240Swren romano // 97fa6aed2aSwren romano /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping 988cb33240Swren romano /// semantic-order to target-order) are a refinement of the desired `shape` 998cb33240Swren romano /// (in semantic-order). 1008cb33240Swren romano /// 1018cb33240Swren romano /// Precondition: `perm` and `shape` must be valid for `rank`. 1028cb33240Swren romano static inline void 103fa6aed2aSwren romano assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes, 104fa6aed2aSwren romano uint64_t rank, const uint64_t *perm, 105fa6aed2aSwren romano const uint64_t *shape) { 1068cb33240Swren romano assert(perm && shape); 107fa6aed2aSwren romano assert(rank == dimSizes.size() && "Rank mismatch"); 1088cb33240Swren romano for (uint64_t r = 0; r < rank; r++) 109fa6aed2aSwren romano assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) && 1108cb33240Swren romano "Dimension size mismatch"); 1118cb33240Swren romano } 1128cb33240Swren romano 1138a91bc7bSHarrietAkot /// A sparse tensor element in coordinate scheme (value and indices). 1148a91bc7bSHarrietAkot /// For example, a rank-1 vector element would look like 1158a91bc7bSHarrietAkot /// ({i}, a[i]) 1168a91bc7bSHarrietAkot /// and a rank-5 tensor element like 1178a91bc7bSHarrietAkot /// ({i,j,k,l,m}, a[i,j,k,l,m]) 118ccd047cbSAart Bik /// We use pointer to a shared index pool rather than e.g. a direct 119ccd047cbSAart Bik /// vector since that (1) reduces the per-element memory footprint, and 120ccd047cbSAart Bik /// (2) centralizes the memory reservation and (re)allocation to one place. 1218a91bc7bSHarrietAkot template <typename V> 12276944420Swren romano struct Element final { 123ccd047cbSAart Bik Element(uint64_t *ind, V val) : indices(ind), value(val){}; 124ccd047cbSAart Bik uint64_t *indices; // pointer into shared index pool 1258a91bc7bSHarrietAkot V value; 1268a91bc7bSHarrietAkot }; 1278a91bc7bSHarrietAkot 128753fe330Swren romano /// The type of callback functions which receive an element. We avoid 129753fe330Swren romano /// packaging the coordinates and value together as an `Element` object 130753fe330Swren romano /// because this helps keep code somewhat cleaner. 131753fe330Swren romano template <typename V> 132753fe330Swren romano using ElementConsumer = 133753fe330Swren romano const std::function<void(const std::vector<uint64_t> &, V)> &; 134753fe330Swren romano 1358a91bc7bSHarrietAkot /// A memory-resident sparse tensor in coordinate scheme (collection of 1368a91bc7bSHarrietAkot /// elements). This data structure is used to read a sparse tensor from 1378a91bc7bSHarrietAkot /// any external format into memory and sort the elements lexicographically 1388a91bc7bSHarrietAkot /// by indices before passing it back to the client (most packed storage 1398a91bc7bSHarrietAkot /// formats require the elements to appear in lexicographic index order). 1408a91bc7bSHarrietAkot template <typename V> 14176944420Swren romano struct SparseTensorCOO final { 1428a91bc7bSHarrietAkot public: 143fa6aed2aSwren romano SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity) 144fa6aed2aSwren romano : dimSizes(dimSizes) { 145ccd047cbSAart Bik if (capacity) { 1468a91bc7bSHarrietAkot elements.reserve(capacity); 147ccd047cbSAart Bik indices.reserve(capacity * getRank()); 1488a91bc7bSHarrietAkot } 149ccd047cbSAart Bik } 150ccd047cbSAart Bik 1518a91bc7bSHarrietAkot /// Adds element as indices and value. 1528a91bc7bSHarrietAkot void add(const std::vector<uint64_t> &ind, V val) { 1538a91bc7bSHarrietAkot assert(!iteratorLocked && "Attempt to add() after startIterator()"); 154ccd047cbSAart Bik uint64_t *base = indices.data(); 155ccd047cbSAart Bik uint64_t size = indices.size(); 1568a91bc7bSHarrietAkot uint64_t rank = getRank(); 157fa6aed2aSwren romano assert(ind.size() == rank && "Element rank mismatch"); 158ccd047cbSAart Bik for (uint64_t r = 0; r < rank; r++) { 159fa6aed2aSwren romano assert(ind[r] < dimSizes[r] && "Index is too large for the dimension"); 160ccd047cbSAart Bik indices.push_back(ind[r]); 1618a91bc7bSHarrietAkot } 162ccd047cbSAart Bik // This base only changes if indices were reallocated. In that case, we 163ccd047cbSAart Bik // need to correct all previous pointers into the vector. Note that this 164ccd047cbSAart Bik // only happens if we did not set the initial capacity right, and then only 165ccd047cbSAart Bik // for every internal vector reallocation (which with the doubling rule 166ccd047cbSAart Bik // should only incur an amortized linear overhead). 167298d2fa1SMehdi Amini uint64_t *newBase = indices.data(); 168298d2fa1SMehdi Amini if (newBase != base) { 169ccd047cbSAart Bik for (uint64_t i = 0, n = elements.size(); i < n; i++) 170298d2fa1SMehdi Amini elements[i].indices = newBase + (elements[i].indices - base); 171298d2fa1SMehdi Amini base = newBase; 172ccd047cbSAart Bik } 173ccd047cbSAart Bik // Add element as (pointer into shared index pool, value) pair. 174ccd047cbSAart Bik elements.emplace_back(base + size, val); 175ccd047cbSAart Bik } 176ccd047cbSAart Bik 1778a91bc7bSHarrietAkot /// Sorts elements lexicographically by index. 1788a91bc7bSHarrietAkot void sort() { 1798a91bc7bSHarrietAkot assert(!iteratorLocked && "Attempt to sort() after startIterator()"); 180cf358253Swren romano // TODO: we may want to cache an `isSorted` bit, to avoid 181cf358253Swren romano // unnecessary/redundant sorting. 182ccd047cbSAart Bik uint64_t rank = getRank(); 183aff9c89fSwren romano std::sort(elements.begin(), elements.end(), 184aff9c89fSwren romano [rank](const Element<V> &e1, const Element<V> &e2) { 185ccd047cbSAart Bik for (uint64_t r = 0; r < rank; r++) { 186ccd047cbSAart Bik if (e1.indices[r] == e2.indices[r]) 187ccd047cbSAart Bik continue; 188ccd047cbSAart Bik return e1.indices[r] < e2.indices[r]; 1898a91bc7bSHarrietAkot } 190ccd047cbSAart Bik return false; 191ccd047cbSAart Bik }); 192ccd047cbSAart Bik } 193ccd047cbSAart Bik 194fa6aed2aSwren romano /// Get the rank of the tensor. 195fa6aed2aSwren romano uint64_t getRank() const { return dimSizes.size(); } 196ccd047cbSAart Bik 197fa6aed2aSwren romano /// Getter for the dimension-sizes array. 198fa6aed2aSwren romano const std::vector<uint64_t> &getDimSizes() const { return dimSizes; } 199ccd047cbSAart Bik 200fa6aed2aSwren romano /// Getter for the elements array. 2018a91bc7bSHarrietAkot const std::vector<Element<V>> &getElements() const { return elements; } 2028a91bc7bSHarrietAkot 2038a91bc7bSHarrietAkot /// Switch into iterator mode. 2048a91bc7bSHarrietAkot void startIterator() { 2058a91bc7bSHarrietAkot iteratorLocked = true; 2068a91bc7bSHarrietAkot iteratorPos = 0; 2078a91bc7bSHarrietAkot } 208ccd047cbSAart Bik 2098a91bc7bSHarrietAkot /// Get the next element. 2108a91bc7bSHarrietAkot const Element<V> *getNext() { 2118a91bc7bSHarrietAkot assert(iteratorLocked && "Attempt to getNext() before startIterator()"); 2128a91bc7bSHarrietAkot if (iteratorPos < elements.size()) 2138a91bc7bSHarrietAkot return &(elements[iteratorPos++]); 2148a91bc7bSHarrietAkot iteratorLocked = false; 2158a91bc7bSHarrietAkot return nullptr; 2168a91bc7bSHarrietAkot } 2178a91bc7bSHarrietAkot 2188a91bc7bSHarrietAkot /// Factory method. Permutes the original dimensions according to 2198a91bc7bSHarrietAkot /// the given ordering and expects subsequent add() calls to honor 2208a91bc7bSHarrietAkot /// that same ordering for the given indices. The result is a 2218a91bc7bSHarrietAkot /// fully permuted coordinate scheme. 2228d8b566fSwren romano /// 223fa6aed2aSwren romano /// Precondition: `dimSizes` and `perm` must be valid for `rank`. 2248a91bc7bSHarrietAkot static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank, 225fa6aed2aSwren romano const uint64_t *dimSizes, 2268a91bc7bSHarrietAkot const uint64_t *perm, 2278a91bc7bSHarrietAkot uint64_t capacity = 0) { 2288a91bc7bSHarrietAkot std::vector<uint64_t> permsz(rank); 229d83a7068Swren romano for (uint64_t r = 0; r < rank; r++) { 230fa6aed2aSwren romano assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage"); 231fa6aed2aSwren romano permsz[perm[r]] = dimSizes[r]; 232d83a7068Swren romano } 2338a91bc7bSHarrietAkot return new SparseTensorCOO<V>(permsz, capacity); 2348a91bc7bSHarrietAkot } 2358a91bc7bSHarrietAkot 2368a91bc7bSHarrietAkot private: 237fa6aed2aSwren romano const std::vector<uint64_t> dimSizes; // per-dimension sizes 238ccd047cbSAart Bik std::vector<Element<V>> elements; // all COO elements 239ccd047cbSAart Bik std::vector<uint64_t> indices; // shared index pool 240db6796dfSMehdi Amini bool iteratorLocked = false; 241db6796dfSMehdi Amini unsigned iteratorPos = 0; 2428a91bc7bSHarrietAkot }; 2438a91bc7bSHarrietAkot 2448cb33240Swren romano // Forward. 2458cb33240Swren romano template <typename V> 2468cb33240Swren romano class SparseTensorEnumeratorBase; 2478cb33240Swren romano 248774674ceSwren romano // Helper macro for generating error messages when some 249774674ceSwren romano // `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase` 250774674ceSwren romano // and then the wrong "partial method specialization" is called. 251774674ceSwren romano #define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME); 252774674ceSwren romano 2538d8b566fSwren romano /// Abstract base class for `SparseTensorStorage<P,I,V>`. This class 2548d8b566fSwren romano /// takes responsibility for all the `<P,I,V>`-independent aspects 2558d8b566fSwren romano /// of the tensor (e.g., shape, sparsity, permutation). In addition, 2568d8b566fSwren romano /// we use function overloading to implement "partial" method 2578d8b566fSwren romano /// specialization, which the C-API relies on to catch type errors 2588d8b566fSwren romano /// arising from our use of opaque pointers. 2598a91bc7bSHarrietAkot class SparseTensorStorageBase { 2608a91bc7bSHarrietAkot public: 2618d8b566fSwren romano /// Constructs a new storage object. The `perm` maps the tensor's 2628d8b566fSwren romano /// semantic-ordering of dimensions to this object's storage-order. 263fa6aed2aSwren romano /// The `dimSizes` and `sparsity` arrays are already in storage-order. 2648d8b566fSwren romano /// 265fa6aed2aSwren romano /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`. 266fa6aed2aSwren romano SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes, 2678d8b566fSwren romano const uint64_t *perm, const DimLevelType *sparsity) 268fa6aed2aSwren romano : dimSizes(dimSizes), rev(getRank()), 2698d8b566fSwren romano dimTypes(sparsity, sparsity + getRank()) { 270753fe330Swren romano assert(perm && sparsity); 2718d8b566fSwren romano const uint64_t rank = getRank(); 2728d8b566fSwren romano // Validate parameters. 2738d8b566fSwren romano assert(rank > 0 && "Trivial shape is unsupported"); 2748d8b566fSwren romano for (uint64_t r = 0; r < rank; r++) { 2758d8b566fSwren romano assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage"); 2768d8b566fSwren romano assert((dimTypes[r] == DimLevelType::kDense || 2778d8b566fSwren romano dimTypes[r] == DimLevelType::kCompressed) && 2788d8b566fSwren romano "Unsupported DimLevelType"); 2798d8b566fSwren romano } 2808d8b566fSwren romano // Construct the "reverse" (i.e., inverse) permutation. 2818d8b566fSwren romano for (uint64_t r = 0; r < rank; r++) 2828d8b566fSwren romano rev[perm[r]] = r; 2838d8b566fSwren romano } 2848d8b566fSwren romano 2858d8b566fSwren romano virtual ~SparseTensorStorageBase() = default; 2868d8b566fSwren romano 2878d8b566fSwren romano /// Get the rank of the tensor. 2888d8b566fSwren romano uint64_t getRank() const { return dimSizes.size(); } 2898d8b566fSwren romano 2908d8b566fSwren romano /// Getter for the dimension-sizes array, in storage-order. 2918d8b566fSwren romano const std::vector<uint64_t> &getDimSizes() const { return dimSizes; } 2928d8b566fSwren romano 2938d8b566fSwren romano /// Safely lookup the size of the given (storage-order) dimension. 2948d8b566fSwren romano uint64_t getDimSize(uint64_t d) const { 2958d8b566fSwren romano assert(d < getRank()); 2968d8b566fSwren romano return dimSizes[d]; 2978d8b566fSwren romano } 2988d8b566fSwren romano 2998d8b566fSwren romano /// Getter for the "reverse" permutation, which maps this object's 3008d8b566fSwren romano /// storage-order to the tensor's semantic-order. 3018d8b566fSwren romano const std::vector<uint64_t> &getRev() const { return rev; } 3028d8b566fSwren romano 3038d8b566fSwren romano /// Getter for the dimension-types array, in storage-order. 3048d8b566fSwren romano const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; } 3058d8b566fSwren romano 3068d8b566fSwren romano /// Safely check if the (storage-order) dimension uses compressed storage. 3078d8b566fSwren romano bool isCompressedDim(uint64_t d) const { 3088d8b566fSwren romano assert(d < getRank()); 3098d8b566fSwren romano return (dimTypes[d] == DimLevelType::kCompressed); 3108d8b566fSwren romano } 3118a91bc7bSHarrietAkot 3128cb33240Swren romano /// Allocate a new enumerator. 3131313f5d3Swren romano #define DECL_NEWENUMERATOR(VNAME, V) \ 3141313f5d3Swren romano virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t, \ 3151313f5d3Swren romano const uint64_t *) const { \ 316774674ceSwren romano FATAL_PIV("newEnumerator" #VNAME); \ 3178cb33240Swren romano } 3181313f5d3Swren romano FOREVERY_V(DECL_NEWENUMERATOR) 3191313f5d3Swren romano #undef DECL_NEWENUMERATOR 3208cb33240Swren romano 3214f2ec7f9SAart Bik /// Overhead storage. 322a9a19f59Swren romano #define DECL_GETPOINTERS(PNAME, P) \ 323a9a19f59Swren romano virtual void getPointers(std::vector<P> **, uint64_t) { \ 324a9a19f59Swren romano FATAL_PIV("getPointers" #PNAME); \ 325774674ceSwren romano } 326a9a19f59Swren romano FOREVERY_FIXED_O(DECL_GETPOINTERS) 327a9a19f59Swren romano #undef DECL_GETPOINTERS 328a9a19f59Swren romano #define DECL_GETINDICES(INAME, I) \ 329a9a19f59Swren romano virtual void getIndices(std::vector<I> **, uint64_t) { \ 330a9a19f59Swren romano FATAL_PIV("getIndices" #INAME); \ 331774674ceSwren romano } 332a9a19f59Swren romano FOREVERY_FIXED_O(DECL_GETINDICES) 333a9a19f59Swren romano #undef DECL_GETINDICES 3348a91bc7bSHarrietAkot 3354f2ec7f9SAart Bik /// Primary storage. 3361313f5d3Swren romano #define DECL_GETVALUES(VNAME, V) \ 337774674ceSwren romano virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); } 3381313f5d3Swren romano FOREVERY_V(DECL_GETVALUES) 3391313f5d3Swren romano #undef DECL_GETVALUES 3408a91bc7bSHarrietAkot 3414f2ec7f9SAart Bik /// Element-wise insertion in lexicographic index order. 3421313f5d3Swren romano #define DECL_LEXINSERT(VNAME, V) \ 343774674ceSwren romano virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); } 3441313f5d3Swren romano FOREVERY_V(DECL_LEXINSERT) 3451313f5d3Swren romano #undef DECL_LEXINSERT 3464f2ec7f9SAart Bik 3474f2ec7f9SAart Bik /// Expanded insertion. 3481313f5d3Swren romano #define DECL_EXPINSERT(VNAME, V) \ 3491313f5d3Swren romano virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \ 350774674ceSwren romano FATAL_PIV("expInsert" #VNAME); \ 3514f2ec7f9SAart Bik } 3521313f5d3Swren romano FOREVERY_V(DECL_EXPINSERT) 3531313f5d3Swren romano #undef DECL_EXPINSERT 3544f2ec7f9SAart Bik 3554f2ec7f9SAart Bik /// Finishes insertion. 356f66e5769SAart Bik virtual void endInsert() = 0; 357f66e5769SAart Bik 358753fe330Swren romano protected: 359753fe330Swren romano // Since this class is virtual, we must disallow public copying in 360753fe330Swren romano // order to avoid "slicing". Since this class has data members, 361753fe330Swren romano // that means making copying protected. 362753fe330Swren romano // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual> 363753fe330Swren romano SparseTensorStorageBase(const SparseTensorStorageBase &) = default; 364753fe330Swren romano // Copy-assignment would be implicitly deleted (because `dimSizes` 365753fe330Swren romano // is const), so we explicitly delete it for clarity. 366753fe330Swren romano SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete; 367753fe330Swren romano 3688a91bc7bSHarrietAkot private: 3698d8b566fSwren romano const std::vector<uint64_t> dimSizes; 3708d8b566fSwren romano std::vector<uint64_t> rev; 3718d8b566fSwren romano const std::vector<DimLevelType> dimTypes; 3728a91bc7bSHarrietAkot }; 3738a91bc7bSHarrietAkot 374774674ceSwren romano #undef FATAL_PIV 375774674ceSwren romano 376753fe330Swren romano // Forward. 377753fe330Swren romano template <typename P, typename I, typename V> 378753fe330Swren romano class SparseTensorEnumerator; 379753fe330Swren romano 3808a91bc7bSHarrietAkot /// A memory-resident sparse tensor using a storage scheme based on 3818a91bc7bSHarrietAkot /// per-dimension sparse/dense annotations. This data structure provides a 3828a91bc7bSHarrietAkot /// bufferized form of a sparse tensor type. In contrast to generating setup 3838a91bc7bSHarrietAkot /// methods for each differently annotated sparse tensor, this method provides 3848a91bc7bSHarrietAkot /// a convenient "one-size-fits-all" solution that simply takes an input tensor 3858a91bc7bSHarrietAkot /// and annotations to implement all required setup in a general manner. 3868a91bc7bSHarrietAkot template <typename P, typename I, typename V> 38776944420Swren romano class SparseTensorStorage final : public SparseTensorStorageBase { 3888cb33240Swren romano /// Private constructor to share code between the other constructors. 3898cb33240Swren romano /// Beware that the object is not necessarily guaranteed to be in a 3908cb33240Swren romano /// valid state after this constructor alone; e.g., `isCompressedDim(d)` 3918cb33240Swren romano /// doesn't entail `!(pointers[d].empty())`. 3928cb33240Swren romano /// 393fa6aed2aSwren romano /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`. 394fa6aed2aSwren romano SparseTensorStorage(const std::vector<uint64_t> &dimSizes, 395fa6aed2aSwren romano const uint64_t *perm, const DimLevelType *sparsity) 396fa6aed2aSwren romano : SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()), 3978cb33240Swren romano indices(getRank()), idx(getRank()) {} 3988cb33240Swren romano 3998a91bc7bSHarrietAkot public: 4008a91bc7bSHarrietAkot /// Constructs a sparse tensor storage scheme with the given dimensions, 4018a91bc7bSHarrietAkot /// permutation, and per-dimension dense/sparse annotations, using 4028a91bc7bSHarrietAkot /// the coordinate scheme tensor for the initial contents if provided. 4038d8b566fSwren romano /// 404fa6aed2aSwren romano /// Precondition: `perm` and `sparsity` must be valid for `dimSizes.size()`. 405fa6aed2aSwren romano SparseTensorStorage(const std::vector<uint64_t> &dimSizes, 406fa6aed2aSwren romano const uint64_t *perm, const DimLevelType *sparsity, 407fa6aed2aSwren romano SparseTensorCOO<V> *coo) 408fa6aed2aSwren romano : SparseTensorStorage(dimSizes, perm, sparsity) { 4098a91bc7bSHarrietAkot // Provide hints on capacity of pointers and indices. 410175b9af4SAart Bik // TODO: needs much fine-tuning based on actual sparsity; currently 411175b9af4SAart Bik // we reserve pointer/index space based on all previous dense 412175b9af4SAart Bik // dimensions, which works well up to first sparse dim; but 413175b9af4SAart Bik // we should really use nnz and dense/sparse distribution. 414f66e5769SAart Bik bool allDense = true; 415f66e5769SAart Bik uint64_t sz = 1; 4168d8b566fSwren romano for (uint64_t r = 0, rank = getRank(); r < rank; r++) { 4178d8b566fSwren romano if (isCompressedDim(r)) { 418fa6aed2aSwren romano // TODO: Take a parameter between 1 and `dimSizes[r]`, and multiply 4198d8b566fSwren romano // `sz` by that before reserving. (For now we just use 1.) 420f66e5769SAart Bik pointers[r].reserve(sz + 1); 4218d8b566fSwren romano pointers[r].push_back(0); 422f66e5769SAart Bik indices[r].reserve(sz); 423f66e5769SAart Bik sz = 1; 424f66e5769SAart Bik allDense = false; 4258d8b566fSwren romano } else { // Dense dimension. 4268d8b566fSwren romano sz = checkedMul(sz, getDimSizes()[r]); 4278a91bc7bSHarrietAkot } 4288a91bc7bSHarrietAkot } 4298a91bc7bSHarrietAkot // Then assign contents from coordinate scheme tensor if provided. 4308d8b566fSwren romano if (coo) { 4314d0a18d0Swren romano // Ensure both preconditions of `fromCOO`. 432fa6aed2aSwren romano assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch"); 4338d8b566fSwren romano coo->sort(); 4344d0a18d0Swren romano // Now actually insert the `elements`. 4358d8b566fSwren romano const std::vector<Element<V>> &elements = coo->getElements(); 436ceda1ae9Swren romano uint64_t nnz = elements.size(); 4378a91bc7bSHarrietAkot values.reserve(nnz); 438ceda1ae9Swren romano fromCOO(elements, 0, nnz, 0); 4391ce77b56SAart Bik } else if (allDense) { 440f66e5769SAart Bik values.resize(sz, 0); 4418a91bc7bSHarrietAkot } 4428a91bc7bSHarrietAkot } 4438a91bc7bSHarrietAkot 4448cb33240Swren romano /// Constructs a sparse tensor storage scheme with the given dimensions, 4458cb33240Swren romano /// permutation, and per-dimension dense/sparse annotations, using 4468cb33240Swren romano /// the given sparse tensor for the initial contents. 4478cb33240Swren romano /// 4488cb33240Swren romano /// Preconditions: 449fa6aed2aSwren romano /// * `perm` and `sparsity` must be valid for `dimSizes.size()`. 4508cb33240Swren romano /// * The `tensor` must have the same value type `V`. 451fa6aed2aSwren romano SparseTensorStorage(const std::vector<uint64_t> &dimSizes, 452fa6aed2aSwren romano const uint64_t *perm, const DimLevelType *sparsity, 4538cb33240Swren romano const SparseTensorStorageBase &tensor); 4548cb33240Swren romano 4550f68c959SMehdi Amini ~SparseTensorStorage() final = default; 4568a91bc7bSHarrietAkot 457f66e5769SAart Bik /// Partially specialize these getter methods based on template types. 4580f68c959SMehdi Amini void getPointers(std::vector<P> **out, uint64_t d) final { 4598a91bc7bSHarrietAkot assert(d < getRank()); 4608a91bc7bSHarrietAkot *out = &pointers[d]; 4618a91bc7bSHarrietAkot } 4620f68c959SMehdi Amini void getIndices(std::vector<I> **out, uint64_t d) final { 4638a91bc7bSHarrietAkot assert(d < getRank()); 4648a91bc7bSHarrietAkot *out = &indices[d]; 4658a91bc7bSHarrietAkot } 4660f68c959SMehdi Amini void getValues(std::vector<V> **out) final { *out = &values; } 4678a91bc7bSHarrietAkot 46803fe15ceSAart Bik /// Partially specialize lexicographical insertions based on template types. 4690f68c959SMehdi Amini void lexInsert(const uint64_t *cursor, V val) final { 4701ce77b56SAart Bik // First, wrap up pending insertion path. 4711ce77b56SAart Bik uint64_t diff = 0; 4721ce77b56SAart Bik uint64_t top = 0; 4731ce77b56SAart Bik if (!values.empty()) { 4741ce77b56SAart Bik diff = lexDiff(cursor); 4751ce77b56SAart Bik endPath(diff + 1); 4761ce77b56SAart Bik top = idx[diff] + 1; 4771ce77b56SAart Bik } 4781ce77b56SAart Bik // Then continue with insertion path. 4791ce77b56SAart Bik insPath(cursor, diff, top, val); 480f66e5769SAart Bik } 481f66e5769SAart Bik 4824f2ec7f9SAart Bik /// Partially specialize expanded insertions based on template types. 4834f2ec7f9SAart Bik /// Note that this method resets the values/filled-switch array back 4844f2ec7f9SAart Bik /// to all-zero/false while only iterating over the nonzero elements. 4854f2ec7f9SAart Bik void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added, 4860f68c959SMehdi Amini uint64_t count) final { 4874f2ec7f9SAart Bik if (count == 0) 4884f2ec7f9SAart Bik return; 4894f2ec7f9SAart Bik // Sort. 4904f2ec7f9SAart Bik std::sort(added, added + count); 4914f2ec7f9SAart Bik // Restore insertion path for first insert. 4923bf2ba3bSwren romano const uint64_t lastDim = getRank() - 1; 4934f2ec7f9SAart Bik uint64_t index = added[0]; 4943bf2ba3bSwren romano cursor[lastDim] = index; 4954f2ec7f9SAart Bik lexInsert(cursor, values[index]); 4964f2ec7f9SAart Bik assert(filled[index]); 4974f2ec7f9SAart Bik values[index] = 0; 4984f2ec7f9SAart Bik filled[index] = false; 4994f2ec7f9SAart Bik // Subsequent insertions are quick. 5004f2ec7f9SAart Bik for (uint64_t i = 1; i < count; i++) { 5014f2ec7f9SAart Bik assert(index < added[i] && "non-lexicographic insertion"); 5024f2ec7f9SAart Bik index = added[i]; 5033bf2ba3bSwren romano cursor[lastDim] = index; 5043bf2ba3bSwren romano insPath(cursor, lastDim, added[i - 1] + 1, values[index]); 5054f2ec7f9SAart Bik assert(filled[index]); 5063bf2ba3bSwren romano values[index] = 0; 5074f2ec7f9SAart Bik filled[index] = false; 5084f2ec7f9SAart Bik } 5094f2ec7f9SAart Bik } 5104f2ec7f9SAart Bik 511f66e5769SAart Bik /// Finalizes lexicographic insertions. 5120f68c959SMehdi Amini void endInsert() final { 5131ce77b56SAart Bik if (values.empty()) 51472ec2f76Swren romano finalizeSegment(0); 5151ce77b56SAart Bik else 5161ce77b56SAart Bik endPath(0); 5171ce77b56SAart Bik } 518f66e5769SAart Bik 5198cb33240Swren romano void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank, 5200f68c959SMehdi Amini const uint64_t *perm) const final { 5218cb33240Swren romano *out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm); 5228cb33240Swren romano } 5238cb33240Swren romano 5248a91bc7bSHarrietAkot /// Returns this sparse tensor storage scheme as a new memory-resident 5258a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme with the given dimension order. 5268d8b566fSwren romano /// 5278d8b566fSwren romano /// Precondition: `perm` must be valid for `getRank()`. 528753fe330Swren romano SparseTensorCOO<V> *toCOO(const uint64_t *perm) const { 5298cb33240Swren romano SparseTensorEnumeratorBase<V> *enumerator; 5308cb33240Swren romano newEnumerator(&enumerator, getRank(), perm); 531753fe330Swren romano SparseTensorCOO<V> *coo = 5328cb33240Swren romano new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size()); 5338cb33240Swren romano enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) { 534753fe330Swren romano coo->add(ind, val); 535753fe330Swren romano }); 5368d8b566fSwren romano // TODO: This assertion assumes there are no stored zeros, 5378d8b566fSwren romano // or if there are then that we don't filter them out. 5388d8b566fSwren romano // Cf., <https://github.com/llvm/llvm-project/issues/54179> 5398d8b566fSwren romano assert(coo->getElements().size() == values.size()); 5408cb33240Swren romano delete enumerator; 5418d8b566fSwren romano return coo; 5428a91bc7bSHarrietAkot } 5438a91bc7bSHarrietAkot 5448a91bc7bSHarrietAkot /// Factory method. Constructs a sparse tensor storage scheme with the given 5458a91bc7bSHarrietAkot /// dimensions, permutation, and per-dimension dense/sparse annotations, 5468a91bc7bSHarrietAkot /// using the coordinate scheme tensor for the initial contents if provided. 5478a91bc7bSHarrietAkot /// In the latter case, the coordinate scheme must respect the same 5488a91bc7bSHarrietAkot /// permutation as is desired for the new sparse tensor storage. 5498d8b566fSwren romano /// 5508d8b566fSwren romano /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`. 5518a91bc7bSHarrietAkot static SparseTensorStorage<P, I, V> * 552d83a7068Swren romano newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm, 5538d8b566fSwren romano const DimLevelType *sparsity, SparseTensorCOO<V> *coo) { 5548a91bc7bSHarrietAkot SparseTensorStorage<P, I, V> *n = nullptr; 5558d8b566fSwren romano if (coo) { 556fa6aed2aSwren romano const auto &coosz = coo->getDimSizes(); 5578cb33240Swren romano assertPermutedSizesMatchShape(coosz, rank, perm, shape); 5588d8b566fSwren romano n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo); 5598a91bc7bSHarrietAkot } else { 5608a91bc7bSHarrietAkot std::vector<uint64_t> permsz(rank); 561d83a7068Swren romano for (uint64_t r = 0; r < rank; r++) { 562d83a7068Swren romano assert(shape[r] > 0 && "Dimension size zero has trivial storage"); 563d83a7068Swren romano permsz[perm[r]] = shape[r]; 564d83a7068Swren romano } 5658cb33240Swren romano // We pass the null `coo` to ensure we select the intended constructor. 5668cb33240Swren romano n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo); 5678a91bc7bSHarrietAkot } 5688a91bc7bSHarrietAkot return n; 5698a91bc7bSHarrietAkot } 5708a91bc7bSHarrietAkot 5718cb33240Swren romano /// Factory method. Constructs a sparse tensor storage scheme with 5728cb33240Swren romano /// the given dimensions, permutation, and per-dimension dense/sparse 5738cb33240Swren romano /// annotations, using the sparse tensor for the initial contents. 5748cb33240Swren romano /// 5758cb33240Swren romano /// Preconditions: 5768cb33240Swren romano /// * `shape`, `perm`, and `sparsity` must be valid for `rank`. 5778cb33240Swren romano /// * The `tensor` must have the same value type `V`. 5788cb33240Swren romano static SparseTensorStorage<P, I, V> * 5798cb33240Swren romano newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm, 5808cb33240Swren romano const DimLevelType *sparsity, 5818cb33240Swren romano const SparseTensorStorageBase *source) { 5828cb33240Swren romano assert(source && "Got nullptr for source"); 5838cb33240Swren romano SparseTensorEnumeratorBase<V> *enumerator; 5848cb33240Swren romano source->newEnumerator(&enumerator, rank, perm); 5858cb33240Swren romano const auto &permsz = enumerator->permutedSizes(); 5868cb33240Swren romano assertPermutedSizesMatchShape(permsz, rank, perm, shape); 5878cb33240Swren romano auto *tensor = 5888cb33240Swren romano new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source); 5898cb33240Swren romano delete enumerator; 5908cb33240Swren romano return tensor; 5918cb33240Swren romano } 5928cb33240Swren romano 5938a91bc7bSHarrietAkot private: 59472ec2f76Swren romano /// Appends an arbitrary new position to `pointers[d]`. This method 59572ec2f76Swren romano /// checks that `pos` is representable in the `P` type; however, it 59672ec2f76Swren romano /// does not check that `pos` is semantically valid (i.e., larger than 59772ec2f76Swren romano /// the previous position and smaller than `indices[d].capacity()`). 5988d8b566fSwren romano void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) { 59972ec2f76Swren romano assert(isCompressedDim(d)); 60072ec2f76Swren romano assert(pos <= std::numeric_limits<P>::max() && 6014d0a18d0Swren romano "Pointer value is too large for the P-type"); 60272ec2f76Swren romano pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos)); 6034d0a18d0Swren romano } 6044d0a18d0Swren romano 60572ec2f76Swren romano /// Appends index `i` to dimension `d`, in the semantically general 60672ec2f76Swren romano /// sense. For non-dense dimensions, that means appending to the 60772ec2f76Swren romano /// `indices[d]` array, checking that `i` is representable in the `I` 60872ec2f76Swren romano /// type; however, we do not verify other semantic requirements (e.g., 609fa6aed2aSwren romano /// that `i` is in bounds for `dimSizes[d]`, and not previously occurring 61072ec2f76Swren romano /// in the same segment). For dense dimensions, this method instead 61172ec2f76Swren romano /// appends the appropriate number of zeros to the `values` array, 61272ec2f76Swren romano /// where `full` is the number of "entries" already written to `values` 61372ec2f76Swren romano /// for this segment (aka one after the highest index previously appended). 61472ec2f76Swren romano void appendIndex(uint64_t d, uint64_t full, uint64_t i) { 61572ec2f76Swren romano if (isCompressedDim(d)) { 6164d0a18d0Swren romano assert(i <= std::numeric_limits<I>::max() && 6174d0a18d0Swren romano "Index value is too large for the I-type"); 61872ec2f76Swren romano indices[d].push_back(static_cast<I>(i)); 61972ec2f76Swren romano } else { // Dense dimension. 62072ec2f76Swren romano assert(i >= full && "Index was already filled"); 62172ec2f76Swren romano if (i == full) 62272ec2f76Swren romano return; // Short-circuit, since it'll be a nop. 62372ec2f76Swren romano if (d + 1 == getRank()) 62472ec2f76Swren romano values.insert(values.end(), i - full, 0); 62572ec2f76Swren romano else 62672ec2f76Swren romano finalizeSegment(d + 1, 0, i - full); 62772ec2f76Swren romano } 6284d0a18d0Swren romano } 6294d0a18d0Swren romano 6308cb33240Swren romano /// Writes the given coordinate to `indices[d][pos]`. This method 6318cb33240Swren romano /// checks that `i` is representable in the `I` type; however, it 6328cb33240Swren romano /// does not check that `i` is semantically valid (i.e., in bounds 633fa6aed2aSwren romano /// for `dimSizes[d]` and not elsewhere occurring in the same segment). 6348cb33240Swren romano void writeIndex(uint64_t d, uint64_t pos, uint64_t i) { 6358cb33240Swren romano assert(isCompressedDim(d)); 6368cb33240Swren romano // Subscript assignment to `std::vector` requires that the `pos`-th 6378cb33240Swren romano // entry has been initialized; thus we must be sure to check `size()` 6388cb33240Swren romano // here, instead of `capacity()` as would be ideal. 6398cb33240Swren romano assert(pos < indices[d].size() && "Index position is out of bounds"); 6408cb33240Swren romano assert(i <= std::numeric_limits<I>::max() && 6418cb33240Swren romano "Index value is too large for the I-type"); 6428cb33240Swren romano indices[d][pos] = static_cast<I>(i); 6438cb33240Swren romano } 6448cb33240Swren romano 6458cb33240Swren romano /// Computes the assembled-size associated with the `d`-th dimension, 6468cb33240Swren romano /// given the assembled-size associated with the `(d-1)`-th dimension. 6478cb33240Swren romano /// "Assembled-sizes" correspond to the (nominal) sizes of overhead 6488cb33240Swren romano /// storage, as opposed to "dimension-sizes" which are the cardinality 6498cb33240Swren romano /// of coordinates for that dimension. 6508cb33240Swren romano /// 6518cb33240Swren romano /// Precondition: the `pointers[d]` array must be fully initialized 6528cb33240Swren romano /// before calling this method. 6538cb33240Swren romano uint64_t assembledSize(uint64_t parentSz, uint64_t d) const { 6548cb33240Swren romano if (isCompressedDim(d)) 6558cb33240Swren romano return pointers[d][parentSz]; 6568cb33240Swren romano // else if dense: 6578cb33240Swren romano return parentSz * getDimSizes()[d]; 6588cb33240Swren romano } 6598cb33240Swren romano 6608a91bc7bSHarrietAkot /// Initializes sparse tensor storage scheme from a memory-resident sparse 6618a91bc7bSHarrietAkot /// tensor in coordinate scheme. This method prepares the pointers and 6628a91bc7bSHarrietAkot /// indices arrays under the given per-dimension dense/sparse annotations. 6634d0a18d0Swren romano /// 6644d0a18d0Swren romano /// Preconditions: 6654d0a18d0Swren romano /// (1) the `elements` must be lexicographically sorted. 666fa6aed2aSwren romano /// (2) the indices of every element are valid for `dimSizes` (equal rank 6674d0a18d0Swren romano /// and pointwise less-than). 668ceda1ae9Swren romano void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo, 669ceda1ae9Swren romano uint64_t hi, uint64_t d) { 670753fe330Swren romano uint64_t rank = getRank(); 671753fe330Swren romano assert(d <= rank && hi <= elements.size()); 6728a91bc7bSHarrietAkot // Once dimensions are exhausted, insert the numerical values. 673753fe330Swren romano if (d == rank) { 674c4017f9dSwren romano assert(lo < hi); 6751ce77b56SAart Bik values.push_back(elements[lo].value); 6768a91bc7bSHarrietAkot return; 6778a91bc7bSHarrietAkot } 6788a91bc7bSHarrietAkot // Visit all elements in this interval. 6798a91bc7bSHarrietAkot uint64_t full = 0; 680c4017f9dSwren romano while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`. 6818a91bc7bSHarrietAkot // Find segment in interval with same index elements in this dimension. 682f66e5769SAart Bik uint64_t i = elements[lo].indices[d]; 6838a91bc7bSHarrietAkot uint64_t seg = lo + 1; 684f66e5769SAart Bik while (seg < hi && elements[seg].indices[d] == i) 6858a91bc7bSHarrietAkot seg++; 6868a91bc7bSHarrietAkot // Handle segment in interval for sparse or dense dimension. 68772ec2f76Swren romano appendIndex(d, full, i); 68872ec2f76Swren romano full = i + 1; 689ceda1ae9Swren romano fromCOO(elements, lo, seg, d + 1); 6908a91bc7bSHarrietAkot // And move on to next segment in interval. 6918a91bc7bSHarrietAkot lo = seg; 6928a91bc7bSHarrietAkot } 6938a91bc7bSHarrietAkot // Finalize the sparse pointer structure at this dimension. 69472ec2f76Swren romano finalizeSegment(d, full); 6958a91bc7bSHarrietAkot } 6968a91bc7bSHarrietAkot 69772ec2f76Swren romano /// Finalize the sparse pointer structure at this dimension. 69872ec2f76Swren romano void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) { 69972ec2f76Swren romano if (count == 0) 70072ec2f76Swren romano return; // Short-circuit, since it'll be a nop. 70172ec2f76Swren romano if (isCompressedDim(d)) { 70272ec2f76Swren romano appendPointer(d, indices[d].size(), count); 70372ec2f76Swren romano } else { // Dense dimension. 7048d8b566fSwren romano const uint64_t sz = getDimSizes()[d]; 70572ec2f76Swren romano assert(sz >= full && "Segment is overfull"); 7068d8b566fSwren romano count = checkedMul(count, sz - full); 70772ec2f76Swren romano // For dense storage we must enumerate all the remaining coordinates 70872ec2f76Swren romano // in this dimension (i.e., coordinates after the last non-zero 70972ec2f76Swren romano // element), and either fill in their zero values or else recurse 71072ec2f76Swren romano // to finalize some deeper dimension. 71172ec2f76Swren romano if (d + 1 == getRank()) 71272ec2f76Swren romano values.insert(values.end(), count, 0); 71372ec2f76Swren romano else 71472ec2f76Swren romano finalizeSegment(d + 1, 0, count); 7151ce77b56SAart Bik } 7161ce77b56SAart Bik } 7171ce77b56SAart Bik 7181ce77b56SAart Bik /// Wraps up a single insertion path, inner to outer. 7191ce77b56SAart Bik void endPath(uint64_t diff) { 7201ce77b56SAart Bik uint64_t rank = getRank(); 7211ce77b56SAart Bik assert(diff <= rank); 7221ce77b56SAart Bik for (uint64_t i = 0; i < rank - diff; i++) { 72372ec2f76Swren romano const uint64_t d = rank - i - 1; 72472ec2f76Swren romano finalizeSegment(d, idx[d] + 1); 7251ce77b56SAart Bik } 7261ce77b56SAart Bik } 7271ce77b56SAart Bik 7281ce77b56SAart Bik /// Continues a single insertion path, outer to inner. 729c03fd1e6Swren romano void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) { 7301ce77b56SAart Bik uint64_t rank = getRank(); 7311ce77b56SAart Bik assert(diff < rank); 7321ce77b56SAart Bik for (uint64_t d = diff; d < rank; d++) { 7331ce77b56SAart Bik uint64_t i = cursor[d]; 73472ec2f76Swren romano appendIndex(d, top, i); 7351ce77b56SAart Bik top = 0; 7361ce77b56SAart Bik idx[d] = i; 7371ce77b56SAart Bik } 7381ce77b56SAart Bik values.push_back(val); 7391ce77b56SAart Bik } 7401ce77b56SAart Bik 7411ce77b56SAart Bik /// Finds the lexicographic differing dimension. 74246bdacaaSwren romano uint64_t lexDiff(const uint64_t *cursor) const { 7431ce77b56SAart Bik for (uint64_t r = 0, rank = getRank(); r < rank; r++) 7441ce77b56SAart Bik if (cursor[r] > idx[r]) 7451ce77b56SAart Bik return r; 7461ce77b56SAart Bik else 7471ce77b56SAart Bik assert(cursor[r] == idx[r] && "non-lexicographic insertion"); 7481ce77b56SAart Bik assert(0 && "duplication insertion"); 7491ce77b56SAart Bik return -1u; 7501ce77b56SAart Bik } 7511ce77b56SAart Bik 752753fe330Swren romano // Allow `SparseTensorEnumerator` to access the data-members (to avoid 753753fe330Swren romano // the cost of virtual-function dispatch in inner loops), without 754753fe330Swren romano // making them public to other client code. 755753fe330Swren romano friend class SparseTensorEnumerator<P, I, V>; 756753fe330Swren romano 7578a91bc7bSHarrietAkot std::vector<std::vector<P>> pointers; 7588a91bc7bSHarrietAkot std::vector<std::vector<I>> indices; 7598a91bc7bSHarrietAkot std::vector<V> values; 7608d8b566fSwren romano std::vector<uint64_t> idx; // index cursor for lexicographic insertion. 7618a91bc7bSHarrietAkot }; 7628a91bc7bSHarrietAkot 763753fe330Swren romano /// A (higher-order) function object for enumerating the elements of some 764753fe330Swren romano /// `SparseTensorStorage` under a permutation. That is, the `forallElements` 765753fe330Swren romano /// method encapsulates the loop-nest for enumerating the elements of 766753fe330Swren romano /// the source tensor (in whatever order is best for the source tensor), 767753fe330Swren romano /// and applies a permutation to the coordinates/indices before handing 768753fe330Swren romano /// each element to the callback. A single enumerator object can be 769753fe330Swren romano /// freely reused for several calls to `forallElements`, just so long 770753fe330Swren romano /// as each call is sequential with respect to one another. 771753fe330Swren romano /// 772753fe330Swren romano /// N.B., this class stores a reference to the `SparseTensorStorageBase` 773753fe330Swren romano /// passed to the constructor; thus, objects of this class must not 774753fe330Swren romano /// outlive the sparse tensor they depend on. 775753fe330Swren romano /// 776753fe330Swren romano /// Design Note: The reason we define this class instead of simply using 777753fe330Swren romano /// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize 778753fe330Swren romano /// the `<P,I>` template parameters from MLIR client code (to simplify the 779753fe330Swren romano /// type parameters used for direct sparse-to-sparse conversion). And the 780753fe330Swren romano /// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather 781753fe330Swren romano /// than simply using this class, is to avoid the cost of virtual-method 782753fe330Swren romano /// dispatch within the loop-nest. 783753fe330Swren romano template <typename V> 784753fe330Swren romano class SparseTensorEnumeratorBase { 785753fe330Swren romano public: 786753fe330Swren romano /// Constructs an enumerator with the given permutation for mapping 787753fe330Swren romano /// the semantic-ordering of dimensions to the desired target-ordering. 788753fe330Swren romano /// 789753fe330Swren romano /// Preconditions: 790753fe330Swren romano /// * the `tensor` must have the same `V` value type. 791753fe330Swren romano /// * `perm` must be valid for `rank`. 792753fe330Swren romano SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor, 793753fe330Swren romano uint64_t rank, const uint64_t *perm) 794753fe330Swren romano : src(tensor), permsz(src.getRev().size()), reord(getRank()), 795753fe330Swren romano cursor(getRank()) { 796753fe330Swren romano assert(perm && "Received nullptr for permutation"); 797753fe330Swren romano assert(rank == getRank() && "Permutation rank mismatch"); 798fa6aed2aSwren romano const auto &rev = src.getRev(); // source-order -> semantic-order 799fa6aed2aSwren romano const auto &dimSizes = src.getDimSizes(); // in source storage-order 800753fe330Swren romano for (uint64_t s = 0; s < rank; s++) { // `s` source storage-order 801753fe330Swren romano uint64_t t = perm[rev[s]]; // `t` target-order 802753fe330Swren romano reord[s] = t; 803fa6aed2aSwren romano permsz[t] = dimSizes[s]; 804753fe330Swren romano } 805753fe330Swren romano } 806753fe330Swren romano 807753fe330Swren romano virtual ~SparseTensorEnumeratorBase() = default; 808753fe330Swren romano 809753fe330Swren romano // We disallow copying to help avoid leaking the `src` reference. 810753fe330Swren romano // (In addition to avoiding the problem of slicing.) 811753fe330Swren romano SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete; 812753fe330Swren romano SparseTensorEnumeratorBase & 813753fe330Swren romano operator=(const SparseTensorEnumeratorBase &) = delete; 814753fe330Swren romano 815753fe330Swren romano /// Returns the source/target tensor's rank. (The source-rank and 816753fe330Swren romano /// target-rank are always equal since we only support permutations. 817753fe330Swren romano /// Though once we add support for other dimension mappings, this 818753fe330Swren romano /// method will have to be split in two.) 819753fe330Swren romano uint64_t getRank() const { return permsz.size(); } 820753fe330Swren romano 821753fe330Swren romano /// Returns the target tensor's dimension sizes. 822753fe330Swren romano const std::vector<uint64_t> &permutedSizes() const { return permsz; } 823753fe330Swren romano 824753fe330Swren romano /// Enumerates all elements of the source tensor, permutes their 825753fe330Swren romano /// indices, and passes the permuted element to the callback. 826753fe330Swren romano /// The callback must not store the cursor reference directly, 827753fe330Swren romano /// since this function reuses the storage. Instead, the callback 828753fe330Swren romano /// must copy it if they want to keep it. 829753fe330Swren romano virtual void forallElements(ElementConsumer<V> yield) = 0; 830753fe330Swren romano 831753fe330Swren romano protected: 832753fe330Swren romano const SparseTensorStorageBase &src; 833753fe330Swren romano std::vector<uint64_t> permsz; // in target order. 834753fe330Swren romano std::vector<uint64_t> reord; // source storage-order -> target order. 835753fe330Swren romano std::vector<uint64_t> cursor; // in target order. 836753fe330Swren romano }; 837753fe330Swren romano 838753fe330Swren romano template <typename P, typename I, typename V> 839753fe330Swren romano class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> { 840753fe330Swren romano using Base = SparseTensorEnumeratorBase<V>; 841753fe330Swren romano 842753fe330Swren romano public: 843753fe330Swren romano /// Constructs an enumerator with the given permutation for mapping 844753fe330Swren romano /// the semantic-ordering of dimensions to the desired target-ordering. 845753fe330Swren romano /// 846753fe330Swren romano /// Precondition: `perm` must be valid for `rank`. 847753fe330Swren romano SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor, 848753fe330Swren romano uint64_t rank, const uint64_t *perm) 849753fe330Swren romano : Base(tensor, rank, perm) {} 850753fe330Swren romano 851f38765a8SMehdi Amini ~SparseTensorEnumerator() final = default; 852753fe330Swren romano 853f38765a8SMehdi Amini void forallElements(ElementConsumer<V> yield) final { 854753fe330Swren romano forallElements(yield, 0, 0); 855753fe330Swren romano } 856753fe330Swren romano 857753fe330Swren romano private: 858753fe330Swren romano /// The recursive component of the public `forallElements`. 859753fe330Swren romano void forallElements(ElementConsumer<V> yield, uint64_t parentPos, 860753fe330Swren romano uint64_t d) { 861753fe330Swren romano // Recover the `<P,I,V>` type parameters of `src`. 862753fe330Swren romano const auto &src = 863753fe330Swren romano static_cast<const SparseTensorStorage<P, I, V> &>(this->src); 864753fe330Swren romano if (d == Base::getRank()) { 865753fe330Swren romano assert(parentPos < src.values.size() && 866753fe330Swren romano "Value position is out of bounds"); 867753fe330Swren romano // TODO: <https://github.com/llvm/llvm-project/issues/54179> 868753fe330Swren romano yield(this->cursor, src.values[parentPos]); 869753fe330Swren romano } else if (src.isCompressedDim(d)) { 870753fe330Swren romano // Look up the bounds of the `d`-level segment determined by the 871753fe330Swren romano // `d-1`-level position `parentPos`. 872d8c46eb6SMehdi Amini const std::vector<P> &pointersD = src.pointers[d]; 873d8c46eb6SMehdi Amini assert(parentPos + 1 < pointersD.size() && 874753fe330Swren romano "Parent pointer position is out of bounds"); 875d8c46eb6SMehdi Amini const uint64_t pstart = static_cast<uint64_t>(pointersD[parentPos]); 876d8c46eb6SMehdi Amini const uint64_t pstop = static_cast<uint64_t>(pointersD[parentPos + 1]); 877753fe330Swren romano // Loop-invariant code for looking up the `d`-level coordinates/indices. 878d8c46eb6SMehdi Amini const std::vector<I> &indicesD = src.indices[d]; 879d8c46eb6SMehdi Amini assert(pstop <= indicesD.size() && "Index position is out of bounds"); 880d8c46eb6SMehdi Amini uint64_t &cursorReordD = this->cursor[this->reord[d]]; 881753fe330Swren romano for (uint64_t pos = pstart; pos < pstop; pos++) { 882d8c46eb6SMehdi Amini cursorReordD = static_cast<uint64_t>(indicesD[pos]); 883753fe330Swren romano forallElements(yield, pos, d + 1); 884753fe330Swren romano } 885753fe330Swren romano } else { // Dense dimension. 886753fe330Swren romano const uint64_t sz = src.getDimSizes()[d]; 887753fe330Swren romano const uint64_t pstart = parentPos * sz; 888d8c46eb6SMehdi Amini uint64_t &cursorReordD = this->cursor[this->reord[d]]; 889753fe330Swren romano for (uint64_t i = 0; i < sz; i++) { 890d8c46eb6SMehdi Amini cursorReordD = i; 891753fe330Swren romano forallElements(yield, pstart + i, d + 1); 892753fe330Swren romano } 893753fe330Swren romano } 894753fe330Swren romano } 895753fe330Swren romano }; 896753fe330Swren romano 8978cb33240Swren romano /// Statistics regarding the number of nonzero subtensors in 8988cb33240Swren romano /// a source tensor, for direct sparse=>sparse conversion a la 8998cb33240Swren romano /// <https://arxiv.org/abs/2001.02609>. 9008cb33240Swren romano /// 9018cb33240Swren romano /// N.B., this class stores references to the parameters passed to 9028cb33240Swren romano /// the constructor; thus, objects of this class must not outlive 9038cb33240Swren romano /// those parameters. 90476944420Swren romano class SparseTensorNNZ final { 9058cb33240Swren romano public: 9068cb33240Swren romano /// Allocate the statistics structure for the desired sizes and 9078cb33240Swren romano /// sparsity (in the target tensor's storage-order). This constructor 9088cb33240Swren romano /// does not actually populate the statistics, however; for that see 9098cb33240Swren romano /// `initialize`. 9108cb33240Swren romano /// 911fa6aed2aSwren romano /// Precondition: `dimSizes` must not contain zeros. 912fa6aed2aSwren romano SparseTensorNNZ(const std::vector<uint64_t> &dimSizes, 9138cb33240Swren romano const std::vector<DimLevelType> &sparsity) 914fa6aed2aSwren romano : dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) { 9158cb33240Swren romano assert(dimSizes.size() == dimTypes.size() && "Rank mismatch"); 9168cb33240Swren romano bool uncompressed = true; 9178cb33240Swren romano uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`. 9188cb33240Swren romano for (uint64_t rank = getRank(), r = 0; r < rank; r++) { 9198cb33240Swren romano switch (dimTypes[r]) { 9208cb33240Swren romano case DimLevelType::kCompressed: 9218cb33240Swren romano assert(uncompressed && 9228cb33240Swren romano "Multiple compressed layers not currently supported"); 9238cb33240Swren romano uncompressed = false; 9248cb33240Swren romano nnz[r].resize(sz, 0); // Both allocate and zero-initialize. 9258cb33240Swren romano break; 9268cb33240Swren romano case DimLevelType::kDense: 9278cb33240Swren romano assert(uncompressed && 9288cb33240Swren romano "Dense after compressed not currently supported"); 9298cb33240Swren romano break; 9308cb33240Swren romano case DimLevelType::kSingleton: 9318cb33240Swren romano // Singleton after Compressed causes no problems for allocating 9328cb33240Swren romano // `nnz` nor for the yieldPos loop. This remains true even 9338cb33240Swren romano // when adding support for multiple compressed dimensions or 9348cb33240Swren romano // for dense-after-compressed. 9358cb33240Swren romano break; 9368cb33240Swren romano } 9378cb33240Swren romano sz = checkedMul(sz, dimSizes[r]); 9388cb33240Swren romano } 9398cb33240Swren romano } 9408cb33240Swren romano 9418cb33240Swren romano // We disallow copying to help avoid leaking the stored references. 9428cb33240Swren romano SparseTensorNNZ(const SparseTensorNNZ &) = delete; 9438cb33240Swren romano SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete; 9448cb33240Swren romano 9458cb33240Swren romano /// Returns the rank of the target tensor. 9468cb33240Swren romano uint64_t getRank() const { return dimSizes.size(); } 9478cb33240Swren romano 9488cb33240Swren romano /// Enumerate the source tensor to fill in the statistics. The 9498cb33240Swren romano /// enumerator should already incorporate the permutation (from 9508cb33240Swren romano /// semantic-order to the target storage-order). 9518cb33240Swren romano template <typename V> 9528cb33240Swren romano void initialize(SparseTensorEnumeratorBase<V> &enumerator) { 9538cb33240Swren romano assert(enumerator.getRank() == getRank() && "Tensor rank mismatch"); 9548cb33240Swren romano assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch"); 9558cb33240Swren romano enumerator.forallElements( 9568cb33240Swren romano [this](const std::vector<uint64_t> &ind, V) { add(ind); }); 9578cb33240Swren romano } 9588cb33240Swren romano 9598cb33240Swren romano /// The type of callback functions which receive an nnz-statistic. 9608cb33240Swren romano using NNZConsumer = const std::function<void(uint64_t)> &; 9618cb33240Swren romano 9628cb33240Swren romano /// Lexicographically enumerates all indicies for dimensions strictly 9638cb33240Swren romano /// less than `stopDim`, and passes their nnz statistic to the callback. 9648cb33240Swren romano /// Since our use-case only requires the statistic not the coordinates 9658cb33240Swren romano /// themselves, we do not bother to construct those coordinates. 9668cb33240Swren romano void forallIndices(uint64_t stopDim, NNZConsumer yield) const { 9678cb33240Swren romano assert(stopDim < getRank() && "Stopping-dimension is out of bounds"); 9688cb33240Swren romano assert(dimTypes[stopDim] == DimLevelType::kCompressed && 9698cb33240Swren romano "Cannot look up non-compressed dimensions"); 9708cb33240Swren romano forallIndices(yield, stopDim, 0, 0); 9718cb33240Swren romano } 9728cb33240Swren romano 9738cb33240Swren romano private: 9748cb33240Swren romano /// Adds a new element (i.e., increment its statistics). We use 9758cb33240Swren romano /// a method rather than inlining into the lambda in `initialize`, 9768cb33240Swren romano /// to avoid spurious templating over `V`. And this method is private 9778cb33240Swren romano /// to avoid needing to re-assert validity of `ind` (which is guaranteed 9788cb33240Swren romano /// by `forallElements`). 9798cb33240Swren romano void add(const std::vector<uint64_t> &ind) { 9808cb33240Swren romano uint64_t parentPos = 0; 9818cb33240Swren romano for (uint64_t rank = getRank(), r = 0; r < rank; r++) { 9828cb33240Swren romano if (dimTypes[r] == DimLevelType::kCompressed) 9838cb33240Swren romano nnz[r][parentPos]++; 9848cb33240Swren romano parentPos = parentPos * dimSizes[r] + ind[r]; 9858cb33240Swren romano } 9868cb33240Swren romano } 9878cb33240Swren romano 9888cb33240Swren romano /// Recursive component of the public `forallIndices`. 9898cb33240Swren romano void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos, 9908cb33240Swren romano uint64_t d) const { 9918cb33240Swren romano assert(d <= stopDim); 9928cb33240Swren romano if (d == stopDim) { 9938cb33240Swren romano assert(parentPos < nnz[d].size() && "Cursor is out of range"); 9948cb33240Swren romano yield(nnz[d][parentPos]); 9958cb33240Swren romano } else { 9968cb33240Swren romano const uint64_t sz = dimSizes[d]; 9978cb33240Swren romano const uint64_t pstart = parentPos * sz; 9988cb33240Swren romano for (uint64_t i = 0; i < sz; i++) 9998cb33240Swren romano forallIndices(yield, stopDim, pstart + i, d + 1); 10008cb33240Swren romano } 10018cb33240Swren romano } 10028cb33240Swren romano 10038cb33240Swren romano // All of these are in the target storage-order. 10048cb33240Swren romano const std::vector<uint64_t> &dimSizes; 10058cb33240Swren romano const std::vector<DimLevelType> &dimTypes; 10068cb33240Swren romano std::vector<std::vector<uint64_t>> nnz; 10078cb33240Swren romano }; 10088cb33240Swren romano 10098cb33240Swren romano template <typename P, typename I, typename V> 10108cb33240Swren romano SparseTensorStorage<P, I, V>::SparseTensorStorage( 1011fa6aed2aSwren romano const std::vector<uint64_t> &dimSizes, const uint64_t *perm, 10128cb33240Swren romano const DimLevelType *sparsity, const SparseTensorStorageBase &tensor) 1013fa6aed2aSwren romano : SparseTensorStorage(dimSizes, perm, sparsity) { 10148cb33240Swren romano SparseTensorEnumeratorBase<V> *enumerator; 10158cb33240Swren romano tensor.newEnumerator(&enumerator, getRank(), perm); 10168cb33240Swren romano { 10178cb33240Swren romano // Initialize the statistics structure. 10188cb33240Swren romano SparseTensorNNZ nnz(getDimSizes(), getDimTypes()); 10198cb33240Swren romano nnz.initialize(*enumerator); 10208cb33240Swren romano // Initialize "pointers" overhead (and allocate "indices", "values"). 10218cb33240Swren romano uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`. 10228cb33240Swren romano for (uint64_t rank = getRank(), r = 0; r < rank; r++) { 10238cb33240Swren romano if (isCompressedDim(r)) { 10248cb33240Swren romano pointers[r].reserve(parentSz + 1); 10258cb33240Swren romano pointers[r].push_back(0); 10268cb33240Swren romano uint64_t currentPos = 0; 10278cb33240Swren romano nnz.forallIndices(r, [this, ¤tPos, r](uint64_t n) { 10288cb33240Swren romano currentPos += n; 10298cb33240Swren romano appendPointer(r, currentPos); 10308cb33240Swren romano }); 10318cb33240Swren romano assert(pointers[r].size() == parentSz + 1 && 10328cb33240Swren romano "Final pointers size doesn't match allocated size"); 10338cb33240Swren romano // That assertion entails `assembledSize(parentSz, r)` 10348cb33240Swren romano // is now in a valid state. That is, `pointers[r][parentSz]` 10358cb33240Swren romano // equals the present value of `currentPos`, which is the 10368cb33240Swren romano // correct assembled-size for `indices[r]`. 10378cb33240Swren romano } 10388cb33240Swren romano // Update assembled-size for the next iteration. 10398cb33240Swren romano parentSz = assembledSize(parentSz, r); 10408cb33240Swren romano // Ideally we need only `indices[r].reserve(parentSz)`, however 10418cb33240Swren romano // the `std::vector` implementation forces us to initialize it too. 10428cb33240Swren romano // That is, in the yieldPos loop we need random-access assignment 10438cb33240Swren romano // to `indices[r]`; however, `std::vector`'s subscript-assignment 10448cb33240Swren romano // only allows assigning to already-initialized positions. 10458cb33240Swren romano if (isCompressedDim(r)) 10468cb33240Swren romano indices[r].resize(parentSz, 0); 10478cb33240Swren romano } 10488cb33240Swren romano values.resize(parentSz, 0); // Both allocate and zero-initialize. 10498cb33240Swren romano } 10508cb33240Swren romano // The yieldPos loop 10518cb33240Swren romano enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) { 10528cb33240Swren romano uint64_t parentSz = 1, parentPos = 0; 10538cb33240Swren romano for (uint64_t rank = getRank(), r = 0; r < rank; r++) { 10548cb33240Swren romano if (isCompressedDim(r)) { 10558cb33240Swren romano // If `parentPos == parentSz` then it's valid as an array-lookup; 10568cb33240Swren romano // however, it's semantically invalid here since that entry 10578cb33240Swren romano // does not represent a segment of `indices[r]`. Moreover, that 10588cb33240Swren romano // entry must be immutable for `assembledSize` to remain valid. 10598cb33240Swren romano assert(parentPos < parentSz && "Pointers position is out of bounds"); 10608cb33240Swren romano const uint64_t currentPos = pointers[r][parentPos]; 10618cb33240Swren romano // This increment won't overflow the `P` type, since it can't 10628cb33240Swren romano // exceed the original value of `pointers[r][parentPos+1]` 10638cb33240Swren romano // which was already verified to be within bounds for `P` 10648cb33240Swren romano // when it was written to the array. 10658cb33240Swren romano pointers[r][parentPos]++; 10668cb33240Swren romano writeIndex(r, currentPos, ind[r]); 10678cb33240Swren romano parentPos = currentPos; 10688cb33240Swren romano } else { // Dense dimension. 10698cb33240Swren romano parentPos = parentPos * getDimSizes()[r] + ind[r]; 10708cb33240Swren romano } 10718cb33240Swren romano parentSz = assembledSize(parentSz, r); 10728cb33240Swren romano } 10738cb33240Swren romano assert(parentPos < values.size() && "Value position is out of bounds"); 10748cb33240Swren romano values[parentPos] = val; 10758cb33240Swren romano }); 10768cb33240Swren romano // No longer need the enumerator, so we'll delete it ASAP. 10778cb33240Swren romano delete enumerator; 10788cb33240Swren romano // The finalizeYieldPos loop 10798cb33240Swren romano for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) { 10808cb33240Swren romano if (isCompressedDim(r)) { 10818cb33240Swren romano assert(parentSz == pointers[r].size() - 1 && 10828cb33240Swren romano "Actual pointers size doesn't match the expected size"); 10838cb33240Swren romano // Can't check all of them, but at least we can check the last one. 10848cb33240Swren romano assert(pointers[r][parentSz - 1] == pointers[r][parentSz] && 10858cb33240Swren romano "Pointers got corrupted"); 10868cb33240Swren romano // TODO: optimize this by using `memmove` or similar. 10878cb33240Swren romano for (uint64_t n = 0; n < parentSz; n++) { 10888cb33240Swren romano const uint64_t parentPos = parentSz - n; 10898cb33240Swren romano pointers[r][parentPos] = pointers[r][parentPos - 1]; 10908cb33240Swren romano } 10918cb33240Swren romano pointers[r][0] = 0; 10928cb33240Swren romano } 10938cb33240Swren romano parentSz = assembledSize(parentSz, r); 10948cb33240Swren romano } 10958cb33240Swren romano } 10968cb33240Swren romano 10978a91bc7bSHarrietAkot /// Helper to convert string to lower case. 10988a91bc7bSHarrietAkot static char *toLower(char *token) { 10998a91bc7bSHarrietAkot for (char *c = token; *c; c++) 11008a91bc7bSHarrietAkot *c = tolower(*c); 11018a91bc7bSHarrietAkot return token; 11028a91bc7bSHarrietAkot } 11038a91bc7bSHarrietAkot 1104a4c53f8cSwren romano /// This class abstracts over the information stored in file headers, 1105a4c53f8cSwren romano /// as well as providing the buffers and methods for parsing those headers. 1106a4c53f8cSwren romano class SparseTensorFile final { 1107a4c53f8cSwren romano public: 1108a4c53f8cSwren romano explicit SparseTensorFile(char *filename) : filename(filename) { 1109a4c53f8cSwren romano assert(filename && "Received nullptr for filename"); 1110a4c53f8cSwren romano } 1111a4c53f8cSwren romano 1112a4c53f8cSwren romano // Disallows copying, to avoid duplicating the `file` pointer. 1113a4c53f8cSwren romano SparseTensorFile(const SparseTensorFile &) = delete; 1114a4c53f8cSwren romano SparseTensorFile &operator=(const SparseTensorFile &) = delete; 1115a4c53f8cSwren romano 1116a4c53f8cSwren romano // This dtor tries to avoid leaking the `file`. (Though it's better 1117a4c53f8cSwren romano // to call `closeFile` explicitly when possible, since there are 1118a4c53f8cSwren romano // circumstances where dtors are not called reliably.) 1119a4c53f8cSwren romano ~SparseTensorFile() { closeFile(); } 1120a4c53f8cSwren romano 1121a4c53f8cSwren romano /// Opens the file for reading. 1122a4c53f8cSwren romano void openFile() { 1123a4c53f8cSwren romano if (file) 1124a4c53f8cSwren romano FATAL("Already opened file %s\n", filename); 1125a4c53f8cSwren romano file = fopen(filename, "r"); 1126a4c53f8cSwren romano if (!file) 1127a4c53f8cSwren romano FATAL("Cannot find file %s\n", filename); 1128a4c53f8cSwren romano } 1129a4c53f8cSwren romano 1130a4c53f8cSwren romano /// Closes the file. 1131a4c53f8cSwren romano void closeFile() { 1132a4c53f8cSwren romano if (file) { 1133a4c53f8cSwren romano fclose(file); 1134a4c53f8cSwren romano file = nullptr; 1135a4c53f8cSwren romano } 1136a4c53f8cSwren romano } 1137a4c53f8cSwren romano 1138a4c53f8cSwren romano // TODO(wrengr/bixia): figure out how to reorganize the element-parsing 1139a4c53f8cSwren romano // loop of `openSparseTensorCOO` into methods of this class, so we can 1140a4c53f8cSwren romano // avoid leaking access to the `line` pointer (both for general hygiene 1141a4c53f8cSwren romano // and because we can't mark it const due to the second argument of 1142a4c53f8cSwren romano // `strtoul`/`strtoud` being `char * *restrict` rather than 1143a4c53f8cSwren romano // `char const* *restrict`). 1144a4c53f8cSwren romano // 1145a4c53f8cSwren romano /// Attempts to read a line from the file. 1146a4c53f8cSwren romano char *readLine() { 1147a4c53f8cSwren romano if (fgets(line, kColWidth, file)) 1148a4c53f8cSwren romano return line; 1149a4c53f8cSwren romano FATAL("Cannot read next line of %s\n", filename); 1150a4c53f8cSwren romano } 1151a4c53f8cSwren romano 1152a4c53f8cSwren romano /// Reads and parses the file's header. 1153a4c53f8cSwren romano void readHeader() { 1154a4c53f8cSwren romano assert(file && "Attempt to readHeader() before openFile()"); 1155a4c53f8cSwren romano if (strstr(filename, ".mtx")) 1156a4c53f8cSwren romano readMMEHeader(); 1157a4c53f8cSwren romano else if (strstr(filename, ".tns")) 1158a4c53f8cSwren romano readExtFROSTTHeader(); 1159a4c53f8cSwren romano else 1160a4c53f8cSwren romano FATAL("Unknown format %s\n", filename); 1161a4c53f8cSwren romano assert(isValid && "Failed to read the header"); 1162a4c53f8cSwren romano } 1163a4c53f8cSwren romano 1164a4c53f8cSwren romano /// Gets the MME "pattern" property setting. Is only valid after 1165a4c53f8cSwren romano /// parsing the header. 1166a4c53f8cSwren romano bool isPattern() const { 1167a4c53f8cSwren romano assert(isValid && "Attempt to isPattern() before readHeader()"); 1168a4c53f8cSwren romano return isPattern_; 1169a4c53f8cSwren romano } 1170a4c53f8cSwren romano 1171a4c53f8cSwren romano /// Gets the MME "symmetric" property setting. Is only valid after 1172a4c53f8cSwren romano /// parsing the header. 1173a4c53f8cSwren romano bool isSymmetric() const { 1174a4c53f8cSwren romano assert(isValid && "Attempt to isSymmetric() before readHeader()"); 1175a4c53f8cSwren romano return isSymmetric_; 1176a4c53f8cSwren romano } 1177a4c53f8cSwren romano 1178a4c53f8cSwren romano /// Gets the rank of the tensor. Is only valid after parsing the header. 1179a4c53f8cSwren romano uint64_t getRank() const { 1180a4c53f8cSwren romano assert(isValid && "Attempt to getRank() before readHeader()"); 1181a4c53f8cSwren romano return idata[0]; 1182a4c53f8cSwren romano } 1183a4c53f8cSwren romano 1184a4c53f8cSwren romano /// Gets the number of non-zeros. Is only valid after parsing the header. 1185a4c53f8cSwren romano uint64_t getNNZ() const { 1186a4c53f8cSwren romano assert(isValid && "Attempt to getNNZ() before readHeader()"); 1187a4c53f8cSwren romano return idata[1]; 1188a4c53f8cSwren romano } 1189a4c53f8cSwren romano 1190a4c53f8cSwren romano /// Gets the dimension-sizes array. The pointer itself is always 1191a4c53f8cSwren romano /// valid; however, the values stored therein are only valid after 1192a4c53f8cSwren romano /// parsing the header. 1193a4c53f8cSwren romano const uint64_t *getDimSizes() const { return idata + 2; } 1194a4c53f8cSwren romano 1195a4c53f8cSwren romano /// Safely gets the size of the given dimension. Is only valid 1196a4c53f8cSwren romano /// after parsing the header. 1197a4c53f8cSwren romano uint64_t getDimSize(uint64_t d) const { 1198a4c53f8cSwren romano assert(d < getRank()); 1199a4c53f8cSwren romano return idata[2 + d]; 1200a4c53f8cSwren romano } 1201a4c53f8cSwren romano 1202a4c53f8cSwren romano /// Asserts the shape subsumes the actual dimension sizes. Is only 1203a4c53f8cSwren romano /// valid after parsing the header. 1204a4c53f8cSwren romano void assertMatchesShape(uint64_t rank, const uint64_t *shape) const { 1205a4c53f8cSwren romano assert(rank == getRank() && "Rank mismatch"); 1206a4c53f8cSwren romano for (uint64_t r = 0; r < rank; r++) 1207a4c53f8cSwren romano assert((shape[r] == 0 || shape[r] == idata[2 + r]) && 1208a4c53f8cSwren romano "Dimension size mismatch"); 1209a4c53f8cSwren romano } 1210a4c53f8cSwren romano 1211a4c53f8cSwren romano private: 1212a4c53f8cSwren romano void readMMEHeader(); 1213a4c53f8cSwren romano void readExtFROSTTHeader(); 1214a4c53f8cSwren romano 1215a4c53f8cSwren romano const char *filename; 1216a4c53f8cSwren romano FILE *file = nullptr; 1217a4c53f8cSwren romano bool isValid = false; 1218a4c53f8cSwren romano bool isPattern_ = false; 1219a4c53f8cSwren romano bool isSymmetric_ = false; 1220a4c53f8cSwren romano uint64_t idata[512]; 1221a4c53f8cSwren romano char line[kColWidth]; 1222a4c53f8cSwren romano }; 1223a4c53f8cSwren romano 12248a91bc7bSHarrietAkot /// Read the MME header of a general sparse matrix of type real. 1225a4c53f8cSwren romano void SparseTensorFile::readMMEHeader() { 12268a91bc7bSHarrietAkot char header[64]; 12278a91bc7bSHarrietAkot char object[64]; 12288a91bc7bSHarrietAkot char format[64]; 12298a91bc7bSHarrietAkot char field[64]; 12308a91bc7bSHarrietAkot char symmetry[64]; 12318a91bc7bSHarrietAkot // Read header line. 12328a91bc7bSHarrietAkot if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field, 1233774674ceSwren romano symmetry) != 5) 1234774674ceSwren romano FATAL("Corrupt header in %s\n", filename); 123533e8ab8eSAart Bik // Set properties 1236a4c53f8cSwren romano isPattern_ = (strcmp(toLower(field), "pattern") == 0); 1237a4c53f8cSwren romano isSymmetric_ = (strcmp(toLower(symmetry), "symmetric") == 0); 12388a91bc7bSHarrietAkot // Make sure this is a general sparse matrix. 12398a91bc7bSHarrietAkot if (strcmp(toLower(header), "%%matrixmarket") || 12408a91bc7bSHarrietAkot strcmp(toLower(object), "matrix") || 124133e8ab8eSAart Bik strcmp(toLower(format), "coordinate") || 1242a4c53f8cSwren romano (strcmp(toLower(field), "real") && !isPattern_) || 1243a4c53f8cSwren romano (strcmp(toLower(symmetry), "general") && !isSymmetric_)) 1244774674ceSwren romano FATAL("Cannot find a general sparse matrix in %s\n", filename); 12458a91bc7bSHarrietAkot // Skip comments. 1246e5639b3fSMehdi Amini while (true) { 1247a4c53f8cSwren romano readLine(); 12488a91bc7bSHarrietAkot if (line[0] != '%') 12498a91bc7bSHarrietAkot break; 12508a91bc7bSHarrietAkot } 12518a91bc7bSHarrietAkot // Next line contains M N NNZ. 12528a91bc7bSHarrietAkot idata[0] = 2; // rank 12538a91bc7bSHarrietAkot if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3, 1254774674ceSwren romano idata + 1) != 3) 1255774674ceSwren romano FATAL("Cannot find size in %s\n", filename); 1256a4c53f8cSwren romano isValid = true; 12578a91bc7bSHarrietAkot } 12588a91bc7bSHarrietAkot 12598a91bc7bSHarrietAkot /// Read the "extended" FROSTT header. Although not part of the documented 12608a91bc7bSHarrietAkot /// format, we assume that the file starts with optional comments followed 12618a91bc7bSHarrietAkot /// by two lines that define the rank, the number of nonzeros, and the 12628a91bc7bSHarrietAkot /// dimensions sizes (one per rank) of the sparse tensor. 1263a4c53f8cSwren romano void SparseTensorFile::readExtFROSTTHeader() { 12648a91bc7bSHarrietAkot // Skip comments. 1265e5639b3fSMehdi Amini while (true) { 1266a4c53f8cSwren romano readLine(); 12678a91bc7bSHarrietAkot if (line[0] != '#') 12688a91bc7bSHarrietAkot break; 12698a91bc7bSHarrietAkot } 12708a91bc7bSHarrietAkot // Next line contains RANK and NNZ. 1271774674ceSwren romano if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) 1272774674ceSwren romano FATAL("Cannot find metadata in %s\n", filename); 12738a91bc7bSHarrietAkot // Followed by a line with the dimension sizes (one per rank). 1274774674ceSwren romano for (uint64_t r = 0; r < idata[0]; r++) 1275774674ceSwren romano if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) 1276774674ceSwren romano FATAL("Cannot find dimension size %s\n", filename); 1277a4c53f8cSwren romano readLine(); // end of line 1278a4c53f8cSwren romano isValid = true; 12798a91bc7bSHarrietAkot } 12808a91bc7bSHarrietAkot 12818a91bc7bSHarrietAkot /// Reads a sparse tensor with the given filename into a memory-resident 12828a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme. 12838a91bc7bSHarrietAkot template <typename V> 12848a91bc7bSHarrietAkot static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank, 1285d83a7068Swren romano const uint64_t *shape, 12868a91bc7bSHarrietAkot const uint64_t *perm) { 1287a4c53f8cSwren romano SparseTensorFile stfile(filename); 1288a4c53f8cSwren romano stfile.openFile(); 1289a4c53f8cSwren romano stfile.readHeader(); 1290a4c53f8cSwren romano stfile.assertMatchesShape(rank, shape); 12918a91bc7bSHarrietAkot // Prepare sparse tensor object with per-dimension sizes 12928a91bc7bSHarrietAkot // and the number of nonzeros as initial capacity. 1293a4c53f8cSwren romano uint64_t nnz = stfile.getNNZ(); 1294a4c53f8cSwren romano auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, stfile.getDimSizes(), 1295a4c53f8cSwren romano perm, nnz); 12968a91bc7bSHarrietAkot // Read all nonzero elements. 12978a91bc7bSHarrietAkot std::vector<uint64_t> indices(rank); 12988a91bc7bSHarrietAkot for (uint64_t k = 0; k < nnz; k++) { 1299a4c53f8cSwren romano char *linePtr = stfile.readLine(); 130003fe15ceSAart Bik for (uint64_t r = 0; r < rank; r++) { 130103fe15ceSAart Bik uint64_t idx = strtoul(linePtr, &linePtr, 10); 13028a91bc7bSHarrietAkot // Add 0-based index. 13038a91bc7bSHarrietAkot indices[perm[r]] = idx - 1; 13048a91bc7bSHarrietAkot } 13058a91bc7bSHarrietAkot // The external formats always store the numerical values with the type 13068a91bc7bSHarrietAkot // double, but we cast these values to the sparse tensor object type. 130733e8ab8eSAart Bik // For a pattern tensor, we arbitrarily pick the value 1 for all entries. 1308a4c53f8cSwren romano double value = stfile.isPattern() ? 1.0 : strtod(linePtr, &linePtr); 1309a4c53f8cSwren romano // TODO: <https://github.com/llvm/llvm-project/issues/54179> 1310a4c53f8cSwren romano coo->add(indices, value); 131102710413SBixia Zheng // We currently chose to deal with symmetric matrices by fully constructing 131202710413SBixia Zheng // them. In the future, we may want to make symmetry implicit for storage 131302710413SBixia Zheng // reasons. 1314a4c53f8cSwren romano if (stfile.isSymmetric() && indices[0] != indices[1]) 1315a4c53f8cSwren romano coo->add({indices[1], indices[0]}, value); 13168a91bc7bSHarrietAkot } 13178a91bc7bSHarrietAkot // Close the file and return tensor. 1318a4c53f8cSwren romano stfile.closeFile(); 1319a4c53f8cSwren romano return coo; 13208a91bc7bSHarrietAkot } 13218a91bc7bSHarrietAkot 13222046e11aSwren romano /// Writes the sparse tensor to `dest` in extended FROSTT format. 1323efa15f41SAart Bik template <typename V> 132446bdacaaSwren romano static void outSparseTensor(void *tensor, void *dest, bool sort) { 13256438783fSAart Bik assert(tensor && dest); 13266438783fSAart Bik auto coo = static_cast<SparseTensorCOO<V> *>(tensor); 13276438783fSAart Bik if (sort) 13286438783fSAart Bik coo->sort(); 13296438783fSAart Bik char *filename = static_cast<char *>(dest); 1330fa6aed2aSwren romano auto &dimSizes = coo->getDimSizes(); 13316438783fSAart Bik auto &elements = coo->getElements(); 13326438783fSAart Bik uint64_t rank = coo->getRank(); 1333efa15f41SAart Bik uint64_t nnz = elements.size(); 1334efa15f41SAart Bik std::fstream file; 1335efa15f41SAart Bik file.open(filename, std::ios_base::out | std::ios_base::trunc); 1336efa15f41SAart Bik assert(file.is_open()); 1337efa15f41SAart Bik file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl; 1338efa15f41SAart Bik for (uint64_t r = 0; r < rank - 1; r++) 1339fa6aed2aSwren romano file << dimSizes[r] << " "; 1340fa6aed2aSwren romano file << dimSizes[rank - 1] << std::endl; 1341efa15f41SAart Bik for (uint64_t i = 0; i < nnz; i++) { 1342efa15f41SAart Bik auto &idx = elements[i].indices; 1343efa15f41SAart Bik for (uint64_t r = 0; r < rank; r++) 1344efa15f41SAart Bik file << (idx[r] + 1) << " "; 1345efa15f41SAart Bik file << elements[i].value << std::endl; 1346efa15f41SAart Bik } 1347efa15f41SAart Bik file.flush(); 1348efa15f41SAart Bik file.close(); 1349efa15f41SAart Bik assert(file.good()); 13506438783fSAart Bik } 13516438783fSAart Bik 13526438783fSAart Bik /// Initializes sparse tensor from an external COO-flavored format. 13536438783fSAart Bik template <typename V> 135446bdacaaSwren romano static SparseTensorStorage<uint64_t, uint64_t, V> * 13556438783fSAart Bik toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values, 135620eaa88fSBixia Zheng uint64_t *indices, uint64_t *perm, uint8_t *sparse) { 135720eaa88fSBixia Zheng const DimLevelType *sparsity = (DimLevelType *)(sparse); 135820eaa88fSBixia Zheng #ifndef NDEBUG 135920eaa88fSBixia Zheng // Verify that perm is a permutation of 0..(rank-1). 136020eaa88fSBixia Zheng std::vector<uint64_t> order(perm, perm + rank); 136120eaa88fSBixia Zheng std::sort(order.begin(), order.end()); 1362774674ceSwren romano for (uint64_t i = 0; i < rank; ++i) 1363774674ceSwren romano if (i != order[i]) 1364774674ceSwren romano FATAL("Not a permutation of 0..%" PRIu64 "\n", rank); 136520eaa88fSBixia Zheng 136620eaa88fSBixia Zheng // Verify that the sparsity values are supported. 1367774674ceSwren romano for (uint64_t i = 0; i < rank; ++i) 136820eaa88fSBixia Zheng if (sparsity[i] != DimLevelType::kDense && 1369774674ceSwren romano sparsity[i] != DimLevelType::kCompressed) 1370774674ceSwren romano FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i])); 137120eaa88fSBixia Zheng #endif 137220eaa88fSBixia Zheng 13736438783fSAart Bik // Convert external format to internal COO. 137463bdcaf9Swren romano auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse); 13756438783fSAart Bik std::vector<uint64_t> idx(rank); 13766438783fSAart Bik for (uint64_t i = 0, base = 0; i < nse; i++) { 13776438783fSAart Bik for (uint64_t r = 0; r < rank; r++) 1378d8b229a1SAart Bik idx[perm[r]] = indices[base + r]; 137963bdcaf9Swren romano coo->add(idx, values[i]); 13806438783fSAart Bik base += rank; 13816438783fSAart Bik } 13826438783fSAart Bik // Return sparse tensor storage format as opaque pointer. 138363bdcaf9Swren romano auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor( 138463bdcaf9Swren romano rank, shape, perm, sparsity, coo); 138563bdcaf9Swren romano delete coo; 138663bdcaf9Swren romano return tensor; 13876438783fSAart Bik } 13886438783fSAart Bik 13896438783fSAart Bik /// Converts a sparse tensor to an external COO-flavored format. 13906438783fSAart Bik template <typename V> 139146bdacaaSwren romano static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse, 139246bdacaaSwren romano uint64_t **pShape, V **pValues, 139346bdacaaSwren romano uint64_t **pIndices) { 1394736c1b66SAart Bik assert(tensor); 13956438783fSAart Bik auto sparseTensor = 13966438783fSAart Bik static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor); 13976438783fSAart Bik uint64_t rank = sparseTensor->getRank(); 13986438783fSAart Bik std::vector<uint64_t> perm(rank); 13996438783fSAart Bik std::iota(perm.begin(), perm.end(), 0); 14006438783fSAart Bik SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data()); 14016438783fSAart Bik 14026438783fSAart Bik const std::vector<Element<V>> &elements = coo->getElements(); 14036438783fSAart Bik uint64_t nse = elements.size(); 14046438783fSAart Bik 14056438783fSAart Bik uint64_t *shape = new uint64_t[rank]; 14066438783fSAart Bik for (uint64_t i = 0; i < rank; i++) 1407fa6aed2aSwren romano shape[i] = coo->getDimSizes()[i]; 14086438783fSAart Bik 14096438783fSAart Bik V *values = new V[nse]; 14106438783fSAart Bik uint64_t *indices = new uint64_t[rank * nse]; 14116438783fSAart Bik 14126438783fSAart Bik for (uint64_t i = 0, base = 0; i < nse; i++) { 14136438783fSAart Bik values[i] = elements[i].value; 14146438783fSAart Bik for (uint64_t j = 0; j < rank; j++) 14156438783fSAart Bik indices[base + j] = elements[i].indices[j]; 14166438783fSAart Bik base += rank; 14176438783fSAart Bik } 14186438783fSAart Bik 14196438783fSAart Bik delete coo; 14206438783fSAart Bik *pRank = rank; 14216438783fSAart Bik *pNse = nse; 14226438783fSAart Bik *pShape = shape; 14236438783fSAart Bik *pValues = values; 14246438783fSAart Bik *pIndices = indices; 1425efa15f41SAart Bik } 1426efa15f41SAart Bik 14272046e11aSwren romano } // anonymous namespace 14288a91bc7bSHarrietAkot 14298a91bc7bSHarrietAkot extern "C" { 14308a91bc7bSHarrietAkot 14318a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 14328a91bc7bSHarrietAkot // 14332046e11aSwren romano // Public functions which operate on MLIR buffers (memrefs) to interact 14342046e11aSwren romano // with sparse tensors (which are only visible as opaque pointers externally). 14358a91bc7bSHarrietAkot // 14368a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 14378a91bc7bSHarrietAkot 14388a91bc7bSHarrietAkot #define CASE(p, i, v, P, I, V) \ 14398a91bc7bSHarrietAkot if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \ 144063bdcaf9Swren romano SparseTensorCOO<V> *coo = nullptr; \ 1441845561ecSwren romano if (action <= Action::kFromCOO) { \ 1442845561ecSwren romano if (action == Action::kFromFile) { \ 14438a91bc7bSHarrietAkot char *filename = static_cast<char *>(ptr); \ 144463bdcaf9Swren romano coo = openSparseTensorCOO<V>(filename, rank, shape, perm); \ 1445845561ecSwren romano } else if (action == Action::kFromCOO) { \ 144663bdcaf9Swren romano coo = static_cast<SparseTensorCOO<V> *>(ptr); \ 14478a91bc7bSHarrietAkot } else { \ 1448845561ecSwren romano assert(action == Action::kEmpty); \ 14498a91bc7bSHarrietAkot } \ 145063bdcaf9Swren romano auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor( \ 145163bdcaf9Swren romano rank, shape, perm, sparsity, coo); \ 145263bdcaf9Swren romano if (action == Action::kFromFile) \ 145363bdcaf9Swren romano delete coo; \ 145463bdcaf9Swren romano return tensor; \ 1455bb56c2b3SMehdi Amini } \ 14568cb33240Swren romano if (action == Action::kSparseToSparse) { \ 14578cb33240Swren romano auto *tensor = static_cast<SparseTensorStorageBase *>(ptr); \ 14588cb33240Swren romano return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm, \ 14598cb33240Swren romano sparsity, tensor); \ 14608cb33240Swren romano } \ 1461bb56c2b3SMehdi Amini if (action == Action::kEmptyCOO) \ 1462d83a7068Swren romano return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm); \ 146363bdcaf9Swren romano coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \ 1464845561ecSwren romano if (action == Action::kToIterator) { \ 146563bdcaf9Swren romano coo->startIterator(); \ 14668a91bc7bSHarrietAkot } else { \ 1467845561ecSwren romano assert(action == Action::kToCOO); \ 14688a91bc7bSHarrietAkot } \ 146963bdcaf9Swren romano return coo; \ 14708a91bc7bSHarrietAkot } 14718a91bc7bSHarrietAkot 1472845561ecSwren romano #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V) 14734f2ec7f9SAart Bik 1474d2215e79SRainer Orth // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor 1475bc04a470Swren romano // can safely rewrite kIndex to kU64. We make this assertion to guarantee 1476bc04a470Swren romano // that this file cannot get out of sync with its header. 1477d2215e79SRainer Orth static_assert(std::is_same<index_type, uint64_t>::value, 1478d2215e79SRainer Orth "Expected index_type == uint64_t"); 1479bc04a470Swren romano 14808a91bc7bSHarrietAkot void * 1481845561ecSwren romano _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT 1482d2215e79SRainer Orth StridedMemRefType<index_type, 1> *sref, 1483d2215e79SRainer Orth StridedMemRefType<index_type, 1> *pref, 1484845561ecSwren romano OverheadType ptrTp, OverheadType indTp, 1485845561ecSwren romano PrimaryType valTp, Action action, void *ptr) { 14868a91bc7bSHarrietAkot assert(aref && sref && pref); 14878a91bc7bSHarrietAkot assert(aref->strides[0] == 1 && sref->strides[0] == 1 && 14888a91bc7bSHarrietAkot pref->strides[0] == 1); 14898a91bc7bSHarrietAkot assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]); 1490845561ecSwren romano const DimLevelType *sparsity = aref->data + aref->offset; 1491d83a7068Swren romano const index_type *shape = sref->data + sref->offset; 1492d2215e79SRainer Orth const index_type *perm = pref->data + pref->offset; 14938a91bc7bSHarrietAkot uint64_t rank = aref->sizes[0]; 14948a91bc7bSHarrietAkot 1495bc04a470Swren romano // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. 1496bc04a470Swren romano // This is safe because of the static_assert above. 1497bc04a470Swren romano if (ptrTp == OverheadType::kIndex) 1498bc04a470Swren romano ptrTp = OverheadType::kU64; 1499bc04a470Swren romano if (indTp == OverheadType::kIndex) 1500bc04a470Swren romano indTp = OverheadType::kU64; 1501bc04a470Swren romano 15028a91bc7bSHarrietAkot // Double matrices with all combinations of overhead storage. 1503845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t, 1504845561ecSwren romano uint64_t, double); 1505845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t, 1506845561ecSwren romano uint32_t, double); 1507845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t, 1508845561ecSwren romano uint16_t, double); 1509845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t, 1510845561ecSwren romano uint8_t, double); 1511845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t, 1512845561ecSwren romano uint64_t, double); 1513845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t, 1514845561ecSwren romano uint32_t, double); 1515845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t, 1516845561ecSwren romano uint16_t, double); 1517845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t, 1518845561ecSwren romano uint8_t, double); 1519845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t, 1520845561ecSwren romano uint64_t, double); 1521845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t, 1522845561ecSwren romano uint32_t, double); 1523845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t, 1524845561ecSwren romano uint16_t, double); 1525845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t, 1526845561ecSwren romano uint8_t, double); 1527845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t, 1528845561ecSwren romano uint64_t, double); 1529845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t, 1530845561ecSwren romano uint32_t, double); 1531845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t, 1532845561ecSwren romano uint16_t, double); 1533845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t, 1534845561ecSwren romano uint8_t, double); 15358a91bc7bSHarrietAkot 15368a91bc7bSHarrietAkot // Float matrices with all combinations of overhead storage. 1537845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t, 1538845561ecSwren romano uint64_t, float); 1539845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t, 1540845561ecSwren romano uint32_t, float); 1541845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t, 1542845561ecSwren romano uint16_t, float); 1543845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t, 1544845561ecSwren romano uint8_t, float); 1545845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t, 1546845561ecSwren romano uint64_t, float); 1547845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t, 1548845561ecSwren romano uint32_t, float); 1549845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t, 1550845561ecSwren romano uint16_t, float); 1551845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t, 1552845561ecSwren romano uint8_t, float); 1553845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t, 1554845561ecSwren romano uint64_t, float); 1555845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t, 1556845561ecSwren romano uint32_t, float); 1557845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t, 1558845561ecSwren romano uint16_t, float); 1559845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t, 1560845561ecSwren romano uint8_t, float); 1561845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t, 1562845561ecSwren romano uint64_t, float); 1563845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t, 1564845561ecSwren romano uint32_t, float); 1565845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t, 1566845561ecSwren romano uint16_t, float); 1567845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t, 1568845561ecSwren romano uint8_t, float); 15698a91bc7bSHarrietAkot 1570*ea8ed5cbSbixia1 // Two-byte floats with both overheads of the same type. 1571*ea8ed5cbSbixia1 CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16); 1572*ea8ed5cbSbixia1 CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16); 1573*ea8ed5cbSbixia1 CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16); 1574*ea8ed5cbSbixia1 CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16); 1575*ea8ed5cbSbixia1 CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16); 1576*ea8ed5cbSbixia1 CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16); 1577*ea8ed5cbSbixia1 CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16); 1578*ea8ed5cbSbixia1 CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16); 1579*ea8ed5cbSbixia1 1580845561ecSwren romano // Integral matrices with both overheads of the same type. 1581845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t); 1582845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t); 1583845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t); 1584845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t); 15852046e11aSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t); 1586845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t); 1587845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t); 1588845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t); 15892046e11aSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t); 1590845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t); 1591845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t); 1592845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t); 15932046e11aSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t); 1594845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t); 1595845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t); 1596845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t); 15978a91bc7bSHarrietAkot 1598736c1b66SAart Bik // Complex matrices with wide overhead. 1599736c1b66SAart Bik CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64); 1600736c1b66SAart Bik CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32); 1601736c1b66SAart Bik 16028a91bc7bSHarrietAkot // Unsupported case (add above if needed). 1603774674ceSwren romano // TODO: better pretty-printing of enum values! 1604774674ceSwren romano FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n", 1605774674ceSwren romano static_cast<int>(ptrTp), static_cast<int>(indTp), 1606774674ceSwren romano static_cast<int>(valTp)); 16078a91bc7bSHarrietAkot } 16088a91bc7bSHarrietAkot #undef CASE 16091313f5d3Swren romano #undef CASE_SECSAME 16106438783fSAart Bik 1611bfadd13dSwren romano #define IMPL_SPARSEVALUES(VNAME, V) \ 1612bfadd13dSwren romano void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref, \ 1613bfadd13dSwren romano void *tensor) { \ 1614bfadd13dSwren romano assert(ref &&tensor); \ 1615bfadd13dSwren romano std::vector<V> *v; \ 1616bfadd13dSwren romano static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \ 1617bfadd13dSwren romano ref->basePtr = ref->data = v->data(); \ 1618bfadd13dSwren romano ref->offset = 0; \ 1619bfadd13dSwren romano ref->sizes[0] = v->size(); \ 1620bfadd13dSwren romano ref->strides[0] = 1; \ 1621bfadd13dSwren romano } 1622bfadd13dSwren romano FOREVERY_V(IMPL_SPARSEVALUES) 1623bfadd13dSwren romano #undef IMPL_SPARSEVALUES 1624bfadd13dSwren romano 1625bfadd13dSwren romano #define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \ 1626bfadd13dSwren romano void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \ 1627bfadd13dSwren romano index_type d) { \ 1628bfadd13dSwren romano assert(ref &&tensor); \ 1629bfadd13dSwren romano std::vector<TYPE> *v; \ 1630bfadd13dSwren romano static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d); \ 1631bfadd13dSwren romano ref->basePtr = ref->data = v->data(); \ 1632bfadd13dSwren romano ref->offset = 0; \ 1633bfadd13dSwren romano ref->sizes[0] = v->size(); \ 1634bfadd13dSwren romano ref->strides[0] = 1; \ 1635bfadd13dSwren romano } 1636a9a19f59Swren romano #define IMPL_SPARSEPOINTERS(PNAME, P) \ 1637a9a19f59Swren romano IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers) 1638a9a19f59Swren romano FOREVERY_O(IMPL_SPARSEPOINTERS) 1639a9a19f59Swren romano #undef IMPL_SPARSEPOINTERS 1640bfadd13dSwren romano 1641a9a19f59Swren romano #define IMPL_SPARSEINDICES(INAME, I) \ 1642a9a19f59Swren romano IMPL_GETOVERHEAD(sparseIndices##INAME, I, getIndices) 1643a9a19f59Swren romano FOREVERY_O(IMPL_SPARSEINDICES) 1644a9a19f59Swren romano #undef IMPL_SPARSEINDICES 1645bfadd13dSwren romano #undef IMPL_GETOVERHEAD 1646bfadd13dSwren romano 1647bfadd13dSwren romano #define IMPL_ADDELT(VNAME, V) \ 1648bfadd13dSwren romano void *_mlir_ciface_addElt##VNAME(void *coo, V value, \ 1649bfadd13dSwren romano StridedMemRefType<index_type, 1> *iref, \ 1650bfadd13dSwren romano StridedMemRefType<index_type, 1> *pref) { \ 1651bfadd13dSwren romano assert(coo &&iref &&pref); \ 1652bfadd13dSwren romano assert(iref->strides[0] == 1 && pref->strides[0] == 1); \ 1653bfadd13dSwren romano assert(iref->sizes[0] == pref->sizes[0]); \ 1654bfadd13dSwren romano const index_type *indx = iref->data + iref->offset; \ 1655bfadd13dSwren romano const index_type *perm = pref->data + pref->offset; \ 1656bfadd13dSwren romano uint64_t isize = iref->sizes[0]; \ 1657bfadd13dSwren romano std::vector<index_type> indices(isize); \ 1658bfadd13dSwren romano for (uint64_t r = 0; r < isize; r++) \ 1659bfadd13dSwren romano indices[perm[r]] = indx[r]; \ 1660bfadd13dSwren romano static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value); \ 1661bfadd13dSwren romano return coo; \ 1662bfadd13dSwren romano } 1663bfadd13dSwren romano FOREVERY_SIMPLEX_V(IMPL_ADDELT) 1664bfadd13dSwren romano IMPL_ADDELT(C64, complex64) 16652046e11aSwren romano // Marked static because it's not part of the public API. 16660fbe3f3fSwren romano // NOTE: the `static` keyword confuses clang-format here, causing 16670fbe3f3fSwren romano // the strange indentation of the `_mlir_ciface_addEltC32` prototype. 16680fbe3f3fSwren romano // In C++11 we can add a semicolon after the call to `IMPL_ADDELT` 16690fbe3f3fSwren romano // and that will correct clang-format. Alas, this file is compiled 16700fbe3f3fSwren romano // in C++98 mode where that semicolon is illegal (and there's no portable 16710fbe3f3fSwren romano // macro magic to license a no-op semicolon at the top level). 16720fbe3f3fSwren romano static IMPL_ADDELT(C32ABI, complex32) 16732046e11aSwren romano #undef IMPL_ADDELT 1674bfadd13dSwren romano void *_mlir_ciface_addEltC32(void *coo, float r, float i, 1675bfadd13dSwren romano StridedMemRefType<index_type, 1> *iref, 1676bfadd13dSwren romano StridedMemRefType<index_type, 1> *pref) { 1677bfadd13dSwren romano return _mlir_ciface_addEltC32ABI(coo, complex32(r, i), iref, pref); 1678bfadd13dSwren romano } 1679bfadd13dSwren romano 1680bfadd13dSwren romano #define IMPL_GETNEXT(VNAME, V) \ 1681bfadd13dSwren romano bool _mlir_ciface_getNext##VNAME(void *coo, \ 1682bfadd13dSwren romano StridedMemRefType<index_type, 1> *iref, \ 1683bfadd13dSwren romano StridedMemRefType<V, 0> *vref) { \ 1684bfadd13dSwren romano assert(coo &&iref &&vref); \ 1685bfadd13dSwren romano assert(iref->strides[0] == 1); \ 1686bfadd13dSwren romano index_type *indx = iref->data + iref->offset; \ 1687bfadd13dSwren romano V *value = vref->data + vref->offset; \ 1688bfadd13dSwren romano const uint64_t isize = iref->sizes[0]; \ 1689bfadd13dSwren romano const Element<V> *elem = \ 1690bfadd13dSwren romano static_cast<SparseTensorCOO<V> *>(coo)->getNext(); \ 1691bfadd13dSwren romano if (elem == nullptr) \ 1692bfadd13dSwren romano return false; \ 1693bfadd13dSwren romano for (uint64_t r = 0; r < isize; r++) \ 1694bfadd13dSwren romano indx[r] = elem->indices[r]; \ 1695bfadd13dSwren romano *value = elem->value; \ 1696bfadd13dSwren romano return true; \ 1697bfadd13dSwren romano } 1698bfadd13dSwren romano FOREVERY_V(IMPL_GETNEXT) 1699bfadd13dSwren romano #undef IMPL_GETNEXT 1700bfadd13dSwren romano 1701bfadd13dSwren romano #define IMPL_LEXINSERT(VNAME, V) \ 1702bfadd13dSwren romano void _mlir_ciface_lexInsert##VNAME( \ 1703bfadd13dSwren romano void *tensor, StridedMemRefType<index_type, 1> *cref, V val) { \ 1704bfadd13dSwren romano assert(tensor &&cref); \ 1705bfadd13dSwren romano assert(cref->strides[0] == 1); \ 1706bfadd13dSwren romano index_type *cursor = cref->data + cref->offset; \ 1707bfadd13dSwren romano assert(cursor); \ 1708bfadd13dSwren romano static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val); \ 1709bfadd13dSwren romano } 1710bfadd13dSwren romano FOREVERY_SIMPLEX_V(IMPL_LEXINSERT) 1711bfadd13dSwren romano IMPL_LEXINSERT(C64, complex64) 17122046e11aSwren romano // Marked static because it's not part of the public API. 17130fbe3f3fSwren romano // NOTE: see the note for `_mlir_ciface_addEltC32ABI` 17140fbe3f3fSwren romano static IMPL_LEXINSERT(C32ABI, complex32) 17152046e11aSwren romano #undef IMPL_LEXINSERT 1716bfadd13dSwren romano void _mlir_ciface_lexInsertC32(void *tensor, 17170fbe3f3fSwren romano StridedMemRefType<index_type, 1> *cref, 17180fbe3f3fSwren romano float r, float i) { 1719bfadd13dSwren romano _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i)); 1720bfadd13dSwren romano } 1721bfadd13dSwren romano 1722bfadd13dSwren romano #define IMPL_EXPINSERT(VNAME, V) \ 1723bfadd13dSwren romano void _mlir_ciface_expInsert##VNAME( \ 1724bfadd13dSwren romano void *tensor, StridedMemRefType<index_type, 1> *cref, \ 1725bfadd13dSwren romano StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \ 1726bfadd13dSwren romano StridedMemRefType<index_type, 1> *aref, index_type count) { \ 1727bfadd13dSwren romano assert(tensor &&cref &&vref &&fref &&aref); \ 1728bfadd13dSwren romano assert(cref->strides[0] == 1); \ 1729bfadd13dSwren romano assert(vref->strides[0] == 1); \ 1730bfadd13dSwren romano assert(fref->strides[0] == 1); \ 1731bfadd13dSwren romano assert(aref->strides[0] == 1); \ 1732bfadd13dSwren romano assert(vref->sizes[0] == fref->sizes[0]); \ 1733bfadd13dSwren romano index_type *cursor = cref->data + cref->offset; \ 1734bfadd13dSwren romano V *values = vref->data + vref->offset; \ 1735bfadd13dSwren romano bool *filled = fref->data + fref->offset; \ 1736bfadd13dSwren romano index_type *added = aref->data + aref->offset; \ 1737bfadd13dSwren romano static_cast<SparseTensorStorageBase *>(tensor)->expInsert( \ 1738bfadd13dSwren romano cursor, values, filled, added, count); \ 1739bfadd13dSwren romano } 1740bfadd13dSwren romano FOREVERY_V(IMPL_EXPINSERT) 1741bfadd13dSwren romano #undef IMPL_EXPINSERT 1742bfadd13dSwren romano 17438a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 17448a91bc7bSHarrietAkot // 17452046e11aSwren romano // Public functions which accept only C-style data structures to interact 17462046e11aSwren romano // with sparse tensors (which are only visible as opaque pointers externally). 17478a91bc7bSHarrietAkot // 17488a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 17498a91bc7bSHarrietAkot 1750d2215e79SRainer Orth index_type sparseDimSize(void *tensor, index_type d) { 17518a91bc7bSHarrietAkot return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d); 17528a91bc7bSHarrietAkot } 17538a91bc7bSHarrietAkot 1754f66e5769SAart Bik void endInsert(void *tensor) { 1755f66e5769SAart Bik return static_cast<SparseTensorStorageBase *>(tensor)->endInsert(); 1756f66e5769SAart Bik } 1757f66e5769SAart Bik 175805c17bc4Swren romano #define IMPL_OUTSPARSETENSOR(VNAME, V) \ 175905c17bc4Swren romano void outSparseTensor##VNAME(void *coo, void *dest, bool sort) { \ 176005c17bc4Swren romano return outSparseTensor<V>(coo, dest, sort); \ 176105c17bc4Swren romano } 176205c17bc4Swren romano FOREVERY_V(IMPL_OUTSPARSETENSOR) 176305c17bc4Swren romano #undef IMPL_OUTSPARSETENSOR 176405c17bc4Swren romano 17658a91bc7bSHarrietAkot void delSparseTensor(void *tensor) { 17668a91bc7bSHarrietAkot delete static_cast<SparseTensorStorageBase *>(tensor); 17678a91bc7bSHarrietAkot } 17688a91bc7bSHarrietAkot 176963bdcaf9Swren romano #define IMPL_DELCOO(VNAME, V) \ 177063bdcaf9Swren romano void delSparseTensorCOO##VNAME(void *coo) { \ 177163bdcaf9Swren romano delete static_cast<SparseTensorCOO<V> *>(coo); \ 177263bdcaf9Swren romano } 17731313f5d3Swren romano FOREVERY_V(IMPL_DELCOO) 177463bdcaf9Swren romano #undef IMPL_DELCOO 177563bdcaf9Swren romano 177605c17bc4Swren romano char *getTensorFilename(index_type id) { 177705c17bc4Swren romano char var[80]; 177805c17bc4Swren romano sprintf(var, "TENSOR%" PRIu64, id); 177905c17bc4Swren romano char *env = getenv(var); 178005c17bc4Swren romano if (!env) 178105c17bc4Swren romano FATAL("Environment variable %s is not set\n", var); 178205c17bc4Swren romano return env; 178305c17bc4Swren romano } 178405c17bc4Swren romano 1785a4c53f8cSwren romano void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) { 1786a4c53f8cSwren romano assert(out && "Received nullptr for out-parameter"); 1787a4c53f8cSwren romano SparseTensorFile stfile(filename); 1788a4c53f8cSwren romano stfile.openFile(); 1789a4c53f8cSwren romano stfile.readHeader(); 1790a4c53f8cSwren romano stfile.closeFile(); 1791a4c53f8cSwren romano const uint64_t rank = stfile.getRank(); 1792a4c53f8cSwren romano const uint64_t *dimSizes = stfile.getDimSizes(); 1793a4c53f8cSwren romano out->reserve(rank); 1794a4c53f8cSwren romano out->assign(dimSizes, dimSizes + rank); 1795a4c53f8cSwren romano } 1796a4c53f8cSwren romano 179720eaa88fSBixia Zheng // TODO: generalize beyond 64-bit indices. 17981313f5d3Swren romano #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V) \ 17991313f5d3Swren romano void *convertToMLIRSparseTensor##VNAME( \ 18001313f5d3Swren romano uint64_t rank, uint64_t nse, uint64_t *shape, V *values, \ 18011313f5d3Swren romano uint64_t *indices, uint64_t *perm, uint8_t *sparse) { \ 18021313f5d3Swren romano return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm, \ 18031313f5d3Swren romano sparse); \ 18048a91bc7bSHarrietAkot } 18051313f5d3Swren romano FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR) 18061313f5d3Swren romano #undef IMPL_CONVERTTOMLIRSPARSETENSOR 18078a91bc7bSHarrietAkot 18082f49e6b0SBixia Zheng // TODO: Currently, values are copied from SparseTensorStorage to 18092046e11aSwren romano // SparseTensorCOO, then to the output. We may want to reduce the number 18102046e11aSwren romano // of copies. 18112f49e6b0SBixia Zheng // 18126438783fSAart Bik // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions 18136438783fSAart Bik // compressed 18141313f5d3Swren romano #define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V) \ 18151313f5d3Swren romano void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank, \ 18161313f5d3Swren romano uint64_t *pNse, uint64_t **pShape, \ 18171313f5d3Swren romano V **pValues, uint64_t **pIndices) { \ 18181313f5d3Swren romano fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices); \ 18192f49e6b0SBixia Zheng } 18201313f5d3Swren romano FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR) 18211313f5d3Swren romano #undef IMPL_CONVERTFROMMLIRSPARSETENSOR 1822efa15f41SAart Bik 18238a91bc7bSHarrietAkot } // extern "C" 18248a91bc7bSHarrietAkot 18258a91bc7bSHarrietAkot #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 1826