1*8a91bc7bSHarrietAkot //===- SparseTensorUtils.cpp - Sparse Tensor Utils for MLIR execution -----===// 2*8a91bc7bSHarrietAkot // 3*8a91bc7bSHarrietAkot // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*8a91bc7bSHarrietAkot // See https://llvm.org/LICENSE.txt for license information. 5*8a91bc7bSHarrietAkot // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*8a91bc7bSHarrietAkot // 7*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 8*8a91bc7bSHarrietAkot // 9*8a91bc7bSHarrietAkot // This file implements a light-weight runtime support library that is useful 10*8a91bc7bSHarrietAkot // for sparse tensor manipulations. The functionality provided in this library 11*8a91bc7bSHarrietAkot // is meant to simplify benchmarking, testing, and debugging MLIR code that 12*8a91bc7bSHarrietAkot // operates on sparse tensors. The provided functionality is **not** part 13*8a91bc7bSHarrietAkot // of core MLIR, however. 14*8a91bc7bSHarrietAkot // 15*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 16*8a91bc7bSHarrietAkot 17*8a91bc7bSHarrietAkot #include "mlir/ExecutionEngine/CRunnerUtils.h" 18*8a91bc7bSHarrietAkot 19*8a91bc7bSHarrietAkot #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 20*8a91bc7bSHarrietAkot 21*8a91bc7bSHarrietAkot #include <algorithm> 22*8a91bc7bSHarrietAkot #include <cassert> 23*8a91bc7bSHarrietAkot #include <cctype> 24*8a91bc7bSHarrietAkot #include <cinttypes> 25*8a91bc7bSHarrietAkot #include <cstdio> 26*8a91bc7bSHarrietAkot #include <cstdlib> 27*8a91bc7bSHarrietAkot #include <cstring> 28*8a91bc7bSHarrietAkot #include <numeric> 29*8a91bc7bSHarrietAkot #include <vector> 30*8a91bc7bSHarrietAkot 31*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 32*8a91bc7bSHarrietAkot // 33*8a91bc7bSHarrietAkot // Internal support for storing and reading sparse tensors. 34*8a91bc7bSHarrietAkot // 35*8a91bc7bSHarrietAkot // The following memory-resident sparse storage schemes are supported: 36*8a91bc7bSHarrietAkot // 37*8a91bc7bSHarrietAkot // (a) A coordinate scheme for temporarily storing and lexicographically 38*8a91bc7bSHarrietAkot // sorting a sparse tensor by index (SparseTensorCOO). 39*8a91bc7bSHarrietAkot // 40*8a91bc7bSHarrietAkot // (b) A "one-size-fits-all" sparse tensor storage scheme defined by 41*8a91bc7bSHarrietAkot // per-dimension sparse/dense annnotations together with a dimension 42*8a91bc7bSHarrietAkot // ordering used by MLIR compiler-generated code (SparseTensorStorage). 43*8a91bc7bSHarrietAkot // 44*8a91bc7bSHarrietAkot // The following external formats are supported: 45*8a91bc7bSHarrietAkot // 46*8a91bc7bSHarrietAkot // (1) Matrix Market Exchange (MME): *.mtx 47*8a91bc7bSHarrietAkot // https://math.nist.gov/MatrixMarket/formats.html 48*8a91bc7bSHarrietAkot // 49*8a91bc7bSHarrietAkot // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns 50*8a91bc7bSHarrietAkot // http://frostt.io/tensors/file-formats.html 51*8a91bc7bSHarrietAkot // 52*8a91bc7bSHarrietAkot // Two public APIs are supported: 53*8a91bc7bSHarrietAkot // 54*8a91bc7bSHarrietAkot // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse 55*8a91bc7bSHarrietAkot // tensors. These methods should be used exclusively by MLIR 56*8a91bc7bSHarrietAkot // compiler-generated code. 57*8a91bc7bSHarrietAkot // 58*8a91bc7bSHarrietAkot // (II) Methods that accept C-style data structures to interact with sparse 59*8a91bc7bSHarrietAkot // tensors. These methods can be used by any external runtime that wants 60*8a91bc7bSHarrietAkot // to interact with MLIR compiler-generated code. 61*8a91bc7bSHarrietAkot // 62*8a91bc7bSHarrietAkot // In both cases (I) and (II), the SparseTensorStorage format is externally 63*8a91bc7bSHarrietAkot // only visible as an opaque pointer. 64*8a91bc7bSHarrietAkot // 65*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 66*8a91bc7bSHarrietAkot 67*8a91bc7bSHarrietAkot namespace { 68*8a91bc7bSHarrietAkot 69*8a91bc7bSHarrietAkot /// A sparse tensor element in coordinate scheme (value and indices). 70*8a91bc7bSHarrietAkot /// For example, a rank-1 vector element would look like 71*8a91bc7bSHarrietAkot /// ({i}, a[i]) 72*8a91bc7bSHarrietAkot /// and a rank-5 tensor element like 73*8a91bc7bSHarrietAkot /// ({i,j,k,l,m}, a[i,j,k,l,m]) 74*8a91bc7bSHarrietAkot template <typename V> 75*8a91bc7bSHarrietAkot struct Element { 76*8a91bc7bSHarrietAkot Element(const std::vector<uint64_t> &ind, V val) : indices(ind), value(val){}; 77*8a91bc7bSHarrietAkot std::vector<uint64_t> indices; 78*8a91bc7bSHarrietAkot V value; 79*8a91bc7bSHarrietAkot }; 80*8a91bc7bSHarrietAkot 81*8a91bc7bSHarrietAkot /// A memory-resident sparse tensor in coordinate scheme (collection of 82*8a91bc7bSHarrietAkot /// elements). This data structure is used to read a sparse tensor from 83*8a91bc7bSHarrietAkot /// any external format into memory and sort the elements lexicographically 84*8a91bc7bSHarrietAkot /// by indices before passing it back to the client (most packed storage 85*8a91bc7bSHarrietAkot /// formats require the elements to appear in lexicographic index order). 86*8a91bc7bSHarrietAkot template <typename V> 87*8a91bc7bSHarrietAkot struct SparseTensorCOO { 88*8a91bc7bSHarrietAkot public: 89*8a91bc7bSHarrietAkot SparseTensorCOO(const std::vector<uint64_t> &szs, uint64_t capacity) 90*8a91bc7bSHarrietAkot : sizes(szs), iteratorLocked(false), iteratorPos(0) { 91*8a91bc7bSHarrietAkot if (capacity) 92*8a91bc7bSHarrietAkot elements.reserve(capacity); 93*8a91bc7bSHarrietAkot } 94*8a91bc7bSHarrietAkot /// Adds element as indices and value. 95*8a91bc7bSHarrietAkot void add(const std::vector<uint64_t> &ind, V val) { 96*8a91bc7bSHarrietAkot assert(!iteratorLocked && "Attempt to add() after startIterator()"); 97*8a91bc7bSHarrietAkot uint64_t rank = getRank(); 98*8a91bc7bSHarrietAkot assert(rank == ind.size()); 99*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 100*8a91bc7bSHarrietAkot assert(ind[r] < sizes[r]); // within bounds 101*8a91bc7bSHarrietAkot elements.emplace_back(ind, val); 102*8a91bc7bSHarrietAkot } 103*8a91bc7bSHarrietAkot /// Sorts elements lexicographically by index. 104*8a91bc7bSHarrietAkot void sort() { 105*8a91bc7bSHarrietAkot assert(!iteratorLocked && "Attempt to sort() after startIterator()"); 106*8a91bc7bSHarrietAkot std::sort(elements.begin(), elements.end(), lexOrder); 107*8a91bc7bSHarrietAkot } 108*8a91bc7bSHarrietAkot /// Returns rank. 109*8a91bc7bSHarrietAkot uint64_t getRank() const { return sizes.size(); } 110*8a91bc7bSHarrietAkot /// Getter for sizes array. 111*8a91bc7bSHarrietAkot const std::vector<uint64_t> &getSizes() const { return sizes; } 112*8a91bc7bSHarrietAkot /// Getter for elements array. 113*8a91bc7bSHarrietAkot const std::vector<Element<V>> &getElements() const { return elements; } 114*8a91bc7bSHarrietAkot 115*8a91bc7bSHarrietAkot /// Switch into iterator mode. 116*8a91bc7bSHarrietAkot void startIterator() { 117*8a91bc7bSHarrietAkot iteratorLocked = true; 118*8a91bc7bSHarrietAkot iteratorPos = 0; 119*8a91bc7bSHarrietAkot } 120*8a91bc7bSHarrietAkot /// Get the next element. 121*8a91bc7bSHarrietAkot const Element<V> *getNext() { 122*8a91bc7bSHarrietAkot assert(iteratorLocked && "Attempt to getNext() before startIterator()"); 123*8a91bc7bSHarrietAkot if (iteratorPos < elements.size()) 124*8a91bc7bSHarrietAkot return &(elements[iteratorPos++]); 125*8a91bc7bSHarrietAkot iteratorLocked = false; 126*8a91bc7bSHarrietAkot return nullptr; 127*8a91bc7bSHarrietAkot } 128*8a91bc7bSHarrietAkot 129*8a91bc7bSHarrietAkot /// Factory method. Permutes the original dimensions according to 130*8a91bc7bSHarrietAkot /// the given ordering and expects subsequent add() calls to honor 131*8a91bc7bSHarrietAkot /// that same ordering for the given indices. The result is a 132*8a91bc7bSHarrietAkot /// fully permuted coordinate scheme. 133*8a91bc7bSHarrietAkot static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank, 134*8a91bc7bSHarrietAkot const uint64_t *sizes, 135*8a91bc7bSHarrietAkot const uint64_t *perm, 136*8a91bc7bSHarrietAkot uint64_t capacity = 0) { 137*8a91bc7bSHarrietAkot std::vector<uint64_t> permsz(rank); 138*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 139*8a91bc7bSHarrietAkot permsz[perm[r]] = sizes[r]; 140*8a91bc7bSHarrietAkot return new SparseTensorCOO<V>(permsz, capacity); 141*8a91bc7bSHarrietAkot } 142*8a91bc7bSHarrietAkot 143*8a91bc7bSHarrietAkot private: 144*8a91bc7bSHarrietAkot /// Returns true if indices of e1 < indices of e2. 145*8a91bc7bSHarrietAkot static bool lexOrder(const Element<V> &e1, const Element<V> &e2) { 146*8a91bc7bSHarrietAkot uint64_t rank = e1.indices.size(); 147*8a91bc7bSHarrietAkot assert(rank == e2.indices.size()); 148*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) { 149*8a91bc7bSHarrietAkot if (e1.indices[r] == e2.indices[r]) 150*8a91bc7bSHarrietAkot continue; 151*8a91bc7bSHarrietAkot return e1.indices[r] < e2.indices[r]; 152*8a91bc7bSHarrietAkot } 153*8a91bc7bSHarrietAkot return false; 154*8a91bc7bSHarrietAkot } 155*8a91bc7bSHarrietAkot const std::vector<uint64_t> sizes; // per-dimension sizes 156*8a91bc7bSHarrietAkot std::vector<Element<V>> elements; 157*8a91bc7bSHarrietAkot bool iteratorLocked; 158*8a91bc7bSHarrietAkot unsigned iteratorPos; 159*8a91bc7bSHarrietAkot }; 160*8a91bc7bSHarrietAkot 161*8a91bc7bSHarrietAkot /// Abstract base class of sparse tensor storage. Note that we use 162*8a91bc7bSHarrietAkot /// function overloading to implement "partial" method specialization. 163*8a91bc7bSHarrietAkot class SparseTensorStorageBase { 164*8a91bc7bSHarrietAkot public: 165*8a91bc7bSHarrietAkot enum DimLevelType : uint8_t { kDense = 0, kCompressed = 1, kSingleton = 2 }; 166*8a91bc7bSHarrietAkot 167*8a91bc7bSHarrietAkot virtual uint64_t getDimSize(uint64_t) = 0; 168*8a91bc7bSHarrietAkot 169*8a91bc7bSHarrietAkot // Overhead storage. 170*8a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); } 171*8a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); } 172*8a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); } 173*8a91bc7bSHarrietAkot virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); } 174*8a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); } 175*8a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); } 176*8a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); } 177*8a91bc7bSHarrietAkot virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); } 178*8a91bc7bSHarrietAkot 179*8a91bc7bSHarrietAkot // Primary storage. 180*8a91bc7bSHarrietAkot virtual void getValues(std::vector<double> **) { fatal("valf64"); } 181*8a91bc7bSHarrietAkot virtual void getValues(std::vector<float> **) { fatal("valf32"); } 182*8a91bc7bSHarrietAkot virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); } 183*8a91bc7bSHarrietAkot virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); } 184*8a91bc7bSHarrietAkot virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); } 185*8a91bc7bSHarrietAkot virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); } 186*8a91bc7bSHarrietAkot 187*8a91bc7bSHarrietAkot virtual ~SparseTensorStorageBase() {} 188*8a91bc7bSHarrietAkot 189*8a91bc7bSHarrietAkot private: 190*8a91bc7bSHarrietAkot void fatal(const char *tp) { 191*8a91bc7bSHarrietAkot fprintf(stderr, "unsupported %s\n", tp); 192*8a91bc7bSHarrietAkot exit(1); 193*8a91bc7bSHarrietAkot } 194*8a91bc7bSHarrietAkot }; 195*8a91bc7bSHarrietAkot 196*8a91bc7bSHarrietAkot /// A memory-resident sparse tensor using a storage scheme based on 197*8a91bc7bSHarrietAkot /// per-dimension sparse/dense annotations. This data structure provides a 198*8a91bc7bSHarrietAkot /// bufferized form of a sparse tensor type. In contrast to generating setup 199*8a91bc7bSHarrietAkot /// methods for each differently annotated sparse tensor, this method provides 200*8a91bc7bSHarrietAkot /// a convenient "one-size-fits-all" solution that simply takes an input tensor 201*8a91bc7bSHarrietAkot /// and annotations to implement all required setup in a general manner. 202*8a91bc7bSHarrietAkot template <typename P, typename I, typename V> 203*8a91bc7bSHarrietAkot class SparseTensorStorage : public SparseTensorStorageBase { 204*8a91bc7bSHarrietAkot public: 205*8a91bc7bSHarrietAkot /// Constructs a sparse tensor storage scheme with the given dimensions, 206*8a91bc7bSHarrietAkot /// permutation, and per-dimension dense/sparse annotations, using 207*8a91bc7bSHarrietAkot /// the coordinate scheme tensor for the initial contents if provided. 208*8a91bc7bSHarrietAkot SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm, 209*8a91bc7bSHarrietAkot const uint8_t *sparsity, SparseTensorCOO<V> *tensor) 210*8a91bc7bSHarrietAkot : sizes(szs), rev(getRank()), pointers(getRank()), indices(getRank()) { 211*8a91bc7bSHarrietAkot uint64_t rank = getRank(); 212*8a91bc7bSHarrietAkot // Store "reverse" permutation. 213*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 214*8a91bc7bSHarrietAkot rev[perm[r]] = r; 215*8a91bc7bSHarrietAkot // Provide hints on capacity of pointers and indices. 216*8a91bc7bSHarrietAkot // TODO: needs fine-tuning based on sparsity 217*8a91bc7bSHarrietAkot for (uint64_t r = 0, s = 1; r < rank; r++) { 218*8a91bc7bSHarrietAkot s *= sizes[r]; 219*8a91bc7bSHarrietAkot if (sparsity[r] == kCompressed) { 220*8a91bc7bSHarrietAkot pointers[r].reserve(s + 1); 221*8a91bc7bSHarrietAkot indices[r].reserve(s); 222*8a91bc7bSHarrietAkot s = 1; 223*8a91bc7bSHarrietAkot } else { 224*8a91bc7bSHarrietAkot assert(sparsity[r] == kDense && "singleton not yet supported"); 225*8a91bc7bSHarrietAkot } 226*8a91bc7bSHarrietAkot } 227*8a91bc7bSHarrietAkot // Prepare sparse pointer structures for all dimensions. 228*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 229*8a91bc7bSHarrietAkot if (sparsity[r] == kCompressed) 230*8a91bc7bSHarrietAkot pointers[r].push_back(0); 231*8a91bc7bSHarrietAkot // Then assign contents from coordinate scheme tensor if provided. 232*8a91bc7bSHarrietAkot if (tensor) { 233*8a91bc7bSHarrietAkot uint64_t nnz = tensor->getElements().size(); 234*8a91bc7bSHarrietAkot values.reserve(nnz); 235*8a91bc7bSHarrietAkot fromCOO(tensor, sparsity, 0, nnz, 0); 236*8a91bc7bSHarrietAkot } 237*8a91bc7bSHarrietAkot } 238*8a91bc7bSHarrietAkot 239*8a91bc7bSHarrietAkot virtual ~SparseTensorStorage() {} 240*8a91bc7bSHarrietAkot 241*8a91bc7bSHarrietAkot /// Get the rank of the tensor. 242*8a91bc7bSHarrietAkot uint64_t getRank() const { return sizes.size(); } 243*8a91bc7bSHarrietAkot 244*8a91bc7bSHarrietAkot /// Get the size in the given dimension of the tensor. 245*8a91bc7bSHarrietAkot uint64_t getDimSize(uint64_t d) override { 246*8a91bc7bSHarrietAkot assert(d < getRank()); 247*8a91bc7bSHarrietAkot return sizes[d]; 248*8a91bc7bSHarrietAkot } 249*8a91bc7bSHarrietAkot 250*8a91bc7bSHarrietAkot // Partially specialize these three methods based on template types. 251*8a91bc7bSHarrietAkot void getPointers(std::vector<P> **out, uint64_t d) override { 252*8a91bc7bSHarrietAkot assert(d < getRank()); 253*8a91bc7bSHarrietAkot *out = &pointers[d]; 254*8a91bc7bSHarrietAkot } 255*8a91bc7bSHarrietAkot void getIndices(std::vector<I> **out, uint64_t d) override { 256*8a91bc7bSHarrietAkot assert(d < getRank()); 257*8a91bc7bSHarrietAkot *out = &indices[d]; 258*8a91bc7bSHarrietAkot } 259*8a91bc7bSHarrietAkot void getValues(std::vector<V> **out) override { *out = &values; } 260*8a91bc7bSHarrietAkot 261*8a91bc7bSHarrietAkot /// Returns this sparse tensor storage scheme as a new memory-resident 262*8a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme with the given dimension order. 263*8a91bc7bSHarrietAkot SparseTensorCOO<V> *toCOO(const uint64_t *perm) { 264*8a91bc7bSHarrietAkot // Restore original order of the dimension sizes and allocate coordinate 265*8a91bc7bSHarrietAkot // scheme with desired new ordering specified in perm. 266*8a91bc7bSHarrietAkot uint64_t rank = getRank(); 267*8a91bc7bSHarrietAkot std::vector<uint64_t> orgsz(rank); 268*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 269*8a91bc7bSHarrietAkot orgsz[rev[r]] = sizes[r]; 270*8a91bc7bSHarrietAkot SparseTensorCOO<V> *tensor = SparseTensorCOO<V>::newSparseTensorCOO( 271*8a91bc7bSHarrietAkot rank, orgsz.data(), perm, values.size()); 272*8a91bc7bSHarrietAkot // Populate coordinate scheme restored from old ordering and changed with 273*8a91bc7bSHarrietAkot // new ordering. Rather than applying both reorderings during the recursion, 274*8a91bc7bSHarrietAkot // we compute the combine permutation in advance. 275*8a91bc7bSHarrietAkot std::vector<uint64_t> reord(rank); 276*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 277*8a91bc7bSHarrietAkot reord[r] = perm[rev[r]]; 278*8a91bc7bSHarrietAkot std::vector<uint64_t> idx(rank); 279*8a91bc7bSHarrietAkot toCOO(tensor, reord, idx, 0, 0); 280*8a91bc7bSHarrietAkot assert(tensor->getElements().size() == values.size()); 281*8a91bc7bSHarrietAkot return tensor; 282*8a91bc7bSHarrietAkot } 283*8a91bc7bSHarrietAkot 284*8a91bc7bSHarrietAkot /// Factory method. Constructs a sparse tensor storage scheme with the given 285*8a91bc7bSHarrietAkot /// dimensions, permutation, and per-dimension dense/sparse annotations, 286*8a91bc7bSHarrietAkot /// using the coordinate scheme tensor for the initial contents if provided. 287*8a91bc7bSHarrietAkot /// In the latter case, the coordinate scheme must respect the same 288*8a91bc7bSHarrietAkot /// permutation as is desired for the new sparse tensor storage. 289*8a91bc7bSHarrietAkot static SparseTensorStorage<P, I, V> * 290*8a91bc7bSHarrietAkot newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm, 291*8a91bc7bSHarrietAkot const uint8_t *sparsity, SparseTensorCOO<V> *tensor) { 292*8a91bc7bSHarrietAkot SparseTensorStorage<P, I, V> *n = nullptr; 293*8a91bc7bSHarrietAkot if (tensor) { 294*8a91bc7bSHarrietAkot assert(tensor->getRank() == rank); 295*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 296*8a91bc7bSHarrietAkot assert(sizes[r] == 0 || tensor->getSizes()[perm[r]] == sizes[r]); 297*8a91bc7bSHarrietAkot tensor->sort(); // sort lexicographically 298*8a91bc7bSHarrietAkot n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity, 299*8a91bc7bSHarrietAkot tensor); 300*8a91bc7bSHarrietAkot delete tensor; 301*8a91bc7bSHarrietAkot } else { 302*8a91bc7bSHarrietAkot std::vector<uint64_t> permsz(rank); 303*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 304*8a91bc7bSHarrietAkot permsz[perm[r]] = sizes[r]; 305*8a91bc7bSHarrietAkot n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, tensor); 306*8a91bc7bSHarrietAkot } 307*8a91bc7bSHarrietAkot return n; 308*8a91bc7bSHarrietAkot } 309*8a91bc7bSHarrietAkot 310*8a91bc7bSHarrietAkot private: 311*8a91bc7bSHarrietAkot /// Initializes sparse tensor storage scheme from a memory-resident sparse 312*8a91bc7bSHarrietAkot /// tensor in coordinate scheme. This method prepares the pointers and 313*8a91bc7bSHarrietAkot /// indices arrays under the given per-dimension dense/sparse annotations. 314*8a91bc7bSHarrietAkot void fromCOO(SparseTensorCOO<V> *tensor, const uint8_t *sparsity, uint64_t lo, 315*8a91bc7bSHarrietAkot uint64_t hi, uint64_t d) { 316*8a91bc7bSHarrietAkot const std::vector<Element<V>> &elements = tensor->getElements(); 317*8a91bc7bSHarrietAkot // Once dimensions are exhausted, insert the numerical values. 318*8a91bc7bSHarrietAkot if (d == getRank()) { 319*8a91bc7bSHarrietAkot assert(lo >= hi || lo < elements.size()); 320*8a91bc7bSHarrietAkot values.push_back(lo < hi ? elements[lo].value : 0); 321*8a91bc7bSHarrietAkot return; 322*8a91bc7bSHarrietAkot } 323*8a91bc7bSHarrietAkot assert(d < getRank()); 324*8a91bc7bSHarrietAkot // Visit all elements in this interval. 325*8a91bc7bSHarrietAkot uint64_t full = 0; 326*8a91bc7bSHarrietAkot while (lo < hi) { 327*8a91bc7bSHarrietAkot assert(lo < elements.size() && hi <= elements.size()); 328*8a91bc7bSHarrietAkot // Find segment in interval with same index elements in this dimension. 329*8a91bc7bSHarrietAkot uint64_t idx = elements[lo].indices[d]; 330*8a91bc7bSHarrietAkot uint64_t seg = lo + 1; 331*8a91bc7bSHarrietAkot while (seg < hi && elements[seg].indices[d] == idx) 332*8a91bc7bSHarrietAkot seg++; 333*8a91bc7bSHarrietAkot // Handle segment in interval for sparse or dense dimension. 334*8a91bc7bSHarrietAkot if (sparsity[d] == kCompressed) { 335*8a91bc7bSHarrietAkot indices[d].push_back(idx); 336*8a91bc7bSHarrietAkot } else { 337*8a91bc7bSHarrietAkot // For dense storage we must fill in all the zero values between 338*8a91bc7bSHarrietAkot // the previous element (when last we ran this for-loop) and the 339*8a91bc7bSHarrietAkot // current element. 340*8a91bc7bSHarrietAkot for (; full < idx; full++) 341*8a91bc7bSHarrietAkot fromCOO(tensor, sparsity, 0, 0, d + 1); // pass empty 342*8a91bc7bSHarrietAkot full++; 343*8a91bc7bSHarrietAkot } 344*8a91bc7bSHarrietAkot fromCOO(tensor, sparsity, lo, seg, d + 1); 345*8a91bc7bSHarrietAkot // And move on to next segment in interval. 346*8a91bc7bSHarrietAkot lo = seg; 347*8a91bc7bSHarrietAkot } 348*8a91bc7bSHarrietAkot // Finalize the sparse pointer structure at this dimension. 349*8a91bc7bSHarrietAkot if (sparsity[d] == kCompressed) { 350*8a91bc7bSHarrietAkot pointers[d].push_back(indices[d].size()); 351*8a91bc7bSHarrietAkot } else { 352*8a91bc7bSHarrietAkot // For dense storage we must fill in all the zero values after 353*8a91bc7bSHarrietAkot // the last element. 354*8a91bc7bSHarrietAkot for (uint64_t sz = sizes[d]; full < sz; full++) 355*8a91bc7bSHarrietAkot fromCOO(tensor, sparsity, 0, 0, d + 1); // pass empty 356*8a91bc7bSHarrietAkot } 357*8a91bc7bSHarrietAkot } 358*8a91bc7bSHarrietAkot 359*8a91bc7bSHarrietAkot /// Stores the sparse tensor storage scheme into a memory-resident sparse 360*8a91bc7bSHarrietAkot /// tensor in coordinate scheme. 361*8a91bc7bSHarrietAkot void toCOO(SparseTensorCOO<V> *tensor, std::vector<uint64_t> &reord, 362*8a91bc7bSHarrietAkot std::vector<uint64_t> &idx, uint64_t pos, uint64_t d) { 363*8a91bc7bSHarrietAkot assert(d <= getRank()); 364*8a91bc7bSHarrietAkot if (d == getRank()) { 365*8a91bc7bSHarrietAkot assert(pos < values.size()); 366*8a91bc7bSHarrietAkot tensor->add(idx, values[pos]); 367*8a91bc7bSHarrietAkot } else if (pointers[d].empty()) { 368*8a91bc7bSHarrietAkot // Dense dimension. 369*8a91bc7bSHarrietAkot for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) { 370*8a91bc7bSHarrietAkot idx[reord[d]] = i; 371*8a91bc7bSHarrietAkot toCOO(tensor, reord, idx, off + i, d + 1); 372*8a91bc7bSHarrietAkot } 373*8a91bc7bSHarrietAkot } else { 374*8a91bc7bSHarrietAkot // Sparse dimension. 375*8a91bc7bSHarrietAkot for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) { 376*8a91bc7bSHarrietAkot idx[reord[d]] = indices[d][ii]; 377*8a91bc7bSHarrietAkot toCOO(tensor, reord, idx, ii, d + 1); 378*8a91bc7bSHarrietAkot } 379*8a91bc7bSHarrietAkot } 380*8a91bc7bSHarrietAkot } 381*8a91bc7bSHarrietAkot 382*8a91bc7bSHarrietAkot private: 383*8a91bc7bSHarrietAkot std::vector<uint64_t> sizes; // per-dimension sizes 384*8a91bc7bSHarrietAkot std::vector<uint64_t> rev; // "reverse" permutation 385*8a91bc7bSHarrietAkot std::vector<std::vector<P>> pointers; 386*8a91bc7bSHarrietAkot std::vector<std::vector<I>> indices; 387*8a91bc7bSHarrietAkot std::vector<V> values; 388*8a91bc7bSHarrietAkot }; 389*8a91bc7bSHarrietAkot 390*8a91bc7bSHarrietAkot /// Helper to convert string to lower case. 391*8a91bc7bSHarrietAkot static char *toLower(char *token) { 392*8a91bc7bSHarrietAkot for (char *c = token; *c; c++) 393*8a91bc7bSHarrietAkot *c = tolower(*c); 394*8a91bc7bSHarrietAkot return token; 395*8a91bc7bSHarrietAkot } 396*8a91bc7bSHarrietAkot 397*8a91bc7bSHarrietAkot /// Read the MME header of a general sparse matrix of type real. 398*8a91bc7bSHarrietAkot static void readMMEHeader(FILE *file, char *name, uint64_t *idata) { 399*8a91bc7bSHarrietAkot char line[1025]; 400*8a91bc7bSHarrietAkot char header[64]; 401*8a91bc7bSHarrietAkot char object[64]; 402*8a91bc7bSHarrietAkot char format[64]; 403*8a91bc7bSHarrietAkot char field[64]; 404*8a91bc7bSHarrietAkot char symmetry[64]; 405*8a91bc7bSHarrietAkot // Read header line. 406*8a91bc7bSHarrietAkot if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field, 407*8a91bc7bSHarrietAkot symmetry) != 5) { 408*8a91bc7bSHarrietAkot fprintf(stderr, "Corrupt header in %s\n", name); 409*8a91bc7bSHarrietAkot exit(1); 410*8a91bc7bSHarrietAkot } 411*8a91bc7bSHarrietAkot // Make sure this is a general sparse matrix. 412*8a91bc7bSHarrietAkot if (strcmp(toLower(header), "%%matrixmarket") || 413*8a91bc7bSHarrietAkot strcmp(toLower(object), "matrix") || 414*8a91bc7bSHarrietAkot strcmp(toLower(format), "coordinate") || strcmp(toLower(field), "real") || 415*8a91bc7bSHarrietAkot strcmp(toLower(symmetry), "general")) { 416*8a91bc7bSHarrietAkot fprintf(stderr, 417*8a91bc7bSHarrietAkot "Cannot find a general sparse matrix with type real in %s\n", name); 418*8a91bc7bSHarrietAkot exit(1); 419*8a91bc7bSHarrietAkot } 420*8a91bc7bSHarrietAkot // Skip comments. 421*8a91bc7bSHarrietAkot while (1) { 422*8a91bc7bSHarrietAkot if (!fgets(line, 1025, file)) { 423*8a91bc7bSHarrietAkot fprintf(stderr, "Cannot find data in %s\n", name); 424*8a91bc7bSHarrietAkot exit(1); 425*8a91bc7bSHarrietAkot } 426*8a91bc7bSHarrietAkot if (line[0] != '%') 427*8a91bc7bSHarrietAkot break; 428*8a91bc7bSHarrietAkot } 429*8a91bc7bSHarrietAkot // Next line contains M N NNZ. 430*8a91bc7bSHarrietAkot idata[0] = 2; // rank 431*8a91bc7bSHarrietAkot if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3, 432*8a91bc7bSHarrietAkot idata + 1) != 3) { 433*8a91bc7bSHarrietAkot fprintf(stderr, "Cannot find size in %s\n", name); 434*8a91bc7bSHarrietAkot exit(1); 435*8a91bc7bSHarrietAkot } 436*8a91bc7bSHarrietAkot } 437*8a91bc7bSHarrietAkot 438*8a91bc7bSHarrietAkot /// Read the "extended" FROSTT header. Although not part of the documented 439*8a91bc7bSHarrietAkot /// format, we assume that the file starts with optional comments followed 440*8a91bc7bSHarrietAkot /// by two lines that define the rank, the number of nonzeros, and the 441*8a91bc7bSHarrietAkot /// dimensions sizes (one per rank) of the sparse tensor. 442*8a91bc7bSHarrietAkot static void readExtFROSTTHeader(FILE *file, char *name, uint64_t *idata) { 443*8a91bc7bSHarrietAkot char line[1025]; 444*8a91bc7bSHarrietAkot // Skip comments. 445*8a91bc7bSHarrietAkot while (1) { 446*8a91bc7bSHarrietAkot if (!fgets(line, 1025, file)) { 447*8a91bc7bSHarrietAkot fprintf(stderr, "Cannot find data in %s\n", name); 448*8a91bc7bSHarrietAkot exit(1); 449*8a91bc7bSHarrietAkot } 450*8a91bc7bSHarrietAkot if (line[0] != '#') 451*8a91bc7bSHarrietAkot break; 452*8a91bc7bSHarrietAkot } 453*8a91bc7bSHarrietAkot // Next line contains RANK and NNZ. 454*8a91bc7bSHarrietAkot if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) { 455*8a91bc7bSHarrietAkot fprintf(stderr, "Cannot find metadata in %s\n", name); 456*8a91bc7bSHarrietAkot exit(1); 457*8a91bc7bSHarrietAkot } 458*8a91bc7bSHarrietAkot // Followed by a line with the dimension sizes (one per rank). 459*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < idata[0]; r++) { 460*8a91bc7bSHarrietAkot if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) { 461*8a91bc7bSHarrietAkot fprintf(stderr, "Cannot find dimension size %s\n", name); 462*8a91bc7bSHarrietAkot exit(1); 463*8a91bc7bSHarrietAkot } 464*8a91bc7bSHarrietAkot } 465*8a91bc7bSHarrietAkot } 466*8a91bc7bSHarrietAkot 467*8a91bc7bSHarrietAkot /// Reads a sparse tensor with the given filename into a memory-resident 468*8a91bc7bSHarrietAkot /// sparse tensor in coordinate scheme. 469*8a91bc7bSHarrietAkot template <typename V> 470*8a91bc7bSHarrietAkot static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank, 471*8a91bc7bSHarrietAkot const uint64_t *sizes, 472*8a91bc7bSHarrietAkot const uint64_t *perm) { 473*8a91bc7bSHarrietAkot // Open the file. 474*8a91bc7bSHarrietAkot FILE *file = fopen(filename, "r"); 475*8a91bc7bSHarrietAkot if (!file) { 476*8a91bc7bSHarrietAkot fprintf(stderr, "Cannot find %s\n", filename); 477*8a91bc7bSHarrietAkot exit(1); 478*8a91bc7bSHarrietAkot } 479*8a91bc7bSHarrietAkot // Perform some file format dependent set up. 480*8a91bc7bSHarrietAkot uint64_t idata[512]; 481*8a91bc7bSHarrietAkot if (strstr(filename, ".mtx")) { 482*8a91bc7bSHarrietAkot readMMEHeader(file, filename, idata); 483*8a91bc7bSHarrietAkot } else if (strstr(filename, ".tns")) { 484*8a91bc7bSHarrietAkot readExtFROSTTHeader(file, filename, idata); 485*8a91bc7bSHarrietAkot } else { 486*8a91bc7bSHarrietAkot fprintf(stderr, "Unknown format %s\n", filename); 487*8a91bc7bSHarrietAkot exit(1); 488*8a91bc7bSHarrietAkot } 489*8a91bc7bSHarrietAkot // Prepare sparse tensor object with per-dimension sizes 490*8a91bc7bSHarrietAkot // and the number of nonzeros as initial capacity. 491*8a91bc7bSHarrietAkot assert(rank == idata[0] && "rank mismatch"); 492*8a91bc7bSHarrietAkot uint64_t nnz = idata[1]; 493*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 494*8a91bc7bSHarrietAkot assert((sizes[r] == 0 || sizes[r] == idata[2 + r]) && 495*8a91bc7bSHarrietAkot "dimension size mismatch"); 496*8a91bc7bSHarrietAkot SparseTensorCOO<V> *tensor = 497*8a91bc7bSHarrietAkot SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz); 498*8a91bc7bSHarrietAkot // Read all nonzero elements. 499*8a91bc7bSHarrietAkot std::vector<uint64_t> indices(rank); 500*8a91bc7bSHarrietAkot for (uint64_t k = 0; k < nnz; k++) { 501*8a91bc7bSHarrietAkot uint64_t idx = -1; 502*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) { 503*8a91bc7bSHarrietAkot if (fscanf(file, "%" PRIu64, &idx) != 1) { 504*8a91bc7bSHarrietAkot fprintf(stderr, "Cannot find next index in %s\n", filename); 505*8a91bc7bSHarrietAkot exit(1); 506*8a91bc7bSHarrietAkot } 507*8a91bc7bSHarrietAkot // Add 0-based index. 508*8a91bc7bSHarrietAkot indices[perm[r]] = idx - 1; 509*8a91bc7bSHarrietAkot } 510*8a91bc7bSHarrietAkot // The external formats always store the numerical values with the type 511*8a91bc7bSHarrietAkot // double, but we cast these values to the sparse tensor object type. 512*8a91bc7bSHarrietAkot double value; 513*8a91bc7bSHarrietAkot if (fscanf(file, "%lg\n", &value) != 1) { 514*8a91bc7bSHarrietAkot fprintf(stderr, "Cannot find next value in %s\n", filename); 515*8a91bc7bSHarrietAkot exit(1); 516*8a91bc7bSHarrietAkot } 517*8a91bc7bSHarrietAkot tensor->add(indices, value); 518*8a91bc7bSHarrietAkot } 519*8a91bc7bSHarrietAkot // Close the file and return tensor. 520*8a91bc7bSHarrietAkot fclose(file); 521*8a91bc7bSHarrietAkot return tensor; 522*8a91bc7bSHarrietAkot } 523*8a91bc7bSHarrietAkot 524*8a91bc7bSHarrietAkot } // anonymous namespace 525*8a91bc7bSHarrietAkot 526*8a91bc7bSHarrietAkot extern "C" { 527*8a91bc7bSHarrietAkot 528*8a91bc7bSHarrietAkot /// This type is used in the public API at all places where MLIR expects 529*8a91bc7bSHarrietAkot /// values with the built-in type "index". For now, we simply assume that 530*8a91bc7bSHarrietAkot /// type is 64-bit, but targets with different "index" bit widths should link 531*8a91bc7bSHarrietAkot /// with an alternatively built runtime support library. 532*8a91bc7bSHarrietAkot // TODO: support such targets? 533*8a91bc7bSHarrietAkot typedef uint64_t index_t; 534*8a91bc7bSHarrietAkot 535*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 536*8a91bc7bSHarrietAkot // 537*8a91bc7bSHarrietAkot // Public API with methods that operate on MLIR buffers (memrefs) to interact 538*8a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally. 539*8a91bc7bSHarrietAkot // These methods should be used exclusively by MLIR compiler-generated code. 540*8a91bc7bSHarrietAkot // 541*8a91bc7bSHarrietAkot // Some macro magic is used to generate implementations for all required type 542*8a91bc7bSHarrietAkot // combinations that can be called from MLIR compiler-generated code. 543*8a91bc7bSHarrietAkot // 544*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 545*8a91bc7bSHarrietAkot 546*8a91bc7bSHarrietAkot enum OverheadTypeEnum : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 }; 547*8a91bc7bSHarrietAkot 548*8a91bc7bSHarrietAkot enum PrimaryTypeEnum : uint32_t { 549*8a91bc7bSHarrietAkot kF64 = 1, 550*8a91bc7bSHarrietAkot kF32 = 2, 551*8a91bc7bSHarrietAkot kI64 = 3, 552*8a91bc7bSHarrietAkot kI32 = 4, 553*8a91bc7bSHarrietAkot kI16 = 5, 554*8a91bc7bSHarrietAkot kI8 = 6 555*8a91bc7bSHarrietAkot }; 556*8a91bc7bSHarrietAkot 557*8a91bc7bSHarrietAkot enum Action : uint32_t { 558*8a91bc7bSHarrietAkot kEmpty = 0, 559*8a91bc7bSHarrietAkot kFromFile = 1, 560*8a91bc7bSHarrietAkot kFromCOO = 2, 561*8a91bc7bSHarrietAkot kEmptyCOO = 3, 562*8a91bc7bSHarrietAkot kToCOO = 4, 563*8a91bc7bSHarrietAkot kToIter = 5 564*8a91bc7bSHarrietAkot }; 565*8a91bc7bSHarrietAkot 566*8a91bc7bSHarrietAkot #define CASE(p, i, v, P, I, V) \ 567*8a91bc7bSHarrietAkot if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \ 568*8a91bc7bSHarrietAkot SparseTensorCOO<V> *tensor = nullptr; \ 569*8a91bc7bSHarrietAkot if (action <= kFromCOO) { \ 570*8a91bc7bSHarrietAkot if (action == kFromFile) { \ 571*8a91bc7bSHarrietAkot char *filename = static_cast<char *>(ptr); \ 572*8a91bc7bSHarrietAkot tensor = openSparseTensorCOO<V>(filename, rank, sizes, perm); \ 573*8a91bc7bSHarrietAkot } else if (action == kFromCOO) { \ 574*8a91bc7bSHarrietAkot tensor = static_cast<SparseTensorCOO<V> *>(ptr); \ 575*8a91bc7bSHarrietAkot } else { \ 576*8a91bc7bSHarrietAkot assert(action == kEmpty); \ 577*8a91bc7bSHarrietAkot } \ 578*8a91bc7bSHarrietAkot return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm, \ 579*8a91bc7bSHarrietAkot sparsity, tensor); \ 580*8a91bc7bSHarrietAkot } else if (action == kEmptyCOO) { \ 581*8a91bc7bSHarrietAkot return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm); \ 582*8a91bc7bSHarrietAkot } else { \ 583*8a91bc7bSHarrietAkot tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \ 584*8a91bc7bSHarrietAkot if (action == kToIter) { \ 585*8a91bc7bSHarrietAkot tensor->startIterator(); \ 586*8a91bc7bSHarrietAkot } else { \ 587*8a91bc7bSHarrietAkot assert(action == kToCOO); \ 588*8a91bc7bSHarrietAkot } \ 589*8a91bc7bSHarrietAkot return tensor; \ 590*8a91bc7bSHarrietAkot } \ 591*8a91bc7bSHarrietAkot } 592*8a91bc7bSHarrietAkot 593*8a91bc7bSHarrietAkot #define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \ 594*8a91bc7bSHarrietAkot void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) { \ 595*8a91bc7bSHarrietAkot assert(ref); \ 596*8a91bc7bSHarrietAkot assert(tensor); \ 597*8a91bc7bSHarrietAkot std::vector<TYPE> *v; \ 598*8a91bc7bSHarrietAkot static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v); \ 599*8a91bc7bSHarrietAkot ref->basePtr = ref->data = v->data(); \ 600*8a91bc7bSHarrietAkot ref->offset = 0; \ 601*8a91bc7bSHarrietAkot ref->sizes[0] = v->size(); \ 602*8a91bc7bSHarrietAkot ref->strides[0] = 1; \ 603*8a91bc7bSHarrietAkot } 604*8a91bc7bSHarrietAkot 605*8a91bc7bSHarrietAkot #define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \ 606*8a91bc7bSHarrietAkot void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \ 607*8a91bc7bSHarrietAkot index_t d) { \ 608*8a91bc7bSHarrietAkot assert(ref); \ 609*8a91bc7bSHarrietAkot assert(tensor); \ 610*8a91bc7bSHarrietAkot std::vector<TYPE> *v; \ 611*8a91bc7bSHarrietAkot static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d); \ 612*8a91bc7bSHarrietAkot ref->basePtr = ref->data = v->data(); \ 613*8a91bc7bSHarrietAkot ref->offset = 0; \ 614*8a91bc7bSHarrietAkot ref->sizes[0] = v->size(); \ 615*8a91bc7bSHarrietAkot ref->strides[0] = 1; \ 616*8a91bc7bSHarrietAkot } 617*8a91bc7bSHarrietAkot 618*8a91bc7bSHarrietAkot #define IMPL_ADDELT(NAME, TYPE) \ 619*8a91bc7bSHarrietAkot void *_mlir_ciface_##NAME(void *tensor, TYPE value, \ 620*8a91bc7bSHarrietAkot StridedMemRefType<index_t, 1> *iref, \ 621*8a91bc7bSHarrietAkot StridedMemRefType<index_t, 1> *pref) { \ 622*8a91bc7bSHarrietAkot assert(tensor); \ 623*8a91bc7bSHarrietAkot assert(iref); \ 624*8a91bc7bSHarrietAkot assert(pref); \ 625*8a91bc7bSHarrietAkot assert(iref->strides[0] == 1 && pref->strides[0] == 1); \ 626*8a91bc7bSHarrietAkot assert(iref->sizes[0] == pref->sizes[0]); \ 627*8a91bc7bSHarrietAkot const index_t *indx = iref->data + iref->offset; \ 628*8a91bc7bSHarrietAkot const index_t *perm = pref->data + pref->offset; \ 629*8a91bc7bSHarrietAkot uint64_t isize = iref->sizes[0]; \ 630*8a91bc7bSHarrietAkot std::vector<index_t> indices(isize); \ 631*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < isize; r++) \ 632*8a91bc7bSHarrietAkot indices[perm[r]] = indx[r]; \ 633*8a91bc7bSHarrietAkot static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value); \ 634*8a91bc7bSHarrietAkot return tensor; \ 635*8a91bc7bSHarrietAkot } 636*8a91bc7bSHarrietAkot 637*8a91bc7bSHarrietAkot #define IMPL_GETNEXT(NAME, V) \ 638*8a91bc7bSHarrietAkot bool _mlir_ciface_##NAME(void *tensor, StridedMemRefType<uint64_t, 1> *iref, \ 639*8a91bc7bSHarrietAkot StridedMemRefType<V, 0> *vref) { \ 640*8a91bc7bSHarrietAkot assert(iref->strides[0] == 1); \ 641*8a91bc7bSHarrietAkot uint64_t *indx = iref->data + iref->offset; \ 642*8a91bc7bSHarrietAkot V *value = vref->data + vref->offset; \ 643*8a91bc7bSHarrietAkot const uint64_t isize = iref->sizes[0]; \ 644*8a91bc7bSHarrietAkot auto iter = static_cast<SparseTensorCOO<V> *>(tensor); \ 645*8a91bc7bSHarrietAkot const Element<V> *elem = iter->getNext(); \ 646*8a91bc7bSHarrietAkot if (elem == nullptr) { \ 647*8a91bc7bSHarrietAkot delete iter; \ 648*8a91bc7bSHarrietAkot return false; \ 649*8a91bc7bSHarrietAkot } \ 650*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < isize; r++) \ 651*8a91bc7bSHarrietAkot indx[r] = elem->indices[r]; \ 652*8a91bc7bSHarrietAkot *value = elem->value; \ 653*8a91bc7bSHarrietAkot return true; \ 654*8a91bc7bSHarrietAkot } 655*8a91bc7bSHarrietAkot 656*8a91bc7bSHarrietAkot /// Constructs a new sparse tensor. This is the "swiss army knife" 657*8a91bc7bSHarrietAkot /// method for materializing sparse tensors into the computation. 658*8a91bc7bSHarrietAkot /// 659*8a91bc7bSHarrietAkot /// action: 660*8a91bc7bSHarrietAkot /// kEmpty = returns empty storage to fill later 661*8a91bc7bSHarrietAkot /// kFromFile = returns storage, where ptr contains filename to read 662*8a91bc7bSHarrietAkot /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign 663*8a91bc7bSHarrietAkot /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO 664*8a91bc7bSHarrietAkot /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO 665*8a91bc7bSHarrietAkot /// kToIter = returns iterator from storage in ptr (call getNext() to use) 666*8a91bc7bSHarrietAkot void * 667*8a91bc7bSHarrietAkot _mlir_ciface_newSparseTensor(StridedMemRefType<uint8_t, 1> *aref, // NOLINT 668*8a91bc7bSHarrietAkot StridedMemRefType<index_t, 1> *sref, 669*8a91bc7bSHarrietAkot StridedMemRefType<index_t, 1> *pref, 670*8a91bc7bSHarrietAkot uint32_t ptrTp, uint32_t indTp, uint32_t valTp, 671*8a91bc7bSHarrietAkot uint32_t action, void *ptr) { 672*8a91bc7bSHarrietAkot assert(aref && sref && pref); 673*8a91bc7bSHarrietAkot assert(aref->strides[0] == 1 && sref->strides[0] == 1 && 674*8a91bc7bSHarrietAkot pref->strides[0] == 1); 675*8a91bc7bSHarrietAkot assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]); 676*8a91bc7bSHarrietAkot const uint8_t *sparsity = aref->data + aref->offset; 677*8a91bc7bSHarrietAkot const index_t *sizes = sref->data + sref->offset; 678*8a91bc7bSHarrietAkot const index_t *perm = pref->data + pref->offset; 679*8a91bc7bSHarrietAkot uint64_t rank = aref->sizes[0]; 680*8a91bc7bSHarrietAkot 681*8a91bc7bSHarrietAkot // Double matrices with all combinations of overhead storage. 682*8a91bc7bSHarrietAkot CASE(kU64, kU64, kF64, uint64_t, uint64_t, double); 683*8a91bc7bSHarrietAkot CASE(kU64, kU32, kF64, uint64_t, uint32_t, double); 684*8a91bc7bSHarrietAkot CASE(kU64, kU16, kF64, uint64_t, uint16_t, double); 685*8a91bc7bSHarrietAkot CASE(kU64, kU8, kF64, uint64_t, uint8_t, double); 686*8a91bc7bSHarrietAkot CASE(kU32, kU64, kF64, uint32_t, uint64_t, double); 687*8a91bc7bSHarrietAkot CASE(kU32, kU32, kF64, uint32_t, uint32_t, double); 688*8a91bc7bSHarrietAkot CASE(kU32, kU16, kF64, uint32_t, uint16_t, double); 689*8a91bc7bSHarrietAkot CASE(kU32, kU8, kF64, uint32_t, uint8_t, double); 690*8a91bc7bSHarrietAkot CASE(kU16, kU64, kF64, uint16_t, uint64_t, double); 691*8a91bc7bSHarrietAkot CASE(kU16, kU32, kF64, uint16_t, uint32_t, double); 692*8a91bc7bSHarrietAkot CASE(kU16, kU16, kF64, uint16_t, uint16_t, double); 693*8a91bc7bSHarrietAkot CASE(kU16, kU8, kF64, uint16_t, uint8_t, double); 694*8a91bc7bSHarrietAkot CASE(kU8, kU64, kF64, uint8_t, uint64_t, double); 695*8a91bc7bSHarrietAkot CASE(kU8, kU32, kF64, uint8_t, uint32_t, double); 696*8a91bc7bSHarrietAkot CASE(kU8, kU16, kF64, uint8_t, uint16_t, double); 697*8a91bc7bSHarrietAkot CASE(kU8, kU8, kF64, uint8_t, uint8_t, double); 698*8a91bc7bSHarrietAkot 699*8a91bc7bSHarrietAkot // Float matrices with all combinations of overhead storage. 700*8a91bc7bSHarrietAkot CASE(kU64, kU64, kF32, uint64_t, uint64_t, float); 701*8a91bc7bSHarrietAkot CASE(kU64, kU32, kF32, uint64_t, uint32_t, float); 702*8a91bc7bSHarrietAkot CASE(kU64, kU16, kF32, uint64_t, uint16_t, float); 703*8a91bc7bSHarrietAkot CASE(kU64, kU8, kF32, uint64_t, uint8_t, float); 704*8a91bc7bSHarrietAkot CASE(kU32, kU64, kF32, uint32_t, uint64_t, float); 705*8a91bc7bSHarrietAkot CASE(kU32, kU32, kF32, uint32_t, uint32_t, float); 706*8a91bc7bSHarrietAkot CASE(kU32, kU16, kF32, uint32_t, uint16_t, float); 707*8a91bc7bSHarrietAkot CASE(kU32, kU8, kF32, uint32_t, uint8_t, float); 708*8a91bc7bSHarrietAkot CASE(kU16, kU64, kF32, uint16_t, uint64_t, float); 709*8a91bc7bSHarrietAkot CASE(kU16, kU32, kF32, uint16_t, uint32_t, float); 710*8a91bc7bSHarrietAkot CASE(kU16, kU16, kF32, uint16_t, uint16_t, float); 711*8a91bc7bSHarrietAkot CASE(kU16, kU8, kF32, uint16_t, uint8_t, float); 712*8a91bc7bSHarrietAkot CASE(kU8, kU64, kF32, uint8_t, uint64_t, float); 713*8a91bc7bSHarrietAkot CASE(kU8, kU32, kF32, uint8_t, uint32_t, float); 714*8a91bc7bSHarrietAkot CASE(kU8, kU16, kF32, uint8_t, uint16_t, float); 715*8a91bc7bSHarrietAkot CASE(kU8, kU8, kF32, uint8_t, uint8_t, float); 716*8a91bc7bSHarrietAkot 717*8a91bc7bSHarrietAkot // Integral matrices with same overhead storage. 718*8a91bc7bSHarrietAkot CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t); 719*8a91bc7bSHarrietAkot CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t); 720*8a91bc7bSHarrietAkot CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t); 721*8a91bc7bSHarrietAkot CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t); 722*8a91bc7bSHarrietAkot CASE(kU32, kU32, kI32, uint32_t, uint32_t, int32_t); 723*8a91bc7bSHarrietAkot CASE(kU32, kU32, kI16, uint32_t, uint32_t, int16_t); 724*8a91bc7bSHarrietAkot CASE(kU32, kU32, kI8, uint32_t, uint32_t, int8_t); 725*8a91bc7bSHarrietAkot CASE(kU16, kU16, kI32, uint16_t, uint16_t, int32_t); 726*8a91bc7bSHarrietAkot CASE(kU16, kU16, kI16, uint16_t, uint16_t, int16_t); 727*8a91bc7bSHarrietAkot CASE(kU16, kU16, kI8, uint16_t, uint16_t, int8_t); 728*8a91bc7bSHarrietAkot CASE(kU8, kU8, kI32, uint8_t, uint8_t, int32_t); 729*8a91bc7bSHarrietAkot CASE(kU8, kU8, kI16, uint8_t, uint8_t, int16_t); 730*8a91bc7bSHarrietAkot CASE(kU8, kU8, kI8, uint8_t, uint8_t, int8_t); 731*8a91bc7bSHarrietAkot 732*8a91bc7bSHarrietAkot // Unsupported case (add above if needed). 733*8a91bc7bSHarrietAkot fputs("unsupported combination of types\n", stderr); 734*8a91bc7bSHarrietAkot exit(1); 735*8a91bc7bSHarrietAkot } 736*8a91bc7bSHarrietAkot 737*8a91bc7bSHarrietAkot /// Methods that provide direct access to pointers. 738*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers, index_t, getPointers) 739*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers) 740*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers) 741*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers) 742*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers) 743*8a91bc7bSHarrietAkot 744*8a91bc7bSHarrietAkot /// Methods that provide direct access to indices. 745*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices, index_t, getIndices) 746*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices) 747*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices) 748*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices) 749*8a91bc7bSHarrietAkot IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices) 750*8a91bc7bSHarrietAkot 751*8a91bc7bSHarrietAkot /// Methods that provide direct access to values. 752*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF64, double, getValues) 753*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesF32, float, getValues) 754*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues) 755*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues) 756*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues) 757*8a91bc7bSHarrietAkot IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues) 758*8a91bc7bSHarrietAkot 759*8a91bc7bSHarrietAkot /// Helper to add value to coordinate scheme, one per value type. 760*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltF64, double) 761*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltF32, float) 762*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltI64, int64_t) 763*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltI32, int32_t) 764*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltI16, int16_t) 765*8a91bc7bSHarrietAkot IMPL_ADDELT(addEltI8, int8_t) 766*8a91bc7bSHarrietAkot 767*8a91bc7bSHarrietAkot /// Helper to enumerate elements of coordinate scheme, one per value type. 768*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF64, double) 769*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextF32, float) 770*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI64, int64_t) 771*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI32, int32_t) 772*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI16, int16_t) 773*8a91bc7bSHarrietAkot IMPL_GETNEXT(getNextI8, int8_t) 774*8a91bc7bSHarrietAkot 775*8a91bc7bSHarrietAkot #undef CASE 776*8a91bc7bSHarrietAkot #undef IMPL_SPARSEVALUES 777*8a91bc7bSHarrietAkot #undef IMPL_GETOVERHEAD 778*8a91bc7bSHarrietAkot #undef IMPL_ADDELT 779*8a91bc7bSHarrietAkot #undef IMPL_GETNEXT 780*8a91bc7bSHarrietAkot 781*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 782*8a91bc7bSHarrietAkot // 783*8a91bc7bSHarrietAkot // Public API with methods that accept C-style data structures to interact 784*8a91bc7bSHarrietAkot // with sparse tensors, which are only visible as opaque pointers externally. 785*8a91bc7bSHarrietAkot // These methods can be used both by MLIR compiler-generated code as well as by 786*8a91bc7bSHarrietAkot // an external runtime that wants to interact with MLIR compiler-generated code. 787*8a91bc7bSHarrietAkot // 788*8a91bc7bSHarrietAkot //===----------------------------------------------------------------------===// 789*8a91bc7bSHarrietAkot 790*8a91bc7bSHarrietAkot /// Helper method to read a sparse tensor filename from the environment, 791*8a91bc7bSHarrietAkot /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc. 792*8a91bc7bSHarrietAkot char *getTensorFilename(index_t id) { 793*8a91bc7bSHarrietAkot char var[80]; 794*8a91bc7bSHarrietAkot sprintf(var, "TENSOR%" PRIu64, id); 795*8a91bc7bSHarrietAkot char *env = getenv(var); 796*8a91bc7bSHarrietAkot return env; 797*8a91bc7bSHarrietAkot } 798*8a91bc7bSHarrietAkot 799*8a91bc7bSHarrietAkot /// Returns size of sparse tensor in given dimension. 800*8a91bc7bSHarrietAkot index_t sparseDimSize(void *tensor, index_t d) { 801*8a91bc7bSHarrietAkot return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d); 802*8a91bc7bSHarrietAkot } 803*8a91bc7bSHarrietAkot 804*8a91bc7bSHarrietAkot /// Releases sparse tensor storage. 805*8a91bc7bSHarrietAkot void delSparseTensor(void *tensor) { 806*8a91bc7bSHarrietAkot delete static_cast<SparseTensorStorageBase *>(tensor); 807*8a91bc7bSHarrietAkot } 808*8a91bc7bSHarrietAkot 809*8a91bc7bSHarrietAkot /// Initializes sparse tensor from a COO-flavored format expressed using C-style 810*8a91bc7bSHarrietAkot /// data structures. The expected parameters are: 811*8a91bc7bSHarrietAkot /// 812*8a91bc7bSHarrietAkot /// rank: rank of tensor 813*8a91bc7bSHarrietAkot /// nse: number of specified elements (usually the nonzeros) 814*8a91bc7bSHarrietAkot /// shape: array with dimension size for each rank 815*8a91bc7bSHarrietAkot /// values: a "nse" array with values for all specified elements 816*8a91bc7bSHarrietAkot /// indices: a flat "nse x rank" array with indices for all specified elements 817*8a91bc7bSHarrietAkot /// 818*8a91bc7bSHarrietAkot /// For example, the sparse matrix 819*8a91bc7bSHarrietAkot /// | 1.0 0.0 0.0 | 820*8a91bc7bSHarrietAkot /// | 0.0 5.0 3.0 | 821*8a91bc7bSHarrietAkot /// can be passed as 822*8a91bc7bSHarrietAkot /// rank = 2 823*8a91bc7bSHarrietAkot /// nse = 3 824*8a91bc7bSHarrietAkot /// shape = [2, 3] 825*8a91bc7bSHarrietAkot /// values = [1.0, 5.0, 3.0] 826*8a91bc7bSHarrietAkot /// indices = [ 0, 0, 1, 1, 1, 2] 827*8a91bc7bSHarrietAkot // 828*8a91bc7bSHarrietAkot // TODO: for now f64 tensors only, no dim ordering, all dimensions compressed 829*8a91bc7bSHarrietAkot // 830*8a91bc7bSHarrietAkot void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, 831*8a91bc7bSHarrietAkot double *values, uint64_t *indices) { 832*8a91bc7bSHarrietAkot // Setup all-dims compressed and default ordering. 833*8a91bc7bSHarrietAkot std::vector<uint8_t> sparse(rank, SparseTensorStorageBase::kCompressed); 834*8a91bc7bSHarrietAkot std::vector<uint64_t> perm(rank); 835*8a91bc7bSHarrietAkot std::iota(perm.begin(), perm.end(), 0); 836*8a91bc7bSHarrietAkot // Convert external format to internal COO. 837*8a91bc7bSHarrietAkot SparseTensorCOO<double> *tensor = SparseTensorCOO<double>::newSparseTensorCOO( 838*8a91bc7bSHarrietAkot rank, shape, perm.data(), nse); 839*8a91bc7bSHarrietAkot std::vector<uint64_t> idx(rank); 840*8a91bc7bSHarrietAkot for (uint64_t i = 0, base = 0; i < nse; i++) { 841*8a91bc7bSHarrietAkot for (uint64_t r = 0; r < rank; r++) 842*8a91bc7bSHarrietAkot idx[r] = indices[base + r]; 843*8a91bc7bSHarrietAkot tensor->add(idx, values[i]); 844*8a91bc7bSHarrietAkot base += rank; 845*8a91bc7bSHarrietAkot } 846*8a91bc7bSHarrietAkot // Return sparse tensor storage format as opaque pointer. 847*8a91bc7bSHarrietAkot return SparseTensorStorage<uint64_t, uint64_t, double>::newSparseTensor( 848*8a91bc7bSHarrietAkot rank, shape, perm.data(), sparse.data(), tensor); 849*8a91bc7bSHarrietAkot } 850*8a91bc7bSHarrietAkot 851*8a91bc7bSHarrietAkot } // extern "C" 852*8a91bc7bSHarrietAkot 853*8a91bc7bSHarrietAkot #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 854