18a91bc7bSHarrietAkot //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
28a91bc7bSHarrietAkot //
38a91bc7bSHarrietAkot // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48a91bc7bSHarrietAkot // See https://llvm.org/LICENSE.txt for license information.
58a91bc7bSHarrietAkot // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a91bc7bSHarrietAkot //
78a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
88a91bc7bSHarrietAkot //
98a91bc7bSHarrietAkot // This file implements a light-weight runtime support library that is useful
108a91bc7bSHarrietAkot // for sparse tensor manipulations. The functionality provided in this library
118a91bc7bSHarrietAkot // is meant to simplify benchmarking, testing, and debugging MLIR code that
128a91bc7bSHarrietAkot // operates on sparse tensors. The provided functionality is **not** part
138a91bc7bSHarrietAkot // of core MLIR, however.
148a91bc7bSHarrietAkot //
158a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
168a91bc7bSHarrietAkot 
17845561ecSwren romano #include "mlir/ExecutionEngine/SparseTensorUtils.h"
188a91bc7bSHarrietAkot #include "mlir/ExecutionEngine/CRunnerUtils.h"
198a91bc7bSHarrietAkot 
208a91bc7bSHarrietAkot #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
218a91bc7bSHarrietAkot 
228a91bc7bSHarrietAkot #include <algorithm>
238a91bc7bSHarrietAkot #include <cassert>
248a91bc7bSHarrietAkot #include <cctype>
258a91bc7bSHarrietAkot #include <cinttypes>
268a91bc7bSHarrietAkot #include <cstdio>
278a91bc7bSHarrietAkot #include <cstdlib>
288a91bc7bSHarrietAkot #include <cstring>
29efa15f41SAart Bik #include <fstream>
30efa15f41SAart Bik #include <iostream>
314d0a18d0Swren romano #include <limits>
328a91bc7bSHarrietAkot #include <numeric>
338a91bc7bSHarrietAkot #include <vector>
348a91bc7bSHarrietAkot 
358a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
368a91bc7bSHarrietAkot //
378a91bc7bSHarrietAkot // Internal support for storing and reading sparse tensors.
388a91bc7bSHarrietAkot //
398a91bc7bSHarrietAkot // The following memory-resident sparse storage schemes are supported:
408a91bc7bSHarrietAkot //
418a91bc7bSHarrietAkot // (a) A coordinate scheme for temporarily storing and lexicographically
428a91bc7bSHarrietAkot //     sorting a sparse tensor by index (SparseTensorCOO).
438a91bc7bSHarrietAkot //
448a91bc7bSHarrietAkot // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
458a91bc7bSHarrietAkot //     per-dimension sparse/dense annnotations together with a dimension
468a91bc7bSHarrietAkot //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
478a91bc7bSHarrietAkot //
488a91bc7bSHarrietAkot // The following external formats are supported:
498a91bc7bSHarrietAkot //
508a91bc7bSHarrietAkot // (1) Matrix Market Exchange (MME): *.mtx
518a91bc7bSHarrietAkot //     https://math.nist.gov/MatrixMarket/formats.html
528a91bc7bSHarrietAkot //
538a91bc7bSHarrietAkot // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
548a91bc7bSHarrietAkot //     http://frostt.io/tensors/file-formats.html
558a91bc7bSHarrietAkot //
568a91bc7bSHarrietAkot // Two public APIs are supported:
578a91bc7bSHarrietAkot //
588a91bc7bSHarrietAkot // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
598a91bc7bSHarrietAkot //     tensors. These methods should be used exclusively by MLIR
608a91bc7bSHarrietAkot //     compiler-generated code.
618a91bc7bSHarrietAkot //
628a91bc7bSHarrietAkot // (II) Methods that accept C-style data structures to interact with sparse
638a91bc7bSHarrietAkot //      tensors. These methods can be used by any external runtime that wants
648a91bc7bSHarrietAkot //      to interact with MLIR compiler-generated code.
658a91bc7bSHarrietAkot //
668a91bc7bSHarrietAkot // In both cases (I) and (II), the SparseTensorStorage format is externally
678a91bc7bSHarrietAkot // only visible as an opaque pointer.
688a91bc7bSHarrietAkot //
698a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
708a91bc7bSHarrietAkot 
718a91bc7bSHarrietAkot namespace {
728a91bc7bSHarrietAkot 
7303fe15ceSAart Bik static constexpr int kColWidth = 1025;
7403fe15ceSAart Bik 
758a91bc7bSHarrietAkot /// A sparse tensor element in coordinate scheme (value and indices).
768a91bc7bSHarrietAkot /// For example, a rank-1 vector element would look like
778a91bc7bSHarrietAkot ///   ({i}, a[i])
788a91bc7bSHarrietAkot /// and a rank-5 tensor element like
798a91bc7bSHarrietAkot ///   ({i,j,k,l,m}, a[i,j,k,l,m])
808a91bc7bSHarrietAkot template <typename V>
818a91bc7bSHarrietAkot struct Element {
828a91bc7bSHarrietAkot   Element(const std::vector<uint64_t> &ind, V val) : indices(ind), value(val){};
838a91bc7bSHarrietAkot   std::vector<uint64_t> indices;
848a91bc7bSHarrietAkot   V value;
85110295ebSwren romano   /// Returns true if indices of e1 < indices of e2.
86110295ebSwren romano   static bool lexOrder(const Element<V> &e1, const Element<V> &e2) {
87110295ebSwren romano     uint64_t rank = e1.indices.size();
88110295ebSwren romano     assert(rank == e2.indices.size());
89110295ebSwren romano     for (uint64_t r = 0; r < rank; r++) {
90110295ebSwren romano       if (e1.indices[r] == e2.indices[r])
91110295ebSwren romano         continue;
92110295ebSwren romano       return e1.indices[r] < e2.indices[r];
93110295ebSwren romano     }
94110295ebSwren romano     return false;
95110295ebSwren romano   }
968a91bc7bSHarrietAkot };
978a91bc7bSHarrietAkot 
988a91bc7bSHarrietAkot /// A memory-resident sparse tensor in coordinate scheme (collection of
998a91bc7bSHarrietAkot /// elements). This data structure is used to read a sparse tensor from
1008a91bc7bSHarrietAkot /// any external format into memory and sort the elements lexicographically
1018a91bc7bSHarrietAkot /// by indices before passing it back to the client (most packed storage
1028a91bc7bSHarrietAkot /// formats require the elements to appear in lexicographic index order).
1038a91bc7bSHarrietAkot template <typename V>
1048a91bc7bSHarrietAkot struct SparseTensorCOO {
1058a91bc7bSHarrietAkot public:
1068a91bc7bSHarrietAkot   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
1078a91bc7bSHarrietAkot       : sizes(szs), iteratorLocked(false), iteratorPos(0) {
1088a91bc7bSHarrietAkot     if (capacity)
1098a91bc7bSHarrietAkot       elements.reserve(capacity);
1108a91bc7bSHarrietAkot   }
1118a91bc7bSHarrietAkot   /// Adds element as indices and value.
1128a91bc7bSHarrietAkot   void add(const std::vector<uint64_t> &ind, V val) {
1138a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to add() after startIterator()");
1148a91bc7bSHarrietAkot     uint64_t rank = getRank();
1158a91bc7bSHarrietAkot     assert(rank == ind.size());
1168a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
1178a91bc7bSHarrietAkot       assert(ind[r] < sizes[r]); // within bounds
1188a91bc7bSHarrietAkot     elements.emplace_back(ind, val);
1198a91bc7bSHarrietAkot   }
1208a91bc7bSHarrietAkot   /// Sorts elements lexicographically by index.
1218a91bc7bSHarrietAkot   void sort() {
1228a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
123cf358253Swren romano     // TODO: we may want to cache an `isSorted` bit, to avoid
124cf358253Swren romano     // unnecessary/redundant sorting.
125110295ebSwren romano     std::sort(elements.begin(), elements.end(), Element<V>::lexOrder);
1268a91bc7bSHarrietAkot   }
1278a91bc7bSHarrietAkot   /// Returns rank.
1288a91bc7bSHarrietAkot   uint64_t getRank() const { return sizes.size(); }
1298a91bc7bSHarrietAkot   /// Getter for sizes array.
1308a91bc7bSHarrietAkot   const std::vector<uint64_t> &getSizes() const { return sizes; }
1318a91bc7bSHarrietAkot   /// Getter for elements array.
1328a91bc7bSHarrietAkot   const std::vector<Element<V>> &getElements() const { return elements; }
1338a91bc7bSHarrietAkot 
1348a91bc7bSHarrietAkot   /// Switch into iterator mode.
1358a91bc7bSHarrietAkot   void startIterator() {
1368a91bc7bSHarrietAkot     iteratorLocked = true;
1378a91bc7bSHarrietAkot     iteratorPos = 0;
1388a91bc7bSHarrietAkot   }
1398a91bc7bSHarrietAkot   /// Get the next element.
1408a91bc7bSHarrietAkot   const Element<V> *getNext() {
1418a91bc7bSHarrietAkot     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
1428a91bc7bSHarrietAkot     if (iteratorPos < elements.size())
1438a91bc7bSHarrietAkot       return &(elements[iteratorPos++]);
1448a91bc7bSHarrietAkot     iteratorLocked = false;
1458a91bc7bSHarrietAkot     return nullptr;
1468a91bc7bSHarrietAkot   }
1478a91bc7bSHarrietAkot 
1488a91bc7bSHarrietAkot   /// Factory method. Permutes the original dimensions according to
1498a91bc7bSHarrietAkot   /// the given ordering and expects subsequent add() calls to honor
1508a91bc7bSHarrietAkot   /// that same ordering for the given indices. The result is a
1518a91bc7bSHarrietAkot   /// fully permuted coordinate scheme.
1528a91bc7bSHarrietAkot   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
1538a91bc7bSHarrietAkot                                                 const uint64_t *sizes,
1548a91bc7bSHarrietAkot                                                 const uint64_t *perm,
1558a91bc7bSHarrietAkot                                                 uint64_t capacity = 0) {
1568a91bc7bSHarrietAkot     std::vector<uint64_t> permsz(rank);
157d83a7068Swren romano     for (uint64_t r = 0; r < rank; r++) {
158d83a7068Swren romano       assert(sizes[r] > 0 && "Dimension size zero has trivial storage");
1598a91bc7bSHarrietAkot       permsz[perm[r]] = sizes[r];
160d83a7068Swren romano     }
1618a91bc7bSHarrietAkot     return new SparseTensorCOO<V>(permsz, capacity);
1628a91bc7bSHarrietAkot   }
1638a91bc7bSHarrietAkot 
1648a91bc7bSHarrietAkot private:
1658a91bc7bSHarrietAkot   const std::vector<uint64_t> sizes; // per-dimension sizes
1668a91bc7bSHarrietAkot   std::vector<Element<V>> elements;
1678a91bc7bSHarrietAkot   bool iteratorLocked;
1688a91bc7bSHarrietAkot   unsigned iteratorPos;
1698a91bc7bSHarrietAkot };
1708a91bc7bSHarrietAkot 
1718a91bc7bSHarrietAkot /// Abstract base class of sparse tensor storage. Note that we use
1728a91bc7bSHarrietAkot /// function overloading to implement "partial" method specialization.
1738a91bc7bSHarrietAkot class SparseTensorStorageBase {
1748a91bc7bSHarrietAkot public:
1754f2ec7f9SAart Bik   /// Dimension size query.
17646bdacaaSwren romano   virtual uint64_t getDimSize(uint64_t) const = 0;
1778a91bc7bSHarrietAkot 
1784f2ec7f9SAart Bik   /// Overhead storage.
1798a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
1808a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
1818a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
1828a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
1838a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
1848a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
1858a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
1868a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
1878a91bc7bSHarrietAkot 
1884f2ec7f9SAart Bik   /// Primary storage.
1898a91bc7bSHarrietAkot   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
1908a91bc7bSHarrietAkot   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
1918a91bc7bSHarrietAkot   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
1928a91bc7bSHarrietAkot   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
1938a91bc7bSHarrietAkot   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
1948a91bc7bSHarrietAkot   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
1958a91bc7bSHarrietAkot 
1964f2ec7f9SAart Bik   /// Element-wise insertion in lexicographic index order.
197c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
198c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); }
199c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); }
200c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
201c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
202c03fd1e6Swren romano   virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
2034f2ec7f9SAart Bik 
2044f2ec7f9SAart Bik   /// Expanded insertion.
2054f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
2064f2ec7f9SAart Bik     fatal("expf64");
2074f2ec7f9SAart Bik   }
2084f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
2094f2ec7f9SAart Bik     fatal("expf32");
2104f2ec7f9SAart Bik   }
2114f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
2124f2ec7f9SAart Bik     fatal("expi64");
2134f2ec7f9SAart Bik   }
2144f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
2154f2ec7f9SAart Bik     fatal("expi32");
2164f2ec7f9SAart Bik   }
2174f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
2184f2ec7f9SAart Bik     fatal("expi16");
2194f2ec7f9SAart Bik   }
2204f2ec7f9SAart Bik   virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
2214f2ec7f9SAart Bik     fatal("expi8");
2224f2ec7f9SAart Bik   }
2234f2ec7f9SAart Bik 
2244f2ec7f9SAart Bik   /// Finishes insertion.
225f66e5769SAart Bik   virtual void endInsert() = 0;
226f66e5769SAart Bik 
227e5639b3fSMehdi Amini   virtual ~SparseTensorStorageBase() = default;
2288a91bc7bSHarrietAkot 
2298a91bc7bSHarrietAkot private:
23046bdacaaSwren romano   static void fatal(const char *tp) {
2318a91bc7bSHarrietAkot     fprintf(stderr, "unsupported %s\n", tp);
2328a91bc7bSHarrietAkot     exit(1);
2338a91bc7bSHarrietAkot   }
2348a91bc7bSHarrietAkot };
2358a91bc7bSHarrietAkot 
2368a91bc7bSHarrietAkot /// A memory-resident sparse tensor using a storage scheme based on
2378a91bc7bSHarrietAkot /// per-dimension sparse/dense annotations. This data structure provides a
2388a91bc7bSHarrietAkot /// bufferized form of a sparse tensor type. In contrast to generating setup
2398a91bc7bSHarrietAkot /// methods for each differently annotated sparse tensor, this method provides
2408a91bc7bSHarrietAkot /// a convenient "one-size-fits-all" solution that simply takes an input tensor
2418a91bc7bSHarrietAkot /// and annotations to implement all required setup in a general manner.
2428a91bc7bSHarrietAkot template <typename P, typename I, typename V>
2438a91bc7bSHarrietAkot class SparseTensorStorage : public SparseTensorStorageBase {
2448a91bc7bSHarrietAkot public:
2458a91bc7bSHarrietAkot   /// Constructs a sparse tensor storage scheme with the given dimensions,
2468a91bc7bSHarrietAkot   /// permutation, and per-dimension dense/sparse annotations, using
2478a91bc7bSHarrietAkot   /// the coordinate scheme tensor for the initial contents if provided.
2488a91bc7bSHarrietAkot   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
249f66e5769SAart Bik                       const DimLevelType *sparsity,
250f66e5769SAart Bik                       SparseTensorCOO<V> *tensor = nullptr)
251f66e5769SAart Bik       : sizes(szs), rev(getRank()), idx(getRank()), pointers(getRank()),
252f66e5769SAart Bik         indices(getRank()) {
2538a91bc7bSHarrietAkot     uint64_t rank = getRank();
2548a91bc7bSHarrietAkot     // Store "reverse" permutation.
2558a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
2568a91bc7bSHarrietAkot       rev[perm[r]] = r;
2578a91bc7bSHarrietAkot     // Provide hints on capacity of pointers and indices.
2588a91bc7bSHarrietAkot     // TODO: needs fine-tuning based on sparsity
259f66e5769SAart Bik     bool allDense = true;
260f66e5769SAart Bik     uint64_t sz = 1;
261f66e5769SAart Bik     for (uint64_t r = 0; r < rank; r++) {
2624d0a18d0Swren romano       assert(sizes[r] > 0 && "Dimension size zero has trivial storage");
263f66e5769SAart Bik       sz *= sizes[r];
264845561ecSwren romano       if (sparsity[r] == DimLevelType::kCompressed) {
265f66e5769SAart Bik         pointers[r].reserve(sz + 1);
266f66e5769SAart Bik         indices[r].reserve(sz);
267f66e5769SAart Bik         sz = 1;
268f66e5769SAart Bik         allDense = false;
269289f84a4Swren romano         // Prepare the pointer structure.  We cannot use `appendPointer`
2704d0a18d0Swren romano         // here, because `isCompressedDim` won't work until after this
2714d0a18d0Swren romano         // preparation has been done.
2724d0a18d0Swren romano         pointers[r].push_back(0);
2738a91bc7bSHarrietAkot       } else {
274845561ecSwren romano         assert(sparsity[r] == DimLevelType::kDense &&
275845561ecSwren romano                "singleton not yet supported");
2768a91bc7bSHarrietAkot       }
2778a91bc7bSHarrietAkot     }
2788a91bc7bSHarrietAkot     // Then assign contents from coordinate scheme tensor if provided.
2798a91bc7bSHarrietAkot     if (tensor) {
2804d0a18d0Swren romano       // Ensure both preconditions of `fromCOO`.
2814d0a18d0Swren romano       assert(tensor->getSizes() == sizes && "Tensor size mismatch");
282cf358253Swren romano       tensor->sort();
2834d0a18d0Swren romano       // Now actually insert the `elements`.
284ceda1ae9Swren romano       const std::vector<Element<V>> &elements = tensor->getElements();
285ceda1ae9Swren romano       uint64_t nnz = elements.size();
2868a91bc7bSHarrietAkot       values.reserve(nnz);
287ceda1ae9Swren romano       fromCOO(elements, 0, nnz, 0);
2881ce77b56SAart Bik     } else if (allDense) {
289f66e5769SAart Bik       values.resize(sz, 0);
2908a91bc7bSHarrietAkot     }
2918a91bc7bSHarrietAkot   }
2928a91bc7bSHarrietAkot 
2930ae2e958SMehdi Amini   ~SparseTensorStorage() override = default;
2948a91bc7bSHarrietAkot 
2958a91bc7bSHarrietAkot   /// Get the rank of the tensor.
2968a91bc7bSHarrietAkot   uint64_t getRank() const { return sizes.size(); }
2978a91bc7bSHarrietAkot 
29846bdacaaSwren romano   /// Get the size of the given dimension of the tensor.
29946bdacaaSwren romano   uint64_t getDimSize(uint64_t d) const override {
3008a91bc7bSHarrietAkot     assert(d < getRank());
3018a91bc7bSHarrietAkot     return sizes[d];
3028a91bc7bSHarrietAkot   }
3038a91bc7bSHarrietAkot 
304f66e5769SAart Bik   /// Partially specialize these getter methods based on template types.
3058a91bc7bSHarrietAkot   void getPointers(std::vector<P> **out, uint64_t d) override {
3068a91bc7bSHarrietAkot     assert(d < getRank());
3078a91bc7bSHarrietAkot     *out = &pointers[d];
3088a91bc7bSHarrietAkot   }
3098a91bc7bSHarrietAkot   void getIndices(std::vector<I> **out, uint64_t d) override {
3108a91bc7bSHarrietAkot     assert(d < getRank());
3118a91bc7bSHarrietAkot     *out = &indices[d];
3128a91bc7bSHarrietAkot   }
3138a91bc7bSHarrietAkot   void getValues(std::vector<V> **out) override { *out = &values; }
3148a91bc7bSHarrietAkot 
31503fe15ceSAart Bik   /// Partially specialize lexicographical insertions based on template types.
316c03fd1e6Swren romano   void lexInsert(const uint64_t *cursor, V val) override {
3171ce77b56SAart Bik     // First, wrap up pending insertion path.
3181ce77b56SAart Bik     uint64_t diff = 0;
3191ce77b56SAart Bik     uint64_t top = 0;
3201ce77b56SAart Bik     if (!values.empty()) {
3211ce77b56SAart Bik       diff = lexDiff(cursor);
3221ce77b56SAart Bik       endPath(diff + 1);
3231ce77b56SAart Bik       top = idx[diff] + 1;
3241ce77b56SAart Bik     }
3251ce77b56SAart Bik     // Then continue with insertion path.
3261ce77b56SAart Bik     insPath(cursor, diff, top, val);
327f66e5769SAart Bik   }
328f66e5769SAart Bik 
3294f2ec7f9SAart Bik   /// Partially specialize expanded insertions based on template types.
3304f2ec7f9SAart Bik   /// Note that this method resets the values/filled-switch array back
3314f2ec7f9SAart Bik   /// to all-zero/false while only iterating over the nonzero elements.
3324f2ec7f9SAart Bik   void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
3334f2ec7f9SAart Bik                  uint64_t count) override {
3344f2ec7f9SAart Bik     if (count == 0)
3354f2ec7f9SAart Bik       return;
3364f2ec7f9SAart Bik     // Sort.
3374f2ec7f9SAart Bik     std::sort(added, added + count);
3384f2ec7f9SAart Bik     // Restore insertion path for first insert.
339*3bf2ba3bSwren romano     const uint64_t lastDim = getRank() - 1;
3404f2ec7f9SAart Bik     uint64_t index = added[0];
341*3bf2ba3bSwren romano     cursor[lastDim] = index;
3424f2ec7f9SAart Bik     lexInsert(cursor, values[index]);
3434f2ec7f9SAart Bik     assert(filled[index]);
3444f2ec7f9SAart Bik     values[index] = 0;
3454f2ec7f9SAart Bik     filled[index] = false;
3464f2ec7f9SAart Bik     // Subsequent insertions are quick.
3474f2ec7f9SAart Bik     for (uint64_t i = 1; i < count; i++) {
3484f2ec7f9SAart Bik       assert(index < added[i] && "non-lexicographic insertion");
3494f2ec7f9SAart Bik       index = added[i];
350*3bf2ba3bSwren romano       cursor[lastDim] = index;
351*3bf2ba3bSwren romano       insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
3524f2ec7f9SAart Bik       assert(filled[index]);
353*3bf2ba3bSwren romano       values[index] = 0;
3544f2ec7f9SAart Bik       filled[index] = false;
3554f2ec7f9SAart Bik     }
3564f2ec7f9SAart Bik   }
3574f2ec7f9SAart Bik 
358f66e5769SAart Bik   /// Finalizes lexicographic insertions.
3591ce77b56SAart Bik   void endInsert() override {
3601ce77b56SAart Bik     if (values.empty())
3611ce77b56SAart Bik       endDim(0);
3621ce77b56SAart Bik     else
3631ce77b56SAart Bik       endPath(0);
3641ce77b56SAart Bik   }
365f66e5769SAart Bik 
3668a91bc7bSHarrietAkot   /// Returns this sparse tensor storage scheme as a new memory-resident
3678a91bc7bSHarrietAkot   /// sparse tensor in coordinate scheme with the given dimension order.
3688a91bc7bSHarrietAkot   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
3698a91bc7bSHarrietAkot     // Restore original order of the dimension sizes and allocate coordinate
3708a91bc7bSHarrietAkot     // scheme with desired new ordering specified in perm.
3718a91bc7bSHarrietAkot     uint64_t rank = getRank();
3728a91bc7bSHarrietAkot     std::vector<uint64_t> orgsz(rank);
3738a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
3748a91bc7bSHarrietAkot       orgsz[rev[r]] = sizes[r];
3758a91bc7bSHarrietAkot     SparseTensorCOO<V> *tensor = SparseTensorCOO<V>::newSparseTensorCOO(
3768a91bc7bSHarrietAkot         rank, orgsz.data(), perm, values.size());
3778a91bc7bSHarrietAkot     // Populate coordinate scheme restored from old ordering and changed with
3788a91bc7bSHarrietAkot     // new ordering. Rather than applying both reorderings during the recursion,
3798a91bc7bSHarrietAkot     // we compute the combine permutation in advance.
3808a91bc7bSHarrietAkot     std::vector<uint64_t> reord(rank);
3818a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
3828a91bc7bSHarrietAkot       reord[r] = perm[rev[r]];
383ceda1ae9Swren romano     toCOO(*tensor, reord, 0, 0);
3848a91bc7bSHarrietAkot     assert(tensor->getElements().size() == values.size());
3858a91bc7bSHarrietAkot     return tensor;
3868a91bc7bSHarrietAkot   }
3878a91bc7bSHarrietAkot 
3888a91bc7bSHarrietAkot   /// Factory method. Constructs a sparse tensor storage scheme with the given
3898a91bc7bSHarrietAkot   /// dimensions, permutation, and per-dimension dense/sparse annotations,
3908a91bc7bSHarrietAkot   /// using the coordinate scheme tensor for the initial contents if provided.
3918a91bc7bSHarrietAkot   /// In the latter case, the coordinate scheme must respect the same
3928a91bc7bSHarrietAkot   /// permutation as is desired for the new sparse tensor storage.
3938a91bc7bSHarrietAkot   static SparseTensorStorage<P, I, V> *
394d83a7068Swren romano   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
395845561ecSwren romano                   const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) {
3968a91bc7bSHarrietAkot     SparseTensorStorage<P, I, V> *n = nullptr;
3978a91bc7bSHarrietAkot     if (tensor) {
3988a91bc7bSHarrietAkot       assert(tensor->getRank() == rank);
3998a91bc7bSHarrietAkot       for (uint64_t r = 0; r < rank; r++)
400d83a7068Swren romano         assert(shape[r] == 0 || shape[r] == tensor->getSizes()[perm[r]]);
4018a91bc7bSHarrietAkot       n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity,
4028a91bc7bSHarrietAkot                                            tensor);
4038a91bc7bSHarrietAkot     } else {
4048a91bc7bSHarrietAkot       std::vector<uint64_t> permsz(rank);
405d83a7068Swren romano       for (uint64_t r = 0; r < rank; r++) {
406d83a7068Swren romano         assert(shape[r] > 0 && "Dimension size zero has trivial storage");
407d83a7068Swren romano         permsz[perm[r]] = shape[r];
408d83a7068Swren romano       }
409f66e5769SAart Bik       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
4108a91bc7bSHarrietAkot     }
4118a91bc7bSHarrietAkot     return n;
4128a91bc7bSHarrietAkot   }
4138a91bc7bSHarrietAkot 
4148a91bc7bSHarrietAkot private:
4154d0a18d0Swren romano   /// Appends the next free position of `indices[d]` to `pointers[d]`.
4164d0a18d0Swren romano   /// Thus, when called after inserting the last element of a segment,
4174d0a18d0Swren romano   /// it will append the position where the next segment begins.
418289f84a4Swren romano   inline void appendPointer(uint64_t d) {
4194d0a18d0Swren romano     assert(isCompressedDim(d)); // Entails `d < getRank()`.
4204d0a18d0Swren romano     uint64_t p = indices[d].size();
4214d0a18d0Swren romano     assert(p <= std::numeric_limits<P>::max() &&
4224d0a18d0Swren romano            "Pointer value is too large for the P-type");
4234d0a18d0Swren romano     pointers[d].push_back(p); // Here is where we convert to `P`.
4244d0a18d0Swren romano   }
4254d0a18d0Swren romano 
4264d0a18d0Swren romano   /// Appends the given index to `indices[d]`.
427289f84a4Swren romano   inline void appendIndex(uint64_t d, uint64_t i) {
4284d0a18d0Swren romano     assert(isCompressedDim(d)); // Entails `d < getRank()`.
4294d0a18d0Swren romano     assert(i <= std::numeric_limits<I>::max() &&
4304d0a18d0Swren romano            "Index value is too large for the I-type");
4314d0a18d0Swren romano     indices[d].push_back(i); // Here is where we convert to `I`.
4324d0a18d0Swren romano   }
4334d0a18d0Swren romano 
4348a91bc7bSHarrietAkot   /// Initializes sparse tensor storage scheme from a memory-resident sparse
4358a91bc7bSHarrietAkot   /// tensor in coordinate scheme. This method prepares the pointers and
4368a91bc7bSHarrietAkot   /// indices arrays under the given per-dimension dense/sparse annotations.
4374d0a18d0Swren romano   ///
4384d0a18d0Swren romano   /// Preconditions:
4394d0a18d0Swren romano   /// (1) the `elements` must be lexicographically sorted.
4404d0a18d0Swren romano   /// (2) the indices of every element are valid for `sizes` (equal rank
4414d0a18d0Swren romano   ///     and pointwise less-than).
442ceda1ae9Swren romano   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
443ceda1ae9Swren romano                uint64_t hi, uint64_t d) {
4448a91bc7bSHarrietAkot     // Once dimensions are exhausted, insert the numerical values.
445c4017f9dSwren romano     assert(d <= getRank() && hi <= elements.size());
4468a91bc7bSHarrietAkot     if (d == getRank()) {
447c4017f9dSwren romano       assert(lo < hi);
4481ce77b56SAart Bik       values.push_back(elements[lo].value);
4498a91bc7bSHarrietAkot       return;
4508a91bc7bSHarrietAkot     }
4518a91bc7bSHarrietAkot     // Visit all elements in this interval.
4528a91bc7bSHarrietAkot     uint64_t full = 0;
453c4017f9dSwren romano     while (lo < hi) { // If `hi` is unchanged, then `lo < elements.size()`.
4548a91bc7bSHarrietAkot       // Find segment in interval with same index elements in this dimension.
455f66e5769SAart Bik       uint64_t i = elements[lo].indices[d];
4568a91bc7bSHarrietAkot       uint64_t seg = lo + 1;
457f66e5769SAart Bik       while (seg < hi && elements[seg].indices[d] == i)
4588a91bc7bSHarrietAkot         seg++;
4598a91bc7bSHarrietAkot       // Handle segment in interval for sparse or dense dimension.
4601ce77b56SAart Bik       if (isCompressedDim(d)) {
461289f84a4Swren romano         appendIndex(d, i);
4628a91bc7bSHarrietAkot       } else {
4638a91bc7bSHarrietAkot         // For dense storage we must fill in all the zero values between
4648a91bc7bSHarrietAkot         // the previous element (when last we ran this for-loop) and the
4658a91bc7bSHarrietAkot         // current element.
466f66e5769SAart Bik         for (; full < i; full++)
4671ce77b56SAart Bik           endDim(d + 1);
4688a91bc7bSHarrietAkot         full++;
4698a91bc7bSHarrietAkot       }
470ceda1ae9Swren romano       fromCOO(elements, lo, seg, d + 1);
4718a91bc7bSHarrietAkot       // And move on to next segment in interval.
4728a91bc7bSHarrietAkot       lo = seg;
4738a91bc7bSHarrietAkot     }
4748a91bc7bSHarrietAkot     // Finalize the sparse pointer structure at this dimension.
4751ce77b56SAart Bik     if (isCompressedDim(d)) {
476289f84a4Swren romano       appendPointer(d);
4778a91bc7bSHarrietAkot     } else {
4788a91bc7bSHarrietAkot       // For dense storage we must fill in all the zero values after
4798a91bc7bSHarrietAkot       // the last element.
4808a91bc7bSHarrietAkot       for (uint64_t sz = sizes[d]; full < sz; full++)
4811ce77b56SAart Bik         endDim(d + 1);
4828a91bc7bSHarrietAkot     }
4838a91bc7bSHarrietAkot   }
4848a91bc7bSHarrietAkot 
4858a91bc7bSHarrietAkot   /// Stores the sparse tensor storage scheme into a memory-resident sparse
4868a91bc7bSHarrietAkot   /// tensor in coordinate scheme.
487ceda1ae9Swren romano   void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord,
488f66e5769SAart Bik              uint64_t pos, uint64_t d) {
4898a91bc7bSHarrietAkot     assert(d <= getRank());
4908a91bc7bSHarrietAkot     if (d == getRank()) {
4918a91bc7bSHarrietAkot       assert(pos < values.size());
492ceda1ae9Swren romano       tensor.add(idx, values[pos]);
4931ce77b56SAart Bik     } else if (isCompressedDim(d)) {
4948a91bc7bSHarrietAkot       // Sparse dimension.
4958a91bc7bSHarrietAkot       for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
4968a91bc7bSHarrietAkot         idx[reord[d]] = indices[d][ii];
497f66e5769SAart Bik         toCOO(tensor, reord, ii, d + 1);
4988a91bc7bSHarrietAkot       }
4991ce77b56SAart Bik     } else {
5001ce77b56SAart Bik       // Dense dimension.
5011ce77b56SAart Bik       for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) {
5021ce77b56SAart Bik         idx[reord[d]] = i;
5031ce77b56SAart Bik         toCOO(tensor, reord, off + i, d + 1);
5048a91bc7bSHarrietAkot       }
5058a91bc7bSHarrietAkot     }
5061ce77b56SAart Bik   }
5071ce77b56SAart Bik 
5081ce77b56SAart Bik   /// Ends a deeper, never seen before dimension.
5091ce77b56SAart Bik   void endDim(uint64_t d) {
5101ce77b56SAart Bik     assert(d <= getRank());
5111ce77b56SAart Bik     if (d == getRank()) {
5121ce77b56SAart Bik       values.push_back(0);
5131ce77b56SAart Bik     } else if (isCompressedDim(d)) {
514289f84a4Swren romano       appendPointer(d);
5151ce77b56SAart Bik     } else {
5161ce77b56SAart Bik       for (uint64_t full = 0, sz = sizes[d]; full < sz; full++)
5171ce77b56SAart Bik         endDim(d + 1);
5181ce77b56SAart Bik     }
5191ce77b56SAart Bik   }
5201ce77b56SAart Bik 
5211ce77b56SAart Bik   /// Wraps up a single insertion path, inner to outer.
5221ce77b56SAart Bik   void endPath(uint64_t diff) {
5231ce77b56SAart Bik     uint64_t rank = getRank();
5241ce77b56SAart Bik     assert(diff <= rank);
5251ce77b56SAart Bik     for (uint64_t i = 0; i < rank - diff; i++) {
5261ce77b56SAart Bik       uint64_t d = rank - i - 1;
5271ce77b56SAart Bik       if (isCompressedDim(d)) {
528289f84a4Swren romano         appendPointer(d);
5291ce77b56SAart Bik       } else {
5301ce77b56SAart Bik         for (uint64_t full = idx[d] + 1, sz = sizes[d]; full < sz; full++)
5311ce77b56SAart Bik           endDim(d + 1);
5321ce77b56SAart Bik       }
5331ce77b56SAart Bik     }
5341ce77b56SAart Bik   }
5351ce77b56SAart Bik 
5361ce77b56SAart Bik   /// Continues a single insertion path, outer to inner.
537c03fd1e6Swren romano   void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
5381ce77b56SAart Bik     uint64_t rank = getRank();
5391ce77b56SAart Bik     assert(diff < rank);
5401ce77b56SAart Bik     for (uint64_t d = diff; d < rank; d++) {
5411ce77b56SAart Bik       uint64_t i = cursor[d];
5421ce77b56SAart Bik       if (isCompressedDim(d)) {
543289f84a4Swren romano         appendIndex(d, i);
5441ce77b56SAart Bik       } else {
5451ce77b56SAart Bik         for (uint64_t full = top; full < i; full++)
5461ce77b56SAart Bik           endDim(d + 1);
5471ce77b56SAart Bik       }
5481ce77b56SAart Bik       top = 0;
5491ce77b56SAart Bik       idx[d] = i;
5501ce77b56SAart Bik     }
5511ce77b56SAart Bik     values.push_back(val);
5521ce77b56SAart Bik   }
5531ce77b56SAart Bik 
5541ce77b56SAart Bik   /// Finds the lexicographic differing dimension.
55546bdacaaSwren romano   uint64_t lexDiff(const uint64_t *cursor) const {
5561ce77b56SAart Bik     for (uint64_t r = 0, rank = getRank(); r < rank; r++)
5571ce77b56SAart Bik       if (cursor[r] > idx[r])
5581ce77b56SAart Bik         return r;
5591ce77b56SAart Bik       else
5601ce77b56SAart Bik         assert(cursor[r] == idx[r] && "non-lexicographic insertion");
5611ce77b56SAart Bik     assert(0 && "duplication insertion");
5621ce77b56SAart Bik     return -1u;
5631ce77b56SAart Bik   }
5641ce77b56SAart Bik 
5651ce77b56SAart Bik   /// Returns true if dimension is compressed.
5661ce77b56SAart Bik   inline bool isCompressedDim(uint64_t d) const {
5674d0a18d0Swren romano     assert(d < getRank());
5681ce77b56SAart Bik     return (!pointers[d].empty());
5691ce77b56SAart Bik   }
5708a91bc7bSHarrietAkot 
5718a91bc7bSHarrietAkot private:
57246bdacaaSwren romano   const std::vector<uint64_t> sizes; // per-dimension sizes
5738a91bc7bSHarrietAkot   std::vector<uint64_t> rev;   // "reverse" permutation
574f66e5769SAart Bik   std::vector<uint64_t> idx;   // index cursor
5758a91bc7bSHarrietAkot   std::vector<std::vector<P>> pointers;
5768a91bc7bSHarrietAkot   std::vector<std::vector<I>> indices;
5778a91bc7bSHarrietAkot   std::vector<V> values;
5788a91bc7bSHarrietAkot };
5798a91bc7bSHarrietAkot 
5808a91bc7bSHarrietAkot /// Helper to convert string to lower case.
5818a91bc7bSHarrietAkot static char *toLower(char *token) {
5828a91bc7bSHarrietAkot   for (char *c = token; *c; c++)
5838a91bc7bSHarrietAkot     *c = tolower(*c);
5848a91bc7bSHarrietAkot   return token;
5858a91bc7bSHarrietAkot }
5868a91bc7bSHarrietAkot 
5878a91bc7bSHarrietAkot /// Read the MME header of a general sparse matrix of type real.
58803fe15ceSAart Bik static void readMMEHeader(FILE *file, char *filename, char *line,
589bb56c2b3SMehdi Amini                           uint64_t *idata, bool *isSymmetric) {
5908a91bc7bSHarrietAkot   char header[64];
5918a91bc7bSHarrietAkot   char object[64];
5928a91bc7bSHarrietAkot   char format[64];
5938a91bc7bSHarrietAkot   char field[64];
5948a91bc7bSHarrietAkot   char symmetry[64];
5958a91bc7bSHarrietAkot   // Read header line.
5968a91bc7bSHarrietAkot   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
5978a91bc7bSHarrietAkot              symmetry) != 5) {
59803fe15ceSAart Bik     fprintf(stderr, "Corrupt header in %s\n", filename);
5998a91bc7bSHarrietAkot     exit(1);
6008a91bc7bSHarrietAkot   }
601bb56c2b3SMehdi Amini   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
6028a91bc7bSHarrietAkot   // Make sure this is a general sparse matrix.
6038a91bc7bSHarrietAkot   if (strcmp(toLower(header), "%%matrixmarket") ||
6048a91bc7bSHarrietAkot       strcmp(toLower(object), "matrix") ||
6058a91bc7bSHarrietAkot       strcmp(toLower(format), "coordinate") || strcmp(toLower(field), "real") ||
606bb56c2b3SMehdi Amini       (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
6078a91bc7bSHarrietAkot     fprintf(stderr,
60803fe15ceSAart Bik             "Cannot find a general sparse matrix with type real in %s\n",
60903fe15ceSAart Bik             filename);
6108a91bc7bSHarrietAkot     exit(1);
6118a91bc7bSHarrietAkot   }
6128a91bc7bSHarrietAkot   // Skip comments.
613e5639b3fSMehdi Amini   while (true) {
61403fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
61503fe15ceSAart Bik       fprintf(stderr, "Cannot find data in %s\n", filename);
6168a91bc7bSHarrietAkot       exit(1);
6178a91bc7bSHarrietAkot     }
6188a91bc7bSHarrietAkot     if (line[0] != '%')
6198a91bc7bSHarrietAkot       break;
6208a91bc7bSHarrietAkot   }
6218a91bc7bSHarrietAkot   // Next line contains M N NNZ.
6228a91bc7bSHarrietAkot   idata[0] = 2; // rank
6238a91bc7bSHarrietAkot   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
6248a91bc7bSHarrietAkot              idata + 1) != 3) {
62503fe15ceSAart Bik     fprintf(stderr, "Cannot find size in %s\n", filename);
6268a91bc7bSHarrietAkot     exit(1);
6278a91bc7bSHarrietAkot   }
6288a91bc7bSHarrietAkot }
6298a91bc7bSHarrietAkot 
6308a91bc7bSHarrietAkot /// Read the "extended" FROSTT header. Although not part of the documented
6318a91bc7bSHarrietAkot /// format, we assume that the file starts with optional comments followed
6328a91bc7bSHarrietAkot /// by two lines that define the rank, the number of nonzeros, and the
6338a91bc7bSHarrietAkot /// dimensions sizes (one per rank) of the sparse tensor.
63403fe15ceSAart Bik static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
63503fe15ceSAart Bik                                 uint64_t *idata) {
6368a91bc7bSHarrietAkot   // Skip comments.
637e5639b3fSMehdi Amini   while (true) {
63803fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
63903fe15ceSAart Bik       fprintf(stderr, "Cannot find data in %s\n", filename);
6408a91bc7bSHarrietAkot       exit(1);
6418a91bc7bSHarrietAkot     }
6428a91bc7bSHarrietAkot     if (line[0] != '#')
6438a91bc7bSHarrietAkot       break;
6448a91bc7bSHarrietAkot   }
6458a91bc7bSHarrietAkot   // Next line contains RANK and NNZ.
6468a91bc7bSHarrietAkot   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
64703fe15ceSAart Bik     fprintf(stderr, "Cannot find metadata in %s\n", filename);
6488a91bc7bSHarrietAkot     exit(1);
6498a91bc7bSHarrietAkot   }
6508a91bc7bSHarrietAkot   // Followed by a line with the dimension sizes (one per rank).
6518a91bc7bSHarrietAkot   for (uint64_t r = 0; r < idata[0]; r++) {
6528a91bc7bSHarrietAkot     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
65303fe15ceSAart Bik       fprintf(stderr, "Cannot find dimension size %s\n", filename);
6548a91bc7bSHarrietAkot       exit(1);
6558a91bc7bSHarrietAkot     }
6568a91bc7bSHarrietAkot   }
65703fe15ceSAart Bik   fgets(line, kColWidth, file); // end of line
6588a91bc7bSHarrietAkot }
6598a91bc7bSHarrietAkot 
6608a91bc7bSHarrietAkot /// Reads a sparse tensor with the given filename into a memory-resident
6618a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme.
6628a91bc7bSHarrietAkot template <typename V>
6638a91bc7bSHarrietAkot static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
664d83a7068Swren romano                                                const uint64_t *shape,
6658a91bc7bSHarrietAkot                                                const uint64_t *perm) {
6668a91bc7bSHarrietAkot   // Open the file.
6678a91bc7bSHarrietAkot   FILE *file = fopen(filename, "r");
6688a91bc7bSHarrietAkot   if (!file) {
6693734c078Swren romano     assert(filename && "Received nullptr for filename");
6703734c078Swren romano     fprintf(stderr, "Cannot find file %s\n", filename);
6718a91bc7bSHarrietAkot     exit(1);
6728a91bc7bSHarrietAkot   }
6738a91bc7bSHarrietAkot   // Perform some file format dependent set up.
67403fe15ceSAart Bik   char line[kColWidth];
6758a91bc7bSHarrietAkot   uint64_t idata[512];
676bb56c2b3SMehdi Amini   bool isSymmetric = false;
6778a91bc7bSHarrietAkot   if (strstr(filename, ".mtx")) {
678bb56c2b3SMehdi Amini     readMMEHeader(file, filename, line, idata, &isSymmetric);
6798a91bc7bSHarrietAkot   } else if (strstr(filename, ".tns")) {
68003fe15ceSAart Bik     readExtFROSTTHeader(file, filename, line, idata);
6818a91bc7bSHarrietAkot   } else {
6828a91bc7bSHarrietAkot     fprintf(stderr, "Unknown format %s\n", filename);
6838a91bc7bSHarrietAkot     exit(1);
6848a91bc7bSHarrietAkot   }
6858a91bc7bSHarrietAkot   // Prepare sparse tensor object with per-dimension sizes
6868a91bc7bSHarrietAkot   // and the number of nonzeros as initial capacity.
6878a91bc7bSHarrietAkot   assert(rank == idata[0] && "rank mismatch");
6888a91bc7bSHarrietAkot   uint64_t nnz = idata[1];
6898a91bc7bSHarrietAkot   for (uint64_t r = 0; r < rank; r++)
690d83a7068Swren romano     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
6918a91bc7bSHarrietAkot            "dimension size mismatch");
6928a91bc7bSHarrietAkot   SparseTensorCOO<V> *tensor =
6938a91bc7bSHarrietAkot       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
6948a91bc7bSHarrietAkot   //  Read all nonzero elements.
6958a91bc7bSHarrietAkot   std::vector<uint64_t> indices(rank);
6968a91bc7bSHarrietAkot   for (uint64_t k = 0; k < nnz; k++) {
69703fe15ceSAart Bik     if (!fgets(line, kColWidth, file)) {
69803fe15ceSAart Bik       fprintf(stderr, "Cannot find next line of data in %s\n", filename);
6998a91bc7bSHarrietAkot       exit(1);
7008a91bc7bSHarrietAkot     }
70103fe15ceSAart Bik     char *linePtr = line;
70203fe15ceSAart Bik     for (uint64_t r = 0; r < rank; r++) {
70303fe15ceSAart Bik       uint64_t idx = strtoul(linePtr, &linePtr, 10);
7048a91bc7bSHarrietAkot       // Add 0-based index.
7058a91bc7bSHarrietAkot       indices[perm[r]] = idx - 1;
7068a91bc7bSHarrietAkot     }
7078a91bc7bSHarrietAkot     // The external formats always store the numerical values with the type
7088a91bc7bSHarrietAkot     // double, but we cast these values to the sparse tensor object type.
70903fe15ceSAart Bik     double value = strtod(linePtr, &linePtr);
7108a91bc7bSHarrietAkot     tensor->add(indices, value);
71102710413SBixia Zheng     // We currently chose to deal with symmetric matrices by fully constructing
71202710413SBixia Zheng     // them. In the future, we may want to make symmetry implicit for storage
71302710413SBixia Zheng     // reasons.
714bb56c2b3SMehdi Amini     if (isSymmetric && indices[0] != indices[1])
71502710413SBixia Zheng       tensor->add({indices[1], indices[0]}, value);
7168a91bc7bSHarrietAkot   }
7178a91bc7bSHarrietAkot   // Close the file and return tensor.
7188a91bc7bSHarrietAkot   fclose(file);
7198a91bc7bSHarrietAkot   return tensor;
7208a91bc7bSHarrietAkot }
7218a91bc7bSHarrietAkot 
722efa15f41SAart Bik /// Writes the sparse tensor to extended FROSTT format.
723efa15f41SAart Bik template <typename V>
72446bdacaaSwren romano static void outSparseTensor(void *tensor, void *dest, bool sort) {
7256438783fSAart Bik   assert(tensor && dest);
7266438783fSAart Bik   auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
7276438783fSAart Bik   if (sort)
7286438783fSAart Bik     coo->sort();
7296438783fSAart Bik   char *filename = static_cast<char *>(dest);
7306438783fSAart Bik   auto &sizes = coo->getSizes();
7316438783fSAart Bik   auto &elements = coo->getElements();
7326438783fSAart Bik   uint64_t rank = coo->getRank();
733efa15f41SAart Bik   uint64_t nnz = elements.size();
734efa15f41SAart Bik   std::fstream file;
735efa15f41SAart Bik   file.open(filename, std::ios_base::out | std::ios_base::trunc);
736efa15f41SAart Bik   assert(file.is_open());
737efa15f41SAart Bik   file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
738efa15f41SAart Bik   for (uint64_t r = 0; r < rank - 1; r++)
739efa15f41SAart Bik     file << sizes[r] << " ";
740efa15f41SAart Bik   file << sizes[rank - 1] << std::endl;
741efa15f41SAart Bik   for (uint64_t i = 0; i < nnz; i++) {
742efa15f41SAart Bik     auto &idx = elements[i].indices;
743efa15f41SAart Bik     for (uint64_t r = 0; r < rank; r++)
744efa15f41SAart Bik       file << (idx[r] + 1) << " ";
745efa15f41SAart Bik     file << elements[i].value << std::endl;
746efa15f41SAart Bik   }
747efa15f41SAart Bik   file.flush();
748efa15f41SAart Bik   file.close();
749efa15f41SAart Bik   assert(file.good());
7506438783fSAart Bik }
7516438783fSAart Bik 
7526438783fSAart Bik /// Initializes sparse tensor from an external COO-flavored format.
7536438783fSAart Bik template <typename V>
75446bdacaaSwren romano static SparseTensorStorage<uint64_t, uint64_t, V> *
7556438783fSAart Bik toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
75620eaa88fSBixia Zheng                    uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
75720eaa88fSBixia Zheng   const DimLevelType *sparsity = (DimLevelType *)(sparse);
75820eaa88fSBixia Zheng #ifndef NDEBUG
75920eaa88fSBixia Zheng   // Verify that perm is a permutation of 0..(rank-1).
76020eaa88fSBixia Zheng   std::vector<uint64_t> order(perm, perm + rank);
76120eaa88fSBixia Zheng   std::sort(order.begin(), order.end());
7621e47888dSAart Bik   for (uint64_t i = 0; i < rank; ++i) {
76320eaa88fSBixia Zheng     if (i != order[i]) {
764988d4b0dSAart Bik       fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
76520eaa88fSBixia Zheng       exit(1);
76620eaa88fSBixia Zheng     }
76720eaa88fSBixia Zheng   }
76820eaa88fSBixia Zheng 
76920eaa88fSBixia Zheng   // Verify that the sparsity values are supported.
7701e47888dSAart Bik   for (uint64_t i = 0; i < rank; ++i) {
77120eaa88fSBixia Zheng     if (sparsity[i] != DimLevelType::kDense &&
77220eaa88fSBixia Zheng         sparsity[i] != DimLevelType::kCompressed) {
77320eaa88fSBixia Zheng       fprintf(stderr, "Unsupported sparsity value %d\n",
77420eaa88fSBixia Zheng               static_cast<int>(sparsity[i]));
77520eaa88fSBixia Zheng       exit(1);
77620eaa88fSBixia Zheng     }
77720eaa88fSBixia Zheng   }
77820eaa88fSBixia Zheng #endif
77920eaa88fSBixia Zheng 
7806438783fSAart Bik   // Convert external format to internal COO.
78163bdcaf9Swren romano   auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
7826438783fSAart Bik   std::vector<uint64_t> idx(rank);
7836438783fSAart Bik   for (uint64_t i = 0, base = 0; i < nse; i++) {
7846438783fSAart Bik     for (uint64_t r = 0; r < rank; r++)
785d8b229a1SAart Bik       idx[perm[r]] = indices[base + r];
78663bdcaf9Swren romano     coo->add(idx, values[i]);
7876438783fSAart Bik     base += rank;
7886438783fSAart Bik   }
7896438783fSAart Bik   // Return sparse tensor storage format as opaque pointer.
79063bdcaf9Swren romano   auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
79163bdcaf9Swren romano       rank, shape, perm, sparsity, coo);
79263bdcaf9Swren romano   delete coo;
79363bdcaf9Swren romano   return tensor;
7946438783fSAart Bik }
7956438783fSAart Bik 
7966438783fSAart Bik /// Converts a sparse tensor to an external COO-flavored format.
7976438783fSAart Bik template <typename V>
79846bdacaaSwren romano static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
79946bdacaaSwren romano                                  uint64_t **pShape, V **pValues,
80046bdacaaSwren romano                                  uint64_t **pIndices) {
8016438783fSAart Bik   auto sparseTensor =
8026438783fSAart Bik       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
8036438783fSAart Bik   uint64_t rank = sparseTensor->getRank();
8046438783fSAart Bik   std::vector<uint64_t> perm(rank);
8056438783fSAart Bik   std::iota(perm.begin(), perm.end(), 0);
8066438783fSAart Bik   SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
8076438783fSAart Bik 
8086438783fSAart Bik   const std::vector<Element<V>> &elements = coo->getElements();
8096438783fSAart Bik   uint64_t nse = elements.size();
8106438783fSAart Bik 
8116438783fSAart Bik   uint64_t *shape = new uint64_t[rank];
8126438783fSAart Bik   for (uint64_t i = 0; i < rank; i++)
8136438783fSAart Bik     shape[i] = coo->getSizes()[i];
8146438783fSAart Bik 
8156438783fSAart Bik   V *values = new V[nse];
8166438783fSAart Bik   uint64_t *indices = new uint64_t[rank * nse];
8176438783fSAart Bik 
8186438783fSAart Bik   for (uint64_t i = 0, base = 0; i < nse; i++) {
8196438783fSAart Bik     values[i] = elements[i].value;
8206438783fSAart Bik     for (uint64_t j = 0; j < rank; j++)
8216438783fSAart Bik       indices[base + j] = elements[i].indices[j];
8226438783fSAart Bik     base += rank;
8236438783fSAart Bik   }
8246438783fSAart Bik 
8256438783fSAart Bik   delete coo;
8266438783fSAart Bik   *pRank = rank;
8276438783fSAart Bik   *pNse = nse;
8286438783fSAart Bik   *pShape = shape;
8296438783fSAart Bik   *pValues = values;
8306438783fSAart Bik   *pIndices = indices;
831efa15f41SAart Bik }
832efa15f41SAart Bik 
833be0a7e9fSMehdi Amini } // namespace
8348a91bc7bSHarrietAkot 
8358a91bc7bSHarrietAkot extern "C" {
8368a91bc7bSHarrietAkot 
8378a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
8388a91bc7bSHarrietAkot //
8398a91bc7bSHarrietAkot // Public API with methods that operate on MLIR buffers (memrefs) to interact
8408a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally.
8418a91bc7bSHarrietAkot // These methods should be used exclusively by MLIR compiler-generated code.
8428a91bc7bSHarrietAkot //
8438a91bc7bSHarrietAkot // Some macro magic is used to generate implementations for all required type
8448a91bc7bSHarrietAkot // combinations that can be called from MLIR compiler-generated code.
8458a91bc7bSHarrietAkot //
8468a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
8478a91bc7bSHarrietAkot 
8488a91bc7bSHarrietAkot #define CASE(p, i, v, P, I, V)                                                 \
8498a91bc7bSHarrietAkot   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
85063bdcaf9Swren romano     SparseTensorCOO<V> *coo = nullptr;                                         \
851845561ecSwren romano     if (action <= Action::kFromCOO) {                                          \
852845561ecSwren romano       if (action == Action::kFromFile) {                                       \
8538a91bc7bSHarrietAkot         char *filename = static_cast<char *>(ptr);                             \
85463bdcaf9Swren romano         coo = openSparseTensorCOO<V>(filename, rank, shape, perm);             \
855845561ecSwren romano       } else if (action == Action::kFromCOO) {                                 \
85663bdcaf9Swren romano         coo = static_cast<SparseTensorCOO<V> *>(ptr);                          \
8578a91bc7bSHarrietAkot       } else {                                                                 \
858845561ecSwren romano         assert(action == Action::kEmpty);                                      \
8598a91bc7bSHarrietAkot       }                                                                        \
86063bdcaf9Swren romano       auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor(            \
86163bdcaf9Swren romano           rank, shape, perm, sparsity, coo);                                   \
86263bdcaf9Swren romano       if (action == Action::kFromFile)                                         \
86363bdcaf9Swren romano         delete coo;                                                            \
86463bdcaf9Swren romano       return tensor;                                                           \
865bb56c2b3SMehdi Amini     }                                                                          \
866bb56c2b3SMehdi Amini     if (action == Action::kEmptyCOO)                                           \
867d83a7068Swren romano       return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm);        \
86863bdcaf9Swren romano     coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);       \
869845561ecSwren romano     if (action == Action::kToIterator) {                                       \
87063bdcaf9Swren romano       coo->startIterator();                                                    \
8718a91bc7bSHarrietAkot     } else {                                                                   \
872845561ecSwren romano       assert(action == Action::kToCOO);                                        \
8738a91bc7bSHarrietAkot     }                                                                          \
87463bdcaf9Swren romano     return coo;                                                                \
8758a91bc7bSHarrietAkot   }
8768a91bc7bSHarrietAkot 
877845561ecSwren romano #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
878845561ecSwren romano 
8798a91bc7bSHarrietAkot #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
8808a91bc7bSHarrietAkot   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
8814f2ec7f9SAart Bik     assert(ref &&tensor);                                                      \
8828a91bc7bSHarrietAkot     std::vector<TYPE> *v;                                                      \
8838a91bc7bSHarrietAkot     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
8848a91bc7bSHarrietAkot     ref->basePtr = ref->data = v->data();                                      \
8858a91bc7bSHarrietAkot     ref->offset = 0;                                                           \
8868a91bc7bSHarrietAkot     ref->sizes[0] = v->size();                                                 \
8878a91bc7bSHarrietAkot     ref->strides[0] = 1;                                                       \
8888a91bc7bSHarrietAkot   }
8898a91bc7bSHarrietAkot 
8908a91bc7bSHarrietAkot #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
8918a91bc7bSHarrietAkot   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
892d2215e79SRainer Orth                            index_type d) {                                     \
8934f2ec7f9SAart Bik     assert(ref &&tensor);                                                      \
8948a91bc7bSHarrietAkot     std::vector<TYPE> *v;                                                      \
8958a91bc7bSHarrietAkot     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
8968a91bc7bSHarrietAkot     ref->basePtr = ref->data = v->data();                                      \
8978a91bc7bSHarrietAkot     ref->offset = 0;                                                           \
8988a91bc7bSHarrietAkot     ref->sizes[0] = v->size();                                                 \
8998a91bc7bSHarrietAkot     ref->strides[0] = 1;                                                       \
9008a91bc7bSHarrietAkot   }
9018a91bc7bSHarrietAkot 
9028a91bc7bSHarrietAkot #define IMPL_ADDELT(NAME, TYPE)                                                \
9038a91bc7bSHarrietAkot   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
904d2215e79SRainer Orth                             StridedMemRefType<index_type, 1> *iref,            \
905d2215e79SRainer Orth                             StridedMemRefType<index_type, 1> *pref) {          \
9064f2ec7f9SAart Bik     assert(tensor &&iref &&pref);                                              \
9078a91bc7bSHarrietAkot     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
9088a91bc7bSHarrietAkot     assert(iref->sizes[0] == pref->sizes[0]);                                  \
909d2215e79SRainer Orth     const index_type *indx = iref->data + iref->offset;                        \
910d2215e79SRainer Orth     const index_type *perm = pref->data + pref->offset;                        \
9118a91bc7bSHarrietAkot     uint64_t isize = iref->sizes[0];                                           \
912d2215e79SRainer Orth     std::vector<index_type> indices(isize);                                    \
9138a91bc7bSHarrietAkot     for (uint64_t r = 0; r < isize; r++)                                       \
9148a91bc7bSHarrietAkot       indices[perm[r]] = indx[r];                                              \
9158a91bc7bSHarrietAkot     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
9168a91bc7bSHarrietAkot     return tensor;                                                             \
9178a91bc7bSHarrietAkot   }
9188a91bc7bSHarrietAkot 
9198a91bc7bSHarrietAkot #define IMPL_GETNEXT(NAME, V)                                                  \
920d2215e79SRainer Orth   bool _mlir_ciface_##NAME(void *tensor,                                       \
921d2215e79SRainer Orth                            StridedMemRefType<index_type, 1> *iref,             \
9228a91bc7bSHarrietAkot                            StridedMemRefType<V, 0> *vref) {                    \
9234f2ec7f9SAart Bik     assert(tensor &&iref &&vref);                                              \
9248a91bc7bSHarrietAkot     assert(iref->strides[0] == 1);                                             \
925d2215e79SRainer Orth     index_type *indx = iref->data + iref->offset;                              \
926c9f2beffSMehdi Amini     V *value = vref->data + vref->offset;                                      \
9278a91bc7bSHarrietAkot     const uint64_t isize = iref->sizes[0];                                     \
9288a91bc7bSHarrietAkot     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
9298a91bc7bSHarrietAkot     const Element<V> *elem = iter->getNext();                                  \
93063bdcaf9Swren romano     if (elem == nullptr)                                                       \
9318a91bc7bSHarrietAkot       return false;                                                            \
9328a91bc7bSHarrietAkot     for (uint64_t r = 0; r < isize; r++)                                       \
9338a91bc7bSHarrietAkot       indx[r] = elem->indices[r];                                              \
9348a91bc7bSHarrietAkot     *value = elem->value;                                                      \
9358a91bc7bSHarrietAkot     return true;                                                               \
9368a91bc7bSHarrietAkot   }
9378a91bc7bSHarrietAkot 
938f66e5769SAart Bik #define IMPL_LEXINSERT(NAME, V)                                                \
939d2215e79SRainer Orth   void _mlir_ciface_##NAME(void *tensor,                                       \
940d2215e79SRainer Orth                            StridedMemRefType<index_type, 1> *cref, V val) {    \
9414f2ec7f9SAart Bik     assert(tensor &&cref);                                                     \
942f66e5769SAart Bik     assert(cref->strides[0] == 1);                                             \
943d2215e79SRainer Orth     index_type *cursor = cref->data + cref->offset;                            \
944f66e5769SAart Bik     assert(cursor);                                                            \
945f66e5769SAart Bik     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
946f66e5769SAart Bik   }
947f66e5769SAart Bik 
9484f2ec7f9SAart Bik #define IMPL_EXPINSERT(NAME, V)                                                \
9494f2ec7f9SAart Bik   void _mlir_ciface_##NAME(                                                    \
950d2215e79SRainer Orth       void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
9514f2ec7f9SAart Bik       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
952d2215e79SRainer Orth       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
9534f2ec7f9SAart Bik     assert(tensor &&cref &&vref &&fref &&aref);                                \
9544f2ec7f9SAart Bik     assert(cref->strides[0] == 1);                                             \
9554f2ec7f9SAart Bik     assert(vref->strides[0] == 1);                                             \
9564f2ec7f9SAart Bik     assert(fref->strides[0] == 1);                                             \
9574f2ec7f9SAart Bik     assert(aref->strides[0] == 1);                                             \
9584f2ec7f9SAart Bik     assert(vref->sizes[0] == fref->sizes[0]);                                  \
959d2215e79SRainer Orth     index_type *cursor = cref->data + cref->offset;                            \
960c9f2beffSMehdi Amini     V *values = vref->data + vref->offset;                                     \
9614f2ec7f9SAart Bik     bool *filled = fref->data + fref->offset;                                  \
962d2215e79SRainer Orth     index_type *added = aref->data + aref->offset;                             \
9634f2ec7f9SAart Bik     static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
9644f2ec7f9SAart Bik         cursor, values, filled, added, count);                                 \
9654f2ec7f9SAart Bik   }
9664f2ec7f9SAart Bik 
967d2215e79SRainer Orth // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
968bc04a470Swren romano // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
969bc04a470Swren romano // that this file cannot get out of sync with its header.
970d2215e79SRainer Orth static_assert(std::is_same<index_type, uint64_t>::value,
971d2215e79SRainer Orth               "Expected index_type == uint64_t");
972bc04a470Swren romano 
9738a91bc7bSHarrietAkot /// Constructs a new sparse tensor. This is the "swiss army knife"
9748a91bc7bSHarrietAkot /// method for materializing sparse tensors into the computation.
9758a91bc7bSHarrietAkot ///
976845561ecSwren romano /// Action:
9778a91bc7bSHarrietAkot /// kEmpty = returns empty storage to fill later
9788a91bc7bSHarrietAkot /// kFromFile = returns storage, where ptr contains filename to read
9798a91bc7bSHarrietAkot /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
9808a91bc7bSHarrietAkot /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
9818a91bc7bSHarrietAkot /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
982845561ecSwren romano /// kToIterator = returns iterator from storage in ptr (call getNext() to use)
9838a91bc7bSHarrietAkot void *
984845561ecSwren romano _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
985d2215e79SRainer Orth                              StridedMemRefType<index_type, 1> *sref,
986d2215e79SRainer Orth                              StridedMemRefType<index_type, 1> *pref,
987845561ecSwren romano                              OverheadType ptrTp, OverheadType indTp,
988845561ecSwren romano                              PrimaryType valTp, Action action, void *ptr) {
9898a91bc7bSHarrietAkot   assert(aref && sref && pref);
9908a91bc7bSHarrietAkot   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
9918a91bc7bSHarrietAkot          pref->strides[0] == 1);
9928a91bc7bSHarrietAkot   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
993845561ecSwren romano   const DimLevelType *sparsity = aref->data + aref->offset;
994d83a7068Swren romano   const index_type *shape = sref->data + sref->offset;
995d2215e79SRainer Orth   const index_type *perm = pref->data + pref->offset;
9968a91bc7bSHarrietAkot   uint64_t rank = aref->sizes[0];
9978a91bc7bSHarrietAkot 
998bc04a470Swren romano   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
999bc04a470Swren romano   // This is safe because of the static_assert above.
1000bc04a470Swren romano   if (ptrTp == OverheadType::kIndex)
1001bc04a470Swren romano     ptrTp = OverheadType::kU64;
1002bc04a470Swren romano   if (indTp == OverheadType::kIndex)
1003bc04a470Swren romano     indTp = OverheadType::kU64;
1004bc04a470Swren romano 
10058a91bc7bSHarrietAkot   // Double matrices with all combinations of overhead storage.
1006845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
1007845561ecSwren romano        uint64_t, double);
1008845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
1009845561ecSwren romano        uint32_t, double);
1010845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
1011845561ecSwren romano        uint16_t, double);
1012845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
1013845561ecSwren romano        uint8_t, double);
1014845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
1015845561ecSwren romano        uint64_t, double);
1016845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
1017845561ecSwren romano        uint32_t, double);
1018845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
1019845561ecSwren romano        uint16_t, double);
1020845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
1021845561ecSwren romano        uint8_t, double);
1022845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
1023845561ecSwren romano        uint64_t, double);
1024845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
1025845561ecSwren romano        uint32_t, double);
1026845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
1027845561ecSwren romano        uint16_t, double);
1028845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
1029845561ecSwren romano        uint8_t, double);
1030845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
1031845561ecSwren romano        uint64_t, double);
1032845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
1033845561ecSwren romano        uint32_t, double);
1034845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
1035845561ecSwren romano        uint16_t, double);
1036845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
1037845561ecSwren romano        uint8_t, double);
10388a91bc7bSHarrietAkot 
10398a91bc7bSHarrietAkot   // Float matrices with all combinations of overhead storage.
1040845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
1041845561ecSwren romano        uint64_t, float);
1042845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
1043845561ecSwren romano        uint32_t, float);
1044845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
1045845561ecSwren romano        uint16_t, float);
1046845561ecSwren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
1047845561ecSwren romano        uint8_t, float);
1048845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
1049845561ecSwren romano        uint64_t, float);
1050845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
1051845561ecSwren romano        uint32_t, float);
1052845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
1053845561ecSwren romano        uint16_t, float);
1054845561ecSwren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
1055845561ecSwren romano        uint8_t, float);
1056845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
1057845561ecSwren romano        uint64_t, float);
1058845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
1059845561ecSwren romano        uint32_t, float);
1060845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
1061845561ecSwren romano        uint16_t, float);
1062845561ecSwren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
1063845561ecSwren romano        uint8_t, float);
1064845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
1065845561ecSwren romano        uint64_t, float);
1066845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
1067845561ecSwren romano        uint32_t, float);
1068845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
1069845561ecSwren romano        uint16_t, float);
1070845561ecSwren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
1071845561ecSwren romano        uint8_t, float);
10728a91bc7bSHarrietAkot 
1073845561ecSwren romano   // Integral matrices with both overheads of the same type.
1074845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
1075845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
1076845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
1077845561ecSwren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
1078845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
1079845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
1080845561ecSwren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
1081845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
1082845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
1083845561ecSwren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
1084845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
1085845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
1086845561ecSwren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
10878a91bc7bSHarrietAkot 
10888a91bc7bSHarrietAkot   // Unsupported case (add above if needed).
10898a91bc7bSHarrietAkot   fputs("unsupported combination of types\n", stderr);
10908a91bc7bSHarrietAkot   exit(1);
10918a91bc7bSHarrietAkot }
10928a91bc7bSHarrietAkot 
10938a91bc7bSHarrietAkot /// Methods that provide direct access to pointers.
1094d2215e79SRainer Orth IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
10958a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
10968a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
10978a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
10988a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
10998a91bc7bSHarrietAkot 
11008a91bc7bSHarrietAkot /// Methods that provide direct access to indices.
1101d2215e79SRainer Orth IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
11028a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
11038a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
11048a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
11058a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
11068a91bc7bSHarrietAkot 
11078a91bc7bSHarrietAkot /// Methods that provide direct access to values.
11088a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
11098a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
11108a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
11118a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
11128a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
11138a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
11148a91bc7bSHarrietAkot 
11158a91bc7bSHarrietAkot /// Helper to add value to coordinate scheme, one per value type.
11168a91bc7bSHarrietAkot IMPL_ADDELT(addEltF64, double)
11178a91bc7bSHarrietAkot IMPL_ADDELT(addEltF32, float)
11188a91bc7bSHarrietAkot IMPL_ADDELT(addEltI64, int64_t)
11198a91bc7bSHarrietAkot IMPL_ADDELT(addEltI32, int32_t)
11208a91bc7bSHarrietAkot IMPL_ADDELT(addEltI16, int16_t)
11218a91bc7bSHarrietAkot IMPL_ADDELT(addEltI8, int8_t)
11228a91bc7bSHarrietAkot 
11238a91bc7bSHarrietAkot /// Helper to enumerate elements of coordinate scheme, one per value type.
11248a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF64, double)
11258a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF32, float)
11268a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI64, int64_t)
11278a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI32, int32_t)
11288a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI16, int16_t)
11298a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI8, int8_t)
11308a91bc7bSHarrietAkot 
11316438783fSAart Bik /// Insert elements in lexicographical index order, one per value type.
1132f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF64, double)
1133f66e5769SAart Bik IMPL_LEXINSERT(lexInsertF32, float)
1134f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI64, int64_t)
1135f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI32, int32_t)
1136f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI16, int16_t)
1137f66e5769SAart Bik IMPL_LEXINSERT(lexInsertI8, int8_t)
1138f66e5769SAart Bik 
11396438783fSAart Bik /// Insert using expansion, one per value type.
11404f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF64, double)
11414f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertF32, float)
11424f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI64, int64_t)
11434f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI32, int32_t)
11444f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI16, int16_t)
11454f2ec7f9SAart Bik IMPL_EXPINSERT(expInsertI8, int8_t)
11464f2ec7f9SAart Bik 
11478a91bc7bSHarrietAkot #undef CASE
11488a91bc7bSHarrietAkot #undef IMPL_SPARSEVALUES
11498a91bc7bSHarrietAkot #undef IMPL_GETOVERHEAD
11508a91bc7bSHarrietAkot #undef IMPL_ADDELT
11518a91bc7bSHarrietAkot #undef IMPL_GETNEXT
11524f2ec7f9SAart Bik #undef IMPL_LEXINSERT
11534f2ec7f9SAart Bik #undef IMPL_EXPINSERT
11546438783fSAart Bik 
11556438783fSAart Bik /// Output a sparse tensor, one per value type.
11566438783fSAart Bik void outSparseTensorF64(void *tensor, void *dest, bool sort) {
11576438783fSAart Bik   return outSparseTensor<double>(tensor, dest, sort);
11586438783fSAart Bik }
11596438783fSAart Bik void outSparseTensorF32(void *tensor, void *dest, bool sort) {
11606438783fSAart Bik   return outSparseTensor<float>(tensor, dest, sort);
11616438783fSAart Bik }
11626438783fSAart Bik void outSparseTensorI64(void *tensor, void *dest, bool sort) {
11636438783fSAart Bik   return outSparseTensor<int64_t>(tensor, dest, sort);
11646438783fSAart Bik }
11656438783fSAart Bik void outSparseTensorI32(void *tensor, void *dest, bool sort) {
11666438783fSAart Bik   return outSparseTensor<int32_t>(tensor, dest, sort);
11676438783fSAart Bik }
11686438783fSAart Bik void outSparseTensorI16(void *tensor, void *dest, bool sort) {
11696438783fSAart Bik   return outSparseTensor<int16_t>(tensor, dest, sort);
11706438783fSAart Bik }
11716438783fSAart Bik void outSparseTensorI8(void *tensor, void *dest, bool sort) {
11726438783fSAart Bik   return outSparseTensor<int8_t>(tensor, dest, sort);
11736438783fSAart Bik }
11748a91bc7bSHarrietAkot 
11758a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
11768a91bc7bSHarrietAkot //
11778a91bc7bSHarrietAkot // Public API with methods that accept C-style data structures to interact
11788a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally.
11798a91bc7bSHarrietAkot // These methods can be used both by MLIR compiler-generated code as well as by
11808a91bc7bSHarrietAkot // an external runtime that wants to interact with MLIR compiler-generated code.
11818a91bc7bSHarrietAkot //
11828a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
11838a91bc7bSHarrietAkot 
11848a91bc7bSHarrietAkot /// Helper method to read a sparse tensor filename from the environment,
11858a91bc7bSHarrietAkot /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
1186d2215e79SRainer Orth char *getTensorFilename(index_type id) {
11878a91bc7bSHarrietAkot   char var[80];
11888a91bc7bSHarrietAkot   sprintf(var, "TENSOR%" PRIu64, id);
11898a91bc7bSHarrietAkot   char *env = getenv(var);
11903734c078Swren romano   if (!env) {
11913734c078Swren romano     fprintf(stderr, "Environment variable %s is not set\n", var);
11923734c078Swren romano     exit(1);
11933734c078Swren romano   }
11948a91bc7bSHarrietAkot   return env;
11958a91bc7bSHarrietAkot }
11968a91bc7bSHarrietAkot 
11978a91bc7bSHarrietAkot /// Returns size of sparse tensor in given dimension.
1198d2215e79SRainer Orth index_type sparseDimSize(void *tensor, index_type d) {
11998a91bc7bSHarrietAkot   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
12008a91bc7bSHarrietAkot }
12018a91bc7bSHarrietAkot 
1202f66e5769SAart Bik /// Finalizes lexicographic insertions.
1203f66e5769SAart Bik void endInsert(void *tensor) {
1204f66e5769SAart Bik   return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
1205f66e5769SAart Bik }
1206f66e5769SAart Bik 
12078a91bc7bSHarrietAkot /// Releases sparse tensor storage.
12088a91bc7bSHarrietAkot void delSparseTensor(void *tensor) {
12098a91bc7bSHarrietAkot   delete static_cast<SparseTensorStorageBase *>(tensor);
12108a91bc7bSHarrietAkot }
12118a91bc7bSHarrietAkot 
121263bdcaf9Swren romano /// Releases sparse tensor coordinate scheme.
121363bdcaf9Swren romano #define IMPL_DELCOO(VNAME, V)                                                  \
121463bdcaf9Swren romano   void delSparseTensorCOO##VNAME(void *coo) {                                  \
121563bdcaf9Swren romano     delete static_cast<SparseTensorCOO<V> *>(coo);                             \
121663bdcaf9Swren romano   }
121763bdcaf9Swren romano IMPL_DELCOO(F64, double)
121863bdcaf9Swren romano IMPL_DELCOO(F32, float)
121963bdcaf9Swren romano IMPL_DELCOO(I64, int64_t)
122063bdcaf9Swren romano IMPL_DELCOO(I32, int32_t)
122163bdcaf9Swren romano IMPL_DELCOO(I16, int16_t)
122263bdcaf9Swren romano IMPL_DELCOO(I8, int8_t)
122363bdcaf9Swren romano #undef IMPL_DELCOO
122463bdcaf9Swren romano 
12258a91bc7bSHarrietAkot /// Initializes sparse tensor from a COO-flavored format expressed using C-style
12268a91bc7bSHarrietAkot /// data structures. The expected parameters are:
12278a91bc7bSHarrietAkot ///
12288a91bc7bSHarrietAkot ///   rank:    rank of tensor
12298a91bc7bSHarrietAkot ///   nse:     number of specified elements (usually the nonzeros)
12308a91bc7bSHarrietAkot ///   shape:   array with dimension size for each rank
12318a91bc7bSHarrietAkot ///   values:  a "nse" array with values for all specified elements
12328a91bc7bSHarrietAkot ///   indices: a flat "nse x rank" array with indices for all specified elements
123320eaa88fSBixia Zheng ///   perm:    the permutation of the dimensions in the storage
123420eaa88fSBixia Zheng ///   sparse:  the sparsity for the dimensions
12358a91bc7bSHarrietAkot ///
12368a91bc7bSHarrietAkot /// For example, the sparse matrix
12378a91bc7bSHarrietAkot ///     | 1.0 0.0 0.0 |
12388a91bc7bSHarrietAkot ///     | 0.0 5.0 3.0 |
12398a91bc7bSHarrietAkot /// can be passed as
12408a91bc7bSHarrietAkot ///      rank    = 2
12418a91bc7bSHarrietAkot ///      nse     = 3
12428a91bc7bSHarrietAkot ///      shape   = [2, 3]
12438a91bc7bSHarrietAkot ///      values  = [1.0, 5.0, 3.0]
12448a91bc7bSHarrietAkot ///      indices = [ 0, 0,  1, 1,  1, 2]
12458a91bc7bSHarrietAkot //
124620eaa88fSBixia Zheng // TODO: generalize beyond 64-bit indices.
12478a91bc7bSHarrietAkot //
12486438783fSAart Bik void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape,
124920eaa88fSBixia Zheng                                    double *values, uint64_t *indices,
125020eaa88fSBixia Zheng                                    uint64_t *perm, uint8_t *sparse) {
125120eaa88fSBixia Zheng   return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm,
125220eaa88fSBixia Zheng                                     sparse);
12538a91bc7bSHarrietAkot }
12546438783fSAart Bik void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
125520eaa88fSBixia Zheng                                    float *values, uint64_t *indices,
125620eaa88fSBixia Zheng                                    uint64_t *perm, uint8_t *sparse) {
125720eaa88fSBixia Zheng   return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm,
125820eaa88fSBixia Zheng                                    sparse);
12598a91bc7bSHarrietAkot }
12608a91bc7bSHarrietAkot 
12612f49e6b0SBixia Zheng /// Converts a sparse tensor to COO-flavored format expressed using C-style
12622f49e6b0SBixia Zheng /// data structures. The expected output parameters are pointers for these
12632f49e6b0SBixia Zheng /// values:
12642f49e6b0SBixia Zheng ///
12652f49e6b0SBixia Zheng ///   rank:    rank of tensor
12662f49e6b0SBixia Zheng ///   nse:     number of specified elements (usually the nonzeros)
12672f49e6b0SBixia Zheng ///   shape:   array with dimension size for each rank
12682f49e6b0SBixia Zheng ///   values:  a "nse" array with values for all specified elements
12692f49e6b0SBixia Zheng ///   indices: a flat "nse x rank" array with indices for all specified elements
12702f49e6b0SBixia Zheng ///
12712f49e6b0SBixia Zheng /// The input is a pointer to SparseTensorStorage<P, I, V>, typically returned
12722f49e6b0SBixia Zheng /// from convertToMLIRSparseTensor.
12732f49e6b0SBixia Zheng ///
12742f49e6b0SBixia Zheng //  TODO: Currently, values are copied from SparseTensorStorage to
12752f49e6b0SBixia Zheng //  SparseTensorCOO, then to the output. We may want to reduce the number of
12762f49e6b0SBixia Zheng //  copies.
12772f49e6b0SBixia Zheng //
12786438783fSAart Bik // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
12796438783fSAart Bik // compressed
12802f49e6b0SBixia Zheng //
12816438783fSAart Bik void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank,
12826438783fSAart Bik                                     uint64_t *pNse, uint64_t **pShape,
12836438783fSAart Bik                                     double **pValues, uint64_t **pIndices) {
12846438783fSAart Bik   fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices);
12852f49e6b0SBixia Zheng }
12866438783fSAart Bik void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
12876438783fSAart Bik                                     uint64_t *pNse, uint64_t **pShape,
12886438783fSAart Bik                                     float **pValues, uint64_t **pIndices) {
12896438783fSAart Bik   fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
12902f49e6b0SBixia Zheng }
1291efa15f41SAart Bik 
12928a91bc7bSHarrietAkot } // extern "C"
12938a91bc7bSHarrietAkot 
12948a91bc7bSHarrietAkot #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
1295