18a91bc7bSHarrietAkot //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
28a91bc7bSHarrietAkot //
38a91bc7bSHarrietAkot // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48a91bc7bSHarrietAkot // See https://llvm.org/LICENSE.txt for license information.
58a91bc7bSHarrietAkot // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a91bc7bSHarrietAkot //
78a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
88a91bc7bSHarrietAkot //
98a91bc7bSHarrietAkot // This file implements a light-weight runtime support library that is useful
108a91bc7bSHarrietAkot // for sparse tensor manipulations. The functionality provided in this library
118a91bc7bSHarrietAkot // is meant to simplify benchmarking, testing, and debugging MLIR code that
128a91bc7bSHarrietAkot // operates on sparse tensors. The provided functionality is **not** part
138a91bc7bSHarrietAkot // of core MLIR, however.
148a91bc7bSHarrietAkot //
158a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
168a91bc7bSHarrietAkot 
17845561ecSwren romano #include "mlir/ExecutionEngine/SparseTensorUtils.h"
188a91bc7bSHarrietAkot #include "mlir/ExecutionEngine/CRunnerUtils.h"
198a91bc7bSHarrietAkot 
208a91bc7bSHarrietAkot #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
218a91bc7bSHarrietAkot 
228a91bc7bSHarrietAkot #include <algorithm>
238a91bc7bSHarrietAkot #include <cassert>
248a91bc7bSHarrietAkot #include <cctype>
258a91bc7bSHarrietAkot #include <cinttypes>
268a91bc7bSHarrietAkot #include <cstdio>
278a91bc7bSHarrietAkot #include <cstdlib>
288a91bc7bSHarrietAkot #include <cstring>
29efa15f41SAart Bik #include <fstream>
30efa15f41SAart Bik #include <iostream>
314d0a18d0Swren romano #include <limits>
328a91bc7bSHarrietAkot #include <numeric>
338a91bc7bSHarrietAkot #include <vector>
348a91bc7bSHarrietAkot 
358a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
368a91bc7bSHarrietAkot //
378a91bc7bSHarrietAkot // Internal support for storing and reading sparse tensors.
388a91bc7bSHarrietAkot //
398a91bc7bSHarrietAkot // The following memory-resident sparse storage schemes are supported:
408a91bc7bSHarrietAkot //
418a91bc7bSHarrietAkot // (a) A coordinate scheme for temporarily storing and lexicographically
428a91bc7bSHarrietAkot //     sorting a sparse tensor by index (SparseTensorCOO).
438a91bc7bSHarrietAkot //
448a91bc7bSHarrietAkot // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
458a91bc7bSHarrietAkot //     per-dimension sparse/dense annnotations together with a dimension
468a91bc7bSHarrietAkot //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
478a91bc7bSHarrietAkot //
488a91bc7bSHarrietAkot // The following external formats are supported:
498a91bc7bSHarrietAkot //
508a91bc7bSHarrietAkot // (1) Matrix Market Exchange (MME): *.mtx
518a91bc7bSHarrietAkot //     https://math.nist.gov/MatrixMarket/formats.html
528a91bc7bSHarrietAkot //
538a91bc7bSHarrietAkot // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
548a91bc7bSHarrietAkot //     http://frostt.io/tensors/file-formats.html
558a91bc7bSHarrietAkot //
568a91bc7bSHarrietAkot // Two public APIs are supported:
578a91bc7bSHarrietAkot //
588a91bc7bSHarrietAkot // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
598a91bc7bSHarrietAkot //     tensors. These methods should be used exclusively by MLIR
608a91bc7bSHarrietAkot //     compiler-generated code.
618a91bc7bSHarrietAkot //
628a91bc7bSHarrietAkot // (II) Methods that accept C-style data structures to interact with sparse
638a91bc7bSHarrietAkot //      tensors. These methods can be used by any external runtime that wants
648a91bc7bSHarrietAkot //      to interact with MLIR compiler-generated code.
658a91bc7bSHarrietAkot //
668a91bc7bSHarrietAkot // In both cases (I) and (II), the SparseTensorStorage format is externally
678a91bc7bSHarrietAkot // only visible as an opaque pointer.
688a91bc7bSHarrietAkot //
698a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
708a91bc7bSHarrietAkot 
718a91bc7bSHarrietAkot namespace {
728a91bc7bSHarrietAkot 
7303fe15ceSAart Bik static constexpr int kColWidth = 1025;
7403fe15ceSAart Bik 
7572ec2f76Swren romano /// A version of `operator*` on `uint64_t` which checks for overflows.
7672ec2f76Swren romano static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
7772ec2f76Swren romano   assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
7872ec2f76Swren romano          "Integer overflow");
7972ec2f76Swren romano   return lhs * rhs;
8072ec2f76Swren romano }
8172ec2f76Swren romano 
828a91bc7bSHarrietAkot /// A sparse tensor element in coordinate scheme (value and indices).
838a91bc7bSHarrietAkot /// For example, a rank-1 vector element would look like
848a91bc7bSHarrietAkot ///   ({i}, a[i])
858a91bc7bSHarrietAkot /// and a rank-5 tensor element like
868a91bc7bSHarrietAkot ///   ({i,j,k,l,m}, a[i,j,k,l,m])
87*ccd047cbSAart Bik /// We use pointer to a shared index pool rather than e.g. a direct
88*ccd047cbSAart Bik /// vector since that (1) reduces the per-element memory footprint, and
89*ccd047cbSAart Bik /// (2) centralizes the memory reservation and (re)allocation to one place.
908a91bc7bSHarrietAkot template <typename V>
918a91bc7bSHarrietAkot struct Element {
92*ccd047cbSAart Bik   Element(uint64_t *ind, V val) : indices(ind), value(val){};
93*ccd047cbSAart Bik   uint64_t *indices; // pointer into shared index pool
948a91bc7bSHarrietAkot   V value;
958a91bc7bSHarrietAkot };
968a91bc7bSHarrietAkot 
978a91bc7bSHarrietAkot /// A memory-resident sparse tensor in coordinate scheme (collection of
988a91bc7bSHarrietAkot /// elements). This data structure is used to read a sparse tensor from
998a91bc7bSHarrietAkot /// any external format into memory and sort the elements lexicographically
1008a91bc7bSHarrietAkot /// by indices before passing it back to the client (most packed storage
1018a91bc7bSHarrietAkot /// formats require the elements to appear in lexicographic index order).
1028a91bc7bSHarrietAkot template <typename V>
1038a91bc7bSHarrietAkot struct SparseTensorCOO {
1048a91bc7bSHarrietAkot public:
1058a91bc7bSHarrietAkot   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
106db6796dfSMehdi Amini       : sizes(szs) {
107*ccd047cbSAart Bik     if (capacity) {
1088a91bc7bSHarrietAkot       elements.reserve(capacity);
109*ccd047cbSAart Bik       indices.reserve(capacity * getRank());
1108a91bc7bSHarrietAkot     }
111*ccd047cbSAart Bik   }
112*ccd047cbSAart Bik 
1138a91bc7bSHarrietAkot   /// Adds element as indices and value.
1148a91bc7bSHarrietAkot   void add(const std::vector<uint64_t> &ind, V val) {
1158a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to add() after startIterator()");
116*ccd047cbSAart Bik     uint64_t *base = indices.data();
117*ccd047cbSAart Bik     uint64_t size = indices.size();
1188a91bc7bSHarrietAkot     uint64_t rank = getRank();
1198a91bc7bSHarrietAkot     assert(rank == ind.size());
120*ccd047cbSAart Bik     for (uint64_t r = 0; r < rank; r++) {
1218a91bc7bSHarrietAkot       assert(ind[r] < sizes[r]); // within bounds
122*ccd047cbSAart Bik       indices.push_back(ind[r]);
1238a91bc7bSHarrietAkot     }
124*ccd047cbSAart Bik     // This base only changes if indices were reallocated. In that case, we
125*ccd047cbSAart Bik     // need to correct all previous pointers into the vector. Note that this
126*ccd047cbSAart Bik     // only happens if we did not set the initial capacity right, and then only
127*ccd047cbSAart Bik     // for every internal vector reallocation (which with the doubling rule
128*ccd047cbSAart Bik     // should only incur an amortized linear overhead).
129*ccd047cbSAart Bik     uint64_t *new_base = indices.data();
130*ccd047cbSAart Bik     if (new_base != base) {
131*ccd047cbSAart Bik       for (uint64_t i = 0, n = elements.size(); i < n; i++)
132*ccd047cbSAart Bik         elements[i].indices = new_base + (elements[i].indices - base);
133*ccd047cbSAart Bik       base = new_base;
134*ccd047cbSAart Bik     }
135*ccd047cbSAart Bik     // Add element as (pointer into shared index pool, value) pair.
136*ccd047cbSAart Bik     elements.emplace_back(base + size, val);
137*ccd047cbSAart Bik   }
138*ccd047cbSAart Bik 
1398a91bc7bSHarrietAkot   /// Sorts elements lexicographically by index.
1408a91bc7bSHarrietAkot   void sort() {
1418a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
142cf358253Swren romano     // TODO: we may want to cache an `isSorted` bit, to avoid
143cf358253Swren romano     // unnecessary/redundant sorting.
144*ccd047cbSAart Bik     std::sort(elements.begin(), elements.end(),
145*ccd047cbSAart Bik               [this](const Element<V> &e1, const Element<V> &e2) {
146*ccd047cbSAart Bik                 uint64_t rank = getRank();
147*ccd047cbSAart Bik                 for (uint64_t r = 0; r < rank; r++) {
148*ccd047cbSAart Bik                   if (e1.indices[r] == e2.indices[r])
149*ccd047cbSAart Bik                     continue;
150*ccd047cbSAart Bik                   return e1.indices[r] < e2.indices[r];
1518a91bc7bSHarrietAkot                 }
152*ccd047cbSAart Bik                 return false;
153*ccd047cbSAart Bik               });
154*ccd047cbSAart Bik   }
155*ccd047cbSAart Bik 
1568a91bc7bSHarrietAkot   /// Returns rank.
1578a91bc7bSHarrietAkot   uint64_t getRank() const { return sizes.size(); }
158*ccd047cbSAart Bik 
1598a91bc7bSHarrietAkot   /// Getter for sizes array.
1608a91bc7bSHarrietAkot   const std::vector<uint64_t> &getSizes() const { return sizes; }
161*ccd047cbSAart Bik 
1628a91bc7bSHarrietAkot   /// Getter for elements array.
1638a91bc7bSHarrietAkot   const std::vector<Element<V>> &getElements() const { return elements; }
1648a91bc7bSHarrietAkot 
1658a91bc7bSHarrietAkot   /// Switch into iterator mode.
1668a91bc7bSHarrietAkot   void startIterator() {
1678a91bc7bSHarrietAkot     iteratorLocked = true;
1688a91bc7bSHarrietAkot     iteratorPos = 0;
1698a91bc7bSHarrietAkot   }
170*ccd047cbSAart Bik 
1718a91bc7bSHarrietAkot   /// Get the next element.
1728a91bc7bSHarrietAkot   const Element<V> *getNext() {
1738a91bc7bSHarrietAkot     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
1748a91bc7bSHarrietAkot     if (iteratorPos < elements.size())
1758a91bc7bSHarrietAkot       return &(elements[iteratorPos++]);
1768a91bc7bSHarrietAkot     iteratorLocked = false;
1778a91bc7bSHarrietAkot     return nullptr;
1788a91bc7bSHarrietAkot   }
1798a91bc7bSHarrietAkot 
1808a91bc7bSHarrietAkot   /// Factory method. Permutes the original dimensions according to
1818a91bc7bSHarrietAkot   /// the given ordering and expects subsequent add() calls to honor
1828a91bc7bSHarrietAkot   /// that same ordering for the given indices. The result is a
1838a91bc7bSHarrietAkot   /// fully permuted coordinate scheme.
1848d8b566fSwren romano   ///
1858d8b566fSwren romano   /// Precondition: `sizes` and `perm` must be valid for `rank`.
1868a91bc7bSHarrietAkot   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
1878a91bc7bSHarrietAkot                                                 const uint64_t *sizes,
1888a91bc7bSHarrietAkot                                                 const uint64_t *perm,
1898a91bc7bSHarrietAkot                                                 uint64_t capacity = 0) {
1908a91bc7bSHarrietAkot     std::vector<uint64_t> permsz(rank);
191d83a7068Swren romano     for (uint64_t r = 0; r < rank; r++) {
192d83a7068Swren romano       assert(sizes[r] > 0 && "Dimension size zero has trivial storage");
1938a91bc7bSHarrietAkot       permsz[perm[r]] = sizes[r];
194d83a7068Swren romano     }
1958a91bc7bSHarrietAkot     return new SparseTensorCOO<V>(permsz, capacity);
1968a91bc7bSHarrietAkot   }
1978a91bc7bSHarrietAkot 
1988a91bc7bSHarrietAkot private:
1998a91bc7bSHarrietAkot   const std::vector<uint64_t> sizes; // per-dimension sizes
200*ccd047cbSAart Bik   std::vector<Element<V>> elements;  // all COO elements
201*ccd047cbSAart Bik   std::vector<uint64_t> indices;     // shared index pool
202db6796dfSMehdi Amini   bool iteratorLocked = false;
203db6796dfSMehdi Amini   unsigned iteratorPos = 0;
2048a91bc7bSHarrietAkot };
2058a91bc7bSHarrietAkot 
2068d8b566fSwren romano /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
2078d8b566fSwren romano /// takes responsibility for all the `<P,I,V>`-independent aspects
2088d8b566fSwren romano /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
2098d8b566fSwren romano /// we use function overloading to implement "partial" method
2108d8b566fSwren romano /// specialization, which the C-API relies on to catch type errors
2118d8b566fSwren romano /// arising from our use of opaque pointers.
2128a91bc7bSHarrietAkot class SparseTensorStorageBase {
2138a91bc7bSHarrietAkot public:
2148d8b566fSwren romano   /// Constructs a new storage object.  The `perm` maps the tensor's
2158d8b566fSwren romano   /// semantic-ordering of dimensions to this object's storage-order.
2168d8b566fSwren romano   /// The `szs` and `sparsity` arrays are already in storage-order.
2178d8b566fSwren romano   ///
2188d8b566fSwren romano   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
2198d8b566fSwren romano   SparseTensorStorageBase(const std::vector<uint64_t> &szs,
2208d8b566fSwren romano                           const uint64_t *perm, const DimLevelType *sparsity)
2218d8b566fSwren romano       : dimSizes(szs), rev(getRank()),
2228d8b566fSwren romano         dimTypes(sparsity, sparsity + getRank()) {
2238d8b566fSwren romano     const uint64_t rank = getRank();
2248d8b566fSwren romano     // Validate parameters.
2258d8b566fSwren romano     assert(rank > 0 && "Trivial shape is unsupported");
2268d8b566fSwren romano     for (uint64_t r = 0; r < rank; r++) {
2278d8b566fSwren romano       assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
2288d8b566fSwren romano       assert((dimTypes[r] == DimLevelType::kDense ||
2298d8b566fSwren romano               dimTypes[r] == DimLevelType::kCompressed) &&
2308d8b566fSwren romano              "Unsupported DimLevelType");
2318d8b566fSwren romano     }
2328d8b566fSwren romano     // Construct the "reverse" (i.e., inverse) permutation.
2338d8b566fSwren romano     for (uint64_t r = 0; r < rank; r++)
2348d8b566fSwren romano       rev[perm[r]] = r;
2358d8b566fSwren romano   }
2368d8b566fSwren romano 
2378d8b566fSwren romano   virtual ~SparseTensorStorageBase() = default;
2388d8b566fSwren romano 
2398d8b566fSwren romano   /// Get the rank of the tensor.
2408d8b566fSwren romano   uint64_t getRank() const { return dimSizes.size(); }
2418d8b566fSwren romano 
2428d8b566fSwren romano   /// Getter for the dimension-sizes array, in storage-order.
2438d8b566fSwren romano   const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
2448d8b566fSwren romano 
2458d8b566fSwren romano   /// Safely lookup the size of the given (storage-order) dimension.
2468d8b566fSwren romano   uint64_t getDimSize(uint64_t d) const {
2478d8b566fSwren romano     assert(d < getRank());
2488d8b566fSwren romano     return dimSizes[d];
2498d8b566fSwren romano   }
2508d8b566fSwren romano 
2518d8b566fSwren romano   /// Getter for the "reverse" permutation, which maps this object's
2528d8b566fSwren romano   /// storage-order to the tensor's semantic-order.
2538d8b566fSwren romano   const std::vector<uint64_t> &getRev() const { return rev; }
2548d8b566fSwren romano 
2558d8b566fSwren romano   /// Getter for the dimension-types array, in storage-order.
2568d8b566fSwren romano   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
2578d8b566fSwren romano 
2588d8b566fSwren romano   /// Safely check if the (storage-order) dimension uses compressed storage.
2598d8b566fSwren romano   bool isCompressedDim(uint64_t d) const {
2608d8b566fSwren romano     assert(d < getRank());
2618d8b566fSwren romano     return (dimTypes[d] == DimLevelType::kCompressed);
2628d8b566fSwren romano   }
2638a91bc7bSHarrietAkot 
2644f2ec7f9SAart Bik   /// Overhead storage.
2658a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
2668a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
2678a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
2688a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
2698a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
2708a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
2718a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
2728a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
2738a91bc7bSHarrietAkot 
2744f2ec7f9SAart Bik   /// Primary storage.
2758a91bc7bSHarrietAkot   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
2768a91bc7bSHarrietAkot   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
2778a91bc7bSHarrietAkot   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
2788a91bc7bSHarrietAkot   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
2798a91bc7bSHarrietAkot   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
2808a91bc7bSHarrietAkot   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
2818a91bc7bSHarrietAkot 
2824f2ec7f9SAart Bik   /// Element-wise insertion in lexicographic index order.
283c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
284c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); }
285c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); }
286c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
287c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
288c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
2894f2ec7f9SAart Bik 
2904f2ec7f9SAart Bik   /// Expanded insertion.
2914f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
2924f2ec7f9SAart Bik     fatal("expf64");
2934f2ec7f9SAart Bik   }
2944f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
2954f2ec7f9SAart Bik     fatal("expf32");
2964f2ec7f9SAart Bik   }
2974f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
2984f2ec7f9SAart Bik     fatal("expi64");
2994f2ec7f9SAart Bik   }
3004f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
3014f2ec7f9SAart Bik     fatal("expi32");
3024f2ec7f9SAart Bik   }
3034f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
3044f2ec7f9SAart Bik     fatal("expi16");
3054f2ec7f9SAart Bik   }
3064f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
3074f2ec7f9SAart Bik     fatal("expi8");
3084f2ec7f9SAart Bik   }
3094f2ec7f9SAart Bik 
3104f2ec7f9SAart Bik   /// Finishes insertion.
311f66e5769SAart Bik   virtual void endInsert() = 0;
312f66e5769SAart Bik 
3138a91bc7bSHarrietAkot private:
31446bdacaaSwren romano   static void fatal(const char *tp) {
3158a91bc7bSHarrietAkot     fprintf(stderr, "unsupported %s\n", tp);
3168a91bc7bSHarrietAkot     exit(1);
3178a91bc7bSHarrietAkot   }
3188d8b566fSwren romano 
3198d8b566fSwren romano   const std::vector<uint64_t> dimSizes;
3208d8b566fSwren romano   std::vector<uint64_t> rev;
3218d8b566fSwren romano   const std::vector<DimLevelType> dimTypes;
3228a91bc7bSHarrietAkot };
3238a91bc7bSHarrietAkot 
3248a91bc7bSHarrietAkot /// A memory-resident sparse tensor using a storage scheme based on
3258a91bc7bSHarrietAkot /// per-dimension sparse/dense annotations. This data structure provides a
3268a91bc7bSHarrietAkot /// bufferized form of a sparse tensor type. In contrast to generating setup
3278a91bc7bSHarrietAkot /// methods for each differently annotated sparse tensor, this method provides
3288a91bc7bSHarrietAkot /// a convenient "one-size-fits-all" solution that simply takes an input tensor
3298a91bc7bSHarrietAkot /// and annotations to implement all required setup in a general manner.
3308a91bc7bSHarrietAkot template <typename P, typename I, typename V>
3318a91bc7bSHarrietAkot class SparseTensorStorage : public SparseTensorStorageBase {
3328a91bc7bSHarrietAkot public:
3338a91bc7bSHarrietAkot   /// Constructs a sparse tensor storage scheme with the given dimensions,
3348a91bc7bSHarrietAkot   /// permutation, and per-dimension dense/sparse annotations, using
3358a91bc7bSHarrietAkot   /// the coordinate scheme tensor for the initial contents if provided.
3368d8b566fSwren romano   ///
3378d8b566fSwren romano   /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
3388a91bc7bSHarrietAkot   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
339f66e5769SAart Bik                       const DimLevelType *sparsity,
3408d8b566fSwren romano                       SparseTensorCOO<V> *coo = nullptr)
3418d8b566fSwren romano       : SparseTensorStorageBase(szs, perm, sparsity), pointers(getRank()),
3428d8b566fSwren romano         indices(getRank()), idx(getRank()) {
3438a91bc7bSHarrietAkot     // Provide hints on capacity of pointers and indices.
344175b9af4SAart Bik     // TODO: needs much fine-tuning based on actual sparsity; currently
345175b9af4SAart Bik     //       we reserve pointer/index space based on all previous dense
346175b9af4SAart Bik     //       dimensions, which works well up to first sparse dim; but
347175b9af4SAart Bik     //       we should really use nnz and dense/sparse distribution.
348f66e5769SAart Bik     bool allDense = true;
349f66e5769SAart Bik     uint64_t sz = 1;
3508d8b566fSwren romano     for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
3518d8b566fSwren romano       if (isCompressedDim(r)) {
3528d8b566fSwren romano         // TODO: Take a parameter between 1 and `sizes[r]`, and multiply
3538d8b566fSwren romano         // `sz` by that before reserving. (For now we just use 1.)
354f66e5769SAart Bik         pointers[r].reserve(sz + 1);
3558d8b566fSwren romano         pointers[r].push_back(0);
356f66e5769SAart Bik         indices[r].reserve(sz);
357f66e5769SAart Bik         sz = 1;
358f66e5769SAart Bik         allDense = false;
3598d8b566fSwren romano       } else { // Dense dimension.
3608d8b566fSwren romano         sz = checkedMul(sz, getDimSizes()[r]);
3618a91bc7bSHarrietAkot       }
3628a91bc7bSHarrietAkot     }
3638a91bc7bSHarrietAkot     // Then assign contents from coordinate scheme tensor if provided.
3648d8b566fSwren romano     if (coo) {
3654d0a18d0Swren romano       // Ensure both preconditions of `fromCOO`.
3668d8b566fSwren romano       assert(coo->getSizes() == getDimSizes() && "Tensor size mismatch");
3678d8b566fSwren romano       coo->sort();
3684d0a18d0Swren romano       // Now actually insert the `elements`.
3698d8b566fSwren romano       const std::vector<Element<V>> &elements = coo->getElements();
370ceda1ae9Swren romano       uint64_t nnz = elements.size();
3718a91bc7bSHarrietAkot       values.reserve(nnz);
372ceda1ae9Swren romano       fromCOO(elements, 0, nnz, 0);
3731ce77b56SAart Bik     } else if (allDense) {
374f66e5769SAart Bik       values.resize(sz, 0);
3758a91bc7bSHarrietAkot     }
3768a91bc7bSHarrietAkot   }
3778a91bc7bSHarrietAkot 
3780ae2e958SMehdi Amini   ~SparseTensorStorage() override = default;
3798a91bc7bSHarrietAkot 
380f66e5769SAart Bik   /// Partially specialize these getter methods based on template types.
3818a91bc7bSHarrietAkot   void getPointers(std::vector<P> **out, uint64_t d) override {
3828a91bc7bSHarrietAkot     assert(d < getRank());
3838a91bc7bSHarrietAkot     *out = &pointers[d];
3848a91bc7bSHarrietAkot   }
3858a91bc7bSHarrietAkot   void getIndices(std::vector<I> **out, uint64_t d) override {
3868a91bc7bSHarrietAkot     assert(d < getRank());
3878a91bc7bSHarrietAkot     *out = &indices[d];
3888a91bc7bSHarrietAkot   }
3898a91bc7bSHarrietAkot   void getValues(std::vector<V> **out) override { *out = &values; }
3908a91bc7bSHarrietAkot 
39103fe15ceSAart Bik   /// Partially specialize lexicographical insertions based on template types.
392c03fd1e6Swren romano   void lexInsert(const uint64_t *cursor, V val) override {
3931ce77b56SAart Bik     // First, wrap up pending insertion path.
3941ce77b56SAart Bik     uint64_t diff = 0;
3951ce77b56SAart Bik     uint64_t top = 0;
3961ce77b56SAart Bik     if (!values.empty()) {
3971ce77b56SAart Bik       diff = lexDiff(cursor);
3981ce77b56SAart Bik       endPath(diff + 1);
3991ce77b56SAart Bik       top = idx[diff] + 1;
4001ce77b56SAart Bik     }
4011ce77b56SAart Bik     // Then continue with insertion path.
4021ce77b56SAart Bik     insPath(cursor, diff, top, val);
403f66e5769SAart Bik   }
404f66e5769SAart Bik 
4054f2ec7f9SAart Bik   /// Partially specialize expanded insertions based on template types.
4064f2ec7f9SAart Bik   /// Note that this method resets the values/filled-switch array back
4074f2ec7f9SAart Bik   /// to all-zero/false while only iterating over the nonzero elements.
4084f2ec7f9SAart Bik   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
4094f2ec7f9SAart Bik                  uint64_t count) override {
4104f2ec7f9SAart Bik     if (count == 0)
4114f2ec7f9SAart Bik       return;
4124f2ec7f9SAart Bik     // Sort.
4134f2ec7f9SAart Bik     std::sort(added, added + count);
4144f2ec7f9SAart Bik     // Restore insertion path for first insert.
4153bf2ba3bSwren romano     const uint64_t lastDim = getRank() - 1;
4164f2ec7f9SAart Bik     uint64_t index = added[0];
4173bf2ba3bSwren romano     cursor[lastDim] = index;
4184f2ec7f9SAart Bik     lexInsert(cursor, values[index]);
4194f2ec7f9SAart Bik     assert(filled[index]);
4204f2ec7f9SAart Bik     values[index] = 0;
4214f2ec7f9SAart Bik     filled[index] = false;
4224f2ec7f9SAart Bik     // Subsequent insertions are quick.
4234f2ec7f9SAart Bik     for (uint64_t i = 1; i < count; i++) {
4244f2ec7f9SAart Bik       assert(index < added[i] && "non-lexicographic insertion");
4254f2ec7f9SAart Bik       index = added[i];
4263bf2ba3bSwren romano       cursor[lastDim] = index;
4273bf2ba3bSwren romano       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
4284f2ec7f9SAart Bik       assert(filled[index]);
4293bf2ba3bSwren romano       values[index] = 0;
4304f2ec7f9SAart Bik       filled[index] = false;
4314f2ec7f9SAart Bik     }
4324f2ec7f9SAart Bik   }
4334f2ec7f9SAart Bik 
434f66e5769SAart Bik   /// Finalizes lexicographic insertions.
4351ce77b56SAart Bik   void endInsert() override {
4361ce77b56SAart Bik     if (values.empty())
43772ec2f76Swren romano       finalizeSegment(0);
4381ce77b56SAart Bik     else
4391ce77b56SAart Bik       endPath(0);
4401ce77b56SAart Bik   }
441f66e5769SAart Bik 
4428a91bc7bSHarrietAkot   /// Returns this sparse tensor storage scheme as a new memory-resident
4438a91bc7bSHarrietAkot   /// sparse tensor in coordinate scheme with the given dimension order.
4448d8b566fSwren romano   ///
4458d8b566fSwren romano   /// Precondition: `perm` must be valid for `getRank()`.
4468a91bc7bSHarrietAkot   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
4478a91bc7bSHarrietAkot     // Restore original order of the dimension sizes and allocate coordinate
4488a91bc7bSHarrietAkot     // scheme with desired new ordering specified in perm.
4498d8b566fSwren romano     const uint64_t rank = getRank();
4508d8b566fSwren romano     const auto &rev = getRev();
4518d8b566fSwren romano     const auto &sizes = getDimSizes();
4528a91bc7bSHarrietAkot     std::vector<uint64_t> orgsz(rank);
4538a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
4548a91bc7bSHarrietAkot       orgsz[rev[r]] = sizes[r];
4558d8b566fSwren romano     SparseTensorCOO<V> *coo = SparseTensorCOO<V>::newSparseTensorCOO(
4568a91bc7bSHarrietAkot         rank, orgsz.data(), perm, values.size());
4578a91bc7bSHarrietAkot     // Populate coordinate scheme restored from old ordering and changed with
4588a91bc7bSHarrietAkot     // new ordering. Rather than applying both reorderings during the recursion,
4598a91bc7bSHarrietAkot     // we compute the combine permutation in advance.
4608a91bc7bSHarrietAkot     std::vector<uint64_t> reord(rank);
4618a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
4628a91bc7bSHarrietAkot       reord[r] = perm[rev[r]];
4638d8b566fSwren romano     toCOO(*coo, reord, 0, 0);
4648d8b566fSwren romano     // TODO: This assertion assumes there are no stored zeros,
4658d8b566fSwren romano     // or if there are then that we don't filter them out.
4668d8b566fSwren romano     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
4678d8b566fSwren romano     assert(coo->getElements().size() == values.size());
4688d8b566fSwren romano     return coo;
4698a91bc7bSHarrietAkot   }
4708a91bc7bSHarrietAkot 
4718a91bc7bSHarrietAkot   /// Factory method. Constructs a sparse tensor storage scheme with the given
4728a91bc7bSHarrietAkot   /// dimensions, permutation, and per-dimension dense/sparse annotations,
4738a91bc7bSHarrietAkot   /// using the coordinate scheme tensor for the initial contents if provided.
4748a91bc7bSHarrietAkot   /// In the latter case, the coordinate scheme must respect the same
4758a91bc7bSHarrietAkot   /// permutation as is desired for the new sparse tensor storage.
4768d8b566fSwren romano   ///
4778d8b566fSwren romano   /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
4788a91bc7bSHarrietAkot   static SparseTensorStorage<P, I, V> *
479d83a7068Swren romano   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
4808d8b566fSwren romano                   const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
4818a91bc7bSHarrietAkot     SparseTensorStorage<P, I, V> *n = nullptr;
4828d8b566fSwren romano     if (coo) {
4838d8b566fSwren romano       assert(coo->getRank() == rank && "Tensor rank mismatch");
4848d8b566fSwren romano       const auto &coosz = coo->getSizes();
4858a91bc7bSHarrietAkot       for (uint64_t r = 0; r < rank; r++)
4868d8b566fSwren romano         assert(shape[r] == 0 || shape[r] == coosz[perm[r]]);
4878d8b566fSwren romano       n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
4888a91bc7bSHarrietAkot     } else {
4898a91bc7bSHarrietAkot       std::vector<uint64_t> permsz(rank);
490d83a7068Swren romano       for (uint64_t r = 0; r < rank; r++) {
491d83a7068Swren romano         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
492d83a7068Swren romano         permsz[perm[r]] = shape[r];
493d83a7068Swren romano       }
494f66e5769SAart Bik       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
4958a91bc7bSHarrietAkot     }
4968a91bc7bSHarrietAkot     return n;
4978a91bc7bSHarrietAkot   }
4988a91bc7bSHarrietAkot 
4998a91bc7bSHarrietAkot private:
50072ec2f76Swren romano   /// Appends an arbitrary new position to `pointers[d]`.  This method
50172ec2f76Swren romano   /// checks that `pos` is representable in the `P` type; however, it
50272ec2f76Swren romano   /// does not check that `pos` is semantically valid (i.e., larger than
50372ec2f76Swren romano   /// the previous position and smaller than `indices[d].capacity()`).
5048d8b566fSwren romano   void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
50572ec2f76Swren romano     assert(isCompressedDim(d));
50672ec2f76Swren romano     assert(pos <= std::numeric_limits<P>::max() &&
5074d0a18d0Swren romano            "Pointer value is too large for the P-type");
50872ec2f76Swren romano     pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
5094d0a18d0Swren romano   }
5104d0a18d0Swren romano 
51172ec2f76Swren romano   /// Appends index `i` to dimension `d`, in the semantically general
51272ec2f76Swren romano   /// sense.  For non-dense dimensions, that means appending to the
51372ec2f76Swren romano   /// `indices[d]` array, checking that `i` is representable in the `I`
51472ec2f76Swren romano   /// type; however, we do not verify other semantic requirements (e.g.,
51572ec2f76Swren romano   /// that `i` is in bounds for `sizes[d]`, and not previously occurring
51672ec2f76Swren romano   /// in the same segment).  For dense dimensions, this method instead
51772ec2f76Swren romano   /// appends the appropriate number of zeros to the `values` array,
51872ec2f76Swren romano   /// where `full` is the number of "entries" already written to `values`
51972ec2f76Swren romano   /// for this segment (aka one after the highest index previously appended).
52072ec2f76Swren romano   void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
52172ec2f76Swren romano     if (isCompressedDim(d)) {
5224d0a18d0Swren romano       assert(i <= std::numeric_limits<I>::max() &&
5234d0a18d0Swren romano              "Index value is too large for the I-type");
52472ec2f76Swren romano       indices[d].push_back(static_cast<I>(i));
52572ec2f76Swren romano     } else { // Dense dimension.
52672ec2f76Swren romano       assert(i >= full && "Index was already filled");
52772ec2f76Swren romano       if (i == full)
52872ec2f76Swren romano         return; // Short-circuit, since it'll be a nop.
52972ec2f76Swren romano       if (d + 1 == getRank())
53072ec2f76Swren romano         values.insert(values.end(), i - full, 0);
53172ec2f76Swren romano       else
53272ec2f76Swren romano         finalizeSegment(d + 1, 0, i - full);
53372ec2f76Swren romano     }
5344d0a18d0Swren romano   }
5354d0a18d0Swren romano 
5368a91bc7bSHarrietAkot   /// Initializes sparse tensor storage scheme from a memory-resident sparse
5378a91bc7bSHarrietAkot   /// tensor in coordinate scheme. This method prepares the pointers and
5388a91bc7bSHarrietAkot   /// indices arrays under the given per-dimension dense/sparse annotations.
5394d0a18d0Swren romano   ///
5404d0a18d0Swren romano   /// Preconditions:
5414d0a18d0Swren romano   /// (1) the `elements` must be lexicographically sorted.
5424d0a18d0Swren romano   /// (2) the indices of every element are valid for `sizes` (equal rank
5434d0a18d0Swren romano   ///     and pointwise less-than).
544ceda1ae9Swren romano   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
545ceda1ae9Swren romano                uint64_t hi, uint64_t d) {
5468a91bc7bSHarrietAkot     // Once dimensions are exhausted, insert the numerical values.
547c4017f9dSwren romano     assert(d <= getRank() && hi <= elements.size());
5488a91bc7bSHarrietAkot     if (d == getRank()) {
549c4017f9dSwren romano       assert(lo < hi);
5501ce77b56SAart Bik       values.push_back(elements[lo].value);
5518a91bc7bSHarrietAkot       return;
5528a91bc7bSHarrietAkot     }
5538a91bc7bSHarrietAkot     // Visit all elements in this interval.
5548a91bc7bSHarrietAkot     uint64_t full = 0;
555c4017f9dSwren romano     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
5568a91bc7bSHarrietAkot       // Find segment in interval with same index elements in this dimension.
557f66e5769SAart Bik       uint64_t i = elements[lo].indices[d];
5588a91bc7bSHarrietAkot       uint64_t seg = lo + 1;
559f66e5769SAart Bik       while (seg < hi && elements[seg].indices[d] == i)
5608a91bc7bSHarrietAkot         seg++;
5618a91bc7bSHarrietAkot       // Handle segment in interval for sparse or dense dimension.
56272ec2f76Swren romano       appendIndex(d, full, i);
56372ec2f76Swren romano       full = i + 1;
564ceda1ae9Swren romano       fromCOO(elements, lo, seg, d + 1);
5658a91bc7bSHarrietAkot       // And move on to next segment in interval.
5668a91bc7bSHarrietAkot       lo = seg;
5678a91bc7bSHarrietAkot     }
5688a91bc7bSHarrietAkot     // Finalize the sparse pointer structure at this dimension.
56972ec2f76Swren romano     finalizeSegment(d, full);
5708a91bc7bSHarrietAkot   }
5718a91bc7bSHarrietAkot 
5728a91bc7bSHarrietAkot   /// Stores the sparse tensor storage scheme into a memory-resident sparse
5738a91bc7bSHarrietAkot   /// tensor in coordinate scheme.
574ceda1ae9Swren romano   void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord,
575f66e5769SAart Bik              uint64_t pos, uint64_t d) {
5768a91bc7bSHarrietAkot     assert(d <= getRank());
5778a91bc7bSHarrietAkot     if (d == getRank()) {
5788a91bc7bSHarrietAkot       assert(pos < values.size());
579ceda1ae9Swren romano       tensor.add(idx, values[pos]);
5801ce77b56SAart Bik     } else if (isCompressedDim(d)) {
5818a91bc7bSHarrietAkot       // Sparse dimension.
5828a91bc7bSHarrietAkot       for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
5838a91bc7bSHarrietAkot         idx[reord[d]] = indices[d][ii];
584f66e5769SAart Bik         toCOO(tensor, reord, ii, d + 1);
5858a91bc7bSHarrietAkot       }
5861ce77b56SAart Bik     } else {
5871ce77b56SAart Bik       // Dense dimension.
5888d8b566fSwren romano       const uint64_t sz = getDimSizes()[d];
5898d8b566fSwren romano       const uint64_t off = pos * sz;
5908d8b566fSwren romano       for (uint64_t i = 0; i < sz; i++) {
5911ce77b56SAart Bik         idx[reord[d]] = i;
5921ce77b56SAart Bik         toCOO(tensor, reord, off + i, d + 1);
5938a91bc7bSHarrietAkot       }
5948a91bc7bSHarrietAkot     }
5951ce77b56SAart Bik   }
5961ce77b56SAart Bik 
59772ec2f76Swren romano   /// Finalize the sparse pointer structure at this dimension.
59872ec2f76Swren romano   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
59972ec2f76Swren romano     if (count == 0)
60072ec2f76Swren romano       return; // Short-circuit, since it'll be a nop.
60172ec2f76Swren romano     if (isCompressedDim(d)) {
60272ec2f76Swren romano       appendPointer(d, indices[d].size(), count);
60372ec2f76Swren romano     } else { // Dense dimension.
6048d8b566fSwren romano       const uint64_t sz = getDimSizes()[d];
60572ec2f76Swren romano       assert(sz >= full && "Segment is overfull");
6068d8b566fSwren romano       count = checkedMul(count, sz - full);
60772ec2f76Swren romano       // For dense storage we must enumerate all the remaining coordinates
60872ec2f76Swren romano       // in this dimension (i.e., coordinates after the last non-zero
60972ec2f76Swren romano       // element), and either fill in their zero values or else recurse
61072ec2f76Swren romano       // to finalize some deeper dimension.
61172ec2f76Swren romano       if (d + 1 == getRank())
61272ec2f76Swren romano         values.insert(values.end(), count, 0);
61372ec2f76Swren romano       else
61472ec2f76Swren romano         finalizeSegment(d + 1, 0, count);
6151ce77b56SAart Bik     }
6161ce77b56SAart Bik   }
6171ce77b56SAart Bik 
6181ce77b56SAart Bik   /// Wraps up a single insertion path, inner to outer.
6191ce77b56SAart Bik   void endPath(uint64_t diff) {
6201ce77b56SAart Bik     uint64_t rank = getRank();
6211ce77b56SAart Bik     assert(diff <= rank);
6221ce77b56SAart Bik     for (uint64_t i = 0; i < rank - diff; i++) {
62372ec2f76Swren romano       const uint64_t d = rank - i - 1;
62472ec2f76Swren romano       finalizeSegment(d, idx[d] + 1);
6251ce77b56SAart Bik     }
6261ce77b56SAart Bik   }
6271ce77b56SAart Bik 
6281ce77b56SAart Bik   /// Continues a single insertion path, outer to inner.
629c03fd1e6Swren romano   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
6301ce77b56SAart Bik     uint64_t rank = getRank();
6311ce77b56SAart Bik     assert(diff < rank);
6321ce77b56SAart Bik     for (uint64_t d = diff; d < rank; d++) {
6331ce77b56SAart Bik       uint64_t i = cursor[d];
63472ec2f76Swren romano       appendIndex(d, top, i);
6351ce77b56SAart Bik       top = 0;
6361ce77b56SAart Bik       idx[d] = i;
6371ce77b56SAart Bik     }
6381ce77b56SAart Bik     values.push_back(val);
6391ce77b56SAart Bik   }
6401ce77b56SAart Bik 
6411ce77b56SAart Bik   /// Finds the lexicographic differing dimension.
64246bdacaaSwren romano   uint64_t lexDiff(const uint64_t *cursor) const {
6431ce77b56SAart Bik     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
6441ce77b56SAart Bik       if (cursor[r] > idx[r])
6451ce77b56SAart Bik         return r;
6461ce77b56SAart Bik       else
6471ce77b56SAart Bik         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
6481ce77b56SAart Bik     assert(0 && "duplication insertion");
6491ce77b56SAart Bik     return -1u;
6501ce77b56SAart Bik   }
6511ce77b56SAart Bik 
6528a91bc7bSHarrietAkot private:
6538a91bc7bSHarrietAkot   std::vector<std::vector<P>> pointers;
6548a91bc7bSHarrietAkot   std::vector<std::vector<I>> indices;
6558a91bc7bSHarrietAkot   std::vector<V> values;
6568d8b566fSwren romano   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
6578a91bc7bSHarrietAkot };
6588a91bc7bSHarrietAkot 
6598a91bc7bSHarrietAkot /// Helper to convert string to lower case.
6608a91bc7bSHarrietAkot static char *toLower(char *token) {
6618a91bc7bSHarrietAkot   for (char *c = token; *c; c++)
6628a91bc7bSHarrietAkot     *c = tolower(*c);
6638a91bc7bSHarrietAkot   return token;
6648a91bc7bSHarrietAkot }
6658a91bc7bSHarrietAkot 
6668a91bc7bSHarrietAkot /// Read the MME header of a general sparse matrix of type real.
66703fe15ceSAart Bik static void readMMEHeader(FILE *file, char *filename, char *line,
66833e8ab8eSAart Bik                           uint64_t *idata, bool *isPattern, bool *isSymmetric) {
6698a91bc7bSHarrietAkot   char header[64];
6708a91bc7bSHarrietAkot   char object[64];
6718a91bc7bSHarrietAkot   char format[64];
6728a91bc7bSHarrietAkot   char field[64];
6738a91bc7bSHarrietAkot   char symmetry[64];
6748a91bc7bSHarrietAkot   // Read header line.
6758a91bc7bSHarrietAkot   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
6768a91bc7bSHarrietAkot              symmetry) != 5) {
67703fe15ceSAart Bik     fprintf(stderr, "Corrupt header in %s\n", filename);
6788a91bc7bSHarrietAkot     exit(1);
6798a91bc7bSHarrietAkot   }
68033e8ab8eSAart Bik   // Set properties
68133e8ab8eSAart Bik   *isPattern = (strcmp(toLower(field), "pattern") == 0);
682bb56c2b3SMehdi Amini   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
6838a91bc7bSHarrietAkot   // Make sure this is a general sparse matrix.
6848a91bc7bSHarrietAkot   if (strcmp(toLower(header), "%%matrixmarket") ||
6858a91bc7bSHarrietAkot       strcmp(toLower(object), "matrix") ||
68633e8ab8eSAart Bik       strcmp(toLower(format), "coordinate") ||
68733e8ab8eSAart Bik       (strcmp(toLower(field), "real") && !(*isPattern)) ||
688bb56c2b3SMehdi Amini       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
68933e8ab8eSAart Bik     fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
6908a91bc7bSHarrietAkot     exit(1);
6918a91bc7bSHarrietAkot   }
6928a91bc7bSHarrietAkot   // Skip comments.
693e5639b3fSMehdi Amini   while (true) {
69403fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
69503fe15ceSAart Bik       fprintf(stderr, "Cannot find data in %s\n", filename);
6968a91bc7bSHarrietAkot       exit(1);
6978a91bc7bSHarrietAkot     }
6988a91bc7bSHarrietAkot     if (line[0] != '%')
6998a91bc7bSHarrietAkot       break;
7008a91bc7bSHarrietAkot   }
7018a91bc7bSHarrietAkot   // Next line contains M N NNZ.
7028a91bc7bSHarrietAkot   idata[0] = 2; // rank
7038a91bc7bSHarrietAkot   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
7048a91bc7bSHarrietAkot              idata + 1) != 3) {
70503fe15ceSAart Bik     fprintf(stderr, "Cannot find size in %s\n", filename);
7068a91bc7bSHarrietAkot     exit(1);
7078a91bc7bSHarrietAkot   }
7088a91bc7bSHarrietAkot }
7098a91bc7bSHarrietAkot 
7108a91bc7bSHarrietAkot /// Read the "extended" FROSTT header. Although not part of the documented
7118a91bc7bSHarrietAkot /// format, we assume that the file starts with optional comments followed
7128a91bc7bSHarrietAkot /// by two lines that define the rank, the number of nonzeros, and the
7138a91bc7bSHarrietAkot /// dimensions sizes (one per rank) of the sparse tensor.
71403fe15ceSAart Bik static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
71503fe15ceSAart Bik                                 uint64_t *idata) {
7168a91bc7bSHarrietAkot   // Skip comments.
717e5639b3fSMehdi Amini   while (true) {
71803fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
71903fe15ceSAart Bik       fprintf(stderr, "Cannot find data in %s\n", filename);
7208a91bc7bSHarrietAkot       exit(1);
7218a91bc7bSHarrietAkot     }
7228a91bc7bSHarrietAkot     if (line[0] != '#')
7238a91bc7bSHarrietAkot       break;
7248a91bc7bSHarrietAkot   }
7258a91bc7bSHarrietAkot   // Next line contains RANK and NNZ.
7268a91bc7bSHarrietAkot   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
72703fe15ceSAart Bik     fprintf(stderr, "Cannot find metadata in %s\n", filename);
7288a91bc7bSHarrietAkot     exit(1);
7298a91bc7bSHarrietAkot   }
7308a91bc7bSHarrietAkot   // Followed by a line with the dimension sizes (one per rank).
7318a91bc7bSHarrietAkot   for (uint64_t r = 0; r < idata[0]; r++) {
7328a91bc7bSHarrietAkot     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
73303fe15ceSAart Bik       fprintf(stderr, "Cannot find dimension size %s\n", filename);
7348a91bc7bSHarrietAkot       exit(1);
7358a91bc7bSHarrietAkot     }
7368a91bc7bSHarrietAkot   }
73703fe15ceSAart Bik   fgets(line, kColWidth, file); // end of line
7388a91bc7bSHarrietAkot }
7398a91bc7bSHarrietAkot 
7408a91bc7bSHarrietAkot /// Reads a sparse tensor with the given filename into a memory-resident
7418a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme.
7428a91bc7bSHarrietAkot template <typename V>
7438a91bc7bSHarrietAkot static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
744d83a7068Swren romano                                                const uint64_t *shape,
7458a91bc7bSHarrietAkot                                                const uint64_t *perm) {
7468a91bc7bSHarrietAkot   // Open the file.
7478a91bc7bSHarrietAkot   FILE *file = fopen(filename, "r");
7488a91bc7bSHarrietAkot   if (!file) {
7493734c078Swren romano     assert(filename && "Received nullptr for filename");
7503734c078Swren romano     fprintf(stderr, "Cannot find file %s\n", filename);
7518a91bc7bSHarrietAkot     exit(1);
7528a91bc7bSHarrietAkot   }
7538a91bc7bSHarrietAkot   // Perform some file format dependent set up.
75403fe15ceSAart Bik   char line[kColWidth];
7558a91bc7bSHarrietAkot   uint64_t idata[512];
75633e8ab8eSAart Bik   bool isPattern = false;
757bb56c2b3SMehdi Amini   bool isSymmetric = false;
7588a91bc7bSHarrietAkot   if (strstr(filename, ".mtx")) {
75933e8ab8eSAart Bik     readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
7608a91bc7bSHarrietAkot   } else if (strstr(filename, ".tns")) {
76103fe15ceSAart Bik     readExtFROSTTHeader(file, filename, line, idata);
7628a91bc7bSHarrietAkot   } else {
7638a91bc7bSHarrietAkot     fprintf(stderr, "Unknown format %s\n", filename);
7648a91bc7bSHarrietAkot     exit(1);
7658a91bc7bSHarrietAkot   }
7668a91bc7bSHarrietAkot   // Prepare sparse tensor object with per-dimension sizes
7678a91bc7bSHarrietAkot   // and the number of nonzeros as initial capacity.
7688a91bc7bSHarrietAkot   assert(rank == idata[0] && "rank mismatch");
7698a91bc7bSHarrietAkot   uint64_t nnz = idata[1];
7708a91bc7bSHarrietAkot   for (uint64_t r = 0; r < rank; r++)
771d83a7068Swren romano     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
7728a91bc7bSHarrietAkot            "dimension size mismatch");
7738a91bc7bSHarrietAkot   SparseTensorCOO<V> *tensor =
7748a91bc7bSHarrietAkot       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
7758a91bc7bSHarrietAkot   //  Read all nonzero elements.
7768a91bc7bSHarrietAkot   std::vector<uint64_t> indices(rank);
7778a91bc7bSHarrietAkot   for (uint64_t k = 0; k < nnz; k++) {
77803fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
77903fe15ceSAart Bik       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
7808a91bc7bSHarrietAkot       exit(1);
7818a91bc7bSHarrietAkot     }
78203fe15ceSAart Bik     char *linePtr = line;
78303fe15ceSAart Bik     for (uint64_t r = 0; r < rank; r++) {
78403fe15ceSAart Bik       uint64_t idx = strtoul(linePtr, &linePtr, 10);
7858a91bc7bSHarrietAkot       // Add 0-based index.
7868a91bc7bSHarrietAkot       indices[perm[r]] = idx - 1;
7878a91bc7bSHarrietAkot     }
7888a91bc7bSHarrietAkot     // The external formats always store the numerical values with the type
7898a91bc7bSHarrietAkot     // double, but we cast these values to the sparse tensor object type.
79033e8ab8eSAart Bik     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
79133e8ab8eSAart Bik     double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
7928a91bc7bSHarrietAkot     tensor->add(indices, value);
79302710413SBixia Zheng     // We currently chose to deal with symmetric matrices by fully constructing
79402710413SBixia Zheng     // them. In the future, we may want to make symmetry implicit for storage
79502710413SBixia Zheng     // reasons.
796bb56c2b3SMehdi Amini     if (isSymmetric && indices[0] != indices[1])
79702710413SBixia Zheng       tensor->add({indices[1], indices[0]}, value);
7988a91bc7bSHarrietAkot   }
7998a91bc7bSHarrietAkot   // Close the file and return tensor.
8008a91bc7bSHarrietAkot   fclose(file);
8018a91bc7bSHarrietAkot   return tensor;
8028a91bc7bSHarrietAkot }
8038a91bc7bSHarrietAkot 
804efa15f41SAart Bik /// Writes the sparse tensor to extended FROSTT format.
805efa15f41SAart Bik template <typename V>
80646bdacaaSwren romano static void outSparseTensor(void *tensor, void *dest, bool sort) {
8076438783fSAart Bik   assert(tensor && dest);
8086438783fSAart Bik   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
8096438783fSAart Bik   if (sort)
8106438783fSAart Bik     coo->sort();
8116438783fSAart Bik   char *filename = static_cast<char *>(dest);
8126438783fSAart Bik   auto &sizes = coo->getSizes();
8136438783fSAart Bik   auto &elements = coo->getElements();
8146438783fSAart Bik   uint64_t rank = coo->getRank();
815efa15f41SAart Bik   uint64_t nnz = elements.size();
816efa15f41SAart Bik   std::fstream file;
817efa15f41SAart Bik   file.open(filename, std::ios_base::out | std::ios_base::trunc);
818efa15f41SAart Bik   assert(file.is_open());
819efa15f41SAart Bik   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
820efa15f41SAart Bik   for (uint64_t r = 0; r < rank - 1; r++)
821efa15f41SAart Bik     file << sizes[r] << " ";
822efa15f41SAart Bik   file << sizes[rank - 1] << std::endl;
823efa15f41SAart Bik   for (uint64_t i = 0; i < nnz; i++) {
824efa15f41SAart Bik     auto &idx = elements[i].indices;
825efa15f41SAart Bik     for (uint64_t r = 0; r < rank; r++)
826efa15f41SAart Bik       file << (idx[r] + 1) << " ";
827efa15f41SAart Bik     file << elements[i].value << std::endl;
828efa15f41SAart Bik   }
829efa15f41SAart Bik   file.flush();
830efa15f41SAart Bik   file.close();
831efa15f41SAart Bik   assert(file.good());
8326438783fSAart Bik }
8336438783fSAart Bik 
8346438783fSAart Bik /// Initializes sparse tensor from an external COO-flavored format.
8356438783fSAart Bik template <typename V>
83646bdacaaSwren romano static SparseTensorStorage<uint64_t, uint64_t, V> *
8376438783fSAart Bik toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
83820eaa88fSBixia Zheng                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
83920eaa88fSBixia Zheng   const DimLevelType *sparsity = (DimLevelType *)(sparse);
84020eaa88fSBixia Zheng #ifndef NDEBUG
84120eaa88fSBixia Zheng   // Verify that perm is a permutation of 0..(rank-1).
84220eaa88fSBixia Zheng   std::vector<uint64_t> order(perm, perm + rank);
84320eaa88fSBixia Zheng   std::sort(order.begin(), order.end());
8441e47888dSAart Bik   for (uint64_t i = 0; i < rank; ++i) {
84520eaa88fSBixia Zheng     if (i != order[i]) {
846988d4b0dSAart Bik       fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
84720eaa88fSBixia Zheng       exit(1);
84820eaa88fSBixia Zheng     }
84920eaa88fSBixia Zheng   }
85020eaa88fSBixia Zheng 
85120eaa88fSBixia Zheng   // Verify that the sparsity values are supported.
8521e47888dSAart Bik   for (uint64_t i = 0; i < rank; ++i) {
85320eaa88fSBixia Zheng     if (sparsity[i] != DimLevelType::kDense &&
85420eaa88fSBixia Zheng         sparsity[i] != DimLevelType::kCompressed) {
85520eaa88fSBixia Zheng       fprintf(stderr, "Unsupported sparsity value %d\n",
85620eaa88fSBixia Zheng               static_cast<int>(sparsity[i]));
85720eaa88fSBixia Zheng       exit(1);
85820eaa88fSBixia Zheng     }
85920eaa88fSBixia Zheng   }
86020eaa88fSBixia Zheng #endif
86120eaa88fSBixia Zheng 
8626438783fSAart Bik   // Convert external format to internal COO.
86363bdcaf9Swren romano   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
8646438783fSAart Bik   std::vector<uint64_t> idx(rank);
8656438783fSAart Bik   for (uint64_t i = 0, base = 0; i < nse; i++) {
8666438783fSAart Bik     for (uint64_t r = 0; r < rank; r++)
867d8b229a1SAart Bik       idx[perm[r]] = indices[base + r];
86863bdcaf9Swren romano     coo->add(idx, values[i]);
8696438783fSAart Bik     base += rank;
8706438783fSAart Bik   }
8716438783fSAart Bik   // Return sparse tensor storage format as opaque pointer.
87263bdcaf9Swren romano   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
87363bdcaf9Swren romano       rank, shape, perm, sparsity, coo);
87463bdcaf9Swren romano   delete coo;
87563bdcaf9Swren romano   return tensor;
8766438783fSAart Bik }
8776438783fSAart Bik 
8786438783fSAart Bik /// Converts a sparse tensor to an external COO-flavored format.
8796438783fSAart Bik template <typename V>
88046bdacaaSwren romano static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
88146bdacaaSwren romano                                  uint64_t **pShape, V **pValues,
88246bdacaaSwren romano                                  uint64_t **pIndices) {
8836438783fSAart Bik   auto sparseTensor =
8846438783fSAart Bik       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
8856438783fSAart Bik   uint64_t rank = sparseTensor->getRank();
8866438783fSAart Bik   std::vector<uint64_t> perm(rank);
8876438783fSAart Bik   std::iota(perm.begin(), perm.end(), 0);
8886438783fSAart Bik   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
8896438783fSAart Bik 
8906438783fSAart Bik   const std::vector<Element<V>> &elements = coo->getElements();
8916438783fSAart Bik   uint64_t nse = elements.size();
8926438783fSAart Bik 
8936438783fSAart Bik   uint64_t *shape = new uint64_t[rank];
8946438783fSAart Bik   for (uint64_t i = 0; i < rank; i++)
8956438783fSAart Bik     shape[i] = coo->getSizes()[i];
8966438783fSAart Bik 
8976438783fSAart Bik   V *values = new V[nse];
8986438783fSAart Bik   uint64_t *indices = new uint64_t[rank * nse];
8996438783fSAart Bik 
9006438783fSAart Bik   for (uint64_t i = 0, base = 0; i < nse; i++) {
9016438783fSAart Bik     values[i] = elements[i].value;
9026438783fSAart Bik     for (uint64_t j = 0; j < rank; j++)
9036438783fSAart Bik       indices[base + j] = elements[i].indices[j];
9046438783fSAart Bik     base += rank;
9056438783fSAart Bik   }
9066438783fSAart Bik 
9076438783fSAart Bik   delete coo;
9086438783fSAart Bik   *pRank = rank;
9096438783fSAart Bik   *pNse = nse;
9106438783fSAart Bik   *pShape = shape;
9116438783fSAart Bik   *pValues = values;
9126438783fSAart Bik   *pIndices = indices;
913efa15f41SAart Bik }
914efa15f41SAart Bik 
915be0a7e9fSMehdi Amini } // namespace
9168a91bc7bSHarrietAkot 
9178a91bc7bSHarrietAkot extern "C" {
9188a91bc7bSHarrietAkot 
9198a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
9208a91bc7bSHarrietAkot //
9218a91bc7bSHarrietAkot // Public API with methods that operate on MLIR buffers (memrefs) to interact
9228a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally.
9238a91bc7bSHarrietAkot // These methods should be used exclusively by MLIR compiler-generated code.
9248a91bc7bSHarrietAkot //
9258a91bc7bSHarrietAkot // Some macro magic is used to generate implementations for all required type
9268a91bc7bSHarrietAkot // combinations that can be called from MLIR compiler-generated code.
9278a91bc7bSHarrietAkot //
9288a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
9298a91bc7bSHarrietAkot 
9308a91bc7bSHarrietAkot #define CASE(p, i, v, P, I, V)                                                 \
9318a91bc7bSHarrietAkot   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
93263bdcaf9Swren romano     SparseTensorCOO<V> *coo = nullptr;                                         \
933845561ecSwren romano     if (action <= Action::kFromCOO) {                                          \
934845561ecSwren romano       if (action == Action::kFromFile) {                                       \
9358a91bc7bSHarrietAkot         char *filename = static_cast<char *>(ptr);                             \
93663bdcaf9Swren romano         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
937845561ecSwren romano       } else if (action == Action::kFromCOO) {                                 \
93863bdcaf9Swren romano         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
9398a91bc7bSHarrietAkot       } else {                                                                 \
940845561ecSwren romano         assert(action == Action::kEmpty);                                      \
9418a91bc7bSHarrietAkot       }                                                                        \
94263bdcaf9Swren romano       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
94363bdcaf9Swren romano           rank, shape, perm, sparsity, coo);                                   \
94463bdcaf9Swren romano       if (action == Action::kFromFile)                                         \
94563bdcaf9Swren romano         delete coo;                                                            \
94663bdcaf9Swren romano       return tensor;                                                           \
947bb56c2b3SMehdi Amini     }                                                                          \
948bb56c2b3SMehdi Amini     if (action == Action::kEmptyCOO)                                           \
949d83a7068Swren romano       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
95063bdcaf9Swren romano     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
951845561ecSwren romano     if (action == Action::kToIterator) {                                       \
95263bdcaf9Swren romano       coo->startIterator();                                                    \
9538a91bc7bSHarrietAkot     } else {                                                                   \
954845561ecSwren romano       assert(action == Action::kToCOO);                                        \
9558a91bc7bSHarrietAkot     }                                                                          \
95663bdcaf9Swren romano     return coo;                                                                \
9578a91bc7bSHarrietAkot   }
9588a91bc7bSHarrietAkot 
959845561ecSwren romano #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
960845561ecSwren romano 
9618a91bc7bSHarrietAkot #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
9628a91bc7bSHarrietAkot   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
9634f2ec7f9SAart Bik     assert(ref &&tensor);                                                      \
9648a91bc7bSHarrietAkot     std::vector<TYPE> *v;                                                      \
9658a91bc7bSHarrietAkot     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
9668a91bc7bSHarrietAkot     ref->basePtr = ref->data = v->data();                                      \
9678a91bc7bSHarrietAkot     ref->offset = 0;                                                           \
9688a91bc7bSHarrietAkot     ref->sizes[0] = v->size();                                                 \
9698a91bc7bSHarrietAkot     ref->strides[0] = 1;                                                       \
9708a91bc7bSHarrietAkot   }
9718a91bc7bSHarrietAkot 
9728a91bc7bSHarrietAkot #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
9738a91bc7bSHarrietAkot   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
974d2215e79SRainer Orth                            index_type d) {                                     \
9754f2ec7f9SAart Bik     assert(ref &&tensor);                                                      \
9768a91bc7bSHarrietAkot     std::vector<TYPE> *v;                                                      \
9778a91bc7bSHarrietAkot     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
9788a91bc7bSHarrietAkot     ref->basePtr = ref->data = v->data();                                      \
9798a91bc7bSHarrietAkot     ref->offset = 0;                                                           \
9808a91bc7bSHarrietAkot     ref->sizes[0] = v->size();                                                 \
9818a91bc7bSHarrietAkot     ref->strides[0] = 1;                                                       \
9828a91bc7bSHarrietAkot   }
9838a91bc7bSHarrietAkot 
9848a91bc7bSHarrietAkot #define IMPL_ADDELT(NAME, TYPE)                                                \
9858a91bc7bSHarrietAkot   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
986d2215e79SRainer Orth                             StridedMemRefType<index_type, 1> *iref,            \
987d2215e79SRainer Orth                             StridedMemRefType<index_type, 1> *pref) {          \
9884f2ec7f9SAart Bik     assert(tensor &&iref &&pref);                                              \
9898a91bc7bSHarrietAkot     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
9908a91bc7bSHarrietAkot     assert(iref->sizes[0] == pref->sizes[0]);                                  \
991d2215e79SRainer Orth     const index_type *indx = iref->data + iref->offset;                        \
992d2215e79SRainer Orth     const index_type *perm = pref->data + pref->offset;                        \
9938a91bc7bSHarrietAkot     uint64_t isize = iref->sizes[0];                                           \
994d2215e79SRainer Orth     std::vector<index_type> indices(isize);                                    \
9958a91bc7bSHarrietAkot     for (uint64_t r = 0; r < isize; r++)                                       \
9968a91bc7bSHarrietAkot       indices[perm[r]] = indx[r];                                              \
9978a91bc7bSHarrietAkot     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
9988a91bc7bSHarrietAkot     return tensor;                                                             \
9998a91bc7bSHarrietAkot   }
10008a91bc7bSHarrietAkot 
10018a91bc7bSHarrietAkot #define IMPL_GETNEXT(NAME, V)                                                  \
1002d2215e79SRainer Orth   bool _mlir_ciface_##NAME(void *tensor,                                       \
1003d2215e79SRainer Orth                            StridedMemRefType<index_type, 1> *iref,             \
10048a91bc7bSHarrietAkot                            StridedMemRefType<V, 0> *vref) {                    \
10054f2ec7f9SAart Bik     assert(tensor &&iref &&vref);                                              \
10068a91bc7bSHarrietAkot     assert(iref->strides[0] == 1);                                             \
1007d2215e79SRainer Orth     index_type *indx = iref->data + iref->offset;                              \
1008c9f2beffSMehdi Amini     V *value = vref->data + vref->offset;                                      \
10098a91bc7bSHarrietAkot     const uint64_t isize = iref->sizes[0];                                     \
10108a91bc7bSHarrietAkot     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
10118a91bc7bSHarrietAkot     const Element<V> *elem = iter->getNext();                                  \
101263bdcaf9Swren romano     if (elem == nullptr)                                                       \
10138a91bc7bSHarrietAkot       return false;                                                            \
10148a91bc7bSHarrietAkot     for (uint64_t r = 0; r < isize; r++)                                       \
10158a91bc7bSHarrietAkot       indx[r] = elem->indices[r];                                              \
10168a91bc7bSHarrietAkot     *value = elem->value;                                                      \
10178a91bc7bSHarrietAkot     return true;                                                               \
10188a91bc7bSHarrietAkot   }
10198a91bc7bSHarrietAkot 
1020f66e5769SAart Bik #define IMPL_LEXINSERT(NAME, V)                                                \
1021d2215e79SRainer Orth   void _mlir_ciface_##NAME(void *tensor,                                       \
1022d2215e79SRainer Orth                            StridedMemRefType<index_type, 1> *cref, V val) {    \
10234f2ec7f9SAart Bik     assert(tensor &&cref);                                                     \
1024f66e5769SAart Bik     assert(cref->strides[0] == 1);                                             \
1025d2215e79SRainer Orth     index_type *cursor = cref->data + cref->offset;                            \
1026f66e5769SAart Bik     assert(cursor);                                                            \
1027f66e5769SAart Bik     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
1028f66e5769SAart Bik   }
1029f66e5769SAart Bik 
10304f2ec7f9SAart Bik #define IMPL_EXPINSERT(NAME, V)                                                \
10314f2ec7f9SAart Bik   void _mlir_ciface_##NAME(                                                    \
1032d2215e79SRainer Orth       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
10334f2ec7f9SAart Bik       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
1034d2215e79SRainer Orth       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
10354f2ec7f9SAart Bik     assert(tensor &&cref &&vref &&fref &&aref);                                \
10364f2ec7f9SAart Bik     assert(cref->strides[0] == 1);                                             \
10374f2ec7f9SAart Bik     assert(vref->strides[0] == 1);                                             \
10384f2ec7f9SAart Bik     assert(fref->strides[0] == 1);                                             \
10394f2ec7f9SAart Bik     assert(aref->strides[0] == 1);                                             \
10404f2ec7f9SAart Bik     assert(vref->sizes[0] == fref->sizes[0]);                                  \
1041d2215e79SRainer Orth     index_type *cursor = cref->data + cref->offset;                            \
1042c9f2beffSMehdi Amini     V *values = vref->data + vref->offset;                                     \
10434f2ec7f9SAart Bik     bool *filled = fref->data + fref->offset;                                  \
1044d2215e79SRainer Orth     index_type *added = aref->data + aref->offset;                             \
10454f2ec7f9SAart Bik     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
10464f2ec7f9SAart Bik         cursor, values, filled, added, count);                                 \
10474f2ec7f9SAart Bik   }
10484f2ec7f9SAart Bik 
1049d2215e79SRainer Orth // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
1050bc04a470Swren romano // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
1051bc04a470Swren romano // that this file cannot get out of sync with its header.
1052d2215e79SRainer Orth static_assert(std::is_same<index_type, uint64_t>::value,
1053d2215e79SRainer Orth               "Expected index_type == uint64_t");
1054bc04a470Swren romano 
10558a91bc7bSHarrietAkot /// Constructs a new sparse tensor. This is the "swiss army knife"
10568a91bc7bSHarrietAkot /// method for materializing sparse tensors into the computation.
10578a91bc7bSHarrietAkot ///
1058845561ecSwren romano /// Action:
10598a91bc7bSHarrietAkot /// kEmpty = returns empty storage to fill later
10608a91bc7bSHarrietAkot /// kFromFile = returns storage, where ptr contains filename to read
10618a91bc7bSHarrietAkot /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
10628a91bc7bSHarrietAkot /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
10638a91bc7bSHarrietAkot /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
1064845561ecSwren romano /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
10658a91bc7bSHarrietAkot void *
1066845561ecSwren romano _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
1067d2215e79SRainer Orth                              StridedMemRefType<index_type, 1> *sref,
1068d2215e79SRainer Orth                              StridedMemRefType<index_type, 1> *pref,
1069845561ecSwren romano                              OverheadType ptrTp, OverheadType indTp,
1070845561ecSwren romano                              PrimaryType valTp, Action action, void *ptr) {
10718a91bc7bSHarrietAkot   assert(aref && sref && pref);
10728a91bc7bSHarrietAkot   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
10738a91bc7bSHarrietAkot          pref->strides[0] == 1);
10748a91bc7bSHarrietAkot   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
1075845561ecSwren romano   const DimLevelType *sparsity = aref->data + aref->offset;
1076d83a7068Swren romano   const index_type *shape = sref->data + sref->offset;
1077d2215e79SRainer Orth   const index_type *perm = pref->data + pref->offset;
10788a91bc7bSHarrietAkot   uint64_t rank = aref->sizes[0];
10798a91bc7bSHarrietAkot 
1080bc04a470Swren romano   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
1081bc04a470Swren romano   // This is safe because of the static_assert above.
1082bc04a470Swren romano   if (ptrTp == OverheadType::kIndex)
1083bc04a470Swren romano     ptrTp = OverheadType::kU64;
1084bc04a470Swren romano   if (indTp == OverheadType::kIndex)
1085bc04a470Swren romano     indTp = OverheadType::kU64;
1086bc04a470Swren romano 
10878a91bc7bSHarrietAkot   // Double matrices with all combinations of overhead storage.
1088845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1089845561ecSwren romano        uint64_t, double);
1090845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1091845561ecSwren romano        uint32_t, double);
1092845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1093845561ecSwren romano        uint16_t, double);
1094845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1095845561ecSwren romano        uint8_t, double);
1096845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1097845561ecSwren romano        uint64_t, double);
1098845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1099845561ecSwren romano        uint32_t, double);
1100845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1101845561ecSwren romano        uint16_t, double);
1102845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1103845561ecSwren romano        uint8_t, double);
1104845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1105845561ecSwren romano        uint64_t, double);
1106845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1107845561ecSwren romano        uint32_t, double);
1108845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1109845561ecSwren romano        uint16_t, double);
1110845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1111845561ecSwren romano        uint8_t, double);
1112845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1113845561ecSwren romano        uint64_t, double);
1114845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1115845561ecSwren romano        uint32_t, double);
1116845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1117845561ecSwren romano        uint16_t, double);
1118845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1119845561ecSwren romano        uint8_t, double);
11208a91bc7bSHarrietAkot 
11218a91bc7bSHarrietAkot   // Float matrices with all combinations of overhead storage.
1122845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1123845561ecSwren romano        uint64_t, float);
1124845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1125845561ecSwren romano        uint32_t, float);
1126845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1127845561ecSwren romano        uint16_t, float);
1128845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1129845561ecSwren romano        uint8_t, float);
1130845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1131845561ecSwren romano        uint64_t, float);
1132845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1133845561ecSwren romano        uint32_t, float);
1134845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1135845561ecSwren romano        uint16_t, float);
1136845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1137845561ecSwren romano        uint8_t, float);
1138845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1139845561ecSwren romano        uint64_t, float);
1140845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1141845561ecSwren romano        uint32_t, float);
1142845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1143845561ecSwren romano        uint16_t, float);
1144845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1145845561ecSwren romano        uint8_t, float);
1146845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1147845561ecSwren romano        uint64_t, float);
1148845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1149845561ecSwren romano        uint32_t, float);
1150845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1151845561ecSwren romano        uint16_t, float);
1152845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1153845561ecSwren romano        uint8_t, float);
11548a91bc7bSHarrietAkot 
1155845561ecSwren romano   // Integral matrices with both overheads of the same type.
1156845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1157845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1158845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1159845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1160845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1161845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1162845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1163845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1164845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1165845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1166845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1167845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1168845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
11698a91bc7bSHarrietAkot 
11708a91bc7bSHarrietAkot   // Unsupported case (add above if needed).
11718a91bc7bSHarrietAkot   fputs("unsupported combination of types\n", stderr);
11728a91bc7bSHarrietAkot   exit(1);
11738a91bc7bSHarrietAkot }
11748a91bc7bSHarrietAkot 
11758a91bc7bSHarrietAkot /// Methods that provide direct access to pointers.
1176d2215e79SRainer Orth IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
11778a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
11788a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
11798a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
11808a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
11818a91bc7bSHarrietAkot 
11828a91bc7bSHarrietAkot /// Methods that provide direct access to indices.
1183d2215e79SRainer Orth IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
11848a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
11858a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
11868a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
11878a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
11888a91bc7bSHarrietAkot 
11898a91bc7bSHarrietAkot /// Methods that provide direct access to values.
11908a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
11918a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
11928a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
11938a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
11948a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
11958a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
11968a91bc7bSHarrietAkot 
11978a91bc7bSHarrietAkot /// Helper to add value to coordinate scheme, one per value type.
11988a91bc7bSHarrietAkot IMPL_ADDELT(addEltF64, double)
11998a91bc7bSHarrietAkot IMPL_ADDELT(addEltF32, float)
12008a91bc7bSHarrietAkot IMPL_ADDELT(addEltI64, int64_t)
12018a91bc7bSHarrietAkot IMPL_ADDELT(addEltI32, int32_t)
12028a91bc7bSHarrietAkot IMPL_ADDELT(addEltI16, int16_t)
12038a91bc7bSHarrietAkot IMPL_ADDELT(addEltI8, int8_t)
12048a91bc7bSHarrietAkot 
12058a91bc7bSHarrietAkot /// Helper to enumerate elements of coordinate scheme, one per value type.
12068a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF64, double)
12078a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF32, float)
12088a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI64, int64_t)
12098a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI32, int32_t)
12108a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI16, int16_t)
12118a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI8, int8_t)
12128a91bc7bSHarrietAkot 
12136438783fSAart Bik /// Insert elements in lexicographical index order, one per value type.
1214f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF64, double)
1215f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF32, float)
1216f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI64, int64_t)
1217f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI32, int32_t)
1218f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI16, int16_t)
1219f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI8, int8_t)
1220f66e5769SAart Bik 
12216438783fSAart Bik /// Insert using expansion, one per value type.
12224f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF64, double)
12234f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF32, float)
12244f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI64, int64_t)
12254f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI32, int32_t)
12264f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI16, int16_t)
12274f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI8, int8_t)
12284f2ec7f9SAart Bik 
12298a91bc7bSHarrietAkot #undef CASE
12308a91bc7bSHarrietAkot #undef IMPL_SPARSEVALUES
12318a91bc7bSHarrietAkot #undef IMPL_GETOVERHEAD
12328a91bc7bSHarrietAkot #undef IMPL_ADDELT
12338a91bc7bSHarrietAkot #undef IMPL_GETNEXT
12344f2ec7f9SAart Bik #undef IMPL_LEXINSERT
12354f2ec7f9SAart Bik #undef IMPL_EXPINSERT
12366438783fSAart Bik 
12376438783fSAart Bik /// Output a sparse tensor, one per value type.
12386438783fSAart Bik void outSparseTensorF64(void *tensor, void *dest, bool sort) {
12396438783fSAart Bik   return outSparseTensor<double>(tensor, dest, sort);
12406438783fSAart Bik }
12416438783fSAart Bik void outSparseTensorF32(void *tensor, void *dest, bool sort) {
12426438783fSAart Bik   return outSparseTensor<float>(tensor, dest, sort);
12436438783fSAart Bik }
12446438783fSAart Bik void outSparseTensorI64(void *tensor, void *dest, bool sort) {
12456438783fSAart Bik   return outSparseTensor<int64_t>(tensor, dest, sort);
12466438783fSAart Bik }
12476438783fSAart Bik void outSparseTensorI32(void *tensor, void *dest, bool sort) {
12486438783fSAart Bik   return outSparseTensor<int32_t>(tensor, dest, sort);
12496438783fSAart Bik }
12506438783fSAart Bik void outSparseTensorI16(void *tensor, void *dest, bool sort) {
12516438783fSAart Bik   return outSparseTensor<int16_t>(tensor, dest, sort);
12526438783fSAart Bik }
12536438783fSAart Bik void outSparseTensorI8(void *tensor, void *dest, bool sort) {
12546438783fSAart Bik   return outSparseTensor<int8_t>(tensor, dest, sort);
12556438783fSAart Bik }
12568a91bc7bSHarrietAkot 
12578a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
12588a91bc7bSHarrietAkot //
12598a91bc7bSHarrietAkot // Public API with methods that accept C-style data structures to interact
12608a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally.
12618a91bc7bSHarrietAkot // These methods can be used both by MLIR compiler-generated code as well as by
12628a91bc7bSHarrietAkot // an external runtime that wants to interact with MLIR compiler-generated code.
12638a91bc7bSHarrietAkot //
12648a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
12658a91bc7bSHarrietAkot 
12668a91bc7bSHarrietAkot /// Helper method to read a sparse tensor filename from the environment,
12678a91bc7bSHarrietAkot /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1268d2215e79SRainer Orth char *getTensorFilename(index_type id) {
12698a91bc7bSHarrietAkot   char var[80];
12708a91bc7bSHarrietAkot   sprintf(var, "TENSOR%" PRIu64, id);
12718a91bc7bSHarrietAkot   char *env = getenv(var);
12723734c078Swren romano   if (!env) {
12733734c078Swren romano     fprintf(stderr, "Environment variable %s is not set\n", var);
12743734c078Swren romano     exit(1);
12753734c078Swren romano   }
12768a91bc7bSHarrietAkot   return env;
12778a91bc7bSHarrietAkot }
12788a91bc7bSHarrietAkot 
12798a91bc7bSHarrietAkot /// Returns size of sparse tensor in given dimension.
1280d2215e79SRainer Orth index_type sparseDimSize(void *tensor, index_type d) {
12818a91bc7bSHarrietAkot   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
12828a91bc7bSHarrietAkot }
12838a91bc7bSHarrietAkot 
1284f66e5769SAart Bik /// Finalizes lexicographic insertions.
1285f66e5769SAart Bik void endInsert(void *tensor) {
1286f66e5769SAart Bik   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1287f66e5769SAart Bik }
1288f66e5769SAart Bik 
12898a91bc7bSHarrietAkot /// Releases sparse tensor storage.
12908a91bc7bSHarrietAkot void delSparseTensor(void *tensor) {
12918a91bc7bSHarrietAkot   delete static_cast<SparseTensorStorageBase *>(tensor);
12928a91bc7bSHarrietAkot }
12938a91bc7bSHarrietAkot 
129463bdcaf9Swren romano /// Releases sparse tensor coordinate scheme.
129563bdcaf9Swren romano #define IMPL_DELCOO(VNAME, V)                                                  \
129663bdcaf9Swren romano   void delSparseTensorCOO##VNAME(void *coo) {                                  \
129763bdcaf9Swren romano     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
129863bdcaf9Swren romano   }
129963bdcaf9Swren romano IMPL_DELCOO(F64, double)
130063bdcaf9Swren romano IMPL_DELCOO(F32, float)
130163bdcaf9Swren romano IMPL_DELCOO(I64, int64_t)
130263bdcaf9Swren romano IMPL_DELCOO(I32, int32_t)
130363bdcaf9Swren romano IMPL_DELCOO(I16, int16_t)
130463bdcaf9Swren romano IMPL_DELCOO(I8, int8_t)
130563bdcaf9Swren romano #undef IMPL_DELCOO
130663bdcaf9Swren romano 
13078a91bc7bSHarrietAkot /// Initializes sparse tensor from a COO-flavored format expressed using C-style
13088a91bc7bSHarrietAkot /// data structures. The expected parameters are:
13098a91bc7bSHarrietAkot ///
13108a91bc7bSHarrietAkot ///   rank:    rank of tensor
13118a91bc7bSHarrietAkot ///   nse:     number of specified elements (usually the nonzeros)
13128a91bc7bSHarrietAkot ///   shape:   array with dimension size for each rank
13138a91bc7bSHarrietAkot ///   values:  a "nse" array with values for all specified elements
13148a91bc7bSHarrietAkot ///   indices: a flat "nse x rank" array with indices for all specified elements
131520eaa88fSBixia Zheng ///   perm:    the permutation of the dimensions in the storage
131620eaa88fSBixia Zheng ///   sparse:  the sparsity for the dimensions
13178a91bc7bSHarrietAkot ///
13188a91bc7bSHarrietAkot /// For example, the sparse matrix
13198a91bc7bSHarrietAkot ///     | 1.0 0.0 0.0 |
13208a91bc7bSHarrietAkot ///     | 0.0 5.0 3.0 |
13218a91bc7bSHarrietAkot /// can be passed as
13228a91bc7bSHarrietAkot ///      rank    = 2
13238a91bc7bSHarrietAkot ///      nse     = 3
13248a91bc7bSHarrietAkot ///      shape   = [2, 3]
13258a91bc7bSHarrietAkot ///      values  = [1.0, 5.0, 3.0]
13268a91bc7bSHarrietAkot ///      indices = [ 0, 0,  1, 1,  1, 2]
13278a91bc7bSHarrietAkot //
132820eaa88fSBixia Zheng // TODO: generalize beyond 64-bit indices.
13298a91bc7bSHarrietAkot //
13306438783fSAart Bik void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape,
133120eaa88fSBixia Zheng                                    double *values, uint64_t *indices,
133220eaa88fSBixia Zheng                                    uint64_t *perm, uint8_t *sparse) {
133320eaa88fSBixia Zheng   return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm,
133420eaa88fSBixia Zheng                                     sparse);
13358a91bc7bSHarrietAkot }
13366438783fSAart Bik void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
133720eaa88fSBixia Zheng                                    float *values, uint64_t *indices,
133820eaa88fSBixia Zheng                                    uint64_t *perm, uint8_t *sparse) {
133920eaa88fSBixia Zheng   return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm,
134020eaa88fSBixia Zheng                                    sparse);
13418a91bc7bSHarrietAkot }
13428a91bc7bSHarrietAkot 
13432f49e6b0SBixia Zheng /// Converts a sparse tensor to COO-flavored format expressed using C-style
13442f49e6b0SBixia Zheng /// data structures. The expected output parameters are pointers for these
13452f49e6b0SBixia Zheng /// values:
13462f49e6b0SBixia Zheng ///
13472f49e6b0SBixia Zheng ///   rank:    rank of tensor
13482f49e6b0SBixia Zheng ///   nse:     number of specified elements (usually the nonzeros)
13492f49e6b0SBixia Zheng ///   shape:   array with dimension size for each rank
13502f49e6b0SBixia Zheng ///   values:  a "nse" array with values for all specified elements
13512f49e6b0SBixia Zheng ///   indices: a flat "nse x rank" array with indices for all specified elements
13522f49e6b0SBixia Zheng ///
13532f49e6b0SBixia Zheng /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
13542f49e6b0SBixia Zheng /// from convertToMLIRSparseTensor.
13552f49e6b0SBixia Zheng ///
13562f49e6b0SBixia Zheng //  TODO: Currently, values are copied from SparseTensorStorage to
13572f49e6b0SBixia Zheng //  SparseTensorCOO, then to the output. We may want to reduce the number of
13582f49e6b0SBixia Zheng //  copies.
13592f49e6b0SBixia Zheng //
13606438783fSAart Bik // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
13616438783fSAart Bik // compressed
13622f49e6b0SBixia Zheng //
13636438783fSAart Bik void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank,
13646438783fSAart Bik                                     uint64_t *pNse, uint64_t **pShape,
13656438783fSAart Bik                                     double **pValues, uint64_t **pIndices) {
13666438783fSAart Bik   fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices);
13672f49e6b0SBixia Zheng }
13686438783fSAart Bik void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
13696438783fSAart Bik                                     uint64_t *pNse, uint64_t **pShape,
13706438783fSAart Bik                                     float **pValues, uint64_t **pIndices) {
13716438783fSAart Bik   fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
13722f49e6b0SBixia Zheng }
1373efa15f41SAart Bik 
13748a91bc7bSHarrietAkot } // extern "C"
13758a91bc7bSHarrietAkot 
13768a91bc7bSHarrietAkot #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1377