196a23911SAart Bik //===- Sparsification.cpp - Implementation of sparsification --------------===// 2a2c9d4bbSAart Bik // 3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information. 5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a2c9d4bbSAart Bik // 7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 8a2c9d4bbSAart Bik // 9160399c7SAart Bik // This file implements converting sparse tensor types to actual sparse code. 10a2c9d4bbSAart Bik // 11a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 12a2c9d4bbSAart Bik 1376a18618SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 154397a1baSMatthias Springer #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" 16a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h" 1866f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 19a2c9d4bbSAart Bik #include "mlir/Dialect/SCF/SCF.h" 2076a18618SMatthias Springer #include "mlir/Dialect/SCF/Transforms.h" 21a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23744146f6SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 24a2c9d4bbSAart Bik #include "mlir/Dialect/StandardOps/IR/Ops.h" 25a2c9d4bbSAart Bik #include "mlir/Dialect/Vector/VectorOps.h" 26a2c9d4bbSAart Bik #include "mlir/IR/Matchers.h" 2796a23911SAart Bik #include "mlir/IR/TensorEncoding.h" 28a2c9d4bbSAart Bik #include "llvm/ADT/SmallBitVector.h" 29a2c9d4bbSAart Bik 30a2c9d4bbSAart Bik using namespace mlir; 3196a23911SAart Bik using namespace mlir::sparse_tensor; 32a2c9d4bbSAart Bik 335da21338SAart Bik //===----------------------------------------------------------------------===// 345da21338SAart Bik // Declarations of data structures. 355da21338SAart Bik //===----------------------------------------------------------------------===// 365da21338SAart Bik 37a2c9d4bbSAart Bik namespace { 38a2c9d4bbSAart Bik 39b6d1a31cSAart Bik // Iteration graph sorting. 40b6d1a31cSAart Bik enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 41b6d1a31cSAart Bik 425da21338SAart Bik // Reduction kinds. 437373cabcSAart Bik enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 445da21338SAart Bik 45a2c9d4bbSAart Bik // Code generation. 46a2c9d4bbSAart Bik struct CodeGen { 47f66e5769SAart Bik CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 48f66e5769SAart Bik OpOperand *op) 49a2c9d4bbSAart Bik : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 50a2c9d4bbSAart Bik pointers(numTensors, std::vector<Value>(numLoops)), 51a2c9d4bbSAart Bik indices(numTensors, std::vector<Value>(numLoops)), 52a2c9d4bbSAart Bik highs(numTensors, std::vector<Value>(numLoops)), 53a2c9d4bbSAart Bik pidxs(numTensors, std::vector<Value>(numLoops)), 54a2c9d4bbSAart Bik idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 55f66e5769SAart Bik redKind(kNoReduc), sparseOut(op), lexIdx(), curVecLength(1), 56f66e5769SAart Bik curVecMask() {} 57a2c9d4bbSAart Bik /// Sparsification options. 5896a23911SAart Bik SparsificationOptions options; 59a2c9d4bbSAart Bik /// Universal dense indices and upper bounds (by index). The loops array 60a2c9d4bbSAart Bik /// is updated with the value of the universal dense index in the current 61a2c9d4bbSAart Bik /// loop. The sizes array is set once with the inferred dimension sizes. 62a2c9d4bbSAart Bik std::vector<Value> loops; 63a2c9d4bbSAart Bik std::vector<Value> sizes; 64a2c9d4bbSAart Bik /// Buffers for storing dense and sparse numerical values (by tensor). 65a2c9d4bbSAart Bik /// This array is set once during bufferization of all tensors. 66a2c9d4bbSAart Bik std::vector<Value> buffers; 67a2c9d4bbSAart Bik /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 68a2c9d4bbSAart Bik /// This array is set once during bufferization of all sparse tensors. 69a2c9d4bbSAart Bik std::vector<std::vector<Value>> pointers; 70a2c9d4bbSAart Bik std::vector<std::vector<Value>> indices; 71a2c9d4bbSAart Bik /// Sparse iteration information (by tensor and index). These arrays 72a2c9d4bbSAart Bik /// are updated to remain current within the current loop. 73a2c9d4bbSAart Bik std::vector<std::vector<Value>> highs; 74a2c9d4bbSAart Bik std::vector<std::vector<Value>> pidxs; 75a2c9d4bbSAart Bik std::vector<std::vector<Value>> idxs; 76a2c9d4bbSAart Bik /// Current reduction, updated during code generation. When indices of a 777373cabcSAart Bik /// reduction are exhausted, all inner loops can use a scalarized reduction. 78a2c9d4bbSAart Bik unsigned redExp; 79a2c9d4bbSAart Bik Value redVal; 805da21338SAart Bik Reduction redKind; 81f66e5769SAart Bik // Sparse tensor as output. 82f66e5769SAart Bik OpOperand *sparseOut; 83f66e5769SAart Bik Value lexIdx; 84a2c9d4bbSAart Bik // Current vector length and mask. 85a2c9d4bbSAart Bik unsigned curVecLength; 86a2c9d4bbSAart Bik Value curVecMask; 87a2c9d4bbSAart Bik }; 88a2c9d4bbSAart Bik 89a2c9d4bbSAart Bik } // namespace 90a2c9d4bbSAart Bik 915da21338SAart Bik //===----------------------------------------------------------------------===// 925da21338SAart Bik // Sparse compiler analysis methods. 935da21338SAart Bik //===----------------------------------------------------------------------===// 945da21338SAart Bik 955da21338SAart Bik /// Helper method to apply dimension ordering permutation. 965da21338SAart Bik static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 97c194b49cSAart Bik if (enc) { 98c194b49cSAart Bik auto order = enc.getDimOrdering(); 99c194b49cSAart Bik if (order) { 100c194b49cSAart Bik assert(order.isPermutation()); 101c194b49cSAart Bik return order.getDimPosition(d); 102c194b49cSAart Bik } 103c194b49cSAart Bik } 104c194b49cSAart Bik return d; 105c194b49cSAart Bik } 106c194b49cSAart Bik 1075da21338SAart Bik /// Helper method to translate dim level type to internal representation. 1085da21338SAart Bik static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 10996a23911SAart Bik if (enc) { 11096a23911SAart Bik SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 11196a23911SAart Bik if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 11296a23911SAart Bik return Dim::kSparse; 11396a23911SAart Bik if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 11496a23911SAart Bik return Dim::kSingle; 11596a23911SAart Bik } 11696a23911SAart Bik return Dim::kDense; 11796a23911SAart Bik } 11896a23911SAart Bik 119b1d44e59SAart Bik /// Helper method to inspect affine expressions. Rejects cases where the 120c8d5dcb0SAart Bik /// same index is used more than once. Also rejects affine expressions 121c8d5dcb0SAart Bik /// that are not a direct index for annotated tensors. 122c8d5dcb0SAart Bik // TODO: accept more affine cases for sparse tensors 123b1d44e59SAart Bik static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 124b1d44e59SAart Bik bool isDense) { 125b1d44e59SAart Bik switch (a.getKind()) { 126b1d44e59SAart Bik case AffineExprKind::DimId: { 127b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 128b1d44e59SAart Bik if (!merger.isDim(tensor, idx, Dim::kUndef)) 129b1d44e59SAart Bik return false; // used more than once 130b1d44e59SAart Bik merger.setDim(tensor, idx, dim); 131b1d44e59SAart Bik return true; 132b1d44e59SAart Bik } 133b1d44e59SAart Bik case AffineExprKind::Add: 134b1d44e59SAart Bik case AffineExprKind::Mul: { 135b1d44e59SAart Bik if (!isDense) 136b1d44e59SAart Bik return false; 137b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 138b1d44e59SAart Bik return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 139b1d44e59SAart Bik findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 140b1d44e59SAart Bik } 141b1d44e59SAart Bik case AffineExprKind::Constant: 142b1d44e59SAart Bik return isDense; 143b1d44e59SAart Bik default: 144b1d44e59SAart Bik return false; 145b1d44e59SAart Bik } 146b1d44e59SAart Bik } 147b1d44e59SAart Bik 14896a23911SAart Bik /// Helper method to inspect sparse encodings in the tensor types. 149a2c9d4bbSAart Bik /// Fills the per-dimension sparsity information for all tensors. 150b1d44e59SAart Bik /// Returns true if the sparse annotations and affine subscript 151b1d44e59SAart Bik /// expressions of all tensors are admissable. Returns false if 152b1d44e59SAart Bik /// no annotations are found or inadmissable constructs occur. 153bf9ef3efSAart Bik static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 154bf9ef3efSAart Bik bool annotated = false; 1552f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 1562f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 1572f2b5b7dSTobias Gysi auto enc = getSparseTensorEncoding(t->get().getType()); 158727a63e0SAart Bik if (enc) 159bf9ef3efSAart Bik annotated = true; 1602f2b5b7dSTobias Gysi assert(map.getNumResults() == op.getRank(t)); 161c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 162b1d44e59SAart Bik unsigned tensor = t->getOperandNumber(); 163b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 164b1d44e59SAart Bik if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 165b1d44e59SAart Bik return false; // inadmissable affine expression 166a2c9d4bbSAart Bik } 167a2c9d4bbSAart Bik } 168bf9ef3efSAart Bik return annotated; 169a2c9d4bbSAart Bik } 170a2c9d4bbSAart Bik 171a2c9d4bbSAart Bik /// A DFS helper to compute a topological sort. Note that recursion is 172a2c9d4bbSAart Bik /// bounded by the number of implicit loops, which is always small. 173a2c9d4bbSAart Bik /// Returns false when a cycle is detected. 174a2c9d4bbSAart Bik static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 175a2c9d4bbSAart Bik std::vector<unsigned> &topSort, 176a2c9d4bbSAart Bik std::vector<std::vector<bool>> &adjM) { 177a2c9d4bbSAart Bik if (visit[i] != 0) 178a2c9d4bbSAart Bik return visit[i] != 1; // 1 denotes cycle! 179a2c9d4bbSAart Bik visit[i] = 1; 180a2c9d4bbSAart Bik for (unsigned j = 0, e = visit.size(); j < e; j++) 181a2c9d4bbSAart Bik if (adjM[i][j]) 182a2c9d4bbSAart Bik if (!topSortDFS(j, visit, topSort, adjM)) 183a2c9d4bbSAart Bik return false; 184a2c9d4bbSAart Bik visit[i] = 2; 185a2c9d4bbSAart Bik topSort.push_back(i); 186a2c9d4bbSAart Bik return true; 187a2c9d4bbSAart Bik } 188a2c9d4bbSAart Bik 189b1d44e59SAart Bik /// Helper method to add all constraints from the indices in one affine 190b1d44e59SAart Bik /// expression before all indices in the other affine expression. For 191b1d44e59SAart Bik /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 192b1d44e59SAart Bik static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 193b1d44e59SAart Bik AffineExpr a, AffineExpr b, unsigned fidx) { 194b1d44e59SAart Bik switch (a.getKind()) { 195b1d44e59SAart Bik case AffineExprKind::DimId: { 196b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 197b1d44e59SAart Bik if (b) 198b1d44e59SAart Bik addAffineOrderings(adjM, b, AffineExpr(), idx); 199b1d44e59SAart Bik else 200b1d44e59SAart Bik adjM[fidx][idx] = true; 201b1d44e59SAart Bik break; 202b1d44e59SAart Bik } 203b1d44e59SAart Bik case AffineExprKind::Add: 204b1d44e59SAart Bik case AffineExprKind::Mul: { 205b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 206b1d44e59SAart Bik addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 207b1d44e59SAart Bik addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 208b1d44e59SAart Bik break; 209b1d44e59SAart Bik } 210b1d44e59SAart Bik default: 211b1d44e59SAart Bik break; 212b1d44e59SAart Bik } 213b1d44e59SAart Bik } 214b1d44e59SAart Bik 215a2c9d4bbSAart Bik /// Computes a topologically sorted iteration graph for the linalg operation. 216a2c9d4bbSAart Bik /// Ensures all tensors are visited in natural index order. This is essential 217a2c9d4bbSAart Bik /// for sparse storage formats since these only support access along fixed 218a2c9d4bbSAart Bik /// dimensions. Even for dense storage formats, however, the natural index 219a2c9d4bbSAart Bik /// order yields innermost unit-stride access with better spatial locality. 220a2c9d4bbSAart Bik static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 221a2c9d4bbSAart Bik std::vector<unsigned> &topSort, 222b6d1a31cSAart Bik unsigned mask) { 223a2c9d4bbSAart Bik // Set up an n x n from/to adjacency matrix of the iteration graph 224a2c9d4bbSAart Bik // for the implicit loop indices i_0 .. i_n-1. 225a2c9d4bbSAart Bik unsigned n = op.getNumLoops(); 226a2c9d4bbSAart Bik std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 227a2c9d4bbSAart Bik 228a2c9d4bbSAart Bik // Iterate over the indexing maps of every tensor in the tensor expression. 2292f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 2302f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 2312f2b5b7dSTobias Gysi auto enc = getSparseTensorEncoding(t->get().getType()); 232a2c9d4bbSAart Bik assert(map.getNumDims() == n); 233b6d1a31cSAart Bik // Skip dense tensor constraints when not requested. 234b6d1a31cSAart Bik if (!(mask & SortMask::kIncludeDense) && !enc) 235a2c9d4bbSAart Bik continue; 236c194b49cSAart Bik // Each tensor expression and optional dimension ordering (row-major 237c194b49cSAart Bik // by default) puts an ordering constraint on the loop indices. For 238c194b49cSAart Bik // example, the tensor expresion A_ijk forces the ordering i < j < k 239c194b49cSAart Bik // on the loop indices if no explicit dimension ordering is given. 240c194b49cSAart Bik for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 241b1d44e59SAart Bik AffineExpr f = map.getResult(perm(enc, d - 1)); 242b1d44e59SAart Bik AffineExpr t = map.getResult(perm(enc, d)); 243b1d44e59SAart Bik addAffineOrderings(adjM, f, t, 0); 244a2c9d4bbSAart Bik } 245b6d1a31cSAart Bik // Push unrelated loops into sparse iteration space, so these 246b6d1a31cSAart Bik // will be skipped more often. 247b6d1a31cSAart Bik if (mask & SortMask::kIncludeUndef) { 248b6d1a31cSAart Bik unsigned tensor = t->getOperandNumber(); 249b6d1a31cSAart Bik for (unsigned i = 0; i < n; i++) 250b6d1a31cSAart Bik if (merger.isDim(tensor, i, Dim::kSparse)) 251b6d1a31cSAart Bik for (unsigned j = 0; j < n; j++) 252b6d1a31cSAart Bik if (merger.isDim(tensor, j, Dim::kUndef)) 253b6d1a31cSAart Bik adjM[i][j] = true; 254b6d1a31cSAart Bik } 255a2c9d4bbSAart Bik } 256a2c9d4bbSAart Bik 257a2c9d4bbSAart Bik // Topologically sort the iteration graph to determine loop order. 258a2c9d4bbSAart Bik // Report failure for a cyclic iteration graph. 259a2c9d4bbSAart Bik topSort.clear(); 260a2c9d4bbSAart Bik topSort.reserve(n); 261a2c9d4bbSAart Bik std::vector<unsigned> visit(n, 0); 262a2c9d4bbSAart Bik for (unsigned i = 0; i < n; i++) 263a2c9d4bbSAart Bik if (visit[i] == 0) 264a2c9d4bbSAart Bik if (!topSortDFS(i, visit, topSort, adjM)) 265a2c9d4bbSAart Bik return false; // cycle! 266a2c9d4bbSAart Bik std::reverse(std::begin(topSort), std::end(topSort)); 267a2c9d4bbSAart Bik return true; 268a2c9d4bbSAart Bik } 269a2c9d4bbSAart Bik 270c8d5dcb0SAart Bik /// Returns true if tensor has an in-place annotation. 271c8d5dcb0SAart Bik static bool isInPlace(Value val) { 272c8d5dcb0SAart Bik if (auto arg = val.dyn_cast<BlockArgument>()) 273c8d5dcb0SAart Bik if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 274c8d5dcb0SAart Bik if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 2754397a1baSMatthias Springer arg.getArgNumber(), 2764397a1baSMatthias Springer linalg::comprehensive_bufferize::BufferizableOpInterface:: 2774397a1baSMatthias Springer kInplaceableAttrName)) 278c8d5dcb0SAart Bik return attr.getValue(); 279c8d5dcb0SAart Bik return false; 280c8d5dcb0SAart Bik } 281c8d5dcb0SAart Bik 282f66e5769SAart Bik /// Returns true if tensor materializes uninitialized into the computation. 283c8d5dcb0SAart Bik static bool isMaterializing(Value val) { 284c8d5dcb0SAart Bik return val.getDefiningOp<linalg::InitTensorOp>() || 285c8d5dcb0SAart Bik val.getDefiningOp<InitOp>(); 286c8d5dcb0SAart Bik } 287c8d5dcb0SAart Bik 28836b66ab9SAart Bik /// Returns true when the tensor expression is admissable for codegen. 28936b66ab9SAart Bik /// Since all sparse input tensors are admissable, we just need to check 29036b66ab9SAart Bik /// whether the output tensor in the tensor expression codegen is admissable. 291f66e5769SAart Bik /// Sets `sparseOut` when a "truly dynamic" sparse tensor output occurs. 29236b66ab9SAart Bik static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 293f66e5769SAart Bik unsigned exp, OpOperand **sparseOut) { 29436b66ab9SAart Bik OpOperand *lhs = op.getOutputOperand(0); 29536b66ab9SAart Bik unsigned tensor = lhs->getOperandNumber(); 29636b66ab9SAart Bik auto enc = getSparseTensorEncoding(lhs->get().getType()); 29736b66ab9SAart Bik // An non-annotated output tensor is assumed dense, and becomes a random 298b1d44e59SAart Bik // access n-dim memref. Admissable since insertions cannot occur. 29936b66ab9SAart Bik if (!enc) 30036b66ab9SAart Bik return true; 30136b66ab9SAart Bik // An all-dense annotated "sparse" output tensor becomes a linearized random 30236b66ab9SAart Bik // access 1-dim memref. Also admissable since insertions cannot occur. 30336b66ab9SAart Bik bool allDense = true; 30436b66ab9SAart Bik unsigned numLoops = op.iterator_types().getValue().size(); 30536b66ab9SAart Bik for (unsigned i = 0; i < numLoops; i++) 30636b66ab9SAart Bik if (merger.isDim(tensor, i, Dim::kSparse)) { 30736b66ab9SAart Bik allDense = false; 30836b66ab9SAart Bik break; 30936b66ab9SAart Bik } 31036b66ab9SAart Bik if (allDense) 31136b66ab9SAart Bik return true; 31236b66ab9SAart Bik // A tensor expression with a sparse output tensor that changes its values 31336b66ab9SAart Bik // but not its nonzero structure, an operation called "simply dynamic" in 314c8d5dcb0SAart Bik // [Bik96,Ch9], is also admissable without special codegen, provided 315c8d5dcb0SAart Bik // the tensor's underlying sparse storage scheme can be modified in place. 316f66e5769SAart Bik if (merger.isConjunction(tensor, exp) && isInPlace(lhs->get())) 317f66e5769SAart Bik return true; 318f66e5769SAart Bik // Accept "truly dynamic" if the output tensor materializes uninitialized 319f66e5769SAart Bik // into the computation and insertions occur in lexicographic index order. 320f66e5769SAart Bik if (isMaterializing(lhs->get())) { 321f66e5769SAart Bik // In this first sparse tensor output implementation, this is enforced by 322f66e5769SAart Bik // rejecting any reduction loops (since the sparse parallel loops give a 323f66e5769SAart Bik // lexicographically sorted and injective view into that tensor). 324f66e5769SAart Bik // TODO: generalize to include reductions 325f66e5769SAart Bik for (auto attr : op.iterator_types()) 326f66e5769SAart Bik if (isReductionIterator(attr)) 327f66e5769SAart Bik return false; 328f66e5769SAart Bik *sparseOut = lhs; 329f66e5769SAart Bik return true; 330f66e5769SAart Bik } 33136b66ab9SAart Bik return false; 33236b66ab9SAart Bik } 33336b66ab9SAart Bik 3345da21338SAart Bik //===----------------------------------------------------------------------===// 3357373cabcSAart Bik // Sparse compiler synthesis methods (reductions). 3365da21338SAart Bik //===----------------------------------------------------------------------===// 3375da21338SAart Bik 3385da21338SAart Bik /// Maps reduction kind to name encoding. 3395da21338SAart Bik static StringRef getReductionName(Reduction kind) { 3405da21338SAart Bik switch (kind) { 3417373cabcSAart Bik case kNoReduc: 3427373cabcSAart Bik break; 3435da21338SAart Bik case kSum: 3445da21338SAart Bik return "add"; 3455da21338SAart Bik case kProduct: 3465da21338SAart Bik return "mul"; 3475da21338SAart Bik case kAnd: 3485da21338SAart Bik return "and"; 3495da21338SAart Bik case kOr: 3505da21338SAart Bik return "or"; 3515da21338SAart Bik case kXor: 3525da21338SAart Bik return "xor"; 3535da21338SAart Bik } 3545da21338SAart Bik llvm_unreachable("unknown reduction kind"); 3555da21338SAart Bik } 3565da21338SAart Bik 3575da21338SAart Bik /// Maps operation to reduction. 3585da21338SAart Bik static Reduction getReduction(Kind kind) { 3595da21338SAart Bik switch (kind) { 3605da21338SAart Bik case Kind::kAddF: 3615da21338SAart Bik case Kind::kAddI: 3625da21338SAart Bik case Kind::kSubF: 3635da21338SAart Bik case Kind::kSubI: 3645da21338SAart Bik return kSum; 3655da21338SAart Bik case Kind::kMulF: 3665da21338SAart Bik case Kind::kMulI: 3675da21338SAart Bik return kProduct; 3685da21338SAart Bik case Kind::kAndI: 3695da21338SAart Bik return kAnd; 3705da21338SAart Bik case Kind::kOrI: 3715da21338SAart Bik return kOr; 3725da21338SAart Bik case Kind::kXorI: 3735da21338SAart Bik return kXor; 3745da21338SAart Bik default: 3755da21338SAart Bik llvm_unreachable("unexpected reduction operator"); 3765da21338SAart Bik } 3775da21338SAart Bik } 3785da21338SAart Bik 3797373cabcSAart Bik /// Generates an initial value for a vector reduction, following the scheme 3805da21338SAart Bik /// given in Chapter 5 of "The Software Vectorization Handbook", where the 3815da21338SAart Bik /// initial scalar value is correctly embedded in the vector reduction value, 3825da21338SAart Bik /// and a straightforward horizontal reduction will complete the operation. 3837373cabcSAart Bik static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, 3847373cabcSAart Bik Location loc, VectorType vtp) { 3857373cabcSAart Bik Value r = codegen.redVal; 3867373cabcSAart Bik switch (codegen.redKind) { 3877373cabcSAart Bik case kNoReduc: 3887373cabcSAart Bik break; 3895da21338SAart Bik case kSum: 3905da21338SAart Bik case kXor: { 3915da21338SAart Bik // Initialize reduction vector to: | 0 | .. | 0 | r | 3925da21338SAart Bik Attribute zero = rewriter.getZeroAttr(vtp); 393c8d5dcb0SAart Bik Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero); 394*7c5ecc8bSMogball return rewriter.create<vector::InsertElementOp>( 395*7c5ecc8bSMogball loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0)); 3965da21338SAart Bik } 3975da21338SAart Bik case kProduct: { 3985da21338SAart Bik // Initialize reduction vector to: | 1 | .. | 1 | r | 3995da21338SAart Bik Type etp = vtp.getElementType(); 4005da21338SAart Bik Attribute one; 4015da21338SAart Bik if (etp.isa<FloatType>()) 4025da21338SAart Bik one = rewriter.getFloatAttr(etp, 1.0); 4035da21338SAart Bik else 4045da21338SAart Bik one = rewriter.getIntegerAttr(etp, 1); 405c8d5dcb0SAart Bik Value vec = rewriter.create<arith::ConstantOp>( 406c8d5dcb0SAart Bik loc, vtp, DenseElementsAttr::get(vtp, one)); 407*7c5ecc8bSMogball return rewriter.create<vector::InsertElementOp>( 408*7c5ecc8bSMogball loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0)); 4095da21338SAart Bik } 4105da21338SAart Bik case kAnd: 4115da21338SAart Bik case kOr: 4125da21338SAart Bik // Initialize reduction vector to: | r | .. | r | r | 4135da21338SAart Bik return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 4145da21338SAart Bik } 4155da21338SAart Bik llvm_unreachable("unknown reduction kind"); 4165da21338SAart Bik } 4175da21338SAart Bik 4187373cabcSAart Bik /// Generates final value for a vector reduction. 4197373cabcSAart Bik static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, 4207373cabcSAart Bik Location loc, VectorType vtp) { 4217373cabcSAart Bik StringRef name = getReductionName(codegen.redKind); 4227373cabcSAart Bik StringAttr kind = rewriter.getStringAttr(name); 4237373cabcSAart Bik return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind, 4247373cabcSAart Bik codegen.redVal, ValueRange{}); 4257373cabcSAart Bik } 4267373cabcSAart Bik 4277373cabcSAart Bik /// Updates scalarized reduction value. 4287373cabcSAart Bik static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 4297373cabcSAart Bik assert(codegen.redKind != kNoReduc); 4307373cabcSAart Bik codegen.redVal = merger.exp(codegen.redExp).val = reduc; 4317373cabcSAart Bik } 4327373cabcSAart Bik 4337373cabcSAart Bik //===----------------------------------------------------------------------===// 4347373cabcSAart Bik // Sparse compiler synthesis methods (statements and expressions). 4357373cabcSAart Bik //===----------------------------------------------------------------------===// 4367373cabcSAart Bik 437a2c9d4bbSAart Bik /// Maps sparse integer option to actual integral storage type. 43896a23911SAart Bik static Type genIntType(PatternRewriter &rewriter, unsigned width) { 43996a23911SAart Bik if (width == 0) 440a2c9d4bbSAart Bik return rewriter.getIndexType(); 44196a23911SAart Bik return rewriter.getIntegerType(width); 442a2c9d4bbSAart Bik } 443a2c9d4bbSAart Bik 444ec97a205SAart Bik /// Generates buffer for the output tensor. Note that all sparse kernels 445ec97a205SAart Bik /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 446ec97a205SAart Bik /// the output buffer is already initialized to all zeroes and only nonzeroes 447ec97a205SAart Bik /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 448ec97a205SAart Bik /// only nonzeroes values are used for the updates and no assumption on the 449ec97a205SAart Bik /// original contents of the output buffer is necessary.. 450a2c9d4bbSAart Bik static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 451a2c9d4bbSAart Bik linalg::GenericOp op, MemRefType denseTp, 452a2c9d4bbSAart Bik ArrayRef<Value> args) { 453a2c9d4bbSAart Bik Location loc = op.getLoc(); 4542f2b5b7dSTobias Gysi Value tensor = op.getOutputOperand(0)->get(); 455a2c9d4bbSAart Bik // The output tensor simply could materialize from the buffer that will 456a2c9d4bbSAart Bik // be generated for the tensor present in the outs() clause. This has 457a2c9d4bbSAart Bik // the major advantage that the sparse kernel only updates the nonzero 4585879da49SAart Bik // positions for the output tensor. 459c8d5dcb0SAart Bik if (isInPlace(tensor)) 460a2c9d4bbSAart Bik return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 461a2c9d4bbSAart Bik // By default, a new buffer is allocated which is initialized to the 462a2c9d4bbSAart Bik // tensor defined in the outs() clause. This is always correct but 463a2c9d4bbSAart Bik // introduces a dense initialization component that may negatively 464ec97a205SAart Bik // impact the running complexity of the sparse kernel. If the tensor 465c8d5dcb0SAart Bik // materializes into the computation, we need to preserve the zero 466ec97a205SAart Bik // initialization assumption of all sparse output buffers. 467c8d5dcb0SAart Bik if (isMaterializing(tensor)) { 468ec97a205SAart Bik Type tp = denseTp.getElementType(); 469ec97a205SAart Bik Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 470c8d5dcb0SAart Bik Value zero = 471c8d5dcb0SAart Bik rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 472ec97a205SAart Bik rewriter.create<linalg::FillOp>(loc, zero, alloc); 473ec97a205SAart Bik return alloc; 474ec97a205SAart Bik } 475a2c9d4bbSAart Bik Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 476a2c9d4bbSAart Bik Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 47768ac2e53SAart Bik rewriter.create<memref::CopyOp>(loc, init, alloc); 478a2c9d4bbSAart Bik return alloc; 479a2c9d4bbSAart Bik } 480a2c9d4bbSAart Bik 481a2c9d4bbSAart Bik /// Local bufferization of all dense and sparse data structures. 482a2c9d4bbSAart Bik /// This code enables testing the first prototype sparse compiler. 483a2c9d4bbSAart Bik // TODO: replace this with a proliferated bufferization strategy 484c8d5dcb0SAart Bik static void genBuffers(Merger &merger, CodeGen &codegen, 485a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op) { 486a2c9d4bbSAart Bik Location loc = op.getLoc(); 4872f2b5b7dSTobias Gysi assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 488a2c9d4bbSAart Bik // For every tensor, find lower and upper bound on dimensions, set the 489a2c9d4bbSAart Bik // same bounds on loop indices, and obtain dense or sparse buffer(s). 490a2c9d4bbSAart Bik SmallVector<Value, 4> args; 4912f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 492727a63e0SAart Bik unsigned tensor = t->getOperandNumber(); 4932f2b5b7dSTobias Gysi auto shape = op.getShape(t); 4942f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 4952f2b5b7dSTobias Gysi auto enc = getSparseTensorEncoding(t->get().getType()); 496a2c9d4bbSAart Bik // Scan all dimensions of current tensor. 497a2c9d4bbSAart Bik args.clear(); 498c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 499b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 500b1d44e59SAart Bik if (a.getKind() != AffineExprKind::DimId) 501b1d44e59SAart Bik continue; // compound 502b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 503a2c9d4bbSAart Bik // Handle sparse storage schemes. 504727a63e0SAart Bik if (merger.isDim(tensor, idx, Dim::kSparse)) { 505a2c9d4bbSAart Bik auto dynShape = {ShapedType::kDynamicSize}; 506a2c9d4bbSAart Bik auto ptrTp = MemRefType::get( 50796a23911SAart Bik dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 508a2c9d4bbSAart Bik auto indTp = MemRefType::get( 50996a23911SAart Bik dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 510a54f4eaeSMogball Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d); 511a2c9d4bbSAart Bik // Generate sparse primitives to obtains pointer and indices. 512727a63e0SAart Bik codegen.pointers[tensor][idx] = 5132f2b5b7dSTobias Gysi rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 514727a63e0SAart Bik codegen.indices[tensor][idx] = 5152f2b5b7dSTobias Gysi rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 516a2c9d4bbSAart Bik } 517d37d72eaSAart Bik // Find upper bound in current dimension. 518817303efSAart Bik unsigned p = perm(enc, d); 519d37d72eaSAart Bik Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 520d37d72eaSAart Bik if (shape[p] == MemRefType::kDynamicSize) 521a2c9d4bbSAart Bik args.push_back(up); 522817303efSAart Bik assert(codegen.highs[tensor][idx] == nullptr); 523727a63e0SAart Bik codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 524a2c9d4bbSAart Bik } 525727a63e0SAart Bik // Perform the required bufferization. Dense inputs materialize 526727a63e0SAart Bik // from the input tensors. Dense outputs need special handling. 527727a63e0SAart Bik // Sparse inputs use sparse primitives to obtain the values. 528727a63e0SAart Bik // We also accept in-place all-dense annotated "sparse" outputs. 5292f2b5b7dSTobias Gysi Type elementType = getElementTypeOrSelf(t->get().getType()); 53096a23911SAart Bik if (!enc) { 531727a63e0SAart Bik // Non-annotated dense tensors. 5322f2b5b7dSTobias Gysi auto denseTp = MemRefType::get(shape, elementType); 533727a63e0SAart Bik if (tensor < op.getNumInputs()) 534727a63e0SAart Bik codegen.buffers[tensor] = 5352f2b5b7dSTobias Gysi rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 536a2c9d4bbSAart Bik else 537727a63e0SAart Bik codegen.buffers[tensor] = 538a2c9d4bbSAart Bik genOutputBuffer(codegen, rewriter, op, denseTp, args); 539f66e5769SAart Bik } else if (t == codegen.sparseOut) { 540f66e5769SAart Bik // True sparse output needs a lexIdx array. 541f66e5769SAart Bik Value rank = rewriter.create<arith::ConstantIndexOp>(loc, op.getRank(t)); 542f66e5769SAart Bik auto dynShape = {ShapedType::kDynamicSize}; 543f66e5769SAart Bik auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); 544f66e5769SAart Bik codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank); 545a2c9d4bbSAart Bik } else { 546727a63e0SAart Bik // Annotated sparse tensors. 547a2c9d4bbSAart Bik auto dynShape = {ShapedType::kDynamicSize}; 5482f2b5b7dSTobias Gysi auto sparseTp = MemRefType::get(dynShape, elementType); 549727a63e0SAart Bik codegen.buffers[tensor] = 5502f2b5b7dSTobias Gysi rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 551a2c9d4bbSAart Bik } 552a2c9d4bbSAart Bik } 553a2c9d4bbSAart Bik } 554a2c9d4bbSAart Bik 555a2c9d4bbSAart Bik /// Constructs vector type. 556a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Type etp) { 557a2c9d4bbSAart Bik return VectorType::get(codegen.curVecLength, etp); 558a2c9d4bbSAart Bik } 559a2c9d4bbSAart Bik 560a2c9d4bbSAart Bik /// Constructs vector type from pointer. 561a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Value ptr) { 562a2c9d4bbSAart Bik return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 563a2c9d4bbSAart Bik } 564a2c9d4bbSAart Bik 565a2c9d4bbSAart Bik /// Constructs vector iteration mask. 566a2c9d4bbSAart Bik static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 567a2c9d4bbSAart Bik Value iv, Value lo, Value hi, Value step) { 568a2c9d4bbSAart Bik Location loc = iv.getLoc(); 5697373cabcSAart Bik VectorType mtp = vectorType(codegen, genIntType(rewriter, 1)); 570a2c9d4bbSAart Bik // Special case if the vector length evenly divides the trip count (for 571a2c9d4bbSAart Bik // example, "for i = 0, 128, 16"). A constant all-true mask is generated 572a2c9d4bbSAart Bik // so that all subsequent masked memory operations are immediately folded 573a2c9d4bbSAart Bik // into unconditional memory operations. 574a2c9d4bbSAart Bik IntegerAttr loInt, hiInt, stepInt; 575a2c9d4bbSAart Bik if (matchPattern(lo, m_Constant(&loInt)) && 576a2c9d4bbSAart Bik matchPattern(hi, m_Constant(&hiInt)) && 577a2c9d4bbSAart Bik matchPattern(step, m_Constant(&stepInt))) { 578a2c9d4bbSAart Bik if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 579a2c9d4bbSAart Bik return rewriter.create<vector::BroadcastOp>( 580a54f4eaeSMogball loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1)); 581a2c9d4bbSAart Bik } 582a2c9d4bbSAart Bik // Otherwise, generate a vector mask that avoids overrunning the upperbound 583a2c9d4bbSAart Bik // during vector execution. Here we rely on subsequent loop optimizations to 584a2c9d4bbSAart Bik // avoid executing the mask in all iterations, for example, by splitting the 585a2c9d4bbSAart Bik // loop into an unconditional vector loop and a scalar cleanup loop. 58676a18618SMatthias Springer auto minMap = AffineMap::get( 58776a18618SMatthias Springer /*dimCount=*/2, /*symbolCount=*/1, 58876a18618SMatthias Springer {rewriter.getAffineSymbolExpr(0), 58976a18618SMatthias Springer rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 59076a18618SMatthias Springer rewriter.getContext()); 59176a18618SMatthias Springer Value end = 59276a18618SMatthias Springer rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 593a2c9d4bbSAart Bik return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 594a2c9d4bbSAart Bik } 595a2c9d4bbSAart Bik 596a2c9d4bbSAart Bik /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 597a2c9d4bbSAart Bik static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 598a2c9d4bbSAart Bik Value ptr, ArrayRef<Value> args) { 599a2c9d4bbSAart Bik Location loc = ptr.getLoc(); 600a2c9d4bbSAart Bik VectorType vtp = vectorType(codegen, ptr); 601a54f4eaeSMogball Value pass = 602a54f4eaeSMogball rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 603a2c9d4bbSAart Bik if (args.back().getType().isa<VectorType>()) { 604a2c9d4bbSAart Bik SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 605a2c9d4bbSAart Bik Value indexVec = args.back(); 606a54f4eaeSMogball scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 607a2c9d4bbSAart Bik return rewriter.create<vector::GatherOp>( 608a2c9d4bbSAart Bik loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 609a2c9d4bbSAart Bik } 610a2c9d4bbSAart Bik return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 611a2c9d4bbSAart Bik codegen.curVecMask, pass); 612a2c9d4bbSAart Bik } 613a2c9d4bbSAart Bik 614a2c9d4bbSAart Bik /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 615a2c9d4bbSAart Bik static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 616a2c9d4bbSAart Bik Value rhs, Value ptr, ArrayRef<Value> args) { 617a2c9d4bbSAart Bik Location loc = ptr.getLoc(); 618a2c9d4bbSAart Bik if (args.back().getType().isa<VectorType>()) { 619a2c9d4bbSAart Bik SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 620a2c9d4bbSAart Bik Value indexVec = args.back(); 621a54f4eaeSMogball scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 622a2c9d4bbSAart Bik rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 623a2c9d4bbSAart Bik codegen.curVecMask, rhs); 624a2c9d4bbSAart Bik return; 625a2c9d4bbSAart Bik } 626a2c9d4bbSAart Bik rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 627a2c9d4bbSAart Bik rhs); 628a2c9d4bbSAart Bik } 629a2c9d4bbSAart Bik 630a2c9d4bbSAart Bik /// Generates a vectorized invariant. Here we rely on subsequent loop 631a2c9d4bbSAart Bik /// optimizations to hoist the invariant broadcast out of the vector loop. 632a2c9d4bbSAart Bik static Value genVectorInvariantValue(CodeGen &codegen, 633a2c9d4bbSAart Bik PatternRewriter &rewriter, Value val) { 634a2c9d4bbSAart Bik VectorType vtp = vectorType(codegen, val.getType()); 635a2c9d4bbSAart Bik return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 636a2c9d4bbSAart Bik } 637a2c9d4bbSAart Bik 638b1d44e59SAart Bik /// Generates an affine expression. 639b1d44e59SAart Bik // 640b1d44e59SAart Bik // TODO: generalize for sparse tensor subscripts 641b1d44e59SAart Bik // 642b1d44e59SAart Bik static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 643b1d44e59SAart Bik AffineExpr a, Location loc) { 644b1d44e59SAart Bik switch (a.getKind()) { 645b1d44e59SAart Bik case AffineExprKind::DimId: { 646b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 647b1d44e59SAart Bik return codegen.loops[idx]; // universal dense index 648b1d44e59SAart Bik } 649b1d44e59SAart Bik case AffineExprKind::Add: { 650b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 651a54f4eaeSMogball return rewriter.create<arith::AddIOp>( 652b1d44e59SAart Bik loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 653b1d44e59SAart Bik genAffine(codegen, rewriter, binOp.getRHS(), loc)); 654b1d44e59SAart Bik } 655b1d44e59SAart Bik case AffineExprKind::Mul: { 656b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 657a54f4eaeSMogball return rewriter.create<arith::MulIOp>( 658b1d44e59SAart Bik loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 659b1d44e59SAart Bik genAffine(codegen, rewriter, binOp.getRHS(), loc)); 660b1d44e59SAart Bik } 661b1d44e59SAart Bik case AffineExprKind::Constant: { 662b1d44e59SAart Bik int64_t c = a.cast<AffineConstantExpr>().getValue(); 663a54f4eaeSMogball return rewriter.create<arith::ConstantIndexOp>(loc, c); 664b1d44e59SAart Bik } 665b1d44e59SAart Bik default: 666b1d44e59SAart Bik llvm_unreachable("unexpected affine subscript"); 667b1d44e59SAart Bik } 668b1d44e59SAart Bik } 669b1d44e59SAart Bik 670b1d44e59SAart Bik /// Generates subscript for load/store on a dense or sparse tensor. 671b1d44e59SAart Bik static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 672b1d44e59SAart Bik linalg::GenericOp op, OpOperand *t, 673b1d44e59SAart Bik SmallVector<Value, 4> &args) { 674b1d44e59SAart Bik unsigned tensor = t->getOperandNumber(); 675b1d44e59SAart Bik auto map = op.getTiedIndexingMap(t); 676b1d44e59SAart Bik auto enc = getSparseTensorEncoding(t->get().getType()); 677b1d44e59SAart Bik unsigned rank = map.getNumResults(); 678b1d44e59SAart Bik if (enc) { 679b1d44e59SAart Bik // Note that currently, all sparse subscripts are simple. 680b1d44e59SAart Bik // TODO: accept affine too? 681c8d5dcb0SAart Bik AffineExpr a = map.getResult(perm(enc, rank - 1)); 682c8d5dcb0SAart Bik assert(a.getKind() == AffineExprKind::DimId); 683c8d5dcb0SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 684b1d44e59SAart Bik assert(codegen.pidxs[tensor][idx] != nullptr); 685b1d44e59SAart Bik args.push_back(codegen.pidxs[tensor][idx]); // position index 686b1d44e59SAart Bik } else { 687b1d44e59SAart Bik for (unsigned d = 0; d < rank; d++) { 688b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 689b1d44e59SAart Bik args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 690b1d44e59SAart Bik } 691b1d44e59SAart Bik } 692b1d44e59SAart Bik return codegen.buffers[tensor]; 693b1d44e59SAart Bik } 694b1d44e59SAart Bik 695a2c9d4bbSAart Bik /// Generates a load on a dense or sparse tensor. 696a2c9d4bbSAart Bik static Value genTensorLoad(Merger &merger, CodeGen &codegen, 697a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 698a2c9d4bbSAart Bik unsigned exp) { 699a2c9d4bbSAart Bik // Test if the load was hoisted to a higher loop nest. 700a2c9d4bbSAart Bik Value val = merger.exp(exp).val; 701a2c9d4bbSAart Bik if (val) { 702a2c9d4bbSAart Bik if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 703a2c9d4bbSAart Bik return genVectorInvariantValue(codegen, rewriter, val); 704a2c9d4bbSAart Bik return val; 705a2c9d4bbSAart Bik } 706a2c9d4bbSAart Bik // Actual load. 707a2c9d4bbSAart Bik SmallVector<Value, 4> args; 7084569c14aSGus Smith OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 709b1d44e59SAart Bik Value ptr = genSubscript(codegen, rewriter, op, t, args); 710a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 711a2c9d4bbSAart Bik return genVectorLoad(codegen, rewriter, ptr, args); 712b1d44e59SAart Bik return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 713a2c9d4bbSAart Bik } 714a2c9d4bbSAart Bik 715727a63e0SAart Bik /// Generates a store on a dense or sparse tensor. 716a2c9d4bbSAart Bik static void genTensorStore(Merger &merger, CodeGen &codegen, 717a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 718b1d44e59SAart Bik Value rhs) { 719f66e5769SAart Bik Location loc = op.getLoc(); 720a2c9d4bbSAart Bik // Test if this is a scalarized reduction. 721b1d44e59SAart Bik if (codegen.redVal) { 722a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 723f66e5769SAart Bik rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 724a2c9d4bbSAart Bik codegen.redVal); 7257373cabcSAart Bik updateReduc(merger, codegen, rhs); 726a2c9d4bbSAart Bik return; 727a2c9d4bbSAart Bik } 728f66e5769SAart Bik // Insertion. 729f66e5769SAart Bik OpOperand *t = op.getOutputOperand(0); 730f66e5769SAart Bik if (t == codegen.sparseOut) { 731f66e5769SAart Bik rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs); 732f66e5769SAart Bik return; 733f66e5769SAart Bik } 734a2c9d4bbSAart Bik // Actual store. 735a2c9d4bbSAart Bik SmallVector<Value, 4> args; 736b1d44e59SAart Bik Value ptr = genSubscript(codegen, rewriter, op, t, args); 737a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 738a2c9d4bbSAart Bik genVectorStore(codegen, rewriter, rhs, ptr, args); 739a2c9d4bbSAart Bik else 740f66e5769SAart Bik rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 741a2c9d4bbSAart Bik } 742a2c9d4bbSAart Bik 743a2c9d4bbSAart Bik /// Generates a pointer/index load from the sparse storage scheme. Narrower 744a2c9d4bbSAart Bik /// data types need to be zero extended before casting the value into the 745a2c9d4bbSAart Bik /// index type used for looping and indexing. 746a2c9d4bbSAart Bik static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 747a2c9d4bbSAart Bik Value ptr, Value s) { 748a2c9d4bbSAart Bik // See https://llvm.org/docs/GetElementPtr.html for some background on 749a2c9d4bbSAart Bik // the complications described below. 750a2c9d4bbSAart Bik if (codegen.curVecLength > 1) { 751a2c9d4bbSAart Bik // Since the index vector is used in a subsequent gather/scatter operations, 752a2c9d4bbSAart Bik // which effectively defines an unsigned pointer + signed index, we must 753a2c9d4bbSAart Bik // zero extend the vector to an index width. For 8-bit and 16-bit values, 754a2c9d4bbSAart Bik // an 32-bit index width suffices. For 32-bit values, zero extending the 755a2c9d4bbSAart Bik // elements into 64-bit loses some performance since the 32-bit indexed 75686e9bc1aSAart Bik // gather/scatter is more efficient than the 64-bit index variant (if the 75786e9bc1aSAart Bik // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 758727a63e0SAart Bik // preserve this performance). For 64-bit values, there is no good way 759a2c9d4bbSAart Bik // to state that the indices are unsigned, with creates the potential of 760a2c9d4bbSAart Bik // incorrect address calculations in the unlikely case we need such 761a2c9d4bbSAart Bik // extremely large offsets. 762a2c9d4bbSAart Bik Type etp = ptr.getType().cast<MemRefType>().getElementType(); 763a2c9d4bbSAart Bik Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 764a2c9d4bbSAart Bik if (!etp.isa<IndexType>()) { 765a2c9d4bbSAart Bik if (etp.getIntOrFloatBitWidth() < 32) 766a54f4eaeSMogball vload = rewriter.create<arith::ExtUIOp>( 7677373cabcSAart Bik loc, vload, vectorType(codegen, genIntType(rewriter, 32))); 76886e9bc1aSAart Bik else if (etp.getIntOrFloatBitWidth() < 64 && 76986e9bc1aSAart Bik !codegen.options.enableSIMDIndex32) 770a54f4eaeSMogball vload = rewriter.create<arith::ExtUIOp>( 7717373cabcSAart Bik loc, vload, vectorType(codegen, genIntType(rewriter, 64))); 772a2c9d4bbSAart Bik } 773a2c9d4bbSAart Bik return vload; 774a2c9d4bbSAart Bik } 775a2c9d4bbSAart Bik // For the scalar case, we simply zero extend narrower indices into 64-bit 776a2c9d4bbSAart Bik // values before casting to index without a performance penalty. Here too, 777a2c9d4bbSAart Bik // however, indices that already are 64-bit, in theory, cannot express the 778a2c9d4bbSAart Bik // full range as explained above. 779a2c9d4bbSAart Bik Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 780a2c9d4bbSAart Bik if (!load.getType().isa<IndexType>()) { 781a2c9d4bbSAart Bik if (load.getType().getIntOrFloatBitWidth() < 64) 7827373cabcSAart Bik load = 7837373cabcSAart Bik rewriter.create<arith::ExtUIOp>(loc, load, genIntType(rewriter, 64)); 784a54f4eaeSMogball load = 785a54f4eaeSMogball rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType()); 786a2c9d4bbSAart Bik } 787a2c9d4bbSAart Bik return load; 788a2c9d4bbSAart Bik } 789a2c9d4bbSAart Bik 790a2c9d4bbSAart Bik /// Generates an invariant value. 791a2c9d4bbSAart Bik static Value genInvariantValue(Merger &merger, CodeGen &codegen, 792a2c9d4bbSAart Bik PatternRewriter &rewriter, unsigned exp) { 793a2c9d4bbSAart Bik Value val = merger.exp(exp).val; 794a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 795a2c9d4bbSAart Bik return genVectorInvariantValue(codegen, rewriter, val); 796a2c9d4bbSAart Bik return val; 797a2c9d4bbSAart Bik } 798a2c9d4bbSAart Bik 799a2c9d4bbSAart Bik /// Generates an address computation "sz * p + i". 800a2c9d4bbSAart Bik static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 801a2c9d4bbSAart Bik Location loc, Value size, Value p, Value i) { 802a54f4eaeSMogball Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 803a2c9d4bbSAart Bik if (auto vtp = i.getType().dyn_cast<VectorType>()) { 804a54f4eaeSMogball Value inv = 805a54f4eaeSMogball rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType()); 806a2c9d4bbSAart Bik mul = genVectorInvariantValue(codegen, rewriter, inv); 807a2c9d4bbSAart Bik } 808a54f4eaeSMogball return rewriter.create<arith::AddIOp>(loc, mul, i); 809a2c9d4bbSAart Bik } 810a2c9d4bbSAart Bik 811a2c9d4bbSAart Bik /// Recursively generates tensor expression. 812a2c9d4bbSAart Bik static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 813a2c9d4bbSAart Bik linalg::GenericOp op, unsigned exp) { 814b8a021dbSAart Bik Location loc = op.getLoc(); 815123e8dfcSAart Bik if (exp == -1u) 816123e8dfcSAart Bik return Value(); 817a2c9d4bbSAart Bik if (merger.exp(exp).kind == Kind::kTensor) 818a2c9d4bbSAart Bik return genTensorLoad(merger, codegen, rewriter, op, exp); 819b8a021dbSAart Bik if (merger.exp(exp).kind == Kind::kInvariant) 820a2c9d4bbSAart Bik return genInvariantValue(merger, codegen, rewriter, exp); 8214569c14aSGus Smith Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 8224569c14aSGus Smith Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 82345b3cfe8SAart Bik return merger.buildExp(rewriter, loc, exp, v0, v1); 824a2c9d4bbSAart Bik } 825a2c9d4bbSAart Bik 826b1d44e59SAart Bik /// Determines if affine expression is invariant. 827b1d44e59SAart Bik static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 828b1d44e59SAart Bik unsigned ldx, bool &atLevel) { 829b1d44e59SAart Bik switch (a.getKind()) { 830b1d44e59SAart Bik case AffineExprKind::DimId: { 831b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 832b1d44e59SAart Bik if (idx == ldx) 833b1d44e59SAart Bik atLevel = true; 834b1d44e59SAart Bik return codegen.loops[idx] != nullptr; // no longer in play? 835b1d44e59SAart Bik } 836b1d44e59SAart Bik case AffineExprKind::Add: 837b1d44e59SAart Bik case AffineExprKind::Mul: { 838b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 839b1d44e59SAart Bik return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 840b1d44e59SAart Bik isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 841b1d44e59SAart Bik } 842b1d44e59SAart Bik default: 843b1d44e59SAart Bik return true; 844b1d44e59SAart Bik } 845b1d44e59SAart Bik } 846b1d44e59SAart Bik 847a2c9d4bbSAart Bik /// Hoists loop invariant tensor loads for which indices have been exhausted. 848a2c9d4bbSAart Bik static void genInvariants(Merger &merger, CodeGen &codegen, 849a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 8507373cabcSAart Bik unsigned exp, unsigned ldx, bool atStart, 8515da21338SAart Bik Kind last = Kind::kTensor) { 852123e8dfcSAart Bik if (exp == -1u) 853123e8dfcSAart Bik return; 854a2c9d4bbSAart Bik if (merger.exp(exp).kind == Kind::kTensor) { 855a2c9d4bbSAart Bik // Inspect tensor indices. 856a2c9d4bbSAart Bik bool atLevel = ldx == -1u; 8574569c14aSGus Smith OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 858619bfe8bSAart Bik auto map = op.getTiedIndexingMap(t); 859619bfe8bSAart Bik auto enc = getSparseTensorEncoding(t->get().getType()); 860c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 861b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 862b1d44e59SAart Bik if (!isInvariantAffine(codegen, a, ldx, atLevel)) 863a2c9d4bbSAart Bik return; // still in play 864a2c9d4bbSAart Bik } 865a2c9d4bbSAart Bik // All exhausted at this level (atLevel denotes exactly at this level). 8667373cabcSAart Bik if (!atLevel) 8677373cabcSAart Bik return; 8682f2b5b7dSTobias Gysi OpOperand *lhs = op.getOutputOperand(0); 869619bfe8bSAart Bik if (lhs == t) { 8707373cabcSAart Bik // Start or end a scalarized reduction 8717373cabcSAart Bik if (atStart) { 8727373cabcSAart Bik Value load = genTensorLoad(merger, codegen, rewriter, op, exp); 8735da21338SAart Bik codegen.redKind = getReduction(last); 8747373cabcSAart Bik codegen.redExp = exp; 8757373cabcSAart Bik updateReduc(merger, codegen, load); 8767373cabcSAart Bik } else { 8777373cabcSAart Bik Value redVal = codegen.redVal; 8787373cabcSAart Bik updateReduc(merger, codegen, Value()); 8797373cabcSAart Bik codegen.redExp = -1u; 8807373cabcSAart Bik codegen.redKind = kNoReduc; 8817373cabcSAart Bik genTensorStore(merger, codegen, rewriter, op, redVal); 8827373cabcSAart Bik } 8837373cabcSAart Bik } else { 8847373cabcSAart Bik // Start or end loop invariant hoisting of a tensor load. 885a2c9d4bbSAart Bik merger.exp(exp).val = 8867373cabcSAart Bik atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 887a2c9d4bbSAart Bik } 888123e8dfcSAart Bik } else if (merger.exp(exp).kind != Kind::kInvariant) { 889a2c9d4bbSAart Bik // Traverse into the binary operations. Note that we only hoist 890a2c9d4bbSAart Bik // tensor loads, since subsequent MLIR/LLVM passes know how to 891a2c9d4bbSAart Bik // deal with all other kinds of derived loop invariants. 8925da21338SAart Bik Kind last = merger.exp(exp).kind; 8934569c14aSGus Smith unsigned e0 = merger.exp(exp).children.e0; 8944569c14aSGus Smith unsigned e1 = merger.exp(exp).children.e1; 8957373cabcSAart Bik genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); 8967373cabcSAart Bik genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); 897a2c9d4bbSAart Bik } 898a2c9d4bbSAart Bik } 899a2c9d4bbSAart Bik 900a2c9d4bbSAart Bik /// Generates initialization code for the subsequent loop sequence at 901a2c9d4bbSAart Bik /// current index level. Returns true if the loop sequence needs to 902a2c9d4bbSAart Bik /// maintain the universal index. 903a2c9d4bbSAart Bik static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 904a2c9d4bbSAart Bik linalg::GenericOp op, std::vector<unsigned> &topSort, 905a2c9d4bbSAart Bik unsigned at, llvm::BitVector &inits) { 906a2c9d4bbSAart Bik bool needsUniv = false; 907a2c9d4bbSAart Bik Location loc = op.getLoc(); 908a2c9d4bbSAart Bik unsigned idx = topSort[at]; 909a2c9d4bbSAart Bik 910a2c9d4bbSAart Bik // Initialize sparse positions. 911a2c9d4bbSAart Bik for (unsigned b = 0, be = inits.size(); b < be; b++) { 912a2c9d4bbSAart Bik if (inits[b]) { 913a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 914a2c9d4bbSAart Bik assert(idx == merger.index(b)); 915a2c9d4bbSAart Bik if (merger.isDim(b, Dim::kSparse)) { 916a2c9d4bbSAart Bik // Initialize sparse index. 917a2c9d4bbSAart Bik unsigned pat = at; 918a2c9d4bbSAart Bik for (; pat != 0; pat--) { 919a2c9d4bbSAart Bik if (codegen.pidxs[tensor][topSort[pat - 1]]) 920a2c9d4bbSAart Bik break; 921a2c9d4bbSAart Bik } 922a2c9d4bbSAart Bik Value ptr = codegen.pointers[tensor][idx]; 923a54f4eaeSMogball Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 924a54f4eaeSMogball Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 925a2c9d4bbSAart Bik : codegen.pidxs[tensor][topSort[pat - 1]]; 926a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 927a54f4eaeSMogball Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 928a2c9d4bbSAart Bik codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 929a2c9d4bbSAart Bik } else { 930a2c9d4bbSAart Bik // Dense index still in play. 931a2c9d4bbSAart Bik needsUniv = true; 932a2c9d4bbSAart Bik } 933a2c9d4bbSAart Bik } 934a2c9d4bbSAart Bik } 935a2c9d4bbSAart Bik 936a2c9d4bbSAart Bik // Initialize the universal dense index. 937a54f4eaeSMogball codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0); 938a2c9d4bbSAart Bik return needsUniv; 939a2c9d4bbSAart Bik } 940a2c9d4bbSAart Bik 941a2c9d4bbSAart Bik /// Returns vectorization strategy. Any implicit inner loop in the Linalg 942a2c9d4bbSAart Bik /// operation is a candidate. Whether it is actually converted to SIMD code 943a2c9d4bbSAart Bik /// depends on the requested strategy. 944a2c9d4bbSAart Bik static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 945a2c9d4bbSAart Bik switch (codegen.options.vectorizationStrategy) { 946a2c9d4bbSAart Bik case SparseVectorizationStrategy::kNone: 947a2c9d4bbSAart Bik return false; 948a2c9d4bbSAart Bik case SparseVectorizationStrategy::kDenseInnerLoop: 949a2c9d4bbSAart Bik return isInner && !isSparse; 950a2c9d4bbSAart Bik case SparseVectorizationStrategy::kAnyStorageInnerLoop: 951a2c9d4bbSAart Bik return isInner; 952a2c9d4bbSAart Bik } 953a2c9d4bbSAart Bik llvm_unreachable("unexpected vectorization strategy"); 954a2c9d4bbSAart Bik } 955a2c9d4bbSAart Bik 956a2c9d4bbSAart Bik /// Returns parallelization strategy. Any implicit loop in the Linalg operation 957a2c9d4bbSAart Bik /// that is marked "parallel" is a candidate. Whether it is actually converted 958a2c9d4bbSAart Bik /// to a parallel operation depends on the requested strategy. 959a2c9d4bbSAart Bik static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 960a2c9d4bbSAart Bik bool isSparse, bool isVector) { 961a2c9d4bbSAart Bik switch (codegen.options.parallelizationStrategy) { 962a2c9d4bbSAart Bik case SparseParallelizationStrategy::kNone: 963a2c9d4bbSAart Bik return false; 964a2c9d4bbSAart Bik case SparseParallelizationStrategy::kDenseOuterLoop: 965a2c9d4bbSAart Bik return isOuter && !isSparse && !isReduction && !isVector; 966a2c9d4bbSAart Bik case SparseParallelizationStrategy::kAnyStorageOuterLoop: 967a2c9d4bbSAart Bik return isOuter && !isReduction && !isVector; 968a2c9d4bbSAart Bik case SparseParallelizationStrategy::kDenseAnyLoop: 969a2c9d4bbSAart Bik return !isSparse && !isReduction && !isVector; 970a2c9d4bbSAart Bik case SparseParallelizationStrategy::kAnyStorageAnyLoop: 971a2c9d4bbSAart Bik return !isReduction && !isVector; 972a2c9d4bbSAart Bik } 973a2c9d4bbSAart Bik llvm_unreachable("unexpected parallelization strategy"); 974a2c9d4bbSAart Bik } 975a2c9d4bbSAart Bik 976849f016cSAart Bik /// Checks unit stride for dense tensors. The iteration graph may have ignored 977a2c9d4bbSAart Bik /// dense access patterns in order to avoid cycles (sparse access patterns are 978a2c9d4bbSAart Bik /// always placed innermost), but that means dense access has become strided. 979849f016cSAart Bik /// This prevents effective vectorization. 980a2c9d4bbSAart Bik static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 981849f016cSAart Bik unsigned idx) { 9822f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 9832f2b5b7dSTobias Gysi if (!getSparseTensorEncoding(t->get().getType())) { 9842f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 985c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 986b1d44e59SAart Bik AffineExpr a = map.getResult(d); 987849f016cSAart Bik // Report non-unit stride if innermost index appears at an outer 988849f016cSAart Bik // dimension (true non-unit stride) or if the innermost index appears 989849f016cSAart Bik // in a compound subscript in the innermost dimension. Even if the 990849f016cSAart Bik // latter is unit stride, it does not play well with scatter/gather. 991c8d5dcb0SAart Bik // TODO: accept unit stride affine innermost like a[i,j+k+1]? 992849f016cSAart Bik if (a.isFunctionOfDim(idx) && 993849f016cSAart Bik ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 994a2c9d4bbSAart Bik return false; 995a2c9d4bbSAart Bik } 996a2c9d4bbSAart Bik } 997a2c9d4bbSAart Bik } 998a2c9d4bbSAart Bik return true; 999a2c9d4bbSAart Bik } 1000a2c9d4bbSAart Bik 1001a2c9d4bbSAart Bik /// Generates a for-loop on a single index. 1002a2c9d4bbSAart Bik static Operation *genFor(Merger &merger, CodeGen &codegen, 1003a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1004a2c9d4bbSAart Bik bool isOuter, bool isInner, unsigned idx, 1005a2c9d4bbSAart Bik llvm::BitVector &indices) { 1006a2c9d4bbSAart Bik unsigned fb = indices.find_first(); 1007a2c9d4bbSAart Bik unsigned tensor = merger.tensor(fb); 1008a2c9d4bbSAart Bik assert(idx == merger.index(fb)); 1009a2c9d4bbSAart Bik auto iteratorTypes = op.iterator_types().getValue(); 1010583a7542STobias Gysi bool isReduction = isReductionIterator(iteratorTypes[idx]); 1011a2c9d4bbSAart Bik bool isSparse = merger.isDim(fb, Dim::kSparse); 1012f66e5769SAart Bik bool isVector = !codegen.sparseOut && 1013f66e5769SAart Bik isVectorFor(codegen, isInner, isSparse) && 1014a2c9d4bbSAart Bik denseUnitStrides(merger, op, idx); 1015a2c9d4bbSAart Bik bool isParallel = 1016f66e5769SAart Bik !codegen.sparseOut && 1017a2c9d4bbSAart Bik isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1018a2c9d4bbSAart Bik 1019a2c9d4bbSAart Bik // Prepare vector length. 1020a2c9d4bbSAart Bik if (isVector) 1021a2c9d4bbSAart Bik codegen.curVecLength = codegen.options.vectorLength; 1022a2c9d4bbSAart Bik 1023a2c9d4bbSAart Bik // Loop bounds and increment. 1024a2c9d4bbSAart Bik Location loc = op.getLoc(); 1025a2c9d4bbSAart Bik Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1026a2c9d4bbSAart Bik Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1027a54f4eaeSMogball Value step = 1028a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength); 1029a2c9d4bbSAart Bik 1030a2c9d4bbSAart Bik // Emit a parallel loop. 1031a2c9d4bbSAart Bik if (isParallel) { 1032a2c9d4bbSAart Bik assert(!isVector); 1033a2c9d4bbSAart Bik scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1034a2c9d4bbSAart Bik if (isSparse) 1035a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1036a2c9d4bbSAart Bik else 1037a2c9d4bbSAart Bik codegen.loops[idx] = parOp.getInductionVars()[0]; 1038a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(parOp.getBody()); 1039a2c9d4bbSAart Bik return parOp; 1040a2c9d4bbSAart Bik } 1041a2c9d4bbSAart Bik 10427373cabcSAart Bik // Emit a sequential or vector loop. 1043a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 10447373cabcSAart Bik if (codegen.redVal) { 10457373cabcSAart Bik // In a vector loop, bring reduction into SIMD form, if not already. 10467373cabcSAart Bik if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 10477373cabcSAart Bik VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 10487373cabcSAart Bik Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); 10497373cabcSAart Bik updateReduc(merger, codegen, vred); 10507373cabcSAart Bik } 10517373cabcSAart Bik operands.push_back(codegen.redVal); 1052a2c9d4bbSAart Bik } 1053a2c9d4bbSAart Bik scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 10547373cabcSAart Bik if (codegen.redVal) 10557373cabcSAart Bik updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1056a2c9d4bbSAart Bik // Assign induction variable to sparse or dense index. 1057a2c9d4bbSAart Bik Value iv = forOp.getInductionVar(); 1058a2c9d4bbSAart Bik if (isSparse) 1059a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = iv; 1060a2c9d4bbSAart Bik else 1061a2c9d4bbSAart Bik codegen.loops[idx] = iv; 1062a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(forOp.getBody()); 1063a2c9d4bbSAart Bik // Share vector iteration mask between all subsequent loads/stores. 1064a2c9d4bbSAart Bik if (isVector) 1065a2c9d4bbSAart Bik codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1066a2c9d4bbSAart Bik return forOp; 1067a2c9d4bbSAart Bik } 1068a2c9d4bbSAart Bik 1069a2c9d4bbSAart Bik /// Emit a while-loop for co-iteration over multiple indices. 1070a2c9d4bbSAart Bik static Operation *genWhile(Merger &merger, CodeGen &codegen, 1071a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1072a2c9d4bbSAart Bik unsigned idx, bool needsUniv, 1073a2c9d4bbSAart Bik llvm::BitVector &indices) { 1074a2c9d4bbSAart Bik SmallVector<Type, 4> types; 1075a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 1076a2c9d4bbSAart Bik // Construct the while-loop with a parameter for each index. 1077a2c9d4bbSAart Bik Type indexType = rewriter.getIndexType(); 1078a2c9d4bbSAart Bik for (unsigned b = 0, be = indices.size(); b < be; b++) { 1079a2c9d4bbSAart Bik if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1080a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1081a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1082a2c9d4bbSAart Bik types.push_back(indexType); 1083a2c9d4bbSAart Bik operands.push_back(codegen.pidxs[tensor][idx]); 1084a2c9d4bbSAart Bik } 1085a2c9d4bbSAart Bik } 10867373cabcSAart Bik if (codegen.redVal) { 10877373cabcSAart Bik types.push_back(codegen.redVal.getType()); 10887373cabcSAart Bik operands.push_back(codegen.redVal); 10897373cabcSAart Bik } 1090a2c9d4bbSAart Bik if (needsUniv) { 1091a2c9d4bbSAart Bik types.push_back(indexType); 1092a2c9d4bbSAart Bik operands.push_back(codegen.loops[idx]); 1093a2c9d4bbSAart Bik } 10947373cabcSAart Bik assert(types.size() == operands.size()); 1095a2c9d4bbSAart Bik Location loc = op.getLoc(); 1096a2c9d4bbSAart Bik scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1097a2c9d4bbSAart Bik Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1098a2c9d4bbSAart Bik Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1099a2c9d4bbSAart Bik 1100a2c9d4bbSAart Bik // Build the "before" region, which effectively consists 1101a2c9d4bbSAart Bik // of a conjunction of "i < upper" tests on all induction. 1102a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&whileOp.before().front()); 1103a2c9d4bbSAart Bik Value cond; 1104a2c9d4bbSAart Bik unsigned o = 0; 1105a2c9d4bbSAart Bik for (unsigned b = 0, be = indices.size(); b < be; b++) { 1106a2c9d4bbSAart Bik if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1107a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1108a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1109a2c9d4bbSAart Bik Value op1 = before->getArgument(o); 1110a2c9d4bbSAart Bik Value op2 = codegen.highs[tensor][idx]; 1111a54f4eaeSMogball Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1112a54f4eaeSMogball op1, op2); 1113a54f4eaeSMogball cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1114a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = after->getArgument(o++); 1115a2c9d4bbSAart Bik } 1116a2c9d4bbSAart Bik } 11177373cabcSAart Bik if (codegen.redVal) 11187373cabcSAart Bik updateReduc(merger, codegen, after->getArgument(o++)); 1119a2c9d4bbSAart Bik if (needsUniv) 1120a2c9d4bbSAart Bik codegen.loops[idx] = after->getArgument(o++); 1121a2c9d4bbSAart Bik assert(o == operands.size()); 1122a2c9d4bbSAart Bik rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1123a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&whileOp.after().front()); 1124a2c9d4bbSAart Bik return whileOp; 1125a2c9d4bbSAart Bik } 1126a2c9d4bbSAart Bik 1127a2c9d4bbSAart Bik /// Generates a for-loop or a while-loop, depending on whether it implements 1128a2c9d4bbSAart Bik /// singleton iteration or co-iteration over the given conjunction. 1129a2c9d4bbSAart Bik static Operation *genLoop(Merger &merger, CodeGen &codegen, 1130a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1131a2c9d4bbSAart Bik std::vector<unsigned> &topSort, unsigned at, 1132a2c9d4bbSAart Bik bool needsUniv, llvm::BitVector &indices) { 1133a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1134a2c9d4bbSAart Bik if (indices.count() == 1) { 1135a2c9d4bbSAart Bik bool isOuter = at == 0; 1136a2c9d4bbSAart Bik bool isInner = at == topSort.size() - 1; 1137a2c9d4bbSAart Bik return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1138a2c9d4bbSAart Bik indices); 1139a2c9d4bbSAart Bik } 1140a2c9d4bbSAart Bik return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1141a2c9d4bbSAart Bik } 1142a2c9d4bbSAart Bik 1143a2c9d4bbSAart Bik /// Generates the local variables for this loop, consisting of the sparse 1144a2c9d4bbSAart Bik /// indices, restored universal dense index, and dense positions. 1145a2c9d4bbSAart Bik static void genLocals(Merger &merger, CodeGen &codegen, 1146a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1147a2c9d4bbSAart Bik std::vector<unsigned> &topSort, unsigned at, 1148a2c9d4bbSAart Bik bool needsUniv, llvm::BitVector &locals) { 1149a2c9d4bbSAart Bik Location loc = op.getLoc(); 1150a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1151a2c9d4bbSAart Bik 1152a2c9d4bbSAart Bik // Initialize sparse indices. 1153a2c9d4bbSAart Bik Value min; 1154a2c9d4bbSAart Bik for (unsigned b = 0, be = locals.size(); b < be; b++) { 1155a2c9d4bbSAart Bik if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1156a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1157a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1158a2c9d4bbSAart Bik Value ptr = codegen.indices[tensor][idx]; 1159a2c9d4bbSAart Bik Value s = codegen.pidxs[tensor][idx]; 1160a2c9d4bbSAart Bik Value load = genLoad(codegen, rewriter, loc, ptr, s); 1161a2c9d4bbSAart Bik codegen.idxs[tensor][idx] = load; 1162a2c9d4bbSAart Bik if (!needsUniv) { 1163a2c9d4bbSAart Bik if (min) { 1164a54f4eaeSMogball Value cmp = rewriter.create<arith::CmpIOp>( 1165a54f4eaeSMogball loc, arith::CmpIPredicate::ult, load, min); 1166a2c9d4bbSAart Bik min = rewriter.create<SelectOp>(loc, cmp, load, min); 1167a2c9d4bbSAart Bik } else { 1168a2c9d4bbSAart Bik min = load; 1169a2c9d4bbSAart Bik } 1170a2c9d4bbSAart Bik } 1171a2c9d4bbSAart Bik } 1172a2c9d4bbSAart Bik } 1173a2c9d4bbSAart Bik 1174a2c9d4bbSAart Bik // Merge dense universal index over minimum. 1175a2c9d4bbSAart Bik if (min) { 1176a2c9d4bbSAart Bik assert(!needsUniv); 1177a2c9d4bbSAart Bik codegen.loops[idx] = min; 1178a2c9d4bbSAart Bik } 1179a2c9d4bbSAart Bik 1180727a63e0SAart Bik // Initialize dense positions. Note that we generate dense indices of the 1181727a63e0SAart Bik // output tensor unconditionally, since they may not appear in the lattice, 1182727a63e0SAart Bik // but may be needed for linearized codegen. 1183a2c9d4bbSAart Bik for (unsigned b = 0, be = locals.size(); b < be; b++) { 1184727a63e0SAart Bik if ((locals[b] || merger.isOutTensor(b, idx)) && 1185727a63e0SAart Bik merger.isDim(b, Dim::kDense)) { 1186a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1187a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1188a2c9d4bbSAart Bik unsigned pat = at; 1189a2c9d4bbSAart Bik for (; pat != 0; pat--) 1190a2c9d4bbSAart Bik if (codegen.pidxs[tensor][topSort[pat - 1]]) 1191a2c9d4bbSAart Bik break; 1192a54f4eaeSMogball Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 1193a2c9d4bbSAart Bik : codegen.pidxs[tensor][topSort[pat - 1]]; 1194a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = genAddress( 1195a2c9d4bbSAart Bik codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1196a2c9d4bbSAart Bik } 1197a2c9d4bbSAart Bik } 1198f66e5769SAart Bik 1199f66e5769SAart Bik // Move the insertion indices in lexicographic index order. 1200f66e5769SAart Bik if (codegen.sparseOut) { 1201f66e5769SAart Bik Value pos = rewriter.create<arith::ConstantIndexOp>(loc, at); 1202f66e5769SAart Bik rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1203f66e5769SAart Bik pos); 1204f66e5769SAart Bik } 1205a2c9d4bbSAart Bik } 1206a2c9d4bbSAart Bik 1207a2c9d4bbSAart Bik /// Generates the induction structure for a while-loop. 1208a2c9d4bbSAart Bik static void genWhileInduction(Merger &merger, CodeGen &codegen, 1209a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1210a2c9d4bbSAart Bik unsigned idx, bool needsUniv, 12117373cabcSAart Bik llvm::BitVector &induction, 12127373cabcSAart Bik scf::WhileOp whileOp) { 1213a2c9d4bbSAart Bik Location loc = op.getLoc(); 12147373cabcSAart Bik // Finalize each else branch of all if statements. 12157373cabcSAart Bik if (codegen.redVal) { 12167373cabcSAart Bik while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 12177373cabcSAart Bik rewriter.getInsertionBlock()->getParentOp())) { 12187373cabcSAart Bik rewriter.create<scf::YieldOp>(loc, codegen.redVal); 12197373cabcSAart Bik updateReduc(merger, codegen, ifOp.getResult(0)); 12207373cabcSAart Bik rewriter.setInsertionPointAfter(ifOp); 12217373cabcSAart Bik } 12227373cabcSAart Bik } 12237373cabcSAart Bik rewriter.setInsertionPointToEnd(&whileOp.after().front()); 12247373cabcSAart Bik // Finalize the induction. Note that the induction could be performed 12257373cabcSAart Bik // in the individual if-branches to avoid re-evaluating the conditions. 12267373cabcSAart Bik // However, that would result in a rather elaborate forest of yield 12277373cabcSAart Bik // instructions during code generation. Moreover, performing the induction 12287373cabcSAart Bik // after the if-statements more closely resembles code generated by TACO. 1229a2c9d4bbSAart Bik unsigned o = 0; 1230a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 1231a54f4eaeSMogball Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1232a2c9d4bbSAart Bik for (unsigned b = 0, be = induction.size(); b < be; b++) { 1233a2c9d4bbSAart Bik if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1234a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1235a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1236a2c9d4bbSAart Bik Value op1 = codegen.idxs[tensor][idx]; 1237a2c9d4bbSAart Bik Value op2 = codegen.loops[idx]; 1238a2c9d4bbSAart Bik Value op3 = codegen.pidxs[tensor][idx]; 1239a54f4eaeSMogball Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1240a54f4eaeSMogball op1, op2); 1241a54f4eaeSMogball Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1242a2c9d4bbSAart Bik operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 12437373cabcSAart Bik codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1244a2c9d4bbSAart Bik } 1245a2c9d4bbSAart Bik } 12467373cabcSAart Bik if (codegen.redVal) { 12477373cabcSAart Bik operands.push_back(codegen.redVal); 12487373cabcSAart Bik updateReduc(merger, codegen, whileOp->getResult(o++)); 12497373cabcSAart Bik } 1250a2c9d4bbSAart Bik if (needsUniv) { 1251a54f4eaeSMogball operands.push_back( 1252a54f4eaeSMogball rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 12537373cabcSAart Bik codegen.loops[idx] = whileOp->getResult(o++); 1254a2c9d4bbSAart Bik } 1255a2c9d4bbSAart Bik assert(o == operands.size()); 1256a2c9d4bbSAart Bik rewriter.create<scf::YieldOp>(loc, operands); 12577373cabcSAart Bik rewriter.setInsertionPointAfter(whileOp); 12587373cabcSAart Bik } 12597373cabcSAart Bik 12607373cabcSAart Bik /// Generates the induction structure for a for-loop. 12617373cabcSAart Bik static void genForInduction(Merger &merger, CodeGen &codegen, 12627373cabcSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 12637373cabcSAart Bik Operation *loop) { 12647373cabcSAart Bik Location loc = op.getLoc(); 12657373cabcSAart Bik unsigned o = 0; 12667373cabcSAart Bik SmallVector<Value, 4> operands; 12677373cabcSAart Bik if (codegen.redVal) { 12687373cabcSAart Bik operands.push_back(codegen.redVal); 12697373cabcSAart Bik updateReduc(merger, codegen, loop->getResult(o++)); 12707373cabcSAart Bik } 12717373cabcSAart Bik assert(o == operands.size()); 12727373cabcSAart Bik if (o > 0) 12737373cabcSAart Bik rewriter.create<scf::YieldOp>(loc, operands); 12747373cabcSAart Bik rewriter.setInsertionPointAfter(loop); 1275a2c9d4bbSAart Bik } 1276a2c9d4bbSAart Bik 1277a2c9d4bbSAart Bik /// Generates a single if-statement within a while-loop. 1278a2c9d4bbSAart Bik static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1279a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1280a2c9d4bbSAart Bik unsigned idx, llvm::BitVector &conditions) { 1281a2c9d4bbSAart Bik Location loc = op.getLoc(); 12827373cabcSAart Bik SmallVector<Type, 4> types; 1283a2c9d4bbSAart Bik Value cond; 1284a2c9d4bbSAart Bik for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1285a2c9d4bbSAart Bik if (conditions[b]) { 1286a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1287a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1288a2c9d4bbSAart Bik Value clause; 1289a2c9d4bbSAart Bik if (merger.isDim(b, Dim::kSparse)) { 1290a2c9d4bbSAart Bik Value op1 = codegen.idxs[tensor][idx]; 1291a2c9d4bbSAart Bik Value op2 = codegen.loops[idx]; 1292a54f4eaeSMogball clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1293a54f4eaeSMogball op1, op2); 1294a2c9d4bbSAart Bik } else { 1295a54f4eaeSMogball clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true 1296a2c9d4bbSAart Bik } 1297a54f4eaeSMogball cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1298a2c9d4bbSAart Bik } 1299a2c9d4bbSAart Bik } 13007373cabcSAart Bik if (codegen.redVal) 13017373cabcSAart Bik types.push_back(codegen.redVal.getType()); 13027373cabcSAart Bik scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1303a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1304a2c9d4bbSAart Bik return ifOp; 1305a2c9d4bbSAart Bik } 1306a2c9d4bbSAart Bik 13077373cabcSAart Bik /// Generates end of true branch of if-statement within a while-loop. 13087373cabcSAart Bik static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 13097373cabcSAart Bik linalg::GenericOp op, scf::IfOp ifOp, Value ifInput) { 13107373cabcSAart Bik if (codegen.redVal) { 13117373cabcSAart Bik rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal); 13127373cabcSAart Bik updateReduc(merger, codegen, ifInput); 13137373cabcSAart Bik } 13147373cabcSAart Bik rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 13157373cabcSAart Bik } 13167373cabcSAart Bik 1317c8d5dcb0SAart Bik //===----------------------------------------------------------------------===// 1318c8d5dcb0SAart Bik // Sparse compiler synthesis methods (loop sequence). 1319c8d5dcb0SAart Bik //===----------------------------------------------------------------------===// 1320c8d5dcb0SAart Bik 1321c8d5dcb0SAart Bik /// Starts a loop sequence at given level. Returns true if 1322c8d5dcb0SAart Bik /// the universal loop index must be maintained at this level. 1323c8d5dcb0SAart Bik static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1324c8d5dcb0SAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1325c8d5dcb0SAart Bik std::vector<unsigned> &topSort, unsigned exp, 1326c8d5dcb0SAart Bik unsigned at, unsigned idx, unsigned ldx, 1327c8d5dcb0SAart Bik unsigned lts) { 1328c8d5dcb0SAart Bik assert(codegen.curVecLength == 1); 13297373cabcSAart Bik assert(!codegen.loops[idx]); 1330c8d5dcb0SAart Bik // Emit invariants at this loop sequence level. 13317373cabcSAart Bik genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); 1332c8d5dcb0SAart Bik // Emit further intitialization at this loop sequence level. 1333c8d5dcb0SAart Bik unsigned l0 = merger.set(lts)[0]; 13347373cabcSAart Bik bool needsUniv = 13357373cabcSAart Bik genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); 1336c8d5dcb0SAart Bik // Maintain the universal index only if it is actually 1337c8d5dcb0SAart Bik // consumed by a subsequent lattice point. 13387373cabcSAart Bik if (needsUniv) { 1339c8d5dcb0SAart Bik unsigned lsize = merger.set(lts).size(); 1340c8d5dcb0SAart Bik for (unsigned i = 1; i < lsize; i++) { 1341c8d5dcb0SAart Bik unsigned li = merger.set(lts)[i]; 1342c8d5dcb0SAart Bik if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1343c8d5dcb0SAart Bik return true; 1344c8d5dcb0SAart Bik } 1345c8d5dcb0SAart Bik } 1346c8d5dcb0SAart Bik return false; 1347c8d5dcb0SAart Bik } 1348c8d5dcb0SAart Bik 1349c8d5dcb0SAart Bik /// Starts a single loop in current sequence. 1350c8d5dcb0SAart Bik static Operation *startLoop(Merger &merger, CodeGen &codegen, 1351c8d5dcb0SAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1352c8d5dcb0SAart Bik std::vector<unsigned> &topSort, unsigned at, 1353c8d5dcb0SAart Bik unsigned li, bool needsUniv) { 1354c8d5dcb0SAart Bik assert(codegen.curVecLength == 1); 1355c8d5dcb0SAart Bik // Emit the for/while-loop control. 1356c8d5dcb0SAart Bik Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1357c8d5dcb0SAart Bik needsUniv, merger.lat(li).simple); 1358c8d5dcb0SAart Bik // Emit the locals for this loop. 1359c8d5dcb0SAart Bik genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1360c8d5dcb0SAart Bik merger.lat(li).bits); 1361c8d5dcb0SAart Bik return loop; 1362c8d5dcb0SAart Bik } 1363c8d5dcb0SAart Bik 1364c8d5dcb0SAart Bik /// Ends a single loop in current sequence. Returns new values for needsUniv. 1365c8d5dcb0SAart Bik static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1366c8d5dcb0SAart Bik linalg::GenericOp op, Operation *loop, unsigned idx, 1367c8d5dcb0SAart Bik unsigned li, bool needsUniv) { 1368c8d5dcb0SAart Bik codegen.curVecLength = 1; 1369c8d5dcb0SAart Bik // End a while-loop. 1370c8d5dcb0SAart Bik if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1371c8d5dcb0SAart Bik genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 13727373cabcSAart Bik merger.lat(li).bits, whileOp); 1373c8d5dcb0SAart Bik return needsUniv; 1374c8d5dcb0SAart Bik } 1375c8d5dcb0SAart Bik // End a for-loop. 13767373cabcSAart Bik genForInduction(merger, codegen, rewriter, op, loop); 1377c8d5dcb0SAart Bik return false; 1378c8d5dcb0SAart Bik } 1379c8d5dcb0SAart Bik 1380c8d5dcb0SAart Bik /// Ends a loop sequence at given level. 1381c8d5dcb0SAart Bik static void endLoopSeq(Merger &merger, CodeGen &codegen, 1382c8d5dcb0SAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1383c8d5dcb0SAart Bik unsigned exp, unsigned idx, unsigned ldx) { 1384c8d5dcb0SAart Bik assert(codegen.curVecLength == 1); 1385c8d5dcb0SAart Bik codegen.loops[idx] = Value(); 13867373cabcSAart Bik // Bring a pending reduction back from SIMD form when sequence ends. 13877373cabcSAart Bik if (codegen.redVal) 13887373cabcSAart Bik if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 13897373cabcSAart Bik updateReduc(merger, codegen, 13907373cabcSAart Bik genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); 13917373cabcSAart Bik // Unmark bookkeeping of invariants and loop index. 13927373cabcSAart Bik genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); 1393c8d5dcb0SAart Bik } 1394c8d5dcb0SAart Bik 1395a2c9d4bbSAart Bik /// Recursively generates code while computing iteration lattices in order 1396a2c9d4bbSAart Bik /// to manage the complexity of implementing co-iteration over unions 1397a2c9d4bbSAart Bik /// and intersections of sparse iterations spaces. 1398a2c9d4bbSAart Bik static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1399a2c9d4bbSAart Bik linalg::GenericOp op, std::vector<unsigned> &topSort, 1400a2c9d4bbSAart Bik unsigned exp, unsigned at) { 1401a2c9d4bbSAart Bik // At each leaf, assign remaining tensor (sub)expression to output tensor. 1402a2c9d4bbSAart Bik if (at == topSort.size()) { 1403a2c9d4bbSAart Bik Value rhs = genExp(merger, codegen, rewriter, op, exp); 1404b1d44e59SAart Bik genTensorStore(merger, codegen, rewriter, op, rhs); 1405a2c9d4bbSAart Bik return; 1406a2c9d4bbSAart Bik } 1407a2c9d4bbSAart Bik 1408a2c9d4bbSAart Bik // Construct iteration lattices for current loop index, with L0 at top. 1409a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1410a2c9d4bbSAart Bik unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1411c8d5dcb0SAart Bik unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1412a2c9d4bbSAart Bik 1413c8d5dcb0SAart Bik // Start a loop sequence. 1414c8d5dcb0SAart Bik bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1415c8d5dcb0SAart Bik idx, ldx, lts); 1416c8d5dcb0SAart Bik 1417c8d5dcb0SAart Bik // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1418c8d5dcb0SAart Bik unsigned lsize = merger.set(lts).size(); 1419a2c9d4bbSAart Bik for (unsigned i = 0; i < lsize; i++) { 1420c8d5dcb0SAart Bik // Start a loop. 1421a2c9d4bbSAart Bik unsigned li = merger.set(lts)[i]; 1422a2c9d4bbSAart Bik Operation *loop = 1423c8d5dcb0SAart Bik startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1424a2c9d4bbSAart Bik 1425a2c9d4bbSAart Bik // Visit all lattices points with Li >= Lj to generate the 1426a2c9d4bbSAart Bik // loop-body, possibly with if statements for coiteration. 14277373cabcSAart Bik Value ifInput = codegen.redVal; 1428a2c9d4bbSAart Bik bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1429a2c9d4bbSAart Bik for (unsigned j = 0; j < lsize; j++) { 1430a2c9d4bbSAart Bik unsigned lj = merger.set(lts)[j]; 1431a2c9d4bbSAart Bik unsigned ej = merger.lat(lj).exp; 1432a2c9d4bbSAart Bik if (li == lj || merger.latGT(li, lj)) { 1433a2c9d4bbSAart Bik // Recurse into body of each branch. 1434a2c9d4bbSAart Bik if (isWhile) { 1435a2c9d4bbSAart Bik scf::IfOp ifOp = 1436a2c9d4bbSAart Bik genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1437a2c9d4bbSAart Bik genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 14387373cabcSAart Bik endIf(merger, codegen, rewriter, op, ifOp, ifInput); 1439a2c9d4bbSAart Bik } else { 1440a2c9d4bbSAart Bik genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1441a2c9d4bbSAart Bik } 1442a2c9d4bbSAart Bik } 1443a2c9d4bbSAart Bik } 1444a2c9d4bbSAart Bik 1445c8d5dcb0SAart Bik // End a loop. 1446c8d5dcb0SAart Bik needsUniv = 1447c8d5dcb0SAart Bik endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1448a2c9d4bbSAart Bik } 1449a2c9d4bbSAart Bik 1450c8d5dcb0SAart Bik // End a loop sequence. 1451c8d5dcb0SAart Bik endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx); 1452a2c9d4bbSAart Bik } 1453a2c9d4bbSAart Bik 1454727a63e0SAart Bik /// Converts the result computed by the sparse kernel into the required form. 145536b66ab9SAart Bik static void genResult(Merger &merger, CodeGen &codegen, 145636b66ab9SAart Bik PatternRewriter &rewriter, linalg::GenericOp op) { 145736b66ab9SAart Bik OpOperand *lhs = op.getOutputOperand(0); 145836b66ab9SAart Bik Type resType = lhs->get().getType(); 1459f66e5769SAart Bik Value result; 1460f66e5769SAart Bik if (getSparseTensorEncoding(resType)) { 1461f66e5769SAart Bik // The sparse tensor rematerializes from the original sparse tensor's 1462f66e5769SAart Bik // underlying sparse storage format. 1463f66e5769SAart Bik rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1464f66e5769SAart Bik codegen.sparseOut == lhs); 146536b66ab9SAart Bik } else { 1466f66e5769SAart Bik // To rematerialize an non-annotated tensor, simply load it 146736b66ab9SAart Bik // from the bufferized value. 1468f66e5769SAart Bik Value val = codegen.buffers.back(); // value array 1469f66e5769SAart Bik rewriter.replaceOpWithNewOp<memref::TensorLoadOp>(op, resType, val); 147036b66ab9SAart Bik } 1471727a63e0SAart Bik } 1472727a63e0SAart Bik 14735da21338SAart Bik //===----------------------------------------------------------------------===// 14745da21338SAart Bik // Sparse compiler rewriting methods. 14755da21338SAart Bik //===----------------------------------------------------------------------===// 14765da21338SAart Bik 1477a2c9d4bbSAart Bik namespace { 1478a2c9d4bbSAart Bik 1479a2c9d4bbSAart Bik /// Sparse rewriting rule for generic Lingalg operation. 1480a2c9d4bbSAart Bik struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1481a2c9d4bbSAart Bik public: 1482a2c9d4bbSAart Bik GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1483a2c9d4bbSAart Bik : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1484a2c9d4bbSAart Bik 1485a2c9d4bbSAart Bik LogicalResult matchAndRewrite(linalg::GenericOp op, 1486a2c9d4bbSAart Bik PatternRewriter &rewriter) const override { 1487a2c9d4bbSAart Bik // Detects sparse annotations and translate the per-dimension sparsity 1488a2c9d4bbSAart Bik // information for all tensors to loop indices in the kernel. 1489a2c9d4bbSAart Bik assert(op.getNumOutputs() == 1); 14902f2b5b7dSTobias Gysi unsigned numTensors = op.getNumInputsAndOutputs(); 1491a2c9d4bbSAart Bik unsigned numLoops = op.iterator_types().getValue().size(); 1492a2c9d4bbSAart Bik Merger merger(numTensors, numLoops); 1493bf9ef3efSAart Bik if (!findSparseAnnotations(merger, op)) 1494bf9ef3efSAart Bik return failure(); 1495a2c9d4bbSAart Bik 1496a2c9d4bbSAart Bik // Computes a topologically sorted iteration graph to ensure 1497a2c9d4bbSAart Bik // tensors are visited in natural index order. Fails on cycles. 1498a2c9d4bbSAart Bik // This assumes that higher-level passes have already put the 1499a2c9d4bbSAart Bik // tensors in each tensor expression in a feasible order. 1500a2c9d4bbSAart Bik std::vector<unsigned> topSort; 1501b6d1a31cSAart Bik if (!computeIterationGraph(merger, op, topSort, 1502b6d1a31cSAart Bik SortMask::kIncludeUndef | 1503b6d1a31cSAart Bik SortMask::kIncludeDense) && 1504b6d1a31cSAart Bik !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1505b6d1a31cSAart Bik !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1506b6d1a31cSAart Bik !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1507a2c9d4bbSAart Bik return failure(); 1508a2c9d4bbSAart Bik 1509266a7414SAart Bik // Builds the tensor expression for the Linalg operation in SSA form. 15107373cabcSAart Bik Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 15117373cabcSAart Bik if (!optExp.hasValue()) 1512266a7414SAart Bik return failure(); 15137373cabcSAart Bik unsigned exp = optExp.getValue(); 1514a2c9d4bbSAart Bik 1515266a7414SAart Bik // Rejects an inadmissable tensor expression. 1516f66e5769SAart Bik OpOperand *sparseOut = nullptr; 1517f66e5769SAart Bik if (!isAdmissableTensorExp(merger, op, exp, &sparseOut)) 151836b66ab9SAart Bik return failure(); 151936b66ab9SAart Bik 1520a2c9d4bbSAart Bik // Recursively generates code. 1521f66e5769SAart Bik CodeGen codegen(options, numTensors, numLoops, sparseOut); 1522c8d5dcb0SAart Bik genBuffers(merger, codegen, rewriter, op); 15237373cabcSAart Bik genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 152436b66ab9SAart Bik genResult(merger, codegen, rewriter, op); 1525a2c9d4bbSAart Bik return success(); 1526a2c9d4bbSAart Bik } 1527a2c9d4bbSAart Bik 1528a2c9d4bbSAart Bik private: 1529a2c9d4bbSAart Bik /// Options to control sparse code generation. 1530a2c9d4bbSAart Bik SparsificationOptions options; 1531a2c9d4bbSAart Bik }; 1532a2c9d4bbSAart Bik 1533a2c9d4bbSAart Bik } // namespace 1534a2c9d4bbSAart Bik 1535a2c9d4bbSAart Bik /// Populates the given patterns list with rewriting rules required for 1536a2c9d4bbSAart Bik /// the sparsification of linear algebra operations. 1537a2c9d4bbSAart Bik void mlir::populateSparsificationPatterns( 1538a2c9d4bbSAart Bik RewritePatternSet &patterns, const SparsificationOptions &options) { 1539a2c9d4bbSAart Bik patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1540a2c9d4bbSAart Bik } 1541