18a91bc7bSHarrietAkot //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===// 28a91bc7bSHarrietAkot // 38a91bc7bSHarrietAkot // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48a91bc7bSHarrietAkot // See https://llvm.org/LICENSE.txt for license information. 58a91bc7bSHarrietAkot // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68a91bc7bSHarrietAkot // 78a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 88a91bc7bSHarrietAkot // 98a91bc7bSHarrietAkot // This file implements a light-weight runtime support library that is useful 108a91bc7bSHarrietAkot // for sparse tensor manipulations. The functionality provided in this library 118a91bc7bSHarrietAkot // is meant to simplify benchmarking, testing, and debugging MLIR code that 128a91bc7bSHarrietAkot // operates on sparse tensors. The provided functionality is **not** part 138a91bc7bSHarrietAkot // of core MLIR, however. 148a91bc7bSHarrietAkot // 158a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 168a91bc7bSHarrietAkot 17845561ecSwren romano #include "mlir/ExecutionEngine/SparseTensorUtils.h" 188a91bc7bSHarrietAkot #include "mlir/ExecutionEngine/CRunnerUtils.h" 198a91bc7bSHarrietAkot 208a91bc7bSHarrietAkot #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 218a91bc7bSHarrietAkot 228a91bc7bSHarrietAkot #include <algorithm> 238a91bc7bSHarrietAkot #include <cassert> 248a91bc7bSHarrietAkot #include <cctype> 258a91bc7bSHarrietAkot #include <cinttypes> 268a91bc7bSHarrietAkot #include <cstdio> 278a91bc7bSHarrietAkot #include <cstdlib> 288a91bc7bSHarrietAkot #include <cstring> 298a91bc7bSHarrietAkot #include <numeric> 308a91bc7bSHarrietAkot #include <vector> 318a91bc7bSHarrietAkot 328a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 338a91bc7bSHarrietAkot // 348a91bc7bSHarrietAkot // Internal support for storing and reading sparse tensors. 358a91bc7bSHarrietAkot // 368a91bc7bSHarrietAkot // The following memory-resident sparse storage schemes are supported: 378a91bc7bSHarrietAkot // 388a91bc7bSHarrietAkot // (a) A coordinate scheme for temporarily storing and lexicographically 398a91bc7bSHarrietAkot // sorting a sparse tensor by index (SparseTensorCOO). 408a91bc7bSHarrietAkot // 418a91bc7bSHarrietAkot // (b) A "one-size-fits-all" sparse tensor storage scheme defined by 428a91bc7bSHarrietAkot // per-dimension sparse/dense annnotations together with a dimension 438a91bc7bSHarrietAkot // ordering used by MLIR compiler-generated code (SparseTensorStorage). 448a91bc7bSHarrietAkot // 458a91bc7bSHarrietAkot // The following external formats are supported: 468a91bc7bSHarrietAkot // 478a91bc7bSHarrietAkot // (1) Matrix Market Exchange (MME): *.mtx 488a91bc7bSHarrietAkot // https://math.nist.gov/MatrixMarket/formats.html 498a91bc7bSHarrietAkot // 508a91bc7bSHarrietAkot // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns 518a91bc7bSHarrietAkot // http://frostt.io/tensors/file-formats.html 528a91bc7bSHarrietAkot // 538a91bc7bSHarrietAkot // Two public APIs are supported: 548a91bc7bSHarrietAkot // 558a91bc7bSHarrietAkot // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse 568a91bc7bSHarrietAkot // tensors. These methods should be used exclusively by MLIR 578a91bc7bSHarrietAkot // compiler-generated code. 588a91bc7bSHarrietAkot // 598a91bc7bSHarrietAkot // (II) Methods that accept C-style data structures to interact with sparse 608a91bc7bSHarrietAkot // tensors. These methods can be used by any external runtime that wants 618a91bc7bSHarrietAkot // to interact with MLIR compiler-generated code. 628a91bc7bSHarrietAkot // 638a91bc7bSHarrietAkot // In both cases (I) and (II), the SparseTensorStorage format is externally 648a91bc7bSHarrietAkot // only visible as an opaque pointer. 658a91bc7bSHarrietAkot // 668a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 678a91bc7bSHarrietAkot 688a91bc7bSHarrietAkot namespace { 698a91bc7bSHarrietAkot 708a91bc7bSHarrietAkot /// A sparse tensor element in coordinate scheme (value and indices). 718a91bc7bSHarrietAkot /// For example, a rank-1 vector element would look like 728a91bc7bSHarrietAkot /// ({i}, a[i]) 738a91bc7bSHarrietAkot /// and a rank-5 tensor element like 748a91bc7bSHarrietAkot /// ({i,j,k,l,m}, a[i,j,k,l,m]) 758a91bc7bSHarrietAkot template <typename V> 768a91bc7bSHarrietAkot struct Element { 778a91bc7bSHarrietAkot Element(const std::vector<uint64_t> &ind, V val) : indices(ind), value(val){}; 788a91bc7bSHarrietAkot std::vector<uint64_t> indices; 798a91bc7bSHarrietAkot V value; 808a91bc7bSHarrietAkot }; 818a91bc7bSHarrietAkot 828a91bc7bSHarrietAkot /// A memory-resident sparse tensor in coordinate scheme (collection of 838a91bc7bSHarrietAkot /// elements). This data structure is used to read a sparse tensor from 848a91bc7bSHarrietAkot /// any external format into memory and sort the elements lexicographically 858a91bc7bSHarrietAkot /// by indices before passing it back to the client (most packed storage 868a91bc7bSHarrietAkot /// formats require the elements to appear in lexicographic index order). 878a91bc7bSHarrietAkot template <typename V> 888a91bc7bSHarrietAkot struct SparseTensorCOO { 898a91bc7bSHarrietAkot public: 908a91bc7bSHarrietAkot SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity) 918a91bc7bSHarrietAkot : sizes(szs), iteratorLocked(false), iteratorPos(0) { 928a91bc7bSHarrietAkot if (capacity) 938a91bc7bSHarrietAkot elements.reserve(capacity); 948a91bc7bSHarrietAkot } 958a91bc7bSHarrietAkot /// Adds element as indices and value. 968a91bc7bSHarrietAkot void add(const std::vector<uint64_t> &ind, V val) { 978a91bc7bSHarrietAkot assert(!iteratorLocked && "Attempt to add() after startIterator()"); 988a91bc7bSHarrietAkot uint64_t rank = getRank(); 998a91bc7bSHarrietAkot assert(rank == ind.size()); 1008a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 1018a91bc7bSHarrietAkot assert(ind[r] < sizes[r]); // within bounds 1028a91bc7bSHarrietAkot elements.emplace_back(ind, val); 1038a91bc7bSHarrietAkot } 1048a91bc7bSHarrietAkot /// Sorts elements lexicographically by index. 1058a91bc7bSHarrietAkot void sort() { 1068a91bc7bSHarrietAkot assert(!iteratorLocked && "Attempt to sort() after startIterator()"); 1078a91bc7bSHarrietAkot std::sort(elements.begin(), elements.end(), lexOrder); 1088a91bc7bSHarrietAkot } 1098a91bc7bSHarrietAkot /// Returns rank. 1108a91bc7bSHarrietAkot uint64_t getRank() const { return sizes.size(); } 1118a91bc7bSHarrietAkot /// Getter for sizes array. 1128a91bc7bSHarrietAkot const std::vector<uint64_t> &getSizes() const { return sizes; } 1138a91bc7bSHarrietAkot /// Getter for elements array. 1148a91bc7bSHarrietAkot const std::vector<Element<V>> &getElements() const { return elements; } 1158a91bc7bSHarrietAkot 1168a91bc7bSHarrietAkot /// Switch into iterator mode. 1178a91bc7bSHarrietAkot void startIterator() { 1188a91bc7bSHarrietAkot iteratorLocked = true; 1198a91bc7bSHarrietAkot iteratorPos = 0; 1208a91bc7bSHarrietAkot } 1218a91bc7bSHarrietAkot /// Get the next element. 1228a91bc7bSHarrietAkot const Element<V> *getNext() { 1238a91bc7bSHarrietAkot assert(iteratorLocked && "Attempt to getNext() before startIterator()"); 1248a91bc7bSHarrietAkot if (iteratorPos < elements.size()) 1258a91bc7bSHarrietAkot return &(elements[iteratorPos++]); 1268a91bc7bSHarrietAkot iteratorLocked = false; 1278a91bc7bSHarrietAkot return nullptr; 1288a91bc7bSHarrietAkot } 1298a91bc7bSHarrietAkot 1308a91bc7bSHarrietAkot /// Factory method. Permutes the original dimensions according to 1318a91bc7bSHarrietAkot /// the given ordering and expects subsequent add() calls to honor 1328a91bc7bSHarrietAkot /// that same ordering for the given indices. The result is a 1338a91bc7bSHarrietAkot /// fully permuted coordinate scheme. 1348a91bc7bSHarrietAkot static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank, 1358a91bc7bSHarrietAkot const uint64_t *sizes, 1368a91bc7bSHarrietAkot const uint64_t *perm, 1378a91bc7bSHarrietAkot uint64_t capacity = 0) { 1388a91bc7bSHarrietAkot std::vector<uint64_t> permsz(rank); 1398a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 1408a91bc7bSHarrietAkot permsz[perm[r]] = sizes[r]; 1418a91bc7bSHarrietAkot return new SparseTensorCOO<V>(permsz, capacity); 1428a91bc7bSHarrietAkot } 1438a91bc7bSHarrietAkot 1448a91bc7bSHarrietAkot private: 1458a91bc7bSHarrietAkot /// Returns true if indices of e1 < indices of e2. 1468a91bc7bSHarrietAkot static bool lexOrder(const Element<V> &e1, const Element<V> &e2) { 1478a91bc7bSHarrietAkot uint64_t rank = e1.indices.size(); 1488a91bc7bSHarrietAkot assert(rank == e2.indices.size()); 1498a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) { 1508a91bc7bSHarrietAkot if (e1.indices[r] == e2.indices[r]) 1518a91bc7bSHarrietAkot continue; 1528a91bc7bSHarrietAkot return e1.indices[r] < e2.indices[r]; 1538a91bc7bSHarrietAkot } 1548a91bc7bSHarrietAkot return false; 1558a91bc7bSHarrietAkot } 1568a91bc7bSHarrietAkot const std::vector<uint64_t> sizes; // per-dimension sizes 1578a91bc7bSHarrietAkot std::vector<Element<V>> elements; 1588a91bc7bSHarrietAkot bool iteratorLocked; 1598a91bc7bSHarrietAkot unsigned iteratorPos; 1608a91bc7bSHarrietAkot }; 1618a91bc7bSHarrietAkot 1628a91bc7bSHarrietAkot /// Abstract base class of sparse tensor storage. Note that we use 1638a91bc7bSHarrietAkot /// function overloading to implement "partial" method specialization. 1648a91bc7bSHarrietAkot class SparseTensorStorageBase { 1658a91bc7bSHarrietAkot public: 1664f2ec7f9SAart Bik /// Dimension size query. 1678a91bc7bSHarrietAkot virtual uint64_t getDimSize(uint64_t) = 0; 1688a91bc7bSHarrietAkot 1694f2ec7f9SAart Bik /// Overhead storage. 1708a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); } 1718a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); } 1728a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); } 1738a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); } 1748a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); } 1758a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); } 1768a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); } 1778a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); } 1788a91bc7bSHarrietAkot 1794f2ec7f9SAart Bik /// Primary storage. 1808a91bc7bSHarrietAkot virtual void getValues(std::vector<double> **) { fatal("valf64"); } 1818a91bc7bSHarrietAkot virtual void getValues(std::vector<float> **) { fatal("valf32"); } 1828a91bc7bSHarrietAkot virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); } 1838a91bc7bSHarrietAkot virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); } 1848a91bc7bSHarrietAkot virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); } 1858a91bc7bSHarrietAkot virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); } 1868a91bc7bSHarrietAkot 1874f2ec7f9SAart Bik /// Element-wise insertion in lexicographic index order. 188f66e5769SAart Bik virtual void lexInsert(uint64_t *, double) { fatal("insf64"); } 189f66e5769SAart Bik virtual void lexInsert(uint64_t *, float) { fatal("insf32"); } 190f66e5769SAart Bik virtual void lexInsert(uint64_t *, int64_t) { fatal("insi64"); } 191f66e5769SAart Bik virtual void lexInsert(uint64_t *, int32_t) { fatal("insi32"); } 192f66e5769SAart Bik virtual void lexInsert(uint64_t *, int16_t) { fatal("ins16"); } 193f66e5769SAart Bik virtual void lexInsert(uint64_t *, int8_t) { fatal("insi8"); } 1944f2ec7f9SAart Bik 1954f2ec7f9SAart Bik /// Expanded insertion. 1964f2ec7f9SAart Bik virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) { 1974f2ec7f9SAart Bik fatal("expf64"); 1984f2ec7f9SAart Bik } 1994f2ec7f9SAart Bik virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) { 2004f2ec7f9SAart Bik fatal("expf32"); 2014f2ec7f9SAart Bik } 2024f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) { 2034f2ec7f9SAart Bik fatal("expi64"); 2044f2ec7f9SAart Bik } 2054f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) { 2064f2ec7f9SAart Bik fatal("expi32"); 2074f2ec7f9SAart Bik } 2084f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) { 2094f2ec7f9SAart Bik fatal("expi16"); 2104f2ec7f9SAart Bik } 2114f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) { 2124f2ec7f9SAart Bik fatal("expi8"); 2134f2ec7f9SAart Bik } 2144f2ec7f9SAart Bik 2154f2ec7f9SAart Bik /// Finishes insertion. 216f66e5769SAart Bik virtual void endInsert() = 0; 217f66e5769SAart Bik 2188a91bc7bSHarrietAkot virtual ~SparseTensorStorageBase() {} 2198a91bc7bSHarrietAkot 2208a91bc7bSHarrietAkot private: 2218a91bc7bSHarrietAkot void fatal(const char *tp) { 2228a91bc7bSHarrietAkot fprintf(stderr, "unsupported %s\n", tp); 2238a91bc7bSHarrietAkot exit(1); 2248a91bc7bSHarrietAkot } 2258a91bc7bSHarrietAkot }; 2268a91bc7bSHarrietAkot 2278a91bc7bSHarrietAkot /// A memory-resident sparse tensor using a storage scheme based on 2288a91bc7bSHarrietAkot /// per-dimension sparse/dense annotations. This data structure provides a 2298a91bc7bSHarrietAkot /// bufferized form of a sparse tensor type. In contrast to generating setup 2308a91bc7bSHarrietAkot /// methods for each differently annotated sparse tensor, this method provides 2318a91bc7bSHarrietAkot /// a convenient "one-size-fits-all" solution that simply takes an input tensor 2328a91bc7bSHarrietAkot /// and annotations to implement all required setup in a general manner. 2338a91bc7bSHarrietAkot template <typename P, typename I, typename V> 2348a91bc7bSHarrietAkot class SparseTensorStorage : public SparseTensorStorageBase { 2358a91bc7bSHarrietAkot public: 2368a91bc7bSHarrietAkot /// Constructs a sparse tensor storage scheme with the given dimensions, 2378a91bc7bSHarrietAkot /// permutation, and per-dimension dense/sparse annotations, using 2388a91bc7bSHarrietAkot /// the coordinate scheme tensor for the initial contents if provided. 2398a91bc7bSHarrietAkot SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm, 240f66e5769SAart Bik const DimLevelType *sparsity, 241f66e5769SAart Bik SparseTensorCOO<V> *tensor = nullptr) 242f66e5769SAart Bik : sizes(szs), rev(getRank()), idx(getRank()), pointers(getRank()), 243f66e5769SAart Bik indices(getRank()) { 2448a91bc7bSHarrietAkot uint64_t rank = getRank(); 2458a91bc7bSHarrietAkot // Store "reverse" permutation. 2468a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 2478a91bc7bSHarrietAkot rev[perm[r]] = r; 2488a91bc7bSHarrietAkot // Provide hints on capacity of pointers and indices. 2498a91bc7bSHarrietAkot // TODO: needs fine-tuning based on sparsity 250f66e5769SAart Bik bool allDense = true; 251f66e5769SAart Bik uint64_t sz = 1; 252f66e5769SAart Bik for (uint64_t r = 0; r < rank; r++) { 253f66e5769SAart Bik sz *= sizes[r]; 254845561ecSwren romano if (sparsity[r] == DimLevelType::kCompressed) { 255f66e5769SAart Bik pointers[r].reserve(sz + 1); 256f66e5769SAart Bik indices[r].reserve(sz); 257f66e5769SAart Bik sz = 1; 258f66e5769SAart Bik allDense = false; 2598a91bc7bSHarrietAkot } else { 260845561ecSwren romano assert(sparsity[r] == DimLevelType::kDense && 261845561ecSwren romano "singleton not yet supported"); 2628a91bc7bSHarrietAkot } 2638a91bc7bSHarrietAkot } 2648a91bc7bSHarrietAkot // Prepare sparse pointer structures for all dimensions. 2658a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 266845561ecSwren romano if (sparsity[r] == DimLevelType::kCompressed) 2678a91bc7bSHarrietAkot pointers[r].push_back(0); 2688a91bc7bSHarrietAkot // Then assign contents from coordinate scheme tensor if provided. 2698a91bc7bSHarrietAkot if (tensor) { 2708a91bc7bSHarrietAkot uint64_t nnz = tensor->getElements().size(); 2718a91bc7bSHarrietAkot values.reserve(nnz); 2721ce77b56SAart Bik fromCOO(tensor, 0, nnz, 0); 2731ce77b56SAart Bik } else if (allDense) { 274f66e5769SAart Bik values.resize(sz, 0); 2758a91bc7bSHarrietAkot } 2768a91bc7bSHarrietAkot } 2778a91bc7bSHarrietAkot 2788a91bc7bSHarrietAkot virtual ~SparseTensorStorage() {} 2798a91bc7bSHarrietAkot 2808a91bc7bSHarrietAkot /// Get the rank of the tensor. 2818a91bc7bSHarrietAkot uint64_t getRank() const { return sizes.size(); } 2828a91bc7bSHarrietAkot 2838a91bc7bSHarrietAkot /// Get the size in the given dimension of the tensor. 2848a91bc7bSHarrietAkot uint64_t getDimSize(uint64_t d) override { 2858a91bc7bSHarrietAkot assert(d < getRank()); 2868a91bc7bSHarrietAkot return sizes[d]; 2878a91bc7bSHarrietAkot } 2888a91bc7bSHarrietAkot 289f66e5769SAart Bik /// Partially specialize these getter methods based on template types. 2908a91bc7bSHarrietAkot void getPointers(std::vector<P> **out, uint64_t d) override { 2918a91bc7bSHarrietAkot assert(d < getRank()); 2928a91bc7bSHarrietAkot *out = &pointers[d]; 2938a91bc7bSHarrietAkot } 2948a91bc7bSHarrietAkot void getIndices(std::vector<I> **out, uint64_t d) override { 2958a91bc7bSHarrietAkot assert(d < getRank()); 2968a91bc7bSHarrietAkot *out = &indices[d]; 2978a91bc7bSHarrietAkot } 2988a91bc7bSHarrietAkot void getValues(std::vector<V> **out) override { *out = &values; } 2998a91bc7bSHarrietAkot 300f66e5769SAart Bik /// Partially specialize lexicographic insertions based on template types. 301f66e5769SAart Bik void lexInsert(uint64_t *cursor, V val) override { 3021ce77b56SAart Bik // First, wrap up pending insertion path. 3031ce77b56SAart Bik uint64_t diff = 0; 3041ce77b56SAart Bik uint64_t top = 0; 3051ce77b56SAart Bik if (!values.empty()) { 3061ce77b56SAart Bik diff = lexDiff(cursor); 3071ce77b56SAart Bik endPath(diff + 1); 3081ce77b56SAart Bik top = idx[diff] + 1; 3091ce77b56SAart Bik } 3101ce77b56SAart Bik // Then continue with insertion path. 3111ce77b56SAart Bik insPath(cursor, diff, top, val); 312f66e5769SAart Bik } 313f66e5769SAart Bik 3144f2ec7f9SAart Bik /// Partially specialize expanded insertions based on template types. 3154f2ec7f9SAart Bik /// Note that this method resets the values/filled-switch array back 3164f2ec7f9SAart Bik /// to all-zero/false while only iterating over the nonzero elements. 3174f2ec7f9SAart Bik void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added, 3184f2ec7f9SAart Bik uint64_t count) override { 3194f2ec7f9SAart Bik if (count == 0) 3204f2ec7f9SAart Bik return; 3214f2ec7f9SAart Bik // Sort. 3224f2ec7f9SAart Bik std::sort(added, added + count); 3234f2ec7f9SAart Bik // Restore insertion path for first insert. 3244f2ec7f9SAart Bik uint64_t rank = getRank(); 3254f2ec7f9SAart Bik uint64_t index = added[0]; 3264f2ec7f9SAart Bik cursor[rank - 1] = index; 3274f2ec7f9SAart Bik lexInsert(cursor, values[index]); 3284f2ec7f9SAart Bik assert(filled[index]); 3294f2ec7f9SAart Bik values[index] = 0; 3304f2ec7f9SAart Bik filled[index] = false; 3314f2ec7f9SAart Bik // Subsequent insertions are quick. 3324f2ec7f9SAart Bik for (uint64_t i = 1; i < count; i++) { 3334f2ec7f9SAart Bik assert(index < added[i] && "non-lexicographic insertion"); 3344f2ec7f9SAart Bik index = added[i]; 3354f2ec7f9SAart Bik cursor[rank - 1] = index; 3364f2ec7f9SAart Bik insPath(cursor, rank - 1, added[i - 1] + 1, values[index]); 3374f2ec7f9SAart Bik assert(filled[index]); 3384f2ec7f9SAart Bik values[index] = 0.0; 3394f2ec7f9SAart Bik filled[index] = false; 3404f2ec7f9SAart Bik } 3414f2ec7f9SAart Bik } 3424f2ec7f9SAart Bik 343f66e5769SAart Bik /// Finalizes lexicographic insertions. 3441ce77b56SAart Bik void endInsert() override { 3451ce77b56SAart Bik if (values.empty()) 3461ce77b56SAart Bik endDim(0); 3471ce77b56SAart Bik else 3481ce77b56SAart Bik endPath(0); 3491ce77b56SAart Bik } 350f66e5769SAart Bik 3518a91bc7bSHarrietAkot /// Returns this sparse tensor storage scheme as a new memory-resident 3528a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme with the given dimension order. 3538a91bc7bSHarrietAkot SparseTensorCOO<V> *toCOO(const uint64_t *perm) { 3548a91bc7bSHarrietAkot // Restore original order of the dimension sizes and allocate coordinate 3558a91bc7bSHarrietAkot // scheme with desired new ordering specified in perm. 3568a91bc7bSHarrietAkot uint64_t rank = getRank(); 3578a91bc7bSHarrietAkot std::vector<uint64_t> orgsz(rank); 3588a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 3598a91bc7bSHarrietAkot orgsz[rev[r]] = sizes[r]; 3608a91bc7bSHarrietAkot SparseTensorCOO<V> *tensor = SparseTensorCOO<V>::newSparseTensorCOO( 3618a91bc7bSHarrietAkot rank, orgsz.data(), perm, values.size()); 3628a91bc7bSHarrietAkot // Populate coordinate scheme restored from old ordering and changed with 3638a91bc7bSHarrietAkot // new ordering. Rather than applying both reorderings during the recursion, 3648a91bc7bSHarrietAkot // we compute the combine permutation in advance. 3658a91bc7bSHarrietAkot std::vector<uint64_t> reord(rank); 3668a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 3678a91bc7bSHarrietAkot reord[r] = perm[rev[r]]; 368f66e5769SAart Bik toCOO(tensor, reord, 0, 0); 3698a91bc7bSHarrietAkot assert(tensor->getElements().size() == values.size()); 3708a91bc7bSHarrietAkot return tensor; 3718a91bc7bSHarrietAkot } 3728a91bc7bSHarrietAkot 3738a91bc7bSHarrietAkot /// Factory method. Constructs a sparse tensor storage scheme with the given 3748a91bc7bSHarrietAkot /// dimensions, permutation, and per-dimension dense/sparse annotations, 3758a91bc7bSHarrietAkot /// using the coordinate scheme tensor for the initial contents if provided. 3768a91bc7bSHarrietAkot /// In the latter case, the coordinate scheme must respect the same 3778a91bc7bSHarrietAkot /// permutation as is desired for the new sparse tensor storage. 3788a91bc7bSHarrietAkot static SparseTensorStorage<P, I, V> * 3798a91bc7bSHarrietAkot newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm, 380845561ecSwren romano const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) { 3818a91bc7bSHarrietAkot SparseTensorStorage<P, I, V> *n = nullptr; 3828a91bc7bSHarrietAkot if (tensor) { 3838a91bc7bSHarrietAkot assert(tensor->getRank() == rank); 3848a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 3858a91bc7bSHarrietAkot assert(sizes[r] == 0 || tensor->getSizes()[perm[r]] == sizes[r]); 3868a91bc7bSHarrietAkot tensor->sort(); // sort lexicographically 3878a91bc7bSHarrietAkot n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity, 3888a91bc7bSHarrietAkot tensor); 3898a91bc7bSHarrietAkot delete tensor; 3908a91bc7bSHarrietAkot } else { 3918a91bc7bSHarrietAkot std::vector<uint64_t> permsz(rank); 3928a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 3938a91bc7bSHarrietAkot permsz[perm[r]] = sizes[r]; 394f66e5769SAart Bik n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity); 3958a91bc7bSHarrietAkot } 3968a91bc7bSHarrietAkot return n; 3978a91bc7bSHarrietAkot } 3988a91bc7bSHarrietAkot 3998a91bc7bSHarrietAkot private: 4008a91bc7bSHarrietAkot /// Initializes sparse tensor storage scheme from a memory-resident sparse 4018a91bc7bSHarrietAkot /// tensor in coordinate scheme. This method prepares the pointers and 4028a91bc7bSHarrietAkot /// indices arrays under the given per-dimension dense/sparse annotations. 4031ce77b56SAart Bik void fromCOO(SparseTensorCOO<V> *tensor, uint64_t lo, uint64_t hi, 4041ce77b56SAart Bik uint64_t d) { 4058a91bc7bSHarrietAkot const std::vector<Element<V>> &elements = tensor->getElements(); 4068a91bc7bSHarrietAkot // Once dimensions are exhausted, insert the numerical values. 407f66e5769SAart Bik assert(d <= getRank()); 4088a91bc7bSHarrietAkot if (d == getRank()) { 4091ce77b56SAart Bik assert(lo < hi && hi <= elements.size()); 4101ce77b56SAart Bik values.push_back(elements[lo].value); 4118a91bc7bSHarrietAkot return; 4128a91bc7bSHarrietAkot } 4138a91bc7bSHarrietAkot // Visit all elements in this interval. 4148a91bc7bSHarrietAkot uint64_t full = 0; 4158a91bc7bSHarrietAkot while (lo < hi) { 4168a91bc7bSHarrietAkot assert(lo < elements.size() && hi <= elements.size()); 4178a91bc7bSHarrietAkot // Find segment in interval with same index elements in this dimension. 418f66e5769SAart Bik uint64_t i = elements[lo].indices[d]; 4198a91bc7bSHarrietAkot uint64_t seg = lo + 1; 420f66e5769SAart Bik while (seg < hi && elements[seg].indices[d] == i) 4218a91bc7bSHarrietAkot seg++; 4228a91bc7bSHarrietAkot // Handle segment in interval for sparse or dense dimension. 4231ce77b56SAart Bik if (isCompressedDim(d)) { 424f66e5769SAart Bik indices[d].push_back(i); 4258a91bc7bSHarrietAkot } else { 4268a91bc7bSHarrietAkot // For dense storage we must fill in all the zero values between 4278a91bc7bSHarrietAkot // the previous element (when last we ran this for-loop) and the 4288a91bc7bSHarrietAkot // current element. 429f66e5769SAart Bik for (; full < i; full++) 4301ce77b56SAart Bik endDim(d + 1); 4318a91bc7bSHarrietAkot full++; 4328a91bc7bSHarrietAkot } 4331ce77b56SAart Bik fromCOO(tensor, lo, seg, d + 1); 4348a91bc7bSHarrietAkot // And move on to next segment in interval. 4358a91bc7bSHarrietAkot lo = seg; 4368a91bc7bSHarrietAkot } 4378a91bc7bSHarrietAkot // Finalize the sparse pointer structure at this dimension. 4381ce77b56SAart Bik if (isCompressedDim(d)) { 4398a91bc7bSHarrietAkot pointers[d].push_back(indices[d].size()); 4408a91bc7bSHarrietAkot } else { 4418a91bc7bSHarrietAkot // For dense storage we must fill in all the zero values after 4428a91bc7bSHarrietAkot // the last element. 4438a91bc7bSHarrietAkot for (uint64_t sz = sizes[d]; full < sz; full++) 4441ce77b56SAart Bik endDim(d + 1); 4458a91bc7bSHarrietAkot } 4468a91bc7bSHarrietAkot } 4478a91bc7bSHarrietAkot 4488a91bc7bSHarrietAkot /// Stores the sparse tensor storage scheme into a memory-resident sparse 4498a91bc7bSHarrietAkot /// tensor in coordinate scheme. 4508a91bc7bSHarrietAkot void toCOO(SparseTensorCOO<V> *tensor, std::vector<uint64_t> &reord, 451f66e5769SAart Bik uint64_t pos, uint64_t d) { 4528a91bc7bSHarrietAkot assert(d <= getRank()); 4538a91bc7bSHarrietAkot if (d == getRank()) { 4548a91bc7bSHarrietAkot assert(pos < values.size()); 4558a91bc7bSHarrietAkot tensor->add(idx, values[pos]); 4561ce77b56SAart Bik } else if (isCompressedDim(d)) { 4578a91bc7bSHarrietAkot // Sparse dimension. 4588a91bc7bSHarrietAkot for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) { 4598a91bc7bSHarrietAkot idx[reord[d]] = indices[d][ii]; 460f66e5769SAart Bik toCOO(tensor, reord, ii, d + 1); 4618a91bc7bSHarrietAkot } 4621ce77b56SAart Bik } else { 4631ce77b56SAart Bik // Dense dimension. 4641ce77b56SAart Bik for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) { 4651ce77b56SAart Bik idx[reord[d]] = i; 4661ce77b56SAart Bik toCOO(tensor, reord, off + i, d + 1); 4678a91bc7bSHarrietAkot } 4688a91bc7bSHarrietAkot } 4691ce77b56SAart Bik } 4701ce77b56SAart Bik 4711ce77b56SAart Bik /// Ends a deeper, never seen before dimension. 4721ce77b56SAart Bik void endDim(uint64_t d) { 4731ce77b56SAart Bik assert(d <= getRank()); 4741ce77b56SAart Bik if (d == getRank()) { 4751ce77b56SAart Bik values.push_back(0); 4761ce77b56SAart Bik } else if (isCompressedDim(d)) { 4771ce77b56SAart Bik pointers[d].push_back(indices[d].size()); 4781ce77b56SAart Bik } else { 4791ce77b56SAart Bik for (uint64_t full = 0, sz = sizes[d]; full < sz; full++) 4801ce77b56SAart Bik endDim(d + 1); 4811ce77b56SAart Bik } 4821ce77b56SAart Bik } 4831ce77b56SAart Bik 4841ce77b56SAart Bik /// Wraps up a single insertion path, inner to outer. 4851ce77b56SAart Bik void endPath(uint64_t diff) { 4861ce77b56SAart Bik uint64_t rank = getRank(); 4871ce77b56SAart Bik assert(diff <= rank); 4881ce77b56SAart Bik for (uint64_t i = 0; i < rank - diff; i++) { 4891ce77b56SAart Bik uint64_t d = rank - i - 1; 4901ce77b56SAart Bik if (isCompressedDim(d)) { 4911ce77b56SAart Bik pointers[d].push_back(indices[d].size()); 4921ce77b56SAart Bik } else { 4931ce77b56SAart Bik for (uint64_t full = idx[d] + 1, sz = sizes[d]; full < sz; full++) 4941ce77b56SAart Bik endDim(d + 1); 4951ce77b56SAart Bik } 4961ce77b56SAart Bik } 4971ce77b56SAart Bik } 4981ce77b56SAart Bik 4991ce77b56SAart Bik /// Continues a single insertion path, outer to inner. 5001ce77b56SAart Bik void insPath(uint64_t *cursor, uint64_t diff, uint64_t top, V val) { 5011ce77b56SAart Bik uint64_t rank = getRank(); 5021ce77b56SAart Bik assert(diff < rank); 5031ce77b56SAart Bik for (uint64_t d = diff; d < rank; d++) { 5041ce77b56SAart Bik uint64_t i = cursor[d]; 5051ce77b56SAart Bik if (isCompressedDim(d)) { 5061ce77b56SAart Bik indices[d].push_back(i); 5071ce77b56SAart Bik } else { 5081ce77b56SAart Bik for (uint64_t full = top; full < i; full++) 5091ce77b56SAart Bik endDim(d + 1); 5101ce77b56SAart Bik } 5111ce77b56SAart Bik top = 0; 5121ce77b56SAart Bik idx[d] = i; 5131ce77b56SAart Bik } 5141ce77b56SAart Bik values.push_back(val); 5151ce77b56SAart Bik } 5161ce77b56SAart Bik 5171ce77b56SAart Bik /// Finds the lexicographic differing dimension. 5181ce77b56SAart Bik uint64_t lexDiff(uint64_t *cursor) { 5191ce77b56SAart Bik for (uint64_t r = 0, rank = getRank(); r < rank; r++) 5201ce77b56SAart Bik if (cursor[r] > idx[r]) 5211ce77b56SAart Bik return r; 5221ce77b56SAart Bik else 5231ce77b56SAart Bik assert(cursor[r] == idx[r] && "non-lexicographic insertion"); 5241ce77b56SAart Bik assert(0 && "duplication insertion"); 5251ce77b56SAart Bik return -1u; 5261ce77b56SAart Bik } 5271ce77b56SAart Bik 5281ce77b56SAart Bik /// Returns true if dimension is compressed. 5291ce77b56SAart Bik inline bool isCompressedDim(uint64_t d) const { 5301ce77b56SAart Bik return (!pointers[d].empty()); 5311ce77b56SAart Bik } 5328a91bc7bSHarrietAkot 5338a91bc7bSHarrietAkot private: 5348a91bc7bSHarrietAkot std::vector<uint64_t> sizes; // per-dimension sizes 5358a91bc7bSHarrietAkot std::vector<uint64_t> rev; // "reverse" permutation 536f66e5769SAart Bik std::vector<uint64_t> idx; // index cursor 5378a91bc7bSHarrietAkot std::vector<std::vector<P>> pointers; 5388a91bc7bSHarrietAkot std::vector<std::vector<I>> indices; 5398a91bc7bSHarrietAkot std::vector<V> values; 5408a91bc7bSHarrietAkot }; 5418a91bc7bSHarrietAkot 5428a91bc7bSHarrietAkot /// Helper to convert string to lower case. 5438a91bc7bSHarrietAkot static char *toLower(char *token) { 5448a91bc7bSHarrietAkot for (char *c = token; *c; c++) 5458a91bc7bSHarrietAkot *c = tolower(*c); 5468a91bc7bSHarrietAkot return token; 5478a91bc7bSHarrietAkot } 5488a91bc7bSHarrietAkot 5498a91bc7bSHarrietAkot /// Read the MME header of a general sparse matrix of type real. 55002710413SBixia Zheng static void readMMEHeader(FILE *file, char *name, uint64_t *idata, 55102710413SBixia Zheng bool *is_symmetric) { 5528a91bc7bSHarrietAkot char line[1025]; 5538a91bc7bSHarrietAkot char header[64]; 5548a91bc7bSHarrietAkot char object[64]; 5558a91bc7bSHarrietAkot char format[64]; 5568a91bc7bSHarrietAkot char field[64]; 5578a91bc7bSHarrietAkot char symmetry[64]; 5588a91bc7bSHarrietAkot // Read header line. 5598a91bc7bSHarrietAkot if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field, 5608a91bc7bSHarrietAkot symmetry) != 5) { 5618a91bc7bSHarrietAkot fprintf(stderr, "Corrupt header in %s\n", name); 5628a91bc7bSHarrietAkot exit(1); 5638a91bc7bSHarrietAkot } 56402710413SBixia Zheng *is_symmetric = (strcmp(toLower(symmetry), "symmetric") == 0); 5658a91bc7bSHarrietAkot // Make sure this is a general sparse matrix. 5668a91bc7bSHarrietAkot if (strcmp(toLower(header), "%%matrixmarket") || 5678a91bc7bSHarrietAkot strcmp(toLower(object), "matrix") || 5688a91bc7bSHarrietAkot strcmp(toLower(format), "coordinate") || strcmp(toLower(field), "real") || 56902710413SBixia Zheng (strcmp(toLower(symmetry), "general") && !(*is_symmetric))) { 5708a91bc7bSHarrietAkot fprintf(stderr, 5718a91bc7bSHarrietAkot "Cannot find a general sparse matrix with type real in %s\n", name); 5728a91bc7bSHarrietAkot exit(1); 5738a91bc7bSHarrietAkot } 5748a91bc7bSHarrietAkot // Skip comments. 5758a91bc7bSHarrietAkot while (1) { 5768a91bc7bSHarrietAkot if (!fgets(line, 1025, file)) { 5778a91bc7bSHarrietAkot fprintf(stderr, "Cannot find data in %s\n", name); 5788a91bc7bSHarrietAkot exit(1); 5798a91bc7bSHarrietAkot } 5808a91bc7bSHarrietAkot if (line[0] != '%') 5818a91bc7bSHarrietAkot break; 5828a91bc7bSHarrietAkot } 5838a91bc7bSHarrietAkot // Next line contains M N NNZ. 5848a91bc7bSHarrietAkot idata[0] = 2; // rank 5858a91bc7bSHarrietAkot if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3, 5868a91bc7bSHarrietAkot idata + 1) != 3) { 5878a91bc7bSHarrietAkot fprintf(stderr, "Cannot find size in %s\n", name); 5888a91bc7bSHarrietAkot exit(1); 5898a91bc7bSHarrietAkot } 5908a91bc7bSHarrietAkot } 5918a91bc7bSHarrietAkot 5928a91bc7bSHarrietAkot /// Read the "extended" FROSTT header. Although not part of the documented 5938a91bc7bSHarrietAkot /// format, we assume that the file starts with optional comments followed 5948a91bc7bSHarrietAkot /// by two lines that define the rank, the number of nonzeros, and the 5958a91bc7bSHarrietAkot /// dimensions sizes (one per rank) of the sparse tensor. 5968a91bc7bSHarrietAkot static void readExtFROSTTHeader(FILE *file, char *name, uint64_t *idata) { 5978a91bc7bSHarrietAkot char line[1025]; 5988a91bc7bSHarrietAkot // Skip comments. 5998a91bc7bSHarrietAkot while (1) { 6008a91bc7bSHarrietAkot if (!fgets(line, 1025, file)) { 6018a91bc7bSHarrietAkot fprintf(stderr, "Cannot find data in %s\n", name); 6028a91bc7bSHarrietAkot exit(1); 6038a91bc7bSHarrietAkot } 6048a91bc7bSHarrietAkot if (line[0] != '#') 6058a91bc7bSHarrietAkot break; 6068a91bc7bSHarrietAkot } 6078a91bc7bSHarrietAkot // Next line contains RANK and NNZ. 6088a91bc7bSHarrietAkot if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) { 6098a91bc7bSHarrietAkot fprintf(stderr, "Cannot find metadata in %s\n", name); 6108a91bc7bSHarrietAkot exit(1); 6118a91bc7bSHarrietAkot } 6128a91bc7bSHarrietAkot // Followed by a line with the dimension sizes (one per rank). 6138a91bc7bSHarrietAkot for (uint64_t r = 0; r < idata[0]; r++) { 6148a91bc7bSHarrietAkot if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) { 6158a91bc7bSHarrietAkot fprintf(stderr, "Cannot find dimension size %s\n", name); 6168a91bc7bSHarrietAkot exit(1); 6178a91bc7bSHarrietAkot } 6188a91bc7bSHarrietAkot } 6198a91bc7bSHarrietAkot } 6208a91bc7bSHarrietAkot 6218a91bc7bSHarrietAkot /// Reads a sparse tensor with the given filename into a memory-resident 6228a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme. 6238a91bc7bSHarrietAkot template <typename V> 6248a91bc7bSHarrietAkot static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank, 6258a91bc7bSHarrietAkot const uint64_t *sizes, 6268a91bc7bSHarrietAkot const uint64_t *perm) { 6278a91bc7bSHarrietAkot // Open the file. 6288a91bc7bSHarrietAkot FILE *file = fopen(filename, "r"); 6298a91bc7bSHarrietAkot if (!file) { 6308a91bc7bSHarrietAkot fprintf(stderr, "Cannot find %s\n", filename); 6318a91bc7bSHarrietAkot exit(1); 6328a91bc7bSHarrietAkot } 6338a91bc7bSHarrietAkot // Perform some file format dependent set up. 6348a91bc7bSHarrietAkot uint64_t idata[512]; 63502710413SBixia Zheng bool is_symmetric = false; 6368a91bc7bSHarrietAkot if (strstr(filename, ".mtx")) { 63702710413SBixia Zheng readMMEHeader(file, filename, idata, &is_symmetric); 6388a91bc7bSHarrietAkot } else if (strstr(filename, ".tns")) { 6398a91bc7bSHarrietAkot readExtFROSTTHeader(file, filename, idata); 6408a91bc7bSHarrietAkot } else { 6418a91bc7bSHarrietAkot fprintf(stderr, "Unknown format %s\n", filename); 6428a91bc7bSHarrietAkot exit(1); 6438a91bc7bSHarrietAkot } 6448a91bc7bSHarrietAkot // Prepare sparse tensor object with per-dimension sizes 6458a91bc7bSHarrietAkot // and the number of nonzeros as initial capacity. 6468a91bc7bSHarrietAkot assert(rank == idata[0] && "rank mismatch"); 6478a91bc7bSHarrietAkot uint64_t nnz = idata[1]; 6488a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 6498a91bc7bSHarrietAkot assert((sizes[r] == 0 || sizes[r] == idata[2 + r]) && 6508a91bc7bSHarrietAkot "dimension size mismatch"); 6518a91bc7bSHarrietAkot SparseTensorCOO<V> *tensor = 6528a91bc7bSHarrietAkot SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz); 6538a91bc7bSHarrietAkot // Read all nonzero elements. 6548a91bc7bSHarrietAkot std::vector<uint64_t> indices(rank); 6558a91bc7bSHarrietAkot for (uint64_t k = 0; k < nnz; k++) { 656f66e5769SAart Bik uint64_t idx = -1u; 6578a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) { 6588a91bc7bSHarrietAkot if (fscanf(file, "%" PRIu64, &idx) != 1) { 6598a91bc7bSHarrietAkot fprintf(stderr, "Cannot find next index in %s\n", filename); 6608a91bc7bSHarrietAkot exit(1); 6618a91bc7bSHarrietAkot } 6628a91bc7bSHarrietAkot // Add 0-based index. 6638a91bc7bSHarrietAkot indices[perm[r]] = idx - 1; 6648a91bc7bSHarrietAkot } 6658a91bc7bSHarrietAkot // The external formats always store the numerical values with the type 6668a91bc7bSHarrietAkot // double, but we cast these values to the sparse tensor object type. 6678a91bc7bSHarrietAkot double value; 6688a91bc7bSHarrietAkot if (fscanf(file, "%lg\n", &value) != 1) { 6698a91bc7bSHarrietAkot fprintf(stderr, "Cannot find next value in %s\n", filename); 6708a91bc7bSHarrietAkot exit(1); 6718a91bc7bSHarrietAkot } 6728a91bc7bSHarrietAkot tensor->add(indices, value); 67302710413SBixia Zheng // We currently chose to deal with symmetric matrices by fully constructing 67402710413SBixia Zheng // them. In the future, we may want to make symmetry implicit for storage 67502710413SBixia Zheng // reasons. 67602710413SBixia Zheng if (is_symmetric && indices[0] != indices[1]) 67702710413SBixia Zheng tensor->add({indices[1], indices[0]}, value); 6788a91bc7bSHarrietAkot } 6798a91bc7bSHarrietAkot // Close the file and return tensor. 6808a91bc7bSHarrietAkot fclose(file); 6818a91bc7bSHarrietAkot return tensor; 6828a91bc7bSHarrietAkot } 6838a91bc7bSHarrietAkot 684*be0a7e9fSMehdi Amini } // namespace 6858a91bc7bSHarrietAkot 6868a91bc7bSHarrietAkot extern "C" { 6878a91bc7bSHarrietAkot 6888a91bc7bSHarrietAkot /// This type is used in the public API at all places where MLIR expects 6898a91bc7bSHarrietAkot /// values with the built-in type "index". For now, we simply assume that 6908a91bc7bSHarrietAkot /// type is 64-bit, but targets with different "index" bit widths should link 6918a91bc7bSHarrietAkot /// with an alternatively built runtime support library. 6928a91bc7bSHarrietAkot // TODO: support such targets? 6938a91bc7bSHarrietAkot typedef uint64_t index_t; 6948a91bc7bSHarrietAkot 6958a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 6968a91bc7bSHarrietAkot // 6978a91bc7bSHarrietAkot // Public API with methods that operate on MLIR buffers (memrefs) to interact 6988a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally. 6998a91bc7bSHarrietAkot // These methods should be used exclusively by MLIR compiler-generated code. 7008a91bc7bSHarrietAkot // 7018a91bc7bSHarrietAkot // Some macro magic is used to generate implementations for all required type 7028a91bc7bSHarrietAkot // combinations that can be called from MLIR compiler-generated code. 7038a91bc7bSHarrietAkot // 7048a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 7058a91bc7bSHarrietAkot 7068a91bc7bSHarrietAkot #define CASE(p, i, v, P, I, V) \ 7078a91bc7bSHarrietAkot if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \ 7088a91bc7bSHarrietAkot SparseTensorCOO<V> *tensor = nullptr; \ 709845561ecSwren romano if (action <= Action::kFromCOO) { \ 710845561ecSwren romano if (action == Action::kFromFile) { \ 7118a91bc7bSHarrietAkot char *filename = static_cast<char *>(ptr); \ 7128a91bc7bSHarrietAkot tensor = openSparseTensorCOO<V>(filename, rank, sizes, perm); \ 713845561ecSwren romano } else if (action == Action::kFromCOO) { \ 7148a91bc7bSHarrietAkot tensor = static_cast<SparseTensorCOO<V> *>(ptr); \ 7158a91bc7bSHarrietAkot } else { \ 716845561ecSwren romano assert(action == Action::kEmpty); \ 7178a91bc7bSHarrietAkot } \ 7188a91bc7bSHarrietAkot return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm, \ 7198a91bc7bSHarrietAkot sparsity, tensor); \ 720845561ecSwren romano } else if (action == Action::kEmptyCOO) { \ 7218a91bc7bSHarrietAkot return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm); \ 7228a91bc7bSHarrietAkot } else { \ 7238a91bc7bSHarrietAkot tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \ 724845561ecSwren romano if (action == Action::kToIterator) { \ 7258a91bc7bSHarrietAkot tensor->startIterator(); \ 7268a91bc7bSHarrietAkot } else { \ 727845561ecSwren romano assert(action == Action::kToCOO); \ 7288a91bc7bSHarrietAkot } \ 7298a91bc7bSHarrietAkot return tensor; \ 7308a91bc7bSHarrietAkot } \ 7318a91bc7bSHarrietAkot } 7328a91bc7bSHarrietAkot 733845561ecSwren romano #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V) 734845561ecSwren romano 7358a91bc7bSHarrietAkot #define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \ 7368a91bc7bSHarrietAkot void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) { \ 7374f2ec7f9SAart Bik assert(ref &&tensor); \ 7388a91bc7bSHarrietAkot std::vector<TYPE> *v; \ 7398a91bc7bSHarrietAkot static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v); \ 7408a91bc7bSHarrietAkot ref->basePtr = ref->data = v->data(); \ 7418a91bc7bSHarrietAkot ref->offset = 0; \ 7428a91bc7bSHarrietAkot ref->sizes[0] = v->size(); \ 7438a91bc7bSHarrietAkot ref->strides[0] = 1; \ 7448a91bc7bSHarrietAkot } 7458a91bc7bSHarrietAkot 7468a91bc7bSHarrietAkot #define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \ 7478a91bc7bSHarrietAkot void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \ 7488a91bc7bSHarrietAkot index_t d) { \ 7494f2ec7f9SAart Bik assert(ref &&tensor); \ 7508a91bc7bSHarrietAkot std::vector<TYPE> *v; \ 7518a91bc7bSHarrietAkot static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d); \ 7528a91bc7bSHarrietAkot ref->basePtr = ref->data = v->data(); \ 7538a91bc7bSHarrietAkot ref->offset = 0; \ 7548a91bc7bSHarrietAkot ref->sizes[0] = v->size(); \ 7558a91bc7bSHarrietAkot ref->strides[0] = 1; \ 7568a91bc7bSHarrietAkot } 7578a91bc7bSHarrietAkot 7588a91bc7bSHarrietAkot #define IMPL_ADDELT(NAME, TYPE) \ 7598a91bc7bSHarrietAkot void *_mlir_ciface_##NAME(void *tensor, TYPE value, \ 7608a91bc7bSHarrietAkot StridedMemRefType<index_t, 1> *iref, \ 7618a91bc7bSHarrietAkot StridedMemRefType<index_t, 1> *pref) { \ 7624f2ec7f9SAart Bik assert(tensor &&iref &&pref); \ 7638a91bc7bSHarrietAkot assert(iref->strides[0] == 1 && pref->strides[0] == 1); \ 7648a91bc7bSHarrietAkot assert(iref->sizes[0] == pref->sizes[0]); \ 7658a91bc7bSHarrietAkot const index_t *indx = iref->data + iref->offset; \ 7668a91bc7bSHarrietAkot const index_t *perm = pref->data + pref->offset; \ 7678a91bc7bSHarrietAkot uint64_t isize = iref->sizes[0]; \ 7688a91bc7bSHarrietAkot std::vector<index_t> indices(isize); \ 7698a91bc7bSHarrietAkot for (uint64_t r = 0; r < isize; r++) \ 7708a91bc7bSHarrietAkot indices[perm[r]] = indx[r]; \ 7718a91bc7bSHarrietAkot static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value); \ 7728a91bc7bSHarrietAkot return tensor; \ 7738a91bc7bSHarrietAkot } 7748a91bc7bSHarrietAkot 7758a91bc7bSHarrietAkot #define IMPL_GETNEXT(NAME, V) \ 7764f2ec7f9SAart Bik bool _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *iref, \ 7778a91bc7bSHarrietAkot StridedMemRefType<V, 0> *vref) { \ 7784f2ec7f9SAart Bik assert(tensor &&iref &&vref); \ 7798a91bc7bSHarrietAkot assert(iref->strides[0] == 1); \ 7804f2ec7f9SAart Bik index_t *indx = iref->data + iref->offset; \ 7818a91bc7bSHarrietAkot V *value = vref->data + vref->offset; \ 7828a91bc7bSHarrietAkot const uint64_t isize = iref->sizes[0]; \ 7838a91bc7bSHarrietAkot auto iter = static_cast<SparseTensorCOO<V> *>(tensor); \ 7848a91bc7bSHarrietAkot const Element<V> *elem = iter->getNext(); \ 7858a91bc7bSHarrietAkot if (elem == nullptr) { \ 7868a91bc7bSHarrietAkot delete iter; \ 7878a91bc7bSHarrietAkot return false; \ 7888a91bc7bSHarrietAkot } \ 7898a91bc7bSHarrietAkot for (uint64_t r = 0; r < isize; r++) \ 7908a91bc7bSHarrietAkot indx[r] = elem->indices[r]; \ 7918a91bc7bSHarrietAkot *value = elem->value; \ 7928a91bc7bSHarrietAkot return true; \ 7938a91bc7bSHarrietAkot } 7948a91bc7bSHarrietAkot 795f66e5769SAart Bik #define IMPL_LEXINSERT(NAME, V) \ 796f66e5769SAart Bik void _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *cref, \ 797f66e5769SAart Bik V val) { \ 7984f2ec7f9SAart Bik assert(tensor &&cref); \ 799f66e5769SAart Bik assert(cref->strides[0] == 1); \ 8004f2ec7f9SAart Bik index_t *cursor = cref->data + cref->offset; \ 801f66e5769SAart Bik assert(cursor); \ 802f66e5769SAart Bik static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val); \ 803f66e5769SAart Bik } 804f66e5769SAart Bik 8054f2ec7f9SAart Bik #define IMPL_EXPINSERT(NAME, V) \ 8064f2ec7f9SAart Bik void _mlir_ciface_##NAME( \ 8074f2ec7f9SAart Bik void *tensor, StridedMemRefType<index_t, 1> *cref, \ 8084f2ec7f9SAart Bik StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \ 8094f2ec7f9SAart Bik StridedMemRefType<index_t, 1> *aref, index_t count) { \ 8104f2ec7f9SAart Bik assert(tensor &&cref &&vref &&fref &&aref); \ 8114f2ec7f9SAart Bik assert(cref->strides[0] == 1); \ 8124f2ec7f9SAart Bik assert(vref->strides[0] == 1); \ 8134f2ec7f9SAart Bik assert(fref->strides[0] == 1); \ 8144f2ec7f9SAart Bik assert(aref->strides[0] == 1); \ 8154f2ec7f9SAart Bik assert(vref->sizes[0] == fref->sizes[0]); \ 8164f2ec7f9SAart Bik index_t *cursor = cref->data + cref->offset; \ 8174f2ec7f9SAart Bik V *values = vref->data + vref->offset; \ 8184f2ec7f9SAart Bik bool *filled = fref->data + fref->offset; \ 8194f2ec7f9SAart Bik index_t *added = aref->data + aref->offset; \ 8204f2ec7f9SAart Bik static_cast<SparseTensorStorageBase *>(tensor)->expInsert( \ 8214f2ec7f9SAart Bik cursor, values, filled, added, count); \ 8224f2ec7f9SAart Bik } 8234f2ec7f9SAart Bik 8248a91bc7bSHarrietAkot /// Constructs a new sparse tensor. This is the "swiss army knife" 8258a91bc7bSHarrietAkot /// method for materializing sparse tensors into the computation. 8268a91bc7bSHarrietAkot /// 827845561ecSwren romano /// Action: 8288a91bc7bSHarrietAkot /// kEmpty = returns empty storage to fill later 8298a91bc7bSHarrietAkot /// kFromFile = returns storage, where ptr contains filename to read 8308a91bc7bSHarrietAkot /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign 8318a91bc7bSHarrietAkot /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO 8328a91bc7bSHarrietAkot /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO 833845561ecSwren romano /// kToIterator = returns iterator from storage in ptr (call getNext() to use) 8348a91bc7bSHarrietAkot void * 835845561ecSwren romano _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT 8368a91bc7bSHarrietAkot StridedMemRefType<index_t, 1> *sref, 8378a91bc7bSHarrietAkot StridedMemRefType<index_t, 1> *pref, 838845561ecSwren romano OverheadType ptrTp, OverheadType indTp, 839845561ecSwren romano PrimaryType valTp, Action action, void *ptr) { 8408a91bc7bSHarrietAkot assert(aref && sref && pref); 8418a91bc7bSHarrietAkot assert(aref->strides[0] == 1 && sref->strides[0] == 1 && 8428a91bc7bSHarrietAkot pref->strides[0] == 1); 8438a91bc7bSHarrietAkot assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]); 844845561ecSwren romano const DimLevelType *sparsity = aref->data + aref->offset; 8458a91bc7bSHarrietAkot const index_t *sizes = sref->data + sref->offset; 8468a91bc7bSHarrietAkot const index_t *perm = pref->data + pref->offset; 8478a91bc7bSHarrietAkot uint64_t rank = aref->sizes[0]; 8488a91bc7bSHarrietAkot 8498a91bc7bSHarrietAkot // Double matrices with all combinations of overhead storage. 850845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t, 851845561ecSwren romano uint64_t, double); 852845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t, 853845561ecSwren romano uint32_t, double); 854845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t, 855845561ecSwren romano uint16_t, double); 856845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t, 857845561ecSwren romano uint8_t, double); 858845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t, 859845561ecSwren romano uint64_t, double); 860845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t, 861845561ecSwren romano uint32_t, double); 862845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t, 863845561ecSwren romano uint16_t, double); 864845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t, 865845561ecSwren romano uint8_t, double); 866845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t, 867845561ecSwren romano uint64_t, double); 868845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t, 869845561ecSwren romano uint32_t, double); 870845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t, 871845561ecSwren romano uint16_t, double); 872845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t, 873845561ecSwren romano uint8_t, double); 874845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t, 875845561ecSwren romano uint64_t, double); 876845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t, 877845561ecSwren romano uint32_t, double); 878845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t, 879845561ecSwren romano uint16_t, double); 880845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t, 881845561ecSwren romano uint8_t, double); 8828a91bc7bSHarrietAkot 8838a91bc7bSHarrietAkot // Float matrices with all combinations of overhead storage. 884845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t, 885845561ecSwren romano uint64_t, float); 886845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t, 887845561ecSwren romano uint32_t, float); 888845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t, 889845561ecSwren romano uint16_t, float); 890845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t, 891845561ecSwren romano uint8_t, float); 892845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t, 893845561ecSwren romano uint64_t, float); 894845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t, 895845561ecSwren romano uint32_t, float); 896845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t, 897845561ecSwren romano uint16_t, float); 898845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t, 899845561ecSwren romano uint8_t, float); 900845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t, 901845561ecSwren romano uint64_t, float); 902845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t, 903845561ecSwren romano uint32_t, float); 904845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t, 905845561ecSwren romano uint16_t, float); 906845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t, 907845561ecSwren romano uint8_t, float); 908845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t, 909845561ecSwren romano uint64_t, float); 910845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t, 911845561ecSwren romano uint32_t, float); 912845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t, 913845561ecSwren romano uint16_t, float); 914845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t, 915845561ecSwren romano uint8_t, float); 9168a91bc7bSHarrietAkot 917845561ecSwren romano // Integral matrices with both overheads of the same type. 918845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t); 919845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t); 920845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t); 921845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t); 922845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t); 923845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t); 924845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t); 925845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t); 926845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t); 927845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t); 928845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t); 929845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t); 930845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t); 9318a91bc7bSHarrietAkot 9328a91bc7bSHarrietAkot // Unsupported case (add above if needed). 9338a91bc7bSHarrietAkot fputs("unsupported combination of types\n", stderr); 9348a91bc7bSHarrietAkot exit(1); 9358a91bc7bSHarrietAkot } 9368a91bc7bSHarrietAkot 9378a91bc7bSHarrietAkot /// Methods that provide direct access to pointers. 9388a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers, index_t, getPointers) 9398a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers) 9408a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers) 9418a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers) 9428a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers) 9438a91bc7bSHarrietAkot 9448a91bc7bSHarrietAkot /// Methods that provide direct access to indices. 9458a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices, index_t, getIndices) 9468a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices) 9478a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices) 9488a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices) 9498a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices) 9508a91bc7bSHarrietAkot 9518a91bc7bSHarrietAkot /// Methods that provide direct access to values. 9528a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF64, double, getValues) 9538a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF32, float, getValues) 9548a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues) 9558a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues) 9568a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues) 9578a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues) 9588a91bc7bSHarrietAkot 9598a91bc7bSHarrietAkot /// Helper to add value to coordinate scheme, one per value type. 9608a91bc7bSHarrietAkot IMPL_ADDELT(addEltF64, double) 9618a91bc7bSHarrietAkot IMPL_ADDELT(addEltF32, float) 9628a91bc7bSHarrietAkot IMPL_ADDELT(addEltI64, int64_t) 9638a91bc7bSHarrietAkot IMPL_ADDELT(addEltI32, int32_t) 9648a91bc7bSHarrietAkot IMPL_ADDELT(addEltI16, int16_t) 9658a91bc7bSHarrietAkot IMPL_ADDELT(addEltI8, int8_t) 9668a91bc7bSHarrietAkot 9678a91bc7bSHarrietAkot /// Helper to enumerate elements of coordinate scheme, one per value type. 9688a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF64, double) 9698a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF32, float) 9708a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI64, int64_t) 9718a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI32, int32_t) 9728a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI16, int16_t) 9738a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI8, int8_t) 9748a91bc7bSHarrietAkot 975f66e5769SAart Bik /// Helper to insert elements in lexicograph index order, one per value type. 976f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF64, double) 977f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF32, float) 978f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI64, int64_t) 979f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI32, int32_t) 980f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI16, int16_t) 981f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI8, int8_t) 982f66e5769SAart Bik 9834f2ec7f9SAart Bik /// Helper to insert using expansion, one per value type. 9844f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF64, double) 9854f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF32, float) 9864f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI64, int64_t) 9874f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI32, int32_t) 9884f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI16, int16_t) 9894f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI8, int8_t) 9904f2ec7f9SAart Bik 9918a91bc7bSHarrietAkot #undef CASE 9928a91bc7bSHarrietAkot #undef IMPL_SPARSEVALUES 9938a91bc7bSHarrietAkot #undef IMPL_GETOVERHEAD 9948a91bc7bSHarrietAkot #undef IMPL_ADDELT 9958a91bc7bSHarrietAkot #undef IMPL_GETNEXT 9964f2ec7f9SAart Bik #undef IMPL_LEXINSERT 9974f2ec7f9SAart Bik #undef IMPL_EXPINSERT 9988a91bc7bSHarrietAkot 9998a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 10008a91bc7bSHarrietAkot // 10018a91bc7bSHarrietAkot // Public API with methods that accept C-style data structures to interact 10028a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally. 10038a91bc7bSHarrietAkot // These methods can be used both by MLIR compiler-generated code as well as by 10048a91bc7bSHarrietAkot // an external runtime that wants to interact with MLIR compiler-generated code. 10058a91bc7bSHarrietAkot // 10068a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 10078a91bc7bSHarrietAkot 10088a91bc7bSHarrietAkot /// Helper method to read a sparse tensor filename from the environment, 10098a91bc7bSHarrietAkot /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc. 10108a91bc7bSHarrietAkot char *getTensorFilename(index_t id) { 10118a91bc7bSHarrietAkot char var[80]; 10128a91bc7bSHarrietAkot sprintf(var, "TENSOR%" PRIu64, id); 10138a91bc7bSHarrietAkot char *env = getenv(var); 10148a91bc7bSHarrietAkot return env; 10158a91bc7bSHarrietAkot } 10168a91bc7bSHarrietAkot 10178a91bc7bSHarrietAkot /// Returns size of sparse tensor in given dimension. 10188a91bc7bSHarrietAkot index_t sparseDimSize(void *tensor, index_t d) { 10198a91bc7bSHarrietAkot return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d); 10208a91bc7bSHarrietAkot } 10218a91bc7bSHarrietAkot 1022f66e5769SAart Bik /// Finalizes lexicographic insertions. 1023f66e5769SAart Bik void endInsert(void *tensor) { 1024f66e5769SAart Bik return static_cast<SparseTensorStorageBase *>(tensor)->endInsert(); 1025f66e5769SAart Bik } 1026f66e5769SAart Bik 10278a91bc7bSHarrietAkot /// Releases sparse tensor storage. 10288a91bc7bSHarrietAkot void delSparseTensor(void *tensor) { 10298a91bc7bSHarrietAkot delete static_cast<SparseTensorStorageBase *>(tensor); 10308a91bc7bSHarrietAkot } 10318a91bc7bSHarrietAkot 10328a91bc7bSHarrietAkot /// Initializes sparse tensor from a COO-flavored format expressed using C-style 10338a91bc7bSHarrietAkot /// data structures. The expected parameters are: 10348a91bc7bSHarrietAkot /// 10358a91bc7bSHarrietAkot /// rank: rank of tensor 10368a91bc7bSHarrietAkot /// nse: number of specified elements (usually the nonzeros) 10378a91bc7bSHarrietAkot /// shape: array with dimension size for each rank 10388a91bc7bSHarrietAkot /// values: a "nse" array with values for all specified elements 10398a91bc7bSHarrietAkot /// indices: a flat "nse x rank" array with indices for all specified elements 10408a91bc7bSHarrietAkot /// 10418a91bc7bSHarrietAkot /// For example, the sparse matrix 10428a91bc7bSHarrietAkot /// | 1.0 0.0 0.0 | 10438a91bc7bSHarrietAkot /// | 0.0 5.0 3.0 | 10448a91bc7bSHarrietAkot /// can be passed as 10458a91bc7bSHarrietAkot /// rank = 2 10468a91bc7bSHarrietAkot /// nse = 3 10478a91bc7bSHarrietAkot /// shape = [2, 3] 10488a91bc7bSHarrietAkot /// values = [1.0, 5.0, 3.0] 10498a91bc7bSHarrietAkot /// indices = [ 0, 0, 1, 1, 1, 2] 10508a91bc7bSHarrietAkot // 10518a91bc7bSHarrietAkot // TODO: for now f64 tensors only, no dim ordering, all dimensions compressed 10528a91bc7bSHarrietAkot // 10538a91bc7bSHarrietAkot void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, 10548a91bc7bSHarrietAkot double *values, uint64_t *indices) { 10558a91bc7bSHarrietAkot // Setup all-dims compressed and default ordering. 1056845561ecSwren romano std::vector<DimLevelType> sparse(rank, DimLevelType::kCompressed); 10578a91bc7bSHarrietAkot std::vector<uint64_t> perm(rank); 10588a91bc7bSHarrietAkot std::iota(perm.begin(), perm.end(), 0); 10598a91bc7bSHarrietAkot // Convert external format to internal COO. 10608a91bc7bSHarrietAkot SparseTensorCOO<double> *tensor = SparseTensorCOO<double>::newSparseTensorCOO( 10618a91bc7bSHarrietAkot rank, shape, perm.data(), nse); 10628a91bc7bSHarrietAkot std::vector<uint64_t> idx(rank); 10638a91bc7bSHarrietAkot for (uint64_t i = 0, base = 0; i < nse; i++) { 10648a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 10658a91bc7bSHarrietAkot idx[r] = indices[base + r]; 10668a91bc7bSHarrietAkot tensor->add(idx, values[i]); 10678a91bc7bSHarrietAkot base += rank; 10688a91bc7bSHarrietAkot } 10698a91bc7bSHarrietAkot // Return sparse tensor storage format as opaque pointer. 10708a91bc7bSHarrietAkot return SparseTensorStorage<uint64_t, uint64_t, double>::newSparseTensor( 10718a91bc7bSHarrietAkot rank, shape, perm.data(), sparse.data(), tensor); 10728a91bc7bSHarrietAkot } 10738a91bc7bSHarrietAkot 10748a91bc7bSHarrietAkot } // extern "C" 10758a91bc7bSHarrietAkot 10768a91bc7bSHarrietAkot #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 1077