18a91bc7bSHarrietAkot //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
28a91bc7bSHarrietAkot //
38a91bc7bSHarrietAkot // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48a91bc7bSHarrietAkot // See https://llvm.org/LICENSE.txt for license information.
58a91bc7bSHarrietAkot // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a91bc7bSHarrietAkot //
78a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
88a91bc7bSHarrietAkot //
98a91bc7bSHarrietAkot // This file implements a light-weight runtime support library that is useful
108a91bc7bSHarrietAkot // for sparse tensor manipulations. The functionality provided in this library
118a91bc7bSHarrietAkot // is meant to simplify benchmarking, testing, and debugging MLIR code that
128a91bc7bSHarrietAkot // operates on sparse tensors. The provided functionality is **not** part
138a91bc7bSHarrietAkot // of core MLIR, however.
148a91bc7bSHarrietAkot //
158a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
168a91bc7bSHarrietAkot 
17845561ecSwren romano #include "mlir/ExecutionEngine/SparseTensorUtils.h"
188a91bc7bSHarrietAkot #include "mlir/ExecutionEngine/CRunnerUtils.h"
198a91bc7bSHarrietAkot 
208a91bc7bSHarrietAkot #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
218a91bc7bSHarrietAkot 
228a91bc7bSHarrietAkot #include <algorithm>
238a91bc7bSHarrietAkot #include <cassert>
248a91bc7bSHarrietAkot #include <cctype>
258a91bc7bSHarrietAkot #include <cinttypes>
268a91bc7bSHarrietAkot #include <cstdio>
278a91bc7bSHarrietAkot #include <cstdlib>
288a91bc7bSHarrietAkot #include <cstring>
298a91bc7bSHarrietAkot #include <numeric>
308a91bc7bSHarrietAkot #include <vector>
318a91bc7bSHarrietAkot 
328a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
338a91bc7bSHarrietAkot //
348a91bc7bSHarrietAkot // Internal support for storing and reading sparse tensors.
358a91bc7bSHarrietAkot //
368a91bc7bSHarrietAkot // The following memory-resident sparse storage schemes are supported:
378a91bc7bSHarrietAkot //
388a91bc7bSHarrietAkot // (a) A coordinate scheme for temporarily storing and lexicographically
398a91bc7bSHarrietAkot //     sorting a sparse tensor by index (SparseTensorCOO).
408a91bc7bSHarrietAkot //
418a91bc7bSHarrietAkot // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
428a91bc7bSHarrietAkot //     per-dimension sparse/dense annnotations together with a dimension
438a91bc7bSHarrietAkot //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
448a91bc7bSHarrietAkot //
458a91bc7bSHarrietAkot // The following external formats are supported:
468a91bc7bSHarrietAkot //
478a91bc7bSHarrietAkot // (1) Matrix Market Exchange (MME): *.mtx
488a91bc7bSHarrietAkot //     https://math.nist.gov/MatrixMarket/formats.html
498a91bc7bSHarrietAkot //
508a91bc7bSHarrietAkot // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
518a91bc7bSHarrietAkot //     http://frostt.io/tensors/file-formats.html
528a91bc7bSHarrietAkot //
538a91bc7bSHarrietAkot // Two public APIs are supported:
548a91bc7bSHarrietAkot //
558a91bc7bSHarrietAkot // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
568a91bc7bSHarrietAkot //     tensors. These methods should be used exclusively by MLIR
578a91bc7bSHarrietAkot //     compiler-generated code.
588a91bc7bSHarrietAkot //
598a91bc7bSHarrietAkot // (II) Methods that accept C-style data structures to interact with sparse
608a91bc7bSHarrietAkot //      tensors. These methods can be used by any external runtime that wants
618a91bc7bSHarrietAkot //      to interact with MLIR compiler-generated code.
628a91bc7bSHarrietAkot //
638a91bc7bSHarrietAkot // In both cases (I) and (II), the SparseTensorStorage format is externally
648a91bc7bSHarrietAkot // only visible as an opaque pointer.
658a91bc7bSHarrietAkot //
668a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
678a91bc7bSHarrietAkot 
688a91bc7bSHarrietAkot namespace {
698a91bc7bSHarrietAkot 
7003fe15ceSAart Bik static constexpr int kColWidth = 1025;
7103fe15ceSAart Bik 
728a91bc7bSHarrietAkot /// A sparse tensor element in coordinate scheme (value and indices).
738a91bc7bSHarrietAkot /// For example, a rank-1 vector element would look like
748a91bc7bSHarrietAkot ///   ({i}, a[i])
758a91bc7bSHarrietAkot /// and a rank-5 tensor element like
768a91bc7bSHarrietAkot ///   ({i,j,k,l,m}, a[i,j,k,l,m])
778a91bc7bSHarrietAkot template <typename V>
788a91bc7bSHarrietAkot struct Element {
798a91bc7bSHarrietAkot   Element(const std::vector<uint64_t> &ind, V val) : indices(ind), value(val){};
808a91bc7bSHarrietAkot   std::vector<uint64_t> indices;
818a91bc7bSHarrietAkot   V value;
828a91bc7bSHarrietAkot };
838a91bc7bSHarrietAkot 
848a91bc7bSHarrietAkot /// A memory-resident sparse tensor in coordinate scheme (collection of
858a91bc7bSHarrietAkot /// elements). This data structure is used to read a sparse tensor from
868a91bc7bSHarrietAkot /// any external format into memory and sort the elements lexicographically
878a91bc7bSHarrietAkot /// by indices before passing it back to the client (most packed storage
888a91bc7bSHarrietAkot /// formats require the elements to appear in lexicographic index order).
898a91bc7bSHarrietAkot template <typename V>
908a91bc7bSHarrietAkot struct SparseTensorCOO {
918a91bc7bSHarrietAkot public:
928a91bc7bSHarrietAkot   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
938a91bc7bSHarrietAkot       : sizes(szs), iteratorLocked(false), iteratorPos(0) {
948a91bc7bSHarrietAkot     if (capacity)
958a91bc7bSHarrietAkot       elements.reserve(capacity);
968a91bc7bSHarrietAkot   }
978a91bc7bSHarrietAkot   /// Adds element as indices and value.
988a91bc7bSHarrietAkot   void add(const std::vector<uint64_t> &ind, V val) {
998a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to add() after startIterator()");
1008a91bc7bSHarrietAkot     uint64_t rank = getRank();
1018a91bc7bSHarrietAkot     assert(rank == ind.size());
1028a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
1038a91bc7bSHarrietAkot       assert(ind[r] < sizes[r]); // within bounds
1048a91bc7bSHarrietAkot     elements.emplace_back(ind, val);
1058a91bc7bSHarrietAkot   }
1068a91bc7bSHarrietAkot   /// Sorts elements lexicographically by index.
1078a91bc7bSHarrietAkot   void sort() {
1088a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
1098a91bc7bSHarrietAkot     std::sort(elements.begin(), elements.end(), lexOrder);
1108a91bc7bSHarrietAkot   }
1118a91bc7bSHarrietAkot   /// Returns rank.
1128a91bc7bSHarrietAkot   uint64_t getRank() const { return sizes.size(); }
1138a91bc7bSHarrietAkot   /// Getter for sizes array.
1148a91bc7bSHarrietAkot   const std::vector<uint64_t> &getSizes() const { return sizes; }
1158a91bc7bSHarrietAkot   /// Getter for elements array.
1168a91bc7bSHarrietAkot   const std::vector<Element<V>> &getElements() const { return elements; }
1178a91bc7bSHarrietAkot 
1188a91bc7bSHarrietAkot   /// Switch into iterator mode.
1198a91bc7bSHarrietAkot   void startIterator() {
1208a91bc7bSHarrietAkot     iteratorLocked = true;
1218a91bc7bSHarrietAkot     iteratorPos = 0;
1228a91bc7bSHarrietAkot   }
1238a91bc7bSHarrietAkot   /// Get the next element.
1248a91bc7bSHarrietAkot   const Element<V> *getNext() {
1258a91bc7bSHarrietAkot     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
1268a91bc7bSHarrietAkot     if (iteratorPos < elements.size())
1278a91bc7bSHarrietAkot       return &(elements[iteratorPos++]);
1288a91bc7bSHarrietAkot     iteratorLocked = false;
1298a91bc7bSHarrietAkot     return nullptr;
1308a91bc7bSHarrietAkot   }
1318a91bc7bSHarrietAkot 
1328a91bc7bSHarrietAkot   /// Factory method. Permutes the original dimensions according to
1338a91bc7bSHarrietAkot   /// the given ordering and expects subsequent add() calls to honor
1348a91bc7bSHarrietAkot   /// that same ordering for the given indices. The result is a
1358a91bc7bSHarrietAkot   /// fully permuted coordinate scheme.
1368a91bc7bSHarrietAkot   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
1378a91bc7bSHarrietAkot                                                 const uint64_t *sizes,
1388a91bc7bSHarrietAkot                                                 const uint64_t *perm,
1398a91bc7bSHarrietAkot                                                 uint64_t capacity = 0) {
1408a91bc7bSHarrietAkot     std::vector<uint64_t> permsz(rank);
1418a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
1428a91bc7bSHarrietAkot       permsz[perm[r]] = sizes[r];
1438a91bc7bSHarrietAkot     return new SparseTensorCOO<V>(permsz, capacity);
1448a91bc7bSHarrietAkot   }
1458a91bc7bSHarrietAkot 
1468a91bc7bSHarrietAkot private:
1478a91bc7bSHarrietAkot   /// Returns true if indices of e1 < indices of e2.
1488a91bc7bSHarrietAkot   static bool lexOrder(const Element<V> &e1, const Element<V> &e2) {
1498a91bc7bSHarrietAkot     uint64_t rank = e1.indices.size();
1508a91bc7bSHarrietAkot     assert(rank == e2.indices.size());
1518a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++) {
1528a91bc7bSHarrietAkot       if (e1.indices[r] == e2.indices[r])
1538a91bc7bSHarrietAkot         continue;
1548a91bc7bSHarrietAkot       return e1.indices[r] < e2.indices[r];
1558a91bc7bSHarrietAkot     }
1568a91bc7bSHarrietAkot     return false;
1578a91bc7bSHarrietAkot   }
1588a91bc7bSHarrietAkot   const std::vector<uint64_t> sizes; // per-dimension sizes
1598a91bc7bSHarrietAkot   std::vector<Element<V>> elements;
1608a91bc7bSHarrietAkot   bool iteratorLocked;
1618a91bc7bSHarrietAkot   unsigned iteratorPos;
1628a91bc7bSHarrietAkot };
1638a91bc7bSHarrietAkot 
1648a91bc7bSHarrietAkot /// Abstract base class of sparse tensor storage. Note that we use
1658a91bc7bSHarrietAkot /// function overloading to implement "partial" method specialization.
1668a91bc7bSHarrietAkot class SparseTensorStorageBase {
1678a91bc7bSHarrietAkot public:
1684f2ec7f9SAart Bik   /// Dimension size query.
1698a91bc7bSHarrietAkot   virtual uint64_t getDimSize(uint64_t) = 0;
1708a91bc7bSHarrietAkot 
1714f2ec7f9SAart Bik   /// Overhead storage.
1728a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
1738a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
1748a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
1758a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
1768a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
1778a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
1788a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
1798a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
1808a91bc7bSHarrietAkot 
1814f2ec7f9SAart Bik   /// Primary storage.
1828a91bc7bSHarrietAkot   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
1838a91bc7bSHarrietAkot   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
1848a91bc7bSHarrietAkot   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
1858a91bc7bSHarrietAkot   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
1868a91bc7bSHarrietAkot   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
1878a91bc7bSHarrietAkot   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
1888a91bc7bSHarrietAkot 
1894f2ec7f9SAart Bik   /// Element-wise insertion in lexicographic index order.
190f66e5769SAart Bik   virtual void lexInsert(uint64_t *, double) { fatal("insf64"); }
191f66e5769SAart Bik   virtual void lexInsert(uint64_t *, float) { fatal("insf32"); }
192f66e5769SAart Bik   virtual void lexInsert(uint64_t *, int64_t) { fatal("insi64"); }
193f66e5769SAart Bik   virtual void lexInsert(uint64_t *, int32_t) { fatal("insi32"); }
194f66e5769SAart Bik   virtual void lexInsert(uint64_t *, int16_t) { fatal("ins16"); }
195f66e5769SAart Bik   virtual void lexInsert(uint64_t *, int8_t) { fatal("insi8"); }
1964f2ec7f9SAart Bik 
1974f2ec7f9SAart Bik   /// Expanded insertion.
1984f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
1994f2ec7f9SAart Bik     fatal("expf64");
2004f2ec7f9SAart Bik   }
2014f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
2024f2ec7f9SAart Bik     fatal("expf32");
2034f2ec7f9SAart Bik   }
2044f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
2054f2ec7f9SAart Bik     fatal("expi64");
2064f2ec7f9SAart Bik   }
2074f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
2084f2ec7f9SAart Bik     fatal("expi32");
2094f2ec7f9SAart Bik   }
2104f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
2114f2ec7f9SAart Bik     fatal("expi16");
2124f2ec7f9SAart Bik   }
2134f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
2144f2ec7f9SAart Bik     fatal("expi8");
2154f2ec7f9SAart Bik   }
2164f2ec7f9SAart Bik 
2174f2ec7f9SAart Bik   /// Finishes insertion.
218f66e5769SAart Bik   virtual void endInsert() = 0;
219f66e5769SAart Bik 
2208a91bc7bSHarrietAkot   virtual ~SparseTensorStorageBase() {}
2218a91bc7bSHarrietAkot 
2228a91bc7bSHarrietAkot private:
2238a91bc7bSHarrietAkot   void fatal(const char *tp) {
2248a91bc7bSHarrietAkot     fprintf(stderr, "unsupported %s\n", tp);
2258a91bc7bSHarrietAkot     exit(1);
2268a91bc7bSHarrietAkot   }
2278a91bc7bSHarrietAkot };
2288a91bc7bSHarrietAkot 
2298a91bc7bSHarrietAkot /// A memory-resident sparse tensor using a storage scheme based on
2308a91bc7bSHarrietAkot /// per-dimension sparse/dense annotations. This data structure provides a
2318a91bc7bSHarrietAkot /// bufferized form of a sparse tensor type. In contrast to generating setup
2328a91bc7bSHarrietAkot /// methods for each differently annotated sparse tensor, this method provides
2338a91bc7bSHarrietAkot /// a convenient "one-size-fits-all" solution that simply takes an input tensor
2348a91bc7bSHarrietAkot /// and annotations to implement all required setup in a general manner.
2358a91bc7bSHarrietAkot template <typename P, typename I, typename V>
2368a91bc7bSHarrietAkot class SparseTensorStorage : public SparseTensorStorageBase {
2378a91bc7bSHarrietAkot public:
2388a91bc7bSHarrietAkot   /// Constructs a sparse tensor storage scheme with the given dimensions,
2398a91bc7bSHarrietAkot   /// permutation, and per-dimension dense/sparse annotations, using
2408a91bc7bSHarrietAkot   /// the coordinate scheme tensor for the initial contents if provided.
2418a91bc7bSHarrietAkot   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
242f66e5769SAart Bik                       const DimLevelType *sparsity,
243f66e5769SAart Bik                       SparseTensorCOO<V> *tensor = nullptr)
244f66e5769SAart Bik       : sizes(szs), rev(getRank()), idx(getRank()), pointers(getRank()),
245f66e5769SAart Bik         indices(getRank()) {
2468a91bc7bSHarrietAkot     uint64_t rank = getRank();
2478a91bc7bSHarrietAkot     // Store "reverse" permutation.
2488a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
2498a91bc7bSHarrietAkot       rev[perm[r]] = r;
2508a91bc7bSHarrietAkot     // Provide hints on capacity of pointers and indices.
2518a91bc7bSHarrietAkot     // TODO: needs fine-tuning based on sparsity
252f66e5769SAart Bik     bool allDense = true;
253f66e5769SAart Bik     uint64_t sz = 1;
254f66e5769SAart Bik     for (uint64_t r = 0; r < rank; r++) {
255f66e5769SAart Bik       sz *= sizes[r];
256845561ecSwren romano       if (sparsity[r] == DimLevelType::kCompressed) {
257f66e5769SAart Bik         pointers[r].reserve(sz + 1);
258f66e5769SAart Bik         indices[r].reserve(sz);
259f66e5769SAart Bik         sz = 1;
260f66e5769SAart Bik         allDense = false;
2618a91bc7bSHarrietAkot       } else {
262845561ecSwren romano         assert(sparsity[r] == DimLevelType::kDense &&
263845561ecSwren romano                "singleton not yet supported");
2648a91bc7bSHarrietAkot       }
2658a91bc7bSHarrietAkot     }
2668a91bc7bSHarrietAkot     // Prepare sparse pointer structures for all dimensions.
2678a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
268845561ecSwren romano       if (sparsity[r] == DimLevelType::kCompressed)
2698a91bc7bSHarrietAkot         pointers[r].push_back(0);
2708a91bc7bSHarrietAkot     // Then assign contents from coordinate scheme tensor if provided.
2718a91bc7bSHarrietAkot     if (tensor) {
2728a91bc7bSHarrietAkot       uint64_t nnz = tensor->getElements().size();
2738a91bc7bSHarrietAkot       values.reserve(nnz);
2741ce77b56SAart Bik       fromCOO(tensor, 0, nnz, 0);
2751ce77b56SAart Bik     } else if (allDense) {
276f66e5769SAart Bik       values.resize(sz, 0);
2778a91bc7bSHarrietAkot     }
2788a91bc7bSHarrietAkot   }
2798a91bc7bSHarrietAkot 
2808a91bc7bSHarrietAkot   virtual ~SparseTensorStorage() {}
2818a91bc7bSHarrietAkot 
2828a91bc7bSHarrietAkot   /// Get the rank of the tensor.
2838a91bc7bSHarrietAkot   uint64_t getRank() const { return sizes.size(); }
2848a91bc7bSHarrietAkot 
2858a91bc7bSHarrietAkot   /// Get the size in the given dimension of the tensor.
2868a91bc7bSHarrietAkot   uint64_t getDimSize(uint64_t d) override {
2878a91bc7bSHarrietAkot     assert(d < getRank());
2888a91bc7bSHarrietAkot     return sizes[d];
2898a91bc7bSHarrietAkot   }
2908a91bc7bSHarrietAkot 
291f66e5769SAart Bik   /// Partially specialize these getter methods based on template types.
2928a91bc7bSHarrietAkot   void getPointers(std::vector<P> **out, uint64_t d) override {
2938a91bc7bSHarrietAkot     assert(d < getRank());
2948a91bc7bSHarrietAkot     *out = &pointers[d];
2958a91bc7bSHarrietAkot   }
2968a91bc7bSHarrietAkot   void getIndices(std::vector<I> **out, uint64_t d) override {
2978a91bc7bSHarrietAkot     assert(d < getRank());
2988a91bc7bSHarrietAkot     *out = &indices[d];
2998a91bc7bSHarrietAkot   }
3008a91bc7bSHarrietAkot   void getValues(std::vector<V> **out) override { *out = &values; }
3018a91bc7bSHarrietAkot 
30203fe15ceSAart Bik   /// Partially specialize lexicographical insertions based on template types.
303f66e5769SAart Bik   void lexInsert(uint64_t *cursor, V val) override {
3041ce77b56SAart Bik     // First, wrap up pending insertion path.
3051ce77b56SAart Bik     uint64_t diff = 0;
3061ce77b56SAart Bik     uint64_t top = 0;
3071ce77b56SAart Bik     if (!values.empty()) {
3081ce77b56SAart Bik       diff = lexDiff(cursor);
3091ce77b56SAart Bik       endPath(diff + 1);
3101ce77b56SAart Bik       top = idx[diff] + 1;
3111ce77b56SAart Bik     }
3121ce77b56SAart Bik     // Then continue with insertion path.
3131ce77b56SAart Bik     insPath(cursor, diff, top, val);
314f66e5769SAart Bik   }
315f66e5769SAart Bik 
3164f2ec7f9SAart Bik   /// Partially specialize expanded insertions based on template types.
3174f2ec7f9SAart Bik   /// Note that this method resets the values/filled-switch array back
3184f2ec7f9SAart Bik   /// to all-zero/false while only iterating over the nonzero elements.
3194f2ec7f9SAart Bik   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
3204f2ec7f9SAart Bik                  uint64_t count) override {
3214f2ec7f9SAart Bik     if (count == 0)
3224f2ec7f9SAart Bik       return;
3234f2ec7f9SAart Bik     // Sort.
3244f2ec7f9SAart Bik     std::sort(added, added + count);
3254f2ec7f9SAart Bik     // Restore insertion path for first insert.
3264f2ec7f9SAart Bik     uint64_t rank = getRank();
3274f2ec7f9SAart Bik     uint64_t index = added[0];
3284f2ec7f9SAart Bik     cursor[rank - 1] = index;
3294f2ec7f9SAart Bik     lexInsert(cursor, values[index]);
3304f2ec7f9SAart Bik     assert(filled[index]);
3314f2ec7f9SAart Bik     values[index] = 0;
3324f2ec7f9SAart Bik     filled[index] = false;
3334f2ec7f9SAart Bik     // Subsequent insertions are quick.
3344f2ec7f9SAart Bik     for (uint64_t i = 1; i < count; i++) {
3354f2ec7f9SAart Bik       assert(index < added[i] && "non-lexicographic insertion");
3364f2ec7f9SAart Bik       index = added[i];
3374f2ec7f9SAart Bik       cursor[rank - 1] = index;
3384f2ec7f9SAart Bik       insPath(cursor, rank - 1, added[i - 1] + 1, values[index]);
3394f2ec7f9SAart Bik       assert(filled[index]);
3404f2ec7f9SAart Bik       values[index] = 0.0;
3414f2ec7f9SAart Bik       filled[index] = false;
3424f2ec7f9SAart Bik     }
3434f2ec7f9SAart Bik   }
3444f2ec7f9SAart Bik 
345f66e5769SAart Bik   /// Finalizes lexicographic insertions.
3461ce77b56SAart Bik   void endInsert() override {
3471ce77b56SAart Bik     if (values.empty())
3481ce77b56SAart Bik       endDim(0);
3491ce77b56SAart Bik     else
3501ce77b56SAart Bik       endPath(0);
3511ce77b56SAart Bik   }
352f66e5769SAart Bik 
3538a91bc7bSHarrietAkot   /// Returns this sparse tensor storage scheme as a new memory-resident
3548a91bc7bSHarrietAkot   /// sparse tensor in coordinate scheme with the given dimension order.
3558a91bc7bSHarrietAkot   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
3568a91bc7bSHarrietAkot     // Restore original order of the dimension sizes and allocate coordinate
3578a91bc7bSHarrietAkot     // scheme with desired new ordering specified in perm.
3588a91bc7bSHarrietAkot     uint64_t rank = getRank();
3598a91bc7bSHarrietAkot     std::vector<uint64_t> orgsz(rank);
3608a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
3618a91bc7bSHarrietAkot       orgsz[rev[r]] = sizes[r];
3628a91bc7bSHarrietAkot     SparseTensorCOO<V> *tensor = SparseTensorCOO<V>::newSparseTensorCOO(
3638a91bc7bSHarrietAkot         rank, orgsz.data(), perm, values.size());
3648a91bc7bSHarrietAkot     // Populate coordinate scheme restored from old ordering and changed with
3658a91bc7bSHarrietAkot     // new ordering. Rather than applying both reorderings during the recursion,
3668a91bc7bSHarrietAkot     // we compute the combine permutation in advance.
3678a91bc7bSHarrietAkot     std::vector<uint64_t> reord(rank);
3688a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
3698a91bc7bSHarrietAkot       reord[r] = perm[rev[r]];
370f66e5769SAart Bik     toCOO(tensor, reord, 0, 0);
3718a91bc7bSHarrietAkot     assert(tensor->getElements().size() == values.size());
3728a91bc7bSHarrietAkot     return tensor;
3738a91bc7bSHarrietAkot   }
3748a91bc7bSHarrietAkot 
3758a91bc7bSHarrietAkot   /// Factory method. Constructs a sparse tensor storage scheme with the given
3768a91bc7bSHarrietAkot   /// dimensions, permutation, and per-dimension dense/sparse annotations,
3778a91bc7bSHarrietAkot   /// using the coordinate scheme tensor for the initial contents if provided.
3788a91bc7bSHarrietAkot   /// In the latter case, the coordinate scheme must respect the same
3798a91bc7bSHarrietAkot   /// permutation as is desired for the new sparse tensor storage.
3808a91bc7bSHarrietAkot   static SparseTensorStorage<P, I, V> *
3818a91bc7bSHarrietAkot   newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm,
382845561ecSwren romano                   const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) {
3838a91bc7bSHarrietAkot     SparseTensorStorage<P, I, V> *n = nullptr;
3848a91bc7bSHarrietAkot     if (tensor) {
3858a91bc7bSHarrietAkot       assert(tensor->getRank() == rank);
3868a91bc7bSHarrietAkot       for (uint64_t r = 0; r < rank; r++)
3878a91bc7bSHarrietAkot         assert(sizes[r] == 0 || tensor->getSizes()[perm[r]] == sizes[r]);
3888a91bc7bSHarrietAkot       tensor->sort(); // sort lexicographically
3898a91bc7bSHarrietAkot       n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity,
3908a91bc7bSHarrietAkot                                            tensor);
3918a91bc7bSHarrietAkot       delete tensor;
3928a91bc7bSHarrietAkot     } else {
3938a91bc7bSHarrietAkot       std::vector<uint64_t> permsz(rank);
3948a91bc7bSHarrietAkot       for (uint64_t r = 0; r < rank; r++)
3958a91bc7bSHarrietAkot         permsz[perm[r]] = sizes[r];
396f66e5769SAart Bik       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
3978a91bc7bSHarrietAkot     }
3988a91bc7bSHarrietAkot     return n;
3998a91bc7bSHarrietAkot   }
4008a91bc7bSHarrietAkot 
4018a91bc7bSHarrietAkot private:
4028a91bc7bSHarrietAkot   /// Initializes sparse tensor storage scheme from a memory-resident sparse
4038a91bc7bSHarrietAkot   /// tensor in coordinate scheme. This method prepares the pointers and
4048a91bc7bSHarrietAkot   /// indices arrays under the given per-dimension dense/sparse annotations.
4051ce77b56SAart Bik   void fromCOO(SparseTensorCOO<V> *tensor, uint64_t lo, uint64_t hi,
4061ce77b56SAart Bik                uint64_t d) {
4078a91bc7bSHarrietAkot     const std::vector<Element<V>> &elements = tensor->getElements();
4088a91bc7bSHarrietAkot     // Once dimensions are exhausted, insert the numerical values.
409f66e5769SAart Bik     assert(d <= getRank());
4108a91bc7bSHarrietAkot     if (d == getRank()) {
4111ce77b56SAart Bik       assert(lo < hi && hi <= elements.size());
4121ce77b56SAart Bik       values.push_back(elements[lo].value);
4138a91bc7bSHarrietAkot       return;
4148a91bc7bSHarrietAkot     }
4158a91bc7bSHarrietAkot     // Visit all elements in this interval.
4168a91bc7bSHarrietAkot     uint64_t full = 0;
4178a91bc7bSHarrietAkot     while (lo < hi) {
4188a91bc7bSHarrietAkot       assert(lo < elements.size() && hi <= elements.size());
4198a91bc7bSHarrietAkot       // Find segment in interval with same index elements in this dimension.
420f66e5769SAart Bik       uint64_t i = elements[lo].indices[d];
4218a91bc7bSHarrietAkot       uint64_t seg = lo + 1;
422f66e5769SAart Bik       while (seg < hi && elements[seg].indices[d] == i)
4238a91bc7bSHarrietAkot         seg++;
4248a91bc7bSHarrietAkot       // Handle segment in interval for sparse or dense dimension.
4251ce77b56SAart Bik       if (isCompressedDim(d)) {
426f66e5769SAart Bik         indices[d].push_back(i);
4278a91bc7bSHarrietAkot       } else {
4288a91bc7bSHarrietAkot         // For dense storage we must fill in all the zero values between
4298a91bc7bSHarrietAkot         // the previous element (when last we ran this for-loop) and the
4308a91bc7bSHarrietAkot         // current element.
431f66e5769SAart Bik         for (; full < i; full++)
4321ce77b56SAart Bik           endDim(d + 1);
4338a91bc7bSHarrietAkot         full++;
4348a91bc7bSHarrietAkot       }
4351ce77b56SAart Bik       fromCOO(tensor, lo, seg, d + 1);
4368a91bc7bSHarrietAkot       // And move on to next segment in interval.
4378a91bc7bSHarrietAkot       lo = seg;
4388a91bc7bSHarrietAkot     }
4398a91bc7bSHarrietAkot     // Finalize the sparse pointer structure at this dimension.
4401ce77b56SAart Bik     if (isCompressedDim(d)) {
4418a91bc7bSHarrietAkot       pointers[d].push_back(indices[d].size());
4428a91bc7bSHarrietAkot     } else {
4438a91bc7bSHarrietAkot       // For dense storage we must fill in all the zero values after
4448a91bc7bSHarrietAkot       // the last element.
4458a91bc7bSHarrietAkot       for (uint64_t sz = sizes[d]; full < sz; full++)
4461ce77b56SAart Bik         endDim(d + 1);
4478a91bc7bSHarrietAkot     }
4488a91bc7bSHarrietAkot   }
4498a91bc7bSHarrietAkot 
4508a91bc7bSHarrietAkot   /// Stores the sparse tensor storage scheme into a memory-resident sparse
4518a91bc7bSHarrietAkot   /// tensor in coordinate scheme.
4528a91bc7bSHarrietAkot   void toCOO(SparseTensorCOO<V> *tensor, std::vector<uint64_t> &reord,
453f66e5769SAart Bik              uint64_t pos, uint64_t d) {
4548a91bc7bSHarrietAkot     assert(d <= getRank());
4558a91bc7bSHarrietAkot     if (d == getRank()) {
4568a91bc7bSHarrietAkot       assert(pos < values.size());
4578a91bc7bSHarrietAkot       tensor->add(idx, values[pos]);
4581ce77b56SAart Bik     } else if (isCompressedDim(d)) {
4598a91bc7bSHarrietAkot       // Sparse dimension.
4608a91bc7bSHarrietAkot       for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
4618a91bc7bSHarrietAkot         idx[reord[d]] = indices[d][ii];
462f66e5769SAart Bik         toCOO(tensor, reord, ii, d + 1);
4638a91bc7bSHarrietAkot       }
4641ce77b56SAart Bik     } else {
4651ce77b56SAart Bik       // Dense dimension.
4661ce77b56SAart Bik       for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) {
4671ce77b56SAart Bik         idx[reord[d]] = i;
4681ce77b56SAart Bik         toCOO(tensor, reord, off + i, d + 1);
4698a91bc7bSHarrietAkot       }
4708a91bc7bSHarrietAkot     }
4711ce77b56SAart Bik   }
4721ce77b56SAart Bik 
4731ce77b56SAart Bik   /// Ends a deeper, never seen before dimension.
4741ce77b56SAart Bik   void endDim(uint64_t d) {
4751ce77b56SAart Bik     assert(d <= getRank());
4761ce77b56SAart Bik     if (d == getRank()) {
4771ce77b56SAart Bik       values.push_back(0);
4781ce77b56SAart Bik     } else if (isCompressedDim(d)) {
4791ce77b56SAart Bik       pointers[d].push_back(indices[d].size());
4801ce77b56SAart Bik     } else {
4811ce77b56SAart Bik       for (uint64_t full = 0, sz = sizes[d]; full < sz; full++)
4821ce77b56SAart Bik         endDim(d + 1);
4831ce77b56SAart Bik     }
4841ce77b56SAart Bik   }
4851ce77b56SAart Bik 
4861ce77b56SAart Bik   /// Wraps up a single insertion path, inner to outer.
4871ce77b56SAart Bik   void endPath(uint64_t diff) {
4881ce77b56SAart Bik     uint64_t rank = getRank();
4891ce77b56SAart Bik     assert(diff <= rank);
4901ce77b56SAart Bik     for (uint64_t i = 0; i < rank - diff; i++) {
4911ce77b56SAart Bik       uint64_t d = rank - i - 1;
4921ce77b56SAart Bik       if (isCompressedDim(d)) {
4931ce77b56SAart Bik         pointers[d].push_back(indices[d].size());
4941ce77b56SAart Bik       } else {
4951ce77b56SAart Bik         for (uint64_t full = idx[d] + 1, sz = sizes[d]; full < sz; full++)
4961ce77b56SAart Bik           endDim(d + 1);
4971ce77b56SAart Bik       }
4981ce77b56SAart Bik     }
4991ce77b56SAart Bik   }
5001ce77b56SAart Bik 
5011ce77b56SAart Bik   /// Continues a single insertion path, outer to inner.
5021ce77b56SAart Bik   void insPath(uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
5031ce77b56SAart Bik     uint64_t rank = getRank();
5041ce77b56SAart Bik     assert(diff < rank);
5051ce77b56SAart Bik     for (uint64_t d = diff; d < rank; d++) {
5061ce77b56SAart Bik       uint64_t i = cursor[d];
5071ce77b56SAart Bik       if (isCompressedDim(d)) {
5081ce77b56SAart Bik         indices[d].push_back(i);
5091ce77b56SAart Bik       } else {
5101ce77b56SAart Bik         for (uint64_t full = top; full < i; full++)
5111ce77b56SAart Bik           endDim(d + 1);
5121ce77b56SAart Bik       }
5131ce77b56SAart Bik       top = 0;
5141ce77b56SAart Bik       idx[d] = i;
5151ce77b56SAart Bik     }
5161ce77b56SAart Bik     values.push_back(val);
5171ce77b56SAart Bik   }
5181ce77b56SAart Bik 
5191ce77b56SAart Bik   /// Finds the lexicographic differing dimension.
5201ce77b56SAart Bik   uint64_t lexDiff(uint64_t *cursor) {
5211ce77b56SAart Bik     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
5221ce77b56SAart Bik       if (cursor[r] > idx[r])
5231ce77b56SAart Bik         return r;
5241ce77b56SAart Bik       else
5251ce77b56SAart Bik         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
5261ce77b56SAart Bik     assert(0 && "duplication insertion");
5271ce77b56SAart Bik     return -1u;
5281ce77b56SAart Bik   }
5291ce77b56SAart Bik 
5301ce77b56SAart Bik   /// Returns true if dimension is compressed.
5311ce77b56SAart Bik   inline bool isCompressedDim(uint64_t d) const {
5321ce77b56SAart Bik     return (!pointers[d].empty());
5331ce77b56SAart Bik   }
5348a91bc7bSHarrietAkot 
5358a91bc7bSHarrietAkot private:
5368a91bc7bSHarrietAkot   std::vector<uint64_t> sizes; // per-dimension sizes
5378a91bc7bSHarrietAkot   std::vector<uint64_t> rev;   // "reverse" permutation
538f66e5769SAart Bik   std::vector<uint64_t> idx;   // index cursor
5398a91bc7bSHarrietAkot   std::vector<std::vector<P>> pointers;
5408a91bc7bSHarrietAkot   std::vector<std::vector<I>> indices;
5418a91bc7bSHarrietAkot   std::vector<V> values;
5428a91bc7bSHarrietAkot };
5438a91bc7bSHarrietAkot 
5448a91bc7bSHarrietAkot /// Helper to convert string to lower case.
5458a91bc7bSHarrietAkot static char *toLower(char *token) {
5468a91bc7bSHarrietAkot   for (char *c = token; *c; c++)
5478a91bc7bSHarrietAkot     *c = tolower(*c);
5488a91bc7bSHarrietAkot   return token;
5498a91bc7bSHarrietAkot }
5508a91bc7bSHarrietAkot 
5518a91bc7bSHarrietAkot /// Read the MME header of a general sparse matrix of type real.
55203fe15ceSAart Bik static void readMMEHeader(FILE *file, char *filename, char *line,
553*bb56c2b3SMehdi Amini                           uint64_t *idata, bool *isSymmetric) {
5548a91bc7bSHarrietAkot   char header[64];
5558a91bc7bSHarrietAkot   char object[64];
5568a91bc7bSHarrietAkot   char format[64];
5578a91bc7bSHarrietAkot   char field[64];
5588a91bc7bSHarrietAkot   char symmetry[64];
5598a91bc7bSHarrietAkot   // Read header line.
5608a91bc7bSHarrietAkot   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
5618a91bc7bSHarrietAkot              symmetry) != 5) {
56203fe15ceSAart Bik     fprintf(stderr, "Corrupt header in %s\n", filename);
5638a91bc7bSHarrietAkot     exit(1);
5648a91bc7bSHarrietAkot   }
565*bb56c2b3SMehdi Amini   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
5668a91bc7bSHarrietAkot   // Make sure this is a general sparse matrix.
5678a91bc7bSHarrietAkot   if (strcmp(toLower(header), "%%matrixmarket") ||
5688a91bc7bSHarrietAkot       strcmp(toLower(object), "matrix") ||
5698a91bc7bSHarrietAkot       strcmp(toLower(format), "coordinate") || strcmp(toLower(field), "real") ||
570*bb56c2b3SMehdi Amini       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
5718a91bc7bSHarrietAkot     fprintf(stderr,
57203fe15ceSAart Bik             "Cannot find a general sparse matrix with type real in %s\n",
57303fe15ceSAart Bik             filename);
5748a91bc7bSHarrietAkot     exit(1);
5758a91bc7bSHarrietAkot   }
5768a91bc7bSHarrietAkot   // Skip comments.
5778a91bc7bSHarrietAkot   while (1) {
57803fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
57903fe15ceSAart Bik       fprintf(stderr, "Cannot find data in %s\n", filename);
5808a91bc7bSHarrietAkot       exit(1);
5818a91bc7bSHarrietAkot     }
5828a91bc7bSHarrietAkot     if (line[0] != '%')
5838a91bc7bSHarrietAkot       break;
5848a91bc7bSHarrietAkot   }
5858a91bc7bSHarrietAkot   // Next line contains M N NNZ.
5868a91bc7bSHarrietAkot   idata[0] = 2; // rank
5878a91bc7bSHarrietAkot   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
5888a91bc7bSHarrietAkot              idata + 1) != 3) {
58903fe15ceSAart Bik     fprintf(stderr, "Cannot find size in %s\n", filename);
5908a91bc7bSHarrietAkot     exit(1);
5918a91bc7bSHarrietAkot   }
5928a91bc7bSHarrietAkot }
5938a91bc7bSHarrietAkot 
5948a91bc7bSHarrietAkot /// Read the "extended" FROSTT header. Although not part of the documented
5958a91bc7bSHarrietAkot /// format, we assume that the file starts with optional comments followed
5968a91bc7bSHarrietAkot /// by two lines that define the rank, the number of nonzeros, and the
5978a91bc7bSHarrietAkot /// dimensions sizes (one per rank) of the sparse tensor.
59803fe15ceSAart Bik static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
59903fe15ceSAart Bik                                 uint64_t *idata) {
6008a91bc7bSHarrietAkot   // Skip comments.
6018a91bc7bSHarrietAkot   while (1) {
60203fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
60303fe15ceSAart Bik       fprintf(stderr, "Cannot find data in %s\n", filename);
6048a91bc7bSHarrietAkot       exit(1);
6058a91bc7bSHarrietAkot     }
6068a91bc7bSHarrietAkot     if (line[0] != '#')
6078a91bc7bSHarrietAkot       break;
6088a91bc7bSHarrietAkot   }
6098a91bc7bSHarrietAkot   // Next line contains RANK and NNZ.
6108a91bc7bSHarrietAkot   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
61103fe15ceSAart Bik     fprintf(stderr, "Cannot find metadata in %s\n", filename);
6128a91bc7bSHarrietAkot     exit(1);
6138a91bc7bSHarrietAkot   }
6148a91bc7bSHarrietAkot   // Followed by a line with the dimension sizes (one per rank).
6158a91bc7bSHarrietAkot   for (uint64_t r = 0; r < idata[0]; r++) {
6168a91bc7bSHarrietAkot     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
61703fe15ceSAart Bik       fprintf(stderr, "Cannot find dimension size %s\n", filename);
6188a91bc7bSHarrietAkot       exit(1);
6198a91bc7bSHarrietAkot     }
6208a91bc7bSHarrietAkot   }
62103fe15ceSAart Bik   fgets(line, kColWidth, file); // end of line
6228a91bc7bSHarrietAkot }
6238a91bc7bSHarrietAkot 
6248a91bc7bSHarrietAkot /// Reads a sparse tensor with the given filename into a memory-resident
6258a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme.
6268a91bc7bSHarrietAkot template <typename V>
6278a91bc7bSHarrietAkot static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
6288a91bc7bSHarrietAkot                                                const uint64_t *sizes,
6298a91bc7bSHarrietAkot                                                const uint64_t *perm) {
6308a91bc7bSHarrietAkot   // Open the file.
6318a91bc7bSHarrietAkot   FILE *file = fopen(filename, "r");
6328a91bc7bSHarrietAkot   if (!file) {
6338a91bc7bSHarrietAkot     fprintf(stderr, "Cannot find %s\n", filename);
6348a91bc7bSHarrietAkot     exit(1);
6358a91bc7bSHarrietAkot   }
6368a91bc7bSHarrietAkot   // Perform some file format dependent set up.
63703fe15ceSAart Bik   char line[kColWidth];
6388a91bc7bSHarrietAkot   uint64_t idata[512];
639*bb56c2b3SMehdi Amini   bool isSymmetric = false;
6408a91bc7bSHarrietAkot   if (strstr(filename, ".mtx")) {
641*bb56c2b3SMehdi Amini     readMMEHeader(file, filename, line, idata, &isSymmetric);
6428a91bc7bSHarrietAkot   } else if (strstr(filename, ".tns")) {
64303fe15ceSAart Bik     readExtFROSTTHeader(file, filename, line, idata);
6448a91bc7bSHarrietAkot   } else {
6458a91bc7bSHarrietAkot     fprintf(stderr, "Unknown format %s\n", filename);
6468a91bc7bSHarrietAkot     exit(1);
6478a91bc7bSHarrietAkot   }
6488a91bc7bSHarrietAkot   // Prepare sparse tensor object with per-dimension sizes
6498a91bc7bSHarrietAkot   // and the number of nonzeros as initial capacity.
6508a91bc7bSHarrietAkot   assert(rank == idata[0] && "rank mismatch");
6518a91bc7bSHarrietAkot   uint64_t nnz = idata[1];
6528a91bc7bSHarrietAkot   for (uint64_t r = 0; r < rank; r++)
6538a91bc7bSHarrietAkot     assert((sizes[r] == 0 || sizes[r] == idata[2 + r]) &&
6548a91bc7bSHarrietAkot            "dimension size mismatch");
6558a91bc7bSHarrietAkot   SparseTensorCOO<V> *tensor =
6568a91bc7bSHarrietAkot       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
6578a91bc7bSHarrietAkot   //  Read all nonzero elements.
6588a91bc7bSHarrietAkot   std::vector<uint64_t> indices(rank);
6598a91bc7bSHarrietAkot   for (uint64_t k = 0; k < nnz; k++) {
66003fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
66103fe15ceSAart Bik       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
6628a91bc7bSHarrietAkot       exit(1);
6638a91bc7bSHarrietAkot     }
66403fe15ceSAart Bik     char *linePtr = line;
66503fe15ceSAart Bik     for (uint64_t r = 0; r < rank; r++) {
66603fe15ceSAart Bik       uint64_t idx = strtoul(linePtr, &linePtr, 10);
6678a91bc7bSHarrietAkot       // Add 0-based index.
6688a91bc7bSHarrietAkot       indices[perm[r]] = idx - 1;
6698a91bc7bSHarrietAkot     }
6708a91bc7bSHarrietAkot     // The external formats always store the numerical values with the type
6718a91bc7bSHarrietAkot     // double, but we cast these values to the sparse tensor object type.
67203fe15ceSAart Bik     double value = strtod(linePtr, &linePtr);
6738a91bc7bSHarrietAkot     tensor->add(indices, value);
67402710413SBixia Zheng     // We currently chose to deal with symmetric matrices by fully constructing
67502710413SBixia Zheng     // them. In the future, we may want to make symmetry implicit for storage
67602710413SBixia Zheng     // reasons.
677*bb56c2b3SMehdi Amini     if (isSymmetric && indices[0] != indices[1])
67802710413SBixia Zheng       tensor->add({indices[1], indices[0]}, value);
6798a91bc7bSHarrietAkot   }
6808a91bc7bSHarrietAkot   // Close the file and return tensor.
6818a91bc7bSHarrietAkot   fclose(file);
6828a91bc7bSHarrietAkot   return tensor;
6838a91bc7bSHarrietAkot }
6848a91bc7bSHarrietAkot 
685be0a7e9fSMehdi Amini } // namespace
6868a91bc7bSHarrietAkot 
6878a91bc7bSHarrietAkot extern "C" {
6888a91bc7bSHarrietAkot 
6898a91bc7bSHarrietAkot /// This type is used in the public API at all places where MLIR expects
6908a91bc7bSHarrietAkot /// values with the built-in type "index". For now, we simply assume that
6918a91bc7bSHarrietAkot /// type is 64-bit, but targets with different "index" bit widths should link
6928a91bc7bSHarrietAkot /// with an alternatively built runtime support library.
6938a91bc7bSHarrietAkot // TODO: support such targets?
6948a91bc7bSHarrietAkot typedef uint64_t index_t;
6958a91bc7bSHarrietAkot 
6968a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
6978a91bc7bSHarrietAkot //
6988a91bc7bSHarrietAkot // Public API with methods that operate on MLIR buffers (memrefs) to interact
6998a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally.
7008a91bc7bSHarrietAkot // These methods should be used exclusively by MLIR compiler-generated code.
7018a91bc7bSHarrietAkot //
7028a91bc7bSHarrietAkot // Some macro magic is used to generate implementations for all required type
7038a91bc7bSHarrietAkot // combinations that can be called from MLIR compiler-generated code.
7048a91bc7bSHarrietAkot //
7058a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
7068a91bc7bSHarrietAkot 
7078a91bc7bSHarrietAkot #define CASE(p, i, v, P, I, V)                                                 \
7088a91bc7bSHarrietAkot   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
7098a91bc7bSHarrietAkot     SparseTensorCOO<V> *tensor = nullptr;                                      \
710845561ecSwren romano     if (action <= Action::kFromCOO) {                                          \
711845561ecSwren romano       if (action == Action::kFromFile) {                                       \
7128a91bc7bSHarrietAkot         char *filename = static_cast<char *>(ptr);                             \
7138a91bc7bSHarrietAkot         tensor = openSparseTensorCOO<V>(filename, rank, sizes, perm);          \
714845561ecSwren romano       } else if (action == Action::kFromCOO) {                                 \
7158a91bc7bSHarrietAkot         tensor = static_cast<SparseTensorCOO<V> *>(ptr);                       \
7168a91bc7bSHarrietAkot       } else {                                                                 \
717845561ecSwren romano         assert(action == Action::kEmpty);                                      \
7188a91bc7bSHarrietAkot       }                                                                        \
7198a91bc7bSHarrietAkot       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm,  \
7208a91bc7bSHarrietAkot                                                            sparsity, tensor);  \
721*bb56c2b3SMehdi Amini     }                                                                          \
722*bb56c2b3SMehdi Amini     if (action == Action::kEmptyCOO)                                           \
7238a91bc7bSHarrietAkot       return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm);        \
7248a91bc7bSHarrietAkot     tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);    \
725845561ecSwren romano     if (action == Action::kToIterator) {                                       \
7268a91bc7bSHarrietAkot       tensor->startIterator();                                                 \
7278a91bc7bSHarrietAkot     } else {                                                                   \
728845561ecSwren romano       assert(action == Action::kToCOO);                                        \
7298a91bc7bSHarrietAkot     }                                                                          \
7308a91bc7bSHarrietAkot     return tensor;                                                             \
7318a91bc7bSHarrietAkot   }
7328a91bc7bSHarrietAkot 
733845561ecSwren romano #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
734845561ecSwren romano 
7358a91bc7bSHarrietAkot #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
7368a91bc7bSHarrietAkot   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
7374f2ec7f9SAart Bik     assert(ref &&tensor);                                                      \
7388a91bc7bSHarrietAkot     std::vector<TYPE> *v;                                                      \
7398a91bc7bSHarrietAkot     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
7408a91bc7bSHarrietAkot     ref->basePtr = ref->data = v->data();                                      \
7418a91bc7bSHarrietAkot     ref->offset = 0;                                                           \
7428a91bc7bSHarrietAkot     ref->sizes[0] = v->size();                                                 \
7438a91bc7bSHarrietAkot     ref->strides[0] = 1;                                                       \
7448a91bc7bSHarrietAkot   }
7458a91bc7bSHarrietAkot 
7468a91bc7bSHarrietAkot #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
7478a91bc7bSHarrietAkot   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
7488a91bc7bSHarrietAkot                            index_t d) {                                        \
7494f2ec7f9SAart Bik     assert(ref &&tensor);                                                      \
7508a91bc7bSHarrietAkot     std::vector<TYPE> *v;                                                      \
7518a91bc7bSHarrietAkot     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
7528a91bc7bSHarrietAkot     ref->basePtr = ref->data = v->data();                                      \
7538a91bc7bSHarrietAkot     ref->offset = 0;                                                           \
7548a91bc7bSHarrietAkot     ref->sizes[0] = v->size();                                                 \
7558a91bc7bSHarrietAkot     ref->strides[0] = 1;                                                       \
7568a91bc7bSHarrietAkot   }
7578a91bc7bSHarrietAkot 
7588a91bc7bSHarrietAkot #define IMPL_ADDELT(NAME, TYPE)                                                \
7598a91bc7bSHarrietAkot   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
7608a91bc7bSHarrietAkot                             StridedMemRefType<index_t, 1> *iref,               \
7618a91bc7bSHarrietAkot                             StridedMemRefType<index_t, 1> *pref) {             \
7624f2ec7f9SAart Bik     assert(tensor &&iref &&pref);                                              \
7638a91bc7bSHarrietAkot     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
7648a91bc7bSHarrietAkot     assert(iref->sizes[0] == pref->sizes[0]);                                  \
7658a91bc7bSHarrietAkot     const index_t *indx = iref->data + iref->offset;                           \
7668a91bc7bSHarrietAkot     const index_t *perm = pref->data + pref->offset;                           \
7678a91bc7bSHarrietAkot     uint64_t isize = iref->sizes[0];                                           \
7688a91bc7bSHarrietAkot     std::vector<index_t> indices(isize);                                       \
7698a91bc7bSHarrietAkot     for (uint64_t r = 0; r < isize; r++)                                       \
7708a91bc7bSHarrietAkot       indices[perm[r]] = indx[r];                                              \
7718a91bc7bSHarrietAkot     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
7728a91bc7bSHarrietAkot     return tensor;                                                             \
7738a91bc7bSHarrietAkot   }
7748a91bc7bSHarrietAkot 
7758a91bc7bSHarrietAkot #define IMPL_GETNEXT(NAME, V)                                                  \
7764f2ec7f9SAart Bik   bool _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *iref,  \
7778a91bc7bSHarrietAkot                            StridedMemRefType<V, 0> *vref) {                    \
7784f2ec7f9SAart Bik     assert(tensor &&iref &&vref);                                              \
7798a91bc7bSHarrietAkot     assert(iref->strides[0] == 1);                                             \
7804f2ec7f9SAart Bik     index_t *indx = iref->data + iref->offset;                                 \
7818a91bc7bSHarrietAkot     V *value = vref->data + vref->offset;                                      \
7828a91bc7bSHarrietAkot     const uint64_t isize = iref->sizes[0];                                     \
7838a91bc7bSHarrietAkot     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
7848a91bc7bSHarrietAkot     const Element<V> *elem = iter->getNext();                                  \
7858a91bc7bSHarrietAkot     if (elem == nullptr) {                                                     \
7868a91bc7bSHarrietAkot       delete iter;                                                             \
7878a91bc7bSHarrietAkot       return false;                                                            \
7888a91bc7bSHarrietAkot     }                                                                          \
7898a91bc7bSHarrietAkot     for (uint64_t r = 0; r < isize; r++)                                       \
7908a91bc7bSHarrietAkot       indx[r] = elem->indices[r];                                              \
7918a91bc7bSHarrietAkot     *value = elem->value;                                                      \
7928a91bc7bSHarrietAkot     return true;                                                               \
7938a91bc7bSHarrietAkot   }
7948a91bc7bSHarrietAkot 
795f66e5769SAart Bik #define IMPL_LEXINSERT(NAME, V)                                                \
796f66e5769SAart Bik   void _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *cref,  \
797f66e5769SAart Bik                            V val) {                                            \
7984f2ec7f9SAart Bik     assert(tensor &&cref);                                                     \
799f66e5769SAart Bik     assert(cref->strides[0] == 1);                                             \
8004f2ec7f9SAart Bik     index_t *cursor = cref->data + cref->offset;                               \
801f66e5769SAart Bik     assert(cursor);                                                            \
802f66e5769SAart Bik     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
803f66e5769SAart Bik   }
804f66e5769SAart Bik 
8054f2ec7f9SAart Bik #define IMPL_EXPINSERT(NAME, V)                                                \
8064f2ec7f9SAart Bik   void _mlir_ciface_##NAME(                                                    \
8074f2ec7f9SAart Bik       void *tensor, StridedMemRefType<index_t, 1> *cref,                       \
8084f2ec7f9SAart Bik       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
8094f2ec7f9SAart Bik       StridedMemRefType<index_t, 1> *aref, index_t count) {                    \
8104f2ec7f9SAart Bik     assert(tensor &&cref &&vref &&fref &&aref);                                \
8114f2ec7f9SAart Bik     assert(cref->strides[0] == 1);                                             \
8124f2ec7f9SAart Bik     assert(vref->strides[0] == 1);                                             \
8134f2ec7f9SAart Bik     assert(fref->strides[0] == 1);                                             \
8144f2ec7f9SAart Bik     assert(aref->strides[0] == 1);                                             \
8154f2ec7f9SAart Bik     assert(vref->sizes[0] == fref->sizes[0]);                                  \
8164f2ec7f9SAart Bik     index_t *cursor = cref->data + cref->offset;                               \
8174f2ec7f9SAart Bik     V *values = vref->data + vref->offset;                                     \
8184f2ec7f9SAart Bik     bool *filled = fref->data + fref->offset;                                  \
8194f2ec7f9SAart Bik     index_t *added = aref->data + aref->offset;                                \
8204f2ec7f9SAart Bik     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
8214f2ec7f9SAart Bik         cursor, values, filled, added, count);                                 \
8224f2ec7f9SAart Bik   }
8234f2ec7f9SAart Bik 
8248a91bc7bSHarrietAkot /// Constructs a new sparse tensor. This is the "swiss army knife"
8258a91bc7bSHarrietAkot /// method for materializing sparse tensors into the computation.
8268a91bc7bSHarrietAkot ///
827845561ecSwren romano /// Action:
8288a91bc7bSHarrietAkot /// kEmpty = returns empty storage to fill later
8298a91bc7bSHarrietAkot /// kFromFile = returns storage, where ptr contains filename to read
8308a91bc7bSHarrietAkot /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
8318a91bc7bSHarrietAkot /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
8328a91bc7bSHarrietAkot /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
833845561ecSwren romano /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
8348a91bc7bSHarrietAkot void *
835845561ecSwren romano _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
8368a91bc7bSHarrietAkot                              StridedMemRefType<index_t, 1> *sref,
8378a91bc7bSHarrietAkot                              StridedMemRefType<index_t, 1> *pref,
838845561ecSwren romano                              OverheadType ptrTp, OverheadType indTp,
839845561ecSwren romano                              PrimaryType valTp, Action action, void *ptr) {
8408a91bc7bSHarrietAkot   assert(aref && sref && pref);
8418a91bc7bSHarrietAkot   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
8428a91bc7bSHarrietAkot          pref->strides[0] == 1);
8438a91bc7bSHarrietAkot   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
844845561ecSwren romano   const DimLevelType *sparsity = aref->data + aref->offset;
8458a91bc7bSHarrietAkot   const index_t *sizes = sref->data + sref->offset;
8468a91bc7bSHarrietAkot   const index_t *perm = pref->data + pref->offset;
8478a91bc7bSHarrietAkot   uint64_t rank = aref->sizes[0];
8488a91bc7bSHarrietAkot 
8498a91bc7bSHarrietAkot   // Double matrices with all combinations of overhead storage.
850845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
851845561ecSwren romano        uint64_t, double);
852845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
853845561ecSwren romano        uint32_t, double);
854845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
855845561ecSwren romano        uint16_t, double);
856845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
857845561ecSwren romano        uint8_t, double);
858845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
859845561ecSwren romano        uint64_t, double);
860845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
861845561ecSwren romano        uint32_t, double);
862845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
863845561ecSwren romano        uint16_t, double);
864845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
865845561ecSwren romano        uint8_t, double);
866845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
867845561ecSwren romano        uint64_t, double);
868845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
869845561ecSwren romano        uint32_t, double);
870845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
871845561ecSwren romano        uint16_t, double);
872845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
873845561ecSwren romano        uint8_t, double);
874845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
875845561ecSwren romano        uint64_t, double);
876845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
877845561ecSwren romano        uint32_t, double);
878845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
879845561ecSwren romano        uint16_t, double);
880845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
881845561ecSwren romano        uint8_t, double);
8828a91bc7bSHarrietAkot 
8838a91bc7bSHarrietAkot   // Float matrices with all combinations of overhead storage.
884845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
885845561ecSwren romano        uint64_t, float);
886845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
887845561ecSwren romano        uint32_t, float);
888845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
889845561ecSwren romano        uint16_t, float);
890845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
891845561ecSwren romano        uint8_t, float);
892845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
893845561ecSwren romano        uint64_t, float);
894845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
895845561ecSwren romano        uint32_t, float);
896845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
897845561ecSwren romano        uint16_t, float);
898845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
899845561ecSwren romano        uint8_t, float);
900845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
901845561ecSwren romano        uint64_t, float);
902845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
903845561ecSwren romano        uint32_t, float);
904845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
905845561ecSwren romano        uint16_t, float);
906845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
907845561ecSwren romano        uint8_t, float);
908845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
909845561ecSwren romano        uint64_t, float);
910845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
911845561ecSwren romano        uint32_t, float);
912845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
913845561ecSwren romano        uint16_t, float);
914845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
915845561ecSwren romano        uint8_t, float);
9168a91bc7bSHarrietAkot 
917845561ecSwren romano   // Integral matrices with both overheads of the same type.
918845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
919845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
920845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
921845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
922845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
923845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
924845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
925845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
926845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
927845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
928845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
929845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
930845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
9318a91bc7bSHarrietAkot 
9328a91bc7bSHarrietAkot   // Unsupported case (add above if needed).
9338a91bc7bSHarrietAkot   fputs("unsupported combination of types\n", stderr);
9348a91bc7bSHarrietAkot   exit(1);
9358a91bc7bSHarrietAkot }
9368a91bc7bSHarrietAkot 
9378a91bc7bSHarrietAkot /// Methods that provide direct access to pointers.
9388a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers, index_t, getPointers)
9398a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
9408a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
9418a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
9428a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
9438a91bc7bSHarrietAkot 
9448a91bc7bSHarrietAkot /// Methods that provide direct access to indices.
9458a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices, index_t, getIndices)
9468a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
9478a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
9488a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
9498a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
9508a91bc7bSHarrietAkot 
9518a91bc7bSHarrietAkot /// Methods that provide direct access to values.
9528a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
9538a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
9548a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
9558a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
9568a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
9578a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
9588a91bc7bSHarrietAkot 
9598a91bc7bSHarrietAkot /// Helper to add value to coordinate scheme, one per value type.
9608a91bc7bSHarrietAkot IMPL_ADDELT(addEltF64, double)
9618a91bc7bSHarrietAkot IMPL_ADDELT(addEltF32, float)
9628a91bc7bSHarrietAkot IMPL_ADDELT(addEltI64, int64_t)
9638a91bc7bSHarrietAkot IMPL_ADDELT(addEltI32, int32_t)
9648a91bc7bSHarrietAkot IMPL_ADDELT(addEltI16, int16_t)
9658a91bc7bSHarrietAkot IMPL_ADDELT(addEltI8, int8_t)
9668a91bc7bSHarrietAkot 
9678a91bc7bSHarrietAkot /// Helper to enumerate elements of coordinate scheme, one per value type.
9688a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF64, double)
9698a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF32, float)
9708a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI64, int64_t)
9718a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI32, int32_t)
9728a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI16, int16_t)
9738a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI8, int8_t)
9748a91bc7bSHarrietAkot 
97503fe15ceSAart Bik /// Helper to insert elements in lexicographical index order, one per value
97603fe15ceSAart Bik /// type.
977f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF64, double)
978f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF32, float)
979f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI64, int64_t)
980f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI32, int32_t)
981f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI16, int16_t)
982f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI8, int8_t)
983f66e5769SAart Bik 
9844f2ec7f9SAart Bik /// Helper to insert using expansion, one per value type.
9854f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF64, double)
9864f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF32, float)
9874f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI64, int64_t)
9884f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI32, int32_t)
9894f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI16, int16_t)
9904f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI8, int8_t)
9914f2ec7f9SAart Bik 
9928a91bc7bSHarrietAkot #undef CASE
9938a91bc7bSHarrietAkot #undef IMPL_SPARSEVALUES
9948a91bc7bSHarrietAkot #undef IMPL_GETOVERHEAD
9958a91bc7bSHarrietAkot #undef IMPL_ADDELT
9968a91bc7bSHarrietAkot #undef IMPL_GETNEXT
9974f2ec7f9SAart Bik #undef IMPL_LEXINSERT
9984f2ec7f9SAart Bik #undef IMPL_EXPINSERT
9998a91bc7bSHarrietAkot 
10008a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
10018a91bc7bSHarrietAkot //
10028a91bc7bSHarrietAkot // Public API with methods that accept C-style data structures to interact
10038a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally.
10048a91bc7bSHarrietAkot // These methods can be used both by MLIR compiler-generated code as well as by
10058a91bc7bSHarrietAkot // an external runtime that wants to interact with MLIR compiler-generated code.
10068a91bc7bSHarrietAkot //
10078a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
10088a91bc7bSHarrietAkot 
10098a91bc7bSHarrietAkot /// Helper method to read a sparse tensor filename from the environment,
10108a91bc7bSHarrietAkot /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
10118a91bc7bSHarrietAkot char *getTensorFilename(index_t id) {
10128a91bc7bSHarrietAkot   char var[80];
10138a91bc7bSHarrietAkot   sprintf(var, "TENSOR%" PRIu64, id);
10148a91bc7bSHarrietAkot   char *env = getenv(var);
10158a91bc7bSHarrietAkot   return env;
10168a91bc7bSHarrietAkot }
10178a91bc7bSHarrietAkot 
10188a91bc7bSHarrietAkot /// Returns size of sparse tensor in given dimension.
10198a91bc7bSHarrietAkot index_t sparseDimSize(void *tensor, index_t d) {
10208a91bc7bSHarrietAkot   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
10218a91bc7bSHarrietAkot }
10228a91bc7bSHarrietAkot 
1023f66e5769SAart Bik /// Finalizes lexicographic insertions.
1024f66e5769SAart Bik void endInsert(void *tensor) {
1025f66e5769SAart Bik   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1026f66e5769SAart Bik }
1027f66e5769SAart Bik 
10288a91bc7bSHarrietAkot /// Releases sparse tensor storage.
10298a91bc7bSHarrietAkot void delSparseTensor(void *tensor) {
10308a91bc7bSHarrietAkot   delete static_cast<SparseTensorStorageBase *>(tensor);
10318a91bc7bSHarrietAkot }
10328a91bc7bSHarrietAkot 
10338a91bc7bSHarrietAkot /// Initializes sparse tensor from a COO-flavored format expressed using C-style
10348a91bc7bSHarrietAkot /// data structures. The expected parameters are:
10358a91bc7bSHarrietAkot ///
10368a91bc7bSHarrietAkot ///   rank:    rank of tensor
10378a91bc7bSHarrietAkot ///   nse:     number of specified elements (usually the nonzeros)
10388a91bc7bSHarrietAkot ///   shape:   array with dimension size for each rank
10398a91bc7bSHarrietAkot ///   values:  a "nse" array with values for all specified elements
10408a91bc7bSHarrietAkot ///   indices: a flat "nse x rank" array with indices for all specified elements
10418a91bc7bSHarrietAkot ///
10428a91bc7bSHarrietAkot /// For example, the sparse matrix
10438a91bc7bSHarrietAkot ///     | 1.0 0.0 0.0 |
10448a91bc7bSHarrietAkot ///     | 0.0 5.0 3.0 |
10458a91bc7bSHarrietAkot /// can be passed as
10468a91bc7bSHarrietAkot ///      rank    = 2
10478a91bc7bSHarrietAkot ///      nse     = 3
10488a91bc7bSHarrietAkot ///      shape   = [2, 3]
10498a91bc7bSHarrietAkot ///      values  = [1.0, 5.0, 3.0]
10508a91bc7bSHarrietAkot ///      indices = [ 0, 0,  1, 1,  1, 2]
10518a91bc7bSHarrietAkot //
10528a91bc7bSHarrietAkot // TODO: for now f64 tensors only, no dim ordering, all dimensions compressed
10538a91bc7bSHarrietAkot //
10548a91bc7bSHarrietAkot void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
10558a91bc7bSHarrietAkot                                 double *values, uint64_t *indices) {
10568a91bc7bSHarrietAkot   // Setup all-dims compressed and default ordering.
1057845561ecSwren romano   std::vector<DimLevelType> sparse(rank, DimLevelType::kCompressed);
10588a91bc7bSHarrietAkot   std::vector<uint64_t> perm(rank);
10598a91bc7bSHarrietAkot   std::iota(perm.begin(), perm.end(), 0);
10608a91bc7bSHarrietAkot   // Convert external format to internal COO.
10618a91bc7bSHarrietAkot   SparseTensorCOO<double> *tensor = SparseTensorCOO<double>::newSparseTensorCOO(
10628a91bc7bSHarrietAkot       rank, shape, perm.data(), nse);
10638a91bc7bSHarrietAkot   std::vector<uint64_t> idx(rank);
10648a91bc7bSHarrietAkot   for (uint64_t i = 0, base = 0; i < nse; i++) {
10658a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
10668a91bc7bSHarrietAkot       idx[r] = indices[base + r];
10678a91bc7bSHarrietAkot     tensor->add(idx, values[i]);
10688a91bc7bSHarrietAkot     base += rank;
10698a91bc7bSHarrietAkot   }
10708a91bc7bSHarrietAkot   // Return sparse tensor storage format as opaque pointer.
10718a91bc7bSHarrietAkot   return SparseTensorStorage<uint64_t, uint64_t, double>::newSparseTensor(
10728a91bc7bSHarrietAkot       rank, shape, perm.data(), sparse.data(), tensor);
10738a91bc7bSHarrietAkot }
10748a91bc7bSHarrietAkot 
10752f49e6b0SBixia Zheng /// Converts a sparse tensor to COO-flavored format expressed using C-style
10762f49e6b0SBixia Zheng /// data structures. The expected output parameters are pointers for these
10772f49e6b0SBixia Zheng /// values:
10782f49e6b0SBixia Zheng ///
10792f49e6b0SBixia Zheng ///   rank:    rank of tensor
10802f49e6b0SBixia Zheng ///   nse:     number of specified elements (usually the nonzeros)
10812f49e6b0SBixia Zheng ///   shape:   array with dimension size for each rank
10822f49e6b0SBixia Zheng ///   values:  a "nse" array with values for all specified elements
10832f49e6b0SBixia Zheng ///   indices: a flat "nse x rank" array with indices for all specified elements
10842f49e6b0SBixia Zheng ///
10852f49e6b0SBixia Zheng /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
10862f49e6b0SBixia Zheng /// from convertToMLIRSparseTensor.
10872f49e6b0SBixia Zheng ///
10882f49e6b0SBixia Zheng //  TODO: Currently, values are copied from SparseTensorStorage to
10892f49e6b0SBixia Zheng //  SparseTensorCOO, then to the output. We may want to reduce the number of
10902f49e6b0SBixia Zheng //  copies.
10912f49e6b0SBixia Zheng //
10922f49e6b0SBixia Zheng //  TODO: for now f64 tensors only, no dim ordering, all dimensions compressed
10932f49e6b0SBixia Zheng //
1094*bb56c2b3SMehdi Amini void convertFromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
1095*bb56c2b3SMehdi Amini                                  uint64_t **pShape, double **pValues,
1096*bb56c2b3SMehdi Amini                                  uint64_t **pIndices) {
1097*bb56c2b3SMehdi Amini   SparseTensorStorage<uint64_t, uint64_t, double> *sparseTensor =
10982f49e6b0SBixia Zheng       static_cast<SparseTensorStorage<uint64_t, uint64_t, double> *>(tensor);
1099*bb56c2b3SMehdi Amini   uint64_t rank = sparseTensor->getRank();
11002f49e6b0SBixia Zheng   std::vector<uint64_t> perm(rank);
11012f49e6b0SBixia Zheng   std::iota(perm.begin(), perm.end(), 0);
1102*bb56c2b3SMehdi Amini   SparseTensorCOO<double> *coo = sparseTensor->toCOO(perm.data());
11032f49e6b0SBixia Zheng 
11042f49e6b0SBixia Zheng   const std::vector<Element<double>> &elements = coo->getElements();
11052f49e6b0SBixia Zheng   uint64_t nse = elements.size();
11062f49e6b0SBixia Zheng 
11072f49e6b0SBixia Zheng   uint64_t *shape = new uint64_t[rank];
11082f49e6b0SBixia Zheng   for (uint64_t i = 0; i < rank; i++)
11092f49e6b0SBixia Zheng     shape[i] = coo->getSizes()[i];
11102f49e6b0SBixia Zheng 
11112f49e6b0SBixia Zheng   double *values = new double[nse];
11122f49e6b0SBixia Zheng   uint64_t *indices = new uint64_t[rank * nse];
11132f49e6b0SBixia Zheng 
11142f49e6b0SBixia Zheng   for (uint64_t i = 0, base = 0; i < nse; i++) {
11152f49e6b0SBixia Zheng     values[i] = elements[i].value;
11162f49e6b0SBixia Zheng     for (uint64_t j = 0; j < rank; j++)
11172f49e6b0SBixia Zheng       indices[base + j] = elements[i].indices[j];
11182f49e6b0SBixia Zheng     base += rank;
11192f49e6b0SBixia Zheng   }
11202f49e6b0SBixia Zheng 
11212f49e6b0SBixia Zheng   delete coo;
1122*bb56c2b3SMehdi Amini   *pRank = rank;
1123*bb56c2b3SMehdi Amini   *pNse = nse;
1124*bb56c2b3SMehdi Amini   *pShape = shape;
1125*bb56c2b3SMehdi Amini   *pValues = values;
1126*bb56c2b3SMehdi Amini   *pIndices = indices;
11272f49e6b0SBixia Zheng }
11288a91bc7bSHarrietAkot } // extern "C"
11298a91bc7bSHarrietAkot 
11308a91bc7bSHarrietAkot #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1131