196a23911SAart Bik //===- Sparsification.cpp - Implementation of sparsification --------------===// 2a2c9d4bbSAart Bik // 3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information. 5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a2c9d4bbSAart Bik // 7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 8a2c9d4bbSAart Bik // 9160399c7SAart Bik // This file implements converting sparse tensor types to actual sparse code. 10a2c9d4bbSAart Bik // 11a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 12a2c9d4bbSAart Bik 1376a18618SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 14*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h" 1766f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 18a2c9d4bbSAart Bik #include "mlir/Dialect/SCF/SCF.h" 1976a18618SMatthias Springer #include "mlir/Dialect/SCF/Transforms.h" 20a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 22744146f6SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 23a2c9d4bbSAart Bik #include "mlir/Dialect/StandardOps/IR/Ops.h" 24a2c9d4bbSAart Bik #include "mlir/Dialect/Vector/VectorOps.h" 25a2c9d4bbSAart Bik #include "mlir/IR/Matchers.h" 2696a23911SAart Bik #include "mlir/IR/TensorEncoding.h" 27a2c9d4bbSAart Bik #include "llvm/ADT/SmallBitVector.h" 28a2c9d4bbSAart Bik 29a2c9d4bbSAart Bik using namespace mlir; 3096a23911SAart Bik using namespace mlir::sparse_tensor; 31a2c9d4bbSAart Bik 325da21338SAart Bik //===----------------------------------------------------------------------===// 335da21338SAart Bik // Declarations of data structures. 345da21338SAart Bik //===----------------------------------------------------------------------===// 355da21338SAart Bik 36a2c9d4bbSAart Bik namespace { 37a2c9d4bbSAart Bik 38b6d1a31cSAart Bik // Iteration graph sorting. 39b6d1a31cSAart Bik enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 40b6d1a31cSAart Bik 415da21338SAart Bik // Reduction kinds. 425da21338SAart Bik enum Reduction { kSum, kProduct, kAnd, kOr, kXor }; 435da21338SAart Bik 44a2c9d4bbSAart Bik // Code generation. 45a2c9d4bbSAart Bik struct CodeGen { 4696a23911SAart Bik CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 47a2c9d4bbSAart Bik : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 48a2c9d4bbSAart Bik pointers(numTensors, std::vector<Value>(numLoops)), 49a2c9d4bbSAart Bik indices(numTensors, std::vector<Value>(numLoops)), 50a2c9d4bbSAart Bik highs(numTensors, std::vector<Value>(numLoops)), 51a2c9d4bbSAart Bik pidxs(numTensors, std::vector<Value>(numLoops)), 52a2c9d4bbSAart Bik idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 53a2c9d4bbSAart Bik curVecLength(1), curVecMask() {} 54a2c9d4bbSAart Bik /// Sparsification options. 5596a23911SAart Bik SparsificationOptions options; 56a2c9d4bbSAart Bik /// Universal dense indices and upper bounds (by index). The loops array 57a2c9d4bbSAart Bik /// is updated with the value of the universal dense index in the current 58a2c9d4bbSAart Bik /// loop. The sizes array is set once with the inferred dimension sizes. 59a2c9d4bbSAart Bik std::vector<Value> loops; 60a2c9d4bbSAart Bik std::vector<Value> sizes; 61a2c9d4bbSAart Bik /// Buffers for storing dense and sparse numerical values (by tensor). 62a2c9d4bbSAart Bik /// This array is set once during bufferization of all tensors. 63a2c9d4bbSAart Bik std::vector<Value> buffers; 64a2c9d4bbSAart Bik /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 65a2c9d4bbSAart Bik /// This array is set once during bufferization of all sparse tensors. 66a2c9d4bbSAart Bik std::vector<std::vector<Value>> pointers; 67a2c9d4bbSAart Bik std::vector<std::vector<Value>> indices; 68a2c9d4bbSAart Bik /// Sparse iteration information (by tensor and index). These arrays 69a2c9d4bbSAart Bik /// are updated to remain current within the current loop. 70a2c9d4bbSAart Bik std::vector<std::vector<Value>> highs; 71a2c9d4bbSAart Bik std::vector<std::vector<Value>> pidxs; 72a2c9d4bbSAart Bik std::vector<std::vector<Value>> idxs; 73a2c9d4bbSAart Bik /// Current reduction, updated during code generation. When indices of a 74a2c9d4bbSAart Bik /// reduction are exhausted, all inner loops can "scalarize" the reduction. 75a2c9d4bbSAart Bik // TODO: currently only done for (a chain of) innermost for-loops, where it 76a2c9d4bbSAart Bik // is most effective; we could generalize to more outer and while-loops. 77a2c9d4bbSAart Bik unsigned redExp; 78a2c9d4bbSAart Bik Value redVal; 795da21338SAart Bik Reduction redKind; 80a2c9d4bbSAart Bik // Current vector length and mask. 81a2c9d4bbSAart Bik unsigned curVecLength; 82a2c9d4bbSAart Bik Value curVecMask; 83a2c9d4bbSAart Bik }; 84a2c9d4bbSAart Bik 85a2c9d4bbSAart Bik } // namespace 86a2c9d4bbSAart Bik 875da21338SAart Bik //===----------------------------------------------------------------------===// 885da21338SAart Bik // Sparse compiler analysis methods. 895da21338SAart Bik //===----------------------------------------------------------------------===// 905da21338SAart Bik 915da21338SAart Bik /// Helper method to apply dimension ordering permutation. 925da21338SAart Bik static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 93c194b49cSAart Bik if (enc) { 94c194b49cSAart Bik auto order = enc.getDimOrdering(); 95c194b49cSAart Bik if (order) { 96c194b49cSAart Bik assert(order.isPermutation()); 97c194b49cSAart Bik return order.getDimPosition(d); 98c194b49cSAart Bik } 99c194b49cSAart Bik } 100c194b49cSAart Bik return d; 101c194b49cSAart Bik } 102c194b49cSAart Bik 1035da21338SAart Bik /// Helper method to translate dim level type to internal representation. 1045da21338SAart Bik static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 10596a23911SAart Bik if (enc) { 10696a23911SAart Bik SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 10796a23911SAart Bik if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 10896a23911SAart Bik return Dim::kSparse; 10996a23911SAart Bik if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 11096a23911SAart Bik return Dim::kSingle; 11196a23911SAart Bik } 11296a23911SAart Bik return Dim::kDense; 11396a23911SAart Bik } 11496a23911SAart Bik 115b1d44e59SAart Bik /// Helper method to inspect affine expressions. Rejects cases where the 116b1d44e59SAart Bik /// same index is used in more than one dimension of a tensor. Also rejects 117b1d44e59SAart Bik /// affine expressions that are not a direct index for annotated tensors. 118b1d44e59SAart Bik /// TODO: accept more affine cases for sparse tensors 119b1d44e59SAart Bik static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 120b1d44e59SAart Bik bool isDense) { 121b1d44e59SAart Bik switch (a.getKind()) { 122b1d44e59SAart Bik case AffineExprKind::DimId: { 123b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 124b1d44e59SAart Bik if (!merger.isDim(tensor, idx, Dim::kUndef)) 125b1d44e59SAart Bik return false; // used more than once 126b1d44e59SAart Bik merger.setDim(tensor, idx, dim); 127b1d44e59SAart Bik return true; 128b1d44e59SAart Bik } 129b1d44e59SAart Bik case AffineExprKind::Add: 130b1d44e59SAart Bik case AffineExprKind::Mul: { 131b1d44e59SAart Bik if (!isDense) 132b1d44e59SAart Bik return false; 133b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 134b1d44e59SAart Bik return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 135b1d44e59SAart Bik findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 136b1d44e59SAart Bik } 137b1d44e59SAart Bik case AffineExprKind::Constant: 138b1d44e59SAart Bik return isDense; 139b1d44e59SAart Bik default: 140b1d44e59SAart Bik return false; 141b1d44e59SAart Bik } 142b1d44e59SAart Bik } 143b1d44e59SAart Bik 14496a23911SAart Bik /// Helper method to inspect sparse encodings in the tensor types. 145a2c9d4bbSAart Bik /// Fills the per-dimension sparsity information for all tensors. 146b1d44e59SAart Bik /// Returns true if the sparse annotations and affine subscript 147b1d44e59SAart Bik /// expressions of all tensors are admissable. Returns false if 148b1d44e59SAart Bik /// no annotations are found or inadmissable constructs occur. 149bf9ef3efSAart Bik static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 150bf9ef3efSAart Bik bool annotated = false; 1512f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 1522f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 1532f2b5b7dSTobias Gysi auto enc = getSparseTensorEncoding(t->get().getType()); 154727a63e0SAart Bik if (enc) 155bf9ef3efSAart Bik annotated = true; 1562f2b5b7dSTobias Gysi assert(map.getNumResults() == op.getRank(t)); 157c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 158b1d44e59SAart Bik unsigned tensor = t->getOperandNumber(); 159b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 160b1d44e59SAart Bik if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 161b1d44e59SAart Bik return false; // inadmissable affine expression 162a2c9d4bbSAart Bik } 163a2c9d4bbSAart Bik } 164bf9ef3efSAart Bik return annotated; 165a2c9d4bbSAart Bik } 166a2c9d4bbSAart Bik 167a2c9d4bbSAart Bik /// A DFS helper to compute a topological sort. Note that recursion is 168a2c9d4bbSAart Bik /// bounded by the number of implicit loops, which is always small. 169a2c9d4bbSAart Bik /// Returns false when a cycle is detected. 170a2c9d4bbSAart Bik static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 171a2c9d4bbSAart Bik std::vector<unsigned> &topSort, 172a2c9d4bbSAart Bik std::vector<std::vector<bool>> &adjM) { 173a2c9d4bbSAart Bik if (visit[i] != 0) 174a2c9d4bbSAart Bik return visit[i] != 1; // 1 denotes cycle! 175a2c9d4bbSAart Bik visit[i] = 1; 176a2c9d4bbSAart Bik for (unsigned j = 0, e = visit.size(); j < e; j++) 177a2c9d4bbSAart Bik if (adjM[i][j]) 178a2c9d4bbSAart Bik if (!topSortDFS(j, visit, topSort, adjM)) 179a2c9d4bbSAart Bik return false; 180a2c9d4bbSAart Bik visit[i] = 2; 181a2c9d4bbSAart Bik topSort.push_back(i); 182a2c9d4bbSAart Bik return true; 183a2c9d4bbSAart Bik } 184a2c9d4bbSAart Bik 185b1d44e59SAart Bik /// Helper method to add all constraints from the indices in one affine 186b1d44e59SAart Bik /// expression before all indices in the other affine expression. For 187b1d44e59SAart Bik /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 188b1d44e59SAart Bik static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 189b1d44e59SAart Bik AffineExpr a, AffineExpr b, unsigned fidx) { 190b1d44e59SAart Bik switch (a.getKind()) { 191b1d44e59SAart Bik case AffineExprKind::DimId: { 192b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 193b1d44e59SAart Bik if (b) 194b1d44e59SAart Bik addAffineOrderings(adjM, b, AffineExpr(), idx); 195b1d44e59SAart Bik else 196b1d44e59SAart Bik adjM[fidx][idx] = true; 197b1d44e59SAart Bik break; 198b1d44e59SAart Bik } 199b1d44e59SAart Bik case AffineExprKind::Add: 200b1d44e59SAart Bik case AffineExprKind::Mul: { 201b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 202b1d44e59SAart Bik addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 203b1d44e59SAart Bik addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 204b1d44e59SAart Bik break; 205b1d44e59SAart Bik } 206b1d44e59SAart Bik default: 207b1d44e59SAart Bik break; 208b1d44e59SAart Bik } 209b1d44e59SAart Bik } 210b1d44e59SAart Bik 211a2c9d4bbSAart Bik /// Computes a topologically sorted iteration graph for the linalg operation. 212a2c9d4bbSAart Bik /// Ensures all tensors are visited in natural index order. This is essential 213a2c9d4bbSAart Bik /// for sparse storage formats since these only support access along fixed 214a2c9d4bbSAart Bik /// dimensions. Even for dense storage formats, however, the natural index 215a2c9d4bbSAart Bik /// order yields innermost unit-stride access with better spatial locality. 216a2c9d4bbSAart Bik static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 217a2c9d4bbSAart Bik std::vector<unsigned> &topSort, 218b6d1a31cSAart Bik unsigned mask) { 219a2c9d4bbSAart Bik // Set up an n x n from/to adjacency matrix of the iteration graph 220a2c9d4bbSAart Bik // for the implicit loop indices i_0 .. i_n-1. 221a2c9d4bbSAart Bik unsigned n = op.getNumLoops(); 222a2c9d4bbSAart Bik std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 223a2c9d4bbSAart Bik 224a2c9d4bbSAart Bik // Iterate over the indexing maps of every tensor in the tensor expression. 2252f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 2262f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 2272f2b5b7dSTobias Gysi auto enc = getSparseTensorEncoding(t->get().getType()); 228a2c9d4bbSAart Bik assert(map.getNumDims() == n); 229b6d1a31cSAart Bik // Skip dense tensor constraints when not requested. 230b6d1a31cSAart Bik if (!(mask & SortMask::kIncludeDense) && !enc) 231a2c9d4bbSAart Bik continue; 232c194b49cSAart Bik // Each tensor expression and optional dimension ordering (row-major 233c194b49cSAart Bik // by default) puts an ordering constraint on the loop indices. For 234c194b49cSAart Bik // example, the tensor expresion A_ijk forces the ordering i < j < k 235c194b49cSAart Bik // on the loop indices if no explicit dimension ordering is given. 236c194b49cSAart Bik for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 237b1d44e59SAart Bik AffineExpr f = map.getResult(perm(enc, d - 1)); 238b1d44e59SAart Bik AffineExpr t = map.getResult(perm(enc, d)); 239b1d44e59SAart Bik addAffineOrderings(adjM, f, t, 0); 240a2c9d4bbSAart Bik } 241b6d1a31cSAart Bik // Push unrelated loops into sparse iteration space, so these 242b6d1a31cSAart Bik // will be skipped more often. 243b6d1a31cSAart Bik if (mask & SortMask::kIncludeUndef) { 244b6d1a31cSAart Bik unsigned tensor = t->getOperandNumber(); 245b6d1a31cSAart Bik for (unsigned i = 0; i < n; i++) 246b6d1a31cSAart Bik if (merger.isDim(tensor, i, Dim::kSparse)) 247b6d1a31cSAart Bik for (unsigned j = 0; j < n; j++) 248b6d1a31cSAart Bik if (merger.isDim(tensor, j, Dim::kUndef)) 249b6d1a31cSAart Bik adjM[i][j] = true; 250b6d1a31cSAart Bik } 251a2c9d4bbSAart Bik } 252a2c9d4bbSAart Bik 253a2c9d4bbSAart Bik // Topologically sort the iteration graph to determine loop order. 254a2c9d4bbSAart Bik // Report failure for a cyclic iteration graph. 255a2c9d4bbSAart Bik topSort.clear(); 256a2c9d4bbSAart Bik topSort.reserve(n); 257a2c9d4bbSAart Bik std::vector<unsigned> visit(n, 0); 258a2c9d4bbSAart Bik for (unsigned i = 0; i < n; i++) 259a2c9d4bbSAart Bik if (visit[i] == 0) 260a2c9d4bbSAart Bik if (!topSortDFS(i, visit, topSort, adjM)) 261a2c9d4bbSAart Bik return false; // cycle! 262a2c9d4bbSAart Bik std::reverse(std::begin(topSort), std::end(topSort)); 263a2c9d4bbSAart Bik return true; 264a2c9d4bbSAart Bik } 265a2c9d4bbSAart Bik 26636b66ab9SAart Bik /// Returns true when the tensor expression is admissable for codegen. 26736b66ab9SAart Bik /// Since all sparse input tensors are admissable, we just need to check 26836b66ab9SAart Bik /// whether the output tensor in the tensor expression codegen is admissable. 26936b66ab9SAart Bik static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 27036b66ab9SAart Bik unsigned exp) { 27136b66ab9SAart Bik OpOperand *lhs = op.getOutputOperand(0); 27236b66ab9SAart Bik unsigned tensor = lhs->getOperandNumber(); 27336b66ab9SAart Bik auto enc = getSparseTensorEncoding(lhs->get().getType()); 27436b66ab9SAart Bik // An non-annotated output tensor is assumed dense, and becomes a random 275b1d44e59SAart Bik // access n-dim memref. Admissable since insertions cannot occur. 27636b66ab9SAart Bik if (!enc) 27736b66ab9SAart Bik return true; 27836b66ab9SAart Bik // An all-dense annotated "sparse" output tensor becomes a linearized random 27936b66ab9SAart Bik // access 1-dim memref. Also admissable since insertions cannot occur. 28036b66ab9SAart Bik bool allDense = true; 28136b66ab9SAart Bik unsigned numLoops = op.iterator_types().getValue().size(); 28236b66ab9SAart Bik for (unsigned i = 0; i < numLoops; i++) 28336b66ab9SAart Bik if (merger.isDim(tensor, i, Dim::kSparse)) { 28436b66ab9SAart Bik allDense = false; 28536b66ab9SAart Bik break; 28636b66ab9SAart Bik } 28736b66ab9SAart Bik if (allDense) 28836b66ab9SAart Bik return true; 28936b66ab9SAart Bik // A tensor expression with a sparse output tensor that changes its values 29036b66ab9SAart Bik // but not its nonzero structure, an operation called "simply dynamic" in 29136b66ab9SAart Bik // [Bik96,Ch9], is also admissable without special codegen. 29245b3cfe8SAart Bik if (merger.isConjunction(tensor, exp)) 29336b66ab9SAart Bik return true; 29436b66ab9SAart Bik // Reject for now since this requires changes to the nonzero structure. 29536b66ab9SAart Bik // TODO: implement "workspaces" [Kjolstad2019] 29636b66ab9SAart Bik return false; 29736b66ab9SAart Bik } 29836b66ab9SAart Bik 2995da21338SAart Bik //===----------------------------------------------------------------------===// 3005da21338SAart Bik // Sparse compiler synthesis methods. 3015da21338SAart Bik //===----------------------------------------------------------------------===// 3025da21338SAart Bik 3035da21338SAart Bik /// Maps reduction kind to name encoding. 3045da21338SAart Bik static StringRef getReductionName(Reduction kind) { 3055da21338SAart Bik switch (kind) { 3065da21338SAart Bik case kSum: 3075da21338SAart Bik return "add"; 3085da21338SAart Bik case kProduct: 3095da21338SAart Bik return "mul"; 3105da21338SAart Bik case kAnd: 3115da21338SAart Bik return "and"; 3125da21338SAart Bik case kOr: 3135da21338SAart Bik return "or"; 3145da21338SAart Bik case kXor: 3155da21338SAart Bik return "xor"; 3165da21338SAart Bik } 3175da21338SAart Bik llvm_unreachable("unknown reduction kind"); 3185da21338SAart Bik } 3195da21338SAart Bik 3205da21338SAart Bik /// Maps operation to reduction. 3215da21338SAart Bik static Reduction getReduction(Kind kind) { 3225da21338SAart Bik switch (kind) { 3235da21338SAart Bik case Kind::kAddF: 3245da21338SAart Bik case Kind::kAddI: 3255da21338SAart Bik case Kind::kSubF: 3265da21338SAart Bik case Kind::kSubI: 3275da21338SAart Bik return kSum; 3285da21338SAart Bik case Kind::kMulF: 3295da21338SAart Bik case Kind::kMulI: 3305da21338SAart Bik return kProduct; 3315da21338SAart Bik case Kind::kAndI: 3325da21338SAart Bik return kAnd; 3335da21338SAart Bik case Kind::kOrI: 3345da21338SAart Bik return kOr; 3355da21338SAart Bik case Kind::kXorI: 3365da21338SAart Bik return kXor; 3375da21338SAart Bik default: 3385da21338SAart Bik llvm_unreachable("unexpected reduction operator"); 3395da21338SAart Bik } 3405da21338SAart Bik } 3415da21338SAart Bik 3425da21338SAart Bik /// Generates an initial value for a vector reductions, following the scheme 3435da21338SAart Bik /// given in Chapter 5 of "The Software Vectorization Handbook", where the 3445da21338SAart Bik /// initial scalar value is correctly embedded in the vector reduction value, 3455da21338SAart Bik /// and a straightforward horizontal reduction will complete the operation. 3465da21338SAart Bik static Value genReductionInit(PatternRewriter &rewriter, Location loc, 3475da21338SAart Bik Reduction kind, VectorType vtp, Value r) { 3485da21338SAart Bik switch (kind) { 3495da21338SAart Bik case kSum: 3505da21338SAart Bik case kXor: { 3515da21338SAart Bik // Initialize reduction vector to: | 0 | .. | 0 | r | 3525da21338SAart Bik Attribute zero = rewriter.getZeroAttr(vtp); 3535da21338SAart Bik Value vec = rewriter.create<ConstantOp>(loc, vtp, zero); 3545da21338SAart Bik return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 3555da21338SAart Bik } 3565da21338SAart Bik case kProduct: { 3575da21338SAart Bik // Initialize reduction vector to: | 1 | .. | 1 | r | 3585da21338SAart Bik Type etp = vtp.getElementType(); 3595da21338SAart Bik Attribute one; 3605da21338SAart Bik if (etp.isa<FloatType>()) 3615da21338SAart Bik one = rewriter.getFloatAttr(etp, 1.0); 3625da21338SAart Bik else 3635da21338SAart Bik one = rewriter.getIntegerAttr(etp, 1); 3645da21338SAart Bik Value vec = 3655da21338SAart Bik rewriter.create<ConstantOp>(loc, vtp, DenseElementsAttr::get(vtp, one)); 3665da21338SAart Bik return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0); 3675da21338SAart Bik } 3685da21338SAart Bik case kAnd: 3695da21338SAart Bik case kOr: 3705da21338SAart Bik // Initialize reduction vector to: | r | .. | r | r | 3715da21338SAart Bik return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 3725da21338SAart Bik } 3735da21338SAart Bik llvm_unreachable("unknown reduction kind"); 3745da21338SAart Bik } 3755da21338SAart Bik 376a2c9d4bbSAart Bik /// Maps sparse integer option to actual integral storage type. 37796a23911SAart Bik static Type genIntType(PatternRewriter &rewriter, unsigned width) { 37896a23911SAart Bik if (width == 0) 379a2c9d4bbSAart Bik return rewriter.getIndexType(); 38096a23911SAart Bik return rewriter.getIntegerType(width); 381a2c9d4bbSAart Bik } 382a2c9d4bbSAart Bik 3835879da49SAart Bik /// Detects in-place annotation on tensor argument. 3845879da49SAart Bik static bool getInPlace(Value val) { 3855879da49SAart Bik if (auto arg = val.dyn_cast<BlockArgument>()) 3865879da49SAart Bik if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 3875879da49SAart Bik if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 3885879da49SAart Bik arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 3895879da49SAart Bik return attr.getValue(); 3905879da49SAart Bik return false; 3915879da49SAart Bik } 3925879da49SAart Bik 393ec97a205SAart Bik /// Generates buffer for the output tensor. Note that all sparse kernels 394ec97a205SAart Bik /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 395ec97a205SAart Bik /// the output buffer is already initialized to all zeroes and only nonzeroes 396ec97a205SAart Bik /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 397ec97a205SAart Bik /// only nonzeroes values are used for the updates and no assumption on the 398ec97a205SAart Bik /// original contents of the output buffer is necessary.. 399a2c9d4bbSAart Bik static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 400a2c9d4bbSAart Bik linalg::GenericOp op, MemRefType denseTp, 401a2c9d4bbSAart Bik ArrayRef<Value> args) { 402a2c9d4bbSAart Bik Location loc = op.getLoc(); 4032f2b5b7dSTobias Gysi Value tensor = op.getOutputOperand(0)->get(); 404a2c9d4bbSAart Bik // The output tensor simply could materialize from the buffer that will 405a2c9d4bbSAart Bik // be generated for the tensor present in the outs() clause. This has 406a2c9d4bbSAart Bik // the major advantage that the sparse kernel only updates the nonzero 4075879da49SAart Bik // positions for the output tensor. 4085879da49SAart Bik if (getInPlace(tensor)) 409a2c9d4bbSAart Bik return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 410a2c9d4bbSAart Bik // By default, a new buffer is allocated which is initialized to the 411a2c9d4bbSAart Bik // tensor defined in the outs() clause. This is always correct but 412a2c9d4bbSAart Bik // introduces a dense initialization component that may negatively 413ec97a205SAart Bik // impact the running complexity of the sparse kernel. If the tensor 414ec97a205SAart Bik // materializes within this method, we need to preserve the zero 415ec97a205SAart Bik // initialization assumption of all sparse output buffers. 416ec97a205SAart Bik if (auto init = tensor.getDefiningOp<linalg::InitTensorOp>()) { 417ec97a205SAart Bik Type tp = denseTp.getElementType(); 418ec97a205SAart Bik Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 419ec97a205SAart Bik Value zero = rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 420ec97a205SAart Bik rewriter.create<linalg::FillOp>(loc, zero, alloc); 421ec97a205SAart Bik return alloc; 422ec97a205SAart Bik } 423a2c9d4bbSAart Bik Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 424a2c9d4bbSAart Bik Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 42568ac2e53SAart Bik rewriter.create<memref::CopyOp>(loc, init, alloc); 426a2c9d4bbSAart Bik return alloc; 427a2c9d4bbSAart Bik } 428a2c9d4bbSAart Bik 429a2c9d4bbSAart Bik /// Local bufferization of all dense and sparse data structures. 430a2c9d4bbSAart Bik /// This code enables testing the first prototype sparse compiler. 431a2c9d4bbSAart Bik // TODO: replace this with a proliferated bufferization strategy 432727a63e0SAart Bik static bool genBuffers(Merger &merger, CodeGen &codegen, 433a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op) { 434a2c9d4bbSAart Bik Location loc = op.getLoc(); 4352f2b5b7dSTobias Gysi assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 436a2c9d4bbSAart Bik // For every tensor, find lower and upper bound on dimensions, set the 437a2c9d4bbSAart Bik // same bounds on loop indices, and obtain dense or sparse buffer(s). 438a2c9d4bbSAart Bik SmallVector<Value, 4> args; 4392f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 440727a63e0SAart Bik unsigned tensor = t->getOperandNumber(); 4412f2b5b7dSTobias Gysi auto shape = op.getShape(t); 4422f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 4432f2b5b7dSTobias Gysi auto enc = getSparseTensorEncoding(t->get().getType()); 444a2c9d4bbSAart Bik // Scan all dimensions of current tensor. 445a2c9d4bbSAart Bik args.clear(); 446c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 447b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 448b1d44e59SAart Bik if (a.getKind() != AffineExprKind::DimId) 449b1d44e59SAart Bik continue; // compound 450b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 451a2c9d4bbSAart Bik // Handle sparse storage schemes. 452727a63e0SAart Bik if (merger.isDim(tensor, idx, Dim::kSparse)) { 453a2c9d4bbSAart Bik auto dynShape = {ShapedType::kDynamicSize}; 454a2c9d4bbSAart Bik auto ptrTp = MemRefType::get( 45596a23911SAart Bik dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 456a2c9d4bbSAart Bik auto indTp = MemRefType::get( 45796a23911SAart Bik dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 458*a54f4eaeSMogball Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d); 459a2c9d4bbSAart Bik // Generate sparse primitives to obtains pointer and indices. 460727a63e0SAart Bik codegen.pointers[tensor][idx] = 4612f2b5b7dSTobias Gysi rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 462727a63e0SAart Bik codegen.indices[tensor][idx] = 4632f2b5b7dSTobias Gysi rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 464a2c9d4bbSAart Bik } 465d37d72eaSAart Bik // Find upper bound in current dimension. 466817303efSAart Bik unsigned p = perm(enc, d); 467d37d72eaSAart Bik Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 468d37d72eaSAart Bik if (shape[p] == MemRefType::kDynamicSize) 469a2c9d4bbSAart Bik args.push_back(up); 470817303efSAart Bik assert(codegen.highs[tensor][idx] == nullptr); 471727a63e0SAart Bik codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 472a2c9d4bbSAart Bik } 473727a63e0SAart Bik // Perform the required bufferization. Dense inputs materialize 474727a63e0SAart Bik // from the input tensors. Dense outputs need special handling. 475727a63e0SAart Bik // Sparse inputs use sparse primitives to obtain the values. 476727a63e0SAart Bik // We also accept in-place all-dense annotated "sparse" outputs. 4772f2b5b7dSTobias Gysi Type elementType = getElementTypeOrSelf(t->get().getType()); 47896a23911SAart Bik if (!enc) { 479727a63e0SAart Bik // Non-annotated dense tensors. 4802f2b5b7dSTobias Gysi auto denseTp = MemRefType::get(shape, elementType); 481727a63e0SAart Bik if (tensor < op.getNumInputs()) 482727a63e0SAart Bik codegen.buffers[tensor] = 4832f2b5b7dSTobias Gysi rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 484a2c9d4bbSAart Bik else 485727a63e0SAart Bik codegen.buffers[tensor] = 486a2c9d4bbSAart Bik genOutputBuffer(codegen, rewriter, op, denseTp, args); 487a2c9d4bbSAart Bik } else { 488727a63e0SAart Bik // Annotated sparse tensors. 489727a63e0SAart Bik if (tensor == op.getNumInputs() && !getInPlace(t->get())) 490727a63e0SAart Bik return false; // reject output if not in-place 491a2c9d4bbSAart Bik auto dynShape = {ShapedType::kDynamicSize}; 4922f2b5b7dSTobias Gysi auto sparseTp = MemRefType::get(dynShape, elementType); 493727a63e0SAart Bik codegen.buffers[tensor] = 4942f2b5b7dSTobias Gysi rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 495a2c9d4bbSAart Bik } 496a2c9d4bbSAart Bik } 497727a63e0SAart Bik return true; 498a2c9d4bbSAart Bik } 499a2c9d4bbSAart Bik 500a2c9d4bbSAart Bik /// Constructs vector type. 501a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Type etp) { 502a2c9d4bbSAart Bik return VectorType::get(codegen.curVecLength, etp); 503a2c9d4bbSAart Bik } 504a2c9d4bbSAart Bik 505a2c9d4bbSAart Bik /// Constructs vector type from pointer. 506a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Value ptr) { 507a2c9d4bbSAart Bik return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 508a2c9d4bbSAart Bik } 509a2c9d4bbSAart Bik 510a2c9d4bbSAart Bik /// Constructs vector iteration mask. 511a2c9d4bbSAart Bik static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 512a2c9d4bbSAart Bik Value iv, Value lo, Value hi, Value step) { 513a2c9d4bbSAart Bik Location loc = iv.getLoc(); 514a2c9d4bbSAart Bik VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 515a2c9d4bbSAart Bik // Special case if the vector length evenly divides the trip count (for 516a2c9d4bbSAart Bik // example, "for i = 0, 128, 16"). A constant all-true mask is generated 517a2c9d4bbSAart Bik // so that all subsequent masked memory operations are immediately folded 518a2c9d4bbSAart Bik // into unconditional memory operations. 519a2c9d4bbSAart Bik IntegerAttr loInt, hiInt, stepInt; 520a2c9d4bbSAart Bik if (matchPattern(lo, m_Constant(&loInt)) && 521a2c9d4bbSAart Bik matchPattern(hi, m_Constant(&hiInt)) && 522a2c9d4bbSAart Bik matchPattern(step, m_Constant(&stepInt))) { 523a2c9d4bbSAart Bik if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 524a2c9d4bbSAart Bik return rewriter.create<vector::BroadcastOp>( 525*a54f4eaeSMogball loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1)); 526a2c9d4bbSAart Bik } 527a2c9d4bbSAart Bik // Otherwise, generate a vector mask that avoids overrunning the upperbound 528a2c9d4bbSAart Bik // during vector execution. Here we rely on subsequent loop optimizations to 529a2c9d4bbSAart Bik // avoid executing the mask in all iterations, for example, by splitting the 530a2c9d4bbSAart Bik // loop into an unconditional vector loop and a scalar cleanup loop. 53176a18618SMatthias Springer auto minMap = AffineMap::get( 53276a18618SMatthias Springer /*dimCount=*/2, /*symbolCount=*/1, 53376a18618SMatthias Springer {rewriter.getAffineSymbolExpr(0), 53476a18618SMatthias Springer rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 53576a18618SMatthias Springer rewriter.getContext()); 53676a18618SMatthias Springer Value end = 53776a18618SMatthias Springer rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 538a2c9d4bbSAart Bik return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 539a2c9d4bbSAart Bik } 540a2c9d4bbSAart Bik 541a2c9d4bbSAart Bik /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 542a2c9d4bbSAart Bik static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 543a2c9d4bbSAart Bik Value ptr, ArrayRef<Value> args) { 544a2c9d4bbSAart Bik Location loc = ptr.getLoc(); 545a2c9d4bbSAart Bik VectorType vtp = vectorType(codegen, ptr); 546*a54f4eaeSMogball Value pass = 547*a54f4eaeSMogball rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 548a2c9d4bbSAart Bik if (args.back().getType().isa<VectorType>()) { 549a2c9d4bbSAart Bik SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 550a2c9d4bbSAart Bik Value indexVec = args.back(); 551*a54f4eaeSMogball scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 552a2c9d4bbSAart Bik return rewriter.create<vector::GatherOp>( 553a2c9d4bbSAart Bik loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 554a2c9d4bbSAart Bik } 555a2c9d4bbSAart Bik return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 556a2c9d4bbSAart Bik codegen.curVecMask, pass); 557a2c9d4bbSAart Bik } 558a2c9d4bbSAart Bik 559a2c9d4bbSAart Bik /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 560a2c9d4bbSAart Bik static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 561a2c9d4bbSAart Bik Value rhs, Value ptr, ArrayRef<Value> args) { 562a2c9d4bbSAart Bik Location loc = ptr.getLoc(); 563a2c9d4bbSAart Bik if (args.back().getType().isa<VectorType>()) { 564a2c9d4bbSAart Bik SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 565a2c9d4bbSAart Bik Value indexVec = args.back(); 566*a54f4eaeSMogball scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 567a2c9d4bbSAart Bik rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 568a2c9d4bbSAart Bik codegen.curVecMask, rhs); 569a2c9d4bbSAart Bik return; 570a2c9d4bbSAart Bik } 571a2c9d4bbSAart Bik rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 572a2c9d4bbSAart Bik rhs); 573a2c9d4bbSAart Bik } 574a2c9d4bbSAart Bik 575a2c9d4bbSAart Bik /// Generates a vectorized invariant. Here we rely on subsequent loop 576a2c9d4bbSAart Bik /// optimizations to hoist the invariant broadcast out of the vector loop. 577a2c9d4bbSAart Bik static Value genVectorInvariantValue(CodeGen &codegen, 578a2c9d4bbSAart Bik PatternRewriter &rewriter, Value val) { 579a2c9d4bbSAart Bik VectorType vtp = vectorType(codegen, val.getType()); 580a2c9d4bbSAart Bik return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 581a2c9d4bbSAart Bik } 582a2c9d4bbSAart Bik 583b1d44e59SAart Bik /// Generates an affine expression. 584b1d44e59SAart Bik // 585b1d44e59SAart Bik // TODO: generalize for sparse tensor subscripts 586b1d44e59SAart Bik // 587b1d44e59SAart Bik static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 588b1d44e59SAart Bik AffineExpr a, Location loc) { 589b1d44e59SAart Bik switch (a.getKind()) { 590b1d44e59SAart Bik case AffineExprKind::DimId: { 591b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 592b1d44e59SAart Bik return codegen.loops[idx]; // universal dense index 593b1d44e59SAart Bik } 594b1d44e59SAart Bik case AffineExprKind::Add: { 595b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 596*a54f4eaeSMogball return rewriter.create<arith::AddIOp>( 597b1d44e59SAart Bik loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 598b1d44e59SAart Bik genAffine(codegen, rewriter, binOp.getRHS(), loc)); 599b1d44e59SAart Bik } 600b1d44e59SAart Bik case AffineExprKind::Mul: { 601b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 602*a54f4eaeSMogball return rewriter.create<arith::MulIOp>( 603b1d44e59SAart Bik loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 604b1d44e59SAart Bik genAffine(codegen, rewriter, binOp.getRHS(), loc)); 605b1d44e59SAart Bik } 606b1d44e59SAart Bik case AffineExprKind::Constant: { 607b1d44e59SAart Bik int64_t c = a.cast<AffineConstantExpr>().getValue(); 608*a54f4eaeSMogball return rewriter.create<arith::ConstantIndexOp>(loc, c); 609b1d44e59SAart Bik } 610b1d44e59SAart Bik default: 611b1d44e59SAart Bik llvm_unreachable("unexpected affine subscript"); 612b1d44e59SAart Bik } 613b1d44e59SAart Bik } 614b1d44e59SAart Bik 615b1d44e59SAart Bik /// Generates subscript for load/store on a dense or sparse tensor. 616b1d44e59SAart Bik static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 617b1d44e59SAart Bik linalg::GenericOp op, OpOperand *t, 618b1d44e59SAart Bik SmallVector<Value, 4> &args) { 619b1d44e59SAart Bik unsigned tensor = t->getOperandNumber(); 620b1d44e59SAart Bik auto map = op.getTiedIndexingMap(t); 621b1d44e59SAart Bik auto enc = getSparseTensorEncoding(t->get().getType()); 622b1d44e59SAart Bik unsigned rank = map.getNumResults(); 623b1d44e59SAart Bik if (enc) { 624b1d44e59SAart Bik // Note that currently, all sparse subscripts are simple. 625b1d44e59SAart Bik // TODO: accept affine too? 626b1d44e59SAart Bik unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 627b1d44e59SAart Bik assert(codegen.pidxs[tensor][idx] != nullptr); 628b1d44e59SAart Bik args.push_back(codegen.pidxs[tensor][idx]); // position index 629b1d44e59SAart Bik } else { 630b1d44e59SAart Bik for (unsigned d = 0; d < rank; d++) { 631b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 632b1d44e59SAart Bik args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 633b1d44e59SAart Bik } 634b1d44e59SAart Bik } 635b1d44e59SAart Bik return codegen.buffers[tensor]; 636b1d44e59SAart Bik } 637b1d44e59SAart Bik 638a2c9d4bbSAart Bik /// Generates a load on a dense or sparse tensor. 639a2c9d4bbSAart Bik static Value genTensorLoad(Merger &merger, CodeGen &codegen, 640a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 641a2c9d4bbSAart Bik unsigned exp) { 642a2c9d4bbSAart Bik // Test if the load was hoisted to a higher loop nest. 643a2c9d4bbSAart Bik Value val = merger.exp(exp).val; 644a2c9d4bbSAart Bik if (val) { 645a2c9d4bbSAart Bik if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 646a2c9d4bbSAart Bik return genVectorInvariantValue(codegen, rewriter, val); 647a2c9d4bbSAart Bik return val; 648a2c9d4bbSAart Bik } 649a2c9d4bbSAart Bik // Actual load. 650a2c9d4bbSAart Bik SmallVector<Value, 4> args; 6514569c14aSGus Smith OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 652b1d44e59SAart Bik Value ptr = genSubscript(codegen, rewriter, op, t, args); 653a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 654a2c9d4bbSAart Bik return genVectorLoad(codegen, rewriter, ptr, args); 655b1d44e59SAart Bik return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 656a2c9d4bbSAart Bik } 657a2c9d4bbSAart Bik 658727a63e0SAart Bik /// Generates a store on a dense or sparse tensor. 659a2c9d4bbSAart Bik static void genTensorStore(Merger &merger, CodeGen &codegen, 660a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 661b1d44e59SAart Bik Value rhs) { 662a2c9d4bbSAart Bik // Test if this is a scalarized reduction. 663b1d44e59SAart Bik if (codegen.redVal) { 664a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 665b1d44e59SAart Bik rhs = rewriter.create<SelectOp>(op.getLoc(), codegen.curVecMask, rhs, 666a2c9d4bbSAart Bik codegen.redVal); 667a2c9d4bbSAart Bik codegen.redVal = rhs; 668a2c9d4bbSAart Bik return; 669a2c9d4bbSAart Bik } 670a2c9d4bbSAart Bik // Actual store. 671a2c9d4bbSAart Bik SmallVector<Value, 4> args; 672b1d44e59SAart Bik OpOperand *t = op.getOutputOperand(0); 673b1d44e59SAart Bik Value ptr = genSubscript(codegen, rewriter, op, t, args); 674a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 675a2c9d4bbSAart Bik genVectorStore(codegen, rewriter, rhs, ptr, args); 676a2c9d4bbSAart Bik else 677b1d44e59SAart Bik rewriter.create<memref::StoreOp>(op.getLoc(), rhs, ptr, args); 678a2c9d4bbSAart Bik } 679a2c9d4bbSAart Bik 680a2c9d4bbSAart Bik /// Generates a pointer/index load from the sparse storage scheme. Narrower 681a2c9d4bbSAart Bik /// data types need to be zero extended before casting the value into the 682a2c9d4bbSAart Bik /// index type used for looping and indexing. 683a2c9d4bbSAart Bik static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 684a2c9d4bbSAart Bik Value ptr, Value s) { 685a2c9d4bbSAart Bik // See https://llvm.org/docs/GetElementPtr.html for some background on 686a2c9d4bbSAart Bik // the complications described below. 687a2c9d4bbSAart Bik if (codegen.curVecLength > 1) { 688a2c9d4bbSAart Bik // Since the index vector is used in a subsequent gather/scatter operations, 689a2c9d4bbSAart Bik // which effectively defines an unsigned pointer + signed index, we must 690a2c9d4bbSAart Bik // zero extend the vector to an index width. For 8-bit and 16-bit values, 691a2c9d4bbSAart Bik // an 32-bit index width suffices. For 32-bit values, zero extending the 692a2c9d4bbSAart Bik // elements into 64-bit loses some performance since the 32-bit indexed 69386e9bc1aSAart Bik // gather/scatter is more efficient than the 64-bit index variant (if the 69486e9bc1aSAart Bik // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 695727a63e0SAart Bik // preserve this performance). For 64-bit values, there is no good way 696a2c9d4bbSAart Bik // to state that the indices are unsigned, with creates the potential of 697a2c9d4bbSAart Bik // incorrect address calculations in the unlikely case we need such 698a2c9d4bbSAart Bik // extremely large offsets. 699a2c9d4bbSAart Bik Type etp = ptr.getType().cast<MemRefType>().getElementType(); 700a2c9d4bbSAart Bik Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 701a2c9d4bbSAart Bik if (!etp.isa<IndexType>()) { 702a2c9d4bbSAart Bik if (etp.getIntOrFloatBitWidth() < 32) 703*a54f4eaeSMogball vload = rewriter.create<arith::ExtUIOp>( 704a2c9d4bbSAart Bik loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 70586e9bc1aSAart Bik else if (etp.getIntOrFloatBitWidth() < 64 && 70686e9bc1aSAart Bik !codegen.options.enableSIMDIndex32) 707*a54f4eaeSMogball vload = rewriter.create<arith::ExtUIOp>( 708a2c9d4bbSAart Bik loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 709a2c9d4bbSAart Bik } 710a2c9d4bbSAart Bik return vload; 711a2c9d4bbSAart Bik } 712a2c9d4bbSAart Bik // For the scalar case, we simply zero extend narrower indices into 64-bit 713a2c9d4bbSAart Bik // values before casting to index without a performance penalty. Here too, 714a2c9d4bbSAart Bik // however, indices that already are 64-bit, in theory, cannot express the 715a2c9d4bbSAart Bik // full range as explained above. 716a2c9d4bbSAart Bik Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 717a2c9d4bbSAart Bik if (!load.getType().isa<IndexType>()) { 718a2c9d4bbSAart Bik if (load.getType().getIntOrFloatBitWidth() < 64) 719*a54f4eaeSMogball load = rewriter.create<arith::ExtUIOp>(loc, load, 720a2c9d4bbSAart Bik rewriter.getIntegerType(64)); 721*a54f4eaeSMogball load = 722*a54f4eaeSMogball rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType()); 723a2c9d4bbSAart Bik } 724a2c9d4bbSAart Bik return load; 725a2c9d4bbSAart Bik } 726a2c9d4bbSAart Bik 727a2c9d4bbSAart Bik /// Generates an invariant value. 728a2c9d4bbSAart Bik static Value genInvariantValue(Merger &merger, CodeGen &codegen, 729a2c9d4bbSAart Bik PatternRewriter &rewriter, unsigned exp) { 730a2c9d4bbSAart Bik Value val = merger.exp(exp).val; 731a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 732a2c9d4bbSAart Bik return genVectorInvariantValue(codegen, rewriter, val); 733a2c9d4bbSAart Bik return val; 734a2c9d4bbSAart Bik } 735a2c9d4bbSAart Bik 736a2c9d4bbSAart Bik /// Generates an address computation "sz * p + i". 737a2c9d4bbSAart Bik static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 738a2c9d4bbSAart Bik Location loc, Value size, Value p, Value i) { 739*a54f4eaeSMogball Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 740a2c9d4bbSAart Bik if (auto vtp = i.getType().dyn_cast<VectorType>()) { 741*a54f4eaeSMogball Value inv = 742*a54f4eaeSMogball rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType()); 743a2c9d4bbSAart Bik mul = genVectorInvariantValue(codegen, rewriter, inv); 744a2c9d4bbSAart Bik } 745*a54f4eaeSMogball return rewriter.create<arith::AddIOp>(loc, mul, i); 746a2c9d4bbSAart Bik } 747a2c9d4bbSAart Bik 748a2c9d4bbSAart Bik /// Generates start of a reduction. 749a2c9d4bbSAart Bik static Value genReductionStart(Merger &merger, CodeGen &codegen, 750a2c9d4bbSAart Bik PatternRewriter &rewriter, 751a2c9d4bbSAart Bik linalg::GenericOp op) { 752a2c9d4bbSAart Bik if (codegen.redVal) 753a2c9d4bbSAart Bik return codegen.redVal; // chained with previous for-loop 7545da21338SAart Bik // Generate vector or scalar start of a reduction. 7555da21338SAart Bik unsigned vl = codegen.curVecLength; 7565da21338SAart Bik if (vl > 1) { 757a2c9d4bbSAart Bik VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 7585da21338SAart Bik assert(!merger.exp(codegen.redExp).val); 7595da21338SAart Bik codegen.curVecLength = 1; 7605da21338SAart Bik Value load = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 7615da21338SAart Bik codegen.curVecLength = vl; 7625da21338SAart Bik return genReductionInit(rewriter, op.getLoc(), codegen.redKind, vtp, load); 763a2c9d4bbSAart Bik } 764a2c9d4bbSAart Bik return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 765a2c9d4bbSAart Bik } 766a2c9d4bbSAart Bik 767a2c9d4bbSAart Bik /// Generates end of a reduction. 768a2c9d4bbSAart Bik static void genReductionEnd(Merger &merger, CodeGen &codegen, 769a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op) { 770a2c9d4bbSAart Bik Value red = codegen.redVal; 771a2c9d4bbSAart Bik if (!red) 772a2c9d4bbSAart Bik return; 773a2c9d4bbSAart Bik assert(codegen.curVecLength == 1); 774a2c9d4bbSAart Bik codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 7755da21338SAart Bik // Generate vector or scalar end of a reduction. 776a2c9d4bbSAart Bik if (auto vtp = red.getType().dyn_cast<VectorType>()) { 7775da21338SAart Bik StringRef name = getReductionName(codegen.redKind); 7785da21338SAart Bik StringAttr kind = rewriter.getStringAttr(name); 7795da21338SAart Bik red = rewriter.create<vector::ReductionOp>( 7805da21338SAart Bik op.getLoc(), vtp.getElementType(), kind, red, ValueRange{}); 781a2c9d4bbSAart Bik } 782b1d44e59SAart Bik genTensorStore(merger, codegen, rewriter, op, red); 783a2c9d4bbSAart Bik } 784a2c9d4bbSAart Bik 785a2c9d4bbSAart Bik /// Recursively generates tensor expression. 786a2c9d4bbSAart Bik static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 787a2c9d4bbSAart Bik linalg::GenericOp op, unsigned exp) { 788b8a021dbSAart Bik Location loc = op.getLoc(); 789123e8dfcSAart Bik if (exp == -1u) 790123e8dfcSAart Bik return Value(); 791a2c9d4bbSAart Bik if (merger.exp(exp).kind == Kind::kTensor) 792a2c9d4bbSAart Bik return genTensorLoad(merger, codegen, rewriter, op, exp); 793b8a021dbSAart Bik if (merger.exp(exp).kind == Kind::kInvariant) 794a2c9d4bbSAart Bik return genInvariantValue(merger, codegen, rewriter, exp); 7954569c14aSGus Smith Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 7964569c14aSGus Smith Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 79745b3cfe8SAart Bik return merger.buildExp(rewriter, loc, exp, v0, v1); 798a2c9d4bbSAart Bik } 799a2c9d4bbSAart Bik 800b1d44e59SAart Bik /// Determines if affine expression is invariant. 801b1d44e59SAart Bik static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 802b1d44e59SAart Bik unsigned ldx, bool &atLevel) { 803b1d44e59SAart Bik switch (a.getKind()) { 804b1d44e59SAart Bik case AffineExprKind::DimId: { 805b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 806b1d44e59SAart Bik if (idx == ldx) 807b1d44e59SAart Bik atLevel = true; 808b1d44e59SAart Bik return codegen.loops[idx] != nullptr; // no longer in play? 809b1d44e59SAart Bik } 810b1d44e59SAart Bik case AffineExprKind::Add: 811b1d44e59SAart Bik case AffineExprKind::Mul: { 812b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 813b1d44e59SAart Bik return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 814b1d44e59SAart Bik isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 815b1d44e59SAart Bik } 816b1d44e59SAart Bik default: 817b1d44e59SAart Bik return true; 818b1d44e59SAart Bik } 819b1d44e59SAart Bik } 820b1d44e59SAart Bik 821a2c9d4bbSAart Bik /// Hoists loop invariant tensor loads for which indices have been exhausted. 822a2c9d4bbSAart Bik static void genInvariants(Merger &merger, CodeGen &codegen, 823a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 8245da21338SAart Bik unsigned exp, unsigned ldx, bool hoist, 8255da21338SAart Bik Kind last = Kind::kTensor) { 826123e8dfcSAart Bik if (exp == -1u) 827123e8dfcSAart Bik return; 828a2c9d4bbSAart Bik if (merger.exp(exp).kind == Kind::kTensor) { 829a2c9d4bbSAart Bik // Inspect tensor indices. 830a2c9d4bbSAart Bik bool atLevel = ldx == -1u; 8314569c14aSGus Smith OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 832619bfe8bSAart Bik auto map = op.getTiedIndexingMap(t); 833619bfe8bSAart Bik auto enc = getSparseTensorEncoding(t->get().getType()); 834c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 835b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 836b1d44e59SAart Bik if (!isInvariantAffine(codegen, a, ldx, atLevel)) 837a2c9d4bbSAart Bik return; // still in play 838a2c9d4bbSAart Bik } 839a2c9d4bbSAart Bik // All exhausted at this level (atLevel denotes exactly at this level). 8402f2b5b7dSTobias Gysi OpOperand *lhs = op.getOutputOperand(0); 841619bfe8bSAart Bik if (lhs == t) { 842a2c9d4bbSAart Bik codegen.redExp = hoist ? exp : -1u; 8435da21338SAart Bik codegen.redKind = getReduction(last); 844a2c9d4bbSAart Bik } else if (atLevel) { 845a2c9d4bbSAart Bik merger.exp(exp).val = 846a2c9d4bbSAart Bik hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 847a2c9d4bbSAart Bik } 848123e8dfcSAart Bik } else if (merger.exp(exp).kind != Kind::kInvariant) { 849a2c9d4bbSAart Bik // Traverse into the binary operations. Note that we only hoist 850a2c9d4bbSAart Bik // tensor loads, since subsequent MLIR/LLVM passes know how to 851a2c9d4bbSAart Bik // deal with all other kinds of derived loop invariants. 8525da21338SAart Bik Kind last = merger.exp(exp).kind; 8534569c14aSGus Smith unsigned e0 = merger.exp(exp).children.e0; 8544569c14aSGus Smith unsigned e1 = merger.exp(exp).children.e1; 8555da21338SAart Bik genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist, last); 8565da21338SAart Bik genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist, last); 857a2c9d4bbSAart Bik } 858a2c9d4bbSAart Bik } 859a2c9d4bbSAart Bik 860a2c9d4bbSAart Bik /// Generates initialization code for the subsequent loop sequence at 861a2c9d4bbSAart Bik /// current index level. Returns true if the loop sequence needs to 862a2c9d4bbSAart Bik /// maintain the universal index. 863a2c9d4bbSAart Bik static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 864a2c9d4bbSAart Bik linalg::GenericOp op, std::vector<unsigned> &topSort, 865a2c9d4bbSAart Bik unsigned at, llvm::BitVector &inits) { 866a2c9d4bbSAart Bik bool needsUniv = false; 867a2c9d4bbSAart Bik Location loc = op.getLoc(); 868a2c9d4bbSAart Bik unsigned idx = topSort[at]; 869a2c9d4bbSAart Bik 870a2c9d4bbSAart Bik // Initialize sparse positions. 871a2c9d4bbSAart Bik for (unsigned b = 0, be = inits.size(); b < be; b++) { 872a2c9d4bbSAart Bik if (inits[b]) { 873a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 874a2c9d4bbSAart Bik assert(idx == merger.index(b)); 875a2c9d4bbSAart Bik if (merger.isDim(b, Dim::kSparse)) { 876a2c9d4bbSAart Bik // Initialize sparse index. 877a2c9d4bbSAart Bik unsigned pat = at; 878a2c9d4bbSAart Bik for (; pat != 0; pat--) { 879a2c9d4bbSAart Bik if (codegen.pidxs[tensor][topSort[pat - 1]]) 880a2c9d4bbSAart Bik break; 881a2c9d4bbSAart Bik } 882a2c9d4bbSAart Bik Value ptr = codegen.pointers[tensor][idx]; 883*a54f4eaeSMogball Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 884*a54f4eaeSMogball Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 885a2c9d4bbSAart Bik : codegen.pidxs[tensor][topSort[pat - 1]]; 886a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 887*a54f4eaeSMogball Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 888a2c9d4bbSAart Bik codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 889a2c9d4bbSAart Bik } else { 890a2c9d4bbSAart Bik // Dense index still in play. 891a2c9d4bbSAart Bik needsUniv = true; 892a2c9d4bbSAart Bik } 893a2c9d4bbSAart Bik } 894a2c9d4bbSAart Bik } 895a2c9d4bbSAart Bik 896a2c9d4bbSAart Bik // Initialize the universal dense index. 897*a54f4eaeSMogball codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0); 898a2c9d4bbSAart Bik return needsUniv; 899a2c9d4bbSAart Bik } 900a2c9d4bbSAart Bik 901a2c9d4bbSAart Bik /// Returns vectorization strategy. Any implicit inner loop in the Linalg 902a2c9d4bbSAart Bik /// operation is a candidate. Whether it is actually converted to SIMD code 903a2c9d4bbSAart Bik /// depends on the requested strategy. 904a2c9d4bbSAart Bik static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 905a2c9d4bbSAart Bik switch (codegen.options.vectorizationStrategy) { 906a2c9d4bbSAart Bik case SparseVectorizationStrategy::kNone: 907a2c9d4bbSAart Bik return false; 908a2c9d4bbSAart Bik case SparseVectorizationStrategy::kDenseInnerLoop: 909a2c9d4bbSAart Bik return isInner && !isSparse; 910a2c9d4bbSAart Bik case SparseVectorizationStrategy::kAnyStorageInnerLoop: 911a2c9d4bbSAart Bik return isInner; 912a2c9d4bbSAart Bik } 913a2c9d4bbSAart Bik llvm_unreachable("unexpected vectorization strategy"); 914a2c9d4bbSAart Bik } 915a2c9d4bbSAart Bik 916a2c9d4bbSAart Bik /// Returns parallelization strategy. Any implicit loop in the Linalg operation 917a2c9d4bbSAart Bik /// that is marked "parallel" is a candidate. Whether it is actually converted 918a2c9d4bbSAart Bik /// to a parallel operation depends on the requested strategy. 919a2c9d4bbSAart Bik static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 920a2c9d4bbSAart Bik bool isSparse, bool isVector) { 921a2c9d4bbSAart Bik switch (codegen.options.parallelizationStrategy) { 922a2c9d4bbSAart Bik case SparseParallelizationStrategy::kNone: 923a2c9d4bbSAart Bik return false; 924a2c9d4bbSAart Bik case SparseParallelizationStrategy::kDenseOuterLoop: 925a2c9d4bbSAart Bik return isOuter && !isSparse && !isReduction && !isVector; 926a2c9d4bbSAart Bik case SparseParallelizationStrategy::kAnyStorageOuterLoop: 927a2c9d4bbSAart Bik return isOuter && !isReduction && !isVector; 928a2c9d4bbSAart Bik case SparseParallelizationStrategy::kDenseAnyLoop: 929a2c9d4bbSAart Bik return !isSparse && !isReduction && !isVector; 930a2c9d4bbSAart Bik case SparseParallelizationStrategy::kAnyStorageAnyLoop: 931a2c9d4bbSAart Bik return !isReduction && !isVector; 932a2c9d4bbSAart Bik } 933a2c9d4bbSAart Bik llvm_unreachable("unexpected parallelization strategy"); 934a2c9d4bbSAart Bik } 935a2c9d4bbSAart Bik 936849f016cSAart Bik /// Checks unit stride for dense tensors. The iteration graph may have ignored 937a2c9d4bbSAart Bik /// dense access patterns in order to avoid cycles (sparse access patterns are 938a2c9d4bbSAart Bik /// always placed innermost), but that means dense access has become strided. 939849f016cSAart Bik /// This prevents effective vectorization. 940a2c9d4bbSAart Bik static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 941849f016cSAart Bik unsigned idx) { 9422f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 9432f2b5b7dSTobias Gysi if (!getSparseTensorEncoding(t->get().getType())) { 9442f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 945c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 946b1d44e59SAart Bik AffineExpr a = map.getResult(d); 947849f016cSAart Bik // Report non-unit stride if innermost index appears at an outer 948849f016cSAart Bik // dimension (true non-unit stride) or if the innermost index appears 949849f016cSAart Bik // in a compound subscript in the innermost dimension. Even if the 950849f016cSAart Bik // latter is unit stride, it does not play well with scatter/gather. 951849f016cSAart Bik if (a.isFunctionOfDim(idx) && 952849f016cSAart Bik ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 953a2c9d4bbSAart Bik return false; 954a2c9d4bbSAart Bik } 955a2c9d4bbSAart Bik } 956a2c9d4bbSAart Bik } 957a2c9d4bbSAart Bik return true; 958a2c9d4bbSAart Bik } 959a2c9d4bbSAart Bik 960a2c9d4bbSAart Bik /// Generates a for-loop on a single index. 961a2c9d4bbSAart Bik static Operation *genFor(Merger &merger, CodeGen &codegen, 962a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 963a2c9d4bbSAart Bik bool isOuter, bool isInner, unsigned idx, 964a2c9d4bbSAart Bik llvm::BitVector &indices) { 965a2c9d4bbSAart Bik unsigned fb = indices.find_first(); 966a2c9d4bbSAart Bik unsigned tensor = merger.tensor(fb); 967a2c9d4bbSAart Bik assert(idx == merger.index(fb)); 968a2c9d4bbSAart Bik auto iteratorTypes = op.iterator_types().getValue(); 969583a7542STobias Gysi bool isReduction = isReductionIterator(iteratorTypes[idx]); 970a2c9d4bbSAart Bik bool isSparse = merger.isDim(fb, Dim::kSparse); 971a2c9d4bbSAart Bik bool isVector = isVectorFor(codegen, isInner, isSparse) && 972a2c9d4bbSAart Bik denseUnitStrides(merger, op, idx); 973a2c9d4bbSAart Bik bool isParallel = 974a2c9d4bbSAart Bik isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 975a2c9d4bbSAart Bik 976a2c9d4bbSAart Bik // Prepare vector length. 977a2c9d4bbSAart Bik if (isVector) 978a2c9d4bbSAart Bik codegen.curVecLength = codegen.options.vectorLength; 979a2c9d4bbSAart Bik 980a2c9d4bbSAart Bik // Loop bounds and increment. 981a2c9d4bbSAart Bik Location loc = op.getLoc(); 982a2c9d4bbSAart Bik Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 983a2c9d4bbSAart Bik Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 984*a54f4eaeSMogball Value step = 985*a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength); 986a2c9d4bbSAart Bik 987a2c9d4bbSAart Bik // Emit a parallel loop. 988a2c9d4bbSAart Bik if (isParallel) { 989a2c9d4bbSAart Bik assert(!isVector); 990a2c9d4bbSAart Bik scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 991a2c9d4bbSAart Bik if (isSparse) 992a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 993a2c9d4bbSAart Bik else 994a2c9d4bbSAart Bik codegen.loops[idx] = parOp.getInductionVars()[0]; 995a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(parOp.getBody()); 996a2c9d4bbSAart Bik return parOp; 997a2c9d4bbSAart Bik } 998a2c9d4bbSAart Bik 999a2c9d4bbSAart Bik // Emit a sequential loop, potentially with a scalarized reduction. 1000a2c9d4bbSAart Bik bool scalarRed = isInner && codegen.redExp != -1u; 1001a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 1002a2c9d4bbSAart Bik if (scalarRed) { 1003a2c9d4bbSAart Bik Value load = genReductionStart(merger, codegen, rewriter, op); 1004a2c9d4bbSAart Bik operands.push_back(load); 1005a2c9d4bbSAart Bik } 1006a2c9d4bbSAart Bik scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1007a2c9d4bbSAart Bik if (scalarRed) { 1008a2c9d4bbSAart Bik codegen.redVal = merger.exp(codegen.redExp).val = 1009a2c9d4bbSAart Bik forOp.getRegionIterArgs().front(); 1010a2c9d4bbSAart Bik } 1011a2c9d4bbSAart Bik // Assign induction variable to sparse or dense index. 1012a2c9d4bbSAart Bik Value iv = forOp.getInductionVar(); 1013a2c9d4bbSAart Bik if (isSparse) 1014a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = iv; 1015a2c9d4bbSAart Bik else 1016a2c9d4bbSAart Bik codegen.loops[idx] = iv; 1017a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(forOp.getBody()); 1018a2c9d4bbSAart Bik // Share vector iteration mask between all subsequent loads/stores. 1019a2c9d4bbSAart Bik if (isVector) 1020a2c9d4bbSAart Bik codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1021a2c9d4bbSAart Bik return forOp; 1022a2c9d4bbSAart Bik } 1023a2c9d4bbSAart Bik 1024a2c9d4bbSAart Bik /// Emit a while-loop for co-iteration over multiple indices. 1025a2c9d4bbSAart Bik static Operation *genWhile(Merger &merger, CodeGen &codegen, 1026a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1027a2c9d4bbSAart Bik unsigned idx, bool needsUniv, 1028a2c9d4bbSAart Bik llvm::BitVector &indices) { 1029a2c9d4bbSAart Bik SmallVector<Type, 4> types; 1030a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 1031a2c9d4bbSAart Bik // Construct the while-loop with a parameter for each index. 1032a2c9d4bbSAart Bik Type indexType = rewriter.getIndexType(); 1033a2c9d4bbSAart Bik for (unsigned b = 0, be = indices.size(); b < be; b++) { 1034a2c9d4bbSAart Bik if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1035a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1036a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1037a2c9d4bbSAart Bik types.push_back(indexType); 1038a2c9d4bbSAart Bik assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 1039a2c9d4bbSAart Bik "type mismatch for sparse index"); 1040a2c9d4bbSAart Bik operands.push_back(codegen.pidxs[tensor][idx]); 1041a2c9d4bbSAart Bik } 1042a2c9d4bbSAart Bik } 1043a2c9d4bbSAart Bik if (needsUniv) { 1044a2c9d4bbSAart Bik types.push_back(indexType); 1045a2c9d4bbSAart Bik assert(codegen.loops[idx].getType().isa<IndexType>() && 1046a2c9d4bbSAart Bik "type mismatch for universal index"); 1047a2c9d4bbSAart Bik operands.push_back(codegen.loops[idx]); 1048a2c9d4bbSAart Bik } 1049a2c9d4bbSAart Bik Location loc = op.getLoc(); 1050a2c9d4bbSAart Bik scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1051a2c9d4bbSAart Bik Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1052a2c9d4bbSAart Bik Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1053a2c9d4bbSAart Bik 1054a2c9d4bbSAart Bik // Build the "before" region, which effectively consists 1055a2c9d4bbSAart Bik // of a conjunction of "i < upper" tests on all induction. 1056a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&whileOp.before().front()); 1057a2c9d4bbSAart Bik Value cond; 1058a2c9d4bbSAart Bik unsigned o = 0; 1059a2c9d4bbSAart Bik for (unsigned b = 0, be = indices.size(); b < be; b++) { 1060a2c9d4bbSAart Bik if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1061a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1062a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1063a2c9d4bbSAart Bik Value op1 = before->getArgument(o); 1064a2c9d4bbSAart Bik Value op2 = codegen.highs[tensor][idx]; 1065*a54f4eaeSMogball Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1066*a54f4eaeSMogball op1, op2); 1067*a54f4eaeSMogball cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1068a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = after->getArgument(o++); 1069a2c9d4bbSAart Bik } 1070a2c9d4bbSAart Bik } 1071a2c9d4bbSAart Bik if (needsUniv) 1072a2c9d4bbSAart Bik codegen.loops[idx] = after->getArgument(o++); 1073a2c9d4bbSAart Bik assert(o == operands.size()); 1074a2c9d4bbSAart Bik rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1075a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&whileOp.after().front()); 1076a2c9d4bbSAart Bik return whileOp; 1077a2c9d4bbSAart Bik } 1078a2c9d4bbSAart Bik 1079a2c9d4bbSAart Bik /// Generates a for-loop or a while-loop, depending on whether it implements 1080a2c9d4bbSAart Bik /// singleton iteration or co-iteration over the given conjunction. 1081a2c9d4bbSAart Bik static Operation *genLoop(Merger &merger, CodeGen &codegen, 1082a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1083a2c9d4bbSAart Bik std::vector<unsigned> &topSort, unsigned at, 1084a2c9d4bbSAart Bik bool needsUniv, llvm::BitVector &indices) { 1085a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1086a2c9d4bbSAart Bik if (indices.count() == 1) { 1087a2c9d4bbSAart Bik bool isOuter = at == 0; 1088a2c9d4bbSAart Bik bool isInner = at == topSort.size() - 1; 1089a2c9d4bbSAart Bik return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1090a2c9d4bbSAart Bik indices); 1091a2c9d4bbSAart Bik } 1092a2c9d4bbSAart Bik genReductionEnd(merger, codegen, rewriter, op); // cannot chain 1093a2c9d4bbSAart Bik return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1094a2c9d4bbSAart Bik } 1095a2c9d4bbSAart Bik 1096a2c9d4bbSAart Bik /// Generates the local variables for this loop, consisting of the sparse 1097a2c9d4bbSAart Bik /// indices, restored universal dense index, and dense positions. 1098a2c9d4bbSAart Bik static void genLocals(Merger &merger, CodeGen &codegen, 1099a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1100a2c9d4bbSAart Bik std::vector<unsigned> &topSort, unsigned at, 1101a2c9d4bbSAart Bik bool needsUniv, llvm::BitVector &locals) { 1102a2c9d4bbSAart Bik Location loc = op.getLoc(); 1103a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1104a2c9d4bbSAart Bik 1105a2c9d4bbSAart Bik // Initialize sparse indices. 1106a2c9d4bbSAart Bik Value min; 1107a2c9d4bbSAart Bik for (unsigned b = 0, be = locals.size(); b < be; b++) { 1108a2c9d4bbSAart Bik if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1109a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1110a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1111a2c9d4bbSAart Bik Value ptr = codegen.indices[tensor][idx]; 1112a2c9d4bbSAart Bik Value s = codegen.pidxs[tensor][idx]; 1113a2c9d4bbSAart Bik Value load = genLoad(codegen, rewriter, loc, ptr, s); 1114a2c9d4bbSAart Bik codegen.idxs[tensor][idx] = load; 1115a2c9d4bbSAart Bik if (!needsUniv) { 1116a2c9d4bbSAart Bik if (min) { 1117*a54f4eaeSMogball Value cmp = rewriter.create<arith::CmpIOp>( 1118*a54f4eaeSMogball loc, arith::CmpIPredicate::ult, load, min); 1119a2c9d4bbSAart Bik min = rewriter.create<SelectOp>(loc, cmp, load, min); 1120a2c9d4bbSAart Bik } else { 1121a2c9d4bbSAart Bik min = load; 1122a2c9d4bbSAart Bik } 1123a2c9d4bbSAart Bik } 1124a2c9d4bbSAart Bik } 1125a2c9d4bbSAart Bik } 1126a2c9d4bbSAart Bik 1127a2c9d4bbSAart Bik // Merge dense universal index over minimum. 1128a2c9d4bbSAart Bik if (min) { 1129a2c9d4bbSAart Bik assert(!needsUniv); 1130a2c9d4bbSAart Bik codegen.loops[idx] = min; 1131a2c9d4bbSAart Bik } 1132a2c9d4bbSAart Bik 1133727a63e0SAart Bik // Initialize dense positions. Note that we generate dense indices of the 1134727a63e0SAart Bik // output tensor unconditionally, since they may not appear in the lattice, 1135727a63e0SAart Bik // but may be needed for linearized codegen. 1136a2c9d4bbSAart Bik for (unsigned b = 0, be = locals.size(); b < be; b++) { 1137727a63e0SAart Bik if ((locals[b] || merger.isOutTensor(b, idx)) && 1138727a63e0SAart Bik merger.isDim(b, Dim::kDense)) { 1139a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1140a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1141a2c9d4bbSAart Bik unsigned pat = at; 1142a2c9d4bbSAart Bik for (; pat != 0; pat--) 1143a2c9d4bbSAart Bik if (codegen.pidxs[tensor][topSort[pat - 1]]) 1144a2c9d4bbSAart Bik break; 1145*a54f4eaeSMogball Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 1146a2c9d4bbSAart Bik : codegen.pidxs[tensor][topSort[pat - 1]]; 1147a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = genAddress( 1148a2c9d4bbSAart Bik codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1149a2c9d4bbSAart Bik } 1150a2c9d4bbSAart Bik } 1151a2c9d4bbSAart Bik } 1152a2c9d4bbSAart Bik 1153a2c9d4bbSAart Bik /// Generates the induction structure for a while-loop. 1154a2c9d4bbSAart Bik static void genWhileInduction(Merger &merger, CodeGen &codegen, 1155a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1156a2c9d4bbSAart Bik unsigned idx, bool needsUniv, 1157a2c9d4bbSAart Bik llvm::BitVector &induction, ResultRange results) { 1158a2c9d4bbSAart Bik Location loc = op.getLoc(); 1159a2c9d4bbSAart Bik unsigned o = 0; 1160a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 1161*a54f4eaeSMogball Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1162a2c9d4bbSAart Bik for (unsigned b = 0, be = induction.size(); b < be; b++) { 1163a2c9d4bbSAart Bik if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1164a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1165a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1166a2c9d4bbSAart Bik Value op1 = codegen.idxs[tensor][idx]; 1167a2c9d4bbSAart Bik Value op2 = codegen.loops[idx]; 1168a2c9d4bbSAart Bik Value op3 = codegen.pidxs[tensor][idx]; 1169*a54f4eaeSMogball Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1170*a54f4eaeSMogball op1, op2); 1171*a54f4eaeSMogball Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1172a2c9d4bbSAart Bik operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1173a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = results[o++]; 1174a2c9d4bbSAart Bik } 1175a2c9d4bbSAart Bik } 1176a2c9d4bbSAart Bik if (needsUniv) { 1177*a54f4eaeSMogball operands.push_back( 1178*a54f4eaeSMogball rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1179a2c9d4bbSAart Bik codegen.loops[idx] = results[o++]; 1180a2c9d4bbSAart Bik } 1181a2c9d4bbSAart Bik assert(o == operands.size()); 1182a2c9d4bbSAart Bik rewriter.create<scf::YieldOp>(loc, operands); 1183a2c9d4bbSAart Bik } 1184a2c9d4bbSAart Bik 1185a2c9d4bbSAart Bik /// Generates a single if-statement within a while-loop. 1186a2c9d4bbSAart Bik static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1187a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1188a2c9d4bbSAart Bik unsigned idx, llvm::BitVector &conditions) { 1189a2c9d4bbSAart Bik Location loc = op.getLoc(); 1190a2c9d4bbSAart Bik Value cond; 1191a2c9d4bbSAart Bik for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1192a2c9d4bbSAart Bik if (conditions[b]) { 1193a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1194a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1195a2c9d4bbSAart Bik Value clause; 1196a2c9d4bbSAart Bik if (merger.isDim(b, Dim::kSparse)) { 1197a2c9d4bbSAart Bik Value op1 = codegen.idxs[tensor][idx]; 1198a2c9d4bbSAart Bik Value op2 = codegen.loops[idx]; 1199*a54f4eaeSMogball clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1200*a54f4eaeSMogball op1, op2); 1201a2c9d4bbSAart Bik } else { 1202*a54f4eaeSMogball clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true 1203a2c9d4bbSAart Bik } 1204*a54f4eaeSMogball cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1205a2c9d4bbSAart Bik } 1206a2c9d4bbSAart Bik } 1207a2c9d4bbSAart Bik scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1208a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1209a2c9d4bbSAart Bik return ifOp; 1210a2c9d4bbSAart Bik } 1211a2c9d4bbSAart Bik 1212a2c9d4bbSAart Bik /// Recursively generates code while computing iteration lattices in order 1213a2c9d4bbSAart Bik /// to manage the complexity of implementing co-iteration over unions 1214a2c9d4bbSAart Bik /// and intersections of sparse iterations spaces. 1215a2c9d4bbSAart Bik static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1216a2c9d4bbSAart Bik linalg::GenericOp op, std::vector<unsigned> &topSort, 1217a2c9d4bbSAart Bik unsigned exp, unsigned at) { 1218a2c9d4bbSAart Bik // At each leaf, assign remaining tensor (sub)expression to output tensor. 1219a2c9d4bbSAart Bik if (at == topSort.size()) { 1220a2c9d4bbSAart Bik Value rhs = genExp(merger, codegen, rewriter, op, exp); 1221b1d44e59SAart Bik genTensorStore(merger, codegen, rewriter, op, rhs); 1222a2c9d4bbSAart Bik return; 1223a2c9d4bbSAart Bik } 1224a2c9d4bbSAart Bik assert(codegen.curVecLength == 1); 1225a2c9d4bbSAart Bik 1226a2c9d4bbSAart Bik // Construct iteration lattices for current loop index, with L0 at top. 1227a2c9d4bbSAart Bik // Then emit initialization code for the loop sequence at this level. 1228a2c9d4bbSAart Bik // We maintain the universal dense index if dense indices are still 1229a2c9d4bbSAart Bik // in play for a non-singleton loop sequence. 1230a2c9d4bbSAart Bik Location loc = op.getLoc(); 1231a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1232043ce4e6SGus Smith unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1233a2c9d4bbSAart Bik unsigned lsize = merger.set(lts).size(); 1234a2c9d4bbSAart Bik assert(lsize != 0); 1235a2c9d4bbSAart Bik unsigned l0 = merger.set(lts)[0]; 1236a2c9d4bbSAart Bik unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1237a2c9d4bbSAart Bik genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1238a2c9d4bbSAart Bik bool needsUniv = false; 1239a2c9d4bbSAart Bik if (genInit(merger, codegen, rewriter, op, topSort, at, 1240a2c9d4bbSAart Bik merger.lat(l0).bits)) { 1241a2c9d4bbSAart Bik // Maintain the universal index only if it is actually 1242a2c9d4bbSAart Bik // consumed by a subsequent lattice point. 1243a2c9d4bbSAart Bik for (unsigned i = 1; i < lsize; i++) { 1244a2c9d4bbSAart Bik unsigned li = merger.set(lts)[i]; 1245a2c9d4bbSAart Bik if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1246a2c9d4bbSAart Bik needsUniv = true; 1247a2c9d4bbSAart Bik break; 1248a2c9d4bbSAart Bik } 1249a2c9d4bbSAart Bik } 1250a2c9d4bbSAart Bik } 1251a2c9d4bbSAart Bik 1252a2c9d4bbSAart Bik // Emit a loop for every lattice point L0 >= Li. 1253a2c9d4bbSAart Bik for (unsigned i = 0; i < lsize; i++) { 1254a2c9d4bbSAart Bik unsigned li = merger.set(lts)[i]; 1255a2c9d4bbSAart Bik 1256a2c9d4bbSAart Bik // Emit loop. 1257a2c9d4bbSAart Bik codegen.curVecLength = 1; 1258a2c9d4bbSAart Bik llvm::BitVector indices = merger.lat(li).simple; 1259a2c9d4bbSAart Bik Operation *loop = 1260a2c9d4bbSAart Bik genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1261a2c9d4bbSAart Bik genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1262a2c9d4bbSAart Bik merger.lat(li).bits); 1263a2c9d4bbSAart Bik 1264a2c9d4bbSAart Bik // Visit all lattices points with Li >= Lj to generate the 1265a2c9d4bbSAart Bik // loop-body, possibly with if statements for coiteration. 1266a2c9d4bbSAart Bik bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1267a2c9d4bbSAart Bik for (unsigned j = 0; j < lsize; j++) { 1268a2c9d4bbSAart Bik unsigned lj = merger.set(lts)[j]; 1269a2c9d4bbSAart Bik unsigned ej = merger.lat(lj).exp; 1270a2c9d4bbSAart Bik if (li == lj || merger.latGT(li, lj)) { 1271a2c9d4bbSAart Bik // Recurse into body of each branch. 1272a2c9d4bbSAart Bik if (isWhile) { 1273a2c9d4bbSAart Bik scf::IfOp ifOp = 1274a2c9d4bbSAart Bik genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1275a2c9d4bbSAart Bik genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1276a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1277a2c9d4bbSAart Bik } else { 1278a2c9d4bbSAart Bik genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1279a2c9d4bbSAart Bik } 1280a2c9d4bbSAart Bik } 1281a2c9d4bbSAart Bik } 1282a2c9d4bbSAart Bik 1283a2c9d4bbSAart Bik // Wrap-up induction and restore insertion point. 1284a2c9d4bbSAart Bik if (isWhile) { 1285a2c9d4bbSAart Bik scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1286a2c9d4bbSAart Bik rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1287a2c9d4bbSAart Bik genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1288a2c9d4bbSAart Bik merger.lat(li).bits, whileOp.results()); 1289a2c9d4bbSAart Bik } else { 1290a2c9d4bbSAart Bik needsUniv = false; 1291a2c9d4bbSAart Bik if (codegen.redVal) { 1292a2c9d4bbSAart Bik rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1293a2c9d4bbSAart Bik codegen.redVal = loop->getResult(0); 1294a2c9d4bbSAart Bik } 1295a2c9d4bbSAart Bik } 1296a2c9d4bbSAart Bik rewriter.setInsertionPointAfter(loop); 1297a2c9d4bbSAart Bik } 1298a2c9d4bbSAart Bik 1299a2c9d4bbSAart Bik // Wrap-up loop sequence. 1300a2c9d4bbSAart Bik codegen.curVecLength = 1; 1301a2c9d4bbSAart Bik genReductionEnd(merger, codegen, rewriter, op); 1302a2c9d4bbSAart Bik genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1303a2c9d4bbSAart Bik codegen.loops[idx] = Value(); 1304a2c9d4bbSAart Bik } 1305a2c9d4bbSAart Bik 1306727a63e0SAart Bik /// Converts the result computed by the sparse kernel into the required form. 130736b66ab9SAart Bik static void genResult(Merger &merger, CodeGen &codegen, 130836b66ab9SAart Bik PatternRewriter &rewriter, linalg::GenericOp op) { 130936b66ab9SAart Bik Location loc = op.getLoc(); 131036b66ab9SAart Bik OpOperand *lhs = op.getOutputOperand(0); 131136b66ab9SAart Bik Type resType = lhs->get().getType(); 131236b66ab9SAart Bik unsigned tensor = lhs->getOperandNumber(); 131336b66ab9SAart Bik auto map = op.getTiedIndexingMap(lhs); 131436b66ab9SAart Bik auto enc = getSparseTensorEncoding(resType); 131536b66ab9SAart Bik Value result = codegen.buffers.back(); // value array 131636b66ab9SAart Bik if (enc) { 131736b66ab9SAart Bik // The sparse annotation unambigiously defines the arrays needed 131836b66ab9SAart Bik // to "reconstruct" the sparse tensor from the storage scheme 131936b66ab9SAart Bik // (even though lowering should never need this eventually). 132036b66ab9SAart Bik SmallVector<Value, 4> args; 132136b66ab9SAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1322b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 1323b1d44e59SAart Bik if (a.getKind() != AffineExprKind::DimId) 1324b1d44e59SAart Bik continue; // compound 1325b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 132636b66ab9SAart Bik if (merger.isDim(tensor, idx, Dim::kSparse)) { 132736b66ab9SAart Bik args.push_back(codegen.pointers[tensor][idx]); 132836b66ab9SAart Bik args.push_back(codegen.indices[tensor][idx]); 132936b66ab9SAart Bik } 133036b66ab9SAart Bik } 133136b66ab9SAart Bik args.push_back(result); 133236b66ab9SAart Bik result = rewriter.create<ToTensorOp>(loc, resType, args); 133336b66ab9SAart Bik } else { 133436b66ab9SAart Bik // To "reconstruct" an non-annotated tensor, sipmly load it 133536b66ab9SAart Bik // from the bufferized value. 133636b66ab9SAart Bik result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 133736b66ab9SAart Bik } 1338727a63e0SAart Bik rewriter.replaceOp(op, result); 1339727a63e0SAart Bik } 1340727a63e0SAart Bik 13415da21338SAart Bik //===----------------------------------------------------------------------===// 13425da21338SAart Bik // Sparse compiler rewriting methods. 13435da21338SAart Bik //===----------------------------------------------------------------------===// 13445da21338SAart Bik 1345a2c9d4bbSAart Bik namespace { 1346a2c9d4bbSAart Bik 1347a2c9d4bbSAart Bik /// Sparse rewriting rule for generic Lingalg operation. 1348a2c9d4bbSAart Bik struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1349a2c9d4bbSAart Bik public: 1350a2c9d4bbSAart Bik GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1351a2c9d4bbSAart Bik : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1352a2c9d4bbSAart Bik 1353a2c9d4bbSAart Bik LogicalResult matchAndRewrite(linalg::GenericOp op, 1354a2c9d4bbSAart Bik PatternRewriter &rewriter) const override { 1355a2c9d4bbSAart Bik // Detects sparse annotations and translate the per-dimension sparsity 1356a2c9d4bbSAart Bik // information for all tensors to loop indices in the kernel. 1357a2c9d4bbSAart Bik assert(op.getNumOutputs() == 1); 13582f2b5b7dSTobias Gysi unsigned numTensors = op.getNumInputsAndOutputs(); 1359a2c9d4bbSAart Bik unsigned numLoops = op.iterator_types().getValue().size(); 1360a2c9d4bbSAart Bik Merger merger(numTensors, numLoops); 1361bf9ef3efSAart Bik if (!findSparseAnnotations(merger, op)) 1362bf9ef3efSAart Bik return failure(); 1363a2c9d4bbSAart Bik 1364a2c9d4bbSAart Bik // Computes a topologically sorted iteration graph to ensure 1365a2c9d4bbSAart Bik // tensors are visited in natural index order. Fails on cycles. 1366a2c9d4bbSAart Bik // This assumes that higher-level passes have already put the 1367a2c9d4bbSAart Bik // tensors in each tensor expression in a feasible order. 1368a2c9d4bbSAart Bik std::vector<unsigned> topSort; 1369b6d1a31cSAart Bik if (!computeIterationGraph(merger, op, topSort, 1370b6d1a31cSAart Bik SortMask::kIncludeUndef | 1371b6d1a31cSAart Bik SortMask::kIncludeDense) && 1372b6d1a31cSAart Bik !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1373b6d1a31cSAart Bik !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1374b6d1a31cSAart Bik !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1375a2c9d4bbSAart Bik return failure(); 1376a2c9d4bbSAart Bik 1377266a7414SAart Bik // Builds the tensor expression for the Linalg operation in SSA form. 1378266a7414SAart Bik Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op); 1379a2c9d4bbSAart Bik if (!exp.hasValue()) 1380266a7414SAart Bik return failure(); 1381a2c9d4bbSAart Bik 1382266a7414SAart Bik // Rejects an inadmissable tensor expression. 138336b66ab9SAart Bik if (!isAdmissableTensorExp(merger, op, exp.getValue())) 138436b66ab9SAart Bik return failure(); 138536b66ab9SAart Bik 1386a2c9d4bbSAart Bik // Recursively generates code. 1387a2c9d4bbSAart Bik CodeGen codegen(options, numTensors, numLoops); 1388727a63e0SAart Bik if (!genBuffers(merger, codegen, rewriter, op)) 1389727a63e0SAart Bik return failure(); // could not bufferize 1390a2c9d4bbSAart Bik genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 139136b66ab9SAart Bik genResult(merger, codegen, rewriter, op); 1392a2c9d4bbSAart Bik return success(); 1393a2c9d4bbSAart Bik } 1394a2c9d4bbSAart Bik 1395a2c9d4bbSAart Bik private: 1396a2c9d4bbSAart Bik /// Options to control sparse code generation. 1397a2c9d4bbSAart Bik SparsificationOptions options; 1398a2c9d4bbSAart Bik }; 1399a2c9d4bbSAart Bik 1400a2c9d4bbSAart Bik } // namespace 1401a2c9d4bbSAart Bik 1402a2c9d4bbSAart Bik /// Populates the given patterns list with rewriting rules required for 1403a2c9d4bbSAart Bik /// the sparsification of linear algebra operations. 1404a2c9d4bbSAart Bik void mlir::populateSparsificationPatterns( 1405a2c9d4bbSAart Bik RewritePatternSet &patterns, const SparsificationOptions &options) { 1406a2c9d4bbSAart Bik patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1407a2c9d4bbSAart Bik } 1408