196a23911SAart Bik //===- Sparsification.cpp - Implementation of sparsification --------------===// 2a2c9d4bbSAart Bik // 3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information. 5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a2c9d4bbSAart Bik // 7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 8a2c9d4bbSAart Bik // 9160399c7SAart Bik // This file implements converting sparse tensor types to actual sparse code. 10a2c9d4bbSAart Bik // 11a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 12a2c9d4bbSAart Bik 1376a18618SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 1557470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 164397a1baSMatthias Springer #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" 17a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h" 1966f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 20a2c9d4bbSAart Bik #include "mlir/Dialect/SCF/SCF.h" 2176a18618SMatthias Springer #include "mlir/Dialect/SCF/Transforms.h" 22a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24744146f6SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 25a2c9d4bbSAart Bik #include "mlir/Dialect/StandardOps/IR/Ops.h" 26a2c9d4bbSAart Bik #include "mlir/Dialect/Vector/VectorOps.h" 27a2c9d4bbSAart Bik #include "mlir/IR/Matchers.h" 2896a23911SAart Bik #include "mlir/IR/TensorEncoding.h" 29a2c9d4bbSAart Bik #include "llvm/ADT/SmallBitVector.h" 30a2c9d4bbSAart Bik 31a2c9d4bbSAart Bik using namespace mlir; 3296a23911SAart Bik using namespace mlir::sparse_tensor; 33a2c9d4bbSAart Bik 345da21338SAart Bik //===----------------------------------------------------------------------===// 355da21338SAart Bik // Declarations of data structures. 365da21338SAart Bik //===----------------------------------------------------------------------===// 375da21338SAart Bik 38a2c9d4bbSAart Bik namespace { 39a2c9d4bbSAart Bik 40b6d1a31cSAart Bik // Iteration graph sorting. 41b6d1a31cSAart Bik enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 42b6d1a31cSAart Bik 435da21338SAart Bik // Reduction kinds. 447373cabcSAart Bik enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 455da21338SAart Bik 46a2c9d4bbSAart Bik // Code generation. 47a2c9d4bbSAart Bik struct CodeGen { 48f66e5769SAart Bik CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 49*7d4da4e1SAart Bik OpOperand *op, unsigned nest) 50a2c9d4bbSAart Bik : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 51a2c9d4bbSAart Bik pointers(numTensors, std::vector<Value>(numLoops)), 52a2c9d4bbSAart Bik indices(numTensors, std::vector<Value>(numLoops)), 53a2c9d4bbSAart Bik highs(numTensors, std::vector<Value>(numLoops)), 54a2c9d4bbSAart Bik pidxs(numTensors, std::vector<Value>(numLoops)), 55a2c9d4bbSAart Bik idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 56*7d4da4e1SAart Bik redKind(kNoReduc), sparseOut(op), outerParNest(nest), lexIdx(), 57*7d4da4e1SAart Bik curVecLength(1), curVecMask() {} 58a2c9d4bbSAart Bik /// Sparsification options. 5996a23911SAart Bik SparsificationOptions options; 60a2c9d4bbSAart Bik /// Universal dense indices and upper bounds (by index). The loops array 61a2c9d4bbSAart Bik /// is updated with the value of the universal dense index in the current 62a2c9d4bbSAart Bik /// loop. The sizes array is set once with the inferred dimension sizes. 63a2c9d4bbSAart Bik std::vector<Value> loops; 64a2c9d4bbSAart Bik std::vector<Value> sizes; 65a2c9d4bbSAart Bik /// Buffers for storing dense and sparse numerical values (by tensor). 66a2c9d4bbSAart Bik /// This array is set once during bufferization of all tensors. 67a2c9d4bbSAart Bik std::vector<Value> buffers; 68a2c9d4bbSAart Bik /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 69a2c9d4bbSAart Bik /// This array is set once during bufferization of all sparse tensors. 70a2c9d4bbSAart Bik std::vector<std::vector<Value>> pointers; 71a2c9d4bbSAart Bik std::vector<std::vector<Value>> indices; 72a2c9d4bbSAart Bik /// Sparse iteration information (by tensor and index). These arrays 73a2c9d4bbSAart Bik /// are updated to remain current within the current loop. 74a2c9d4bbSAart Bik std::vector<std::vector<Value>> highs; 75a2c9d4bbSAart Bik std::vector<std::vector<Value>> pidxs; 76a2c9d4bbSAart Bik std::vector<std::vector<Value>> idxs; 77a2c9d4bbSAart Bik /// Current reduction, updated during code generation. When indices of a 787373cabcSAart Bik /// reduction are exhausted, all inner loops can use a scalarized reduction. 79a2c9d4bbSAart Bik unsigned redExp; 80a2c9d4bbSAart Bik Value redVal; 815da21338SAart Bik Reduction redKind; 82*7d4da4e1SAart Bik // Sparse tensor as output. Implemented either through direct injective 83*7d4da4e1SAart Bik // insertion in lexicographic index order (where indices are updated 84*7d4da4e1SAart Bik // in the temporary array `lexIdx`) or TODO: access pattern expansion 85f66e5769SAart Bik OpOperand *sparseOut; 86*7d4da4e1SAart Bik unsigned outerParNest; 87f66e5769SAart Bik Value lexIdx; 88a2c9d4bbSAart Bik // Current vector length and mask. 89a2c9d4bbSAart Bik unsigned curVecLength; 90a2c9d4bbSAart Bik Value curVecMask; 91a2c9d4bbSAart Bik }; 92a2c9d4bbSAart Bik 93a2c9d4bbSAart Bik } // namespace 94a2c9d4bbSAart Bik 955da21338SAart Bik //===----------------------------------------------------------------------===// 965da21338SAart Bik // Sparse compiler analysis methods. 975da21338SAart Bik //===----------------------------------------------------------------------===// 985da21338SAart Bik 995da21338SAart Bik /// Helper method to apply dimension ordering permutation. 1005da21338SAart Bik static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 101c194b49cSAart Bik if (enc) { 102c194b49cSAart Bik auto order = enc.getDimOrdering(); 103c194b49cSAart Bik if (order) { 104c194b49cSAart Bik assert(order.isPermutation()); 105c194b49cSAart Bik return order.getDimPosition(d); 106c194b49cSAart Bik } 107c194b49cSAart Bik } 108c194b49cSAart Bik return d; 109c194b49cSAart Bik } 110c194b49cSAart Bik 1115da21338SAart Bik /// Helper method to translate dim level type to internal representation. 1125da21338SAart Bik static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 11396a23911SAart Bik if (enc) { 11496a23911SAart Bik SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 11596a23911SAart Bik if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 11696a23911SAart Bik return Dim::kSparse; 11796a23911SAart Bik if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 11896a23911SAart Bik return Dim::kSingle; 11996a23911SAart Bik } 12096a23911SAart Bik return Dim::kDense; 12196a23911SAart Bik } 12296a23911SAart Bik 123b1d44e59SAart Bik /// Helper method to inspect affine expressions. Rejects cases where the 124c8d5dcb0SAart Bik /// same index is used more than once. Also rejects affine expressions 125c8d5dcb0SAart Bik /// that are not a direct index for annotated tensors. 126c8d5dcb0SAart Bik // TODO: accept more affine cases for sparse tensors 127b1d44e59SAart Bik static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 128b1d44e59SAart Bik bool isDense) { 129b1d44e59SAart Bik switch (a.getKind()) { 130b1d44e59SAart Bik case AffineExprKind::DimId: { 131b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 132b1d44e59SAart Bik if (!merger.isDim(tensor, idx, Dim::kUndef)) 133b1d44e59SAart Bik return false; // used more than once 134b1d44e59SAart Bik merger.setDim(tensor, idx, dim); 135b1d44e59SAart Bik return true; 136b1d44e59SAart Bik } 137b1d44e59SAart Bik case AffineExprKind::Add: 138b1d44e59SAart Bik case AffineExprKind::Mul: { 139b1d44e59SAart Bik if (!isDense) 140b1d44e59SAart Bik return false; 141b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 142b1d44e59SAart Bik return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 143b1d44e59SAart Bik findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 144b1d44e59SAart Bik } 145b1d44e59SAart Bik case AffineExprKind::Constant: 146b1d44e59SAart Bik return isDense; 147b1d44e59SAart Bik default: 148b1d44e59SAart Bik return false; 149b1d44e59SAart Bik } 150b1d44e59SAart Bik } 151b1d44e59SAart Bik 15296a23911SAart Bik /// Helper method to inspect sparse encodings in the tensor types. 153a2c9d4bbSAart Bik /// Fills the per-dimension sparsity information for all tensors. 154b1d44e59SAart Bik /// Returns true if the sparse annotations and affine subscript 155b1d44e59SAart Bik /// expressions of all tensors are admissable. Returns false if 156b1d44e59SAart Bik /// no annotations are found or inadmissable constructs occur. 157bf9ef3efSAart Bik static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 158bf9ef3efSAart Bik bool annotated = false; 1592f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 1602f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 1612f2b5b7dSTobias Gysi auto enc = getSparseTensorEncoding(t->get().getType()); 162727a63e0SAart Bik if (enc) 163bf9ef3efSAart Bik annotated = true; 1642f2b5b7dSTobias Gysi assert(map.getNumResults() == op.getRank(t)); 165c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 166b1d44e59SAart Bik unsigned tensor = t->getOperandNumber(); 167b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 168b1d44e59SAart Bik if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 169b1d44e59SAart Bik return false; // inadmissable affine expression 170a2c9d4bbSAart Bik } 171a2c9d4bbSAart Bik } 172bf9ef3efSAart Bik return annotated; 173a2c9d4bbSAart Bik } 174a2c9d4bbSAart Bik 175a2c9d4bbSAart Bik /// A DFS helper to compute a topological sort. Note that recursion is 176a2c9d4bbSAart Bik /// bounded by the number of implicit loops, which is always small. 177a2c9d4bbSAart Bik /// Returns false when a cycle is detected. 178a2c9d4bbSAart Bik static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 179a2c9d4bbSAart Bik std::vector<unsigned> &topSort, 180a2c9d4bbSAart Bik std::vector<std::vector<bool>> &adjM) { 181a2c9d4bbSAart Bik if (visit[i] != 0) 182a2c9d4bbSAart Bik return visit[i] != 1; // 1 denotes cycle! 183a2c9d4bbSAart Bik visit[i] = 1; 184a2c9d4bbSAart Bik for (unsigned j = 0, e = visit.size(); j < e; j++) 185a2c9d4bbSAart Bik if (adjM[i][j]) 186a2c9d4bbSAart Bik if (!topSortDFS(j, visit, topSort, adjM)) 187a2c9d4bbSAart Bik return false; 188a2c9d4bbSAart Bik visit[i] = 2; 189a2c9d4bbSAart Bik topSort.push_back(i); 190a2c9d4bbSAart Bik return true; 191a2c9d4bbSAart Bik } 192a2c9d4bbSAart Bik 193b1d44e59SAart Bik /// Helper method to add all constraints from the indices in one affine 194b1d44e59SAart Bik /// expression before all indices in the other affine expression. For 195b1d44e59SAart Bik /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 196b1d44e59SAart Bik static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 197b1d44e59SAart Bik AffineExpr a, AffineExpr b, unsigned fidx) { 198b1d44e59SAart Bik switch (a.getKind()) { 199b1d44e59SAart Bik case AffineExprKind::DimId: { 200b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 201b1d44e59SAart Bik if (b) 202b1d44e59SAart Bik addAffineOrderings(adjM, b, AffineExpr(), idx); 203b1d44e59SAart Bik else 204b1d44e59SAart Bik adjM[fidx][idx] = true; 205b1d44e59SAart Bik break; 206b1d44e59SAart Bik } 207b1d44e59SAart Bik case AffineExprKind::Add: 208b1d44e59SAart Bik case AffineExprKind::Mul: { 209b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 210b1d44e59SAart Bik addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 211b1d44e59SAart Bik addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 212b1d44e59SAart Bik break; 213b1d44e59SAart Bik } 214b1d44e59SAart Bik default: 215b1d44e59SAart Bik break; 216b1d44e59SAart Bik } 217b1d44e59SAart Bik } 218b1d44e59SAart Bik 219a2c9d4bbSAart Bik /// Computes a topologically sorted iteration graph for the linalg operation. 220a2c9d4bbSAart Bik /// Ensures all tensors are visited in natural index order. This is essential 221a2c9d4bbSAart Bik /// for sparse storage formats since these only support access along fixed 222a2c9d4bbSAart Bik /// dimensions. Even for dense storage formats, however, the natural index 223a2c9d4bbSAart Bik /// order yields innermost unit-stride access with better spatial locality. 224a2c9d4bbSAart Bik static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 225a2c9d4bbSAart Bik std::vector<unsigned> &topSort, 226b6d1a31cSAart Bik unsigned mask) { 227a2c9d4bbSAart Bik // Set up an n x n from/to adjacency matrix of the iteration graph 228a2c9d4bbSAart Bik // for the implicit loop indices i_0 .. i_n-1. 229a2c9d4bbSAart Bik unsigned n = op.getNumLoops(); 230a2c9d4bbSAart Bik std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 231a2c9d4bbSAart Bik 232a2c9d4bbSAart Bik // Iterate over the indexing maps of every tensor in the tensor expression. 2332f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 2342f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 2352f2b5b7dSTobias Gysi auto enc = getSparseTensorEncoding(t->get().getType()); 236a2c9d4bbSAart Bik assert(map.getNumDims() == n); 237b6d1a31cSAart Bik // Skip dense tensor constraints when not requested. 238b6d1a31cSAart Bik if (!(mask & SortMask::kIncludeDense) && !enc) 239a2c9d4bbSAart Bik continue; 240c194b49cSAart Bik // Each tensor expression and optional dimension ordering (row-major 241c194b49cSAart Bik // by default) puts an ordering constraint on the loop indices. For 242c194b49cSAart Bik // example, the tensor expresion A_ijk forces the ordering i < j < k 243c194b49cSAart Bik // on the loop indices if no explicit dimension ordering is given. 244c194b49cSAart Bik for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 245b1d44e59SAart Bik AffineExpr f = map.getResult(perm(enc, d - 1)); 246b1d44e59SAart Bik AffineExpr t = map.getResult(perm(enc, d)); 247b1d44e59SAart Bik addAffineOrderings(adjM, f, t, 0); 248a2c9d4bbSAart Bik } 249b6d1a31cSAart Bik // Push unrelated loops into sparse iteration space, so these 250b6d1a31cSAart Bik // will be skipped more often. 251b6d1a31cSAart Bik if (mask & SortMask::kIncludeUndef) { 252b6d1a31cSAart Bik unsigned tensor = t->getOperandNumber(); 253b6d1a31cSAart Bik for (unsigned i = 0; i < n; i++) 254b6d1a31cSAart Bik if (merger.isDim(tensor, i, Dim::kSparse)) 255b6d1a31cSAart Bik for (unsigned j = 0; j < n; j++) 256b6d1a31cSAart Bik if (merger.isDim(tensor, j, Dim::kUndef)) 257b6d1a31cSAart Bik adjM[i][j] = true; 258b6d1a31cSAart Bik } 259a2c9d4bbSAart Bik } 260a2c9d4bbSAart Bik 261a2c9d4bbSAart Bik // Topologically sort the iteration graph to determine loop order. 262a2c9d4bbSAart Bik // Report failure for a cyclic iteration graph. 263a2c9d4bbSAart Bik topSort.clear(); 264a2c9d4bbSAart Bik topSort.reserve(n); 265a2c9d4bbSAart Bik std::vector<unsigned> visit(n, 0); 266a2c9d4bbSAart Bik for (unsigned i = 0; i < n; i++) 267a2c9d4bbSAart Bik if (visit[i] == 0) 268a2c9d4bbSAart Bik if (!topSortDFS(i, visit, topSort, adjM)) 269a2c9d4bbSAart Bik return false; // cycle! 270a2c9d4bbSAart Bik std::reverse(std::begin(topSort), std::end(topSort)); 271a2c9d4bbSAart Bik return true; 272a2c9d4bbSAart Bik } 273a2c9d4bbSAart Bik 274c8d5dcb0SAart Bik /// Returns true if tensor has an in-place annotation. 275c8d5dcb0SAart Bik static bool isInPlace(Value val) { 276c8d5dcb0SAart Bik if (auto arg = val.dyn_cast<BlockArgument>()) 277c8d5dcb0SAart Bik if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 278c8d5dcb0SAart Bik if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 2794397a1baSMatthias Springer arg.getArgNumber(), 2804397a1baSMatthias Springer linalg::comprehensive_bufferize::BufferizableOpInterface:: 2814397a1baSMatthias Springer kInplaceableAttrName)) 282c8d5dcb0SAart Bik return attr.getValue(); 283c8d5dcb0SAart Bik return false; 284c8d5dcb0SAart Bik } 285c8d5dcb0SAart Bik 286f66e5769SAart Bik /// Returns true if tensor materializes uninitialized into the computation. 287c8d5dcb0SAart Bik static bool isMaterializing(Value val) { 288c8d5dcb0SAart Bik return val.getDefiningOp<linalg::InitTensorOp>() || 289c8d5dcb0SAart Bik val.getDefiningOp<InitOp>(); 290c8d5dcb0SAart Bik } 291c8d5dcb0SAart Bik 29236b66ab9SAart Bik /// Returns true when the tensor expression is admissable for codegen. 29336b66ab9SAart Bik /// Since all sparse input tensors are admissable, we just need to check 294*7d4da4e1SAart Bik /// whether the out tensor in the tensor expression codegen is admissable. 295*7d4da4e1SAart Bik /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective 296*7d4da4e1SAart Bik /// nesting depth when a "truly dynamic" sparse tensor output occurs. 29736b66ab9SAart Bik static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 298*7d4da4e1SAart Bik std::vector<unsigned> &topSort, unsigned exp, 299*7d4da4e1SAart Bik OpOperand **sparseOut, 300*7d4da4e1SAart Bik unsigned &outerParNest) { 30136b66ab9SAart Bik OpOperand *lhs = op.getOutputOperand(0); 30236b66ab9SAart Bik unsigned tensor = lhs->getOperandNumber(); 30336b66ab9SAart Bik auto enc = getSparseTensorEncoding(lhs->get().getType()); 30436b66ab9SAart Bik // An non-annotated output tensor is assumed dense, and becomes a random 305b1d44e59SAart Bik // access n-dim memref. Admissable since insertions cannot occur. 30636b66ab9SAart Bik if (!enc) 30736b66ab9SAart Bik return true; 30836b66ab9SAart Bik // An all-dense annotated "sparse" output tensor becomes a linearized random 30936b66ab9SAart Bik // access 1-dim memref. Also admissable since insertions cannot occur. 31036b66ab9SAart Bik bool allDense = true; 311*7d4da4e1SAart Bik auto iteratorTypes = op.iterator_types().getValue(); 312*7d4da4e1SAart Bik unsigned numLoops = iteratorTypes.size(); 31336b66ab9SAart Bik for (unsigned i = 0; i < numLoops; i++) 31436b66ab9SAart Bik if (merger.isDim(tensor, i, Dim::kSparse)) { 31536b66ab9SAart Bik allDense = false; 31636b66ab9SAart Bik break; 31736b66ab9SAart Bik } 31836b66ab9SAart Bik if (allDense) 31936b66ab9SAart Bik return true; 32036b66ab9SAart Bik // A tensor expression with a sparse output tensor that changes its values 32136b66ab9SAart Bik // but not its nonzero structure, an operation called "simply dynamic" in 322c8d5dcb0SAart Bik // [Bik96,Ch9], is also admissable without special codegen, provided 323c8d5dcb0SAart Bik // the tensor's underlying sparse storage scheme can be modified in place. 324f66e5769SAart Bik if (merger.isConjunction(tensor, exp) && isInPlace(lhs->get())) 325f66e5769SAart Bik return true; 326f66e5769SAart Bik // Accept "truly dynamic" if the output tensor materializes uninitialized 327f66e5769SAart Bik // into the computation and insertions occur in lexicographic index order. 328f66e5769SAart Bik if (isMaterializing(lhs->get())) { 329*7d4da4e1SAart Bik unsigned nest = 0; 330*7d4da4e1SAart Bik for (unsigned i = 0; i < numLoops; i++) { 331*7d4da4e1SAart Bik if (isReductionIterator(iteratorTypes[topSort[i]])) 332*7d4da4e1SAart Bik break; // terminate at first reduction 333*7d4da4e1SAart Bik nest++; 334*7d4da4e1SAart Bik } 335*7d4da4e1SAart Bik // Determine admissable dynamic insertion situations: 336*7d4da4e1SAart Bik // (1) fully injective, since there are no reductions, 337*7d4da4e1SAart Bik // (2) admissable 1-d expansion in innermost dimension. TODO: accept 338*7d4da4e1SAart Bik if (nest == op.getRank(lhs)) { 339f66e5769SAart Bik *sparseOut = lhs; 340*7d4da4e1SAart Bik outerParNest = nest; 341f66e5769SAart Bik return true; 342f66e5769SAart Bik } 343*7d4da4e1SAart Bik } 34436b66ab9SAart Bik return false; 34536b66ab9SAart Bik } 34636b66ab9SAart Bik 3475da21338SAart Bik //===----------------------------------------------------------------------===// 3487373cabcSAart Bik // Sparse compiler synthesis methods (reductions). 3495da21338SAart Bik //===----------------------------------------------------------------------===// 3505da21338SAart Bik 3515da21338SAart Bik /// Maps reduction kind to name encoding. 3525da21338SAart Bik static StringRef getReductionName(Reduction kind) { 3535da21338SAart Bik switch (kind) { 3547373cabcSAart Bik case kNoReduc: 3557373cabcSAart Bik break; 3565da21338SAart Bik case kSum: 3575da21338SAart Bik return "add"; 3585da21338SAart Bik case kProduct: 3595da21338SAart Bik return "mul"; 3605da21338SAart Bik case kAnd: 3615da21338SAart Bik return "and"; 3625da21338SAart Bik case kOr: 3635da21338SAart Bik return "or"; 3645da21338SAart Bik case kXor: 3655da21338SAart Bik return "xor"; 3665da21338SAart Bik } 3675da21338SAart Bik llvm_unreachable("unknown reduction kind"); 3685da21338SAart Bik } 3695da21338SAart Bik 3705da21338SAart Bik /// Maps operation to reduction. 3715da21338SAart Bik static Reduction getReduction(Kind kind) { 3725da21338SAart Bik switch (kind) { 3735da21338SAart Bik case Kind::kAddF: 3745da21338SAart Bik case Kind::kAddI: 3755da21338SAart Bik case Kind::kSubF: 3765da21338SAart Bik case Kind::kSubI: 3775da21338SAart Bik return kSum; 3785da21338SAart Bik case Kind::kMulF: 3795da21338SAart Bik case Kind::kMulI: 3805da21338SAart Bik return kProduct; 3815da21338SAart Bik case Kind::kAndI: 3825da21338SAart Bik return kAnd; 3835da21338SAart Bik case Kind::kOrI: 3845da21338SAart Bik return kOr; 3855da21338SAart Bik case Kind::kXorI: 3865da21338SAart Bik return kXor; 3875da21338SAart Bik default: 3885da21338SAart Bik llvm_unreachable("unexpected reduction operator"); 3895da21338SAart Bik } 3905da21338SAart Bik } 3915da21338SAart Bik 3927373cabcSAart Bik /// Generates an initial value for a vector reduction, following the scheme 3935da21338SAart Bik /// given in Chapter 5 of "The Software Vectorization Handbook", where the 3945da21338SAart Bik /// initial scalar value is correctly embedded in the vector reduction value, 3955da21338SAart Bik /// and a straightforward horizontal reduction will complete the operation. 3967373cabcSAart Bik static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, 3977373cabcSAart Bik Location loc, VectorType vtp) { 3987373cabcSAart Bik Value r = codegen.redVal; 3997373cabcSAart Bik switch (codegen.redKind) { 4007373cabcSAart Bik case kNoReduc: 4017373cabcSAart Bik break; 4025da21338SAart Bik case kSum: 4035da21338SAart Bik case kXor: { 4045da21338SAart Bik // Initialize reduction vector to: | 0 | .. | 0 | r | 4055da21338SAart Bik Attribute zero = rewriter.getZeroAttr(vtp); 406c8d5dcb0SAart Bik Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero); 4077c5ecc8bSMogball return rewriter.create<vector::InsertElementOp>( 4087c5ecc8bSMogball loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0)); 4095da21338SAart Bik } 4105da21338SAart Bik case kProduct: { 4115da21338SAart Bik // Initialize reduction vector to: | 1 | .. | 1 | r | 4125da21338SAart Bik Type etp = vtp.getElementType(); 4135da21338SAart Bik Attribute one; 4145da21338SAart Bik if (etp.isa<FloatType>()) 4155da21338SAart Bik one = rewriter.getFloatAttr(etp, 1.0); 4165da21338SAart Bik else 4175da21338SAart Bik one = rewriter.getIntegerAttr(etp, 1); 418c8d5dcb0SAart Bik Value vec = rewriter.create<arith::ConstantOp>( 419c8d5dcb0SAart Bik loc, vtp, DenseElementsAttr::get(vtp, one)); 4207c5ecc8bSMogball return rewriter.create<vector::InsertElementOp>( 4217c5ecc8bSMogball loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0)); 4225da21338SAart Bik } 4235da21338SAart Bik case kAnd: 4245da21338SAart Bik case kOr: 4255da21338SAart Bik // Initialize reduction vector to: | r | .. | r | r | 4265da21338SAart Bik return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 4275da21338SAart Bik } 4285da21338SAart Bik llvm_unreachable("unknown reduction kind"); 4295da21338SAart Bik } 4305da21338SAart Bik 4317373cabcSAart Bik /// Generates final value for a vector reduction. 4327373cabcSAart Bik static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, 4337373cabcSAart Bik Location loc, VectorType vtp) { 4347373cabcSAart Bik StringRef name = getReductionName(codegen.redKind); 4357373cabcSAart Bik StringAttr kind = rewriter.getStringAttr(name); 4367373cabcSAart Bik return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind, 4377373cabcSAart Bik codegen.redVal, ValueRange{}); 4387373cabcSAart Bik } 4397373cabcSAart Bik 4407373cabcSAart Bik /// Updates scalarized reduction value. 4417373cabcSAart Bik static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 4427373cabcSAart Bik assert(codegen.redKind != kNoReduc); 4437373cabcSAart Bik codegen.redVal = merger.exp(codegen.redExp).val = reduc; 4447373cabcSAart Bik } 4457373cabcSAart Bik 4467373cabcSAart Bik //===----------------------------------------------------------------------===// 4477373cabcSAart Bik // Sparse compiler synthesis methods (statements and expressions). 4487373cabcSAart Bik //===----------------------------------------------------------------------===// 4497373cabcSAart Bik 450a2c9d4bbSAart Bik /// Maps sparse integer option to actual integral storage type. 45196a23911SAart Bik static Type genIntType(PatternRewriter &rewriter, unsigned width) { 45296a23911SAart Bik if (width == 0) 453a2c9d4bbSAart Bik return rewriter.getIndexType(); 45496a23911SAart Bik return rewriter.getIntegerType(width); 455a2c9d4bbSAart Bik } 456a2c9d4bbSAart Bik 457ec97a205SAart Bik /// Generates buffer for the output tensor. Note that all sparse kernels 458ec97a205SAart Bik /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 459ec97a205SAart Bik /// the output buffer is already initialized to all zeroes and only nonzeroes 460ec97a205SAart Bik /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 461ec97a205SAart Bik /// only nonzeroes values are used for the updates and no assumption on the 462ec97a205SAart Bik /// original contents of the output buffer is necessary.. 463a2c9d4bbSAart Bik static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 464a2c9d4bbSAart Bik linalg::GenericOp op, MemRefType denseTp, 465a2c9d4bbSAart Bik ArrayRef<Value> args) { 466a2c9d4bbSAart Bik Location loc = op.getLoc(); 4672f2b5b7dSTobias Gysi Value tensor = op.getOutputOperand(0)->get(); 468a2c9d4bbSAart Bik // The output tensor simply could materialize from the buffer that will 469a2c9d4bbSAart Bik // be generated for the tensor present in the outs() clause. This has 470a2c9d4bbSAart Bik // the major advantage that the sparse kernel only updates the nonzero 4715879da49SAart Bik // positions for the output tensor. 472c8d5dcb0SAart Bik if (isInPlace(tensor)) 47357470abcSAlexander Belyaev return rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 474a2c9d4bbSAart Bik // By default, a new buffer is allocated which is initialized to the 475a2c9d4bbSAart Bik // tensor defined in the outs() clause. This is always correct but 476a2c9d4bbSAart Bik // introduces a dense initialization component that may negatively 477ec97a205SAart Bik // impact the running complexity of the sparse kernel. If the tensor 478c8d5dcb0SAart Bik // materializes into the computation, we need to preserve the zero 479ec97a205SAart Bik // initialization assumption of all sparse output buffers. 480c8d5dcb0SAart Bik if (isMaterializing(tensor)) { 481ec97a205SAart Bik Type tp = denseTp.getElementType(); 482ec97a205SAart Bik Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 483c8d5dcb0SAart Bik Value zero = 484c8d5dcb0SAart Bik rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 485ec97a205SAart Bik rewriter.create<linalg::FillOp>(loc, zero, alloc); 486ec97a205SAart Bik return alloc; 487ec97a205SAart Bik } 48857470abcSAlexander Belyaev Value init = rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 489a2c9d4bbSAart Bik Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 49068ac2e53SAart Bik rewriter.create<memref::CopyOp>(loc, init, alloc); 491a2c9d4bbSAart Bik return alloc; 492a2c9d4bbSAart Bik } 493a2c9d4bbSAart Bik 494a2c9d4bbSAart Bik /// Local bufferization of all dense and sparse data structures. 495a2c9d4bbSAart Bik /// This code enables testing the first prototype sparse compiler. 496a2c9d4bbSAart Bik // TODO: replace this with a proliferated bufferization strategy 497c8d5dcb0SAart Bik static void genBuffers(Merger &merger, CodeGen &codegen, 498a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op) { 499a2c9d4bbSAart Bik Location loc = op.getLoc(); 5002f2b5b7dSTobias Gysi assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 501a2c9d4bbSAart Bik // For every tensor, find lower and upper bound on dimensions, set the 502a2c9d4bbSAart Bik // same bounds on loop indices, and obtain dense or sparse buffer(s). 503a2c9d4bbSAart Bik SmallVector<Value, 4> args; 5042f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 505727a63e0SAart Bik unsigned tensor = t->getOperandNumber(); 5062f2b5b7dSTobias Gysi auto shape = op.getShape(t); 5072f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 5082f2b5b7dSTobias Gysi auto enc = getSparseTensorEncoding(t->get().getType()); 509a2c9d4bbSAart Bik // Scan all dimensions of current tensor. 510a2c9d4bbSAart Bik args.clear(); 511c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 512b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 513b1d44e59SAart Bik if (a.getKind() != AffineExprKind::DimId) 514b1d44e59SAart Bik continue; // compound 515b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 516a2c9d4bbSAart Bik // Handle sparse storage schemes. 517727a63e0SAart Bik if (merger.isDim(tensor, idx, Dim::kSparse)) { 518a2c9d4bbSAart Bik auto dynShape = {ShapedType::kDynamicSize}; 519a2c9d4bbSAart Bik auto ptrTp = MemRefType::get( 52096a23911SAart Bik dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 521a2c9d4bbSAart Bik auto indTp = MemRefType::get( 52296a23911SAart Bik dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 523a54f4eaeSMogball Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d); 524a2c9d4bbSAart Bik // Generate sparse primitives to obtains pointer and indices. 525727a63e0SAart Bik codegen.pointers[tensor][idx] = 5262f2b5b7dSTobias Gysi rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 527727a63e0SAart Bik codegen.indices[tensor][idx] = 5282f2b5b7dSTobias Gysi rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 529a2c9d4bbSAart Bik } 530d37d72eaSAart Bik // Find upper bound in current dimension. 531817303efSAart Bik unsigned p = perm(enc, d); 532d37d72eaSAart Bik Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 533d37d72eaSAart Bik if (shape[p] == MemRefType::kDynamicSize) 534a2c9d4bbSAart Bik args.push_back(up); 535817303efSAart Bik assert(codegen.highs[tensor][idx] == nullptr); 536727a63e0SAart Bik codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 537a2c9d4bbSAart Bik } 538727a63e0SAart Bik // Perform the required bufferization. Dense inputs materialize 539727a63e0SAart Bik // from the input tensors. Dense outputs need special handling. 540727a63e0SAart Bik // Sparse inputs use sparse primitives to obtain the values. 541727a63e0SAart Bik // We also accept in-place all-dense annotated "sparse" outputs. 5422f2b5b7dSTobias Gysi Type elementType = getElementTypeOrSelf(t->get().getType()); 54396a23911SAart Bik if (!enc) { 544727a63e0SAart Bik // Non-annotated dense tensors. 5452f2b5b7dSTobias Gysi auto denseTp = MemRefType::get(shape, elementType); 546727a63e0SAart Bik if (tensor < op.getNumInputs()) 547727a63e0SAart Bik codegen.buffers[tensor] = 54857470abcSAlexander Belyaev rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); 549a2c9d4bbSAart Bik else 550727a63e0SAart Bik codegen.buffers[tensor] = 551a2c9d4bbSAart Bik genOutputBuffer(codegen, rewriter, op, denseTp, args); 552f66e5769SAart Bik } else if (t == codegen.sparseOut) { 553f66e5769SAart Bik // True sparse output needs a lexIdx array. 554f66e5769SAart Bik Value rank = rewriter.create<arith::ConstantIndexOp>(loc, op.getRank(t)); 555f66e5769SAart Bik auto dynShape = {ShapedType::kDynamicSize}; 556f66e5769SAart Bik auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); 557f66e5769SAart Bik codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank); 558a2c9d4bbSAart Bik } else { 559727a63e0SAart Bik // Annotated sparse tensors. 560a2c9d4bbSAart Bik auto dynShape = {ShapedType::kDynamicSize}; 5612f2b5b7dSTobias Gysi auto sparseTp = MemRefType::get(dynShape, elementType); 562727a63e0SAart Bik codegen.buffers[tensor] = 5632f2b5b7dSTobias Gysi rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 564a2c9d4bbSAart Bik } 565a2c9d4bbSAart Bik } 566a2c9d4bbSAart Bik } 567a2c9d4bbSAart Bik 568a2c9d4bbSAart Bik /// Constructs vector type. 569a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Type etp) { 570a2c9d4bbSAart Bik return VectorType::get(codegen.curVecLength, etp); 571a2c9d4bbSAart Bik } 572a2c9d4bbSAart Bik 573a2c9d4bbSAart Bik /// Constructs vector type from pointer. 574a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Value ptr) { 575a2c9d4bbSAart Bik return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 576a2c9d4bbSAart Bik } 577a2c9d4bbSAart Bik 578a2c9d4bbSAart Bik /// Constructs vector iteration mask. 579a2c9d4bbSAart Bik static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 580a2c9d4bbSAart Bik Value iv, Value lo, Value hi, Value step) { 581a2c9d4bbSAart Bik Location loc = iv.getLoc(); 5827373cabcSAart Bik VectorType mtp = vectorType(codegen, genIntType(rewriter, 1)); 583a2c9d4bbSAart Bik // Special case if the vector length evenly divides the trip count (for 584a2c9d4bbSAart Bik // example, "for i = 0, 128, 16"). A constant all-true mask is generated 585a2c9d4bbSAart Bik // so that all subsequent masked memory operations are immediately folded 586a2c9d4bbSAart Bik // into unconditional memory operations. 587a2c9d4bbSAart Bik IntegerAttr loInt, hiInt, stepInt; 588a2c9d4bbSAart Bik if (matchPattern(lo, m_Constant(&loInt)) && 589a2c9d4bbSAart Bik matchPattern(hi, m_Constant(&hiInt)) && 590a2c9d4bbSAart Bik matchPattern(step, m_Constant(&stepInt))) { 591a2c9d4bbSAart Bik if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 592a2c9d4bbSAart Bik return rewriter.create<vector::BroadcastOp>( 593a54f4eaeSMogball loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1)); 594a2c9d4bbSAart Bik } 595a2c9d4bbSAart Bik // Otherwise, generate a vector mask that avoids overrunning the upperbound 596a2c9d4bbSAart Bik // during vector execution. Here we rely on subsequent loop optimizations to 597a2c9d4bbSAart Bik // avoid executing the mask in all iterations, for example, by splitting the 598a2c9d4bbSAart Bik // loop into an unconditional vector loop and a scalar cleanup loop. 59976a18618SMatthias Springer auto minMap = AffineMap::get( 60076a18618SMatthias Springer /*dimCount=*/2, /*symbolCount=*/1, 60176a18618SMatthias Springer {rewriter.getAffineSymbolExpr(0), 60276a18618SMatthias Springer rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 60376a18618SMatthias Springer rewriter.getContext()); 60476a18618SMatthias Springer Value end = 60576a18618SMatthias Springer rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 606a2c9d4bbSAart Bik return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 607a2c9d4bbSAart Bik } 608a2c9d4bbSAart Bik 609a2c9d4bbSAart Bik /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 610a2c9d4bbSAart Bik static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 611a2c9d4bbSAart Bik Value ptr, ArrayRef<Value> args) { 612a2c9d4bbSAart Bik Location loc = ptr.getLoc(); 613a2c9d4bbSAart Bik VectorType vtp = vectorType(codegen, ptr); 614a54f4eaeSMogball Value pass = 615a54f4eaeSMogball rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 616a2c9d4bbSAart Bik if (args.back().getType().isa<VectorType>()) { 617a2c9d4bbSAart Bik SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 618a2c9d4bbSAart Bik Value indexVec = args.back(); 619a54f4eaeSMogball scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 620a2c9d4bbSAart Bik return rewriter.create<vector::GatherOp>( 621a2c9d4bbSAart Bik loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 622a2c9d4bbSAart Bik } 623a2c9d4bbSAart Bik return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 624a2c9d4bbSAart Bik codegen.curVecMask, pass); 625a2c9d4bbSAart Bik } 626a2c9d4bbSAart Bik 627a2c9d4bbSAart Bik /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 628a2c9d4bbSAart Bik static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 629a2c9d4bbSAart Bik Value rhs, Value ptr, ArrayRef<Value> args) { 630a2c9d4bbSAart Bik Location loc = ptr.getLoc(); 631a2c9d4bbSAart Bik if (args.back().getType().isa<VectorType>()) { 632a2c9d4bbSAart Bik SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 633a2c9d4bbSAart Bik Value indexVec = args.back(); 634a54f4eaeSMogball scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0); 635a2c9d4bbSAart Bik rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 636a2c9d4bbSAart Bik codegen.curVecMask, rhs); 637a2c9d4bbSAart Bik return; 638a2c9d4bbSAart Bik } 639a2c9d4bbSAart Bik rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 640a2c9d4bbSAart Bik rhs); 641a2c9d4bbSAart Bik } 642a2c9d4bbSAart Bik 643a2c9d4bbSAart Bik /// Generates a vectorized invariant. Here we rely on subsequent loop 644a2c9d4bbSAart Bik /// optimizations to hoist the invariant broadcast out of the vector loop. 645a2c9d4bbSAart Bik static Value genVectorInvariantValue(CodeGen &codegen, 646a2c9d4bbSAart Bik PatternRewriter &rewriter, Value val) { 647a2c9d4bbSAart Bik VectorType vtp = vectorType(codegen, val.getType()); 648a2c9d4bbSAart Bik return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 649a2c9d4bbSAart Bik } 650a2c9d4bbSAart Bik 651b1d44e59SAart Bik /// Generates an affine expression. 652b1d44e59SAart Bik // 653b1d44e59SAart Bik // TODO: generalize for sparse tensor subscripts 654b1d44e59SAart Bik // 655b1d44e59SAart Bik static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 656b1d44e59SAart Bik AffineExpr a, Location loc) { 657b1d44e59SAart Bik switch (a.getKind()) { 658b1d44e59SAart Bik case AffineExprKind::DimId: { 659b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 660b1d44e59SAart Bik return codegen.loops[idx]; // universal dense index 661b1d44e59SAart Bik } 662b1d44e59SAart Bik case AffineExprKind::Add: { 663b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 664a54f4eaeSMogball return rewriter.create<arith::AddIOp>( 665b1d44e59SAart Bik loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 666b1d44e59SAart Bik genAffine(codegen, rewriter, binOp.getRHS(), loc)); 667b1d44e59SAart Bik } 668b1d44e59SAart Bik case AffineExprKind::Mul: { 669b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 670a54f4eaeSMogball return rewriter.create<arith::MulIOp>( 671b1d44e59SAart Bik loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 672b1d44e59SAart Bik genAffine(codegen, rewriter, binOp.getRHS(), loc)); 673b1d44e59SAart Bik } 674b1d44e59SAart Bik case AffineExprKind::Constant: { 675b1d44e59SAart Bik int64_t c = a.cast<AffineConstantExpr>().getValue(); 676a54f4eaeSMogball return rewriter.create<arith::ConstantIndexOp>(loc, c); 677b1d44e59SAart Bik } 678b1d44e59SAart Bik default: 679b1d44e59SAart Bik llvm_unreachable("unexpected affine subscript"); 680b1d44e59SAart Bik } 681b1d44e59SAart Bik } 682b1d44e59SAart Bik 683b1d44e59SAart Bik /// Generates subscript for load/store on a dense or sparse tensor. 684b1d44e59SAart Bik static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 685b1d44e59SAart Bik linalg::GenericOp op, OpOperand *t, 686b1d44e59SAart Bik SmallVector<Value, 4> &args) { 687b1d44e59SAart Bik unsigned tensor = t->getOperandNumber(); 688b1d44e59SAart Bik auto map = op.getTiedIndexingMap(t); 689b1d44e59SAart Bik auto enc = getSparseTensorEncoding(t->get().getType()); 690b1d44e59SAart Bik unsigned rank = map.getNumResults(); 691b1d44e59SAart Bik if (enc) { 692b1d44e59SAart Bik // Note that currently, all sparse subscripts are simple. 693b1d44e59SAart Bik // TODO: accept affine too? 694c8d5dcb0SAart Bik AffineExpr a = map.getResult(perm(enc, rank - 1)); 695c8d5dcb0SAart Bik assert(a.getKind() == AffineExprKind::DimId); 696c8d5dcb0SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 697b1d44e59SAart Bik assert(codegen.pidxs[tensor][idx] != nullptr); 698b1d44e59SAart Bik args.push_back(codegen.pidxs[tensor][idx]); // position index 699b1d44e59SAart Bik } else { 700b1d44e59SAart Bik for (unsigned d = 0; d < rank; d++) { 701b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 702b1d44e59SAart Bik args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 703b1d44e59SAart Bik } 704b1d44e59SAart Bik } 705b1d44e59SAart Bik return codegen.buffers[tensor]; 706b1d44e59SAart Bik } 707b1d44e59SAart Bik 708a2c9d4bbSAart Bik /// Generates a load on a dense or sparse tensor. 709a2c9d4bbSAart Bik static Value genTensorLoad(Merger &merger, CodeGen &codegen, 710a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 711a2c9d4bbSAart Bik unsigned exp) { 712a2c9d4bbSAart Bik // Test if the load was hoisted to a higher loop nest. 713a2c9d4bbSAart Bik Value val = merger.exp(exp).val; 714a2c9d4bbSAart Bik if (val) { 715a2c9d4bbSAart Bik if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 716a2c9d4bbSAart Bik return genVectorInvariantValue(codegen, rewriter, val); 717a2c9d4bbSAart Bik return val; 718a2c9d4bbSAart Bik } 719*7d4da4e1SAart Bik // Insertion (a sparse tensor output "loads" as zero). 720*7d4da4e1SAart Bik OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 721*7d4da4e1SAart Bik if (t == codegen.sparseOut) { 722*7d4da4e1SAart Bik Type tp = getElementTypeOrSelf(t->get().getType()); 723*7d4da4e1SAart Bik return rewriter.create<arith::ConstantOp>(op.getLoc(), tp, 724*7d4da4e1SAart Bik rewriter.getZeroAttr(tp)); 725*7d4da4e1SAart Bik } 726a2c9d4bbSAart Bik // Actual load. 727a2c9d4bbSAart Bik SmallVector<Value, 4> args; 728b1d44e59SAart Bik Value ptr = genSubscript(codegen, rewriter, op, t, args); 729a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 730a2c9d4bbSAart Bik return genVectorLoad(codegen, rewriter, ptr, args); 731b1d44e59SAart Bik return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 732a2c9d4bbSAart Bik } 733a2c9d4bbSAart Bik 734727a63e0SAart Bik /// Generates a store on a dense or sparse tensor. 735a2c9d4bbSAart Bik static void genTensorStore(Merger &merger, CodeGen &codegen, 736a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 737b1d44e59SAart Bik Value rhs) { 738f66e5769SAart Bik Location loc = op.getLoc(); 739a2c9d4bbSAart Bik // Test if this is a scalarized reduction. 740b1d44e59SAart Bik if (codegen.redVal) { 741a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 742f66e5769SAart Bik rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 743a2c9d4bbSAart Bik codegen.redVal); 7447373cabcSAart Bik updateReduc(merger, codegen, rhs); 745a2c9d4bbSAart Bik return; 746a2c9d4bbSAart Bik } 747f66e5769SAart Bik // Insertion. 748f66e5769SAart Bik OpOperand *t = op.getOutputOperand(0); 749f66e5769SAart Bik if (t == codegen.sparseOut) { 750f66e5769SAart Bik rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs); 751f66e5769SAart Bik return; 752f66e5769SAart Bik } 753a2c9d4bbSAart Bik // Actual store. 754a2c9d4bbSAart Bik SmallVector<Value, 4> args; 755b1d44e59SAart Bik Value ptr = genSubscript(codegen, rewriter, op, t, args); 756a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 757a2c9d4bbSAart Bik genVectorStore(codegen, rewriter, rhs, ptr, args); 758a2c9d4bbSAart Bik else 759f66e5769SAart Bik rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 760a2c9d4bbSAart Bik } 761a2c9d4bbSAart Bik 762a2c9d4bbSAart Bik /// Generates a pointer/index load from the sparse storage scheme. Narrower 763a2c9d4bbSAart Bik /// data types need to be zero extended before casting the value into the 764a2c9d4bbSAart Bik /// index type used for looping and indexing. 765a2c9d4bbSAart Bik static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 766a2c9d4bbSAart Bik Value ptr, Value s) { 767a2c9d4bbSAart Bik // See https://llvm.org/docs/GetElementPtr.html for some background on 768a2c9d4bbSAart Bik // the complications described below. 769a2c9d4bbSAart Bik if (codegen.curVecLength > 1) { 770a2c9d4bbSAart Bik // Since the index vector is used in a subsequent gather/scatter operations, 771a2c9d4bbSAart Bik // which effectively defines an unsigned pointer + signed index, we must 772a2c9d4bbSAart Bik // zero extend the vector to an index width. For 8-bit and 16-bit values, 773a2c9d4bbSAart Bik // an 32-bit index width suffices. For 32-bit values, zero extending the 774a2c9d4bbSAart Bik // elements into 64-bit loses some performance since the 32-bit indexed 77586e9bc1aSAart Bik // gather/scatter is more efficient than the 64-bit index variant (if the 77686e9bc1aSAart Bik // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 777727a63e0SAart Bik // preserve this performance). For 64-bit values, there is no good way 778a2c9d4bbSAart Bik // to state that the indices are unsigned, with creates the potential of 779a2c9d4bbSAart Bik // incorrect address calculations in the unlikely case we need such 780a2c9d4bbSAart Bik // extremely large offsets. 781a2c9d4bbSAart Bik Type etp = ptr.getType().cast<MemRefType>().getElementType(); 782a2c9d4bbSAart Bik Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 783a2c9d4bbSAart Bik if (!etp.isa<IndexType>()) { 784a2c9d4bbSAart Bik if (etp.getIntOrFloatBitWidth() < 32) 785a54f4eaeSMogball vload = rewriter.create<arith::ExtUIOp>( 7867373cabcSAart Bik loc, vload, vectorType(codegen, genIntType(rewriter, 32))); 78786e9bc1aSAart Bik else if (etp.getIntOrFloatBitWidth() < 64 && 78886e9bc1aSAart Bik !codegen.options.enableSIMDIndex32) 789a54f4eaeSMogball vload = rewriter.create<arith::ExtUIOp>( 7907373cabcSAart Bik loc, vload, vectorType(codegen, genIntType(rewriter, 64))); 791a2c9d4bbSAart Bik } 792a2c9d4bbSAart Bik return vload; 793a2c9d4bbSAart Bik } 794a2c9d4bbSAart Bik // For the scalar case, we simply zero extend narrower indices into 64-bit 795a2c9d4bbSAart Bik // values before casting to index without a performance penalty. Here too, 796a2c9d4bbSAart Bik // however, indices that already are 64-bit, in theory, cannot express the 797a2c9d4bbSAart Bik // full range as explained above. 798a2c9d4bbSAart Bik Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 799a2c9d4bbSAart Bik if (!load.getType().isa<IndexType>()) { 800a2c9d4bbSAart Bik if (load.getType().getIntOrFloatBitWidth() < 64) 8017373cabcSAart Bik load = 8027373cabcSAart Bik rewriter.create<arith::ExtUIOp>(loc, load, genIntType(rewriter, 64)); 803a54f4eaeSMogball load = 804a54f4eaeSMogball rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType()); 805a2c9d4bbSAart Bik } 806a2c9d4bbSAart Bik return load; 807a2c9d4bbSAart Bik } 808a2c9d4bbSAart Bik 809a2c9d4bbSAart Bik /// Generates an invariant value. 810a2c9d4bbSAart Bik static Value genInvariantValue(Merger &merger, CodeGen &codegen, 811a2c9d4bbSAart Bik PatternRewriter &rewriter, unsigned exp) { 812a2c9d4bbSAart Bik Value val = merger.exp(exp).val; 813a2c9d4bbSAart Bik if (codegen.curVecLength > 1) 814a2c9d4bbSAart Bik return genVectorInvariantValue(codegen, rewriter, val); 815a2c9d4bbSAart Bik return val; 816a2c9d4bbSAart Bik } 817a2c9d4bbSAart Bik 818a2c9d4bbSAart Bik /// Generates an address computation "sz * p + i". 819a2c9d4bbSAart Bik static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 820a2c9d4bbSAart Bik Location loc, Value size, Value p, Value i) { 821a54f4eaeSMogball Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 822a2c9d4bbSAart Bik if (auto vtp = i.getType().dyn_cast<VectorType>()) { 823a54f4eaeSMogball Value inv = 824a54f4eaeSMogball rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType()); 825a2c9d4bbSAart Bik mul = genVectorInvariantValue(codegen, rewriter, inv); 826a2c9d4bbSAart Bik } 827a54f4eaeSMogball return rewriter.create<arith::AddIOp>(loc, mul, i); 828a2c9d4bbSAart Bik } 829a2c9d4bbSAart Bik 830a2c9d4bbSAart Bik /// Recursively generates tensor expression. 831a2c9d4bbSAart Bik static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 832a2c9d4bbSAart Bik linalg::GenericOp op, unsigned exp) { 833b8a021dbSAart Bik Location loc = op.getLoc(); 834123e8dfcSAart Bik if (exp == -1u) 835123e8dfcSAart Bik return Value(); 836a2c9d4bbSAart Bik if (merger.exp(exp).kind == Kind::kTensor) 837a2c9d4bbSAart Bik return genTensorLoad(merger, codegen, rewriter, op, exp); 838b8a021dbSAart Bik if (merger.exp(exp).kind == Kind::kInvariant) 839a2c9d4bbSAart Bik return genInvariantValue(merger, codegen, rewriter, exp); 8404569c14aSGus Smith Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 8414569c14aSGus Smith Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 84245b3cfe8SAart Bik return merger.buildExp(rewriter, loc, exp, v0, v1); 843a2c9d4bbSAart Bik } 844a2c9d4bbSAart Bik 845b1d44e59SAart Bik /// Determines if affine expression is invariant. 846b1d44e59SAart Bik static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 847b1d44e59SAart Bik unsigned ldx, bool &atLevel) { 848b1d44e59SAart Bik switch (a.getKind()) { 849b1d44e59SAart Bik case AffineExprKind::DimId: { 850b1d44e59SAart Bik unsigned idx = a.cast<AffineDimExpr>().getPosition(); 851b1d44e59SAart Bik if (idx == ldx) 852b1d44e59SAart Bik atLevel = true; 853b1d44e59SAart Bik return codegen.loops[idx] != nullptr; // no longer in play? 854b1d44e59SAart Bik } 855b1d44e59SAart Bik case AffineExprKind::Add: 856b1d44e59SAart Bik case AffineExprKind::Mul: { 857b1d44e59SAart Bik auto binOp = a.cast<AffineBinaryOpExpr>(); 858b1d44e59SAart Bik return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 859b1d44e59SAart Bik isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 860b1d44e59SAart Bik } 861b1d44e59SAart Bik default: 862b1d44e59SAart Bik return true; 863b1d44e59SAart Bik } 864b1d44e59SAart Bik } 865b1d44e59SAart Bik 866a2c9d4bbSAart Bik /// Hoists loop invariant tensor loads for which indices have been exhausted. 867a2c9d4bbSAart Bik static void genInvariants(Merger &merger, CodeGen &codegen, 868a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 8697373cabcSAart Bik unsigned exp, unsigned ldx, bool atStart, 8705da21338SAart Bik Kind last = Kind::kTensor) { 871123e8dfcSAart Bik if (exp == -1u) 872123e8dfcSAart Bik return; 873a2c9d4bbSAart Bik if (merger.exp(exp).kind == Kind::kTensor) { 874a2c9d4bbSAart Bik // Inspect tensor indices. 875a2c9d4bbSAart Bik bool atLevel = ldx == -1u; 8764569c14aSGus Smith OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 877619bfe8bSAart Bik auto map = op.getTiedIndexingMap(t); 878619bfe8bSAart Bik auto enc = getSparseTensorEncoding(t->get().getType()); 879c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 880b1d44e59SAart Bik AffineExpr a = map.getResult(perm(enc, d)); 881b1d44e59SAart Bik if (!isInvariantAffine(codegen, a, ldx, atLevel)) 882a2c9d4bbSAart Bik return; // still in play 883a2c9d4bbSAart Bik } 884a2c9d4bbSAart Bik // All exhausted at this level (atLevel denotes exactly at this level). 8857373cabcSAart Bik if (!atLevel) 8867373cabcSAart Bik return; 8872f2b5b7dSTobias Gysi OpOperand *lhs = op.getOutputOperand(0); 888619bfe8bSAart Bik if (lhs == t) { 8897373cabcSAart Bik // Start or end a scalarized reduction 8907373cabcSAart Bik if (atStart) { 8917373cabcSAart Bik Value load = genTensorLoad(merger, codegen, rewriter, op, exp); 8925da21338SAart Bik codegen.redKind = getReduction(last); 8937373cabcSAart Bik codegen.redExp = exp; 8947373cabcSAart Bik updateReduc(merger, codegen, load); 8957373cabcSAart Bik } else { 8967373cabcSAart Bik Value redVal = codegen.redVal; 8977373cabcSAart Bik updateReduc(merger, codegen, Value()); 8987373cabcSAart Bik codegen.redExp = -1u; 8997373cabcSAart Bik codegen.redKind = kNoReduc; 9007373cabcSAart Bik genTensorStore(merger, codegen, rewriter, op, redVal); 9017373cabcSAart Bik } 9027373cabcSAart Bik } else { 9037373cabcSAart Bik // Start or end loop invariant hoisting of a tensor load. 904a2c9d4bbSAart Bik merger.exp(exp).val = 9057373cabcSAart Bik atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 906a2c9d4bbSAart Bik } 907123e8dfcSAart Bik } else if (merger.exp(exp).kind != Kind::kInvariant) { 908a2c9d4bbSAart Bik // Traverse into the binary operations. Note that we only hoist 909a2c9d4bbSAart Bik // tensor loads, since subsequent MLIR/LLVM passes know how to 910a2c9d4bbSAart Bik // deal with all other kinds of derived loop invariants. 9115da21338SAart Bik Kind last = merger.exp(exp).kind; 9124569c14aSGus Smith unsigned e0 = merger.exp(exp).children.e0; 9134569c14aSGus Smith unsigned e1 = merger.exp(exp).children.e1; 9147373cabcSAart Bik genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); 9157373cabcSAart Bik genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); 916a2c9d4bbSAart Bik } 917a2c9d4bbSAart Bik } 918a2c9d4bbSAart Bik 919a2c9d4bbSAart Bik /// Generates initialization code for the subsequent loop sequence at 920a2c9d4bbSAart Bik /// current index level. Returns true if the loop sequence needs to 921a2c9d4bbSAart Bik /// maintain the universal index. 922a2c9d4bbSAart Bik static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 923a2c9d4bbSAart Bik linalg::GenericOp op, std::vector<unsigned> &topSort, 924a2c9d4bbSAart Bik unsigned at, llvm::BitVector &inits) { 925a2c9d4bbSAart Bik bool needsUniv = false; 926a2c9d4bbSAart Bik Location loc = op.getLoc(); 927a2c9d4bbSAart Bik unsigned idx = topSort[at]; 928a2c9d4bbSAart Bik 929a2c9d4bbSAart Bik // Initialize sparse positions. 930a2c9d4bbSAart Bik for (unsigned b = 0, be = inits.size(); b < be; b++) { 931a2c9d4bbSAart Bik if (inits[b]) { 932a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 933a2c9d4bbSAart Bik assert(idx == merger.index(b)); 934a2c9d4bbSAart Bik if (merger.isDim(b, Dim::kSparse)) { 935a2c9d4bbSAart Bik // Initialize sparse index. 936a2c9d4bbSAart Bik unsigned pat = at; 937a2c9d4bbSAart Bik for (; pat != 0; pat--) { 938a2c9d4bbSAart Bik if (codegen.pidxs[tensor][topSort[pat - 1]]) 939a2c9d4bbSAart Bik break; 940a2c9d4bbSAart Bik } 941a2c9d4bbSAart Bik Value ptr = codegen.pointers[tensor][idx]; 942a54f4eaeSMogball Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 943a54f4eaeSMogball Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 944a2c9d4bbSAart Bik : codegen.pidxs[tensor][topSort[pat - 1]]; 945a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 946a54f4eaeSMogball Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 947a2c9d4bbSAart Bik codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 948a2c9d4bbSAart Bik } else { 949a2c9d4bbSAart Bik // Dense index still in play. 950a2c9d4bbSAart Bik needsUniv = true; 951a2c9d4bbSAart Bik } 952a2c9d4bbSAart Bik } 953a2c9d4bbSAart Bik } 954a2c9d4bbSAart Bik 955a2c9d4bbSAart Bik // Initialize the universal dense index. 956a54f4eaeSMogball codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0); 957a2c9d4bbSAart Bik return needsUniv; 958a2c9d4bbSAart Bik } 959a2c9d4bbSAart Bik 960a2c9d4bbSAart Bik /// Returns vectorization strategy. Any implicit inner loop in the Linalg 961a2c9d4bbSAart Bik /// operation is a candidate. Whether it is actually converted to SIMD code 962a2c9d4bbSAart Bik /// depends on the requested strategy. 963a2c9d4bbSAart Bik static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 964a2c9d4bbSAart Bik switch (codegen.options.vectorizationStrategy) { 965a2c9d4bbSAart Bik case SparseVectorizationStrategy::kNone: 966a2c9d4bbSAart Bik return false; 967a2c9d4bbSAart Bik case SparseVectorizationStrategy::kDenseInnerLoop: 968a2c9d4bbSAart Bik return isInner && !isSparse; 969a2c9d4bbSAart Bik case SparseVectorizationStrategy::kAnyStorageInnerLoop: 970a2c9d4bbSAart Bik return isInner; 971a2c9d4bbSAart Bik } 972a2c9d4bbSAart Bik llvm_unreachable("unexpected vectorization strategy"); 973a2c9d4bbSAart Bik } 974a2c9d4bbSAart Bik 975a2c9d4bbSAart Bik /// Returns parallelization strategy. Any implicit loop in the Linalg operation 976a2c9d4bbSAart Bik /// that is marked "parallel" is a candidate. Whether it is actually converted 977a2c9d4bbSAart Bik /// to a parallel operation depends on the requested strategy. 978a2c9d4bbSAart Bik static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 979a2c9d4bbSAart Bik bool isSparse, bool isVector) { 980a2c9d4bbSAart Bik switch (codegen.options.parallelizationStrategy) { 981a2c9d4bbSAart Bik case SparseParallelizationStrategy::kNone: 982a2c9d4bbSAart Bik return false; 983a2c9d4bbSAart Bik case SparseParallelizationStrategy::kDenseOuterLoop: 984a2c9d4bbSAart Bik return isOuter && !isSparse && !isReduction && !isVector; 985a2c9d4bbSAart Bik case SparseParallelizationStrategy::kAnyStorageOuterLoop: 986a2c9d4bbSAart Bik return isOuter && !isReduction && !isVector; 987a2c9d4bbSAart Bik case SparseParallelizationStrategy::kDenseAnyLoop: 988a2c9d4bbSAart Bik return !isSparse && !isReduction && !isVector; 989a2c9d4bbSAart Bik case SparseParallelizationStrategy::kAnyStorageAnyLoop: 990a2c9d4bbSAart Bik return !isReduction && !isVector; 991a2c9d4bbSAart Bik } 992a2c9d4bbSAart Bik llvm_unreachable("unexpected parallelization strategy"); 993a2c9d4bbSAart Bik } 994a2c9d4bbSAart Bik 995849f016cSAart Bik /// Checks unit stride for dense tensors. The iteration graph may have ignored 996a2c9d4bbSAart Bik /// dense access patterns in order to avoid cycles (sparse access patterns are 997a2c9d4bbSAart Bik /// always placed innermost), but that means dense access has become strided. 998849f016cSAart Bik /// This prevents effective vectorization. 999a2c9d4bbSAart Bik static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1000849f016cSAart Bik unsigned idx) { 10012f2b5b7dSTobias Gysi for (OpOperand *t : op.getInputAndOutputOperands()) { 10022f2b5b7dSTobias Gysi if (!getSparseTensorEncoding(t->get().getType())) { 10032f2b5b7dSTobias Gysi auto map = op.getTiedIndexingMap(t); 1004c194b49cSAart Bik for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1005b1d44e59SAart Bik AffineExpr a = map.getResult(d); 1006849f016cSAart Bik // Report non-unit stride if innermost index appears at an outer 1007849f016cSAart Bik // dimension (true non-unit stride) or if the innermost index appears 1008849f016cSAart Bik // in a compound subscript in the innermost dimension. Even if the 1009849f016cSAart Bik // latter is unit stride, it does not play well with scatter/gather. 1010c8d5dcb0SAart Bik // TODO: accept unit stride affine innermost like a[i,j+k+1]? 1011849f016cSAart Bik if (a.isFunctionOfDim(idx) && 1012849f016cSAart Bik ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 1013a2c9d4bbSAart Bik return false; 1014a2c9d4bbSAart Bik } 1015a2c9d4bbSAart Bik } 1016a2c9d4bbSAart Bik } 1017a2c9d4bbSAart Bik return true; 1018a2c9d4bbSAart Bik } 1019a2c9d4bbSAart Bik 1020a2c9d4bbSAart Bik /// Generates a for-loop on a single index. 1021a2c9d4bbSAart Bik static Operation *genFor(Merger &merger, CodeGen &codegen, 1022a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1023a2c9d4bbSAart Bik bool isOuter, bool isInner, unsigned idx, 1024a2c9d4bbSAart Bik llvm::BitVector &indices) { 1025a2c9d4bbSAart Bik unsigned fb = indices.find_first(); 1026a2c9d4bbSAart Bik unsigned tensor = merger.tensor(fb); 1027a2c9d4bbSAart Bik assert(idx == merger.index(fb)); 1028a2c9d4bbSAart Bik auto iteratorTypes = op.iterator_types().getValue(); 1029583a7542STobias Gysi bool isReduction = isReductionIterator(iteratorTypes[idx]); 1030a2c9d4bbSAart Bik bool isSparse = merger.isDim(fb, Dim::kSparse); 1031f66e5769SAart Bik bool isVector = !codegen.sparseOut && 1032f66e5769SAart Bik isVectorFor(codegen, isInner, isSparse) && 1033a2c9d4bbSAart Bik denseUnitStrides(merger, op, idx); 1034a2c9d4bbSAart Bik bool isParallel = 1035f66e5769SAart Bik !codegen.sparseOut && 1036a2c9d4bbSAart Bik isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1037a2c9d4bbSAart Bik 1038a2c9d4bbSAart Bik // Prepare vector length. 1039a2c9d4bbSAart Bik if (isVector) 1040a2c9d4bbSAart Bik codegen.curVecLength = codegen.options.vectorLength; 1041a2c9d4bbSAart Bik 1042a2c9d4bbSAart Bik // Loop bounds and increment. 1043a2c9d4bbSAart Bik Location loc = op.getLoc(); 1044a2c9d4bbSAart Bik Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1045a2c9d4bbSAart Bik Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1046a54f4eaeSMogball Value step = 1047a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength); 1048a2c9d4bbSAart Bik 1049a2c9d4bbSAart Bik // Emit a parallel loop. 1050a2c9d4bbSAart Bik if (isParallel) { 1051a2c9d4bbSAart Bik assert(!isVector); 1052a2c9d4bbSAart Bik scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1053a2c9d4bbSAart Bik if (isSparse) 1054a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1055a2c9d4bbSAart Bik else 1056a2c9d4bbSAart Bik codegen.loops[idx] = parOp.getInductionVars()[0]; 1057a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(parOp.getBody()); 1058a2c9d4bbSAart Bik return parOp; 1059a2c9d4bbSAart Bik } 1060a2c9d4bbSAart Bik 10617373cabcSAart Bik // Emit a sequential or vector loop. 1062a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 10637373cabcSAart Bik if (codegen.redVal) { 10647373cabcSAart Bik // In a vector loop, bring reduction into SIMD form, if not already. 10657373cabcSAart Bik if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 10667373cabcSAart Bik VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 10677373cabcSAart Bik Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); 10687373cabcSAart Bik updateReduc(merger, codegen, vred); 10697373cabcSAart Bik } 10707373cabcSAart Bik operands.push_back(codegen.redVal); 1071a2c9d4bbSAart Bik } 1072a2c9d4bbSAart Bik scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 10737373cabcSAart Bik if (codegen.redVal) 10747373cabcSAart Bik updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1075a2c9d4bbSAart Bik // Assign induction variable to sparse or dense index. 1076a2c9d4bbSAart Bik Value iv = forOp.getInductionVar(); 1077a2c9d4bbSAart Bik if (isSparse) 1078a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = iv; 1079a2c9d4bbSAart Bik else 1080a2c9d4bbSAart Bik codegen.loops[idx] = iv; 1081a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(forOp.getBody()); 1082a2c9d4bbSAart Bik // Share vector iteration mask between all subsequent loads/stores. 1083a2c9d4bbSAart Bik if (isVector) 1084a2c9d4bbSAart Bik codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1085a2c9d4bbSAart Bik return forOp; 1086a2c9d4bbSAart Bik } 1087a2c9d4bbSAart Bik 1088a2c9d4bbSAart Bik /// Emit a while-loop for co-iteration over multiple indices. 1089a2c9d4bbSAart Bik static Operation *genWhile(Merger &merger, CodeGen &codegen, 1090a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1091a2c9d4bbSAart Bik unsigned idx, bool needsUniv, 1092a2c9d4bbSAart Bik llvm::BitVector &indices) { 1093a2c9d4bbSAart Bik SmallVector<Type, 4> types; 1094a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 1095a2c9d4bbSAart Bik // Construct the while-loop with a parameter for each index. 1096a2c9d4bbSAart Bik Type indexType = rewriter.getIndexType(); 1097a2c9d4bbSAart Bik for (unsigned b = 0, be = indices.size(); b < be; b++) { 1098a2c9d4bbSAart Bik if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1099a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1100a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1101a2c9d4bbSAart Bik types.push_back(indexType); 1102a2c9d4bbSAart Bik operands.push_back(codegen.pidxs[tensor][idx]); 1103a2c9d4bbSAart Bik } 1104a2c9d4bbSAart Bik } 11057373cabcSAart Bik if (codegen.redVal) { 11067373cabcSAart Bik types.push_back(codegen.redVal.getType()); 11077373cabcSAart Bik operands.push_back(codegen.redVal); 11087373cabcSAart Bik } 1109a2c9d4bbSAart Bik if (needsUniv) { 1110a2c9d4bbSAart Bik types.push_back(indexType); 1111a2c9d4bbSAart Bik operands.push_back(codegen.loops[idx]); 1112a2c9d4bbSAart Bik } 11137373cabcSAart Bik assert(types.size() == operands.size()); 1114a2c9d4bbSAart Bik Location loc = op.getLoc(); 1115a2c9d4bbSAart Bik scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1116a2c9d4bbSAart Bik Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1117a2c9d4bbSAart Bik Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1118a2c9d4bbSAart Bik 1119a2c9d4bbSAart Bik // Build the "before" region, which effectively consists 1120a2c9d4bbSAart Bik // of a conjunction of "i < upper" tests on all induction. 1121a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&whileOp.before().front()); 1122a2c9d4bbSAart Bik Value cond; 1123a2c9d4bbSAart Bik unsigned o = 0; 1124a2c9d4bbSAart Bik for (unsigned b = 0, be = indices.size(); b < be; b++) { 1125a2c9d4bbSAart Bik if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1126a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1127a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1128a2c9d4bbSAart Bik Value op1 = before->getArgument(o); 1129a2c9d4bbSAart Bik Value op2 = codegen.highs[tensor][idx]; 1130a54f4eaeSMogball Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1131a54f4eaeSMogball op1, op2); 1132a54f4eaeSMogball cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1133a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = after->getArgument(o++); 1134a2c9d4bbSAart Bik } 1135a2c9d4bbSAart Bik } 11367373cabcSAart Bik if (codegen.redVal) 11377373cabcSAart Bik updateReduc(merger, codegen, after->getArgument(o++)); 1138a2c9d4bbSAart Bik if (needsUniv) 1139a2c9d4bbSAart Bik codegen.loops[idx] = after->getArgument(o++); 1140a2c9d4bbSAart Bik assert(o == operands.size()); 1141a2c9d4bbSAart Bik rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1142a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&whileOp.after().front()); 1143a2c9d4bbSAart Bik return whileOp; 1144a2c9d4bbSAart Bik } 1145a2c9d4bbSAart Bik 1146a2c9d4bbSAart Bik /// Generates a for-loop or a while-loop, depending on whether it implements 1147a2c9d4bbSAart Bik /// singleton iteration or co-iteration over the given conjunction. 1148a2c9d4bbSAart Bik static Operation *genLoop(Merger &merger, CodeGen &codegen, 1149a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1150a2c9d4bbSAart Bik std::vector<unsigned> &topSort, unsigned at, 1151a2c9d4bbSAart Bik bool needsUniv, llvm::BitVector &indices) { 1152a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1153a2c9d4bbSAart Bik if (indices.count() == 1) { 1154a2c9d4bbSAart Bik bool isOuter = at == 0; 1155a2c9d4bbSAart Bik bool isInner = at == topSort.size() - 1; 1156a2c9d4bbSAart Bik return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1157a2c9d4bbSAart Bik indices); 1158a2c9d4bbSAart Bik } 1159a2c9d4bbSAart Bik return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1160a2c9d4bbSAart Bik } 1161a2c9d4bbSAart Bik 1162a2c9d4bbSAart Bik /// Generates the local variables for this loop, consisting of the sparse 1163a2c9d4bbSAart Bik /// indices, restored universal dense index, and dense positions. 1164a2c9d4bbSAart Bik static void genLocals(Merger &merger, CodeGen &codegen, 1165a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1166a2c9d4bbSAart Bik std::vector<unsigned> &topSort, unsigned at, 1167a2c9d4bbSAart Bik bool needsUniv, llvm::BitVector &locals) { 1168a2c9d4bbSAart Bik Location loc = op.getLoc(); 1169a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1170a2c9d4bbSAart Bik 1171a2c9d4bbSAart Bik // Initialize sparse indices. 1172a2c9d4bbSAart Bik Value min; 1173a2c9d4bbSAart Bik for (unsigned b = 0, be = locals.size(); b < be; b++) { 1174a2c9d4bbSAart Bik if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1175a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1176a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1177a2c9d4bbSAart Bik Value ptr = codegen.indices[tensor][idx]; 1178a2c9d4bbSAart Bik Value s = codegen.pidxs[tensor][idx]; 1179a2c9d4bbSAart Bik Value load = genLoad(codegen, rewriter, loc, ptr, s); 1180a2c9d4bbSAart Bik codegen.idxs[tensor][idx] = load; 1181a2c9d4bbSAart Bik if (!needsUniv) { 1182a2c9d4bbSAart Bik if (min) { 1183a54f4eaeSMogball Value cmp = rewriter.create<arith::CmpIOp>( 1184a54f4eaeSMogball loc, arith::CmpIPredicate::ult, load, min); 1185a2c9d4bbSAart Bik min = rewriter.create<SelectOp>(loc, cmp, load, min); 1186a2c9d4bbSAart Bik } else { 1187a2c9d4bbSAart Bik min = load; 1188a2c9d4bbSAart Bik } 1189a2c9d4bbSAart Bik } 1190a2c9d4bbSAart Bik } 1191a2c9d4bbSAart Bik } 1192a2c9d4bbSAart Bik 1193a2c9d4bbSAart Bik // Merge dense universal index over minimum. 1194a2c9d4bbSAart Bik if (min) { 1195a2c9d4bbSAart Bik assert(!needsUniv); 1196a2c9d4bbSAart Bik codegen.loops[idx] = min; 1197a2c9d4bbSAart Bik } 1198a2c9d4bbSAart Bik 1199727a63e0SAart Bik // Initialize dense positions. Note that we generate dense indices of the 1200727a63e0SAart Bik // output tensor unconditionally, since they may not appear in the lattice, 1201727a63e0SAart Bik // but may be needed for linearized codegen. 1202a2c9d4bbSAart Bik for (unsigned b = 0, be = locals.size(); b < be; b++) { 1203727a63e0SAart Bik if ((locals[b] || merger.isOutTensor(b, idx)) && 1204727a63e0SAart Bik merger.isDim(b, Dim::kDense)) { 1205a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1206a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1207a2c9d4bbSAart Bik unsigned pat = at; 1208a2c9d4bbSAart Bik for (; pat != 0; pat--) 1209a2c9d4bbSAart Bik if (codegen.pidxs[tensor][topSort[pat - 1]]) 1210a2c9d4bbSAart Bik break; 1211a54f4eaeSMogball Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0) 1212a2c9d4bbSAart Bik : codegen.pidxs[tensor][topSort[pat - 1]]; 1213a2c9d4bbSAart Bik codegen.pidxs[tensor][idx] = genAddress( 1214a2c9d4bbSAart Bik codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1215a2c9d4bbSAart Bik } 1216a2c9d4bbSAart Bik } 1217f66e5769SAart Bik 1218f66e5769SAart Bik // Move the insertion indices in lexicographic index order. 1219f66e5769SAart Bik if (codegen.sparseOut) { 1220f66e5769SAart Bik Value pos = rewriter.create<arith::ConstantIndexOp>(loc, at); 1221f66e5769SAart Bik rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1222f66e5769SAart Bik pos); 1223f66e5769SAart Bik } 1224a2c9d4bbSAart Bik } 1225a2c9d4bbSAart Bik 1226a2c9d4bbSAart Bik /// Generates the induction structure for a while-loop. 1227a2c9d4bbSAart Bik static void genWhileInduction(Merger &merger, CodeGen &codegen, 1228a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1229a2c9d4bbSAart Bik unsigned idx, bool needsUniv, 12307373cabcSAart Bik llvm::BitVector &induction, 12317373cabcSAart Bik scf::WhileOp whileOp) { 1232a2c9d4bbSAart Bik Location loc = op.getLoc(); 12337373cabcSAart Bik // Finalize each else branch of all if statements. 12347373cabcSAart Bik if (codegen.redVal) { 12357373cabcSAart Bik while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 12367373cabcSAart Bik rewriter.getInsertionBlock()->getParentOp())) { 12377373cabcSAart Bik rewriter.create<scf::YieldOp>(loc, codegen.redVal); 12387373cabcSAart Bik updateReduc(merger, codegen, ifOp.getResult(0)); 12397373cabcSAart Bik rewriter.setInsertionPointAfter(ifOp); 12407373cabcSAart Bik } 12417373cabcSAart Bik } 12427373cabcSAart Bik rewriter.setInsertionPointToEnd(&whileOp.after().front()); 12437373cabcSAart Bik // Finalize the induction. Note that the induction could be performed 12447373cabcSAart Bik // in the individual if-branches to avoid re-evaluating the conditions. 12457373cabcSAart Bik // However, that would result in a rather elaborate forest of yield 12467373cabcSAart Bik // instructions during code generation. Moreover, performing the induction 12477373cabcSAart Bik // after the if-statements more closely resembles code generated by TACO. 1248a2c9d4bbSAart Bik unsigned o = 0; 1249a2c9d4bbSAart Bik SmallVector<Value, 4> operands; 1250a54f4eaeSMogball Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1251a2c9d4bbSAart Bik for (unsigned b = 0, be = induction.size(); b < be; b++) { 1252a2c9d4bbSAart Bik if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1253a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1254a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1255a2c9d4bbSAart Bik Value op1 = codegen.idxs[tensor][idx]; 1256a2c9d4bbSAart Bik Value op2 = codegen.loops[idx]; 1257a2c9d4bbSAart Bik Value op3 = codegen.pidxs[tensor][idx]; 1258a54f4eaeSMogball Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1259a54f4eaeSMogball op1, op2); 1260a54f4eaeSMogball Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1261a2c9d4bbSAart Bik operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 12627373cabcSAart Bik codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1263a2c9d4bbSAart Bik } 1264a2c9d4bbSAart Bik } 12657373cabcSAart Bik if (codegen.redVal) { 12667373cabcSAart Bik operands.push_back(codegen.redVal); 12677373cabcSAart Bik updateReduc(merger, codegen, whileOp->getResult(o++)); 12687373cabcSAart Bik } 1269a2c9d4bbSAart Bik if (needsUniv) { 1270a54f4eaeSMogball operands.push_back( 1271a54f4eaeSMogball rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 12727373cabcSAart Bik codegen.loops[idx] = whileOp->getResult(o++); 1273a2c9d4bbSAart Bik } 1274a2c9d4bbSAart Bik assert(o == operands.size()); 1275a2c9d4bbSAart Bik rewriter.create<scf::YieldOp>(loc, operands); 12767373cabcSAart Bik rewriter.setInsertionPointAfter(whileOp); 12777373cabcSAart Bik } 12787373cabcSAart Bik 12797373cabcSAart Bik /// Generates the induction structure for a for-loop. 12807373cabcSAart Bik static void genForInduction(Merger &merger, CodeGen &codegen, 12817373cabcSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 12827373cabcSAart Bik Operation *loop) { 12837373cabcSAart Bik Location loc = op.getLoc(); 12847373cabcSAart Bik unsigned o = 0; 12857373cabcSAart Bik SmallVector<Value, 4> operands; 12867373cabcSAart Bik if (codegen.redVal) { 12877373cabcSAart Bik operands.push_back(codegen.redVal); 12887373cabcSAart Bik updateReduc(merger, codegen, loop->getResult(o++)); 12897373cabcSAart Bik } 12907373cabcSAart Bik assert(o == operands.size()); 12917373cabcSAart Bik if (o > 0) 12927373cabcSAart Bik rewriter.create<scf::YieldOp>(loc, operands); 12937373cabcSAart Bik rewriter.setInsertionPointAfter(loop); 1294a2c9d4bbSAart Bik } 1295a2c9d4bbSAart Bik 1296a2c9d4bbSAart Bik /// Generates a single if-statement within a while-loop. 1297a2c9d4bbSAart Bik static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1298a2c9d4bbSAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1299a2c9d4bbSAart Bik unsigned idx, llvm::BitVector &conditions) { 1300a2c9d4bbSAart Bik Location loc = op.getLoc(); 13017373cabcSAart Bik SmallVector<Type, 4> types; 1302a2c9d4bbSAart Bik Value cond; 1303a2c9d4bbSAart Bik for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1304a2c9d4bbSAart Bik if (conditions[b]) { 1305a2c9d4bbSAart Bik unsigned tensor = merger.tensor(b); 1306a2c9d4bbSAart Bik assert(idx == merger.index(b)); 1307a2c9d4bbSAart Bik Value clause; 1308a2c9d4bbSAart Bik if (merger.isDim(b, Dim::kSparse)) { 1309a2c9d4bbSAart Bik Value op1 = codegen.idxs[tensor][idx]; 1310a2c9d4bbSAart Bik Value op2 = codegen.loops[idx]; 1311a54f4eaeSMogball clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1312a54f4eaeSMogball op1, op2); 1313a2c9d4bbSAart Bik } else { 1314a54f4eaeSMogball clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true 1315a2c9d4bbSAart Bik } 1316a54f4eaeSMogball cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1317a2c9d4bbSAart Bik } 1318a2c9d4bbSAart Bik } 13197373cabcSAart Bik if (codegen.redVal) 13207373cabcSAart Bik types.push_back(codegen.redVal.getType()); 13217373cabcSAart Bik scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1322a2c9d4bbSAart Bik rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1323a2c9d4bbSAart Bik return ifOp; 1324a2c9d4bbSAart Bik } 1325a2c9d4bbSAart Bik 13267373cabcSAart Bik /// Generates end of true branch of if-statement within a while-loop. 13277373cabcSAart Bik static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 13287373cabcSAart Bik linalg::GenericOp op, scf::IfOp ifOp, Value ifInput) { 13297373cabcSAart Bik if (codegen.redVal) { 13307373cabcSAart Bik rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal); 13317373cabcSAart Bik updateReduc(merger, codegen, ifInput); 13327373cabcSAart Bik } 13337373cabcSAart Bik rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 13347373cabcSAart Bik } 13357373cabcSAart Bik 1336c8d5dcb0SAart Bik //===----------------------------------------------------------------------===// 1337c8d5dcb0SAart Bik // Sparse compiler synthesis methods (loop sequence). 1338c8d5dcb0SAart Bik //===----------------------------------------------------------------------===// 1339c8d5dcb0SAart Bik 1340c8d5dcb0SAart Bik /// Starts a loop sequence at given level. Returns true if 1341c8d5dcb0SAart Bik /// the universal loop index must be maintained at this level. 1342c8d5dcb0SAart Bik static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1343c8d5dcb0SAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1344c8d5dcb0SAart Bik std::vector<unsigned> &topSort, unsigned exp, 1345c8d5dcb0SAart Bik unsigned at, unsigned idx, unsigned ldx, 1346c8d5dcb0SAart Bik unsigned lts) { 1347c8d5dcb0SAart Bik assert(codegen.curVecLength == 1); 13487373cabcSAart Bik assert(!codegen.loops[idx]); 1349c8d5dcb0SAart Bik // Emit invariants at this loop sequence level. 13507373cabcSAart Bik genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); 1351c8d5dcb0SAart Bik // Emit further intitialization at this loop sequence level. 1352c8d5dcb0SAart Bik unsigned l0 = merger.set(lts)[0]; 13537373cabcSAart Bik bool needsUniv = 13547373cabcSAart Bik genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); 1355c8d5dcb0SAart Bik // Maintain the universal index only if it is actually 1356c8d5dcb0SAart Bik // consumed by a subsequent lattice point. 13577373cabcSAart Bik if (needsUniv) { 1358c8d5dcb0SAart Bik unsigned lsize = merger.set(lts).size(); 1359c8d5dcb0SAart Bik for (unsigned i = 1; i < lsize; i++) { 1360c8d5dcb0SAart Bik unsigned li = merger.set(lts)[i]; 1361c8d5dcb0SAart Bik if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1362c8d5dcb0SAart Bik return true; 1363c8d5dcb0SAart Bik } 1364c8d5dcb0SAart Bik } 1365c8d5dcb0SAart Bik return false; 1366c8d5dcb0SAart Bik } 1367c8d5dcb0SAart Bik 1368c8d5dcb0SAart Bik /// Starts a single loop in current sequence. 1369c8d5dcb0SAart Bik static Operation *startLoop(Merger &merger, CodeGen &codegen, 1370c8d5dcb0SAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1371c8d5dcb0SAart Bik std::vector<unsigned> &topSort, unsigned at, 1372c8d5dcb0SAart Bik unsigned li, bool needsUniv) { 1373c8d5dcb0SAart Bik assert(codegen.curVecLength == 1); 1374c8d5dcb0SAart Bik // Emit the for/while-loop control. 1375c8d5dcb0SAart Bik Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1376c8d5dcb0SAart Bik needsUniv, merger.lat(li).simple); 1377c8d5dcb0SAart Bik // Emit the locals for this loop. 1378c8d5dcb0SAart Bik genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1379c8d5dcb0SAart Bik merger.lat(li).bits); 1380c8d5dcb0SAart Bik return loop; 1381c8d5dcb0SAart Bik } 1382c8d5dcb0SAart Bik 1383c8d5dcb0SAart Bik /// Ends a single loop in current sequence. Returns new values for needsUniv. 1384c8d5dcb0SAart Bik static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1385c8d5dcb0SAart Bik linalg::GenericOp op, Operation *loop, unsigned idx, 1386c8d5dcb0SAart Bik unsigned li, bool needsUniv) { 1387c8d5dcb0SAart Bik codegen.curVecLength = 1; 1388c8d5dcb0SAart Bik // End a while-loop. 1389c8d5dcb0SAart Bik if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1390c8d5dcb0SAart Bik genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 13917373cabcSAart Bik merger.lat(li).bits, whileOp); 1392c8d5dcb0SAart Bik return needsUniv; 1393c8d5dcb0SAart Bik } 1394c8d5dcb0SAart Bik // End a for-loop. 13957373cabcSAart Bik genForInduction(merger, codegen, rewriter, op, loop); 1396c8d5dcb0SAart Bik return false; 1397c8d5dcb0SAart Bik } 1398c8d5dcb0SAart Bik 1399c8d5dcb0SAart Bik /// Ends a loop sequence at given level. 1400c8d5dcb0SAart Bik static void endLoopSeq(Merger &merger, CodeGen &codegen, 1401c8d5dcb0SAart Bik PatternRewriter &rewriter, linalg::GenericOp op, 1402c8d5dcb0SAart Bik unsigned exp, unsigned idx, unsigned ldx) { 1403c8d5dcb0SAart Bik assert(codegen.curVecLength == 1); 1404c8d5dcb0SAart Bik codegen.loops[idx] = Value(); 14057373cabcSAart Bik // Bring a pending reduction back from SIMD form when sequence ends. 14067373cabcSAart Bik if (codegen.redVal) 14077373cabcSAart Bik if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 14087373cabcSAart Bik updateReduc(merger, codegen, 14097373cabcSAart Bik genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); 14107373cabcSAart Bik // Unmark bookkeeping of invariants and loop index. 14117373cabcSAart Bik genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); 1412c8d5dcb0SAart Bik } 1413c8d5dcb0SAart Bik 1414a2c9d4bbSAart Bik /// Recursively generates code while computing iteration lattices in order 1415a2c9d4bbSAart Bik /// to manage the complexity of implementing co-iteration over unions 1416a2c9d4bbSAart Bik /// and intersections of sparse iterations spaces. 1417a2c9d4bbSAart Bik static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1418a2c9d4bbSAart Bik linalg::GenericOp op, std::vector<unsigned> &topSort, 1419a2c9d4bbSAart Bik unsigned exp, unsigned at) { 1420a2c9d4bbSAart Bik // At each leaf, assign remaining tensor (sub)expression to output tensor. 1421a2c9d4bbSAart Bik if (at == topSort.size()) { 1422a2c9d4bbSAart Bik Value rhs = genExp(merger, codegen, rewriter, op, exp); 1423b1d44e59SAart Bik genTensorStore(merger, codegen, rewriter, op, rhs); 1424a2c9d4bbSAart Bik return; 1425a2c9d4bbSAart Bik } 1426a2c9d4bbSAart Bik 1427a2c9d4bbSAart Bik // Construct iteration lattices for current loop index, with L0 at top. 1428a2c9d4bbSAart Bik unsigned idx = topSort[at]; 1429a2c9d4bbSAart Bik unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1430c8d5dcb0SAart Bik unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1431a2c9d4bbSAart Bik 1432c8d5dcb0SAart Bik // Start a loop sequence. 1433c8d5dcb0SAart Bik bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1434c8d5dcb0SAart Bik idx, ldx, lts); 1435c8d5dcb0SAart Bik 1436c8d5dcb0SAart Bik // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1437c8d5dcb0SAart Bik unsigned lsize = merger.set(lts).size(); 1438a2c9d4bbSAart Bik for (unsigned i = 0; i < lsize; i++) { 1439c8d5dcb0SAart Bik // Start a loop. 1440a2c9d4bbSAart Bik unsigned li = merger.set(lts)[i]; 1441a2c9d4bbSAart Bik Operation *loop = 1442c8d5dcb0SAart Bik startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1443a2c9d4bbSAart Bik 1444a2c9d4bbSAart Bik // Visit all lattices points with Li >= Lj to generate the 1445a2c9d4bbSAart Bik // loop-body, possibly with if statements for coiteration. 14467373cabcSAart Bik Value ifInput = codegen.redVal; 1447a2c9d4bbSAart Bik bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1448a2c9d4bbSAart Bik for (unsigned j = 0; j < lsize; j++) { 1449a2c9d4bbSAart Bik unsigned lj = merger.set(lts)[j]; 1450a2c9d4bbSAart Bik unsigned ej = merger.lat(lj).exp; 1451a2c9d4bbSAart Bik if (li == lj || merger.latGT(li, lj)) { 1452a2c9d4bbSAart Bik // Recurse into body of each branch. 1453a2c9d4bbSAart Bik if (isWhile) { 1454a2c9d4bbSAart Bik scf::IfOp ifOp = 1455a2c9d4bbSAart Bik genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1456a2c9d4bbSAart Bik genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 14577373cabcSAart Bik endIf(merger, codegen, rewriter, op, ifOp, ifInput); 1458a2c9d4bbSAart Bik } else { 1459a2c9d4bbSAart Bik genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1460a2c9d4bbSAart Bik } 1461a2c9d4bbSAart Bik } 1462a2c9d4bbSAart Bik } 1463a2c9d4bbSAart Bik 1464c8d5dcb0SAart Bik // End a loop. 1465c8d5dcb0SAart Bik needsUniv = 1466c8d5dcb0SAart Bik endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1467a2c9d4bbSAart Bik } 1468a2c9d4bbSAart Bik 1469c8d5dcb0SAart Bik // End a loop sequence. 1470c8d5dcb0SAart Bik endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx); 1471a2c9d4bbSAart Bik } 1472a2c9d4bbSAart Bik 1473727a63e0SAart Bik /// Converts the result computed by the sparse kernel into the required form. 147436b66ab9SAart Bik static void genResult(Merger &merger, CodeGen &codegen, 147536b66ab9SAart Bik PatternRewriter &rewriter, linalg::GenericOp op) { 147636b66ab9SAart Bik OpOperand *lhs = op.getOutputOperand(0); 147736b66ab9SAart Bik Type resType = lhs->get().getType(); 1478f66e5769SAart Bik Value result; 1479f66e5769SAart Bik if (getSparseTensorEncoding(resType)) { 1480f66e5769SAart Bik // The sparse tensor rematerializes from the original sparse tensor's 1481f66e5769SAart Bik // underlying sparse storage format. 1482f66e5769SAart Bik rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1483f66e5769SAart Bik codegen.sparseOut == lhs); 148436b66ab9SAart Bik } else { 1485f66e5769SAart Bik // To rematerialize an non-annotated tensor, simply load it 148636b66ab9SAart Bik // from the bufferized value. 1487f66e5769SAart Bik Value val = codegen.buffers.back(); // value array 148857470abcSAlexander Belyaev rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 148936b66ab9SAart Bik } 1490727a63e0SAart Bik } 1491727a63e0SAart Bik 14925da21338SAart Bik //===----------------------------------------------------------------------===// 14935da21338SAart Bik // Sparse compiler rewriting methods. 14945da21338SAart Bik //===----------------------------------------------------------------------===// 14955da21338SAart Bik 1496a2c9d4bbSAart Bik namespace { 1497a2c9d4bbSAart Bik 1498a2c9d4bbSAart Bik /// Sparse rewriting rule for generic Lingalg operation. 1499a2c9d4bbSAart Bik struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1500a2c9d4bbSAart Bik public: 1501a2c9d4bbSAart Bik GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1502a2c9d4bbSAart Bik : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1503a2c9d4bbSAart Bik 1504a2c9d4bbSAart Bik LogicalResult matchAndRewrite(linalg::GenericOp op, 1505a2c9d4bbSAart Bik PatternRewriter &rewriter) const override { 1506a2c9d4bbSAart Bik // Detects sparse annotations and translate the per-dimension sparsity 1507a2c9d4bbSAart Bik // information for all tensors to loop indices in the kernel. 1508a2c9d4bbSAart Bik assert(op.getNumOutputs() == 1); 15092f2b5b7dSTobias Gysi unsigned numTensors = op.getNumInputsAndOutputs(); 1510a2c9d4bbSAart Bik unsigned numLoops = op.iterator_types().getValue().size(); 1511a2c9d4bbSAart Bik Merger merger(numTensors, numLoops); 1512bf9ef3efSAart Bik if (!findSparseAnnotations(merger, op)) 1513bf9ef3efSAart Bik return failure(); 1514a2c9d4bbSAart Bik 1515a2c9d4bbSAart Bik // Computes a topologically sorted iteration graph to ensure 1516a2c9d4bbSAart Bik // tensors are visited in natural index order. Fails on cycles. 1517a2c9d4bbSAart Bik // This assumes that higher-level passes have already put the 1518a2c9d4bbSAart Bik // tensors in each tensor expression in a feasible order. 1519a2c9d4bbSAart Bik std::vector<unsigned> topSort; 1520b6d1a31cSAart Bik if (!computeIterationGraph(merger, op, topSort, 1521b6d1a31cSAart Bik SortMask::kIncludeUndef | 1522b6d1a31cSAart Bik SortMask::kIncludeDense) && 1523b6d1a31cSAart Bik !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1524b6d1a31cSAart Bik !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1525b6d1a31cSAart Bik !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1526a2c9d4bbSAart Bik return failure(); 1527a2c9d4bbSAart Bik 1528266a7414SAart Bik // Builds the tensor expression for the Linalg operation in SSA form. 15297373cabcSAart Bik Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 15307373cabcSAart Bik if (!optExp.hasValue()) 1531266a7414SAart Bik return failure(); 15327373cabcSAart Bik unsigned exp = optExp.getValue(); 1533a2c9d4bbSAart Bik 1534266a7414SAart Bik // Rejects an inadmissable tensor expression. 1535f66e5769SAart Bik OpOperand *sparseOut = nullptr; 1536*7d4da4e1SAart Bik unsigned outerParNest = 0; 1537*7d4da4e1SAart Bik if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, 1538*7d4da4e1SAart Bik outerParNest)) 153936b66ab9SAart Bik return failure(); 154036b66ab9SAart Bik 1541a2c9d4bbSAart Bik // Recursively generates code. 1542*7d4da4e1SAart Bik merger.setHasSparseOut(sparseOut != nullptr); 1543*7d4da4e1SAart Bik CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); 1544c8d5dcb0SAart Bik genBuffers(merger, codegen, rewriter, op); 15457373cabcSAart Bik genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 154636b66ab9SAart Bik genResult(merger, codegen, rewriter, op); 1547a2c9d4bbSAart Bik return success(); 1548a2c9d4bbSAart Bik } 1549a2c9d4bbSAart Bik 1550a2c9d4bbSAart Bik private: 1551a2c9d4bbSAart Bik /// Options to control sparse code generation. 1552a2c9d4bbSAart Bik SparsificationOptions options; 1553a2c9d4bbSAart Bik }; 1554a2c9d4bbSAart Bik 1555a2c9d4bbSAart Bik } // namespace 1556a2c9d4bbSAart Bik 1557a2c9d4bbSAart Bik /// Populates the given patterns list with rewriting rules required for 1558a2c9d4bbSAart Bik /// the sparsification of linear algebra operations. 1559a2c9d4bbSAart Bik void mlir::populateSparsificationPatterns( 1560a2c9d4bbSAart Bik RewritePatternSet &patterns, const SparsificationOptions &options) { 1561a2c9d4bbSAart Bik patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1562a2c9d4bbSAart Bik } 1563