18a91bc7bSHarrietAkot //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===// 28a91bc7bSHarrietAkot // 38a91bc7bSHarrietAkot // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48a91bc7bSHarrietAkot // See https://llvm.org/LICENSE.txt for license information. 58a91bc7bSHarrietAkot // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68a91bc7bSHarrietAkot // 78a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 88a91bc7bSHarrietAkot // 98a91bc7bSHarrietAkot // This file implements a light-weight runtime support library that is useful 108a91bc7bSHarrietAkot // for sparse tensor manipulations. The functionality provided in this library 118a91bc7bSHarrietAkot // is meant to simplify benchmarking, testing, and debugging MLIR code that 128a91bc7bSHarrietAkot // operates on sparse tensors. The provided functionality is **not** part 138a91bc7bSHarrietAkot // of core MLIR, however. 148a91bc7bSHarrietAkot // 158a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 168a91bc7bSHarrietAkot 17845561ecSwren romano #include "mlir/ExecutionEngine/SparseTensorUtils.h" 188a91bc7bSHarrietAkot #include "mlir/ExecutionEngine/CRunnerUtils.h" 198a91bc7bSHarrietAkot 208a91bc7bSHarrietAkot #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 218a91bc7bSHarrietAkot 228a91bc7bSHarrietAkot #include <algorithm> 238a91bc7bSHarrietAkot #include <cassert> 248a91bc7bSHarrietAkot #include <cctype> 258a91bc7bSHarrietAkot #include <cinttypes> 268a91bc7bSHarrietAkot #include <cstdio> 278a91bc7bSHarrietAkot #include <cstdlib> 288a91bc7bSHarrietAkot #include <cstring> 29efa15f41SAart Bik #include <fstream> 30efa15f41SAart Bik #include <iostream> 314d0a18d0Swren romano #include <limits> 328a91bc7bSHarrietAkot #include <numeric> 338a91bc7bSHarrietAkot #include <vector> 348a91bc7bSHarrietAkot 358a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 368a91bc7bSHarrietAkot // 378a91bc7bSHarrietAkot // Internal support for storing and reading sparse tensors. 388a91bc7bSHarrietAkot // 398a91bc7bSHarrietAkot // The following memory-resident sparse storage schemes are supported: 408a91bc7bSHarrietAkot // 418a91bc7bSHarrietAkot // (a) A coordinate scheme for temporarily storing and lexicographically 428a91bc7bSHarrietAkot // sorting a sparse tensor by index (SparseTensorCOO). 438a91bc7bSHarrietAkot // 448a91bc7bSHarrietAkot // (b) A "one-size-fits-all" sparse tensor storage scheme defined by 458a91bc7bSHarrietAkot // per-dimension sparse/dense annnotations together with a dimension 468a91bc7bSHarrietAkot // ordering used by MLIR compiler-generated code (SparseTensorStorage). 478a91bc7bSHarrietAkot // 488a91bc7bSHarrietAkot // The following external formats are supported: 498a91bc7bSHarrietAkot // 508a91bc7bSHarrietAkot // (1) Matrix Market Exchange (MME): *.mtx 518a91bc7bSHarrietAkot // https://math.nist.gov/MatrixMarket/formats.html 528a91bc7bSHarrietAkot // 538a91bc7bSHarrietAkot // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns 548a91bc7bSHarrietAkot // http://frostt.io/tensors/file-formats.html 558a91bc7bSHarrietAkot // 568a91bc7bSHarrietAkot // Two public APIs are supported: 578a91bc7bSHarrietAkot // 588a91bc7bSHarrietAkot // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse 598a91bc7bSHarrietAkot // tensors. These methods should be used exclusively by MLIR 608a91bc7bSHarrietAkot // compiler-generated code. 618a91bc7bSHarrietAkot // 628a91bc7bSHarrietAkot // (II) Methods that accept C-style data structures to interact with sparse 638a91bc7bSHarrietAkot // tensors. These methods can be used by any external runtime that wants 648a91bc7bSHarrietAkot // to interact with MLIR compiler-generated code. 658a91bc7bSHarrietAkot // 668a91bc7bSHarrietAkot // In both cases (I) and (II), the SparseTensorStorage format is externally 678a91bc7bSHarrietAkot // only visible as an opaque pointer. 688a91bc7bSHarrietAkot // 698a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 708a91bc7bSHarrietAkot 718a91bc7bSHarrietAkot namespace { 728a91bc7bSHarrietAkot 7303fe15ceSAart Bik static constexpr int kColWidth = 1025; 7403fe15ceSAart Bik 7572ec2f76Swren romano /// A version of `operator*` on `uint64_t` which checks for overflows. 7672ec2f76Swren romano static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) { 7772ec2f76Swren romano assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) && 7872ec2f76Swren romano "Integer overflow"); 7972ec2f76Swren romano return lhs * rhs; 8072ec2f76Swren romano } 8172ec2f76Swren romano 828a91bc7bSHarrietAkot /// A sparse tensor element in coordinate scheme (value and indices). 838a91bc7bSHarrietAkot /// For example, a rank-1 vector element would look like 848a91bc7bSHarrietAkot /// ({i}, a[i]) 858a91bc7bSHarrietAkot /// and a rank-5 tensor element like 868a91bc7bSHarrietAkot /// ({i,j,k,l,m}, a[i,j,k,l,m]) 87*ccd047cbSAart Bik /// We use pointer to a shared index pool rather than e.g. a direct 88*ccd047cbSAart Bik /// vector since that (1) reduces the per-element memory footprint, and 89*ccd047cbSAart Bik /// (2) centralizes the memory reservation and (re)allocation to one place. 908a91bc7bSHarrietAkot template <typename V> 918a91bc7bSHarrietAkot struct Element { 92*ccd047cbSAart Bik Element(uint64_t *ind, V val) : indices(ind), value(val){}; 93*ccd047cbSAart Bik uint64_t *indices; // pointer into shared index pool 948a91bc7bSHarrietAkot V value; 958a91bc7bSHarrietAkot }; 968a91bc7bSHarrietAkot 978a91bc7bSHarrietAkot /// A memory-resident sparse tensor in coordinate scheme (collection of 988a91bc7bSHarrietAkot /// elements). This data structure is used to read a sparse tensor from 998a91bc7bSHarrietAkot /// any external format into memory and sort the elements lexicographically 1008a91bc7bSHarrietAkot /// by indices before passing it back to the client (most packed storage 1018a91bc7bSHarrietAkot /// formats require the elements to appear in lexicographic index order). 1028a91bc7bSHarrietAkot template <typename V> 1038a91bc7bSHarrietAkot struct SparseTensorCOO { 1048a91bc7bSHarrietAkot public: 1058a91bc7bSHarrietAkot SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity) 106db6796dfSMehdi Amini : sizes(szs) { 107*ccd047cbSAart Bik if (capacity) { 1088a91bc7bSHarrietAkot elements.reserve(capacity); 109*ccd047cbSAart Bik indices.reserve(capacity * getRank()); 1108a91bc7bSHarrietAkot } 111*ccd047cbSAart Bik } 112*ccd047cbSAart Bik 1138a91bc7bSHarrietAkot /// Adds element as indices and value. 1148a91bc7bSHarrietAkot void add(const std::vector<uint64_t> &ind, V val) { 1158a91bc7bSHarrietAkot assert(!iteratorLocked && "Attempt to add() after startIterator()"); 116*ccd047cbSAart Bik uint64_t *base = indices.data(); 117*ccd047cbSAart Bik uint64_t size = indices.size(); 1188a91bc7bSHarrietAkot uint64_t rank = getRank(); 1198a91bc7bSHarrietAkot assert(rank == ind.size()); 120*ccd047cbSAart Bik for (uint64_t r = 0; r < rank; r++) { 1218a91bc7bSHarrietAkot assert(ind[r] < sizes[r]); // within bounds 122*ccd047cbSAart Bik indices.push_back(ind[r]); 1238a91bc7bSHarrietAkot } 124*ccd047cbSAart Bik // This base only changes if indices were reallocated. In that case, we 125*ccd047cbSAart Bik // need to correct all previous pointers into the vector. Note that this 126*ccd047cbSAart Bik // only happens if we did not set the initial capacity right, and then only 127*ccd047cbSAart Bik // for every internal vector reallocation (which with the doubling rule 128*ccd047cbSAart Bik // should only incur an amortized linear overhead). 129*ccd047cbSAart Bik uint64_t *new_base = indices.data(); 130*ccd047cbSAart Bik if (new_base != base) { 131*ccd047cbSAart Bik for (uint64_t i = 0, n = elements.size(); i < n; i++) 132*ccd047cbSAart Bik elements[i].indices = new_base + (elements[i].indices - base); 133*ccd047cbSAart Bik base = new_base; 134*ccd047cbSAart Bik } 135*ccd047cbSAart Bik // Add element as (pointer into shared index pool, value) pair. 136*ccd047cbSAart Bik elements.emplace_back(base + size, val); 137*ccd047cbSAart Bik } 138*ccd047cbSAart Bik 1398a91bc7bSHarrietAkot /// Sorts elements lexicographically by index. 1408a91bc7bSHarrietAkot void sort() { 1418a91bc7bSHarrietAkot assert(!iteratorLocked && "Attempt to sort() after startIterator()"); 142cf358253Swren romano // TODO: we may want to cache an `isSorted` bit, to avoid 143cf358253Swren romano // unnecessary/redundant sorting. 144*ccd047cbSAart Bik std::sort(elements.begin(), elements.end(), 145*ccd047cbSAart Bik [this](const Element<V> &e1, const Element<V> &e2) { 146*ccd047cbSAart Bik uint64_t rank = getRank(); 147*ccd047cbSAart Bik for (uint64_t r = 0; r < rank; r++) { 148*ccd047cbSAart Bik if (e1.indices[r] == e2.indices[r]) 149*ccd047cbSAart Bik continue; 150*ccd047cbSAart Bik return e1.indices[r] < e2.indices[r]; 1518a91bc7bSHarrietAkot } 152*ccd047cbSAart Bik return false; 153*ccd047cbSAart Bik }); 154*ccd047cbSAart Bik } 155*ccd047cbSAart Bik 1568a91bc7bSHarrietAkot /// Returns rank. 1578a91bc7bSHarrietAkot uint64_t getRank() const { return sizes.size(); } 158*ccd047cbSAart Bik 1598a91bc7bSHarrietAkot /// Getter for sizes array. 1608a91bc7bSHarrietAkot const std::vector<uint64_t> &getSizes() const { return sizes; } 161*ccd047cbSAart Bik 1628a91bc7bSHarrietAkot /// Getter for elements array. 1638a91bc7bSHarrietAkot const std::vector<Element<V>> &getElements() const { return elements; } 1648a91bc7bSHarrietAkot 1658a91bc7bSHarrietAkot /// Switch into iterator mode. 1668a91bc7bSHarrietAkot void startIterator() { 1678a91bc7bSHarrietAkot iteratorLocked = true; 1688a91bc7bSHarrietAkot iteratorPos = 0; 1698a91bc7bSHarrietAkot } 170*ccd047cbSAart Bik 1718a91bc7bSHarrietAkot /// Get the next element. 1728a91bc7bSHarrietAkot const Element<V> *getNext() { 1738a91bc7bSHarrietAkot assert(iteratorLocked && "Attempt to getNext() before startIterator()"); 1748a91bc7bSHarrietAkot if (iteratorPos < elements.size()) 1758a91bc7bSHarrietAkot return &(elements[iteratorPos++]); 1768a91bc7bSHarrietAkot iteratorLocked = false; 1778a91bc7bSHarrietAkot return nullptr; 1788a91bc7bSHarrietAkot } 1798a91bc7bSHarrietAkot 1808a91bc7bSHarrietAkot /// Factory method. Permutes the original dimensions according to 1818a91bc7bSHarrietAkot /// the given ordering and expects subsequent add() calls to honor 1828a91bc7bSHarrietAkot /// that same ordering for the given indices. The result is a 1838a91bc7bSHarrietAkot /// fully permuted coordinate scheme. 1848d8b566fSwren romano /// 1858d8b566fSwren romano /// Precondition: `sizes` and `perm` must be valid for `rank`. 1868a91bc7bSHarrietAkot static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank, 1878a91bc7bSHarrietAkot const uint64_t *sizes, 1888a91bc7bSHarrietAkot const uint64_t *perm, 1898a91bc7bSHarrietAkot uint64_t capacity = 0) { 1908a91bc7bSHarrietAkot std::vector<uint64_t> permsz(rank); 191d83a7068Swren romano for (uint64_t r = 0; r < rank; r++) { 192d83a7068Swren romano assert(sizes[r] > 0 && "Dimension size zero has trivial storage"); 1938a91bc7bSHarrietAkot permsz[perm[r]] = sizes[r]; 194d83a7068Swren romano } 1958a91bc7bSHarrietAkot return new SparseTensorCOO<V>(permsz, capacity); 1968a91bc7bSHarrietAkot } 1978a91bc7bSHarrietAkot 1988a91bc7bSHarrietAkot private: 1998a91bc7bSHarrietAkot const std::vector<uint64_t> sizes; // per-dimension sizes 200*ccd047cbSAart Bik std::vector<Element<V>> elements; // all COO elements 201*ccd047cbSAart Bik std::vector<uint64_t> indices; // shared index pool 202db6796dfSMehdi Amini bool iteratorLocked = false; 203db6796dfSMehdi Amini unsigned iteratorPos = 0; 2048a91bc7bSHarrietAkot }; 2058a91bc7bSHarrietAkot 2068d8b566fSwren romano /// Abstract base class for `SparseTensorStorage<P,I,V>`. This class 2078d8b566fSwren romano /// takes responsibility for all the `<P,I,V>`-independent aspects 2088d8b566fSwren romano /// of the tensor (e.g., shape, sparsity, permutation). In addition, 2098d8b566fSwren romano /// we use function overloading to implement "partial" method 2108d8b566fSwren romano /// specialization, which the C-API relies on to catch type errors 2118d8b566fSwren romano /// arising from our use of opaque pointers. 2128a91bc7bSHarrietAkot class SparseTensorStorageBase { 2138a91bc7bSHarrietAkot public: 2148d8b566fSwren romano /// Constructs a new storage object. The `perm` maps the tensor's 2158d8b566fSwren romano /// semantic-ordering of dimensions to this object's storage-order. 2168d8b566fSwren romano /// The `szs` and `sparsity` arrays are already in storage-order. 2178d8b566fSwren romano /// 2188d8b566fSwren romano /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`. 2198d8b566fSwren romano SparseTensorStorageBase(const std::vector<uint64_t> &szs, 2208d8b566fSwren romano const uint64_t *perm, const DimLevelType *sparsity) 2218d8b566fSwren romano : dimSizes(szs), rev(getRank()), 2228d8b566fSwren romano dimTypes(sparsity, sparsity + getRank()) { 2238d8b566fSwren romano const uint64_t rank = getRank(); 2248d8b566fSwren romano // Validate parameters. 2258d8b566fSwren romano assert(rank > 0 && "Trivial shape is unsupported"); 2268d8b566fSwren romano for (uint64_t r = 0; r < rank; r++) { 2278d8b566fSwren romano assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage"); 2288d8b566fSwren romano assert((dimTypes[r] == DimLevelType::kDense || 2298d8b566fSwren romano dimTypes[r] == DimLevelType::kCompressed) && 2308d8b566fSwren romano "Unsupported DimLevelType"); 2318d8b566fSwren romano } 2328d8b566fSwren romano // Construct the "reverse" (i.e., inverse) permutation. 2338d8b566fSwren romano for (uint64_t r = 0; r < rank; r++) 2348d8b566fSwren romano rev[perm[r]] = r; 2358d8b566fSwren romano } 2368d8b566fSwren romano 2378d8b566fSwren romano virtual ~SparseTensorStorageBase() = default; 2388d8b566fSwren romano 2398d8b566fSwren romano /// Get the rank of the tensor. 2408d8b566fSwren romano uint64_t getRank() const { return dimSizes.size(); } 2418d8b566fSwren romano 2428d8b566fSwren romano /// Getter for the dimension-sizes array, in storage-order. 2438d8b566fSwren romano const std::vector<uint64_t> &getDimSizes() const { return dimSizes; } 2448d8b566fSwren romano 2458d8b566fSwren romano /// Safely lookup the size of the given (storage-order) dimension. 2468d8b566fSwren romano uint64_t getDimSize(uint64_t d) const { 2478d8b566fSwren romano assert(d < getRank()); 2488d8b566fSwren romano return dimSizes[d]; 2498d8b566fSwren romano } 2508d8b566fSwren romano 2518d8b566fSwren romano /// Getter for the "reverse" permutation, which maps this object's 2528d8b566fSwren romano /// storage-order to the tensor's semantic-order. 2538d8b566fSwren romano const std::vector<uint64_t> &getRev() const { return rev; } 2548d8b566fSwren romano 2558d8b566fSwren romano /// Getter for the dimension-types array, in storage-order. 2568d8b566fSwren romano const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; } 2578d8b566fSwren romano 2588d8b566fSwren romano /// Safely check if the (storage-order) dimension uses compressed storage. 2598d8b566fSwren romano bool isCompressedDim(uint64_t d) const { 2608d8b566fSwren romano assert(d < getRank()); 2618d8b566fSwren romano return (dimTypes[d] == DimLevelType::kCompressed); 2628d8b566fSwren romano } 2638a91bc7bSHarrietAkot 2644f2ec7f9SAart Bik /// Overhead storage. 2658a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); } 2668a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); } 2678a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); } 2688a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); } 2698a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); } 2708a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); } 2718a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); } 2728a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); } 2738a91bc7bSHarrietAkot 2744f2ec7f9SAart Bik /// Primary storage. 2758a91bc7bSHarrietAkot virtual void getValues(std::vector<double> **) { fatal("valf64"); } 2768a91bc7bSHarrietAkot virtual void getValues(std::vector<float> **) { fatal("valf32"); } 2778a91bc7bSHarrietAkot virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); } 2788a91bc7bSHarrietAkot virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); } 2798a91bc7bSHarrietAkot virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); } 2808a91bc7bSHarrietAkot virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); } 2818a91bc7bSHarrietAkot 2824f2ec7f9SAart Bik /// Element-wise insertion in lexicographic index order. 283c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); } 284c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); } 285c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); } 286c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); } 287c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); } 288c03fd1e6Swren romano virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); } 2894f2ec7f9SAart Bik 2904f2ec7f9SAart Bik /// Expanded insertion. 2914f2ec7f9SAart Bik virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) { 2924f2ec7f9SAart Bik fatal("expf64"); 2934f2ec7f9SAart Bik } 2944f2ec7f9SAart Bik virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) { 2954f2ec7f9SAart Bik fatal("expf32"); 2964f2ec7f9SAart Bik } 2974f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) { 2984f2ec7f9SAart Bik fatal("expi64"); 2994f2ec7f9SAart Bik } 3004f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) { 3014f2ec7f9SAart Bik fatal("expi32"); 3024f2ec7f9SAart Bik } 3034f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) { 3044f2ec7f9SAart Bik fatal("expi16"); 3054f2ec7f9SAart Bik } 3064f2ec7f9SAart Bik virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) { 3074f2ec7f9SAart Bik fatal("expi8"); 3084f2ec7f9SAart Bik } 3094f2ec7f9SAart Bik 3104f2ec7f9SAart Bik /// Finishes insertion. 311f66e5769SAart Bik virtual void endInsert() = 0; 312f66e5769SAart Bik 3138a91bc7bSHarrietAkot private: 31446bdacaaSwren romano static void fatal(const char *tp) { 3158a91bc7bSHarrietAkot fprintf(stderr, "unsupported %s\n", tp); 3168a91bc7bSHarrietAkot exit(1); 3178a91bc7bSHarrietAkot } 3188d8b566fSwren romano 3198d8b566fSwren romano const std::vector<uint64_t> dimSizes; 3208d8b566fSwren romano std::vector<uint64_t> rev; 3218d8b566fSwren romano const std::vector<DimLevelType> dimTypes; 3228a91bc7bSHarrietAkot }; 3238a91bc7bSHarrietAkot 3248a91bc7bSHarrietAkot /// A memory-resident sparse tensor using a storage scheme based on 3258a91bc7bSHarrietAkot /// per-dimension sparse/dense annotations. This data structure provides a 3268a91bc7bSHarrietAkot /// bufferized form of a sparse tensor type. In contrast to generating setup 3278a91bc7bSHarrietAkot /// methods for each differently annotated sparse tensor, this method provides 3288a91bc7bSHarrietAkot /// a convenient "one-size-fits-all" solution that simply takes an input tensor 3298a91bc7bSHarrietAkot /// and annotations to implement all required setup in a general manner. 3308a91bc7bSHarrietAkot template <typename P, typename I, typename V> 3318a91bc7bSHarrietAkot class SparseTensorStorage : public SparseTensorStorageBase { 3328a91bc7bSHarrietAkot public: 3338a91bc7bSHarrietAkot /// Constructs a sparse tensor storage scheme with the given dimensions, 3348a91bc7bSHarrietAkot /// permutation, and per-dimension dense/sparse annotations, using 3358a91bc7bSHarrietAkot /// the coordinate scheme tensor for the initial contents if provided. 3368d8b566fSwren romano /// 3378d8b566fSwren romano /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`. 3388a91bc7bSHarrietAkot SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm, 339f66e5769SAart Bik const DimLevelType *sparsity, 3408d8b566fSwren romano SparseTensorCOO<V> *coo = nullptr) 3418d8b566fSwren romano : SparseTensorStorageBase(szs, perm, sparsity), pointers(getRank()), 3428d8b566fSwren romano indices(getRank()), idx(getRank()) { 3438a91bc7bSHarrietAkot // Provide hints on capacity of pointers and indices. 344175b9af4SAart Bik // TODO: needs much fine-tuning based on actual sparsity; currently 345175b9af4SAart Bik // we reserve pointer/index space based on all previous dense 346175b9af4SAart Bik // dimensions, which works well up to first sparse dim; but 347175b9af4SAart Bik // we should really use nnz and dense/sparse distribution. 348f66e5769SAart Bik bool allDense = true; 349f66e5769SAart Bik uint64_t sz = 1; 3508d8b566fSwren romano for (uint64_t r = 0, rank = getRank(); r < rank; r++) { 3518d8b566fSwren romano if (isCompressedDim(r)) { 3528d8b566fSwren romano // TODO: Take a parameter between 1 and `sizes[r]`, and multiply 3538d8b566fSwren romano // `sz` by that before reserving. (For now we just use 1.) 354f66e5769SAart Bik pointers[r].reserve(sz + 1); 3558d8b566fSwren romano pointers[r].push_back(0); 356f66e5769SAart Bik indices[r].reserve(sz); 357f66e5769SAart Bik sz = 1; 358f66e5769SAart Bik allDense = false; 3598d8b566fSwren romano } else { // Dense dimension. 3608d8b566fSwren romano sz = checkedMul(sz, getDimSizes()[r]); 3618a91bc7bSHarrietAkot } 3628a91bc7bSHarrietAkot } 3638a91bc7bSHarrietAkot // Then assign contents from coordinate scheme tensor if provided. 3648d8b566fSwren romano if (coo) { 3654d0a18d0Swren romano // Ensure both preconditions of `fromCOO`. 3668d8b566fSwren romano assert(coo->getSizes() == getDimSizes() && "Tensor size mismatch"); 3678d8b566fSwren romano coo->sort(); 3684d0a18d0Swren romano // Now actually insert the `elements`. 3698d8b566fSwren romano const std::vector<Element<V>> &elements = coo->getElements(); 370ceda1ae9Swren romano uint64_t nnz = elements.size(); 3718a91bc7bSHarrietAkot values.reserve(nnz); 372ceda1ae9Swren romano fromCOO(elements, 0, nnz, 0); 3731ce77b56SAart Bik } else if (allDense) { 374f66e5769SAart Bik values.resize(sz, 0); 3758a91bc7bSHarrietAkot } 3768a91bc7bSHarrietAkot } 3778a91bc7bSHarrietAkot 3780ae2e958SMehdi Amini ~SparseTensorStorage() override = default; 3798a91bc7bSHarrietAkot 380f66e5769SAart Bik /// Partially specialize these getter methods based on template types. 3818a91bc7bSHarrietAkot void getPointers(std::vector<P> **out, uint64_t d) override { 3828a91bc7bSHarrietAkot assert(d < getRank()); 3838a91bc7bSHarrietAkot *out = &pointers[d]; 3848a91bc7bSHarrietAkot } 3858a91bc7bSHarrietAkot void getIndices(std::vector<I> **out, uint64_t d) override { 3868a91bc7bSHarrietAkot assert(d < getRank()); 3878a91bc7bSHarrietAkot *out = &indices[d]; 3888a91bc7bSHarrietAkot } 3898a91bc7bSHarrietAkot void getValues(std::vector<V> **out) override { *out = &values; } 3908a91bc7bSHarrietAkot 39103fe15ceSAart Bik /// Partially specialize lexicographical insertions based on template types. 392c03fd1e6Swren romano void lexInsert(const uint64_t *cursor, V val) override { 3931ce77b56SAart Bik // First, wrap up pending insertion path. 3941ce77b56SAart Bik uint64_t diff = 0; 3951ce77b56SAart Bik uint64_t top = 0; 3961ce77b56SAart Bik if (!values.empty()) { 3971ce77b56SAart Bik diff = lexDiff(cursor); 3981ce77b56SAart Bik endPath(diff + 1); 3991ce77b56SAart Bik top = idx[diff] + 1; 4001ce77b56SAart Bik } 4011ce77b56SAart Bik // Then continue with insertion path. 4021ce77b56SAart Bik insPath(cursor, diff, top, val); 403f66e5769SAart Bik } 404f66e5769SAart Bik 4054f2ec7f9SAart Bik /// Partially specialize expanded insertions based on template types. 4064f2ec7f9SAart Bik /// Note that this method resets the values/filled-switch array back 4074f2ec7f9SAart Bik /// to all-zero/false while only iterating over the nonzero elements. 4084f2ec7f9SAart Bik void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added, 4094f2ec7f9SAart Bik uint64_t count) override { 4104f2ec7f9SAart Bik if (count == 0) 4114f2ec7f9SAart Bik return; 4124f2ec7f9SAart Bik // Sort. 4134f2ec7f9SAart Bik std::sort(added, added + count); 4144f2ec7f9SAart Bik // Restore insertion path for first insert. 4153bf2ba3bSwren romano const uint64_t lastDim = getRank() - 1; 4164f2ec7f9SAart Bik uint64_t index = added[0]; 4173bf2ba3bSwren romano cursor[lastDim] = index; 4184f2ec7f9SAart Bik lexInsert(cursor, values[index]); 4194f2ec7f9SAart Bik assert(filled[index]); 4204f2ec7f9SAart Bik values[index] = 0; 4214f2ec7f9SAart Bik filled[index] = false; 4224f2ec7f9SAart Bik // Subsequent insertions are quick. 4234f2ec7f9SAart Bik for (uint64_t i = 1; i < count; i++) { 4244f2ec7f9SAart Bik assert(index < added[i] && "non-lexicographic insertion"); 4254f2ec7f9SAart Bik index = added[i]; 4263bf2ba3bSwren romano cursor[lastDim] = index; 4273bf2ba3bSwren romano insPath(cursor, lastDim, added[i - 1] + 1, values[index]); 4284f2ec7f9SAart Bik assert(filled[index]); 4293bf2ba3bSwren romano values[index] = 0; 4304f2ec7f9SAart Bik filled[index] = false; 4314f2ec7f9SAart Bik } 4324f2ec7f9SAart Bik } 4334f2ec7f9SAart Bik 434f66e5769SAart Bik /// Finalizes lexicographic insertions. 4351ce77b56SAart Bik void endInsert() override { 4361ce77b56SAart Bik if (values.empty()) 43772ec2f76Swren romano finalizeSegment(0); 4381ce77b56SAart Bik else 4391ce77b56SAart Bik endPath(0); 4401ce77b56SAart Bik } 441f66e5769SAart Bik 4428a91bc7bSHarrietAkot /// Returns this sparse tensor storage scheme as a new memory-resident 4438a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme with the given dimension order. 4448d8b566fSwren romano /// 4458d8b566fSwren romano /// Precondition: `perm` must be valid for `getRank()`. 4468a91bc7bSHarrietAkot SparseTensorCOO<V> *toCOO(const uint64_t *perm) { 4478a91bc7bSHarrietAkot // Restore original order of the dimension sizes and allocate coordinate 4488a91bc7bSHarrietAkot // scheme with desired new ordering specified in perm. 4498d8b566fSwren romano const uint64_t rank = getRank(); 4508d8b566fSwren romano const auto &rev = getRev(); 4518d8b566fSwren romano const auto &sizes = getDimSizes(); 4528a91bc7bSHarrietAkot std::vector<uint64_t> orgsz(rank); 4538a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 4548a91bc7bSHarrietAkot orgsz[rev[r]] = sizes[r]; 4558d8b566fSwren romano SparseTensorCOO<V> *coo = SparseTensorCOO<V>::newSparseTensorCOO( 4568a91bc7bSHarrietAkot rank, orgsz.data(), perm, values.size()); 4578a91bc7bSHarrietAkot // Populate coordinate scheme restored from old ordering and changed with 4588a91bc7bSHarrietAkot // new ordering. Rather than applying both reorderings during the recursion, 4598a91bc7bSHarrietAkot // we compute the combine permutation in advance. 4608a91bc7bSHarrietAkot std::vector<uint64_t> reord(rank); 4618a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 4628a91bc7bSHarrietAkot reord[r] = perm[rev[r]]; 4638d8b566fSwren romano toCOO(*coo, reord, 0, 0); 4648d8b566fSwren romano // TODO: This assertion assumes there are no stored zeros, 4658d8b566fSwren romano // or if there are then that we don't filter them out. 4668d8b566fSwren romano // Cf., <https://github.com/llvm/llvm-project/issues/54179> 4678d8b566fSwren romano assert(coo->getElements().size() == values.size()); 4688d8b566fSwren romano return coo; 4698a91bc7bSHarrietAkot } 4708a91bc7bSHarrietAkot 4718a91bc7bSHarrietAkot /// Factory method. Constructs a sparse tensor storage scheme with the given 4728a91bc7bSHarrietAkot /// dimensions, permutation, and per-dimension dense/sparse annotations, 4738a91bc7bSHarrietAkot /// using the coordinate scheme tensor for the initial contents if provided. 4748a91bc7bSHarrietAkot /// In the latter case, the coordinate scheme must respect the same 4758a91bc7bSHarrietAkot /// permutation as is desired for the new sparse tensor storage. 4768d8b566fSwren romano /// 4778d8b566fSwren romano /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`. 4788a91bc7bSHarrietAkot static SparseTensorStorage<P, I, V> * 479d83a7068Swren romano newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm, 4808d8b566fSwren romano const DimLevelType *sparsity, SparseTensorCOO<V> *coo) { 4818a91bc7bSHarrietAkot SparseTensorStorage<P, I, V> *n = nullptr; 4828d8b566fSwren romano if (coo) { 4838d8b566fSwren romano assert(coo->getRank() == rank && "Tensor rank mismatch"); 4848d8b566fSwren romano const auto &coosz = coo->getSizes(); 4858a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 4868d8b566fSwren romano assert(shape[r] == 0 || shape[r] == coosz[perm[r]]); 4878d8b566fSwren romano n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo); 4888a91bc7bSHarrietAkot } else { 4898a91bc7bSHarrietAkot std::vector<uint64_t> permsz(rank); 490d83a7068Swren romano for (uint64_t r = 0; r < rank; r++) { 491d83a7068Swren romano assert(shape[r] > 0 && "Dimension size zero has trivial storage"); 492d83a7068Swren romano permsz[perm[r]] = shape[r]; 493d83a7068Swren romano } 494f66e5769SAart Bik n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity); 4958a91bc7bSHarrietAkot } 4968a91bc7bSHarrietAkot return n; 4978a91bc7bSHarrietAkot } 4988a91bc7bSHarrietAkot 4998a91bc7bSHarrietAkot private: 50072ec2f76Swren romano /// Appends an arbitrary new position to `pointers[d]`. This method 50172ec2f76Swren romano /// checks that `pos` is representable in the `P` type; however, it 50272ec2f76Swren romano /// does not check that `pos` is semantically valid (i.e., larger than 50372ec2f76Swren romano /// the previous position and smaller than `indices[d].capacity()`). 5048d8b566fSwren romano void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) { 50572ec2f76Swren romano assert(isCompressedDim(d)); 50672ec2f76Swren romano assert(pos <= std::numeric_limits<P>::max() && 5074d0a18d0Swren romano "Pointer value is too large for the P-type"); 50872ec2f76Swren romano pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos)); 5094d0a18d0Swren romano } 5104d0a18d0Swren romano 51172ec2f76Swren romano /// Appends index `i` to dimension `d`, in the semantically general 51272ec2f76Swren romano /// sense. For non-dense dimensions, that means appending to the 51372ec2f76Swren romano /// `indices[d]` array, checking that `i` is representable in the `I` 51472ec2f76Swren romano /// type; however, we do not verify other semantic requirements (e.g., 51572ec2f76Swren romano /// that `i` is in bounds for `sizes[d]`, and not previously occurring 51672ec2f76Swren romano /// in the same segment). For dense dimensions, this method instead 51772ec2f76Swren romano /// appends the appropriate number of zeros to the `values` array, 51872ec2f76Swren romano /// where `full` is the number of "entries" already written to `values` 51972ec2f76Swren romano /// for this segment (aka one after the highest index previously appended). 52072ec2f76Swren romano void appendIndex(uint64_t d, uint64_t full, uint64_t i) { 52172ec2f76Swren romano if (isCompressedDim(d)) { 5224d0a18d0Swren romano assert(i <= std::numeric_limits<I>::max() && 5234d0a18d0Swren romano "Index value is too large for the I-type"); 52472ec2f76Swren romano indices[d].push_back(static_cast<I>(i)); 52572ec2f76Swren romano } else { // Dense dimension. 52672ec2f76Swren romano assert(i >= full && "Index was already filled"); 52772ec2f76Swren romano if (i == full) 52872ec2f76Swren romano return; // Short-circuit, since it'll be a nop. 52972ec2f76Swren romano if (d + 1 == getRank()) 53072ec2f76Swren romano values.insert(values.end(), i - full, 0); 53172ec2f76Swren romano else 53272ec2f76Swren romano finalizeSegment(d + 1, 0, i - full); 53372ec2f76Swren romano } 5344d0a18d0Swren romano } 5354d0a18d0Swren romano 5368a91bc7bSHarrietAkot /// Initializes sparse tensor storage scheme from a memory-resident sparse 5378a91bc7bSHarrietAkot /// tensor in coordinate scheme. This method prepares the pointers and 5388a91bc7bSHarrietAkot /// indices arrays under the given per-dimension dense/sparse annotations. 5394d0a18d0Swren romano /// 5404d0a18d0Swren romano /// Preconditions: 5414d0a18d0Swren romano /// (1) the `elements` must be lexicographically sorted. 5424d0a18d0Swren romano /// (2) the indices of every element are valid for `sizes` (equal rank 5434d0a18d0Swren romano /// and pointwise less-than). 544ceda1ae9Swren romano void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo, 545ceda1ae9Swren romano uint64_t hi, uint64_t d) { 5468a91bc7bSHarrietAkot // Once dimensions are exhausted, insert the numerical values. 547c4017f9dSwren romano assert(d <= getRank() && hi <= elements.size()); 5488a91bc7bSHarrietAkot if (d == getRank()) { 549c4017f9dSwren romano assert(lo < hi); 5501ce77b56SAart Bik values.push_back(elements[lo].value); 5518a91bc7bSHarrietAkot return; 5528a91bc7bSHarrietAkot } 5538a91bc7bSHarrietAkot // Visit all elements in this interval. 5548a91bc7bSHarrietAkot uint64_t full = 0; 555c4017f9dSwren romano while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`. 5568a91bc7bSHarrietAkot // Find segment in interval with same index elements in this dimension. 557f66e5769SAart Bik uint64_t i = elements[lo].indices[d]; 5588a91bc7bSHarrietAkot uint64_t seg = lo + 1; 559f66e5769SAart Bik while (seg < hi && elements[seg].indices[d] == i) 5608a91bc7bSHarrietAkot seg++; 5618a91bc7bSHarrietAkot // Handle segment in interval for sparse or dense dimension. 56272ec2f76Swren romano appendIndex(d, full, i); 56372ec2f76Swren romano full = i + 1; 564ceda1ae9Swren romano fromCOO(elements, lo, seg, d + 1); 5658a91bc7bSHarrietAkot // And move on to next segment in interval. 5668a91bc7bSHarrietAkot lo = seg; 5678a91bc7bSHarrietAkot } 5688a91bc7bSHarrietAkot // Finalize the sparse pointer structure at this dimension. 56972ec2f76Swren romano finalizeSegment(d, full); 5708a91bc7bSHarrietAkot } 5718a91bc7bSHarrietAkot 5728a91bc7bSHarrietAkot /// Stores the sparse tensor storage scheme into a memory-resident sparse 5738a91bc7bSHarrietAkot /// tensor in coordinate scheme. 574ceda1ae9Swren romano void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord, 575f66e5769SAart Bik uint64_t pos, uint64_t d) { 5768a91bc7bSHarrietAkot assert(d <= getRank()); 5778a91bc7bSHarrietAkot if (d == getRank()) { 5788a91bc7bSHarrietAkot assert(pos < values.size()); 579ceda1ae9Swren romano tensor.add(idx, values[pos]); 5801ce77b56SAart Bik } else if (isCompressedDim(d)) { 5818a91bc7bSHarrietAkot // Sparse dimension. 5828a91bc7bSHarrietAkot for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) { 5838a91bc7bSHarrietAkot idx[reord[d]] = indices[d][ii]; 584f66e5769SAart Bik toCOO(tensor, reord, ii, d + 1); 5858a91bc7bSHarrietAkot } 5861ce77b56SAart Bik } else { 5871ce77b56SAart Bik // Dense dimension. 5888d8b566fSwren romano const uint64_t sz = getDimSizes()[d]; 5898d8b566fSwren romano const uint64_t off = pos * sz; 5908d8b566fSwren romano for (uint64_t i = 0; i < sz; i++) { 5911ce77b56SAart Bik idx[reord[d]] = i; 5921ce77b56SAart Bik toCOO(tensor, reord, off + i, d + 1); 5938a91bc7bSHarrietAkot } 5948a91bc7bSHarrietAkot } 5951ce77b56SAart Bik } 5961ce77b56SAart Bik 59772ec2f76Swren romano /// Finalize the sparse pointer structure at this dimension. 59872ec2f76Swren romano void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) { 59972ec2f76Swren romano if (count == 0) 60072ec2f76Swren romano return; // Short-circuit, since it'll be a nop. 60172ec2f76Swren romano if (isCompressedDim(d)) { 60272ec2f76Swren romano appendPointer(d, indices[d].size(), count); 60372ec2f76Swren romano } else { // Dense dimension. 6048d8b566fSwren romano const uint64_t sz = getDimSizes()[d]; 60572ec2f76Swren romano assert(sz >= full && "Segment is overfull"); 6068d8b566fSwren romano count = checkedMul(count, sz - full); 60772ec2f76Swren romano // For dense storage we must enumerate all the remaining coordinates 60872ec2f76Swren romano // in this dimension (i.e., coordinates after the last non-zero 60972ec2f76Swren romano // element), and either fill in their zero values or else recurse 61072ec2f76Swren romano // to finalize some deeper dimension. 61172ec2f76Swren romano if (d + 1 == getRank()) 61272ec2f76Swren romano values.insert(values.end(), count, 0); 61372ec2f76Swren romano else 61472ec2f76Swren romano finalizeSegment(d + 1, 0, count); 6151ce77b56SAart Bik } 6161ce77b56SAart Bik } 6171ce77b56SAart Bik 6181ce77b56SAart Bik /// Wraps up a single insertion path, inner to outer. 6191ce77b56SAart Bik void endPath(uint64_t diff) { 6201ce77b56SAart Bik uint64_t rank = getRank(); 6211ce77b56SAart Bik assert(diff <= rank); 6221ce77b56SAart Bik for (uint64_t i = 0; i < rank - diff; i++) { 62372ec2f76Swren romano const uint64_t d = rank - i - 1; 62472ec2f76Swren romano finalizeSegment(d, idx[d] + 1); 6251ce77b56SAart Bik } 6261ce77b56SAart Bik } 6271ce77b56SAart Bik 6281ce77b56SAart Bik /// Continues a single insertion path, outer to inner. 629c03fd1e6Swren romano void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) { 6301ce77b56SAart Bik uint64_t rank = getRank(); 6311ce77b56SAart Bik assert(diff < rank); 6321ce77b56SAart Bik for (uint64_t d = diff; d < rank; d++) { 6331ce77b56SAart Bik uint64_t i = cursor[d]; 63472ec2f76Swren romano appendIndex(d, top, i); 6351ce77b56SAart Bik top = 0; 6361ce77b56SAart Bik idx[d] = i; 6371ce77b56SAart Bik } 6381ce77b56SAart Bik values.push_back(val); 6391ce77b56SAart Bik } 6401ce77b56SAart Bik 6411ce77b56SAart Bik /// Finds the lexicographic differing dimension. 64246bdacaaSwren romano uint64_t lexDiff(const uint64_t *cursor) const { 6431ce77b56SAart Bik for (uint64_t r = 0, rank = getRank(); r < rank; r++) 6441ce77b56SAart Bik if (cursor[r] > idx[r]) 6451ce77b56SAart Bik return r; 6461ce77b56SAart Bik else 6471ce77b56SAart Bik assert(cursor[r] == idx[r] && "non-lexicographic insertion"); 6481ce77b56SAart Bik assert(0 && "duplication insertion"); 6491ce77b56SAart Bik return -1u; 6501ce77b56SAart Bik } 6511ce77b56SAart Bik 6528a91bc7bSHarrietAkot private: 6538a91bc7bSHarrietAkot std::vector<std::vector<P>> pointers; 6548a91bc7bSHarrietAkot std::vector<std::vector<I>> indices; 6558a91bc7bSHarrietAkot std::vector<V> values; 6568d8b566fSwren romano std::vector<uint64_t> idx; // index cursor for lexicographic insertion. 6578a91bc7bSHarrietAkot }; 6588a91bc7bSHarrietAkot 6598a91bc7bSHarrietAkot /// Helper to convert string to lower case. 6608a91bc7bSHarrietAkot static char *toLower(char *token) { 6618a91bc7bSHarrietAkot for (char *c = token; *c; c++) 6628a91bc7bSHarrietAkot *c = tolower(*c); 6638a91bc7bSHarrietAkot return token; 6648a91bc7bSHarrietAkot } 6658a91bc7bSHarrietAkot 6668a91bc7bSHarrietAkot /// Read the MME header of a general sparse matrix of type real. 66703fe15ceSAart Bik static void readMMEHeader(FILE *file, char *filename, char *line, 66833e8ab8eSAart Bik uint64_t *idata, bool *isPattern, bool *isSymmetric) { 6698a91bc7bSHarrietAkot char header[64]; 6708a91bc7bSHarrietAkot char object[64]; 6718a91bc7bSHarrietAkot char format[64]; 6728a91bc7bSHarrietAkot char field[64]; 6738a91bc7bSHarrietAkot char symmetry[64]; 6748a91bc7bSHarrietAkot // Read header line. 6758a91bc7bSHarrietAkot if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field, 6768a91bc7bSHarrietAkot symmetry) != 5) { 67703fe15ceSAart Bik fprintf(stderr, "Corrupt header in %s\n", filename); 6788a91bc7bSHarrietAkot exit(1); 6798a91bc7bSHarrietAkot } 68033e8ab8eSAart Bik // Set properties 68133e8ab8eSAart Bik *isPattern = (strcmp(toLower(field), "pattern") == 0); 682bb56c2b3SMehdi Amini *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0); 6838a91bc7bSHarrietAkot // Make sure this is a general sparse matrix. 6848a91bc7bSHarrietAkot if (strcmp(toLower(header), "%%matrixmarket") || 6858a91bc7bSHarrietAkot strcmp(toLower(object), "matrix") || 68633e8ab8eSAart Bik strcmp(toLower(format), "coordinate") || 68733e8ab8eSAart Bik (strcmp(toLower(field), "real") && !(*isPattern)) || 688bb56c2b3SMehdi Amini (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) { 68933e8ab8eSAart Bik fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename); 6908a91bc7bSHarrietAkot exit(1); 6918a91bc7bSHarrietAkot } 6928a91bc7bSHarrietAkot // Skip comments. 693e5639b3fSMehdi Amini while (true) { 69403fe15ceSAart Bik if (!fgets(line, kColWidth, file)) { 69503fe15ceSAart Bik fprintf(stderr, "Cannot find data in %s\n", filename); 6968a91bc7bSHarrietAkot exit(1); 6978a91bc7bSHarrietAkot } 6988a91bc7bSHarrietAkot if (line[0] != '%') 6998a91bc7bSHarrietAkot break; 7008a91bc7bSHarrietAkot } 7018a91bc7bSHarrietAkot // Next line contains M N NNZ. 7028a91bc7bSHarrietAkot idata[0] = 2; // rank 7038a91bc7bSHarrietAkot if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3, 7048a91bc7bSHarrietAkot idata + 1) != 3) { 70503fe15ceSAart Bik fprintf(stderr, "Cannot find size in %s\n", filename); 7068a91bc7bSHarrietAkot exit(1); 7078a91bc7bSHarrietAkot } 7088a91bc7bSHarrietAkot } 7098a91bc7bSHarrietAkot 7108a91bc7bSHarrietAkot /// Read the "extended" FROSTT header. Although not part of the documented 7118a91bc7bSHarrietAkot /// format, we assume that the file starts with optional comments followed 7128a91bc7bSHarrietAkot /// by two lines that define the rank, the number of nonzeros, and the 7138a91bc7bSHarrietAkot /// dimensions sizes (one per rank) of the sparse tensor. 71403fe15ceSAart Bik static void readExtFROSTTHeader(FILE *file, char *filename, char *line, 71503fe15ceSAart Bik uint64_t *idata) { 7168a91bc7bSHarrietAkot // Skip comments. 717e5639b3fSMehdi Amini while (true) { 71803fe15ceSAart Bik if (!fgets(line, kColWidth, file)) { 71903fe15ceSAart Bik fprintf(stderr, "Cannot find data in %s\n", filename); 7208a91bc7bSHarrietAkot exit(1); 7218a91bc7bSHarrietAkot } 7228a91bc7bSHarrietAkot if (line[0] != '#') 7238a91bc7bSHarrietAkot break; 7248a91bc7bSHarrietAkot } 7258a91bc7bSHarrietAkot // Next line contains RANK and NNZ. 7268a91bc7bSHarrietAkot if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) { 72703fe15ceSAart Bik fprintf(stderr, "Cannot find metadata in %s\n", filename); 7288a91bc7bSHarrietAkot exit(1); 7298a91bc7bSHarrietAkot } 7308a91bc7bSHarrietAkot // Followed by a line with the dimension sizes (one per rank). 7318a91bc7bSHarrietAkot for (uint64_t r = 0; r < idata[0]; r++) { 7328a91bc7bSHarrietAkot if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) { 73303fe15ceSAart Bik fprintf(stderr, "Cannot find dimension size %s\n", filename); 7348a91bc7bSHarrietAkot exit(1); 7358a91bc7bSHarrietAkot } 7368a91bc7bSHarrietAkot } 73703fe15ceSAart Bik fgets(line, kColWidth, file); // end of line 7388a91bc7bSHarrietAkot } 7398a91bc7bSHarrietAkot 7408a91bc7bSHarrietAkot /// Reads a sparse tensor with the given filename into a memory-resident 7418a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme. 7428a91bc7bSHarrietAkot template <typename V> 7438a91bc7bSHarrietAkot static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank, 744d83a7068Swren romano const uint64_t *shape, 7458a91bc7bSHarrietAkot const uint64_t *perm) { 7468a91bc7bSHarrietAkot // Open the file. 7478a91bc7bSHarrietAkot FILE *file = fopen(filename, "r"); 7488a91bc7bSHarrietAkot if (!file) { 7493734c078Swren romano assert(filename && "Received nullptr for filename"); 7503734c078Swren romano fprintf(stderr, "Cannot find file %s\n", filename); 7518a91bc7bSHarrietAkot exit(1); 7528a91bc7bSHarrietAkot } 7538a91bc7bSHarrietAkot // Perform some file format dependent set up. 75403fe15ceSAart Bik char line[kColWidth]; 7558a91bc7bSHarrietAkot uint64_t idata[512]; 75633e8ab8eSAart Bik bool isPattern = false; 757bb56c2b3SMehdi Amini bool isSymmetric = false; 7588a91bc7bSHarrietAkot if (strstr(filename, ".mtx")) { 75933e8ab8eSAart Bik readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric); 7608a91bc7bSHarrietAkot } else if (strstr(filename, ".tns")) { 76103fe15ceSAart Bik readExtFROSTTHeader(file, filename, line, idata); 7628a91bc7bSHarrietAkot } else { 7638a91bc7bSHarrietAkot fprintf(stderr, "Unknown format %s\n", filename); 7648a91bc7bSHarrietAkot exit(1); 7658a91bc7bSHarrietAkot } 7668a91bc7bSHarrietAkot // Prepare sparse tensor object with per-dimension sizes 7678a91bc7bSHarrietAkot // and the number of nonzeros as initial capacity. 7688a91bc7bSHarrietAkot assert(rank == idata[0] && "rank mismatch"); 7698a91bc7bSHarrietAkot uint64_t nnz = idata[1]; 7708a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 771d83a7068Swren romano assert((shape[r] == 0 || shape[r] == idata[2 + r]) && 7728a91bc7bSHarrietAkot "dimension size mismatch"); 7738a91bc7bSHarrietAkot SparseTensorCOO<V> *tensor = 7748a91bc7bSHarrietAkot SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz); 7758a91bc7bSHarrietAkot // Read all nonzero elements. 7768a91bc7bSHarrietAkot std::vector<uint64_t> indices(rank); 7778a91bc7bSHarrietAkot for (uint64_t k = 0; k < nnz; k++) { 77803fe15ceSAart Bik if (!fgets(line, kColWidth, file)) { 77903fe15ceSAart Bik fprintf(stderr, "Cannot find next line of data in %s\n", filename); 7808a91bc7bSHarrietAkot exit(1); 7818a91bc7bSHarrietAkot } 78203fe15ceSAart Bik char *linePtr = line; 78303fe15ceSAart Bik for (uint64_t r = 0; r < rank; r++) { 78403fe15ceSAart Bik uint64_t idx = strtoul(linePtr, &linePtr, 10); 7858a91bc7bSHarrietAkot // Add 0-based index. 7868a91bc7bSHarrietAkot indices[perm[r]] = idx - 1; 7878a91bc7bSHarrietAkot } 7888a91bc7bSHarrietAkot // The external formats always store the numerical values with the type 7898a91bc7bSHarrietAkot // double, but we cast these values to the sparse tensor object type. 79033e8ab8eSAart Bik // For a pattern tensor, we arbitrarily pick the value 1 for all entries. 79133e8ab8eSAart Bik double value = isPattern ? 1.0 : strtod(linePtr, &linePtr); 7928a91bc7bSHarrietAkot tensor->add(indices, value); 79302710413SBixia Zheng // We currently chose to deal with symmetric matrices by fully constructing 79402710413SBixia Zheng // them. In the future, we may want to make symmetry implicit for storage 79502710413SBixia Zheng // reasons. 796bb56c2b3SMehdi Amini if (isSymmetric && indices[0] != indices[1]) 79702710413SBixia Zheng tensor->add({indices[1], indices[0]}, value); 7988a91bc7bSHarrietAkot } 7998a91bc7bSHarrietAkot // Close the file and return tensor. 8008a91bc7bSHarrietAkot fclose(file); 8018a91bc7bSHarrietAkot return tensor; 8028a91bc7bSHarrietAkot } 8038a91bc7bSHarrietAkot 804efa15f41SAart Bik /// Writes the sparse tensor to extended FROSTT format. 805efa15f41SAart Bik template <typename V> 80646bdacaaSwren romano static void outSparseTensor(void *tensor, void *dest, bool sort) { 8076438783fSAart Bik assert(tensor && dest); 8086438783fSAart Bik auto coo = static_cast<SparseTensorCOO<V> *>(tensor); 8096438783fSAart Bik if (sort) 8106438783fSAart Bik coo->sort(); 8116438783fSAart Bik char *filename = static_cast<char *>(dest); 8126438783fSAart Bik auto &sizes = coo->getSizes(); 8136438783fSAart Bik auto &elements = coo->getElements(); 8146438783fSAart Bik uint64_t rank = coo->getRank(); 815efa15f41SAart Bik uint64_t nnz = elements.size(); 816efa15f41SAart Bik std::fstream file; 817efa15f41SAart Bik file.open(filename, std::ios_base::out | std::ios_base::trunc); 818efa15f41SAart Bik assert(file.is_open()); 819efa15f41SAart Bik file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl; 820efa15f41SAart Bik for (uint64_t r = 0; r < rank - 1; r++) 821efa15f41SAart Bik file << sizes[r] << " "; 822efa15f41SAart Bik file << sizes[rank - 1] << std::endl; 823efa15f41SAart Bik for (uint64_t i = 0; i < nnz; i++) { 824efa15f41SAart Bik auto &idx = elements[i].indices; 825efa15f41SAart Bik for (uint64_t r = 0; r < rank; r++) 826efa15f41SAart Bik file << (idx[r] + 1) << " "; 827efa15f41SAart Bik file << elements[i].value << std::endl; 828efa15f41SAart Bik } 829efa15f41SAart Bik file.flush(); 830efa15f41SAart Bik file.close(); 831efa15f41SAart Bik assert(file.good()); 8326438783fSAart Bik } 8336438783fSAart Bik 8346438783fSAart Bik /// Initializes sparse tensor from an external COO-flavored format. 8356438783fSAart Bik template <typename V> 83646bdacaaSwren romano static SparseTensorStorage<uint64_t, uint64_t, V> * 8376438783fSAart Bik toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values, 83820eaa88fSBixia Zheng uint64_t *indices, uint64_t *perm, uint8_t *sparse) { 83920eaa88fSBixia Zheng const DimLevelType *sparsity = (DimLevelType *)(sparse); 84020eaa88fSBixia Zheng #ifndef NDEBUG 84120eaa88fSBixia Zheng // Verify that perm is a permutation of 0..(rank-1). 84220eaa88fSBixia Zheng std::vector<uint64_t> order(perm, perm + rank); 84320eaa88fSBixia Zheng std::sort(order.begin(), order.end()); 8441e47888dSAart Bik for (uint64_t i = 0; i < rank; ++i) { 84520eaa88fSBixia Zheng if (i != order[i]) { 846988d4b0dSAart Bik fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank); 84720eaa88fSBixia Zheng exit(1); 84820eaa88fSBixia Zheng } 84920eaa88fSBixia Zheng } 85020eaa88fSBixia Zheng 85120eaa88fSBixia Zheng // Verify that the sparsity values are supported. 8521e47888dSAart Bik for (uint64_t i = 0; i < rank; ++i) { 85320eaa88fSBixia Zheng if (sparsity[i] != DimLevelType::kDense && 85420eaa88fSBixia Zheng sparsity[i] != DimLevelType::kCompressed) { 85520eaa88fSBixia Zheng fprintf(stderr, "Unsupported sparsity value %d\n", 85620eaa88fSBixia Zheng static_cast<int>(sparsity[i])); 85720eaa88fSBixia Zheng exit(1); 85820eaa88fSBixia Zheng } 85920eaa88fSBixia Zheng } 86020eaa88fSBixia Zheng #endif 86120eaa88fSBixia Zheng 8626438783fSAart Bik // Convert external format to internal COO. 86363bdcaf9Swren romano auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse); 8646438783fSAart Bik std::vector<uint64_t> idx(rank); 8656438783fSAart Bik for (uint64_t i = 0, base = 0; i < nse; i++) { 8666438783fSAart Bik for (uint64_t r = 0; r < rank; r++) 867d8b229a1SAart Bik idx[perm[r]] = indices[base + r]; 86863bdcaf9Swren romano coo->add(idx, values[i]); 8696438783fSAart Bik base += rank; 8706438783fSAart Bik } 8716438783fSAart Bik // Return sparse tensor storage format as opaque pointer. 87263bdcaf9Swren romano auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor( 87363bdcaf9Swren romano rank, shape, perm, sparsity, coo); 87463bdcaf9Swren romano delete coo; 87563bdcaf9Swren romano return tensor; 8766438783fSAart Bik } 8776438783fSAart Bik 8786438783fSAart Bik /// Converts a sparse tensor to an external COO-flavored format. 8796438783fSAart Bik template <typename V> 88046bdacaaSwren romano static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse, 88146bdacaaSwren romano uint64_t **pShape, V **pValues, 88246bdacaaSwren romano uint64_t **pIndices) { 8836438783fSAart Bik auto sparseTensor = 8846438783fSAart Bik static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor); 8856438783fSAart Bik uint64_t rank = sparseTensor->getRank(); 8866438783fSAart Bik std::vector<uint64_t> perm(rank); 8876438783fSAart Bik std::iota(perm.begin(), perm.end(), 0); 8886438783fSAart Bik SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data()); 8896438783fSAart Bik 8906438783fSAart Bik const std::vector<Element<V>> &elements = coo->getElements(); 8916438783fSAart Bik uint64_t nse = elements.size(); 8926438783fSAart Bik 8936438783fSAart Bik uint64_t *shape = new uint64_t[rank]; 8946438783fSAart Bik for (uint64_t i = 0; i < rank; i++) 8956438783fSAart Bik shape[i] = coo->getSizes()[i]; 8966438783fSAart Bik 8976438783fSAart Bik V *values = new V[nse]; 8986438783fSAart Bik uint64_t *indices = new uint64_t[rank * nse]; 8996438783fSAart Bik 9006438783fSAart Bik for (uint64_t i = 0, base = 0; i < nse; i++) { 9016438783fSAart Bik values[i] = elements[i].value; 9026438783fSAart Bik for (uint64_t j = 0; j < rank; j++) 9036438783fSAart Bik indices[base + j] = elements[i].indices[j]; 9046438783fSAart Bik base += rank; 9056438783fSAart Bik } 9066438783fSAart Bik 9076438783fSAart Bik delete coo; 9086438783fSAart Bik *pRank = rank; 9096438783fSAart Bik *pNse = nse; 9106438783fSAart Bik *pShape = shape; 9116438783fSAart Bik *pValues = values; 9126438783fSAart Bik *pIndices = indices; 913efa15f41SAart Bik } 914efa15f41SAart Bik 915be0a7e9fSMehdi Amini } // namespace 9168a91bc7bSHarrietAkot 9178a91bc7bSHarrietAkot extern "C" { 9188a91bc7bSHarrietAkot 9198a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 9208a91bc7bSHarrietAkot // 9218a91bc7bSHarrietAkot // Public API with methods that operate on MLIR buffers (memrefs) to interact 9228a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally. 9238a91bc7bSHarrietAkot // These methods should be used exclusively by MLIR compiler-generated code. 9248a91bc7bSHarrietAkot // 9258a91bc7bSHarrietAkot // Some macro magic is used to generate implementations for all required type 9268a91bc7bSHarrietAkot // combinations that can be called from MLIR compiler-generated code. 9278a91bc7bSHarrietAkot // 9288a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 9298a91bc7bSHarrietAkot 9308a91bc7bSHarrietAkot #define CASE(p, i, v, P, I, V) \ 9318a91bc7bSHarrietAkot if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \ 93263bdcaf9Swren romano SparseTensorCOO<V> *coo = nullptr; \ 933845561ecSwren romano if (action <= Action::kFromCOO) { \ 934845561ecSwren romano if (action == Action::kFromFile) { \ 9358a91bc7bSHarrietAkot char *filename = static_cast<char *>(ptr); \ 93663bdcaf9Swren romano coo = openSparseTensorCOO<V>(filename, rank, shape, perm); \ 937845561ecSwren romano } else if (action == Action::kFromCOO) { \ 93863bdcaf9Swren romano coo = static_cast<SparseTensorCOO<V> *>(ptr); \ 9398a91bc7bSHarrietAkot } else { \ 940845561ecSwren romano assert(action == Action::kEmpty); \ 9418a91bc7bSHarrietAkot } \ 94263bdcaf9Swren romano auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor( \ 94363bdcaf9Swren romano rank, shape, perm, sparsity, coo); \ 94463bdcaf9Swren romano if (action == Action::kFromFile) \ 94563bdcaf9Swren romano delete coo; \ 94663bdcaf9Swren romano return tensor; \ 947bb56c2b3SMehdi Amini } \ 948bb56c2b3SMehdi Amini if (action == Action::kEmptyCOO) \ 949d83a7068Swren romano return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm); \ 95063bdcaf9Swren romano coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \ 951845561ecSwren romano if (action == Action::kToIterator) { \ 95263bdcaf9Swren romano coo->startIterator(); \ 9538a91bc7bSHarrietAkot } else { \ 954845561ecSwren romano assert(action == Action::kToCOO); \ 9558a91bc7bSHarrietAkot } \ 95663bdcaf9Swren romano return coo; \ 9578a91bc7bSHarrietAkot } 9588a91bc7bSHarrietAkot 959845561ecSwren romano #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V) 960845561ecSwren romano 9618a91bc7bSHarrietAkot #define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \ 9628a91bc7bSHarrietAkot void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) { \ 9634f2ec7f9SAart Bik assert(ref &&tensor); \ 9648a91bc7bSHarrietAkot std::vector<TYPE> *v; \ 9658a91bc7bSHarrietAkot static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v); \ 9668a91bc7bSHarrietAkot ref->basePtr = ref->data = v->data(); \ 9678a91bc7bSHarrietAkot ref->offset = 0; \ 9688a91bc7bSHarrietAkot ref->sizes[0] = v->size(); \ 9698a91bc7bSHarrietAkot ref->strides[0] = 1; \ 9708a91bc7bSHarrietAkot } 9718a91bc7bSHarrietAkot 9728a91bc7bSHarrietAkot #define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \ 9738a91bc7bSHarrietAkot void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \ 974d2215e79SRainer Orth index_type d) { \ 9754f2ec7f9SAart Bik assert(ref &&tensor); \ 9768a91bc7bSHarrietAkot std::vector<TYPE> *v; \ 9778a91bc7bSHarrietAkot static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d); \ 9788a91bc7bSHarrietAkot ref->basePtr = ref->data = v->data(); \ 9798a91bc7bSHarrietAkot ref->offset = 0; \ 9808a91bc7bSHarrietAkot ref->sizes[0] = v->size(); \ 9818a91bc7bSHarrietAkot ref->strides[0] = 1; \ 9828a91bc7bSHarrietAkot } 9838a91bc7bSHarrietAkot 9848a91bc7bSHarrietAkot #define IMPL_ADDELT(NAME, TYPE) \ 9858a91bc7bSHarrietAkot void *_mlir_ciface_##NAME(void *tensor, TYPE value, \ 986d2215e79SRainer Orth StridedMemRefType<index_type, 1> *iref, \ 987d2215e79SRainer Orth StridedMemRefType<index_type, 1> *pref) { \ 9884f2ec7f9SAart Bik assert(tensor &&iref &&pref); \ 9898a91bc7bSHarrietAkot assert(iref->strides[0] == 1 && pref->strides[0] == 1); \ 9908a91bc7bSHarrietAkot assert(iref->sizes[0] == pref->sizes[0]); \ 991d2215e79SRainer Orth const index_type *indx = iref->data + iref->offset; \ 992d2215e79SRainer Orth const index_type *perm = pref->data + pref->offset; \ 9938a91bc7bSHarrietAkot uint64_t isize = iref->sizes[0]; \ 994d2215e79SRainer Orth std::vector<index_type> indices(isize); \ 9958a91bc7bSHarrietAkot for (uint64_t r = 0; r < isize; r++) \ 9968a91bc7bSHarrietAkot indices[perm[r]] = indx[r]; \ 9978a91bc7bSHarrietAkot static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value); \ 9988a91bc7bSHarrietAkot return tensor; \ 9998a91bc7bSHarrietAkot } 10008a91bc7bSHarrietAkot 10018a91bc7bSHarrietAkot #define IMPL_GETNEXT(NAME, V) \ 1002d2215e79SRainer Orth bool _mlir_ciface_##NAME(void *tensor, \ 1003d2215e79SRainer Orth StridedMemRefType<index_type, 1> *iref, \ 10048a91bc7bSHarrietAkot StridedMemRefType<V, 0> *vref) { \ 10054f2ec7f9SAart Bik assert(tensor &&iref &&vref); \ 10068a91bc7bSHarrietAkot assert(iref->strides[0] == 1); \ 1007d2215e79SRainer Orth index_type *indx = iref->data + iref->offset; \ 1008c9f2beffSMehdi Amini V *value = vref->data + vref->offset; \ 10098a91bc7bSHarrietAkot const uint64_t isize = iref->sizes[0]; \ 10108a91bc7bSHarrietAkot auto iter = static_cast<SparseTensorCOO<V> *>(tensor); \ 10118a91bc7bSHarrietAkot const Element<V> *elem = iter->getNext(); \ 101263bdcaf9Swren romano if (elem == nullptr) \ 10138a91bc7bSHarrietAkot return false; \ 10148a91bc7bSHarrietAkot for (uint64_t r = 0; r < isize; r++) \ 10158a91bc7bSHarrietAkot indx[r] = elem->indices[r]; \ 10168a91bc7bSHarrietAkot *value = elem->value; \ 10178a91bc7bSHarrietAkot return true; \ 10188a91bc7bSHarrietAkot } 10198a91bc7bSHarrietAkot 1020f66e5769SAart Bik #define IMPL_LEXINSERT(NAME, V) \ 1021d2215e79SRainer Orth void _mlir_ciface_##NAME(void *tensor, \ 1022d2215e79SRainer Orth StridedMemRefType<index_type, 1> *cref, V val) { \ 10234f2ec7f9SAart Bik assert(tensor &&cref); \ 1024f66e5769SAart Bik assert(cref->strides[0] == 1); \ 1025d2215e79SRainer Orth index_type *cursor = cref->data + cref->offset; \ 1026f66e5769SAart Bik assert(cursor); \ 1027f66e5769SAart Bik static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val); \ 1028f66e5769SAart Bik } 1029f66e5769SAart Bik 10304f2ec7f9SAart Bik #define IMPL_EXPINSERT(NAME, V) \ 10314f2ec7f9SAart Bik void _mlir_ciface_##NAME( \ 1032d2215e79SRainer Orth void *tensor, StridedMemRefType<index_type, 1> *cref, \ 10334f2ec7f9SAart Bik StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \ 1034d2215e79SRainer Orth StridedMemRefType<index_type, 1> *aref, index_type count) { \ 10354f2ec7f9SAart Bik assert(tensor &&cref &&vref &&fref &&aref); \ 10364f2ec7f9SAart Bik assert(cref->strides[0] == 1); \ 10374f2ec7f9SAart Bik assert(vref->strides[0] == 1); \ 10384f2ec7f9SAart Bik assert(fref->strides[0] == 1); \ 10394f2ec7f9SAart Bik assert(aref->strides[0] == 1); \ 10404f2ec7f9SAart Bik assert(vref->sizes[0] == fref->sizes[0]); \ 1041d2215e79SRainer Orth index_type *cursor = cref->data + cref->offset; \ 1042c9f2beffSMehdi Amini V *values = vref->data + vref->offset; \ 10434f2ec7f9SAart Bik bool *filled = fref->data + fref->offset; \ 1044d2215e79SRainer Orth index_type *added = aref->data + aref->offset; \ 10454f2ec7f9SAart Bik static_cast<SparseTensorStorageBase *>(tensor)->expInsert( \ 10464f2ec7f9SAart Bik cursor, values, filled, added, count); \ 10474f2ec7f9SAart Bik } 10484f2ec7f9SAart Bik 1049d2215e79SRainer Orth // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor 1050bc04a470Swren romano // can safely rewrite kIndex to kU64. We make this assertion to guarantee 1051bc04a470Swren romano // that this file cannot get out of sync with its header. 1052d2215e79SRainer Orth static_assert(std::is_same<index_type, uint64_t>::value, 1053d2215e79SRainer Orth "Expected index_type == uint64_t"); 1054bc04a470Swren romano 10558a91bc7bSHarrietAkot /// Constructs a new sparse tensor. This is the "swiss army knife" 10568a91bc7bSHarrietAkot /// method for materializing sparse tensors into the computation. 10578a91bc7bSHarrietAkot /// 1058845561ecSwren romano /// Action: 10598a91bc7bSHarrietAkot /// kEmpty = returns empty storage to fill later 10608a91bc7bSHarrietAkot /// kFromFile = returns storage, where ptr contains filename to read 10618a91bc7bSHarrietAkot /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign 10628a91bc7bSHarrietAkot /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO 10638a91bc7bSHarrietAkot /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO 1064845561ecSwren romano /// kToIterator = returns iterator from storage in ptr (call getNext() to use) 10658a91bc7bSHarrietAkot void * 1066845561ecSwren romano _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT 1067d2215e79SRainer Orth StridedMemRefType<index_type, 1> *sref, 1068d2215e79SRainer Orth StridedMemRefType<index_type, 1> *pref, 1069845561ecSwren romano OverheadType ptrTp, OverheadType indTp, 1070845561ecSwren romano PrimaryType valTp, Action action, void *ptr) { 10718a91bc7bSHarrietAkot assert(aref && sref && pref); 10728a91bc7bSHarrietAkot assert(aref->strides[0] == 1 && sref->strides[0] == 1 && 10738a91bc7bSHarrietAkot pref->strides[0] == 1); 10748a91bc7bSHarrietAkot assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]); 1075845561ecSwren romano const DimLevelType *sparsity = aref->data + aref->offset; 1076d83a7068Swren romano const index_type *shape = sref->data + sref->offset; 1077d2215e79SRainer Orth const index_type *perm = pref->data + pref->offset; 10788a91bc7bSHarrietAkot uint64_t rank = aref->sizes[0]; 10798a91bc7bSHarrietAkot 1080bc04a470Swren romano // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. 1081bc04a470Swren romano // This is safe because of the static_assert above. 1082bc04a470Swren romano if (ptrTp == OverheadType::kIndex) 1083bc04a470Swren romano ptrTp = OverheadType::kU64; 1084bc04a470Swren romano if (indTp == OverheadType::kIndex) 1085bc04a470Swren romano indTp = OverheadType::kU64; 1086bc04a470Swren romano 10878a91bc7bSHarrietAkot // Double matrices with all combinations of overhead storage. 1088845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t, 1089845561ecSwren romano uint64_t, double); 1090845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t, 1091845561ecSwren romano uint32_t, double); 1092845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t, 1093845561ecSwren romano uint16_t, double); 1094845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t, 1095845561ecSwren romano uint8_t, double); 1096845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t, 1097845561ecSwren romano uint64_t, double); 1098845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t, 1099845561ecSwren romano uint32_t, double); 1100845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t, 1101845561ecSwren romano uint16_t, double); 1102845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t, 1103845561ecSwren romano uint8_t, double); 1104845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t, 1105845561ecSwren romano uint64_t, double); 1106845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t, 1107845561ecSwren romano uint32_t, double); 1108845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t, 1109845561ecSwren romano uint16_t, double); 1110845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t, 1111845561ecSwren romano uint8_t, double); 1112845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t, 1113845561ecSwren romano uint64_t, double); 1114845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t, 1115845561ecSwren romano uint32_t, double); 1116845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t, 1117845561ecSwren romano uint16_t, double); 1118845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t, 1119845561ecSwren romano uint8_t, double); 11208a91bc7bSHarrietAkot 11218a91bc7bSHarrietAkot // Float matrices with all combinations of overhead storage. 1122845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t, 1123845561ecSwren romano uint64_t, float); 1124845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t, 1125845561ecSwren romano uint32_t, float); 1126845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t, 1127845561ecSwren romano uint16_t, float); 1128845561ecSwren romano CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t, 1129845561ecSwren romano uint8_t, float); 1130845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t, 1131845561ecSwren romano uint64_t, float); 1132845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t, 1133845561ecSwren romano uint32_t, float); 1134845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t, 1135845561ecSwren romano uint16_t, float); 1136845561ecSwren romano CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t, 1137845561ecSwren romano uint8_t, float); 1138845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t, 1139845561ecSwren romano uint64_t, float); 1140845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t, 1141845561ecSwren romano uint32_t, float); 1142845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t, 1143845561ecSwren romano uint16_t, float); 1144845561ecSwren romano CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t, 1145845561ecSwren romano uint8_t, float); 1146845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t, 1147845561ecSwren romano uint64_t, float); 1148845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t, 1149845561ecSwren romano uint32_t, float); 1150845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t, 1151845561ecSwren romano uint16_t, float); 1152845561ecSwren romano CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t, 1153845561ecSwren romano uint8_t, float); 11548a91bc7bSHarrietAkot 1155845561ecSwren romano // Integral matrices with both overheads of the same type. 1156845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t); 1157845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t); 1158845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t); 1159845561ecSwren romano CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t); 1160845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t); 1161845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t); 1162845561ecSwren romano CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t); 1163845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t); 1164845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t); 1165845561ecSwren romano CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t); 1166845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t); 1167845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t); 1168845561ecSwren romano CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t); 11698a91bc7bSHarrietAkot 11708a91bc7bSHarrietAkot // Unsupported case (add above if needed). 11718a91bc7bSHarrietAkot fputs("unsupported combination of types\n", stderr); 11728a91bc7bSHarrietAkot exit(1); 11738a91bc7bSHarrietAkot } 11748a91bc7bSHarrietAkot 11758a91bc7bSHarrietAkot /// Methods that provide direct access to pointers. 1176d2215e79SRainer Orth IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers) 11778a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers) 11788a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers) 11798a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers) 11808a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers) 11818a91bc7bSHarrietAkot 11828a91bc7bSHarrietAkot /// Methods that provide direct access to indices. 1183d2215e79SRainer Orth IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices) 11848a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices) 11858a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices) 11868a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices) 11878a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices) 11888a91bc7bSHarrietAkot 11898a91bc7bSHarrietAkot /// Methods that provide direct access to values. 11908a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF64, double, getValues) 11918a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF32, float, getValues) 11928a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues) 11938a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues) 11948a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues) 11958a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues) 11968a91bc7bSHarrietAkot 11978a91bc7bSHarrietAkot /// Helper to add value to coordinate scheme, one per value type. 11988a91bc7bSHarrietAkot IMPL_ADDELT(addEltF64, double) 11998a91bc7bSHarrietAkot IMPL_ADDELT(addEltF32, float) 12008a91bc7bSHarrietAkot IMPL_ADDELT(addEltI64, int64_t) 12018a91bc7bSHarrietAkot IMPL_ADDELT(addEltI32, int32_t) 12028a91bc7bSHarrietAkot IMPL_ADDELT(addEltI16, int16_t) 12038a91bc7bSHarrietAkot IMPL_ADDELT(addEltI8, int8_t) 12048a91bc7bSHarrietAkot 12058a91bc7bSHarrietAkot /// Helper to enumerate elements of coordinate scheme, one per value type. 12068a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF64, double) 12078a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF32, float) 12088a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI64, int64_t) 12098a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI32, int32_t) 12108a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI16, int16_t) 12118a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI8, int8_t) 12128a91bc7bSHarrietAkot 12136438783fSAart Bik /// Insert elements in lexicographical index order, one per value type. 1214f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF64, double) 1215f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF32, float) 1216f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI64, int64_t) 1217f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI32, int32_t) 1218f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI16, int16_t) 1219f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI8, int8_t) 1220f66e5769SAart Bik 12216438783fSAart Bik /// Insert using expansion, one per value type. 12224f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF64, double) 12234f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF32, float) 12244f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI64, int64_t) 12254f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI32, int32_t) 12264f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI16, int16_t) 12274f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI8, int8_t) 12284f2ec7f9SAart Bik 12298a91bc7bSHarrietAkot #undef CASE 12308a91bc7bSHarrietAkot #undef IMPL_SPARSEVALUES 12318a91bc7bSHarrietAkot #undef IMPL_GETOVERHEAD 12328a91bc7bSHarrietAkot #undef IMPL_ADDELT 12338a91bc7bSHarrietAkot #undef IMPL_GETNEXT 12344f2ec7f9SAart Bik #undef IMPL_LEXINSERT 12354f2ec7f9SAart Bik #undef IMPL_EXPINSERT 12366438783fSAart Bik 12376438783fSAart Bik /// Output a sparse tensor, one per value type. 12386438783fSAart Bik void outSparseTensorF64(void *tensor, void *dest, bool sort) { 12396438783fSAart Bik return outSparseTensor<double>(tensor, dest, sort); 12406438783fSAart Bik } 12416438783fSAart Bik void outSparseTensorF32(void *tensor, void *dest, bool sort) { 12426438783fSAart Bik return outSparseTensor<float>(tensor, dest, sort); 12436438783fSAart Bik } 12446438783fSAart Bik void outSparseTensorI64(void *tensor, void *dest, bool sort) { 12456438783fSAart Bik return outSparseTensor<int64_t>(tensor, dest, sort); 12466438783fSAart Bik } 12476438783fSAart Bik void outSparseTensorI32(void *tensor, void *dest, bool sort) { 12486438783fSAart Bik return outSparseTensor<int32_t>(tensor, dest, sort); 12496438783fSAart Bik } 12506438783fSAart Bik void outSparseTensorI16(void *tensor, void *dest, bool sort) { 12516438783fSAart Bik return outSparseTensor<int16_t>(tensor, dest, sort); 12526438783fSAart Bik } 12536438783fSAart Bik void outSparseTensorI8(void *tensor, void *dest, bool sort) { 12546438783fSAart Bik return outSparseTensor<int8_t>(tensor, dest, sort); 12556438783fSAart Bik } 12568a91bc7bSHarrietAkot 12578a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 12588a91bc7bSHarrietAkot // 12598a91bc7bSHarrietAkot // Public API with methods that accept C-style data structures to interact 12608a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally. 12618a91bc7bSHarrietAkot // These methods can be used both by MLIR compiler-generated code as well as by 12628a91bc7bSHarrietAkot // an external runtime that wants to interact with MLIR compiler-generated code. 12638a91bc7bSHarrietAkot // 12648a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 12658a91bc7bSHarrietAkot 12668a91bc7bSHarrietAkot /// Helper method to read a sparse tensor filename from the environment, 12678a91bc7bSHarrietAkot /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc. 1268d2215e79SRainer Orth char *getTensorFilename(index_type id) { 12698a91bc7bSHarrietAkot char var[80]; 12708a91bc7bSHarrietAkot sprintf(var, "TENSOR%" PRIu64, id); 12718a91bc7bSHarrietAkot char *env = getenv(var); 12723734c078Swren romano if (!env) { 12733734c078Swren romano fprintf(stderr, "Environment variable %s is not set\n", var); 12743734c078Swren romano exit(1); 12753734c078Swren romano } 12768a91bc7bSHarrietAkot return env; 12778a91bc7bSHarrietAkot } 12788a91bc7bSHarrietAkot 12798a91bc7bSHarrietAkot /// Returns size of sparse tensor in given dimension. 1280d2215e79SRainer Orth index_type sparseDimSize(void *tensor, index_type d) { 12818a91bc7bSHarrietAkot return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d); 12828a91bc7bSHarrietAkot } 12838a91bc7bSHarrietAkot 1284f66e5769SAart Bik /// Finalizes lexicographic insertions. 1285f66e5769SAart Bik void endInsert(void *tensor) { 1286f66e5769SAart Bik return static_cast<SparseTensorStorageBase *>(tensor)->endInsert(); 1287f66e5769SAart Bik } 1288f66e5769SAart Bik 12898a91bc7bSHarrietAkot /// Releases sparse tensor storage. 12908a91bc7bSHarrietAkot void delSparseTensor(void *tensor) { 12918a91bc7bSHarrietAkot delete static_cast<SparseTensorStorageBase *>(tensor); 12928a91bc7bSHarrietAkot } 12938a91bc7bSHarrietAkot 129463bdcaf9Swren romano /// Releases sparse tensor coordinate scheme. 129563bdcaf9Swren romano #define IMPL_DELCOO(VNAME, V) \ 129663bdcaf9Swren romano void delSparseTensorCOO##VNAME(void *coo) { \ 129763bdcaf9Swren romano delete static_cast<SparseTensorCOO<V> *>(coo); \ 129863bdcaf9Swren romano } 129963bdcaf9Swren romano IMPL_DELCOO(F64, double) 130063bdcaf9Swren romano IMPL_DELCOO(F32, float) 130163bdcaf9Swren romano IMPL_DELCOO(I64, int64_t) 130263bdcaf9Swren romano IMPL_DELCOO(I32, int32_t) 130363bdcaf9Swren romano IMPL_DELCOO(I16, int16_t) 130463bdcaf9Swren romano IMPL_DELCOO(I8, int8_t) 130563bdcaf9Swren romano #undef IMPL_DELCOO 130663bdcaf9Swren romano 13078a91bc7bSHarrietAkot /// Initializes sparse tensor from a COO-flavored format expressed using C-style 13088a91bc7bSHarrietAkot /// data structures. The expected parameters are: 13098a91bc7bSHarrietAkot /// 13108a91bc7bSHarrietAkot /// rank: rank of tensor 13118a91bc7bSHarrietAkot /// nse: number of specified elements (usually the nonzeros) 13128a91bc7bSHarrietAkot /// shape: array with dimension size for each rank 13138a91bc7bSHarrietAkot /// values: a "nse" array with values for all specified elements 13148a91bc7bSHarrietAkot /// indices: a flat "nse x rank" array with indices for all specified elements 131520eaa88fSBixia Zheng /// perm: the permutation of the dimensions in the storage 131620eaa88fSBixia Zheng /// sparse: the sparsity for the dimensions 13178a91bc7bSHarrietAkot /// 13188a91bc7bSHarrietAkot /// For example, the sparse matrix 13198a91bc7bSHarrietAkot /// | 1.0 0.0 0.0 | 13208a91bc7bSHarrietAkot /// | 0.0 5.0 3.0 | 13218a91bc7bSHarrietAkot /// can be passed as 13228a91bc7bSHarrietAkot /// rank = 2 13238a91bc7bSHarrietAkot /// nse = 3 13248a91bc7bSHarrietAkot /// shape = [2, 3] 13258a91bc7bSHarrietAkot /// values = [1.0, 5.0, 3.0] 13268a91bc7bSHarrietAkot /// indices = [ 0, 0, 1, 1, 1, 2] 13278a91bc7bSHarrietAkot // 132820eaa88fSBixia Zheng // TODO: generalize beyond 64-bit indices. 13298a91bc7bSHarrietAkot // 13306438783fSAart Bik void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape, 133120eaa88fSBixia Zheng double *values, uint64_t *indices, 133220eaa88fSBixia Zheng uint64_t *perm, uint8_t *sparse) { 133320eaa88fSBixia Zheng return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm, 133420eaa88fSBixia Zheng sparse); 13358a91bc7bSHarrietAkot } 13366438783fSAart Bik void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape, 133720eaa88fSBixia Zheng float *values, uint64_t *indices, 133820eaa88fSBixia Zheng uint64_t *perm, uint8_t *sparse) { 133920eaa88fSBixia Zheng return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm, 134020eaa88fSBixia Zheng sparse); 13418a91bc7bSHarrietAkot } 13428a91bc7bSHarrietAkot 13432f49e6b0SBixia Zheng /// Converts a sparse tensor to COO-flavored format expressed using C-style 13442f49e6b0SBixia Zheng /// data structures. The expected output parameters are pointers for these 13452f49e6b0SBixia Zheng /// values: 13462f49e6b0SBixia Zheng /// 13472f49e6b0SBixia Zheng /// rank: rank of tensor 13482f49e6b0SBixia Zheng /// nse: number of specified elements (usually the nonzeros) 13492f49e6b0SBixia Zheng /// shape: array with dimension size for each rank 13502f49e6b0SBixia Zheng /// values: a "nse" array with values for all specified elements 13512f49e6b0SBixia Zheng /// indices: a flat "nse x rank" array with indices for all specified elements 13522f49e6b0SBixia Zheng /// 13532f49e6b0SBixia Zheng /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned 13542f49e6b0SBixia Zheng /// from convertToMLIRSparseTensor. 13552f49e6b0SBixia Zheng /// 13562f49e6b0SBixia Zheng // TODO: Currently, values are copied from SparseTensorStorage to 13572f49e6b0SBixia Zheng // SparseTensorCOO, then to the output. We may want to reduce the number of 13582f49e6b0SBixia Zheng // copies. 13592f49e6b0SBixia Zheng // 13606438783fSAart Bik // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions 13616438783fSAart Bik // compressed 13622f49e6b0SBixia Zheng // 13636438783fSAart Bik void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank, 13646438783fSAart Bik uint64_t *pNse, uint64_t **pShape, 13656438783fSAart Bik double **pValues, uint64_t **pIndices) { 13666438783fSAart Bik fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices); 13672f49e6b0SBixia Zheng } 13686438783fSAart Bik void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank, 13696438783fSAart Bik uint64_t *pNse, uint64_t **pShape, 13706438783fSAart Bik float **pValues, uint64_t **pIndices) { 13716438783fSAart Bik fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices); 13722f49e6b0SBixia Zheng } 1373efa15f41SAart Bik 13748a91bc7bSHarrietAkot } // extern "C" 13758a91bc7bSHarrietAkot 13768a91bc7bSHarrietAkot #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 1377