196a23911SAart Bik //===- Sparsification.cpp - Implementation of sparsification --------------===// 2a2c9d4bbSAart Bik // 3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information. 5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a2c9d4bbSAart Bik // 7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 8a2c9d4bbSAart Bik // 996a23911SAart Bik // This file implements lowering sparse tensor types to actual sparse code. 10a2c9d4bbSAart Bik // 11a2c9d4bbSAart Bik // The concept of letting a compiler generate sparse code automatically was 12a2c9d4bbSAart Bik // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and 13a2c9d4bbSAart Bik // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor 14a2c9d4bbSAart Bik // Algebra Compiler (TACO). The implementation in this file closely follows 15a2c9d4bbSAart Bik // the "sparse iteration theory" that forms the foundation of TACO. A rewriting 16a2c9d4bbSAart Bik // rule is applied to each tensor expression in linalg (MLIR's tensor index 17a2c9d4bbSAart Bik // notation) where the sparsity of tensors is indicated with annotation using 18a2c9d4bbSAart Bik // a per-dimension specification of sparse/dense storage together with a 19a2c9d4bbSAart Bik // specification of the order on the dimensions. Subsequently, a topologically 20a2c9d4bbSAart Bik // sorted iteration graph, reflecting the required order on indices with respect 21a2c9d4bbSAart Bik // to the dimensions of each tensor, is constructed to ensure that all tensors 22a2c9d4bbSAart Bik // are visited in natural index order. Next, iteration lattices are constructed 23a2c9d4bbSAart Bik // for the tensor expression for every index in topological order. Each 24a2c9d4bbSAart Bik // iteration lattice point consists of a conjunction of tensor indices together 25a2c9d4bbSAart Bik // with a tensor (sub)expression that needs to be evaluated for that 26a2c9d4bbSAart Bik // conjunction. Within the lattice, iteration points are ordered according to 27a2c9d4bbSAart Bik // the way indices are exhausted. As such these iteration lattices drive actual 28a2c9d4bbSAart Bik // sparse code generation, which consists of a tedious but relatively 29a2c9d4bbSAart Bik // straightforward one-to-one mapping from iteration lattices to combinations 30a2c9d4bbSAart Bik // of for-loops, while-loops, and if-statements. 31a2c9d4bbSAart Bik // 32a2c9d4bbSAart Bik // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations. 33a2c9d4bbSAart Bik // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php). 34a2c9d4bbSAart Bik // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou, 35a2c9d4bbSAart Bik // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler. 36a2c9d4bbSAart Bik // Proceedings of the ACM on Programming Languages, October 2017. 37a2c9d4bbSAart Bik // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation. 38a2c9d4bbSAart Bik // PhD thesis, MIT, February, 2020 (tensor-compiler.org). 39a2c9d4bbSAart Bik // 40a2c9d4bbSAart Bik // Implementation detail: We use llvm::SmallVector for vectors with 41a2c9d4bbSAart Bik // variable lengths and std::vector for vectors with fixed lengths. 42a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 43a2c9d4bbSAart Bik 44a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 45a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h" 46a2c9d4bbSAart Bik #include "mlir/Dialect/SCF/SCF.h" 47a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 48a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 49a2c9d4bbSAart Bik #include "mlir/Dialect/StandardOps/IR/Ops.h" 50a2c9d4bbSAart Bik #include "mlir/Dialect/Vector/VectorOps.h" 51a2c9d4bbSAart Bik #include "mlir/IR/Matchers.h" 5296a23911SAart Bik #include "mlir/IR/TensorEncoding.h" 53a2c9d4bbSAart Bik #include "llvm/ADT/SmallBitVector.h" 54a2c9d4bbSAart Bik 55a2c9d4bbSAart Bik using namespace mlir; 5696a23911SAart Bik using namespace mlir::sparse_tensor; 57a2c9d4bbSAart Bik 58a2c9d4bbSAart Bik namespace { 59a2c9d4bbSAart Bik 60a2c9d4bbSAart Bik enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; 6196a23911SAart Bik enum class Dim { kSparse, kDense, kSingle, kUndef }; 62a2c9d4bbSAart Bik 63a2c9d4bbSAart Bik /// Tensor expression. Represents a MLIR expression in tensor index notation. 64a2c9d4bbSAart Bik /// For tensors, e0 denotes the tensor index. For invariants, the IR value is 65a2c9d4bbSAart Bik /// stored directly. For binary operations, e0 and e1 denote the index of the 66a2c9d4bbSAart Bik /// children tensor expressions. 67a2c9d4bbSAart Bik struct TensorExp { 68a2c9d4bbSAart Bik TensorExp(Kind k, unsigned x, unsigned y, Value v) 69a2c9d4bbSAart Bik : kind(k), e0(x), e1(y), val(v) { 70a2c9d4bbSAart Bik assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || 71a2c9d4bbSAart Bik (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || 72a2c9d4bbSAart Bik (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); 73a2c9d4bbSAart Bik } 74a2c9d4bbSAart Bik Kind kind; 75a2c9d4bbSAart Bik /// Indices of children expression(s). 76a2c9d4bbSAart Bik unsigned e0; 77a2c9d4bbSAart Bik unsigned e1; 78a2c9d4bbSAart Bik /// Direct link to IR for an invariant. During code generation, 79a2c9d4bbSAart Bik /// field is used to cache "hoisted" loop invariant tensor loads. 80a2c9d4bbSAart Bik Value val; 81a2c9d4bbSAart Bik }; 82a2c9d4bbSAart Bik 83a2c9d4bbSAart Bik /// Lattice point. Each lattice point consists of a conjunction of tensor 84a2c9d4bbSAart Bik /// loop indices (encoded in a bitvector) and the index of the corresponding 85a2c9d4bbSAart Bik /// tensor expression. 86a2c9d4bbSAart Bik struct LatPoint { 87a2c9d4bbSAart Bik LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) { 88a2c9d4bbSAart Bik bits.set(b); 89a2c9d4bbSAart Bik } 90a2c9d4bbSAart Bik LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {} 91a2c9d4bbSAart Bik /// Conjunction of tensor loop indices as bitvector. This represents 92a2c9d4bbSAart Bik /// all indices involved in the tensor expression 93a2c9d4bbSAart Bik llvm::BitVector bits; 94a2c9d4bbSAart Bik /// Simplified conjunction of tensor loop indices as bitvector. This 95a2c9d4bbSAart Bik /// represents a simplified condition under which this tensor expression 96a2c9d4bbSAart Bik /// must execute. Pre-computed during codegen to avoid repeated eval. 97a2c9d4bbSAart Bik llvm::BitVector simple; 98a2c9d4bbSAart Bik /// Index of the tensor expresssion. 99a2c9d4bbSAart Bik unsigned exp; 100a2c9d4bbSAart Bik }; 101a2c9d4bbSAart Bik 102a2c9d4bbSAart Bik /// A class to handle all iteration lattice operations. This class abstracts 103a2c9d4bbSAart Bik /// away from some implementation details of storing iteration lattices and 104a2c9d4bbSAart Bik /// tensor expressions. This allows for fine-tuning performance characteristics 105a2c9d4bbSAart Bik /// independently from the basic algorithm if bottlenecks are identified. 106a2c9d4bbSAart Bik class Merger { 107a2c9d4bbSAart Bik public: 108a2c9d4bbSAart Bik /// Constructs a merger for the given number of tensors and loops. The 109a2c9d4bbSAart Bik /// user supplies the number of tensors involved in the kernel, with the 110a2c9d4bbSAart Bik /// last tensor in this set denoting the output tensor. The merger adds an 111a2c9d4bbSAart Bik /// additional synthetic tensor at the end of this set to represent all 112a2c9d4bbSAart Bik /// invariant expressions in the kernel. 113a2c9d4bbSAart Bik Merger(unsigned t, unsigned l) 114a2c9d4bbSAart Bik : outTensor(t - 1), numTensors(t + 1), numLoops(l), 115a2c9d4bbSAart Bik dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {} 116a2c9d4bbSAart Bik 117a2c9d4bbSAart Bik /// Adds a tensor expression. Returns its index. 118a2c9d4bbSAart Bik unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) { 119a2c9d4bbSAart Bik unsigned e = tensorExps.size(); 120a2c9d4bbSAart Bik tensorExps.push_back(TensorExp(k, e0, e1, v)); 121a2c9d4bbSAart Bik return e; 122a2c9d4bbSAart Bik } 123a2c9d4bbSAart Bik unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } 124a2c9d4bbSAart Bik 125a2c9d4bbSAart Bik /// Adds an iteration lattice point. Returns its index. 126a2c9d4bbSAart Bik unsigned addLat(unsigned t, unsigned i, unsigned e) { 127a2c9d4bbSAart Bik assert(t < numTensors && i < numLoops); 128a2c9d4bbSAart Bik unsigned p = latPoints.size(); 129a2c9d4bbSAart Bik latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 130a2c9d4bbSAart Bik return p; 131a2c9d4bbSAart Bik } 132a2c9d4bbSAart Bik 133a2c9d4bbSAart Bik /// Adds a new, initially empty, set. Returns its index. 134a2c9d4bbSAart Bik unsigned addSet() { 135a2c9d4bbSAart Bik unsigned s = latSets.size(); 136a2c9d4bbSAart Bik latSets.emplace_back(SmallVector<unsigned, 16>()); 137a2c9d4bbSAart Bik return s; 138a2c9d4bbSAart Bik } 139a2c9d4bbSAart Bik 140a2c9d4bbSAart Bik /// Computes a single conjunction of two lattice points by taking the "union" 141a2c9d4bbSAart Bik /// of loop indices (effectively constructing a larger "intersection" of those 142a2c9d4bbSAart Bik /// indices) with a newly constructed tensor (sub)expression of given kind. 143a2c9d4bbSAart Bik /// Returns the index of the new lattice point. 144a2c9d4bbSAart Bik unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 145a2c9d4bbSAart Bik unsigned p = latPoints.size(); 146a2c9d4bbSAart Bik llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 147a2c9d4bbSAart Bik nb |= latPoints[p1].bits; 148a2c9d4bbSAart Bik unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 149a2c9d4bbSAart Bik latPoints.push_back(LatPoint(nb, e)); 150a2c9d4bbSAart Bik return p; 151a2c9d4bbSAart Bik } 152a2c9d4bbSAart Bik 153a2c9d4bbSAart Bik /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of 154a2c9d4bbSAart Bik /// cartesian product. Returns the index of the new set. 155a2c9d4bbSAart Bik unsigned takeConj(Kind kind, unsigned s0, unsigned s1) { 156a2c9d4bbSAart Bik unsigned s = addSet(); 157a2c9d4bbSAart Bik for (unsigned p0 : latSets[s0]) 158a2c9d4bbSAart Bik for (unsigned p1 : latSets[s1]) 159a2c9d4bbSAart Bik latSets[s].push_back(conjLatPoint(kind, p0, p1)); 160a2c9d4bbSAart Bik return s; 161a2c9d4bbSAart Bik } 162a2c9d4bbSAart Bik 163a2c9d4bbSAart Bik /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). 164a2c9d4bbSAart Bik /// Returns the index of the new set. 165a2c9d4bbSAart Bik unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) { 166a2c9d4bbSAart Bik unsigned s = takeConj(kind, s0, s1); 167a2c9d4bbSAart Bik for (unsigned p : latSets[s0]) 168a2c9d4bbSAart Bik latSets[s].push_back(p); 169a2c9d4bbSAart Bik for (unsigned p : latSets[s1]) 170a2c9d4bbSAart Bik latSets[s].push_back(p); 171a2c9d4bbSAart Bik return s; 172a2c9d4bbSAart Bik } 173a2c9d4bbSAart Bik 174a2c9d4bbSAart Bik /// Optimizes the iteration lattice points in the given set. This 175a2c9d4bbSAart Bik /// method should be called right before code generation to avoid 176a2c9d4bbSAart Bik /// generating redundant loops and conditions. 177a2c9d4bbSAart Bik unsigned optimizeSet(unsigned s0) { 178a2c9d4bbSAart Bik unsigned s = addSet(); 179a2c9d4bbSAart Bik assert(latSets[s0].size() != 0); 180a2c9d4bbSAart Bik unsigned p0 = latSets[s0][0]; 181a2c9d4bbSAart Bik for (unsigned p1 : latSets[s0]) { 182a2c9d4bbSAart Bik bool add = true; 183a2c9d4bbSAart Bik if (p0 != p1) { 184a2c9d4bbSAart Bik // Is this a straightforward copy? 185a2c9d4bbSAart Bik unsigned e = latPoints[p1].exp; 186a2c9d4bbSAart Bik if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) 187a2c9d4bbSAart Bik continue; 188a2c9d4bbSAart Bik // Conjunction already covered? 189a2c9d4bbSAart Bik for (unsigned p2 : latSets[s]) { 190a2c9d4bbSAart Bik assert(!latGT(p1, p2)); // Lj => Li would be bad 191a2c9d4bbSAart Bik if (onlyDenseDiff(p2, p1)) { 192a2c9d4bbSAart Bik add = false; 193a2c9d4bbSAart Bik break; 194a2c9d4bbSAart Bik } 195a2c9d4bbSAart Bik } 196a2c9d4bbSAart Bik assert(!add || latGT(p0, p1)); 197a2c9d4bbSAart Bik } 198a2c9d4bbSAart Bik if (add) 199a2c9d4bbSAart Bik latSets[s].push_back(p1); 200a2c9d4bbSAart Bik } 201a2c9d4bbSAart Bik for (unsigned p : latSets[s]) 202a2c9d4bbSAart Bik latPoints[p].simple = simplifyCond(s, p); 203a2c9d4bbSAart Bik return s; 204a2c9d4bbSAart Bik } 205a2c9d4bbSAart Bik 206a2c9d4bbSAart Bik /// Simplifies the conditions in a conjunction of a given lattice point 207a2c9d4bbSAart Bik /// within the given set using just two basic rules: 208a2c9d4bbSAart Bik /// (1) multiple dense conditions are reduced to single dense, and 209a2c9d4bbSAart Bik /// (2) a *singleton* sparse/dense is reduced to sparse/random access. 210a2c9d4bbSAart Bik llvm::BitVector simplifyCond(unsigned s, unsigned p0) { 211a2c9d4bbSAart Bik // First determine if this lattice point is a *singleton*, i.e., 212a2c9d4bbSAart Bik // the last point in a lattice, no other is less than this one. 213a2c9d4bbSAart Bik bool isSingleton = true; 214a2c9d4bbSAart Bik for (unsigned p1 : latSets[s]) { 215a2c9d4bbSAart Bik if (p0 != p1 && latGT(p0, p1)) { 216a2c9d4bbSAart Bik isSingleton = false; 217a2c9d4bbSAart Bik break; 218a2c9d4bbSAart Bik } 219a2c9d4bbSAart Bik } 220a2c9d4bbSAart Bik // Now apply the two basic rules. 221a2c9d4bbSAart Bik llvm::BitVector simple = latPoints[p0].bits; 222a2c9d4bbSAart Bik bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); 223a2c9d4bbSAart Bik for (unsigned b = 0, be = simple.size(); b < be; b++) { 224a2c9d4bbSAart Bik if (simple[b] && !isDim(b, Dim::kSparse)) { 225a2c9d4bbSAart Bik if (reset) 226a2c9d4bbSAart Bik simple.reset(b); 227a2c9d4bbSAart Bik reset = true; 228a2c9d4bbSAart Bik } 229a2c9d4bbSAart Bik } 230a2c9d4bbSAart Bik return simple; 231a2c9d4bbSAart Bik } 232a2c9d4bbSAart Bik 233a2c9d4bbSAart Bik /// Returns true if Li > Lj. 234a2c9d4bbSAart Bik bool latGT(unsigned i, unsigned j) const { 235a2c9d4bbSAart Bik const llvm::BitVector &bitsi = latPoints[i].bits; 236a2c9d4bbSAart Bik const llvm::BitVector &bitsj = latPoints[j].bits; 237a2c9d4bbSAart Bik assert(bitsi.size() == bitsj.size()); 238a2c9d4bbSAart Bik if (bitsi.count() > bitsj.count()) { 239a2c9d4bbSAart Bik for (unsigned b = 0, be = bitsj.size(); b < be; b++) 240a2c9d4bbSAart Bik if (bitsj[b] && !bitsi[b]) 241a2c9d4bbSAart Bik return false; 242a2c9d4bbSAart Bik return true; 243a2c9d4bbSAart Bik } 244a2c9d4bbSAart Bik return false; 245a2c9d4bbSAart Bik } 246a2c9d4bbSAart Bik 247a2c9d4bbSAart Bik /// Returns true if Li and Lj only differ in dense. 248a2c9d4bbSAart Bik bool onlyDenseDiff(unsigned i, unsigned j) { 249a2c9d4bbSAart Bik llvm::BitVector tmp = latPoints[j].bits; 250a2c9d4bbSAart Bik tmp ^= latPoints[i].bits; 251a2c9d4bbSAart Bik return !hasAnyDimOf(tmp, Dim::kSparse); 252a2c9d4bbSAart Bik } 253a2c9d4bbSAart Bik 254a2c9d4bbSAart Bik /// Bit translation. 255a2c9d4bbSAart Bik unsigned tensor(unsigned b) const { return b % numTensors; } 256a2c9d4bbSAart Bik unsigned index(unsigned b) const { return b / numTensors; } 257a2c9d4bbSAart Bik 258a2c9d4bbSAart Bik /// Returns true if bit corresponds to queried dim. 259a2c9d4bbSAart Bik bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); } 260a2c9d4bbSAart Bik 261a2c9d4bbSAart Bik /// Returns true if tensor access at given index has queried dim. 262a2c9d4bbSAart Bik bool isDim(unsigned t, unsigned i, Dim d) const { 263a2c9d4bbSAart Bik assert(t < numTensors && i < numLoops); 264a2c9d4bbSAart Bik return dims[t][i] == d; 265a2c9d4bbSAart Bik } 266a2c9d4bbSAart Bik 267a2c9d4bbSAart Bik /// Returns true if any set bit corresponds to queried dim. 268a2c9d4bbSAart Bik bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 269a2c9d4bbSAart Bik for (unsigned b = 0, be = bits.size(); b < be; b++) 270a2c9d4bbSAart Bik if (bits[b] && isDim(b, d)) 271a2c9d4bbSAart Bik return true; 272a2c9d4bbSAart Bik return false; 273a2c9d4bbSAart Bik } 274a2c9d4bbSAart Bik 275a2c9d4bbSAart Bik /// Setter 276a2c9d4bbSAart Bik void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } 277a2c9d4bbSAart Bik 278a2c9d4bbSAart Bik /// Getters. 279a2c9d4bbSAart Bik TensorExp &exp(unsigned e) { return tensorExps[e]; } 280a2c9d4bbSAart Bik LatPoint &lat(unsigned l) { return latPoints[l]; } 281a2c9d4bbSAart Bik SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; } 282a2c9d4bbSAart Bik 283a2c9d4bbSAart Bik private: 284a2c9d4bbSAart Bik const unsigned outTensor; 285a2c9d4bbSAart Bik const unsigned numTensors; 286a2c9d4bbSAart Bik const unsigned numLoops; 287a2c9d4bbSAart Bik 288a2c9d4bbSAart Bik std::vector<std::vector<Dim>> dims; 289a2c9d4bbSAart Bik llvm::SmallVector<TensorExp, 32> tensorExps; 290a2c9d4bbSAart Bik llvm::SmallVector<LatPoint, 16> latPoints; 291a2c9d4bbSAart Bik llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets; 292a2c9d4bbSAart Bik }; 293a2c9d4bbSAart Bik 294a2c9d4bbSAart Bik // Code generation. 295a2c9d4bbSAart Bik struct CodeGen { 29696a23911SAart Bik CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 297a2c9d4bbSAart Bik : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 298a2c9d4bbSAart Bik pointers(numTensors, std::vector<Value>(numLoops)), 299a2c9d4bbSAart Bik indices(numTensors, std::vector<Value>(numLoops)), 300a2c9d4bbSAart Bik highs(numTensors, std::vector<Value>(numLoops)), 301a2c9d4bbSAart Bik pidxs(numTensors, std::vector<Value>(numLoops)), 302a2c9d4bbSAart Bik idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 303a2c9d4bbSAart Bik curVecLength(1), curVecMask() {} 304a2c9d4bbSAart Bik /// Sparsification options. 30596a23911SAart Bik SparsificationOptions options; 306a2c9d4bbSAart Bik /// Universal dense indices and upper bounds (by index). The loops array 307a2c9d4bbSAart Bik /// is updated with the value of the universal dense index in the current 308a2c9d4bbSAart Bik /// loop. The sizes array is set once with the inferred dimension sizes. 309a2c9d4bbSAart Bik std::vector<Value> loops; 310a2c9d4bbSAart Bik std::vector<Value> sizes; 311a2c9d4bbSAart Bik /// Buffers for storing dense and sparse numerical values (by tensor). 312a2c9d4bbSAart Bik /// This array is set once during bufferization of all tensors. 313a2c9d4bbSAart Bik std::vector<Value> buffers; 314a2c9d4bbSAart Bik /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 315a2c9d4bbSAart Bik /// This array is set once during bufferization of all sparse tensors. 316a2c9d4bbSAart Bik std::vector<std::vector<Value>> pointers; 317a2c9d4bbSAart Bik std::vector<std::vector<Value>> indices; 318a2c9d4bbSAart Bik /// Sparse iteration information (by tensor and index). These arrays 319a2c9d4bbSAart Bik /// are updated to remain current within the current loop. 320a2c9d4bbSAart Bik std::vector<std::vector<Value>> highs; 321a2c9d4bbSAart Bik std::vector<std::vector<Value>> pidxs; 322a2c9d4bbSAart Bik std::vector<std::vector<Value>> idxs; 323a2c9d4bbSAart Bik /// Current reduction, updated during code generation. When indices of a 324a2c9d4bbSAart Bik /// reduction are exhausted, all inner loops can "scalarize" the reduction. 325a2c9d4bbSAart Bik // TODO: currently only done for (a chain of) innermost for-loops, where it 326a2c9d4bbSAart Bik // is most effective; we could generalize to more outer and while-loops. 327a2c9d4bbSAart Bik unsigned redExp; 328a2c9d4bbSAart Bik Value redVal; 329a2c9d4bbSAart Bik // Current vector length and mask. 330a2c9d4bbSAart Bik unsigned curVecLength; 331a2c9d4bbSAart Bik Value curVecMask; 332a2c9d4bbSAart Bik }; 333a2c9d4bbSAart Bik 334a2c9d4bbSAart Bik } // namespace 335a2c9d4bbSAart Bik 33696a23911SAart Bik // Helper method to translate dim level type to internal representation. 33796a23911SAart Bik static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { 33896a23911SAart Bik if (enc) { 33996a23911SAart Bik SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 34096a23911SAart Bik if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 34196a23911SAart Bik return Dim::kSparse; 34296a23911SAart Bik if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 34396a23911SAart Bik return Dim::kSingle; 34496a23911SAart Bik } 34596a23911SAart Bik return Dim::kDense; 34696a23911SAart Bik } 34796a23911SAart Bik 34896a23911SAart Bik /// Helper method to inspect sparse encodings in the tensor types. 349a2c9d4bbSAart Bik /// Fills the per-dimension sparsity information for all tensors. 350*bf9ef3efSAart Bik static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 351*bf9ef3efSAart Bik bool annotated = false; 352a2c9d4bbSAart Bik unsigned numTensors = op.getNumShapedOperands(); 353*bf9ef3efSAart Bik unsigned lhs = numTensors - 1; 354a2c9d4bbSAart Bik for (unsigned t = 0; t < numTensors; t++) { 355a2c9d4bbSAart Bik auto map = op.getIndexingMap(t); 356a2c9d4bbSAart Bik unsigned rank = op.getShapedType(t).getRank(); 35796a23911SAart Bik auto enc = getSparseTensorEncoding(op.getShapedType(t)); 358*bf9ef3efSAart Bik if (enc) { 359*bf9ef3efSAart Bik annotated = true; 360*bf9ef3efSAart Bik if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity()) 361*bf9ef3efSAart Bik return false; // TODO: handle permutations 362*bf9ef3efSAart Bik if (t == lhs) 363*bf9ef3efSAart Bik return false; // TODO: handle sparse outputs 364*bf9ef3efSAart Bik } 365a2c9d4bbSAart Bik for (unsigned d = 0; d < rank; d++) { 366a2c9d4bbSAart Bik unsigned idx = map.getDimPosition(d); 36796a23911SAart Bik merger.setDim(t, idx, toDim(enc, d)); 368a2c9d4bbSAart Bik } 369a2c9d4bbSAart Bik } 370*bf9ef3efSAart Bik return annotated; 371a2c9d4bbSAart Bik } 372a2c9d4bbSAart Bik 373a2c9d4bbSAart Bik /// A DFS helper to compute a topological sort. Note that recursion is 374a2c9d4bbSAart Bik /// bounded by the number of implicit loops, which is always small. 375a2c9d4bbSAart Bik /// Returns false when a cycle is detected. 376a2c9d4bbSAart Bik static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 377a2c9d4bbSAart Bik std::vector<unsigned> &topSort, 378a2c9d4bbSAart Bik std::vector<std::vector<bool>> &adjM) { 379a2c9d4bbSAart Bik if (visit[i] != 0) 380a2c9d4bbSAart Bik return visit[i] != 1; // 1 denotes cycle! 381a2c9d4bbSAart Bik visit[i] = 1; 382a2c9d4bbSAart Bik for (unsigned j = 0, e = visit.size(); j < e; j++) 383a2c9d4bbSAart Bik if (adjM[i][j]) 384a2c9d4bbSAart Bik if (!topSortDFS(j, visit, topSort, adjM)) 385a2c9d4bbSAart Bik return false; 386a2c9d4bbSAart Bik visit[i] = 2; 387a2c9d4bbSAart Bik topSort.push_back(i); 388a2c9d4bbSAart Bik return true; 389a2c9d4bbSAart Bik } 390a2c9d4bbSAart Bik 391a2c9d4bbSAart Bik /// Computes a topologically sorted iteration graph for the linalg operation. 392a2c9d4bbSAart Bik /// Ensures all tensors are visited in natural index order. This is essential 393a2c9d4bbSAart Bik /// for sparse storage formats since these only support access along fixed 394a2c9d4bbSAart Bik /// dimensions. Even for dense storage formats, however, the natural index 395a2c9d4bbSAart Bik /// order yields innermost unit-stride access with better spatial locality. 396a2c9d4bbSAart Bik static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 397a2c9d4bbSAart Bik std::vector<unsigned> &topSort, 398a2c9d4bbSAart Bik bool sparseOnly) { 399a2c9d4bbSAart Bik // Set up an n x n from/to adjacency matrix of the iteration graph 400a2c9d4bbSAart Bik // for the implicit loop indices i_0 .. i_n-1. 401a2c9d4bbSAart Bik unsigned n = op.getNumLoops(); 402a2c9d4bbSAart Bik std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 403a2c9d4bbSAart Bik 404a2c9d4bbSAart Bik // Iterate over the indexing maps of every tensor in the tensor expression. 405a2c9d4bbSAart Bik unsigned numTensors = op.getNumShapedOperands(); 406a2c9d4bbSAart Bik for (unsigned t = 0; t < numTensors; t++) { 407a2c9d4bbSAart Bik auto map = op.getIndexingMap(t); 408a2c9d4bbSAart Bik assert(map.getNumDims() == n); 409a2c9d4bbSAart Bik // Skip dense tensor constraints when sparse only is requested. 41096a23911SAart Bik if (sparseOnly && !getSparseTensorEncoding(op.getShapedType(t))) 411a2c9d4bbSAart Bik continue; 412a2c9d4bbSAart Bik // At the moment, we take the index variables in the tensor access 413a2c9d4bbSAart Bik // expression in the order in which they appear (conceptually a 414a2c9d4bbSAart Bik // "row-major" layout of every tensor). So, a tensor access A_ijk 415a2c9d4bbSAart Bik // forces the ordering i < j < k on the loop indices. 416a2c9d4bbSAart Bik // TODO: support affine map to define alternative dimension orders. 417a2c9d4bbSAart Bik for (unsigned d = 1, e = map.getNumResults(); d < e; d++) { 418a2c9d4bbSAart Bik unsigned f = map.getDimPosition(d - 1); 419a2c9d4bbSAart Bik unsigned t = map.getDimPosition(d); 420a2c9d4bbSAart Bik adjM[f][t] = true; 421a2c9d4bbSAart Bik } 422a2c9d4bbSAart Bik } 423a2c9d4bbSAart Bik 424a2c9d4bbSAart Bik // Topologically sort the iteration graph to determine loop order. 425a2c9d4bbSAart Bik // Report failure for a cyclic iteration graph. 426a2c9d4bbSAart Bik topSort.clear(); 427a2c9d4bbSAart Bik topSort.reserve(n); 428a2c9d4bbSAart Bik std::vector<unsigned> visit(n, 0); 429a2c9d4bbSAart Bik for (unsigned i = 0; i < n; i++) 430a2c9d4bbSAart Bik if (visit[i] == 0) 431a2c9d4bbSAart Bik if (!topSortDFS(i, visit, topSort, adjM)) 432a2c9d4bbSAart Bik return false; // cycle! 433a2c9d4bbSAart Bik std::reverse(std::begin(topSort), std::end(topSort)); 434a2c9d4bbSAart Bik return true; 435a2c9d4bbSAart Bik } 436a2c9d4bbSAart Bik 437a2c9d4bbSAart Bik /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. 438a2c9d4bbSAart Bik /// This simplifies constructing (sub)expressions during iteration lattice 439a2c9d4bbSAart Bik /// building (compared to using the SSA representation everywhere). 440a2c9d4bbSAart Bik static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op, 441a2c9d4bbSAart Bik Value val) { 442a2c9d4bbSAart Bik if (auto arg = val.dyn_cast<BlockArgument>()) { 443a2c9d4bbSAart Bik unsigned argN = arg.getArgNumber(); 444a2c9d4bbSAart Bik if (arg.getOwner()->getParentOp() == op) { 445a2c9d4bbSAart Bik // Any parameter of the generic op is considered a tensor, 446a2c9d4bbSAart Bik // indexed by the implicit loop bounds. 447a2c9d4bbSAart Bik auto map = op.getIndexingMap(argN); 448a2c9d4bbSAart Bik if (map.isProjectedPermutation()) 449a2c9d4bbSAart Bik return merger.addExp(Kind::kTensor, argN); 450a2c9d4bbSAart Bik // Cannot handle (yet). 451a2c9d4bbSAart Bik return None; 452a2c9d4bbSAart Bik } 453a2c9d4bbSAart Bik // Any parameter of a higher op is invariant. 454a2c9d4bbSAart Bik return merger.addExp(Kind::kInvariant, val); 455a2c9d4bbSAart Bik } 456a2c9d4bbSAart Bik Operation *def = val.getDefiningOp(); 457a2c9d4bbSAart Bik if (def->getBlock() != &op.region().front()) { 458a2c9d4bbSAart Bik // Something defined outside is invariant. 459a2c9d4bbSAart Bik return merger.addExp(Kind::kInvariant, val); 460a2c9d4bbSAart Bik } else if (def->getNumOperands() == 2) { 461a2c9d4bbSAart Bik // Construct binary operations if subexpressions could be built. 462a2c9d4bbSAart Bik auto x = buildTensorExp(merger, op, def->getOperand(0)); 463a2c9d4bbSAart Bik auto y = buildTensorExp(merger, op, def->getOperand(1)); 464a2c9d4bbSAart Bik if (x.hasValue() && y.hasValue()) { 465a2c9d4bbSAart Bik unsigned e0 = x.getValue(); 466a2c9d4bbSAart Bik unsigned e1 = y.getValue(); 467a2c9d4bbSAart Bik if (isa<MulFOp>(def)) 468a2c9d4bbSAart Bik return merger.addExp(Kind::kMulF, e0, e1); 469a2c9d4bbSAart Bik if (isa<MulIOp>(def)) 470a2c9d4bbSAart Bik return merger.addExp(Kind::kMulI, e0, e1); 471a2c9d4bbSAart Bik if (isa<AddFOp>(def)) 472a2c9d4bbSAart Bik return merger.addExp(Kind::kAddF, e0, e1); 473a2c9d4bbSAart Bik if (isa<AddIOp>(def)) 474a2c9d4bbSAart Bik return merger.addExp(Kind::kAddI, e0, e1); 475a2c9d4bbSAart Bik } 476a2c9d4bbSAart Bik } 477a2c9d4bbSAart Bik // Cannot build (yet). 478a2c9d4bbSAart Bik return None; 479a2c9d4bbSAart Bik } 480a2c9d4bbSAart Bik 481a2c9d4bbSAart Bik /// Builds the iteration lattices in a bottom-up traversal given the remaining 482a2c9d4bbSAart Bik /// tensor (sub)expression and the next loop index in the iteration graph. 483a2c9d4bbSAart Bik static unsigned buildLattices(Merger &merger, linalg::GenericOp op, 484a2c9d4bbSAart Bik unsigned exp, unsigned idx) { 485a2c9d4bbSAart Bik Kind kind = merger.exp(exp).kind; 486a2c9d4bbSAart Bik if (kind == Kind::kTensor || kind == Kind::kInvariant) { 487a2c9d4bbSAart Bik // Either the index is really used in the tensor expression, or it is 488a2c9d4bbSAart Bik // set to the undefined index in that dimension. An invariant expression 489a2c9d4bbSAart Bik // is set to a synthetic tensor with undefined indices only. 490a2c9d4bbSAart Bik unsigned s = merger.addSet(); 491a2c9d4bbSAart Bik unsigned t = 492a2c9d4bbSAart Bik kind == Kind::kTensor ? merger.exp(exp).e0 : op.getNumShapedOperands(); 493a2c9d4bbSAart Bik merger.set(s).push_back(merger.addLat(t, idx, exp)); 494a2c9d4bbSAart Bik return s; 495a2c9d4bbSAart Bik } 496a2c9d4bbSAart Bik unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx); 497a2c9d4bbSAart Bik unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx); 498a2c9d4bbSAart Bik switch (kind) { 499a2c9d4bbSAart Bik case Kind::kTensor: 500a2c9d4bbSAart Bik case Kind::kInvariant: 501a2c9d4bbSAart Bik llvm_unreachable("handled above"); 502a2c9d4bbSAart Bik case Kind::kMulF: 503a2c9d4bbSAart Bik case Kind::kMulI: 504a2c9d4bbSAart Bik return merger.takeConj(kind, s0, s1); 505a2c9d4bbSAart Bik case Kind::kAddF: 506a2c9d4bbSAart Bik case Kind::kAddI: 507a2c9d4bbSAart Bik return merger.takeDisj(kind, s0, s1); 508a2c9d4bbSAart Bik } 509a2c9d4bbSAart Bik llvm_unreachable("unexpected expression kind"); 510a2c9d4bbSAart Bik } 511a2c9d4bbSAart Bik 512a2c9d4bbSAart Bik /// Maps sparse integer option to actual integral storage type. 51396a23911SAart Bik static Type genIntType(PatternRewriter &rewriter, unsigned width) { 51496a23911SAart Bik if (width == 0) 515a2c9d4bbSAart Bik return rewriter.getIndexType(); 51696a23911SAart Bik return rewriter.getIntegerType(width); 517a2c9d4bbSAart Bik } 518a2c9d4bbSAart Bik 5195879da49SAart Bik /// Detects in-place annotation on tensor argument. 5205879da49SAart Bik static bool getInPlace(Value val) { 5215879da49SAart Bik if (auto arg = val.dyn_cast<BlockArgument>()) 5225879da49SAart Bik if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 5235879da49SAart Bik if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 5245879da49SAart Bik arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 5255879da49SAart Bik return attr.getValue(); 5265879da49SAart Bik return false; 5275879da49SAart Bik } 5285879da49SAart Bik 529a2c9d4bbSAart Bik /// Generates buffer for the output tensor. 530a2c9d4bbSAart Bik static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 531a2c9d4bbSAart Bik linalg::GenericOp op, MemRefType denseTp, 532a2c9d4bbSAart Bik ArrayRef<Value> args) { 533a2c9d4bbSAart Bik Location loc = op.getLoc(); 534a2c9d4bbSAart Bik Value tensor = op.getOutput(0); 535a2c9d4bbSAart Bik // The output tensor simply could materialize from the buffer that will 536a2c9d4bbSAart Bik // be generated for the tensor present in the outs() clause. This has 537a2c9d4bbSAart Bik // the major advantage that the sparse kernel only updates the nonzero 5385879da49SAart Bik // positions for the output tensor. 5395879da49SAart Bik if (getInPlace(tensor)) 540a2c9d4bbSAart Bik return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 541a2c9d4bbSAart Bik // By default, a new buffer is allocated which is initialized to the 542a2c9d4bbSAart Bik // tensor defined in the outs() clause. This is always correct but 543a2c9d4bbSAart Bik // introduces a dense initialization component that may negatively 544a2c9d4bbSAart Bik // impact the running complexity of the sparse kernel. 545a2c9d4bbSAart Bik Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 546a2c9d4bbSAart Bik Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 547a2c9d4bbSAart Bik rewriter.create<linalg::CopyOp>(loc, init, alloc); 548a2c9d4bbSAart Bik return alloc; 549a2c9d4bbSAart Bik } 550a2c9d4bbSAart Bik 551a2c9d4bbSAart Bik /// Local bufferization of all dense and sparse data structures. 552a2c9d4bbSAart Bik /// This code enables testing the first prototype sparse compiler. 553a2c9d4bbSAart Bik // TODO: replace this with a proliferated bufferization strategy 554a2c9d4bbSAart Bik static void genBuffers(Merger &merger, CodeGen &codegen, 555a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op) { 556a2c9d4bbSAart Bik Location loc = op.getLoc(); 557a2c9d4bbSAart Bik unsigned numTensors = op.getNumShapedOperands(); 558a2c9d4bbSAart Bik unsigned numInputs = op.getNumInputs(); 559a2c9d4bbSAart Bik assert(numTensors == numInputs + 1); 560a2c9d4bbSAart Bik // For every tensor, find lower and upper bound on dimensions, set the 561a2c9d4bbSAart Bik // same bounds on loop indices, and obtain dense or sparse buffer(s). 562a2c9d4bbSAart Bik SmallVector<Value, 4> args; 563a2c9d4bbSAart Bik for (unsigned t = 0; t < numTensors; t++) { 564a2c9d4bbSAart Bik Value tensor = t < numInputs ? op.getInput(t) : op.getOutput(0); 565a2c9d4bbSAart Bik auto tensorType = op.getShapedType(t); 566a2c9d4bbSAart Bik auto shape = tensorType.getShape(); 567a2c9d4bbSAart Bik auto map = op.getIndexingMap(t); 56896a23911SAart Bik auto enc = getSparseTensorEncoding(tensorType); 569a2c9d4bbSAart Bik // Scan all dimensions of current tensor. 570a2c9d4bbSAart Bik args.clear(); 571a2c9d4bbSAart Bik for (unsigned d = 0, rank = shape.size(); d < rank; d++) { 572a2c9d4bbSAart Bik unsigned i = map.getDimPosition(d); 573a2c9d4bbSAart Bik // Handle sparse storage schemes. 574a2c9d4bbSAart Bik if (merger.isDim(t, i, Dim::kSparse)) { 575a2c9d4bbSAart Bik auto dynShape = {ShapedType::kDynamicSize}; 576a2c9d4bbSAart Bik auto ptrTp = MemRefType::get( 57796a23911SAart Bik dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 578a2c9d4bbSAart Bik auto indTp = MemRefType::get( 57996a23911SAart Bik dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 580a2c9d4bbSAart Bik Value dim = rewriter.create<ConstantIndexOp>(loc, d); 581a2c9d4bbSAart Bik // Generate sparse primitives to obtains pointer and indices. 58296a23911SAart Bik codegen.pointers[t][i] = 58396a23911SAart Bik rewriter.create<ToPointersOp>(loc, ptrTp, tensor, dim); 58496a23911SAart Bik codegen.indices[t][i] = 58596a23911SAart Bik rewriter.create<ToIndicesOp>(loc, indTp, tensor, dim); 586a2c9d4bbSAart Bik } 587a2c9d4bbSAart Bik // Find lower and upper bound in current dimension. 588a2c9d4bbSAart Bik Value up; 589a2c9d4bbSAart Bik if (shape[d] == MemRefType::kDynamicSize) { 590a2c9d4bbSAart Bik up = rewriter.create<memref::DimOp>(loc, tensor, d); 591a2c9d4bbSAart Bik args.push_back(up); 592a2c9d4bbSAart Bik } else { 593a2c9d4bbSAart Bik up = rewriter.create<ConstantIndexOp>(loc, shape[d]); 594a2c9d4bbSAart Bik } 595a2c9d4bbSAart Bik codegen.sizes[i] = codegen.highs[t][i] = up; 596a2c9d4bbSAart Bik } 597a2c9d4bbSAart Bik // Perform the required bufferization. All dense inputs materialize 598a2c9d4bbSAart Bik // from the input tensor. The dense output tensor needs special 599a2c9d4bbSAart Bik // handling. Sparse inputs use a sparse primitive to obtain the values. 60096a23911SAart Bik if (!enc) { 601a2c9d4bbSAart Bik auto denseTp = MemRefType::get(shape, tensorType.getElementType()); 602a2c9d4bbSAart Bik if (t < numInputs) 603a2c9d4bbSAart Bik codegen.buffers[t] = 604a2c9d4bbSAart Bik rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 605a2c9d4bbSAart Bik else 606a2c9d4bbSAart Bik codegen.buffers[t] = 607a2c9d4bbSAart Bik genOutputBuffer(codegen, rewriter, op, denseTp, args); 608a2c9d4bbSAart Bik } else { 609a2c9d4bbSAart Bik auto dynShape = {ShapedType::kDynamicSize}; 610a2c9d4bbSAart Bik auto sparseTp = MemRefType::get(dynShape, tensorType.getElementType()); 61196a23911SAart Bik codegen.buffers[t] = rewriter.create<ToValuesOp>(loc, sparseTp, tensor); 612a2c9d4bbSAart Bik } 613a2c9d4bbSAart Bik } 614a2c9d4bbSAart Bik } 615a2c9d4bbSAart Bik 616a2c9d4bbSAart Bik /// Constructs vector type. 617a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Type etp) { 618a2c9d4bbSAart Bik return VectorType::get(codegen.curVecLength, etp); 619a2c9d4bbSAart Bik } 620a2c9d4bbSAart Bik 621a2c9d4bbSAart Bik /// Constructs vector type from pointer. 622a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Value ptr) { 623a2c9d4bbSAart Bik return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 624a2c9d4bbSAart Bik } 625a2c9d4bbSAart Bik 626a2c9d4bbSAart Bik /// Constructs vector iteration mask. 627a2c9d4bbSAart Bik static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 628a2c9d4bbSAart Bik Value iv, Value lo, Value hi, Value step) { 629a2c9d4bbSAart Bik Location loc = iv.getLoc(); 630a2c9d4bbSAart Bik VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 631a2c9d4bbSAart Bik // Special case if the vector length evenly divides the trip count (for 632a2c9d4bbSAart Bik // example, "for i = 0, 128, 16"). A constant all-true mask is generated 633a2c9d4bbSAart Bik // so that all subsequent masked memory operations are immediately folded 634a2c9d4bbSAart Bik // into unconditional memory operations. 635a2c9d4bbSAart Bik IntegerAttr loInt, hiInt, stepInt; 636a2c9d4bbSAart Bik if (matchPattern(lo, m_Constant(&loInt)) && 637a2c9d4bbSAart Bik matchPattern(hi, m_Constant(&hiInt)) && 638a2c9d4bbSAart Bik matchPattern(step, m_Constant(&stepInt))) { 639a2c9d4bbSAart Bik if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 640a2c9d4bbSAart Bik return rewriter.create<vector::BroadcastOp>( 641a2c9d4bbSAart Bik loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 642a2c9d4bbSAart Bik } 643a2c9d4bbSAart Bik // Otherwise, generate a vector mask that avoids overrunning the upperbound 644a2c9d4bbSAart Bik // during vector execution. Here we rely on subsequent loop optimizations to 645a2c9d4bbSAart Bik // avoid executing the mask in all iterations, for example, by splitting the 646a2c9d4bbSAart Bik // loop into an unconditional vector loop and a scalar cleanup loop. 647a2c9d4bbSAart Bik Value end = rewriter.create<SubIOp>(loc, hi, iv); 648a2c9d4bbSAart Bik return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 649a2c9d4bbSAart Bik } 650a2c9d4bbSAart Bik 651a2c9d4bbSAart Bik /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 652a2c9d4bbSAart Bik static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 653a2c9d4bbSAart Bik Value ptr, ArrayRef<Value> args) { 654a2c9d4bbSAart Bik Location loc = ptr.getLoc(); 655a2c9d4bbSAart Bik VectorType vtp = vectorType(codegen, ptr); 656a2c9d4bbSAart Bik Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 657a2c9d4bbSAart Bik if (args.back().getType().isa<VectorType>()) { 658a2c9d4bbSAart Bik SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 659a2c9d4bbSAart Bik Value indexVec = args.back(); 660a2c9d4bbSAart Bik scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 661a2c9d4bbSAart Bik return rewriter.create<vector::GatherOp>( 662a2c9d4bbSAart Bik loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 663a2c9d4bbSAart Bik } 664a2c9d4bbSAart Bik return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 665a2c9d4bbSAart Bik codegen.curVecMask, pass); 666a2c9d4bbSAart Bik } 667a2c9d4bbSAart Bik 668a2c9d4bbSAart Bik /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 669a2c9d4bbSAart Bik static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 670a2c9d4bbSAart Bik Value rhs, Value ptr, ArrayRef<Value> args) { 671a2c9d4bbSAart Bik Location loc = ptr.getLoc(); 672a2c9d4bbSAart Bik if (args.back().getType().isa<VectorType>()) { 673a2c9d4bbSAart Bik SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 674a2c9d4bbSAart Bik Value indexVec = args.back(); 675a2c9d4bbSAart Bik scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 676a2c9d4bbSAart Bik rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 677a2c9d4bbSAart Bik codegen.curVecMask, rhs); 678a2c9d4bbSAart Bik return; 679a2c9d4bbSAart Bik } 680a2c9d4bbSAart Bik rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 681a2c9d4bbSAart Bik rhs); 682a2c9d4bbSAart Bik } 683a2c9d4bbSAart Bik 684a2c9d4bbSAart Bik /// Generates a vectorized invariant. Here we rely on subsequent loop 685a2c9d4bbSAart Bik /// optimizations to hoist the invariant broadcast out of the vector loop. 686a2c9d4bbSAart Bik static Value genVectorInvariantValue(CodeGen &codegen, 687a2c9d4bbSAart Bik PatternRewriter &rewriter, Value val) { 688a2c9d4bbSAart Bik VectorType vtp = vectorType(codegen, val.getType()); 689a2c9d4bbSAart Bik return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 690a2c9d4bbSAart Bik } 691a2c9d4bbSAart Bik 692a2c9d4bbSAart Bik /// Generates a load on a dense or sparse tensor. 693a2c9d4bbSAart Bik static Value genTensorLoad(Merger &merger, CodeGen &codegen, 694a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 695a2c9d4bbSAart Bik unsigned exp) { 696a2c9d4bbSAart Bik // Test if the load was hoisted to a higher loop nest. 697a2c9d4bbSAart Bik Value val = merger.exp(exp).val; 698a2c9d4bbSAart Bik if (val) { 699a2c9d4bbSAart Bik if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 700a2c9d4bbSAart Bik return genVectorInvariantValue(codegen, rewriter, val); 701a2c9d4bbSAart Bik return val; 702a2c9d4bbSAart Bik } 703a2c9d4bbSAart Bik // Actual load. 704a2c9d4bbSAart Bik SmallVector<Value, 4> args; 705a2c9d4bbSAart Bik unsigned tensor = merger.exp(exp).e0; 706a2c9d4bbSAart Bik auto map = op.getIndexingMap(tensor); 70796a23911SAart Bik auto enc = getSparseTensorEncoding(op.getShapedType(tensor)); 708a2c9d4bbSAart Bik for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { 709a2c9d4bbSAart Bik unsigned idx = map.getDimPosition(i); 710a2c9d4bbSAart Bik args.push_back(codegen.loops[idx]); // universal dense index 71196a23911SAart Bik if (enc) { 712a2c9d4bbSAart Bik args.clear(); 713a2c9d4bbSAart Bik args.push_back(codegen.pidxs[tensor][idx]); // position index 714a2c9d4bbSAart Bik } 715a2c9d4bbSAart Bik } 716a2c9d4bbSAart Bik Location loc = op.getLoc(); 717a2c9d4bbSAart Bik Value ptr = codegen.buffers[tensor]; 718a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 719a2c9d4bbSAart Bik return genVectorLoad(codegen, rewriter, ptr, args); 720a2c9d4bbSAart Bik return rewriter.create<memref::LoadOp>(loc, ptr, args); 721a2c9d4bbSAart Bik } 722a2c9d4bbSAart Bik 723a2c9d4bbSAart Bik /// Generates a store on a dense tensor. 724a2c9d4bbSAart Bik static void genTensorStore(Merger &merger, CodeGen &codegen, 725a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 726a2c9d4bbSAart Bik unsigned tensor, Value rhs) { 727a2c9d4bbSAart Bik Location loc = op.getLoc(); 728a2c9d4bbSAart Bik // Test if this is a scalarized reduction. 729a2c9d4bbSAart Bik unsigned lhs = op.getNumShapedOperands() - 1; 730a2c9d4bbSAart Bik if (lhs == tensor && codegen.redVal) { 731a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 732a2c9d4bbSAart Bik rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 733a2c9d4bbSAart Bik codegen.redVal); 734a2c9d4bbSAart Bik codegen.redVal = rhs; 735a2c9d4bbSAart Bik return; 736a2c9d4bbSAart Bik } 737a2c9d4bbSAart Bik // Actual store. 738a2c9d4bbSAart Bik SmallVector<Value, 4> args; 739a2c9d4bbSAart Bik auto map = op.getIndexingMap(tensor); 740a2c9d4bbSAart Bik for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { 741a2c9d4bbSAart Bik unsigned idx = map.getDimPosition(i); 742a2c9d4bbSAart Bik args.push_back(codegen.loops[idx]); // universal dense index 743a2c9d4bbSAart Bik } 744a2c9d4bbSAart Bik Value ptr = codegen.buffers[tensor]; 745a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 746a2c9d4bbSAart Bik genVectorStore(codegen, rewriter, rhs, ptr, args); 747a2c9d4bbSAart Bik else 748a2c9d4bbSAart Bik rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 749a2c9d4bbSAart Bik } 750a2c9d4bbSAart Bik 751a2c9d4bbSAart Bik /// Generates a pointer/index load from the sparse storage scheme. Narrower 752a2c9d4bbSAart Bik /// data types need to be zero extended before casting the value into the 753a2c9d4bbSAart Bik /// index type used for looping and indexing. 754a2c9d4bbSAart Bik static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 755a2c9d4bbSAart Bik Value ptr, Value s) { 756a2c9d4bbSAart Bik // See https://llvm.org/docs/GetElementPtr.html for some background on 757a2c9d4bbSAart Bik // the complications described below. 758a2c9d4bbSAart Bik if (codegen.curVecLength > 1) { 759a2c9d4bbSAart Bik // Since the index vector is used in a subsequent gather/scatter operations, 760a2c9d4bbSAart Bik // which effectively defines an unsigned pointer + signed index, we must 761a2c9d4bbSAart Bik // zero extend the vector to an index width. For 8-bit and 16-bit values, 762a2c9d4bbSAart Bik // an 32-bit index width suffices. For 32-bit values, zero extending the 763a2c9d4bbSAart Bik // elements into 64-bit loses some performance since the 32-bit indexed 764a2c9d4bbSAart Bik // gather/scatter is more efficient than the 64-bit index variant (in 765a2c9d4bbSAart Bik // the future, we could introduce a flag that states the negative space 766a2c9d4bbSAart Bik // of 32-bit indices is unused). For 64-bit values, there is no good way 767a2c9d4bbSAart Bik // to state that the indices are unsigned, with creates the potential of 768a2c9d4bbSAart Bik // incorrect address calculations in the unlikely case we need such 769a2c9d4bbSAart Bik // extremely large offsets. 770a2c9d4bbSAart Bik Type etp = ptr.getType().cast<MemRefType>().getElementType(); 771a2c9d4bbSAart Bik Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 772a2c9d4bbSAart Bik if (!etp.isa<IndexType>()) { 773a2c9d4bbSAart Bik if (etp.getIntOrFloatBitWidth() < 32) 774a2c9d4bbSAart Bik vload = rewriter.create<ZeroExtendIOp>( 775a2c9d4bbSAart Bik loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 776a2c9d4bbSAart Bik else if (etp.getIntOrFloatBitWidth() < 64) 777a2c9d4bbSAart Bik vload = rewriter.create<ZeroExtendIOp>( 778a2c9d4bbSAart Bik loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 779a2c9d4bbSAart Bik } 780a2c9d4bbSAart Bik return vload; 781a2c9d4bbSAart Bik } 782a2c9d4bbSAart Bik // For the scalar case, we simply zero extend narrower indices into 64-bit 783a2c9d4bbSAart Bik // values before casting to index without a performance penalty. Here too, 784a2c9d4bbSAart Bik // however, indices that already are 64-bit, in theory, cannot express the 785a2c9d4bbSAart Bik // full range as explained above. 786a2c9d4bbSAart Bik Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 787a2c9d4bbSAart Bik if (!load.getType().isa<IndexType>()) { 788a2c9d4bbSAart Bik if (load.getType().getIntOrFloatBitWidth() < 64) 789a2c9d4bbSAart Bik load = rewriter.create<ZeroExtendIOp>(loc, load, 790a2c9d4bbSAart Bik rewriter.getIntegerType(64)); 791a2c9d4bbSAart Bik load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 792a2c9d4bbSAart Bik } 793a2c9d4bbSAart Bik return load; 794a2c9d4bbSAart Bik } 795a2c9d4bbSAart Bik 796a2c9d4bbSAart Bik /// Generates an invariant value. 797a2c9d4bbSAart Bik static Value genInvariantValue(Merger &merger, CodeGen &codegen, 798a2c9d4bbSAart Bik PatternRewriter &rewriter, unsigned exp) { 799a2c9d4bbSAart Bik Value val = merger.exp(exp).val; 800a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 801a2c9d4bbSAart Bik return genVectorInvariantValue(codegen, rewriter, val); 802a2c9d4bbSAart Bik return val; 803a2c9d4bbSAart Bik } 804a2c9d4bbSAart Bik 805a2c9d4bbSAart Bik /// Generates an address computation "sz * p + i". 806a2c9d4bbSAart Bik static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 807a2c9d4bbSAart Bik Location loc, Value size, Value p, Value i) { 808a2c9d4bbSAart Bik Value mul = rewriter.create<MulIOp>(loc, size, p); 809a2c9d4bbSAart Bik if (auto vtp = i.getType().dyn_cast<VectorType>()) { 810a2c9d4bbSAart Bik Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 811a2c9d4bbSAart Bik mul = genVectorInvariantValue(codegen, rewriter, inv); 812a2c9d4bbSAart Bik } 813a2c9d4bbSAart Bik return rewriter.create<AddIOp>(loc, mul, i); 814a2c9d4bbSAart Bik } 815a2c9d4bbSAart Bik 816a2c9d4bbSAart Bik /// Generates start of a reduction. 817a2c9d4bbSAart Bik static Value genReductionStart(Merger &merger, CodeGen &codegen, 818a2c9d4bbSAart Bik PatternRewriter &rewriter, 819a2c9d4bbSAart Bik linalg::GenericOp op) { 820a2c9d4bbSAart Bik if (codegen.redVal) 821a2c9d4bbSAart Bik return codegen.redVal; // chained with previous for-loop 822a2c9d4bbSAart Bik if (codegen.curVecLength > 1) { 823a2c9d4bbSAart Bik // TODO: assumes + reductions for now 824a2c9d4bbSAart Bik VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 825a2c9d4bbSAart Bik return rewriter.create<ConstantOp>(op.getLoc(), vtp, 826a2c9d4bbSAart Bik rewriter.getZeroAttr(vtp)); 827a2c9d4bbSAart Bik } 828a2c9d4bbSAart Bik return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 829a2c9d4bbSAart Bik } 830a2c9d4bbSAart Bik 831a2c9d4bbSAart Bik /// Generates end of a reduction. 832a2c9d4bbSAart Bik static void genReductionEnd(Merger &merger, CodeGen &codegen, 833a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op) { 834a2c9d4bbSAart Bik Value red = codegen.redVal; 835a2c9d4bbSAart Bik if (!red) 836a2c9d4bbSAart Bik return; 837a2c9d4bbSAart Bik assert(codegen.curVecLength == 1); 838a2c9d4bbSAart Bik codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 839a2c9d4bbSAart Bik unsigned lhs = op.getNumShapedOperands() - 1; 840a2c9d4bbSAart Bik if (auto vtp = red.getType().dyn_cast<VectorType>()) { 841a2c9d4bbSAart Bik // TODO: assumes + reductions for now 842a2c9d4bbSAart Bik StringAttr kind = rewriter.getStringAttr("add"); 843a2c9d4bbSAart Bik Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 844a2c9d4bbSAart Bik // Integer reductions don't accept an accumulator. 845a2c9d4bbSAart Bik if (vtp.getElementType().isa<IntegerType>()) { 846a2c9d4bbSAart Bik red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 847a2c9d4bbSAart Bik kind, red, ValueRange{}); 848a2c9d4bbSAart Bik red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 849a2c9d4bbSAart Bik } else { 850a2c9d4bbSAart Bik red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 851a2c9d4bbSAart Bik kind, red, ld); 852a2c9d4bbSAart Bik } 853a2c9d4bbSAart Bik } 854a2c9d4bbSAart Bik genTensorStore(merger, codegen, rewriter, op, lhs, red); 855a2c9d4bbSAart Bik } 856a2c9d4bbSAart Bik 857a2c9d4bbSAart Bik /// Recursively generates tensor expression. 858a2c9d4bbSAart Bik static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 859a2c9d4bbSAart Bik linalg::GenericOp op, unsigned exp) { 860a2c9d4bbSAart Bik if (merger.exp(exp).kind == Kind::kTensor) 861a2c9d4bbSAart Bik return genTensorLoad(merger, codegen, rewriter, op, exp); 862a2c9d4bbSAart Bik else if (merger.exp(exp).kind == Kind::kInvariant) 863a2c9d4bbSAart Bik return genInvariantValue(merger, codegen, rewriter, exp); 864a2c9d4bbSAart Bik Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); 865a2c9d4bbSAart Bik Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); 866a2c9d4bbSAart Bik switch (merger.exp(exp).kind) { 867a2c9d4bbSAart Bik case Kind::kTensor: 868a2c9d4bbSAart Bik case Kind::kInvariant: 869a2c9d4bbSAart Bik llvm_unreachable("handled above"); 870a2c9d4bbSAart Bik case Kind::kMulF: 871a2c9d4bbSAart Bik return rewriter.create<MulFOp>(op.getLoc(), v0, v1); 872a2c9d4bbSAart Bik case Kind::kMulI: 873a2c9d4bbSAart Bik return rewriter.create<MulIOp>(op.getLoc(), v0, v1); 874a2c9d4bbSAart Bik case Kind::kAddF: 875a2c9d4bbSAart Bik return rewriter.create<AddFOp>(op.getLoc(), v0, v1); 876a2c9d4bbSAart Bik case Kind::kAddI: 877a2c9d4bbSAart Bik return rewriter.create<AddIOp>(op.getLoc(), v0, v1); 878a2c9d4bbSAart Bik } 879a2c9d4bbSAart Bik llvm_unreachable("unexpected expression kind"); 880a2c9d4bbSAart Bik } 881a2c9d4bbSAart Bik 882a2c9d4bbSAart Bik /// Hoists loop invariant tensor loads for which indices have been exhausted. 883a2c9d4bbSAart Bik static void genInvariants(Merger &merger, CodeGen &codegen, 884a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 885a2c9d4bbSAart Bik unsigned exp, unsigned ldx, bool hoist) { 886a2c9d4bbSAart Bik if (merger.exp(exp).kind == Kind::kTensor) { 887a2c9d4bbSAart Bik // Inspect tensor indices. 888a2c9d4bbSAart Bik bool atLevel = ldx == -1u; 889a2c9d4bbSAart Bik unsigned tensor = merger.exp(exp).e0; 890a2c9d4bbSAart Bik auto map = op.getIndexingMap(tensor); 891a2c9d4bbSAart Bik for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { 892a2c9d4bbSAart Bik unsigned idx = map.getDimPosition(i); 893a2c9d4bbSAart Bik if (!codegen.loops[idx]) 894a2c9d4bbSAart Bik return; // still in play 895a2c9d4bbSAart Bik else if (idx == ldx) 896a2c9d4bbSAart Bik atLevel = true; 897a2c9d4bbSAart Bik } 898a2c9d4bbSAart Bik // All exhausted at this level (atLevel denotes exactly at this level). 899a2c9d4bbSAart Bik unsigned lhs = op.getNumShapedOperands() - 1; 900a2c9d4bbSAart Bik if (lhs == tensor) { 901a2c9d4bbSAart Bik codegen.redExp = hoist ? exp : -1u; 902a2c9d4bbSAart Bik } else if (atLevel) { 903a2c9d4bbSAart Bik merger.exp(exp).val = 904a2c9d4bbSAart Bik hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 905a2c9d4bbSAart Bik } 906a2c9d4bbSAart Bik } else if (merger.exp(exp).kind != Kind::kInvariant) { 907a2c9d4bbSAart Bik // Traverse into the binary operations. Note that we only hoist 908a2c9d4bbSAart Bik // tensor loads, since subsequent MLIR/LLVM passes know how to 909a2c9d4bbSAart Bik // deal with all other kinds of derived loop invariants. 910a2c9d4bbSAart Bik unsigned e0 = merger.exp(exp).e0; 911a2c9d4bbSAart Bik unsigned e1 = merger.exp(exp).e1; 912a2c9d4bbSAart Bik genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 913a2c9d4bbSAart Bik genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 914a2c9d4bbSAart Bik } 915a2c9d4bbSAart Bik } 916a2c9d4bbSAart Bik 917a2c9d4bbSAart Bik /// Generates initialization code for the subsequent loop sequence at 918a2c9d4bbSAart Bik /// current index level. Returns true if the loop sequence needs to 919a2c9d4bbSAart Bik /// maintain the universal index. 920a2c9d4bbSAart Bik static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 921a2c9d4bbSAart Bik linalg::GenericOp op, std::vector<unsigned> &topSort, 922a2c9d4bbSAart Bik unsigned at, llvm::BitVector &inits) { 923a2c9d4bbSAart Bik bool needsUniv = false; 924a2c9d4bbSAart Bik Location loc = op.getLoc(); 925a2c9d4bbSAart Bik unsigned idx = topSort[at]; 926a2c9d4bbSAart Bik 927a2c9d4bbSAart Bik // Initialize sparse positions. 928a2c9d4bbSAart Bik for (unsigned b = 0, be = inits.size(); b < be; b++) { 929a2c9d4bbSAart Bik if (inits[b]) { 930a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 931a2c9d4bbSAart Bik assert(idx == merger.index(b)); 932a2c9d4bbSAart Bik if (merger.isDim(b, Dim::kSparse)) { 933a2c9d4bbSAart Bik // Initialize sparse index. 934a2c9d4bbSAart Bik unsigned pat = at; 935a2c9d4bbSAart Bik for (; pat != 0; pat--) { 936a2c9d4bbSAart Bik if (codegen.pidxs[tensor][topSort[pat - 1]]) 937a2c9d4bbSAart Bik break; 938a2c9d4bbSAart Bik } 939a2c9d4bbSAart Bik Value ptr = codegen.pointers[tensor][idx]; 940a2c9d4bbSAart Bik Value one = rewriter.create<ConstantIndexOp>(loc, 1); 941a2c9d4bbSAart Bik Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 942a2c9d4bbSAart Bik : codegen.pidxs[tensor][topSort[pat - 1]]; 943a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 944a2c9d4bbSAart Bik Value p1 = rewriter.create<AddIOp>(loc, p0, one); 945a2c9d4bbSAart Bik codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 946a2c9d4bbSAart Bik } else { 947a2c9d4bbSAart Bik // Dense index still in play. 948a2c9d4bbSAart Bik needsUniv = true; 949a2c9d4bbSAart Bik } 950a2c9d4bbSAart Bik } 951a2c9d4bbSAart Bik } 952a2c9d4bbSAart Bik 953a2c9d4bbSAart Bik // Initialize the universal dense index. 954a2c9d4bbSAart Bik codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 955a2c9d4bbSAart Bik return needsUniv; 956a2c9d4bbSAart Bik } 957a2c9d4bbSAart Bik 958a2c9d4bbSAart Bik /// Returns vectorization strategy. Any implicit inner loop in the Linalg 959a2c9d4bbSAart Bik /// operation is a candidate. Whether it is actually converted to SIMD code 960a2c9d4bbSAart Bik /// depends on the requested strategy. 961a2c9d4bbSAart Bik static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 962a2c9d4bbSAart Bik switch (codegen.options.vectorizationStrategy) { 963a2c9d4bbSAart Bik case SparseVectorizationStrategy::kNone: 964a2c9d4bbSAart Bik return false; 965a2c9d4bbSAart Bik case SparseVectorizationStrategy::kDenseInnerLoop: 966a2c9d4bbSAart Bik return isInner && !isSparse; 967a2c9d4bbSAart Bik case SparseVectorizationStrategy::kAnyStorageInnerLoop: 968a2c9d4bbSAart Bik return isInner; 969a2c9d4bbSAart Bik } 970a2c9d4bbSAart Bik llvm_unreachable("unexpected vectorization strategy"); 971a2c9d4bbSAart Bik } 972a2c9d4bbSAart Bik 973a2c9d4bbSAart Bik /// Returns parallelization strategy. Any implicit loop in the Linalg operation 974a2c9d4bbSAart Bik /// that is marked "parallel" is a candidate. Whether it is actually converted 975a2c9d4bbSAart Bik /// to a parallel operation depends on the requested strategy. 976a2c9d4bbSAart Bik static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 977a2c9d4bbSAart Bik bool isSparse, bool isVector) { 978a2c9d4bbSAart Bik switch (codegen.options.parallelizationStrategy) { 979a2c9d4bbSAart Bik case SparseParallelizationStrategy::kNone: 980a2c9d4bbSAart Bik return false; 981a2c9d4bbSAart Bik case SparseParallelizationStrategy::kDenseOuterLoop: 982a2c9d4bbSAart Bik return isOuter && !isSparse && !isReduction && !isVector; 983a2c9d4bbSAart Bik case SparseParallelizationStrategy::kAnyStorageOuterLoop: 984a2c9d4bbSAart Bik return isOuter && !isReduction && !isVector; 985a2c9d4bbSAart Bik case SparseParallelizationStrategy::kDenseAnyLoop: 986a2c9d4bbSAart Bik return !isSparse && !isReduction && !isVector; 987a2c9d4bbSAart Bik case SparseParallelizationStrategy::kAnyStorageAnyLoop: 988a2c9d4bbSAart Bik return !isReduction && !isVector; 989a2c9d4bbSAart Bik } 990a2c9d4bbSAart Bik llvm_unreachable("unexpected parallelization strategy"); 991a2c9d4bbSAart Bik } 992a2c9d4bbSAart Bik 993a2c9d4bbSAart Bik /// Checks unit strides for dense tensors. The iteration graph may have ignored 994a2c9d4bbSAart Bik /// dense access patterns in order to avoid cycles (sparse access patterns are 995a2c9d4bbSAart Bik /// always placed innermost), but that means dense access has become strided. 996a2c9d4bbSAart Bik /// For now, we reject vectorization of such cases. 997a2c9d4bbSAart Bik /// TODO: implement strided load/stores on dense arrays 998a2c9d4bbSAart Bik static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 999a2c9d4bbSAart Bik unsigned idx) { 1000a2c9d4bbSAart Bik unsigned numTensors = op.getNumShapedOperands(); 1001a2c9d4bbSAart Bik for (unsigned t = 0; t < numTensors; t++) { 100296a23911SAart Bik if (!getSparseTensorEncoding(op.getShapedType(t))) { 1003a2c9d4bbSAart Bik auto map = op.getIndexingMap(t); 1004a2c9d4bbSAart Bik unsigned r = map.getNumResults(); 1005a2c9d4bbSAart Bik for (unsigned i = 0; i < r; i++) { 1006a2c9d4bbSAart Bik if (map.getDimPosition(i) == idx && i != r - 1) 1007a2c9d4bbSAart Bik return false; 1008a2c9d4bbSAart Bik } 1009a2c9d4bbSAart Bik } 1010a2c9d4bbSAart Bik } 1011a2c9d4bbSAart Bik return true; 1012a2c9d4bbSAart Bik } 1013a2c9d4bbSAart Bik 1014a2c9d4bbSAart Bik /// Generates a for-loop on a single index. 1015a2c9d4bbSAart Bik static Operation *genFor(Merger &merger, CodeGen &codegen, 1016a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1017a2c9d4bbSAart Bik bool isOuter, bool isInner, unsigned idx, 1018a2c9d4bbSAart Bik llvm::BitVector &indices) { 1019a2c9d4bbSAart Bik unsigned fb = indices.find_first(); 1020a2c9d4bbSAart Bik unsigned tensor = merger.tensor(fb); 1021a2c9d4bbSAart Bik assert(idx == merger.index(fb)); 1022a2c9d4bbSAart Bik auto iteratorTypes = op.iterator_types().getValue(); 1023a2c9d4bbSAart Bik bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); 1024a2c9d4bbSAart Bik bool isSparse = merger.isDim(fb, Dim::kSparse); 1025a2c9d4bbSAart Bik bool isVector = isVectorFor(codegen, isInner, isSparse) && 1026a2c9d4bbSAart Bik denseUnitStrides(merger, op, idx); 1027a2c9d4bbSAart Bik bool isParallel = 1028a2c9d4bbSAart Bik isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1029a2c9d4bbSAart Bik 1030a2c9d4bbSAart Bik // Prepare vector length. 1031a2c9d4bbSAart Bik if (isVector) 1032a2c9d4bbSAart Bik codegen.curVecLength = codegen.options.vectorLength; 1033a2c9d4bbSAart Bik 1034a2c9d4bbSAart Bik // Loop bounds and increment. 1035a2c9d4bbSAart Bik Location loc = op.getLoc(); 1036a2c9d4bbSAart Bik Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1037a2c9d4bbSAart Bik Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1038a2c9d4bbSAart Bik Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 1039a2c9d4bbSAart Bik 1040a2c9d4bbSAart Bik // Emit a parallel loop. 1041a2c9d4bbSAart Bik if (isParallel) { 1042a2c9d4bbSAart Bik assert(!isVector); 1043a2c9d4bbSAart Bik scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1044a2c9d4bbSAart Bik if (isSparse) 1045a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1046a2c9d4bbSAart Bik else 1047a2c9d4bbSAart Bik codegen.loops[idx] = parOp.getInductionVars()[0]; 1048a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(parOp.getBody()); 1049a2c9d4bbSAart Bik return parOp; 1050a2c9d4bbSAart Bik } 1051a2c9d4bbSAart Bik 1052a2c9d4bbSAart Bik // Emit a sequential loop, potentially with a scalarized reduction. 1053a2c9d4bbSAart Bik bool scalarRed = isInner && codegen.redExp != -1u; 1054a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 1055a2c9d4bbSAart Bik if (scalarRed) { 1056a2c9d4bbSAart Bik Value load = genReductionStart(merger, codegen, rewriter, op); 1057a2c9d4bbSAart Bik operands.push_back(load); 1058a2c9d4bbSAart Bik } 1059a2c9d4bbSAart Bik scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1060a2c9d4bbSAart Bik if (scalarRed) { 1061a2c9d4bbSAart Bik codegen.redVal = merger.exp(codegen.redExp).val = 1062a2c9d4bbSAart Bik forOp.getRegionIterArgs().front(); 1063a2c9d4bbSAart Bik } 1064a2c9d4bbSAart Bik // Assign induction variable to sparse or dense index. 1065a2c9d4bbSAart Bik Value iv = forOp.getInductionVar(); 1066a2c9d4bbSAart Bik if (isSparse) 1067a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = iv; 1068a2c9d4bbSAart Bik else 1069a2c9d4bbSAart Bik codegen.loops[idx] = iv; 1070a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(forOp.getBody()); 1071a2c9d4bbSAart Bik // Share vector iteration mask between all subsequent loads/stores. 1072a2c9d4bbSAart Bik if (isVector) 1073a2c9d4bbSAart Bik codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1074a2c9d4bbSAart Bik return forOp; 1075a2c9d4bbSAart Bik } 1076a2c9d4bbSAart Bik 1077a2c9d4bbSAart Bik /// Emit a while-loop for co-iteration over multiple indices. 1078a2c9d4bbSAart Bik static Operation *genWhile(Merger &merger, CodeGen &codegen, 1079a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1080a2c9d4bbSAart Bik unsigned idx, bool needsUniv, 1081a2c9d4bbSAart Bik llvm::BitVector &indices) { 1082a2c9d4bbSAart Bik SmallVector<Type, 4> types; 1083a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 1084a2c9d4bbSAart Bik // Construct the while-loop with a parameter for each index. 1085a2c9d4bbSAart Bik Type indexType = rewriter.getIndexType(); 1086a2c9d4bbSAart Bik for (unsigned b = 0, be = indices.size(); b < be; b++) { 1087a2c9d4bbSAart Bik if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1088a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1089a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1090a2c9d4bbSAart Bik types.push_back(indexType); 1091a2c9d4bbSAart Bik assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 1092a2c9d4bbSAart Bik "type mismatch for sparse index"); 1093a2c9d4bbSAart Bik operands.push_back(codegen.pidxs[tensor][idx]); 1094a2c9d4bbSAart Bik } 1095a2c9d4bbSAart Bik } 1096a2c9d4bbSAart Bik if (needsUniv) { 1097a2c9d4bbSAart Bik types.push_back(indexType); 1098a2c9d4bbSAart Bik assert(codegen.loops[idx].getType().isa<IndexType>() && 1099a2c9d4bbSAart Bik "type mismatch for universal index"); 1100a2c9d4bbSAart Bik operands.push_back(codegen.loops[idx]); 1101a2c9d4bbSAart Bik } 1102a2c9d4bbSAart Bik Location loc = op.getLoc(); 1103a2c9d4bbSAart Bik scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1104a2c9d4bbSAart Bik Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1105a2c9d4bbSAart Bik Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1106a2c9d4bbSAart Bik 1107a2c9d4bbSAart Bik // Build the "before" region, which effectively consists 1108a2c9d4bbSAart Bik // of a conjunction of "i < upper" tests on all induction. 1109a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&whileOp.before().front()); 1110a2c9d4bbSAart Bik Value cond; 1111a2c9d4bbSAart Bik unsigned o = 0; 1112a2c9d4bbSAart Bik for (unsigned b = 0, be = indices.size(); b < be; b++) { 1113a2c9d4bbSAart Bik if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1114a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1115a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1116a2c9d4bbSAart Bik Value op1 = before->getArgument(o); 1117a2c9d4bbSAart Bik Value op2 = codegen.highs[tensor][idx]; 1118a2c9d4bbSAart Bik Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 1119a2c9d4bbSAart Bik cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 1120a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = after->getArgument(o++); 1121a2c9d4bbSAart Bik } 1122a2c9d4bbSAart Bik } 1123a2c9d4bbSAart Bik if (needsUniv) 1124a2c9d4bbSAart Bik codegen.loops[idx] = after->getArgument(o++); 1125a2c9d4bbSAart Bik assert(o == operands.size()); 1126a2c9d4bbSAart Bik rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1127a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&whileOp.after().front()); 1128a2c9d4bbSAart Bik return whileOp; 1129a2c9d4bbSAart Bik } 1130a2c9d4bbSAart Bik 1131a2c9d4bbSAart Bik /// Generates a for-loop or a while-loop, depending on whether it implements 1132a2c9d4bbSAart Bik /// singleton iteration or co-iteration over the given conjunction. 1133a2c9d4bbSAart Bik static Operation *genLoop(Merger &merger, CodeGen &codegen, 1134a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1135a2c9d4bbSAart Bik std::vector<unsigned> &topSort, unsigned at, 1136a2c9d4bbSAart Bik bool needsUniv, llvm::BitVector &indices) { 1137a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1138a2c9d4bbSAart Bik if (indices.count() == 1) { 1139a2c9d4bbSAart Bik bool isOuter = at == 0; 1140a2c9d4bbSAart Bik bool isInner = at == topSort.size() - 1; 1141a2c9d4bbSAart Bik return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1142a2c9d4bbSAart Bik indices); 1143a2c9d4bbSAart Bik } 1144a2c9d4bbSAart Bik genReductionEnd(merger, codegen, rewriter, op); // cannot chain 1145a2c9d4bbSAart Bik return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1146a2c9d4bbSAart Bik } 1147a2c9d4bbSAart Bik 1148a2c9d4bbSAart Bik /// Generates the local variables for this loop, consisting of the sparse 1149a2c9d4bbSAart Bik /// indices, restored universal dense index, and dense positions. 1150a2c9d4bbSAart Bik static void genLocals(Merger &merger, CodeGen &codegen, 1151a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1152a2c9d4bbSAart Bik std::vector<unsigned> &topSort, unsigned at, 1153a2c9d4bbSAart Bik bool needsUniv, llvm::BitVector &locals) { 1154a2c9d4bbSAart Bik Location loc = op.getLoc(); 1155a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1156a2c9d4bbSAart Bik 1157a2c9d4bbSAart Bik // Initialize sparse indices. 1158a2c9d4bbSAart Bik Value min; 1159a2c9d4bbSAart Bik for (unsigned b = 0, be = locals.size(); b < be; b++) { 1160a2c9d4bbSAart Bik if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1161a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1162a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1163a2c9d4bbSAart Bik Value ptr = codegen.indices[tensor][idx]; 1164a2c9d4bbSAart Bik Value s = codegen.pidxs[tensor][idx]; 1165a2c9d4bbSAart Bik Value load = genLoad(codegen, rewriter, loc, ptr, s); 1166a2c9d4bbSAart Bik codegen.idxs[tensor][idx] = load; 1167a2c9d4bbSAart Bik if (!needsUniv) { 1168a2c9d4bbSAart Bik if (min) { 1169a2c9d4bbSAart Bik Value cmp = 1170a2c9d4bbSAart Bik rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 1171a2c9d4bbSAart Bik min = rewriter.create<SelectOp>(loc, cmp, load, min); 1172a2c9d4bbSAart Bik } else { 1173a2c9d4bbSAart Bik min = load; 1174a2c9d4bbSAart Bik } 1175a2c9d4bbSAart Bik } 1176a2c9d4bbSAart Bik } 1177a2c9d4bbSAart Bik } 1178a2c9d4bbSAart Bik 1179a2c9d4bbSAart Bik // Merge dense universal index over minimum. 1180a2c9d4bbSAart Bik if (min) { 1181a2c9d4bbSAart Bik assert(!needsUniv); 1182a2c9d4bbSAart Bik codegen.loops[idx] = min; 1183a2c9d4bbSAart Bik } 1184a2c9d4bbSAart Bik 1185a2c9d4bbSAart Bik // Initialize dense positions. 1186a2c9d4bbSAart Bik for (unsigned b = 0, be = locals.size(); b < be; b++) { 1187a2c9d4bbSAart Bik if (locals[b] && merger.isDim(b, Dim::kDense)) { 1188a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1189a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1190a2c9d4bbSAart Bik unsigned pat = at; 1191a2c9d4bbSAart Bik for (; pat != 0; pat--) 1192a2c9d4bbSAart Bik if (codegen.pidxs[tensor][topSort[pat - 1]]) 1193a2c9d4bbSAart Bik break; 1194a2c9d4bbSAart Bik Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 1195a2c9d4bbSAart Bik : codegen.pidxs[tensor][topSort[pat - 1]]; 1196a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = genAddress( 1197a2c9d4bbSAart Bik codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1198a2c9d4bbSAart Bik } 1199a2c9d4bbSAart Bik } 1200a2c9d4bbSAart Bik } 1201a2c9d4bbSAart Bik 1202a2c9d4bbSAart Bik /// Generates the induction structure for a while-loop. 1203a2c9d4bbSAart Bik static void genWhileInduction(Merger &merger, CodeGen &codegen, 1204a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1205a2c9d4bbSAart Bik unsigned idx, bool needsUniv, 1206a2c9d4bbSAart Bik llvm::BitVector &induction, ResultRange results) { 1207a2c9d4bbSAart Bik Location loc = op.getLoc(); 1208a2c9d4bbSAart Bik unsigned o = 0; 1209a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 1210a2c9d4bbSAart Bik Value one = rewriter.create<ConstantIndexOp>(loc, 1); 1211a2c9d4bbSAart Bik for (unsigned b = 0, be = induction.size(); b < be; b++) { 1212a2c9d4bbSAart Bik if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1213a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1214a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1215a2c9d4bbSAart Bik Value op1 = codegen.idxs[tensor][idx]; 1216a2c9d4bbSAart Bik Value op2 = codegen.loops[idx]; 1217a2c9d4bbSAart Bik Value op3 = codegen.pidxs[tensor][idx]; 1218a2c9d4bbSAart Bik Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1219a2c9d4bbSAart Bik Value add = rewriter.create<AddIOp>(loc, op3, one); 1220a2c9d4bbSAart Bik operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1221a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = results[o++]; 1222a2c9d4bbSAart Bik } 1223a2c9d4bbSAart Bik } 1224a2c9d4bbSAart Bik if (needsUniv) { 1225a2c9d4bbSAart Bik operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 1226a2c9d4bbSAart Bik codegen.loops[idx] = results[o++]; 1227a2c9d4bbSAart Bik } 1228a2c9d4bbSAart Bik assert(o == operands.size()); 1229a2c9d4bbSAart Bik rewriter.create<scf::YieldOp>(loc, operands); 1230a2c9d4bbSAart Bik } 1231a2c9d4bbSAart Bik 1232a2c9d4bbSAart Bik /// Generates a single if-statement within a while-loop. 1233a2c9d4bbSAart Bik static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1234a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1235a2c9d4bbSAart Bik unsigned idx, llvm::BitVector &conditions) { 1236a2c9d4bbSAart Bik Location loc = op.getLoc(); 1237a2c9d4bbSAart Bik Value cond; 1238a2c9d4bbSAart Bik for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1239a2c9d4bbSAart Bik if (conditions[b]) { 1240a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1241a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1242a2c9d4bbSAart Bik Value clause; 1243a2c9d4bbSAart Bik if (merger.isDim(b, Dim::kSparse)) { 1244a2c9d4bbSAart Bik Value op1 = codegen.idxs[tensor][idx]; 1245a2c9d4bbSAart Bik Value op2 = codegen.loops[idx]; 1246a2c9d4bbSAart Bik clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1247a2c9d4bbSAart Bik } else { 1248a2c9d4bbSAart Bik clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 1249a2c9d4bbSAart Bik } 1250a2c9d4bbSAart Bik cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 1251a2c9d4bbSAart Bik } 1252a2c9d4bbSAart Bik } 1253a2c9d4bbSAart Bik scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1254a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1255a2c9d4bbSAart Bik return ifOp; 1256a2c9d4bbSAart Bik } 1257a2c9d4bbSAart Bik 1258a2c9d4bbSAart Bik /// Recursively generates code while computing iteration lattices in order 1259a2c9d4bbSAart Bik /// to manage the complexity of implementing co-iteration over unions 1260a2c9d4bbSAart Bik /// and intersections of sparse iterations spaces. 1261a2c9d4bbSAart Bik static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1262a2c9d4bbSAart Bik linalg::GenericOp op, std::vector<unsigned> &topSort, 1263a2c9d4bbSAart Bik unsigned exp, unsigned at) { 1264a2c9d4bbSAart Bik // At each leaf, assign remaining tensor (sub)expression to output tensor. 1265a2c9d4bbSAart Bik if (at == topSort.size()) { 1266a2c9d4bbSAart Bik unsigned lhs = op.getNumShapedOperands() - 1; 1267a2c9d4bbSAart Bik Value rhs = genExp(merger, codegen, rewriter, op, exp); 1268a2c9d4bbSAart Bik genTensorStore(merger, codegen, rewriter, op, lhs, rhs); 1269a2c9d4bbSAart Bik return; 1270a2c9d4bbSAart Bik } 1271a2c9d4bbSAart Bik assert(codegen.curVecLength == 1); 1272a2c9d4bbSAart Bik 1273a2c9d4bbSAart Bik // Construct iteration lattices for current loop index, with L0 at top. 1274a2c9d4bbSAart Bik // Then emit initialization code for the loop sequence at this level. 1275a2c9d4bbSAart Bik // We maintain the universal dense index if dense indices are still 1276a2c9d4bbSAart Bik // in play for a non-singleton loop sequence. 1277a2c9d4bbSAart Bik Location loc = op.getLoc(); 1278a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1279a2c9d4bbSAart Bik unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx)); 1280a2c9d4bbSAart Bik unsigned lsize = merger.set(lts).size(); 1281a2c9d4bbSAart Bik assert(lsize != 0); 1282a2c9d4bbSAart Bik unsigned l0 = merger.set(lts)[0]; 1283a2c9d4bbSAart Bik unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1284a2c9d4bbSAart Bik genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1285a2c9d4bbSAart Bik bool needsUniv = false; 1286a2c9d4bbSAart Bik if (genInit(merger, codegen, rewriter, op, topSort, at, 1287a2c9d4bbSAart Bik merger.lat(l0).bits)) { 1288a2c9d4bbSAart Bik // Maintain the universal index only if it is actually 1289a2c9d4bbSAart Bik // consumed by a subsequent lattice point. 1290a2c9d4bbSAart Bik for (unsigned i = 1; i < lsize; i++) { 1291a2c9d4bbSAart Bik unsigned li = merger.set(lts)[i]; 1292a2c9d4bbSAart Bik if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1293a2c9d4bbSAart Bik needsUniv = true; 1294a2c9d4bbSAart Bik break; 1295a2c9d4bbSAart Bik } 1296a2c9d4bbSAart Bik } 1297a2c9d4bbSAart Bik } 1298a2c9d4bbSAart Bik 1299a2c9d4bbSAart Bik // Emit a loop for every lattice point L0 >= Li. 1300a2c9d4bbSAart Bik for (unsigned i = 0; i < lsize; i++) { 1301a2c9d4bbSAart Bik unsigned li = merger.set(lts)[i]; 1302a2c9d4bbSAart Bik 1303a2c9d4bbSAart Bik // Emit loop. 1304a2c9d4bbSAart Bik codegen.curVecLength = 1; 1305a2c9d4bbSAart Bik llvm::BitVector indices = merger.lat(li).simple; 1306a2c9d4bbSAart Bik Operation *loop = 1307a2c9d4bbSAart Bik genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1308a2c9d4bbSAart Bik genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1309a2c9d4bbSAart Bik merger.lat(li).bits); 1310a2c9d4bbSAart Bik 1311a2c9d4bbSAart Bik // Visit all lattices points with Li >= Lj to generate the 1312a2c9d4bbSAart Bik // loop-body, possibly with if statements for coiteration. 1313a2c9d4bbSAart Bik bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1314a2c9d4bbSAart Bik for (unsigned j = 0; j < lsize; j++) { 1315a2c9d4bbSAart Bik unsigned lj = merger.set(lts)[j]; 1316a2c9d4bbSAart Bik unsigned ej = merger.lat(lj).exp; 1317a2c9d4bbSAart Bik if (li == lj || merger.latGT(li, lj)) { 1318a2c9d4bbSAart Bik // Recurse into body of each branch. 1319a2c9d4bbSAart Bik if (isWhile) { 1320a2c9d4bbSAart Bik scf::IfOp ifOp = 1321a2c9d4bbSAart Bik genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1322a2c9d4bbSAart Bik genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1323a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1324a2c9d4bbSAart Bik } else { 1325a2c9d4bbSAart Bik genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1326a2c9d4bbSAart Bik } 1327a2c9d4bbSAart Bik } 1328a2c9d4bbSAart Bik } 1329a2c9d4bbSAart Bik 1330a2c9d4bbSAart Bik // Wrap-up induction and restore insertion point. 1331a2c9d4bbSAart Bik if (isWhile) { 1332a2c9d4bbSAart Bik scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1333a2c9d4bbSAart Bik rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1334a2c9d4bbSAart Bik genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1335a2c9d4bbSAart Bik merger.lat(li).bits, whileOp.results()); 1336a2c9d4bbSAart Bik } else { 1337a2c9d4bbSAart Bik needsUniv = false; 1338a2c9d4bbSAart Bik if (codegen.redVal) { 1339a2c9d4bbSAart Bik rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1340a2c9d4bbSAart Bik codegen.redVal = loop->getResult(0); 1341a2c9d4bbSAart Bik } 1342a2c9d4bbSAart Bik } 1343a2c9d4bbSAart Bik rewriter.setInsertionPointAfter(loop); 1344a2c9d4bbSAart Bik } 1345a2c9d4bbSAart Bik 1346a2c9d4bbSAart Bik // Wrap-up loop sequence. 1347a2c9d4bbSAart Bik codegen.curVecLength = 1; 1348a2c9d4bbSAart Bik genReductionEnd(merger, codegen, rewriter, op); 1349a2c9d4bbSAart Bik genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1350a2c9d4bbSAart Bik codegen.loops[idx] = Value(); 1351a2c9d4bbSAart Bik } 1352a2c9d4bbSAart Bik 1353a2c9d4bbSAart Bik namespace { 1354a2c9d4bbSAart Bik 1355a2c9d4bbSAart Bik /// Sparse rewriting rule for generic Lingalg operation. 1356a2c9d4bbSAart Bik struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1357a2c9d4bbSAart Bik public: 1358a2c9d4bbSAart Bik GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1359a2c9d4bbSAart Bik : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1360a2c9d4bbSAart Bik 1361a2c9d4bbSAart Bik LogicalResult matchAndRewrite(linalg::GenericOp op, 1362a2c9d4bbSAart Bik PatternRewriter &rewriter) const override { 1363a2c9d4bbSAart Bik // Detects sparse annotations and translate the per-dimension sparsity 1364a2c9d4bbSAart Bik // information for all tensors to loop indices in the kernel. 1365a2c9d4bbSAart Bik assert(op.getNumOutputs() == 1); 1366a2c9d4bbSAart Bik unsigned numTensors = op.getNumShapedOperands(); 1367a2c9d4bbSAart Bik unsigned numLoops = op.iterator_types().getValue().size(); 1368a2c9d4bbSAart Bik Merger merger(numTensors, numLoops); 1369*bf9ef3efSAart Bik if (!findSparseAnnotations(merger, op)) 1370*bf9ef3efSAart Bik return failure(); 1371a2c9d4bbSAart Bik 1372a2c9d4bbSAart Bik // Computes a topologically sorted iteration graph to ensure 1373a2c9d4bbSAart Bik // tensors are visited in natural index order. Fails on cycles. 1374a2c9d4bbSAart Bik // This assumes that higher-level passes have already put the 1375a2c9d4bbSAart Bik // tensors in each tensor expression in a feasible order. 1376a2c9d4bbSAart Bik std::vector<unsigned> topSort; 1377a2c9d4bbSAart Bik if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && 1378a2c9d4bbSAart Bik !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) 1379a2c9d4bbSAart Bik return failure(); 1380a2c9d4bbSAart Bik 1381a2c9d4bbSAart Bik // Finds the terminating yield statement and builds the tensor 1382a2c9d4bbSAart Bik // expression for the Linalg operation in SSA form. 1383a2c9d4bbSAart Bik Operation *yield = op.region().front().getTerminator(); 1384a2c9d4bbSAart Bik Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0)); 1385a2c9d4bbSAart Bik if (!exp.hasValue()) 1386a2c9d4bbSAart Bik return failure(); // build failure 1387a2c9d4bbSAart Bik 1388a2c9d4bbSAart Bik // Recursively generates code. 1389a2c9d4bbSAart Bik CodeGen codegen(options, numTensors, numLoops); 1390a2c9d4bbSAart Bik genBuffers(merger, codegen, rewriter, op); 1391a2c9d4bbSAart Bik genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1392a2c9d4bbSAart Bik Value result = rewriter.create<memref::TensorLoadOp>( 1393a2c9d4bbSAart Bik op.getLoc(), codegen.buffers.back()); 1394a2c9d4bbSAart Bik rewriter.replaceOp(op, result); 1395a2c9d4bbSAart Bik return success(); 1396a2c9d4bbSAart Bik } 1397a2c9d4bbSAart Bik 1398a2c9d4bbSAart Bik private: 1399a2c9d4bbSAart Bik /// Options to control sparse code generation. 1400a2c9d4bbSAart Bik SparsificationOptions options; 1401a2c9d4bbSAart Bik }; 1402a2c9d4bbSAart Bik 1403a2c9d4bbSAart Bik } // namespace 1404a2c9d4bbSAart Bik 1405a2c9d4bbSAart Bik /// Populates the given patterns list with rewriting rules required for 1406a2c9d4bbSAart Bik /// the sparsification of linear algebra operations. 1407a2c9d4bbSAart Bik void mlir::populateSparsificationPatterns( 1408a2c9d4bbSAart Bik RewritePatternSet &patterns, const SparsificationOptions &options) { 1409a2c9d4bbSAart Bik patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1410a2c9d4bbSAart Bik } 1411