18a91bc7bSHarrietAkot //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===// 28a91bc7bSHarrietAkot // 38a91bc7bSHarrietAkot // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48a91bc7bSHarrietAkot // See https://llvm.org/LICENSE.txt for license information. 58a91bc7bSHarrietAkot // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68a91bc7bSHarrietAkot // 78a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 88a91bc7bSHarrietAkot // 98a91bc7bSHarrietAkot // This file implements a light-weight runtime support library that is useful 108a91bc7bSHarrietAkot // for sparse tensor manipulations. The functionality provided in this library 118a91bc7bSHarrietAkot // is meant to simplify benchmarking, testing, and debugging MLIR code that 128a91bc7bSHarrietAkot // operates on sparse tensors. The provided functionality is **not** part 138a91bc7bSHarrietAkot // of core MLIR, however. 148a91bc7bSHarrietAkot // 158a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 168a91bc7bSHarrietAkot 17845561ecSwren romano #include "mlir/ExecutionEngine/SparseTensorUtils.h" 188a91bc7bSHarrietAkot #include "mlir/ExecutionEngine/CRunnerUtils.h" 198a91bc7bSHarrietAkot 208a91bc7bSHarrietAkot #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 218a91bc7bSHarrietAkot 228a91bc7bSHarrietAkot #include <algorithm> 238a91bc7bSHarrietAkot #include <cassert> 248a91bc7bSHarrietAkot #include <cctype> 258a91bc7bSHarrietAkot #include <cinttypes> 268a91bc7bSHarrietAkot #include <cstdio> 278a91bc7bSHarrietAkot #include <cstdlib> 288a91bc7bSHarrietAkot #include <cstring> 29efa15f41SAart Bik #include <fstream> 30efa15f41SAart Bik #include <iostream> 314d0a18d0Swren romano #include <limits> 328a91bc7bSHarrietAkot #include <numeric> 338a91bc7bSHarrietAkot #include <vector> 348a91bc7bSHarrietAkot 358a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 368a91bc7bSHarrietAkot // 378a91bc7bSHarrietAkot // Internal support for storing and reading sparse tensors. 388a91bc7bSHarrietAkot // 398a91bc7bSHarrietAkot // The following memory-resident sparse storage schemes are supported: 408a91bc7bSHarrietAkot // 418a91bc7bSHarrietAkot // (a) A coordinate scheme for temporarily storing and lexicographically 428a91bc7bSHarrietAkot // sorting a sparse tensor by index (SparseTensorCOO). 438a91bc7bSHarrietAkot // 448a91bc7bSHarrietAkot // (b) A "one-size-fits-all" sparse tensor storage scheme defined by 458a91bc7bSHarrietAkot // per-dimension sparse/dense annnotations together with a dimension 468a91bc7bSHarrietAkot // ordering used by MLIR compiler-generated code (SparseTensorStorage). 478a91bc7bSHarrietAkot // 488a91bc7bSHarrietAkot // The following external formats are supported: 498a91bc7bSHarrietAkot // 508a91bc7bSHarrietAkot // (1) Matrix Market Exchange (MME): *.mtx 518a91bc7bSHarrietAkot // https://math.nist.gov/MatrixMarket/formats.html 528a91bc7bSHarrietAkot // 538a91bc7bSHarrietAkot // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns 548a91bc7bSHarrietAkot // http://frostt.io/tensors/file-formats.html 558a91bc7bSHarrietAkot // 568a91bc7bSHarrietAkot // Two public APIs are supported: 578a91bc7bSHarrietAkot // 588a91bc7bSHarrietAkot // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse 598a91bc7bSHarrietAkot // tensors. These methods should be used exclusively by MLIR 608a91bc7bSHarrietAkot // compiler-generated code. 618a91bc7bSHarrietAkot // 628a91bc7bSHarrietAkot // (II) Methods that accept C-style data structures to interact with sparse 638a91bc7bSHarrietAkot // tensors. These methods can be used by any external runtime that wants 648a91bc7bSHarrietAkot // to interact with MLIR compiler-generated code. 658a91bc7bSHarrietAkot // 668a91bc7bSHarrietAkot // In both cases (I) and (II), the SparseTensorStorage format is externally 678a91bc7bSHarrietAkot // only visible as an opaque pointer. 688a91bc7bSHarrietAkot // 698a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 708a91bc7bSHarrietAkot 718a91bc7bSHarrietAkot namespace { 728a91bc7bSHarrietAkot 7303fe15ceSAart Bik static constexpr int kColWidth = 1025; 7403fe15ceSAart Bik 758a91bc7bSHarrietAkot /// A sparse tensor element in coordinate scheme (value and indices). 768a91bc7bSHarrietAkot /// For example, a rank-1 vector element would look like 778a91bc7bSHarrietAkot /// ({i}, a[i]) 788a91bc7bSHarrietAkot /// and a rank-5 tensor element like 798a91bc7bSHarrietAkot /// ({i,j,k,l,m}, a[i,j,k,l,m]) 808a91bc7bSHarrietAkot template <typename V> 818a91bc7bSHarrietAkot struct Element { 828a91bc7bSHarrietAkot Element(const std::vector<uint64_t> &ind, V val) : indices(ind), value(val){}; 838a91bc7bSHarrietAkot std::vector<uint64_t> indices; 848a91bc7bSHarrietAkot V value; 85110295ebSwren romano /// Returns true if indices of e1 < indices of e2. 86110295ebSwren romano static bool lexOrder(const Element<V> &e1, const Element<V> &e2) { 87110295ebSwren romano uint64_t rank = e1.indices.size(); 88110295ebSwren romano assert(rank == e2.indices.size()); 89110295ebSwren romano for (uint64_t r = 0; r < rank; r++) { 90110295ebSwren romano if (e1.indices[r] == e2.indices[r]) 91110295ebSwren romano continue; 92110295ebSwren romano return e1.indices[r] < e2.indices[r]; 93110295ebSwren romano } 94110295ebSwren romano return false; 95110295ebSwren romano } 968a91bc7bSHarrietAkot }; 978a91bc7bSHarrietAkot 988a91bc7bSHarrietAkot /// A memory-resident sparse tensor in coordinate scheme (collection of 998a91bc7bSHarrietAkot /// elements). This data structure is used to read a sparse tensor from 1008a91bc7bSHarrietAkot /// any external format into memory and sort the elements lexicographically 1018a91bc7bSHarrietAkot /// by indices before passing it back to the client (most packed storage 1028a91bc7bSHarrietAkot /// formats require the elements to appear in lexicographic index order). 1038a91bc7bSHarrietAkot template <typename V> 1048a91bc7bSHarrietAkot struct SparseTensorCOO { 1058a91bc7bSHarrietAkot public: 1068a91bc7bSHarrietAkot SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity) 1078a91bc7bSHarrietAkot : sizes(szs), iteratorLocked(false), iteratorPos(0) { 1088a91bc7bSHarrietAkot if (capacity) 1098a91bc7bSHarrietAkot elements.reserve(capacity); 1108a91bc7bSHarrietAkot } 1118a91bc7bSHarrietAkot /// Adds element as indices and value. 1128a91bc7bSHarrietAkot void add(const std::vector<uint64_t> &ind, V val) { 1138a91bc7bSHarrietAkot assert(!iteratorLocked && "Attempt to add() after startIterator()"); 1148a91bc7bSHarrietAkot uint64_t rank = getRank(); 1158a91bc7bSHarrietAkot assert(rank == ind.size()); 1168a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 1178a91bc7bSHarrietAkot assert(ind[r] < sizes[r]); // within bounds 1188a91bc7bSHarrietAkot elements.emplace_back(ind, val); 1198a91bc7bSHarrietAkot } 1208a91bc7bSHarrietAkot /// Sorts elements lexicographically by index. 1218a91bc7bSHarrietAkot void sort() { 1228a91bc7bSHarrietAkot assert(!iteratorLocked && "Attempt to sort() after startIterator()"); 123cf358253Swren romano // TODO: we may want to cache an `isSorted` bit, to avoid 124cf358253Swren romano // unnecessary/redundant sorting. 125110295ebSwren romano std::sort(elements.begin(), elements.end(), Element<V>::lexOrder); 1268a91bc7bSHarrietAkot } 1278a91bc7bSHarrietAkot /// Returns rank. 1288a91bc7bSHarrietAkot uint64_t getRank() const { return sizes.size(); } 1298a91bc7bSHarrietAkot /// Getter for sizes array. 1308a91bc7bSHarrietAkot const std::vector<uint64_t> &getSizes() const { return sizes; } 1318a91bc7bSHarrietAkot /// Getter for elements array. 1328a91bc7bSHarrietAkot const std::vector<Element<V>> &getElements() const { return elements; } 1338a91bc7bSHarrietAkot 1348a91bc7bSHarrietAkot /// Switch into iterator mode. 1358a91bc7bSHarrietAkot void startIterator() { 1368a91bc7bSHarrietAkot iteratorLocked = true; 1378a91bc7bSHarrietAkot iteratorPos = 0; 1388a91bc7bSHarrietAkot } 1398a91bc7bSHarrietAkot /// Get the next element. 1408a91bc7bSHarrietAkot const Element<V> *getNext() { 1418a91bc7bSHarrietAkot assert(iteratorLocked && "Attempt to getNext() before startIterator()"); 1428a91bc7bSHarrietAkot if (iteratorPos < elements.size()) 1438a91bc7bSHarrietAkot return &(elements[iteratorPos++]); 1448a91bc7bSHarrietAkot iteratorLocked = false; 1458a91bc7bSHarrietAkot return nullptr; 1468a91bc7bSHarrietAkot } 1478a91bc7bSHarrietAkot 1488a91bc7bSHarrietAkot /// Factory method. Permutes the original dimensions according to 1498a91bc7bSHarrietAkot /// the given ordering and expects subsequent add() calls to honor 1508a91bc7bSHarrietAkot /// that same ordering for the given indices. The result is a 1518a91bc7bSHarrietAkot /// fully permuted coordinate scheme. 1528a91bc7bSHarrietAkot static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank, 1538a91bc7bSHarrietAkot const uint64_t *sizes, 1548a91bc7bSHarrietAkot const uint64_t *perm, 1558a91bc7bSHarrietAkot uint64_t capacity = 0) { 1568a91bc7bSHarrietAkot std::vector<uint64_t> permsz(rank); 157d83a7068Swren romano for (uint64_t r = 0; r < rank; r++) { 158d83a7068Swren romano assert(sizes[r] > 0 && "Dimension size zero has trivial storage"); 1598a91bc7bSHarrietAkot permsz[perm[r]] = sizes[r]; 160d83a7068Swren romano } 1618a91bc7bSHarrietAkot return new SparseTensorCOO<V>(permsz, capacity); 1628a91bc7bSHarrietAkot } 1638a91bc7bSHarrietAkot 1648a91bc7bSHarrietAkot private: 1658a91bc7bSHarrietAkot const std::vector<uint64_t> sizes; // per-dimension sizes 1668a91bc7bSHarrietAkot std::vector<Element<V>> elements; 1678a91bc7bSHarrietAkot bool iteratorLocked; 1688a91bc7bSHarrietAkot unsigned iteratorPos; 1698a91bc7bSHarrietAkot }; 1708a91bc7bSHarrietAkot 1718a91bc7bSHarrietAkot /// Abstract base class of sparse tensor storage. Note that we use 1728a91bc7bSHarrietAkot /// function overloading to implement "partial" method specialization. 1738a91bc7bSHarrietAkot class SparseTensorStorageBase { 1748a91bc7bSHarrietAkot public: 1754f2ec7f9SAart Bik /// Dimension size query. 17646bdacaaSwren romano virtual uint64_t getDimSize(uint64_t) const = 0; 1778a91bc7bSHarrietAkot 1784f2ec7f9SAart Bik /// Overhead storage. 1798a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); } 1808a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); } 1818a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); } 1828a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); } 1838a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); } 1848a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); } 1858a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); } 1868a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); } 1878a91bc7bSHarrietAkot 1884f2ec7f9SAart Bik /// Primary storage. 1898a91bc7bSHarrietAkot virtual void getValues(std::vector<double> **) { fatal("valf64"); } 1908a91bc7bSHarrietAkot virtual void getValues(std::vector<float> **) { fatal("valf32"); } 1918a91bc7bSHarrietAkot virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); } 1928a91bc7bSHarrietAkot virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); } 1938a91bc7bSHarrietAkot virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); } 1948a91bc7bSHarrietAkot virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); } 1958a91bc7bSHarrietAkot 1964f2ec7f9SAart Bik /// Element-wise insertion in lexicographic index order. 197c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); } 198c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); } 199c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); } 200c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); } 201c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); } 202c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); } 2034f2ec7f9SAart Bik 2044f2ec7f9SAart Bik /// Expanded insertion. 2054f2ec7f9SAart Bik virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) { 2064f2ec7f9SAart Bik fatal("expf64"); 2074f2ec7f9SAart Bik } 2084f2ec7f9SAart Bik virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) { 2094f2ec7f9SAart Bik fatal("expf32"); 2104f2ec7f9SAart Bik } 2114f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) { 2124f2ec7f9SAart Bik fatal("expi64"); 2134f2ec7f9SAart Bik } 2144f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) { 2154f2ec7f9SAart Bik fatal("expi32"); 2164f2ec7f9SAart Bik } 2174f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) { 2184f2ec7f9SAart Bik fatal("expi16"); 2194f2ec7f9SAart Bik } 2204f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) { 2214f2ec7f9SAart Bik fatal("expi8"); 2224f2ec7f9SAart Bik } 2234f2ec7f9SAart Bik 2244f2ec7f9SAart Bik /// Finishes insertion. 225f66e5769SAart Bik virtual void endInsert() = 0; 226f66e5769SAart Bik 227e5639b3fSMehdi Amini virtual ~SparseTensorStorageBase() = default; 2288a91bc7bSHarrietAkot 2298a91bc7bSHarrietAkot private: 23046bdacaaSwren romano static void fatal(const char *tp) { 2318a91bc7bSHarrietAkot fprintf(stderr, "unsupported %s\n", tp); 2328a91bc7bSHarrietAkot exit(1); 2338a91bc7bSHarrietAkot } 2348a91bc7bSHarrietAkot }; 2358a91bc7bSHarrietAkot 2368a91bc7bSHarrietAkot /// A memory-resident sparse tensor using a storage scheme based on 2378a91bc7bSHarrietAkot /// per-dimension sparse/dense annotations. This data structure provides a 2388a91bc7bSHarrietAkot /// bufferized form of a sparse tensor type. In contrast to generating setup 2398a91bc7bSHarrietAkot /// methods for each differently annotated sparse tensor, this method provides 2408a91bc7bSHarrietAkot /// a convenient "one-size-fits-all" solution that simply takes an input tensor 2418a91bc7bSHarrietAkot /// and annotations to implement all required setup in a general manner. 2428a91bc7bSHarrietAkot template <typename P, typename I, typename V> 2438a91bc7bSHarrietAkot class SparseTensorStorage : public SparseTensorStorageBase { 2448a91bc7bSHarrietAkot public: 2458a91bc7bSHarrietAkot /// Constructs a sparse tensor storage scheme with the given dimensions, 2468a91bc7bSHarrietAkot /// permutation, and per-dimension dense/sparse annotations, using 2478a91bc7bSHarrietAkot /// the coordinate scheme tensor for the initial contents if provided. 2488a91bc7bSHarrietAkot SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm, 249f66e5769SAart Bik const DimLevelType *sparsity, 250f66e5769SAart Bik SparseTensorCOO<V> *tensor = nullptr) 251f66e5769SAart Bik : sizes(szs), rev(getRank()), idx(getRank()), pointers(getRank()), 252f66e5769SAart Bik indices(getRank()) { 2538a91bc7bSHarrietAkot uint64_t rank = getRank(); 2548a91bc7bSHarrietAkot // Store "reverse" permutation. 2558a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 2568a91bc7bSHarrietAkot rev[perm[r]] = r; 2578a91bc7bSHarrietAkot // Provide hints on capacity of pointers and indices. 2588a91bc7bSHarrietAkot // TODO: needs fine-tuning based on sparsity 259f66e5769SAart Bik bool allDense = true; 260f66e5769SAart Bik uint64_t sz = 1; 261f66e5769SAart Bik for (uint64_t r = 0; r < rank; r++) { 2624d0a18d0Swren romano assert(sizes[r] > 0 && "Dimension size zero has trivial storage"); 263f66e5769SAart Bik sz *= sizes[r]; 264845561ecSwren romano if (sparsity[r] == DimLevelType::kCompressed) { 265f66e5769SAart Bik pointers[r].reserve(sz + 1); 266f66e5769SAart Bik indices[r].reserve(sz); 267f66e5769SAart Bik sz = 1; 268f66e5769SAart Bik allDense = false; 269289f84a4Swren romano // Prepare the pointer structure. We cannot use `appendPointer` 2704d0a18d0Swren romano // here, because `isCompressedDim` won't work until after this 2714d0a18d0Swren romano // preparation has been done. 2724d0a18d0Swren romano pointers[r].push_back(0); 2738a91bc7bSHarrietAkot } else { 274845561ecSwren romano assert(sparsity[r] == DimLevelType::kDense && 275845561ecSwren romano "singleton not yet supported"); 2768a91bc7bSHarrietAkot } 2778a91bc7bSHarrietAkot } 2788a91bc7bSHarrietAkot // Then assign contents from coordinate scheme tensor if provided. 2798a91bc7bSHarrietAkot if (tensor) { 2804d0a18d0Swren romano // Ensure both preconditions of `fromCOO`. 2814d0a18d0Swren romano assert(tensor->getSizes() == sizes && "Tensor size mismatch"); 282cf358253Swren romano tensor->sort(); 2834d0a18d0Swren romano // Now actually insert the `elements`. 284ceda1ae9Swren romano const std::vector<Element<V>> &elements = tensor->getElements(); 285ceda1ae9Swren romano uint64_t nnz = elements.size(); 2868a91bc7bSHarrietAkot values.reserve(nnz); 287ceda1ae9Swren romano fromCOO(elements, 0, nnz, 0); 2881ce77b56SAart Bik } else if (allDense) { 289f66e5769SAart Bik values.resize(sz, 0); 2908a91bc7bSHarrietAkot } 2918a91bc7bSHarrietAkot } 2928a91bc7bSHarrietAkot 2930ae2e958SMehdi Amini ~SparseTensorStorage() override = default; 2948a91bc7bSHarrietAkot 2958a91bc7bSHarrietAkot /// Get the rank of the tensor. 2968a91bc7bSHarrietAkot uint64_t getRank() const { return sizes.size(); } 2978a91bc7bSHarrietAkot 29846bdacaaSwren romano /// Get the size of the given dimension of the tensor. 29946bdacaaSwren romano uint64_t getDimSize(uint64_t d) const override { 3008a91bc7bSHarrietAkot assert(d < getRank()); 3018a91bc7bSHarrietAkot return sizes[d]; 3028a91bc7bSHarrietAkot } 3038a91bc7bSHarrietAkot 304f66e5769SAart Bik /// Partially specialize these getter methods based on template types. 3058a91bc7bSHarrietAkot void getPointers(std::vector<P> **out, uint64_t d) override { 3068a91bc7bSHarrietAkot assert(d < getRank()); 3078a91bc7bSHarrietAkot *out = &pointers[d]; 3088a91bc7bSHarrietAkot } 3098a91bc7bSHarrietAkot void getIndices(std::vector<I> **out, uint64_t d) override { 3108a91bc7bSHarrietAkot assert(d < getRank()); 3118a91bc7bSHarrietAkot *out = &indices[d]; 3128a91bc7bSHarrietAkot } 3138a91bc7bSHarrietAkot void getValues(std::vector<V> **out) override { *out = &values; } 3148a91bc7bSHarrietAkot 31503fe15ceSAart Bik /// Partially specialize lexicographical insertions based on template types. 316c03fd1e6Swren romano void lexInsert(const uint64_t *cursor, V val) override { 3171ce77b56SAart Bik // First, wrap up pending insertion path. 3181ce77b56SAart Bik uint64_t diff = 0; 3191ce77b56SAart Bik uint64_t top = 0; 3201ce77b56SAart Bik if (!values.empty()) { 3211ce77b56SAart Bik diff = lexDiff(cursor); 3221ce77b56SAart Bik endPath(diff + 1); 3231ce77b56SAart Bik top = idx[diff] + 1; 3241ce77b56SAart Bik } 3251ce77b56SAart Bik // Then continue with insertion path. 3261ce77b56SAart Bik insPath(cursor, diff, top, val); 327f66e5769SAart Bik } 328f66e5769SAart Bik 3294f2ec7f9SAart Bik /// Partially specialize expanded insertions based on template types. 3304f2ec7f9SAart Bik /// Note that this method resets the values/filled-switch array back 3314f2ec7f9SAart Bik /// to all-zero/false while only iterating over the nonzero elements. 3324f2ec7f9SAart Bik void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added, 3334f2ec7f9SAart Bik uint64_t count) override { 3344f2ec7f9SAart Bik if (count == 0) 3354f2ec7f9SAart Bik return; 3364f2ec7f9SAart Bik // Sort. 3374f2ec7f9SAart Bik std::sort(added, added + count); 3384f2ec7f9SAart Bik // Restore insertion path for first insert. 339*3bf2ba3bSwren romano const uint64_t lastDim = getRank() - 1; 3404f2ec7f9SAart Bik uint64_t index = added[0]; 341*3bf2ba3bSwren romano cursor[lastDim] = index; 3424f2ec7f9SAart Bik lexInsert(cursor, values[index]); 3434f2ec7f9SAart Bik assert(filled[index]); 3444f2ec7f9SAart Bik values[index] = 0; 3454f2ec7f9SAart Bik filled[index] = false; 3464f2ec7f9SAart Bik // Subsequent insertions are quick. 3474f2ec7f9SAart Bik for (uint64_t i = 1; i < count; i++) { 3484f2ec7f9SAart Bik assert(index < added[i] && "non-lexicographic insertion"); 3494f2ec7f9SAart Bik index = added[i]; 350*3bf2ba3bSwren romano cursor[lastDim] = index; 351*3bf2ba3bSwren romano insPath(cursor, lastDim, added[i - 1] + 1, values[index]); 3524f2ec7f9SAart Bik assert(filled[index]); 353*3bf2ba3bSwren romano values[index] = 0; 3544f2ec7f9SAart Bik filled[index] = false; 3554f2ec7f9SAart Bik } 3564f2ec7f9SAart Bik } 3574f2ec7f9SAart Bik 358f66e5769SAart Bik /// Finalizes lexicographic insertions. 3591ce77b56SAart Bik void endInsert() override { 3601ce77b56SAart Bik if (values.empty()) 3611ce77b56SAart Bik endDim(0); 3621ce77b56SAart Bik else 3631ce77b56SAart Bik endPath(0); 3641ce77b56SAart Bik } 365f66e5769SAart Bik 3668a91bc7bSHarrietAkot /// Returns this sparse tensor storage scheme as a new memory-resident 3678a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme with the given dimension order. 3688a91bc7bSHarrietAkot SparseTensorCOO<V> *toCOO(const uint64_t *perm) { 3698a91bc7bSHarrietAkot // Restore original order of the dimension sizes and allocate coordinate 3708a91bc7bSHarrietAkot // scheme with desired new ordering specified in perm. 3718a91bc7bSHarrietAkot uint64_t rank = getRank(); 3728a91bc7bSHarrietAkot std::vector<uint64_t> orgsz(rank); 3738a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 3748a91bc7bSHarrietAkot orgsz[rev[r]] = sizes[r]; 3758a91bc7bSHarrietAkot SparseTensorCOO<V> *tensor = SparseTensorCOO<V>::newSparseTensorCOO( 3768a91bc7bSHarrietAkot rank, orgsz.data(), perm, values.size()); 3778a91bc7bSHarrietAkot // Populate coordinate scheme restored from old ordering and changed with 3788a91bc7bSHarrietAkot // new ordering. Rather than applying both reorderings during the recursion, 3798a91bc7bSHarrietAkot // we compute the combine permutation in advance. 3808a91bc7bSHarrietAkot std::vector<uint64_t> reord(rank); 3818a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 3828a91bc7bSHarrietAkot reord[r] = perm[rev[r]]; 383ceda1ae9Swren romano toCOO(*tensor, reord, 0, 0); 3848a91bc7bSHarrietAkot assert(tensor->getElements().size() == values.size()); 3858a91bc7bSHarrietAkot return tensor; 3868a91bc7bSHarrietAkot } 3878a91bc7bSHarrietAkot 3888a91bc7bSHarrietAkot /// Factory method. Constructs a sparse tensor storage scheme with the given 3898a91bc7bSHarrietAkot /// dimensions, permutation, and per-dimension dense/sparse annotations, 3908a91bc7bSHarrietAkot /// using the coordinate scheme tensor for the initial contents if provided. 3918a91bc7bSHarrietAkot /// In the latter case, the coordinate scheme must respect the same 3928a91bc7bSHarrietAkot /// permutation as is desired for the new sparse tensor storage. 3938a91bc7bSHarrietAkot static SparseTensorStorage<P, I, V> * 394d83a7068Swren romano newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm, 395845561ecSwren romano const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) { 3968a91bc7bSHarrietAkot SparseTensorStorage<P, I, V> *n = nullptr; 3978a91bc7bSHarrietAkot if (tensor) { 3988a91bc7bSHarrietAkot assert(tensor->getRank() == rank); 3998a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 400d83a7068Swren romano assert(shape[r] == 0 || shape[r] == tensor->getSizes()[perm[r]]); 4018a91bc7bSHarrietAkot n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity, 4028a91bc7bSHarrietAkot tensor); 4038a91bc7bSHarrietAkot } else { 4048a91bc7bSHarrietAkot std::vector<uint64_t> permsz(rank); 405d83a7068Swren romano for (uint64_t r = 0; r < rank; r++) { 406d83a7068Swren romano assert(shape[r] > 0 && "Dimension size zero has trivial storage"); 407d83a7068Swren romano permsz[perm[r]] = shape[r]; 408d83a7068Swren romano } 409f66e5769SAart Bik n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity); 4108a91bc7bSHarrietAkot } 4118a91bc7bSHarrietAkot return n; 4128a91bc7bSHarrietAkot } 4138a91bc7bSHarrietAkot 4148a91bc7bSHarrietAkot private: 4154d0a18d0Swren romano /// Appends the next free position of `indices[d]` to `pointers[d]`. 4164d0a18d0Swren romano /// Thus, when called after inserting the last element of a segment, 4174d0a18d0Swren romano /// it will append the position where the next segment begins. 418289f84a4Swren romano inline void appendPointer(uint64_t d) { 4194d0a18d0Swren romano assert(isCompressedDim(d)); // Entails `d < getRank()`. 4204d0a18d0Swren romano uint64_t p = indices[d].size(); 4214d0a18d0Swren romano assert(p <= std::numeric_limits<P>::max() && 4224d0a18d0Swren romano "Pointer value is too large for the P-type"); 4234d0a18d0Swren romano pointers[d].push_back(p); // Here is where we convert to `P`. 4244d0a18d0Swren romano } 4254d0a18d0Swren romano 4264d0a18d0Swren romano /// Appends the given index to `indices[d]`. 427289f84a4Swren romano inline void appendIndex(uint64_t d, uint64_t i) { 4284d0a18d0Swren romano assert(isCompressedDim(d)); // Entails `d < getRank()`. 4294d0a18d0Swren romano assert(i <= std::numeric_limits<I>::max() && 4304d0a18d0Swren romano "Index value is too large for the I-type"); 4314d0a18d0Swren romano indices[d].push_back(i); // Here is where we convert to `I`. 4324d0a18d0Swren romano } 4334d0a18d0Swren romano 4348a91bc7bSHarrietAkot /// Initializes sparse tensor storage scheme from a memory-resident sparse 4358a91bc7bSHarrietAkot /// tensor in coordinate scheme. This method prepares the pointers and 4368a91bc7bSHarrietAkot /// indices arrays under the given per-dimension dense/sparse annotations. 4374d0a18d0Swren romano /// 4384d0a18d0Swren romano /// Preconditions: 4394d0a18d0Swren romano /// (1) the `elements` must be lexicographically sorted. 4404d0a18d0Swren romano /// (2) the indices of every element are valid for `sizes` (equal rank 4414d0a18d0Swren romano /// and pointwise less-than). 442ceda1ae9Swren romano void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo, 443ceda1ae9Swren romano uint64_t hi, uint64_t d) { 4448a91bc7bSHarrietAkot // Once dimensions are exhausted, insert the numerical values. 445c4017f9dSwren romano assert(d <= getRank() && hi <= elements.size()); 4468a91bc7bSHarrietAkot if (d == getRank()) { 447c4017f9dSwren romano assert(lo < hi); 4481ce77b56SAart Bik values.push_back(elements[lo].value); 4498a91bc7bSHarrietAkot return; 4508a91bc7bSHarrietAkot } 4518a91bc7bSHarrietAkot // Visit all elements in this interval. 4528a91bc7bSHarrietAkot uint64_t full = 0; 453c4017f9dSwren romano while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`. 4548a91bc7bSHarrietAkot // Find segment in interval with same index elements in this dimension. 455f66e5769SAart Bik uint64_t i = elements[lo].indices[d]; 4568a91bc7bSHarrietAkot uint64_t seg = lo + 1; 457f66e5769SAart Bik while (seg < hi && elements[seg].indices[d] == i) 4588a91bc7bSHarrietAkot seg++; 4598a91bc7bSHarrietAkot // Handle segment in interval for sparse or dense dimension. 4601ce77b56SAart Bik if (isCompressedDim(d)) { 461289f84a4Swren romano appendIndex(d, i); 4628a91bc7bSHarrietAkot } else { 4638a91bc7bSHarrietAkot // For dense storage we must fill in all the zero values between 4648a91bc7bSHarrietAkot // the previous element (when last we ran this for-loop) and the 4658a91bc7bSHarrietAkot // current element. 466f66e5769SAart Bik for (; full < i; full++) 4671ce77b56SAart Bik endDim(d + 1); 4688a91bc7bSHarrietAkot full++; 4698a91bc7bSHarrietAkot } 470ceda1ae9Swren romano fromCOO(elements, lo, seg, d + 1); 4718a91bc7bSHarrietAkot // And move on to next segment in interval. 4728a91bc7bSHarrietAkot lo = seg; 4738a91bc7bSHarrietAkot } 4748a91bc7bSHarrietAkot // Finalize the sparse pointer structure at this dimension. 4751ce77b56SAart Bik if (isCompressedDim(d)) { 476289f84a4Swren romano appendPointer(d); 4778a91bc7bSHarrietAkot } else { 4788a91bc7bSHarrietAkot // For dense storage we must fill in all the zero values after 4798a91bc7bSHarrietAkot // the last element. 4808a91bc7bSHarrietAkot for (uint64_t sz = sizes[d]; full < sz; full++) 4811ce77b56SAart Bik endDim(d + 1); 4828a91bc7bSHarrietAkot } 4838a91bc7bSHarrietAkot } 4848a91bc7bSHarrietAkot 4858a91bc7bSHarrietAkot /// Stores the sparse tensor storage scheme into a memory-resident sparse 4868a91bc7bSHarrietAkot /// tensor in coordinate scheme. 487ceda1ae9Swren romano void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord, 488f66e5769SAart Bik uint64_t pos, uint64_t d) { 4898a91bc7bSHarrietAkot assert(d <= getRank()); 4908a91bc7bSHarrietAkot if (d == getRank()) { 4918a91bc7bSHarrietAkot assert(pos < values.size()); 492ceda1ae9Swren romano tensor.add(idx, values[pos]); 4931ce77b56SAart Bik } else if (isCompressedDim(d)) { 4948a91bc7bSHarrietAkot // Sparse dimension. 4958a91bc7bSHarrietAkot for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) { 4968a91bc7bSHarrietAkot idx[reord[d]] = indices[d][ii]; 497f66e5769SAart Bik toCOO(tensor, reord, ii, d + 1); 4988a91bc7bSHarrietAkot } 4991ce77b56SAart Bik } else { 5001ce77b56SAart Bik // Dense dimension. 5011ce77b56SAart Bik for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) { 5021ce77b56SAart Bik idx[reord[d]] = i; 5031ce77b56SAart Bik toCOO(tensor, reord, off + i, d + 1); 5048a91bc7bSHarrietAkot } 5058a91bc7bSHarrietAkot } 5061ce77b56SAart Bik } 5071ce77b56SAart Bik 5081ce77b56SAart Bik /// Ends a deeper, never seen before dimension. 5091ce77b56SAart Bik void endDim(uint64_t d) { 5101ce77b56SAart Bik assert(d <= getRank()); 5111ce77b56SAart Bik if (d == getRank()) { 5121ce77b56SAart Bik values.push_back(0); 5131ce77b56SAart Bik } else if (isCompressedDim(d)) { 514289f84a4Swren romano appendPointer(d); 5151ce77b56SAart Bik } else { 5161ce77b56SAart Bik for (uint64_t full = 0, sz = sizes[d]; full < sz; full++) 5171ce77b56SAart Bik endDim(d + 1); 5181ce77b56SAart Bik } 5191ce77b56SAart Bik } 5201ce77b56SAart Bik 5211ce77b56SAart Bik /// Wraps up a single insertion path, inner to outer. 5221ce77b56SAart Bik void endPath(uint64_t diff) { 5231ce77b56SAart Bik uint64_t rank = getRank(); 5241ce77b56SAart Bik assert(diff <= rank); 5251ce77b56SAart Bik for (uint64_t i = 0; i < rank - diff; i++) { 5261ce77b56SAart Bik uint64_t d = rank - i - 1; 5271ce77b56SAart Bik if (isCompressedDim(d)) { 528289f84a4Swren romano appendPointer(d); 5291ce77b56SAart Bik } else { 5301ce77b56SAart Bik for (uint64_t full = idx[d] + 1, sz = sizes[d]; full < sz; full++) 5311ce77b56SAart Bik endDim(d + 1); 5321ce77b56SAart Bik } 5331ce77b56SAart Bik } 5341ce77b56SAart Bik } 5351ce77b56SAart Bik 5361ce77b56SAart Bik /// Continues a single insertion path, outer to inner. 537c03fd1e6Swren romano void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) { 5381ce77b56SAart Bik uint64_t rank = getRank(); 5391ce77b56SAart Bik assert(diff < rank); 5401ce77b56SAart Bik for (uint64_t d = diff; d < rank; d++) { 5411ce77b56SAart Bik uint64_t i = cursor[d]; 5421ce77b56SAart Bik if (isCompressedDim(d)) { 543289f84a4Swren romano appendIndex(d, i); 5441ce77b56SAart Bik } else { 5451ce77b56SAart Bik for (uint64_t full = top; full < i; full++) 5461ce77b56SAart Bik endDim(d + 1); 5471ce77b56SAart Bik } 5481ce77b56SAart Bik top = 0; 5491ce77b56SAart Bik idx[d] = i; 5501ce77b56SAart Bik } 5511ce77b56SAart Bik values.push_back(val); 5521ce77b56SAart Bik } 5531ce77b56SAart Bik 5541ce77b56SAart Bik /// Finds the lexicographic differing dimension. 55546bdacaaSwren romano uint64_t lexDiff(const uint64_t *cursor) const { 5561ce77b56SAart Bik for (uint64_t r = 0, rank = getRank(); r < rank; r++) 5571ce77b56SAart Bik if (cursor[r] > idx[r]) 5581ce77b56SAart Bik return r; 5591ce77b56SAart Bik else 5601ce77b56SAart Bik assert(cursor[r] == idx[r] && "non-lexicographic insertion"); 5611ce77b56SAart Bik assert(0 && "duplication insertion"); 5621ce77b56SAart Bik return -1u; 5631ce77b56SAart Bik } 5641ce77b56SAart Bik 5651ce77b56SAart Bik /// Returns true if dimension is compressed. 5661ce77b56SAart Bik inline bool isCompressedDim(uint64_t d) const { 5674d0a18d0Swren romano assert(d < getRank()); 5681ce77b56SAart Bik return (!pointers[d].empty()); 5691ce77b56SAart Bik } 5708a91bc7bSHarrietAkot 5718a91bc7bSHarrietAkot private: 57246bdacaaSwren romano const std::vector<uint64_t> sizes; // per-dimension sizes 5738a91bc7bSHarrietAkot std::vector<uint64_t> rev; // "reverse" permutation 574f66e5769SAart Bik std::vector<uint64_t> idx; // index cursor 5758a91bc7bSHarrietAkot std::vector<std::vector<P>> pointers; 5768a91bc7bSHarrietAkot std::vector<std::vector<I>> indices; 5778a91bc7bSHarrietAkot std::vector<V> values; 5788a91bc7bSHarrietAkot }; 5798a91bc7bSHarrietAkot 5808a91bc7bSHarrietAkot /// Helper to convert string to lower case. 5818a91bc7bSHarrietAkot static char *toLower(char *token) { 5828a91bc7bSHarrietAkot for (char *c = token; *c; c++) 5838a91bc7bSHarrietAkot *c = tolower(*c); 5848a91bc7bSHarrietAkot return token; 5858a91bc7bSHarrietAkot } 5868a91bc7bSHarrietAkot 5878a91bc7bSHarrietAkot /// Read the MME header of a general sparse matrix of type real. 58803fe15ceSAart Bik static void readMMEHeader(FILE *file, char *filename, char *line, 589bb56c2b3SMehdi Amini uint64_t *idata, bool *isSymmetric) { 5908a91bc7bSHarrietAkot char header[64]; 5918a91bc7bSHarrietAkot char object[64]; 5928a91bc7bSHarrietAkot char format[64]; 5938a91bc7bSHarrietAkot char field[64]; 5948a91bc7bSHarrietAkot char symmetry[64]; 5958a91bc7bSHarrietAkot // Read header line. 5968a91bc7bSHarrietAkot if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field, 5978a91bc7bSHarrietAkot symmetry) != 5) { 59803fe15ceSAart Bik fprintf(stderr, "Corrupt header in %s\n", filename); 5998a91bc7bSHarrietAkot exit(1); 6008a91bc7bSHarrietAkot } 601bb56c2b3SMehdi Amini *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0); 6028a91bc7bSHarrietAkot // Make sure this is a general sparse matrix. 6038a91bc7bSHarrietAkot if (strcmp(toLower(header), "%%matrixmarket") || 6048a91bc7bSHarrietAkot strcmp(toLower(object), "matrix") || 6058a91bc7bSHarrietAkot strcmp(toLower(format), "coordinate") || strcmp(toLower(field), "real") || 606bb56c2b3SMehdi Amini (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) { 6078a91bc7bSHarrietAkot fprintf(stderr, 60803fe15ceSAart Bik "Cannot find a general sparse matrix with type real in %s\n", 60903fe15ceSAart Bik filename); 6108a91bc7bSHarrietAkot exit(1); 6118a91bc7bSHarrietAkot } 6128a91bc7bSHarrietAkot // Skip comments. 613e5639b3fSMehdi Amini while (true) { 61403fe15ceSAart Bik if (!fgets(line, kColWidth, file)) { 61503fe15ceSAart Bik fprintf(stderr, "Cannot find data in %s\n", filename); 6168a91bc7bSHarrietAkot exit(1); 6178a91bc7bSHarrietAkot } 6188a91bc7bSHarrietAkot if (line[0] != '%') 6198a91bc7bSHarrietAkot break; 6208a91bc7bSHarrietAkot } 6218a91bc7bSHarrietAkot // Next line contains M N NNZ. 6228a91bc7bSHarrietAkot idata[0] = 2; // rank 6238a91bc7bSHarrietAkot if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3, 6248a91bc7bSHarrietAkot idata + 1) != 3) { 62503fe15ceSAart Bik fprintf(stderr, "Cannot find size in %s\n", filename); 6268a91bc7bSHarrietAkot exit(1); 6278a91bc7bSHarrietAkot } 6288a91bc7bSHarrietAkot } 6298a91bc7bSHarrietAkot 6308a91bc7bSHarrietAkot /// Read the "extended" FROSTT header. Although not part of the documented 6318a91bc7bSHarrietAkot /// format, we assume that the file starts with optional comments followed 6328a91bc7bSHarrietAkot /// by two lines that define the rank, the number of nonzeros, and the 6338a91bc7bSHarrietAkot /// dimensions sizes (one per rank) of the sparse tensor. 63403fe15ceSAart Bik static void readExtFROSTTHeader(FILE *file, char *filename, char *line, 63503fe15ceSAart Bik uint64_t *idata) { 6368a91bc7bSHarrietAkot // Skip comments. 637e5639b3fSMehdi Amini while (true) { 63803fe15ceSAart Bik if (!fgets(line, kColWidth, file)) { 63903fe15ceSAart Bik fprintf(stderr, "Cannot find data in %s\n", filename); 6408a91bc7bSHarrietAkot exit(1); 6418a91bc7bSHarrietAkot } 6428a91bc7bSHarrietAkot if (line[0] != '#') 6438a91bc7bSHarrietAkot break; 6448a91bc7bSHarrietAkot } 6458a91bc7bSHarrietAkot // Next line contains RANK and NNZ. 6468a91bc7bSHarrietAkot if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) { 64703fe15ceSAart Bik fprintf(stderr, "Cannot find metadata in %s\n", filename); 6488a91bc7bSHarrietAkot exit(1); 6498a91bc7bSHarrietAkot } 6508a91bc7bSHarrietAkot // Followed by a line with the dimension sizes (one per rank). 6518a91bc7bSHarrietAkot for (uint64_t r = 0; r < idata[0]; r++) { 6528a91bc7bSHarrietAkot if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) { 65303fe15ceSAart Bik fprintf(stderr, "Cannot find dimension size %s\n", filename); 6548a91bc7bSHarrietAkot exit(1); 6558a91bc7bSHarrietAkot } 6568a91bc7bSHarrietAkot } 65703fe15ceSAart Bik fgets(line, kColWidth, file); // end of line 6588a91bc7bSHarrietAkot } 6598a91bc7bSHarrietAkot 6608a91bc7bSHarrietAkot /// Reads a sparse tensor with the given filename into a memory-resident 6618a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme. 6628a91bc7bSHarrietAkot template <typename V> 6638a91bc7bSHarrietAkot static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank, 664d83a7068Swren romano const uint64_t *shape, 6658a91bc7bSHarrietAkot const uint64_t *perm) { 6668a91bc7bSHarrietAkot // Open the file. 6678a91bc7bSHarrietAkot FILE *file = fopen(filename, "r"); 6688a91bc7bSHarrietAkot if (!file) { 6693734c078Swren romano assert(filename && "Received nullptr for filename"); 6703734c078Swren romano fprintf(stderr, "Cannot find file %s\n", filename); 6718a91bc7bSHarrietAkot exit(1); 6728a91bc7bSHarrietAkot } 6738a91bc7bSHarrietAkot // Perform some file format dependent set up. 67403fe15ceSAart Bik char line[kColWidth]; 6758a91bc7bSHarrietAkot uint64_t idata[512]; 676bb56c2b3SMehdi Amini bool isSymmetric = false; 6778a91bc7bSHarrietAkot if (strstr(filename, ".mtx")) { 678bb56c2b3SMehdi Amini readMMEHeader(file, filename, line, idata, &isSymmetric); 6798a91bc7bSHarrietAkot } else if (strstr(filename, ".tns")) { 68003fe15ceSAart Bik readExtFROSTTHeader(file, filename, line, idata); 6818a91bc7bSHarrietAkot } else { 6828a91bc7bSHarrietAkot fprintf(stderr, "Unknown format %s\n", filename); 6838a91bc7bSHarrietAkot exit(1); 6848a91bc7bSHarrietAkot } 6858a91bc7bSHarrietAkot // Prepare sparse tensor object with per-dimension sizes 6868a91bc7bSHarrietAkot // and the number of nonzeros as initial capacity. 6878a91bc7bSHarrietAkot assert(rank == idata[0] && "rank mismatch"); 6888a91bc7bSHarrietAkot uint64_t nnz = idata[1]; 6898a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 690d83a7068Swren romano assert((shape[r] == 0 || shape[r] == idata[2 + r]) && 6918a91bc7bSHarrietAkot "dimension size mismatch"); 6928a91bc7bSHarrietAkot SparseTensorCOO<V> *tensor = 6938a91bc7bSHarrietAkot SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz); 6948a91bc7bSHarrietAkot // Read all nonzero elements. 6958a91bc7bSHarrietAkot std::vector<uint64_t> indices(rank); 6968a91bc7bSHarrietAkot for (uint64_t k = 0; k < nnz; k++) { 69703fe15ceSAart Bik if (!fgets(line, kColWidth, file)) { 69803fe15ceSAart Bik fprintf(stderr, "Cannot find next line of data in %s\n", filename); 6998a91bc7bSHarrietAkot exit(1); 7008a91bc7bSHarrietAkot } 70103fe15ceSAart Bik char *linePtr = line; 70203fe15ceSAart Bik for (uint64_t r = 0; r < rank; r++) { 70303fe15ceSAart Bik uint64_t idx = strtoul(linePtr, &linePtr, 10); 7048a91bc7bSHarrietAkot // Add 0-based index. 7058a91bc7bSHarrietAkot indices[perm[r]] = idx - 1; 7068a91bc7bSHarrietAkot } 7078a91bc7bSHarrietAkot // The external formats always store the numerical values with the type 7088a91bc7bSHarrietAkot // double, but we cast these values to the sparse tensor object type. 70903fe15ceSAart Bik double value = strtod(linePtr, &linePtr); 7108a91bc7bSHarrietAkot tensor->add(indices, value); 71102710413SBixia Zheng // We currently chose to deal with symmetric matrices by fully constructing 71202710413SBixia Zheng // them. In the future, we may want to make symmetry implicit for storage 71302710413SBixia Zheng // reasons. 714bb56c2b3SMehdi Amini if (isSymmetric && indices[0] != indices[1]) 71502710413SBixia Zheng tensor->add({indices[1], indices[0]}, value); 7168a91bc7bSHarrietAkot } 7178a91bc7bSHarrietAkot // Close the file and return tensor. 7188a91bc7bSHarrietAkot fclose(file); 7198a91bc7bSHarrietAkot return tensor; 7208a91bc7bSHarrietAkot } 7218a91bc7bSHarrietAkot 722efa15f41SAart Bik /// Writes the sparse tensor to extended FROSTT format. 723efa15f41SAart Bik template <typename V> 72446bdacaaSwren romano static void outSparseTensor(void *tensor, void *dest, bool sort) { 7256438783fSAart Bik assert(tensor && dest); 7266438783fSAart Bik auto coo = static_cast<SparseTensorCOO<V> *>(tensor); 7276438783fSAart Bik if (sort) 7286438783fSAart Bik coo->sort(); 7296438783fSAart Bik char *filename = static_cast<char *>(dest); 7306438783fSAart Bik auto &sizes = coo->getSizes(); 7316438783fSAart Bik auto &elements = coo->getElements(); 7326438783fSAart Bik uint64_t rank = coo->getRank(); 733efa15f41SAart Bik uint64_t nnz = elements.size(); 734efa15f41SAart Bik std::fstream file; 735efa15f41SAart Bik file.open(filename, std::ios_base::out | std::ios_base::trunc); 736efa15f41SAart Bik assert(file.is_open()); 737efa15f41SAart Bik file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl; 738efa15f41SAart Bik for (uint64_t r = 0; r < rank - 1; r++) 739efa15f41SAart Bik file << sizes[r] << " "; 740efa15f41SAart Bik file << sizes[rank - 1] << std::endl; 741efa15f41SAart Bik for (uint64_t i = 0; i < nnz; i++) { 742efa15f41SAart Bik auto &idx = elements[i].indices; 743efa15f41SAart Bik for (uint64_t r = 0; r < rank; r++) 744efa15f41SAart Bik file << (idx[r] + 1) << " "; 745efa15f41SAart Bik file << elements[i].value << std::endl; 746efa15f41SAart Bik } 747efa15f41SAart Bik file.flush(); 748efa15f41SAart Bik file.close(); 749efa15f41SAart Bik assert(file.good()); 7506438783fSAart Bik } 7516438783fSAart Bik 7526438783fSAart Bik /// Initializes sparse tensor from an external COO-flavored format. 7536438783fSAart Bik template <typename V> 75446bdacaaSwren romano static SparseTensorStorage<uint64_t, uint64_t, V> * 7556438783fSAart Bik toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values, 75620eaa88fSBixia Zheng uint64_t *indices, uint64_t *perm, uint8_t *sparse) { 75720eaa88fSBixia Zheng const DimLevelType *sparsity = (DimLevelType *)(sparse); 75820eaa88fSBixia Zheng #ifndef NDEBUG 75920eaa88fSBixia Zheng // Verify that perm is a permutation of 0..(rank-1). 76020eaa88fSBixia Zheng std::vector<uint64_t> order(perm, perm + rank); 76120eaa88fSBixia Zheng std::sort(order.begin(), order.end()); 7621e47888dSAart Bik for (uint64_t i = 0; i < rank; ++i) { 76320eaa88fSBixia Zheng if (i != order[i]) { 764988d4b0dSAart Bik fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank); 76520eaa88fSBixia Zheng exit(1); 76620eaa88fSBixia Zheng } 76720eaa88fSBixia Zheng } 76820eaa88fSBixia Zheng 76920eaa88fSBixia Zheng // Verify that the sparsity values are supported. 7701e47888dSAart Bik for (uint64_t i = 0; i < rank; ++i) { 77120eaa88fSBixia Zheng if (sparsity[i] != DimLevelType::kDense && 77220eaa88fSBixia Zheng sparsity[i] != DimLevelType::kCompressed) { 77320eaa88fSBixia Zheng fprintf(stderr, "Unsupported sparsity value %d\n", 77420eaa88fSBixia Zheng static_cast<int>(sparsity[i])); 77520eaa88fSBixia Zheng exit(1); 77620eaa88fSBixia Zheng } 77720eaa88fSBixia Zheng } 77820eaa88fSBixia Zheng #endif 77920eaa88fSBixia Zheng 7806438783fSAart Bik // Convert external format to internal COO. 78163bdcaf9Swren romano auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse); 7826438783fSAart Bik std::vector<uint64_t> idx(rank); 7836438783fSAart Bik for (uint64_t i = 0, base = 0; i < nse; i++) { 7846438783fSAart Bik for (uint64_t r = 0; r < rank; r++) 785d8b229a1SAart Bik idx[perm[r]] = indices[base + r]; 78663bdcaf9Swren romano coo->add(idx, values[i]); 7876438783fSAart Bik base += rank; 7886438783fSAart Bik } 7896438783fSAart Bik // Return sparse tensor storage format as opaque pointer. 79063bdcaf9Swren romano auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor( 79163bdcaf9Swren romano rank, shape, perm, sparsity, coo); 79263bdcaf9Swren romano delete coo; 79363bdcaf9Swren romano return tensor; 7946438783fSAart Bik } 7956438783fSAart Bik 7966438783fSAart Bik /// Converts a sparse tensor to an external COO-flavored format. 7976438783fSAart Bik template <typename V> 79846bdacaaSwren romano static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse, 79946bdacaaSwren romano uint64_t **pShape, V **pValues, 80046bdacaaSwren romano uint64_t **pIndices) { 8016438783fSAart Bik auto sparseTensor = 8026438783fSAart Bik static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor); 8036438783fSAart Bik uint64_t rank = sparseTensor->getRank(); 8046438783fSAart Bik std::vector<uint64_t> perm(rank); 8056438783fSAart Bik std::iota(perm.begin(), perm.end(), 0); 8066438783fSAart Bik SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data()); 8076438783fSAart Bik 8086438783fSAart Bik const std::vector<Element<V>> &elements = coo->getElements(); 8096438783fSAart Bik uint64_t nse = elements.size(); 8106438783fSAart Bik 8116438783fSAart Bik uint64_t *shape = new uint64_t[rank]; 8126438783fSAart Bik for (uint64_t i = 0; i < rank; i++) 8136438783fSAart Bik shape[i] = coo->getSizes()[i]; 8146438783fSAart Bik 8156438783fSAart Bik V *values = new V[nse]; 8166438783fSAart Bik uint64_t *indices = new uint64_t[rank * nse]; 8176438783fSAart Bik 8186438783fSAart Bik for (uint64_t i = 0, base = 0; i < nse; i++) { 8196438783fSAart Bik values[i] = elements[i].value; 8206438783fSAart Bik for (uint64_t j = 0; j < rank; j++) 8216438783fSAart Bik indices[base + j] = elements[i].indices[j]; 8226438783fSAart Bik base += rank; 8236438783fSAart Bik } 8246438783fSAart Bik 8256438783fSAart Bik delete coo; 8266438783fSAart Bik *pRank = rank; 8276438783fSAart Bik *pNse = nse; 8286438783fSAart Bik *pShape = shape; 8296438783fSAart Bik *pValues = values; 8306438783fSAart Bik *pIndices = indices; 831efa15f41SAart Bik } 832efa15f41SAart Bik 833be0a7e9fSMehdi Amini } // namespace 8348a91bc7bSHarrietAkot 8358a91bc7bSHarrietAkot extern "C" { 8368a91bc7bSHarrietAkot 8378a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 8388a91bc7bSHarrietAkot // 8398a91bc7bSHarrietAkot // Public API with methods that operate on MLIR buffers (memrefs) to interact 8408a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally. 8418a91bc7bSHarrietAkot // These methods should be used exclusively by MLIR compiler-generated code. 8428a91bc7bSHarrietAkot // 8438a91bc7bSHarrietAkot // Some macro magic is used to generate implementations for all required type 8448a91bc7bSHarrietAkot // combinations that can be called from MLIR compiler-generated code. 8458a91bc7bSHarrietAkot // 8468a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 8478a91bc7bSHarrietAkot 8488a91bc7bSHarrietAkot #define CASE(p, i, v, P, I, V) \ 8498a91bc7bSHarrietAkot if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \ 85063bdcaf9Swren romano SparseTensorCOO<V> *coo = nullptr; \ 851845561ecSwren romano if (action <= Action::kFromCOO) { \ 852845561ecSwren romano if (action == Action::kFromFile) { \ 8538a91bc7bSHarrietAkot char *filename = static_cast<char *>(ptr); \ 85463bdcaf9Swren romano coo = openSparseTensorCOO<V>(filename, rank, shape, perm); \ 855845561ecSwren romano } else if (action == Action::kFromCOO) { \ 85663bdcaf9Swren romano coo = static_cast<SparseTensorCOO<V> *>(ptr); \ 8578a91bc7bSHarrietAkot } else { \ 858845561ecSwren romano assert(action == Action::kEmpty); \ 8598a91bc7bSHarrietAkot } \ 86063bdcaf9Swren romano auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor( \ 86163bdcaf9Swren romano rank, shape, perm, sparsity, coo); \ 86263bdcaf9Swren romano if (action == Action::kFromFile) \ 86363bdcaf9Swren romano delete coo; \ 86463bdcaf9Swren romano return tensor; \ 865bb56c2b3SMehdi Amini } \ 866bb56c2b3SMehdi Amini if (action == Action::kEmptyCOO) \ 867d83a7068Swren romano return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm); \ 86863bdcaf9Swren romano coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \ 869845561ecSwren romano if (action == Action::kToIterator) { \ 87063bdcaf9Swren romano coo->startIterator(); \ 8718a91bc7bSHarrietAkot } else { \ 872845561ecSwren romano assert(action == Action::kToCOO); \ 8738a91bc7bSHarrietAkot } \ 87463bdcaf9Swren romano return coo; \ 8758a91bc7bSHarrietAkot } 8768a91bc7bSHarrietAkot 877845561ecSwren romano #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V) 878845561ecSwren romano 8798a91bc7bSHarrietAkot #define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \ 8808a91bc7bSHarrietAkot void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) { \ 8814f2ec7f9SAart Bik assert(ref &&tensor); \ 8828a91bc7bSHarrietAkot std::vector<TYPE> *v; \ 8838a91bc7bSHarrietAkot static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v); \ 8848a91bc7bSHarrietAkot ref->basePtr = ref->data = v->data(); \ 8858a91bc7bSHarrietAkot ref->offset = 0; \ 8868a91bc7bSHarrietAkot ref->sizes[0] = v->size(); \ 8878a91bc7bSHarrietAkot ref->strides[0] = 1; \ 8888a91bc7bSHarrietAkot } 8898a91bc7bSHarrietAkot 8908a91bc7bSHarrietAkot #define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \ 8918a91bc7bSHarrietAkot void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \ 892d2215e79SRainer Orth index_type d) { \ 8934f2ec7f9SAart Bik assert(ref &&tensor); \ 8948a91bc7bSHarrietAkot std::vector<TYPE> *v; \ 8958a91bc7bSHarrietAkot static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d); \ 8968a91bc7bSHarrietAkot ref->basePtr = ref->data = v->data(); \ 8978a91bc7bSHarrietAkot ref->offset = 0; \ 8988a91bc7bSHarrietAkot ref->sizes[0] = v->size(); \ 8998a91bc7bSHarrietAkot ref->strides[0] = 1; \ 9008a91bc7bSHarrietAkot } 9018a91bc7bSHarrietAkot 9028a91bc7bSHarrietAkot #define IMPL_ADDELT(NAME, TYPE) \ 9038a91bc7bSHarrietAkot void *_mlir_ciface_##NAME(void *tensor, TYPE value, \ 904d2215e79SRainer Orth StridedMemRefType<index_type, 1> *iref, \ 905d2215e79SRainer Orth StridedMemRefType<index_type, 1> *pref) { \ 9064f2ec7f9SAart Bik assert(tensor &&iref &&pref); \ 9078a91bc7bSHarrietAkot assert(iref->strides[0] == 1 && pref->strides[0] == 1); \ 9088a91bc7bSHarrietAkot assert(iref->sizes[0] == pref->sizes[0]); \ 909d2215e79SRainer Orth const index_type *indx = iref->data + iref->offset; \ 910d2215e79SRainer Orth const index_type *perm = pref->data + pref->offset; \ 9118a91bc7bSHarrietAkot uint64_t isize = iref->sizes[0]; \ 912d2215e79SRainer Orth std::vector<index_type> indices(isize); \ 9138a91bc7bSHarrietAkot for (uint64_t r = 0; r < isize; r++) \ 9148a91bc7bSHarrietAkot indices[perm[r]] = indx[r]; \ 9158a91bc7bSHarrietAkot static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value); \ 9168a91bc7bSHarrietAkot return tensor; \ 9178a91bc7bSHarrietAkot } 9188a91bc7bSHarrietAkot 9198a91bc7bSHarrietAkot #define IMPL_GETNEXT(NAME, V) \ 920d2215e79SRainer Orth bool _mlir_ciface_##NAME(void *tensor, \ 921d2215e79SRainer Orth StridedMemRefType<index_type, 1> *iref, \ 9228a91bc7bSHarrietAkot StridedMemRefType<V, 0> *vref) { \ 9234f2ec7f9SAart Bik assert(tensor &&iref &&vref); \ 9248a91bc7bSHarrietAkot assert(iref->strides[0] == 1); \ 925d2215e79SRainer Orth index_type *indx = iref->data + iref->offset; \ 926c9f2beffSMehdi Amini V *value = vref->data + vref->offset; \ 9278a91bc7bSHarrietAkot const uint64_t isize = iref->sizes[0]; \ 9288a91bc7bSHarrietAkot auto iter = static_cast<SparseTensorCOO<V> *>(tensor); \ 9298a91bc7bSHarrietAkot const Element<V> *elem = iter->getNext(); \ 93063bdcaf9Swren romano if (elem == nullptr) \ 9318a91bc7bSHarrietAkot return false; \ 9328a91bc7bSHarrietAkot for (uint64_t r = 0; r < isize; r++) \ 9338a91bc7bSHarrietAkot indx[r] = elem->indices[r]; \ 9348a91bc7bSHarrietAkot *value = elem->value; \ 9358a91bc7bSHarrietAkot return true; \ 9368a91bc7bSHarrietAkot } 9378a91bc7bSHarrietAkot 938f66e5769SAart Bik #define IMPL_LEXINSERT(NAME, V) \ 939d2215e79SRainer Orth void _mlir_ciface_##NAME(void *tensor, \ 940d2215e79SRainer Orth StridedMemRefType<index_type, 1> *cref, V val) { \ 9414f2ec7f9SAart Bik assert(tensor &&cref); \ 942f66e5769SAart Bik assert(cref->strides[0] == 1); \ 943d2215e79SRainer Orth index_type *cursor = cref->data + cref->offset; \ 944f66e5769SAart Bik assert(cursor); \ 945f66e5769SAart Bik static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val); \ 946f66e5769SAart Bik } 947f66e5769SAart Bik 9484f2ec7f9SAart Bik #define IMPL_EXPINSERT(NAME, V) \ 9494f2ec7f9SAart Bik void _mlir_ciface_##NAME( \ 950d2215e79SRainer Orth void *tensor, StridedMemRefType<index_type, 1> *cref, \ 9514f2ec7f9SAart Bik StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \ 952d2215e79SRainer Orth StridedMemRefType<index_type, 1> *aref, index_type count) { \ 9534f2ec7f9SAart Bik assert(tensor &&cref &&vref &&fref &&aref); \ 9544f2ec7f9SAart Bik assert(cref->strides[0] == 1); \ 9554f2ec7f9SAart Bik assert(vref->strides[0] == 1); \ 9564f2ec7f9SAart Bik assert(fref->strides[0] == 1); \ 9574f2ec7f9SAart Bik assert(aref->strides[0] == 1); \ 9584f2ec7f9SAart Bik assert(vref->sizes[0] == fref->sizes[0]); \ 959d2215e79SRainer Orth index_type *cursor = cref->data + cref->offset; \ 960c9f2beffSMehdi Amini V *values = vref->data + vref->offset; \ 9614f2ec7f9SAart Bik bool *filled = fref->data + fref->offset; \ 962d2215e79SRainer Orth index_type *added = aref->data + aref->offset; \ 9634f2ec7f9SAart Bik static_cast<SparseTensorStorageBase *>(tensor)->expInsert( \ 9644f2ec7f9SAart Bik cursor, values, filled, added, count); \ 9654f2ec7f9SAart Bik } 9664f2ec7f9SAart Bik 967d2215e79SRainer Orth // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor 968bc04a470Swren romano // can safely rewrite kIndex to kU64. We make this assertion to guarantee 969bc04a470Swren romano // that this file cannot get out of sync with its header. 970d2215e79SRainer Orth static_assert(std::is_same<index_type, uint64_t>::value, 971d2215e79SRainer Orth "Expected index_type == uint64_t"); 972bc04a470Swren romano 9738a91bc7bSHarrietAkot /// Constructs a new sparse tensor. This is the "swiss army knife" 9748a91bc7bSHarrietAkot /// method for materializing sparse tensors into the computation. 9758a91bc7bSHarrietAkot /// 976845561ecSwren romano /// Action: 9778a91bc7bSHarrietAkot /// kEmpty = returns empty storage to fill later 9788a91bc7bSHarrietAkot /// kFromFile = returns storage, where ptr contains filename to read 9798a91bc7bSHarrietAkot /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign 9808a91bc7bSHarrietAkot /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO 9818a91bc7bSHarrietAkot /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO 982845561ecSwren romano /// kToIterator = returns iterator from storage in ptr (call getNext() to use) 9838a91bc7bSHarrietAkot void * 984845561ecSwren romano _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT 985d2215e79SRainer Orth StridedMemRefType<index_type, 1> *sref, 986d2215e79SRainer Orth StridedMemRefType<index_type, 1> *pref, 987845561ecSwren romano OverheadType ptrTp, OverheadType indTp, 988845561ecSwren romano PrimaryType valTp, Action action, void *ptr) { 9898a91bc7bSHarrietAkot assert(aref && sref && pref); 9908a91bc7bSHarrietAkot assert(aref->strides[0] == 1 && sref->strides[0] == 1 && 9918a91bc7bSHarrietAkot pref->strides[0] == 1); 9928a91bc7bSHarrietAkot assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]); 993845561ecSwren romano const DimLevelType *sparsity = aref->data + aref->offset; 994d83a7068Swren romano const index_type *shape = sref->data + sref->offset; 995d2215e79SRainer Orth const index_type *perm = pref->data + pref->offset; 9968a91bc7bSHarrietAkot uint64_t rank = aref->sizes[0]; 9978a91bc7bSHarrietAkot 998bc04a470Swren romano // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. 999bc04a470Swren romano // This is safe because of the static_assert above. 1000bc04a470Swren romano if (ptrTp == OverheadType::kIndex) 1001bc04a470Swren romano ptrTp = OverheadType::kU64; 1002bc04a470Swren romano if (indTp == OverheadType::kIndex) 1003bc04a470Swren romano indTp = OverheadType::kU64; 1004bc04a470Swren romano 10058a91bc7bSHarrietAkot // Double matrices with all combinations of overhead storage. 1006845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t, 1007845561ecSwren romano uint64_t, double); 1008845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t, 1009845561ecSwren romano uint32_t, double); 1010845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t, 1011845561ecSwren romano uint16_t, double); 1012845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t, 1013845561ecSwren romano uint8_t, double); 1014845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t, 1015845561ecSwren romano uint64_t, double); 1016845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t, 1017845561ecSwren romano uint32_t, double); 1018845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t, 1019845561ecSwren romano uint16_t, double); 1020845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t, 1021845561ecSwren romano uint8_t, double); 1022845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t, 1023845561ecSwren romano uint64_t, double); 1024845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t, 1025845561ecSwren romano uint32_t, double); 1026845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t, 1027845561ecSwren romano uint16_t, double); 1028845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t, 1029845561ecSwren romano uint8_t, double); 1030845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t, 1031845561ecSwren romano uint64_t, double); 1032845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t, 1033845561ecSwren romano uint32_t, double); 1034845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t, 1035845561ecSwren romano uint16_t, double); 1036845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t, 1037845561ecSwren romano uint8_t, double); 10388a91bc7bSHarrietAkot 10398a91bc7bSHarrietAkot // Float matrices with all combinations of overhead storage. 1040845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t, 1041845561ecSwren romano uint64_t, float); 1042845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t, 1043845561ecSwren romano uint32_t, float); 1044845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t, 1045845561ecSwren romano uint16_t, float); 1046845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t, 1047845561ecSwren romano uint8_t, float); 1048845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t, 1049845561ecSwren romano uint64_t, float); 1050845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t, 1051845561ecSwren romano uint32_t, float); 1052845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t, 1053845561ecSwren romano uint16_t, float); 1054845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t, 1055845561ecSwren romano uint8_t, float); 1056845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t, 1057845561ecSwren romano uint64_t, float); 1058845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t, 1059845561ecSwren romano uint32_t, float); 1060845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t, 1061845561ecSwren romano uint16_t, float); 1062845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t, 1063845561ecSwren romano uint8_t, float); 1064845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t, 1065845561ecSwren romano uint64_t, float); 1066845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t, 1067845561ecSwren romano uint32_t, float); 1068845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t, 1069845561ecSwren romano uint16_t, float); 1070845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t, 1071845561ecSwren romano uint8_t, float); 10728a91bc7bSHarrietAkot 1073845561ecSwren romano // Integral matrices with both overheads of the same type. 1074845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t); 1075845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t); 1076845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t); 1077845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t); 1078845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t); 1079845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t); 1080845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t); 1081845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t); 1082845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t); 1083845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t); 1084845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t); 1085845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t); 1086845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t); 10878a91bc7bSHarrietAkot 10888a91bc7bSHarrietAkot // Unsupported case (add above if needed). 10898a91bc7bSHarrietAkot fputs("unsupported combination of types\n", stderr); 10908a91bc7bSHarrietAkot exit(1); 10918a91bc7bSHarrietAkot } 10928a91bc7bSHarrietAkot 10938a91bc7bSHarrietAkot /// Methods that provide direct access to pointers. 1094d2215e79SRainer Orth IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers) 10958a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers) 10968a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers) 10978a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers) 10988a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers) 10998a91bc7bSHarrietAkot 11008a91bc7bSHarrietAkot /// Methods that provide direct access to indices. 1101d2215e79SRainer Orth IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices) 11028a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices) 11038a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices) 11048a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices) 11058a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices) 11068a91bc7bSHarrietAkot 11078a91bc7bSHarrietAkot /// Methods that provide direct access to values. 11088a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF64, double, getValues) 11098a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF32, float, getValues) 11108a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues) 11118a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues) 11128a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues) 11138a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues) 11148a91bc7bSHarrietAkot 11158a91bc7bSHarrietAkot /// Helper to add value to coordinate scheme, one per value type. 11168a91bc7bSHarrietAkot IMPL_ADDELT(addEltF64, double) 11178a91bc7bSHarrietAkot IMPL_ADDELT(addEltF32, float) 11188a91bc7bSHarrietAkot IMPL_ADDELT(addEltI64, int64_t) 11198a91bc7bSHarrietAkot IMPL_ADDELT(addEltI32, int32_t) 11208a91bc7bSHarrietAkot IMPL_ADDELT(addEltI16, int16_t) 11218a91bc7bSHarrietAkot IMPL_ADDELT(addEltI8, int8_t) 11228a91bc7bSHarrietAkot 11238a91bc7bSHarrietAkot /// Helper to enumerate elements of coordinate scheme, one per value type. 11248a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF64, double) 11258a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF32, float) 11268a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI64, int64_t) 11278a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI32, int32_t) 11288a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI16, int16_t) 11298a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI8, int8_t) 11308a91bc7bSHarrietAkot 11316438783fSAart Bik /// Insert elements in lexicographical index order, one per value type. 1132f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF64, double) 1133f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF32, float) 1134f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI64, int64_t) 1135f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI32, int32_t) 1136f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI16, int16_t) 1137f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI8, int8_t) 1138f66e5769SAart Bik 11396438783fSAart Bik /// Insert using expansion, one per value type. 11404f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF64, double) 11414f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF32, float) 11424f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI64, int64_t) 11434f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI32, int32_t) 11444f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI16, int16_t) 11454f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI8, int8_t) 11464f2ec7f9SAart Bik 11478a91bc7bSHarrietAkot #undef CASE 11488a91bc7bSHarrietAkot #undef IMPL_SPARSEVALUES 11498a91bc7bSHarrietAkot #undef IMPL_GETOVERHEAD 11508a91bc7bSHarrietAkot #undef IMPL_ADDELT 11518a91bc7bSHarrietAkot #undef IMPL_GETNEXT 11524f2ec7f9SAart Bik #undef IMPL_LEXINSERT 11534f2ec7f9SAart Bik #undef IMPL_EXPINSERT 11546438783fSAart Bik 11556438783fSAart Bik /// Output a sparse tensor, one per value type. 11566438783fSAart Bik void outSparseTensorF64(void *tensor, void *dest, bool sort) { 11576438783fSAart Bik return outSparseTensor<double>(tensor, dest, sort); 11586438783fSAart Bik } 11596438783fSAart Bik void outSparseTensorF32(void *tensor, void *dest, bool sort) { 11606438783fSAart Bik return outSparseTensor<float>(tensor, dest, sort); 11616438783fSAart Bik } 11626438783fSAart Bik void outSparseTensorI64(void *tensor, void *dest, bool sort) { 11636438783fSAart Bik return outSparseTensor<int64_t>(tensor, dest, sort); 11646438783fSAart Bik } 11656438783fSAart Bik void outSparseTensorI32(void *tensor, void *dest, bool sort) { 11666438783fSAart Bik return outSparseTensor<int32_t>(tensor, dest, sort); 11676438783fSAart Bik } 11686438783fSAart Bik void outSparseTensorI16(void *tensor, void *dest, bool sort) { 11696438783fSAart Bik return outSparseTensor<int16_t>(tensor, dest, sort); 11706438783fSAart Bik } 11716438783fSAart Bik void outSparseTensorI8(void *tensor, void *dest, bool sort) { 11726438783fSAart Bik return outSparseTensor<int8_t>(tensor, dest, sort); 11736438783fSAart Bik } 11748a91bc7bSHarrietAkot 11758a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 11768a91bc7bSHarrietAkot // 11778a91bc7bSHarrietAkot // Public API with methods that accept C-style data structures to interact 11788a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally. 11798a91bc7bSHarrietAkot // These methods can be used both by MLIR compiler-generated code as well as by 11808a91bc7bSHarrietAkot // an external runtime that wants to interact with MLIR compiler-generated code. 11818a91bc7bSHarrietAkot // 11828a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 11838a91bc7bSHarrietAkot 11848a91bc7bSHarrietAkot /// Helper method to read a sparse tensor filename from the environment, 11858a91bc7bSHarrietAkot /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc. 1186d2215e79SRainer Orth char *getTensorFilename(index_type id) { 11878a91bc7bSHarrietAkot char var[80]; 11888a91bc7bSHarrietAkot sprintf(var, "TENSOR%" PRIu64, id); 11898a91bc7bSHarrietAkot char *env = getenv(var); 11903734c078Swren romano if (!env) { 11913734c078Swren romano fprintf(stderr, "Environment variable %s is not set\n", var); 11923734c078Swren romano exit(1); 11933734c078Swren romano } 11948a91bc7bSHarrietAkot return env; 11958a91bc7bSHarrietAkot } 11968a91bc7bSHarrietAkot 11978a91bc7bSHarrietAkot /// Returns size of sparse tensor in given dimension. 1198d2215e79SRainer Orth index_type sparseDimSize(void *tensor, index_type d) { 11998a91bc7bSHarrietAkot return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d); 12008a91bc7bSHarrietAkot } 12018a91bc7bSHarrietAkot 1202f66e5769SAart Bik /// Finalizes lexicographic insertions. 1203f66e5769SAart Bik void endInsert(void *tensor) { 1204f66e5769SAart Bik return static_cast<SparseTensorStorageBase *>(tensor)->endInsert(); 1205f66e5769SAart Bik } 1206f66e5769SAart Bik 12078a91bc7bSHarrietAkot /// Releases sparse tensor storage. 12088a91bc7bSHarrietAkot void delSparseTensor(void *tensor) { 12098a91bc7bSHarrietAkot delete static_cast<SparseTensorStorageBase *>(tensor); 12108a91bc7bSHarrietAkot } 12118a91bc7bSHarrietAkot 121263bdcaf9Swren romano /// Releases sparse tensor coordinate scheme. 121363bdcaf9Swren romano #define IMPL_DELCOO(VNAME, V) \ 121463bdcaf9Swren romano void delSparseTensorCOO##VNAME(void *coo) { \ 121563bdcaf9Swren romano delete static_cast<SparseTensorCOO<V> *>(coo); \ 121663bdcaf9Swren romano } 121763bdcaf9Swren romano IMPL_DELCOO(F64, double) 121863bdcaf9Swren romano IMPL_DELCOO(F32, float) 121963bdcaf9Swren romano IMPL_DELCOO(I64, int64_t) 122063bdcaf9Swren romano IMPL_DELCOO(I32, int32_t) 122163bdcaf9Swren romano IMPL_DELCOO(I16, int16_t) 122263bdcaf9Swren romano IMPL_DELCOO(I8, int8_t) 122363bdcaf9Swren romano #undef IMPL_DELCOO 122463bdcaf9Swren romano 12258a91bc7bSHarrietAkot /// Initializes sparse tensor from a COO-flavored format expressed using C-style 12268a91bc7bSHarrietAkot /// data structures. The expected parameters are: 12278a91bc7bSHarrietAkot /// 12288a91bc7bSHarrietAkot /// rank: rank of tensor 12298a91bc7bSHarrietAkot /// nse: number of specified elements (usually the nonzeros) 12308a91bc7bSHarrietAkot /// shape: array with dimension size for each rank 12318a91bc7bSHarrietAkot /// values: a "nse" array with values for all specified elements 12328a91bc7bSHarrietAkot /// indices: a flat "nse x rank" array with indices for all specified elements 123320eaa88fSBixia Zheng /// perm: the permutation of the dimensions in the storage 123420eaa88fSBixia Zheng /// sparse: the sparsity for the dimensions 12358a91bc7bSHarrietAkot /// 12368a91bc7bSHarrietAkot /// For example, the sparse matrix 12378a91bc7bSHarrietAkot /// | 1.0 0.0 0.0 | 12388a91bc7bSHarrietAkot /// | 0.0 5.0 3.0 | 12398a91bc7bSHarrietAkot /// can be passed as 12408a91bc7bSHarrietAkot /// rank = 2 12418a91bc7bSHarrietAkot /// nse = 3 12428a91bc7bSHarrietAkot /// shape = [2, 3] 12438a91bc7bSHarrietAkot /// values = [1.0, 5.0, 3.0] 12448a91bc7bSHarrietAkot /// indices = [ 0, 0, 1, 1, 1, 2] 12458a91bc7bSHarrietAkot // 124620eaa88fSBixia Zheng // TODO: generalize beyond 64-bit indices. 12478a91bc7bSHarrietAkot // 12486438783fSAart Bik void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape, 124920eaa88fSBixia Zheng double *values, uint64_t *indices, 125020eaa88fSBixia Zheng uint64_t *perm, uint8_t *sparse) { 125120eaa88fSBixia Zheng return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm, 125220eaa88fSBixia Zheng sparse); 12538a91bc7bSHarrietAkot } 12546438783fSAart Bik void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape, 125520eaa88fSBixia Zheng float *values, uint64_t *indices, 125620eaa88fSBixia Zheng uint64_t *perm, uint8_t *sparse) { 125720eaa88fSBixia Zheng return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm, 125820eaa88fSBixia Zheng sparse); 12598a91bc7bSHarrietAkot } 12608a91bc7bSHarrietAkot 12612f49e6b0SBixia Zheng /// Converts a sparse tensor to COO-flavored format expressed using C-style 12622f49e6b0SBixia Zheng /// data structures. The expected output parameters are pointers for these 12632f49e6b0SBixia Zheng /// values: 12642f49e6b0SBixia Zheng /// 12652f49e6b0SBixia Zheng /// rank: rank of tensor 12662f49e6b0SBixia Zheng /// nse: number of specified elements (usually the nonzeros) 12672f49e6b0SBixia Zheng /// shape: array with dimension size for each rank 12682f49e6b0SBixia Zheng /// values: a "nse" array with values for all specified elements 12692f49e6b0SBixia Zheng /// indices: a flat "nse x rank" array with indices for all specified elements 12702f49e6b0SBixia Zheng /// 12712f49e6b0SBixia Zheng /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned 12722f49e6b0SBixia Zheng /// from convertToMLIRSparseTensor. 12732f49e6b0SBixia Zheng /// 12742f49e6b0SBixia Zheng // TODO: Currently, values are copied from SparseTensorStorage to 12752f49e6b0SBixia Zheng // SparseTensorCOO, then to the output. We may want to reduce the number of 12762f49e6b0SBixia Zheng // copies. 12772f49e6b0SBixia Zheng // 12786438783fSAart Bik // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions 12796438783fSAart Bik // compressed 12802f49e6b0SBixia Zheng // 12816438783fSAart Bik void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank, 12826438783fSAart Bik uint64_t *pNse, uint64_t **pShape, 12836438783fSAart Bik double **pValues, uint64_t **pIndices) { 12846438783fSAart Bik fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices); 12852f49e6b0SBixia Zheng } 12866438783fSAart Bik void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank, 12876438783fSAart Bik uint64_t *pNse, uint64_t **pShape, 12886438783fSAart Bik float **pValues, uint64_t **pIndices) { 12896438783fSAart Bik fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices); 12902f49e6b0SBixia Zheng } 1291efa15f41SAart Bik 12928a91bc7bSHarrietAkot } // extern "C" 12938a91bc7bSHarrietAkot 12948a91bc7bSHarrietAkot #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 1295