1*8a91bc7bSHarrietAkot //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===//
2*8a91bc7bSHarrietAkot //
3*8a91bc7bSHarrietAkot // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*8a91bc7bSHarrietAkot // See https://llvm.org/LICENSE.txt for license information.
5*8a91bc7bSHarrietAkot // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*8a91bc7bSHarrietAkot //
7*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
8*8a91bc7bSHarrietAkot //
9*8a91bc7bSHarrietAkot // This file implements a light-weight runtime support library that is useful
10*8a91bc7bSHarrietAkot // for sparse tensor manipulations. The functionality provided in this library
11*8a91bc7bSHarrietAkot // is meant to simplify benchmarking, testing, and debugging MLIR code that
12*8a91bc7bSHarrietAkot // operates on sparse tensors. The provided functionality is **not** part
13*8a91bc7bSHarrietAkot // of core MLIR, however.
14*8a91bc7bSHarrietAkot //
15*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
16*8a91bc7bSHarrietAkot 
17*8a91bc7bSHarrietAkot #include "mlir/ExecutionEngine/CRunnerUtils.h"
18*8a91bc7bSHarrietAkot 
19*8a91bc7bSHarrietAkot #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
20*8a91bc7bSHarrietAkot 
21*8a91bc7bSHarrietAkot #include <algorithm>
22*8a91bc7bSHarrietAkot #include <cassert>
23*8a91bc7bSHarrietAkot #include <cctype>
24*8a91bc7bSHarrietAkot #include <cinttypes>
25*8a91bc7bSHarrietAkot #include <cstdio>
26*8a91bc7bSHarrietAkot #include <cstdlib>
27*8a91bc7bSHarrietAkot #include <cstring>
28*8a91bc7bSHarrietAkot #include <numeric>
29*8a91bc7bSHarrietAkot #include <vector>
30*8a91bc7bSHarrietAkot 
31*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
32*8a91bc7bSHarrietAkot //
33*8a91bc7bSHarrietAkot // Internal support for storing and reading sparse tensors.
34*8a91bc7bSHarrietAkot //
35*8a91bc7bSHarrietAkot // The following memory-resident sparse storage schemes are supported:
36*8a91bc7bSHarrietAkot //
37*8a91bc7bSHarrietAkot // (a) A coordinate scheme for temporarily storing and lexicographically
38*8a91bc7bSHarrietAkot //     sorting a sparse tensor by index (SparseTensorCOO).
39*8a91bc7bSHarrietAkot //
40*8a91bc7bSHarrietAkot // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
41*8a91bc7bSHarrietAkot //     per-dimension sparse/dense annnotations together with a dimension
42*8a91bc7bSHarrietAkot //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
43*8a91bc7bSHarrietAkot //
44*8a91bc7bSHarrietAkot // The following external formats are supported:
45*8a91bc7bSHarrietAkot //
46*8a91bc7bSHarrietAkot // (1) Matrix Market Exchange (MME): *.mtx
47*8a91bc7bSHarrietAkot //     https://math.nist.gov/MatrixMarket/formats.html
48*8a91bc7bSHarrietAkot //
49*8a91bc7bSHarrietAkot // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
50*8a91bc7bSHarrietAkot //     http://frostt.io/tensors/file-formats.html
51*8a91bc7bSHarrietAkot //
52*8a91bc7bSHarrietAkot // Two public APIs are supported:
53*8a91bc7bSHarrietAkot //
54*8a91bc7bSHarrietAkot // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
55*8a91bc7bSHarrietAkot //     tensors. These methods should be used exclusively by MLIR
56*8a91bc7bSHarrietAkot //     compiler-generated code.
57*8a91bc7bSHarrietAkot //
58*8a91bc7bSHarrietAkot // (II) Methods that accept C-style data structures to interact with sparse
59*8a91bc7bSHarrietAkot //      tensors. These methods can be used by any external runtime that wants
60*8a91bc7bSHarrietAkot //      to interact with MLIR compiler-generated code.
61*8a91bc7bSHarrietAkot //
62*8a91bc7bSHarrietAkot // In both cases (I) and (II), the SparseTensorStorage format is externally
63*8a91bc7bSHarrietAkot // only visible as an opaque pointer.
64*8a91bc7bSHarrietAkot //
65*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
66*8a91bc7bSHarrietAkot 
67*8a91bc7bSHarrietAkot namespace {
68*8a91bc7bSHarrietAkot 
69*8a91bc7bSHarrietAkot /// A sparse tensor element in coordinate scheme (value and indices).
70*8a91bc7bSHarrietAkot /// For example, a rank-1 vector element would look like
71*8a91bc7bSHarrietAkot ///   ({i}, a[i])
72*8a91bc7bSHarrietAkot /// and a rank-5 tensor element like
73*8a91bc7bSHarrietAkot ///   ({i,j,k,l,m}, a[i,j,k,l,m])
74*8a91bc7bSHarrietAkot template <typename V>
75*8a91bc7bSHarrietAkot struct Element {
76*8a91bc7bSHarrietAkot   Element(const std::vector<uint64_t> &ind, V val) : indices(ind), value(val){};
77*8a91bc7bSHarrietAkot   std::vector<uint64_t> indices;
78*8a91bc7bSHarrietAkot   V value;
79*8a91bc7bSHarrietAkot };
80*8a91bc7bSHarrietAkot 
81*8a91bc7bSHarrietAkot /// A memory-resident sparse tensor in coordinate scheme (collection of
82*8a91bc7bSHarrietAkot /// elements). This data structure is used to read a sparse tensor from
83*8a91bc7bSHarrietAkot /// any external format into memory and sort the elements lexicographically
84*8a91bc7bSHarrietAkot /// by indices before passing it back to the client (most packed storage
85*8a91bc7bSHarrietAkot /// formats require the elements to appear in lexicographic index order).
86*8a91bc7bSHarrietAkot template <typename V>
87*8a91bc7bSHarrietAkot struct SparseTensorCOO {
88*8a91bc7bSHarrietAkot public:
89*8a91bc7bSHarrietAkot   SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity)
90*8a91bc7bSHarrietAkot       : sizes(szs), iteratorLocked(false), iteratorPos(0) {
91*8a91bc7bSHarrietAkot     if (capacity)
92*8a91bc7bSHarrietAkot       elements.reserve(capacity);
93*8a91bc7bSHarrietAkot   }
94*8a91bc7bSHarrietAkot   /// Adds element as indices and value.
95*8a91bc7bSHarrietAkot   void add(const std::vector<uint64_t> &ind, V val) {
96*8a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to add() after startIterator()");
97*8a91bc7bSHarrietAkot     uint64_t rank = getRank();
98*8a91bc7bSHarrietAkot     assert(rank == ind.size());
99*8a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
100*8a91bc7bSHarrietAkot       assert(ind[r] < sizes[r]); // within bounds
101*8a91bc7bSHarrietAkot     elements.emplace_back(ind, val);
102*8a91bc7bSHarrietAkot   }
103*8a91bc7bSHarrietAkot   /// Sorts elements lexicographically by index.
104*8a91bc7bSHarrietAkot   void sort() {
105*8a91bc7bSHarrietAkot     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
106*8a91bc7bSHarrietAkot     std::sort(elements.begin(), elements.end(), lexOrder);
107*8a91bc7bSHarrietAkot   }
108*8a91bc7bSHarrietAkot   /// Returns rank.
109*8a91bc7bSHarrietAkot   uint64_t getRank() const { return sizes.size(); }
110*8a91bc7bSHarrietAkot   /// Getter for sizes array.
111*8a91bc7bSHarrietAkot   const std::vector<uint64_t> &getSizes() const { return sizes; }
112*8a91bc7bSHarrietAkot   /// Getter for elements array.
113*8a91bc7bSHarrietAkot   const std::vector<Element<V>> &getElements() const { return elements; }
114*8a91bc7bSHarrietAkot 
115*8a91bc7bSHarrietAkot   /// Switch into iterator mode.
116*8a91bc7bSHarrietAkot   void startIterator() {
117*8a91bc7bSHarrietAkot     iteratorLocked = true;
118*8a91bc7bSHarrietAkot     iteratorPos = 0;
119*8a91bc7bSHarrietAkot   }
120*8a91bc7bSHarrietAkot   /// Get the next element.
121*8a91bc7bSHarrietAkot   const Element<V> *getNext() {
122*8a91bc7bSHarrietAkot     assert(iteratorLocked && "Attempt to getNext() before startIterator()");
123*8a91bc7bSHarrietAkot     if (iteratorPos < elements.size())
124*8a91bc7bSHarrietAkot       return &(elements[iteratorPos++]);
125*8a91bc7bSHarrietAkot     iteratorLocked = false;
126*8a91bc7bSHarrietAkot     return nullptr;
127*8a91bc7bSHarrietAkot   }
128*8a91bc7bSHarrietAkot 
129*8a91bc7bSHarrietAkot   /// Factory method. Permutes the original dimensions according to
130*8a91bc7bSHarrietAkot   /// the given ordering and expects subsequent add() calls to honor
131*8a91bc7bSHarrietAkot   /// that same ordering for the given indices. The result is a
132*8a91bc7bSHarrietAkot   /// fully permuted coordinate scheme.
133*8a91bc7bSHarrietAkot   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
134*8a91bc7bSHarrietAkot                                                 const uint64_t *sizes,
135*8a91bc7bSHarrietAkot                                                 const uint64_t *perm,
136*8a91bc7bSHarrietAkot                                                 uint64_t capacity = 0) {
137*8a91bc7bSHarrietAkot     std::vector<uint64_t> permsz(rank);
138*8a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
139*8a91bc7bSHarrietAkot       permsz[perm[r]] = sizes[r];
140*8a91bc7bSHarrietAkot     return new SparseTensorCOO<V>(permsz, capacity);
141*8a91bc7bSHarrietAkot   }
142*8a91bc7bSHarrietAkot 
143*8a91bc7bSHarrietAkot private:
144*8a91bc7bSHarrietAkot   /// Returns true if indices of e1 < indices of e2.
145*8a91bc7bSHarrietAkot   static bool lexOrder(const Element<V> &e1, const Element<V> &e2) {
146*8a91bc7bSHarrietAkot     uint64_t rank = e1.indices.size();
147*8a91bc7bSHarrietAkot     assert(rank == e2.indices.size());
148*8a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++) {
149*8a91bc7bSHarrietAkot       if (e1.indices[r] == e2.indices[r])
150*8a91bc7bSHarrietAkot         continue;
151*8a91bc7bSHarrietAkot       return e1.indices[r] < e2.indices[r];
152*8a91bc7bSHarrietAkot     }
153*8a91bc7bSHarrietAkot     return false;
154*8a91bc7bSHarrietAkot   }
155*8a91bc7bSHarrietAkot   const std::vector<uint64_t> sizes; // per-dimension sizes
156*8a91bc7bSHarrietAkot   std::vector<Element<V>> elements;
157*8a91bc7bSHarrietAkot   bool iteratorLocked;
158*8a91bc7bSHarrietAkot   unsigned iteratorPos;
159*8a91bc7bSHarrietAkot };
160*8a91bc7bSHarrietAkot 
161*8a91bc7bSHarrietAkot /// Abstract base class of sparse tensor storage. Note that we use
162*8a91bc7bSHarrietAkot /// function overloading to implement "partial" method specialization.
163*8a91bc7bSHarrietAkot class SparseTensorStorageBase {
164*8a91bc7bSHarrietAkot public:
165*8a91bc7bSHarrietAkot   enum DimLevelType : uint8_t { kDense = 0, kCompressed = 1, kSingleton = 2 };
166*8a91bc7bSHarrietAkot 
167*8a91bc7bSHarrietAkot   virtual uint64_t getDimSize(uint64_t) = 0;
168*8a91bc7bSHarrietAkot 
169*8a91bc7bSHarrietAkot   // Overhead storage.
170*8a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
171*8a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
172*8a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
173*8a91bc7bSHarrietAkot   virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
174*8a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
175*8a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
176*8a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
177*8a91bc7bSHarrietAkot   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
178*8a91bc7bSHarrietAkot 
179*8a91bc7bSHarrietAkot   // Primary storage.
180*8a91bc7bSHarrietAkot   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
181*8a91bc7bSHarrietAkot   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
182*8a91bc7bSHarrietAkot   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
183*8a91bc7bSHarrietAkot   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
184*8a91bc7bSHarrietAkot   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
185*8a91bc7bSHarrietAkot   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
186*8a91bc7bSHarrietAkot 
187*8a91bc7bSHarrietAkot   virtual ~SparseTensorStorageBase() {}
188*8a91bc7bSHarrietAkot 
189*8a91bc7bSHarrietAkot private:
190*8a91bc7bSHarrietAkot   void fatal(const char *tp) {
191*8a91bc7bSHarrietAkot     fprintf(stderr, "unsupported %s\n", tp);
192*8a91bc7bSHarrietAkot     exit(1);
193*8a91bc7bSHarrietAkot   }
194*8a91bc7bSHarrietAkot };
195*8a91bc7bSHarrietAkot 
196*8a91bc7bSHarrietAkot /// A memory-resident sparse tensor using a storage scheme based on
197*8a91bc7bSHarrietAkot /// per-dimension sparse/dense annotations. This data structure provides a
198*8a91bc7bSHarrietAkot /// bufferized form of a sparse tensor type. In contrast to generating setup
199*8a91bc7bSHarrietAkot /// methods for each differently annotated sparse tensor, this method provides
200*8a91bc7bSHarrietAkot /// a convenient "one-size-fits-all" solution that simply takes an input tensor
201*8a91bc7bSHarrietAkot /// and annotations to implement all required setup in a general manner.
202*8a91bc7bSHarrietAkot template <typename P, typename I, typename V>
203*8a91bc7bSHarrietAkot class SparseTensorStorage : public SparseTensorStorageBase {
204*8a91bc7bSHarrietAkot public:
205*8a91bc7bSHarrietAkot   /// Constructs a sparse tensor storage scheme with the given dimensions,
206*8a91bc7bSHarrietAkot   /// permutation, and per-dimension dense/sparse annotations, using
207*8a91bc7bSHarrietAkot   /// the coordinate scheme tensor for the initial contents if provided.
208*8a91bc7bSHarrietAkot   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
209*8a91bc7bSHarrietAkot                       const uint8_t *sparsity, SparseTensorCOO<V> *tensor)
210*8a91bc7bSHarrietAkot       : sizes(szs), rev(getRank()), pointers(getRank()), indices(getRank()) {
211*8a91bc7bSHarrietAkot     uint64_t rank = getRank();
212*8a91bc7bSHarrietAkot     // Store "reverse" permutation.
213*8a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
214*8a91bc7bSHarrietAkot       rev[perm[r]] = r;
215*8a91bc7bSHarrietAkot     // Provide hints on capacity of pointers and indices.
216*8a91bc7bSHarrietAkot     // TODO: needs fine-tuning based on sparsity
217*8a91bc7bSHarrietAkot     for (uint64_t r = 0, s = 1; r < rank; r++) {
218*8a91bc7bSHarrietAkot       s *= sizes[r];
219*8a91bc7bSHarrietAkot       if (sparsity[r] == kCompressed) {
220*8a91bc7bSHarrietAkot         pointers[r].reserve(s + 1);
221*8a91bc7bSHarrietAkot         indices[r].reserve(s);
222*8a91bc7bSHarrietAkot         s = 1;
223*8a91bc7bSHarrietAkot       } else {
224*8a91bc7bSHarrietAkot         assert(sparsity[r] == kDense && "singleton not yet supported");
225*8a91bc7bSHarrietAkot       }
226*8a91bc7bSHarrietAkot     }
227*8a91bc7bSHarrietAkot     // Prepare sparse pointer structures for all dimensions.
228*8a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
229*8a91bc7bSHarrietAkot       if (sparsity[r] == kCompressed)
230*8a91bc7bSHarrietAkot         pointers[r].push_back(0);
231*8a91bc7bSHarrietAkot     // Then assign contents from coordinate scheme tensor if provided.
232*8a91bc7bSHarrietAkot     if (tensor) {
233*8a91bc7bSHarrietAkot       uint64_t nnz = tensor->getElements().size();
234*8a91bc7bSHarrietAkot       values.reserve(nnz);
235*8a91bc7bSHarrietAkot       fromCOO(tensor, sparsity, 0, nnz, 0);
236*8a91bc7bSHarrietAkot     }
237*8a91bc7bSHarrietAkot   }
238*8a91bc7bSHarrietAkot 
239*8a91bc7bSHarrietAkot   virtual ~SparseTensorStorage() {}
240*8a91bc7bSHarrietAkot 
241*8a91bc7bSHarrietAkot   /// Get the rank of the tensor.
242*8a91bc7bSHarrietAkot   uint64_t getRank() const { return sizes.size(); }
243*8a91bc7bSHarrietAkot 
244*8a91bc7bSHarrietAkot   /// Get the size in the given dimension of the tensor.
245*8a91bc7bSHarrietAkot   uint64_t getDimSize(uint64_t d) override {
246*8a91bc7bSHarrietAkot     assert(d < getRank());
247*8a91bc7bSHarrietAkot     return sizes[d];
248*8a91bc7bSHarrietAkot   }
249*8a91bc7bSHarrietAkot 
250*8a91bc7bSHarrietAkot   // Partially specialize these three methods based on template types.
251*8a91bc7bSHarrietAkot   void getPointers(std::vector<P> **out, uint64_t d) override {
252*8a91bc7bSHarrietAkot     assert(d < getRank());
253*8a91bc7bSHarrietAkot     *out = &pointers[d];
254*8a91bc7bSHarrietAkot   }
255*8a91bc7bSHarrietAkot   void getIndices(std::vector<I> **out, uint64_t d) override {
256*8a91bc7bSHarrietAkot     assert(d < getRank());
257*8a91bc7bSHarrietAkot     *out = &indices[d];
258*8a91bc7bSHarrietAkot   }
259*8a91bc7bSHarrietAkot   void getValues(std::vector<V> **out) override { *out = &values; }
260*8a91bc7bSHarrietAkot 
261*8a91bc7bSHarrietAkot   /// Returns this sparse tensor storage scheme as a new memory-resident
262*8a91bc7bSHarrietAkot   /// sparse tensor in coordinate scheme with the given dimension order.
263*8a91bc7bSHarrietAkot   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
264*8a91bc7bSHarrietAkot     // Restore original order of the dimension sizes and allocate coordinate
265*8a91bc7bSHarrietAkot     // scheme with desired new ordering specified in perm.
266*8a91bc7bSHarrietAkot     uint64_t rank = getRank();
267*8a91bc7bSHarrietAkot     std::vector<uint64_t> orgsz(rank);
268*8a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
269*8a91bc7bSHarrietAkot       orgsz[rev[r]] = sizes[r];
270*8a91bc7bSHarrietAkot     SparseTensorCOO<V> *tensor = SparseTensorCOO<V>::newSparseTensorCOO(
271*8a91bc7bSHarrietAkot         rank, orgsz.data(), perm, values.size());
272*8a91bc7bSHarrietAkot     // Populate coordinate scheme restored from old ordering and changed with
273*8a91bc7bSHarrietAkot     // new ordering. Rather than applying both reorderings during the recursion,
274*8a91bc7bSHarrietAkot     // we compute the combine permutation in advance.
275*8a91bc7bSHarrietAkot     std::vector<uint64_t> reord(rank);
276*8a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
277*8a91bc7bSHarrietAkot       reord[r] = perm[rev[r]];
278*8a91bc7bSHarrietAkot     std::vector<uint64_t> idx(rank);
279*8a91bc7bSHarrietAkot     toCOO(tensor, reord, idx, 0, 0);
280*8a91bc7bSHarrietAkot     assert(tensor->getElements().size() == values.size());
281*8a91bc7bSHarrietAkot     return tensor;
282*8a91bc7bSHarrietAkot   }
283*8a91bc7bSHarrietAkot 
284*8a91bc7bSHarrietAkot   /// Factory method. Constructs a sparse tensor storage scheme with the given
285*8a91bc7bSHarrietAkot   /// dimensions, permutation, and per-dimension dense/sparse annotations,
286*8a91bc7bSHarrietAkot   /// using the coordinate scheme tensor for the initial contents if provided.
287*8a91bc7bSHarrietAkot   /// In the latter case, the coordinate scheme must respect the same
288*8a91bc7bSHarrietAkot   /// permutation as is desired for the new sparse tensor storage.
289*8a91bc7bSHarrietAkot   static SparseTensorStorage<P, I, V> *
290*8a91bc7bSHarrietAkot   newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm,
291*8a91bc7bSHarrietAkot                   const uint8_t *sparsity, SparseTensorCOO<V> *tensor) {
292*8a91bc7bSHarrietAkot     SparseTensorStorage<P, I, V> *n = nullptr;
293*8a91bc7bSHarrietAkot     if (tensor) {
294*8a91bc7bSHarrietAkot       assert(tensor->getRank() == rank);
295*8a91bc7bSHarrietAkot       for (uint64_t r = 0; r < rank; r++)
296*8a91bc7bSHarrietAkot         assert(sizes[r] == 0 || tensor->getSizes()[perm[r]] == sizes[r]);
297*8a91bc7bSHarrietAkot       tensor->sort(); // sort lexicographically
298*8a91bc7bSHarrietAkot       n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity,
299*8a91bc7bSHarrietAkot                                            tensor);
300*8a91bc7bSHarrietAkot       delete tensor;
301*8a91bc7bSHarrietAkot     } else {
302*8a91bc7bSHarrietAkot       std::vector<uint64_t> permsz(rank);
303*8a91bc7bSHarrietAkot       for (uint64_t r = 0; r < rank; r++)
304*8a91bc7bSHarrietAkot         permsz[perm[r]] = sizes[r];
305*8a91bc7bSHarrietAkot       n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, tensor);
306*8a91bc7bSHarrietAkot     }
307*8a91bc7bSHarrietAkot     return n;
308*8a91bc7bSHarrietAkot   }
309*8a91bc7bSHarrietAkot 
310*8a91bc7bSHarrietAkot private:
311*8a91bc7bSHarrietAkot   /// Initializes sparse tensor storage scheme from a memory-resident sparse
312*8a91bc7bSHarrietAkot   /// tensor in coordinate scheme. This method prepares the pointers and
313*8a91bc7bSHarrietAkot   /// indices arrays under the given per-dimension dense/sparse annotations.
314*8a91bc7bSHarrietAkot   void fromCOO(SparseTensorCOO<V> *tensor, const uint8_t *sparsity, uint64_t lo,
315*8a91bc7bSHarrietAkot                uint64_t hi, uint64_t d) {
316*8a91bc7bSHarrietAkot     const std::vector<Element<V>> &elements = tensor->getElements();
317*8a91bc7bSHarrietAkot     // Once dimensions are exhausted, insert the numerical values.
318*8a91bc7bSHarrietAkot     if (d == getRank()) {
319*8a91bc7bSHarrietAkot       assert(lo >= hi || lo < elements.size());
320*8a91bc7bSHarrietAkot       values.push_back(lo < hi ? elements[lo].value : 0);
321*8a91bc7bSHarrietAkot       return;
322*8a91bc7bSHarrietAkot     }
323*8a91bc7bSHarrietAkot     assert(d < getRank());
324*8a91bc7bSHarrietAkot     // Visit all elements in this interval.
325*8a91bc7bSHarrietAkot     uint64_t full = 0;
326*8a91bc7bSHarrietAkot     while (lo < hi) {
327*8a91bc7bSHarrietAkot       assert(lo < elements.size() && hi <= elements.size());
328*8a91bc7bSHarrietAkot       // Find segment in interval with same index elements in this dimension.
329*8a91bc7bSHarrietAkot       uint64_t idx = elements[lo].indices[d];
330*8a91bc7bSHarrietAkot       uint64_t seg = lo + 1;
331*8a91bc7bSHarrietAkot       while (seg < hi && elements[seg].indices[d] == idx)
332*8a91bc7bSHarrietAkot         seg++;
333*8a91bc7bSHarrietAkot       // Handle segment in interval for sparse or dense dimension.
334*8a91bc7bSHarrietAkot       if (sparsity[d] == kCompressed) {
335*8a91bc7bSHarrietAkot         indices[d].push_back(idx);
336*8a91bc7bSHarrietAkot       } else {
337*8a91bc7bSHarrietAkot         // For dense storage we must fill in all the zero values between
338*8a91bc7bSHarrietAkot         // the previous element (when last we ran this for-loop) and the
339*8a91bc7bSHarrietAkot         // current element.
340*8a91bc7bSHarrietAkot         for (; full < idx; full++)
341*8a91bc7bSHarrietAkot           fromCOO(tensor, sparsity, 0, 0, d + 1); // pass empty
342*8a91bc7bSHarrietAkot         full++;
343*8a91bc7bSHarrietAkot       }
344*8a91bc7bSHarrietAkot       fromCOO(tensor, sparsity, lo, seg, d + 1);
345*8a91bc7bSHarrietAkot       // And move on to next segment in interval.
346*8a91bc7bSHarrietAkot       lo = seg;
347*8a91bc7bSHarrietAkot     }
348*8a91bc7bSHarrietAkot     // Finalize the sparse pointer structure at this dimension.
349*8a91bc7bSHarrietAkot     if (sparsity[d] == kCompressed) {
350*8a91bc7bSHarrietAkot       pointers[d].push_back(indices[d].size());
351*8a91bc7bSHarrietAkot     } else {
352*8a91bc7bSHarrietAkot       // For dense storage we must fill in all the zero values after
353*8a91bc7bSHarrietAkot       // the last element.
354*8a91bc7bSHarrietAkot       for (uint64_t sz = sizes[d]; full < sz; full++)
355*8a91bc7bSHarrietAkot         fromCOO(tensor, sparsity, 0, 0, d + 1); // pass empty
356*8a91bc7bSHarrietAkot     }
357*8a91bc7bSHarrietAkot   }
358*8a91bc7bSHarrietAkot 
359*8a91bc7bSHarrietAkot   /// Stores the sparse tensor storage scheme into a memory-resident sparse
360*8a91bc7bSHarrietAkot   /// tensor in coordinate scheme.
361*8a91bc7bSHarrietAkot   void toCOO(SparseTensorCOO<V> *tensor, std::vector<uint64_t> &reord,
362*8a91bc7bSHarrietAkot              std::vector<uint64_t> &idx, uint64_t pos, uint64_t d) {
363*8a91bc7bSHarrietAkot     assert(d <= getRank());
364*8a91bc7bSHarrietAkot     if (d == getRank()) {
365*8a91bc7bSHarrietAkot       assert(pos < values.size());
366*8a91bc7bSHarrietAkot       tensor->add(idx, values[pos]);
367*8a91bc7bSHarrietAkot     } else if (pointers[d].empty()) {
368*8a91bc7bSHarrietAkot       // Dense dimension.
369*8a91bc7bSHarrietAkot       for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) {
370*8a91bc7bSHarrietAkot         idx[reord[d]] = i;
371*8a91bc7bSHarrietAkot         toCOO(tensor, reord, idx, off + i, d + 1);
372*8a91bc7bSHarrietAkot       }
373*8a91bc7bSHarrietAkot     } else {
374*8a91bc7bSHarrietAkot       // Sparse dimension.
375*8a91bc7bSHarrietAkot       for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
376*8a91bc7bSHarrietAkot         idx[reord[d]] = indices[d][ii];
377*8a91bc7bSHarrietAkot         toCOO(tensor, reord, idx, ii, d + 1);
378*8a91bc7bSHarrietAkot       }
379*8a91bc7bSHarrietAkot     }
380*8a91bc7bSHarrietAkot   }
381*8a91bc7bSHarrietAkot 
382*8a91bc7bSHarrietAkot private:
383*8a91bc7bSHarrietAkot   std::vector<uint64_t> sizes; // per-dimension sizes
384*8a91bc7bSHarrietAkot   std::vector<uint64_t> rev;   // "reverse" permutation
385*8a91bc7bSHarrietAkot   std::vector<std::vector<P>> pointers;
386*8a91bc7bSHarrietAkot   std::vector<std::vector<I>> indices;
387*8a91bc7bSHarrietAkot   std::vector<V> values;
388*8a91bc7bSHarrietAkot };
389*8a91bc7bSHarrietAkot 
390*8a91bc7bSHarrietAkot /// Helper to convert string to lower case.
391*8a91bc7bSHarrietAkot static char *toLower(char *token) {
392*8a91bc7bSHarrietAkot   for (char *c = token; *c; c++)
393*8a91bc7bSHarrietAkot     *c = tolower(*c);
394*8a91bc7bSHarrietAkot   return token;
395*8a91bc7bSHarrietAkot }
396*8a91bc7bSHarrietAkot 
397*8a91bc7bSHarrietAkot /// Read the MME header of a general sparse matrix of type real.
398*8a91bc7bSHarrietAkot static void readMMEHeader(FILE *file, char *name, uint64_t *idata) {
399*8a91bc7bSHarrietAkot   char line[1025];
400*8a91bc7bSHarrietAkot   char header[64];
401*8a91bc7bSHarrietAkot   char object[64];
402*8a91bc7bSHarrietAkot   char format[64];
403*8a91bc7bSHarrietAkot   char field[64];
404*8a91bc7bSHarrietAkot   char symmetry[64];
405*8a91bc7bSHarrietAkot   // Read header line.
406*8a91bc7bSHarrietAkot   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
407*8a91bc7bSHarrietAkot              symmetry) != 5) {
408*8a91bc7bSHarrietAkot     fprintf(stderr, "Corrupt header in %s\n", name);
409*8a91bc7bSHarrietAkot     exit(1);
410*8a91bc7bSHarrietAkot   }
411*8a91bc7bSHarrietAkot   // Make sure this is a general sparse matrix.
412*8a91bc7bSHarrietAkot   if (strcmp(toLower(header), "%%matrixmarket") ||
413*8a91bc7bSHarrietAkot       strcmp(toLower(object), "matrix") ||
414*8a91bc7bSHarrietAkot       strcmp(toLower(format), "coordinate") || strcmp(toLower(field), "real") ||
415*8a91bc7bSHarrietAkot       strcmp(toLower(symmetry), "general")) {
416*8a91bc7bSHarrietAkot     fprintf(stderr,
417*8a91bc7bSHarrietAkot             "Cannot find a general sparse matrix with type real in %s\n", name);
418*8a91bc7bSHarrietAkot     exit(1);
419*8a91bc7bSHarrietAkot   }
420*8a91bc7bSHarrietAkot   // Skip comments.
421*8a91bc7bSHarrietAkot   while (1) {
422*8a91bc7bSHarrietAkot     if (!fgets(line, 1025, file)) {
423*8a91bc7bSHarrietAkot       fprintf(stderr, "Cannot find data in %s\n", name);
424*8a91bc7bSHarrietAkot       exit(1);
425*8a91bc7bSHarrietAkot     }
426*8a91bc7bSHarrietAkot     if (line[0] != '%')
427*8a91bc7bSHarrietAkot       break;
428*8a91bc7bSHarrietAkot   }
429*8a91bc7bSHarrietAkot   // Next line contains M N NNZ.
430*8a91bc7bSHarrietAkot   idata[0] = 2; // rank
431*8a91bc7bSHarrietAkot   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
432*8a91bc7bSHarrietAkot              idata + 1) != 3) {
433*8a91bc7bSHarrietAkot     fprintf(stderr, "Cannot find size in %s\n", name);
434*8a91bc7bSHarrietAkot     exit(1);
435*8a91bc7bSHarrietAkot   }
436*8a91bc7bSHarrietAkot }
437*8a91bc7bSHarrietAkot 
438*8a91bc7bSHarrietAkot /// Read the "extended" FROSTT header. Although not part of the documented
439*8a91bc7bSHarrietAkot /// format, we assume that the file starts with optional comments followed
440*8a91bc7bSHarrietAkot /// by two lines that define the rank, the number of nonzeros, and the
441*8a91bc7bSHarrietAkot /// dimensions sizes (one per rank) of the sparse tensor.
442*8a91bc7bSHarrietAkot static void readExtFROSTTHeader(FILE *file, char *name, uint64_t *idata) {
443*8a91bc7bSHarrietAkot   char line[1025];
444*8a91bc7bSHarrietAkot   // Skip comments.
445*8a91bc7bSHarrietAkot   while (1) {
446*8a91bc7bSHarrietAkot     if (!fgets(line, 1025, file)) {
447*8a91bc7bSHarrietAkot       fprintf(stderr, "Cannot find data in %s\n", name);
448*8a91bc7bSHarrietAkot       exit(1);
449*8a91bc7bSHarrietAkot     }
450*8a91bc7bSHarrietAkot     if (line[0] != '#')
451*8a91bc7bSHarrietAkot       break;
452*8a91bc7bSHarrietAkot   }
453*8a91bc7bSHarrietAkot   // Next line contains RANK and NNZ.
454*8a91bc7bSHarrietAkot   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
455*8a91bc7bSHarrietAkot     fprintf(stderr, "Cannot find metadata in %s\n", name);
456*8a91bc7bSHarrietAkot     exit(1);
457*8a91bc7bSHarrietAkot   }
458*8a91bc7bSHarrietAkot   // Followed by a line with the dimension sizes (one per rank).
459*8a91bc7bSHarrietAkot   for (uint64_t r = 0; r < idata[0]; r++) {
460*8a91bc7bSHarrietAkot     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
461*8a91bc7bSHarrietAkot       fprintf(stderr, "Cannot find dimension size %s\n", name);
462*8a91bc7bSHarrietAkot       exit(1);
463*8a91bc7bSHarrietAkot     }
464*8a91bc7bSHarrietAkot   }
465*8a91bc7bSHarrietAkot }
466*8a91bc7bSHarrietAkot 
467*8a91bc7bSHarrietAkot /// Reads a sparse tensor with the given filename into a memory-resident
468*8a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme.
469*8a91bc7bSHarrietAkot template <typename V>
470*8a91bc7bSHarrietAkot static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
471*8a91bc7bSHarrietAkot                                                const uint64_t *sizes,
472*8a91bc7bSHarrietAkot                                                const uint64_t *perm) {
473*8a91bc7bSHarrietAkot   // Open the file.
474*8a91bc7bSHarrietAkot   FILE *file = fopen(filename, "r");
475*8a91bc7bSHarrietAkot   if (!file) {
476*8a91bc7bSHarrietAkot     fprintf(stderr, "Cannot find %s\n", filename);
477*8a91bc7bSHarrietAkot     exit(1);
478*8a91bc7bSHarrietAkot   }
479*8a91bc7bSHarrietAkot   // Perform some file format dependent set up.
480*8a91bc7bSHarrietAkot   uint64_t idata[512];
481*8a91bc7bSHarrietAkot   if (strstr(filename, ".mtx")) {
482*8a91bc7bSHarrietAkot     readMMEHeader(file, filename, idata);
483*8a91bc7bSHarrietAkot   } else if (strstr(filename, ".tns")) {
484*8a91bc7bSHarrietAkot     readExtFROSTTHeader(file, filename, idata);
485*8a91bc7bSHarrietAkot   } else {
486*8a91bc7bSHarrietAkot     fprintf(stderr, "Unknown format %s\n", filename);
487*8a91bc7bSHarrietAkot     exit(1);
488*8a91bc7bSHarrietAkot   }
489*8a91bc7bSHarrietAkot   // Prepare sparse tensor object with per-dimension sizes
490*8a91bc7bSHarrietAkot   // and the number of nonzeros as initial capacity.
491*8a91bc7bSHarrietAkot   assert(rank == idata[0] && "rank mismatch");
492*8a91bc7bSHarrietAkot   uint64_t nnz = idata[1];
493*8a91bc7bSHarrietAkot   for (uint64_t r = 0; r < rank; r++)
494*8a91bc7bSHarrietAkot     assert((sizes[r] == 0 || sizes[r] == idata[2 + r]) &&
495*8a91bc7bSHarrietAkot            "dimension size mismatch");
496*8a91bc7bSHarrietAkot   SparseTensorCOO<V> *tensor =
497*8a91bc7bSHarrietAkot       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
498*8a91bc7bSHarrietAkot   //  Read all nonzero elements.
499*8a91bc7bSHarrietAkot   std::vector<uint64_t> indices(rank);
500*8a91bc7bSHarrietAkot   for (uint64_t k = 0; k < nnz; k++) {
501*8a91bc7bSHarrietAkot     uint64_t idx = -1;
502*8a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++) {
503*8a91bc7bSHarrietAkot       if (fscanf(file, "%" PRIu64, &idx) != 1) {
504*8a91bc7bSHarrietAkot         fprintf(stderr, "Cannot find next index in %s\n", filename);
505*8a91bc7bSHarrietAkot         exit(1);
506*8a91bc7bSHarrietAkot       }
507*8a91bc7bSHarrietAkot       // Add 0-based index.
508*8a91bc7bSHarrietAkot       indices[perm[r]] = idx - 1;
509*8a91bc7bSHarrietAkot     }
510*8a91bc7bSHarrietAkot     // The external formats always store the numerical values with the type
511*8a91bc7bSHarrietAkot     // double, but we cast these values to the sparse tensor object type.
512*8a91bc7bSHarrietAkot     double value;
513*8a91bc7bSHarrietAkot     if (fscanf(file, "%lg\n", &value) != 1) {
514*8a91bc7bSHarrietAkot       fprintf(stderr, "Cannot find next value in %s\n", filename);
515*8a91bc7bSHarrietAkot       exit(1);
516*8a91bc7bSHarrietAkot     }
517*8a91bc7bSHarrietAkot     tensor->add(indices, value);
518*8a91bc7bSHarrietAkot   }
519*8a91bc7bSHarrietAkot   // Close the file and return tensor.
520*8a91bc7bSHarrietAkot   fclose(file);
521*8a91bc7bSHarrietAkot   return tensor;
522*8a91bc7bSHarrietAkot }
523*8a91bc7bSHarrietAkot 
524*8a91bc7bSHarrietAkot } // anonymous namespace
525*8a91bc7bSHarrietAkot 
526*8a91bc7bSHarrietAkot extern "C" {
527*8a91bc7bSHarrietAkot 
528*8a91bc7bSHarrietAkot /// This type is used in the public API at all places where MLIR expects
529*8a91bc7bSHarrietAkot /// values with the built-in type "index". For now, we simply assume that
530*8a91bc7bSHarrietAkot /// type is 64-bit, but targets with different "index" bit widths should link
531*8a91bc7bSHarrietAkot /// with an alternatively built runtime support library.
532*8a91bc7bSHarrietAkot // TODO: support such targets?
533*8a91bc7bSHarrietAkot typedef uint64_t index_t;
534*8a91bc7bSHarrietAkot 
535*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
536*8a91bc7bSHarrietAkot //
537*8a91bc7bSHarrietAkot // Public API with methods that operate on MLIR buffers (memrefs) to interact
538*8a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally.
539*8a91bc7bSHarrietAkot // These methods should be used exclusively by MLIR compiler-generated code.
540*8a91bc7bSHarrietAkot //
541*8a91bc7bSHarrietAkot // Some macro magic is used to generate implementations for all required type
542*8a91bc7bSHarrietAkot // combinations that can be called from MLIR compiler-generated code.
543*8a91bc7bSHarrietAkot //
544*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
545*8a91bc7bSHarrietAkot 
546*8a91bc7bSHarrietAkot enum OverheadTypeEnum : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
547*8a91bc7bSHarrietAkot 
548*8a91bc7bSHarrietAkot enum PrimaryTypeEnum : uint32_t {
549*8a91bc7bSHarrietAkot   kF64 = 1,
550*8a91bc7bSHarrietAkot   kF32 = 2,
551*8a91bc7bSHarrietAkot   kI64 = 3,
552*8a91bc7bSHarrietAkot   kI32 = 4,
553*8a91bc7bSHarrietAkot   kI16 = 5,
554*8a91bc7bSHarrietAkot   kI8 = 6
555*8a91bc7bSHarrietAkot };
556*8a91bc7bSHarrietAkot 
557*8a91bc7bSHarrietAkot enum Action : uint32_t {
558*8a91bc7bSHarrietAkot   kEmpty = 0,
559*8a91bc7bSHarrietAkot   kFromFile = 1,
560*8a91bc7bSHarrietAkot   kFromCOO = 2,
561*8a91bc7bSHarrietAkot   kEmptyCOO = 3,
562*8a91bc7bSHarrietAkot   kToCOO = 4,
563*8a91bc7bSHarrietAkot   kToIter = 5
564*8a91bc7bSHarrietAkot };
565*8a91bc7bSHarrietAkot 
566*8a91bc7bSHarrietAkot #define CASE(p, i, v, P, I, V)                                                 \
567*8a91bc7bSHarrietAkot   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
568*8a91bc7bSHarrietAkot     SparseTensorCOO<V> *tensor = nullptr;                                      \
569*8a91bc7bSHarrietAkot     if (action <= kFromCOO) {                                                  \
570*8a91bc7bSHarrietAkot       if (action == kFromFile) {                                               \
571*8a91bc7bSHarrietAkot         char *filename = static_cast<char *>(ptr);                             \
572*8a91bc7bSHarrietAkot         tensor = openSparseTensorCOO<V>(filename, rank, sizes, perm);          \
573*8a91bc7bSHarrietAkot       } else if (action == kFromCOO) {                                         \
574*8a91bc7bSHarrietAkot         tensor = static_cast<SparseTensorCOO<V> *>(ptr);                       \
575*8a91bc7bSHarrietAkot       } else {                                                                 \
576*8a91bc7bSHarrietAkot         assert(action == kEmpty);                                              \
577*8a91bc7bSHarrietAkot       }                                                                        \
578*8a91bc7bSHarrietAkot       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm,  \
579*8a91bc7bSHarrietAkot                                                            sparsity, tensor);  \
580*8a91bc7bSHarrietAkot     } else if (action == kEmptyCOO) {                                          \
581*8a91bc7bSHarrietAkot       return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm);        \
582*8a91bc7bSHarrietAkot     } else {                                                                   \
583*8a91bc7bSHarrietAkot       tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);  \
584*8a91bc7bSHarrietAkot       if (action == kToIter) {                                                 \
585*8a91bc7bSHarrietAkot         tensor->startIterator();                                               \
586*8a91bc7bSHarrietAkot       } else {                                                                 \
587*8a91bc7bSHarrietAkot         assert(action == kToCOO);                                              \
588*8a91bc7bSHarrietAkot       }                                                                        \
589*8a91bc7bSHarrietAkot       return tensor;                                                           \
590*8a91bc7bSHarrietAkot     }                                                                          \
591*8a91bc7bSHarrietAkot   }
592*8a91bc7bSHarrietAkot 
593*8a91bc7bSHarrietAkot #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
594*8a91bc7bSHarrietAkot   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
595*8a91bc7bSHarrietAkot     assert(ref);                                                               \
596*8a91bc7bSHarrietAkot     assert(tensor);                                                            \
597*8a91bc7bSHarrietAkot     std::vector<TYPE> *v;                                                      \
598*8a91bc7bSHarrietAkot     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
599*8a91bc7bSHarrietAkot     ref->basePtr = ref->data = v->data();                                      \
600*8a91bc7bSHarrietAkot     ref->offset = 0;                                                           \
601*8a91bc7bSHarrietAkot     ref->sizes[0] = v->size();                                                 \
602*8a91bc7bSHarrietAkot     ref->strides[0] = 1;                                                       \
603*8a91bc7bSHarrietAkot   }
604*8a91bc7bSHarrietAkot 
605*8a91bc7bSHarrietAkot #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
606*8a91bc7bSHarrietAkot   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
607*8a91bc7bSHarrietAkot                            index_t d) {                                        \
608*8a91bc7bSHarrietAkot     assert(ref);                                                               \
609*8a91bc7bSHarrietAkot     assert(tensor);                                                            \
610*8a91bc7bSHarrietAkot     std::vector<TYPE> *v;                                                      \
611*8a91bc7bSHarrietAkot     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
612*8a91bc7bSHarrietAkot     ref->basePtr = ref->data = v->data();                                      \
613*8a91bc7bSHarrietAkot     ref->offset = 0;                                                           \
614*8a91bc7bSHarrietAkot     ref->sizes[0] = v->size();                                                 \
615*8a91bc7bSHarrietAkot     ref->strides[0] = 1;                                                       \
616*8a91bc7bSHarrietAkot   }
617*8a91bc7bSHarrietAkot 
618*8a91bc7bSHarrietAkot #define IMPL_ADDELT(NAME, TYPE)                                                \
619*8a91bc7bSHarrietAkot   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
620*8a91bc7bSHarrietAkot                             StridedMemRefType<index_t, 1> *iref,               \
621*8a91bc7bSHarrietAkot                             StridedMemRefType<index_t, 1> *pref) {             \
622*8a91bc7bSHarrietAkot     assert(tensor);                                                            \
623*8a91bc7bSHarrietAkot     assert(iref);                                                              \
624*8a91bc7bSHarrietAkot     assert(pref);                                                              \
625*8a91bc7bSHarrietAkot     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
626*8a91bc7bSHarrietAkot     assert(iref->sizes[0] == pref->sizes[0]);                                  \
627*8a91bc7bSHarrietAkot     const index_t *indx = iref->data + iref->offset;                           \
628*8a91bc7bSHarrietAkot     const index_t *perm = pref->data + pref->offset;                           \
629*8a91bc7bSHarrietAkot     uint64_t isize = iref->sizes[0];                                           \
630*8a91bc7bSHarrietAkot     std::vector<index_t> indices(isize);                                       \
631*8a91bc7bSHarrietAkot     for (uint64_t r = 0; r < isize; r++)                                       \
632*8a91bc7bSHarrietAkot       indices[perm[r]] = indx[r];                                              \
633*8a91bc7bSHarrietAkot     static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value);         \
634*8a91bc7bSHarrietAkot     return tensor;                                                             \
635*8a91bc7bSHarrietAkot   }
636*8a91bc7bSHarrietAkot 
637*8a91bc7bSHarrietAkot #define IMPL_GETNEXT(NAME, V)                                                  \
638*8a91bc7bSHarrietAkot   bool _mlir_ciface_##NAME(void *tensor, StridedMemRefType<uint64_t, 1> *iref, \
639*8a91bc7bSHarrietAkot                            StridedMemRefType<V, 0> *vref) {                    \
640*8a91bc7bSHarrietAkot     assert(iref->strides[0] == 1);                                             \
641*8a91bc7bSHarrietAkot     uint64_t *indx = iref->data + iref->offset;                                \
642*8a91bc7bSHarrietAkot     V *value = vref->data + vref->offset;                                      \
643*8a91bc7bSHarrietAkot     const uint64_t isize = iref->sizes[0];                                     \
644*8a91bc7bSHarrietAkot     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
645*8a91bc7bSHarrietAkot     const Element<V> *elem = iter->getNext();                                  \
646*8a91bc7bSHarrietAkot     if (elem == nullptr) {                                                     \
647*8a91bc7bSHarrietAkot       delete iter;                                                             \
648*8a91bc7bSHarrietAkot       return false;                                                            \
649*8a91bc7bSHarrietAkot     }                                                                          \
650*8a91bc7bSHarrietAkot     for (uint64_t r = 0; r < isize; r++)                                       \
651*8a91bc7bSHarrietAkot       indx[r] = elem->indices[r];                                              \
652*8a91bc7bSHarrietAkot     *value = elem->value;                                                      \
653*8a91bc7bSHarrietAkot     return true;                                                               \
654*8a91bc7bSHarrietAkot   }
655*8a91bc7bSHarrietAkot 
656*8a91bc7bSHarrietAkot /// Constructs a new sparse tensor. This is the "swiss army knife"
657*8a91bc7bSHarrietAkot /// method for materializing sparse tensors into the computation.
658*8a91bc7bSHarrietAkot ///
659*8a91bc7bSHarrietAkot /// action:
660*8a91bc7bSHarrietAkot /// kEmpty = returns empty storage to fill later
661*8a91bc7bSHarrietAkot /// kFromFile = returns storage, where ptr contains filename to read
662*8a91bc7bSHarrietAkot /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
663*8a91bc7bSHarrietAkot /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
664*8a91bc7bSHarrietAkot /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
665*8a91bc7bSHarrietAkot /// kToIter = returns iterator from storage in ptr (call getNext() to use)
666*8a91bc7bSHarrietAkot void *
667*8a91bc7bSHarrietAkot _mlir_ciface_newSparseTensor(StridedMemRefType<uint8_t, 1> *aref, // NOLINT
668*8a91bc7bSHarrietAkot                              StridedMemRefType<index_t, 1> *sref,
669*8a91bc7bSHarrietAkot                              StridedMemRefType<index_t, 1> *pref,
670*8a91bc7bSHarrietAkot                              uint32_t ptrTp, uint32_t indTp, uint32_t valTp,
671*8a91bc7bSHarrietAkot                              uint32_t action, void *ptr) {
672*8a91bc7bSHarrietAkot   assert(aref && sref && pref);
673*8a91bc7bSHarrietAkot   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
674*8a91bc7bSHarrietAkot          pref->strides[0] == 1);
675*8a91bc7bSHarrietAkot   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
676*8a91bc7bSHarrietAkot   const uint8_t *sparsity = aref->data + aref->offset;
677*8a91bc7bSHarrietAkot   const index_t *sizes = sref->data + sref->offset;
678*8a91bc7bSHarrietAkot   const index_t *perm = pref->data + pref->offset;
679*8a91bc7bSHarrietAkot   uint64_t rank = aref->sizes[0];
680*8a91bc7bSHarrietAkot 
681*8a91bc7bSHarrietAkot   // Double matrices with all combinations of overhead storage.
682*8a91bc7bSHarrietAkot   CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
683*8a91bc7bSHarrietAkot   CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
684*8a91bc7bSHarrietAkot   CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
685*8a91bc7bSHarrietAkot   CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
686*8a91bc7bSHarrietAkot   CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
687*8a91bc7bSHarrietAkot   CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
688*8a91bc7bSHarrietAkot   CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
689*8a91bc7bSHarrietAkot   CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
690*8a91bc7bSHarrietAkot   CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
691*8a91bc7bSHarrietAkot   CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
692*8a91bc7bSHarrietAkot   CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
693*8a91bc7bSHarrietAkot   CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
694*8a91bc7bSHarrietAkot   CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
695*8a91bc7bSHarrietAkot   CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
696*8a91bc7bSHarrietAkot   CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
697*8a91bc7bSHarrietAkot   CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
698*8a91bc7bSHarrietAkot 
699*8a91bc7bSHarrietAkot   // Float matrices with all combinations of overhead storage.
700*8a91bc7bSHarrietAkot   CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
701*8a91bc7bSHarrietAkot   CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
702*8a91bc7bSHarrietAkot   CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
703*8a91bc7bSHarrietAkot   CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
704*8a91bc7bSHarrietAkot   CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
705*8a91bc7bSHarrietAkot   CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
706*8a91bc7bSHarrietAkot   CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
707*8a91bc7bSHarrietAkot   CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
708*8a91bc7bSHarrietAkot   CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
709*8a91bc7bSHarrietAkot   CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
710*8a91bc7bSHarrietAkot   CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
711*8a91bc7bSHarrietAkot   CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
712*8a91bc7bSHarrietAkot   CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
713*8a91bc7bSHarrietAkot   CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
714*8a91bc7bSHarrietAkot   CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
715*8a91bc7bSHarrietAkot   CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
716*8a91bc7bSHarrietAkot 
717*8a91bc7bSHarrietAkot   // Integral matrices with same overhead storage.
718*8a91bc7bSHarrietAkot   CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t);
719*8a91bc7bSHarrietAkot   CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t);
720*8a91bc7bSHarrietAkot   CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t);
721*8a91bc7bSHarrietAkot   CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t);
722*8a91bc7bSHarrietAkot   CASE(kU32, kU32, kI32, uint32_t, uint32_t, int32_t);
723*8a91bc7bSHarrietAkot   CASE(kU32, kU32, kI16, uint32_t, uint32_t, int16_t);
724*8a91bc7bSHarrietAkot   CASE(kU32, kU32, kI8, uint32_t, uint32_t, int8_t);
725*8a91bc7bSHarrietAkot   CASE(kU16, kU16, kI32, uint16_t, uint16_t, int32_t);
726*8a91bc7bSHarrietAkot   CASE(kU16, kU16, kI16, uint16_t, uint16_t, int16_t);
727*8a91bc7bSHarrietAkot   CASE(kU16, kU16, kI8, uint16_t, uint16_t, int8_t);
728*8a91bc7bSHarrietAkot   CASE(kU8, kU8, kI32, uint8_t, uint8_t, int32_t);
729*8a91bc7bSHarrietAkot   CASE(kU8, kU8, kI16, uint8_t, uint8_t, int16_t);
730*8a91bc7bSHarrietAkot   CASE(kU8, kU8, kI8, uint8_t, uint8_t, int8_t);
731*8a91bc7bSHarrietAkot 
732*8a91bc7bSHarrietAkot   // Unsupported case (add above if needed).
733*8a91bc7bSHarrietAkot   fputs("unsupported combination of types\n", stderr);
734*8a91bc7bSHarrietAkot   exit(1);
735*8a91bc7bSHarrietAkot }
736*8a91bc7bSHarrietAkot 
737*8a91bc7bSHarrietAkot /// Methods that provide direct access to pointers.
738*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers, index_t, getPointers)
739*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
740*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
741*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
742*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
743*8a91bc7bSHarrietAkot 
744*8a91bc7bSHarrietAkot /// Methods that provide direct access to indices.
745*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices, index_t, getIndices)
746*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
747*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
748*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
749*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
750*8a91bc7bSHarrietAkot 
751*8a91bc7bSHarrietAkot /// Methods that provide direct access to values.
752*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
753*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
754*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
755*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
756*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
757*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
758*8a91bc7bSHarrietAkot 
759*8a91bc7bSHarrietAkot /// Helper to add value to coordinate scheme, one per value type.
760*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltF64, double)
761*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltF32, float)
762*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltI64, int64_t)
763*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltI32, int32_t)
764*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltI16, int16_t)
765*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltI8, int8_t)
766*8a91bc7bSHarrietAkot 
767*8a91bc7bSHarrietAkot /// Helper to enumerate elements of coordinate scheme, one per value type.
768*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF64, double)
769*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF32, float)
770*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI64, int64_t)
771*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI32, int32_t)
772*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI16, int16_t)
773*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI8, int8_t)
774*8a91bc7bSHarrietAkot 
775*8a91bc7bSHarrietAkot #undef CASE
776*8a91bc7bSHarrietAkot #undef IMPL_SPARSEVALUES
777*8a91bc7bSHarrietAkot #undef IMPL_GETOVERHEAD
778*8a91bc7bSHarrietAkot #undef IMPL_ADDELT
779*8a91bc7bSHarrietAkot #undef IMPL_GETNEXT
780*8a91bc7bSHarrietAkot 
781*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
782*8a91bc7bSHarrietAkot //
783*8a91bc7bSHarrietAkot // Public API with methods that accept C-style data structures to interact
784*8a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally.
785*8a91bc7bSHarrietAkot // These methods can be used both by MLIR compiler-generated code as well as by
786*8a91bc7bSHarrietAkot // an external runtime that wants to interact with MLIR compiler-generated code.
787*8a91bc7bSHarrietAkot //
788*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===//
789*8a91bc7bSHarrietAkot 
790*8a91bc7bSHarrietAkot /// Helper method to read a sparse tensor filename from the environment,
791*8a91bc7bSHarrietAkot /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
792*8a91bc7bSHarrietAkot char *getTensorFilename(index_t id) {
793*8a91bc7bSHarrietAkot   char var[80];
794*8a91bc7bSHarrietAkot   sprintf(var, "TENSOR%" PRIu64, id);
795*8a91bc7bSHarrietAkot   char *env = getenv(var);
796*8a91bc7bSHarrietAkot   return env;
797*8a91bc7bSHarrietAkot }
798*8a91bc7bSHarrietAkot 
799*8a91bc7bSHarrietAkot /// Returns size of sparse tensor in given dimension.
800*8a91bc7bSHarrietAkot index_t sparseDimSize(void *tensor, index_t d) {
801*8a91bc7bSHarrietAkot   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
802*8a91bc7bSHarrietAkot }
803*8a91bc7bSHarrietAkot 
804*8a91bc7bSHarrietAkot /// Releases sparse tensor storage.
805*8a91bc7bSHarrietAkot void delSparseTensor(void *tensor) {
806*8a91bc7bSHarrietAkot   delete static_cast<SparseTensorStorageBase *>(tensor);
807*8a91bc7bSHarrietAkot }
808*8a91bc7bSHarrietAkot 
809*8a91bc7bSHarrietAkot /// Initializes sparse tensor from a COO-flavored format expressed using C-style
810*8a91bc7bSHarrietAkot /// data structures. The expected parameters are:
811*8a91bc7bSHarrietAkot ///
812*8a91bc7bSHarrietAkot ///   rank:    rank of tensor
813*8a91bc7bSHarrietAkot ///   nse:     number of specified elements (usually the nonzeros)
814*8a91bc7bSHarrietAkot ///   shape:   array with dimension size for each rank
815*8a91bc7bSHarrietAkot ///   values:  a "nse" array with values for all specified elements
816*8a91bc7bSHarrietAkot ///   indices: a flat "nse x rank" array with indices for all specified elements
817*8a91bc7bSHarrietAkot ///
818*8a91bc7bSHarrietAkot /// For example, the sparse matrix
819*8a91bc7bSHarrietAkot ///     | 1.0 0.0 0.0 |
820*8a91bc7bSHarrietAkot ///     | 0.0 5.0 3.0 |
821*8a91bc7bSHarrietAkot /// can be passed as
822*8a91bc7bSHarrietAkot ///      rank    = 2
823*8a91bc7bSHarrietAkot ///      nse     = 3
824*8a91bc7bSHarrietAkot ///      shape   = [2, 3]
825*8a91bc7bSHarrietAkot ///      values  = [1.0, 5.0, 3.0]
826*8a91bc7bSHarrietAkot ///      indices = [ 0, 0,  1, 1,  1, 2]
827*8a91bc7bSHarrietAkot //
828*8a91bc7bSHarrietAkot // TODO: for now f64 tensors only, no dim ordering, all dimensions compressed
829*8a91bc7bSHarrietAkot //
830*8a91bc7bSHarrietAkot void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
831*8a91bc7bSHarrietAkot                                 double *values, uint64_t *indices) {
832*8a91bc7bSHarrietAkot   // Setup all-dims compressed and default ordering.
833*8a91bc7bSHarrietAkot   std::vector<uint8_t> sparse(rank, SparseTensorStorageBase::kCompressed);
834*8a91bc7bSHarrietAkot   std::vector<uint64_t> perm(rank);
835*8a91bc7bSHarrietAkot   std::iota(perm.begin(), perm.end(), 0);
836*8a91bc7bSHarrietAkot   // Convert external format to internal COO.
837*8a91bc7bSHarrietAkot   SparseTensorCOO<double> *tensor = SparseTensorCOO<double>::newSparseTensorCOO(
838*8a91bc7bSHarrietAkot       rank, shape, perm.data(), nse);
839*8a91bc7bSHarrietAkot   std::vector<uint64_t> idx(rank);
840*8a91bc7bSHarrietAkot   for (uint64_t i = 0, base = 0; i < nse; i++) {
841*8a91bc7bSHarrietAkot     for (uint64_t r = 0; r < rank; r++)
842*8a91bc7bSHarrietAkot       idx[r] = indices[base + r];
843*8a91bc7bSHarrietAkot     tensor->add(idx, values[i]);
844*8a91bc7bSHarrietAkot     base += rank;
845*8a91bc7bSHarrietAkot   }
846*8a91bc7bSHarrietAkot   // Return sparse tensor storage format as opaque pointer.
847*8a91bc7bSHarrietAkot   return SparseTensorStorage<uint64_t, uint64_t, double>::newSparseTensor(
848*8a91bc7bSHarrietAkot       rank, shape, perm.data(), sparse.data(), tensor);
849*8a91bc7bSHarrietAkot }
850*8a91bc7bSHarrietAkot 
851*8a91bc7bSHarrietAkot } // extern "C"
852*8a91bc7bSHarrietAkot 
853*8a91bc7bSHarrietAkot #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
854