196a23911SAart Bik //===- Sparsification.cpp - Implementation of sparsification --------------===//
2a2c9d4bbSAart Bik //
3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information.
5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a2c9d4bbSAart Bik //
7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
8a2c9d4bbSAart Bik //
9160399c7SAart Bik // This file implements converting sparse tensor types to actual sparse code.
10a2c9d4bbSAart Bik //
11a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
12a2c9d4bbSAart Bik 
1376a18618SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
154397a1baSMatthias Springer #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
16a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h"
1866f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
19a2c9d4bbSAart Bik #include "mlir/Dialect/SCF/SCF.h"
2076a18618SMatthias Springer #include "mlir/Dialect/SCF/Transforms.h"
21a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23744146f6SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
24a2c9d4bbSAart Bik #include "mlir/Dialect/StandardOps/IR/Ops.h"
25a2c9d4bbSAart Bik #include "mlir/Dialect/Vector/VectorOps.h"
26a2c9d4bbSAart Bik #include "mlir/IR/Matchers.h"
2796a23911SAart Bik #include "mlir/IR/TensorEncoding.h"
28a2c9d4bbSAart Bik #include "llvm/ADT/SmallBitVector.h"
29a2c9d4bbSAart Bik 
30a2c9d4bbSAart Bik using namespace mlir;
3196a23911SAart Bik using namespace mlir::sparse_tensor;
32a2c9d4bbSAart Bik 
335da21338SAart Bik //===----------------------------------------------------------------------===//
345da21338SAart Bik // Declarations of data structures.
355da21338SAart Bik //===----------------------------------------------------------------------===//
365da21338SAart Bik 
37a2c9d4bbSAart Bik namespace {
38a2c9d4bbSAart Bik 
39b6d1a31cSAart Bik // Iteration graph sorting.
40b6d1a31cSAart Bik enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 };
41b6d1a31cSAart Bik 
425da21338SAart Bik // Reduction kinds.
437373cabcSAart Bik enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor };
445da21338SAart Bik 
45a2c9d4bbSAart Bik // Code generation.
46a2c9d4bbSAart Bik struct CodeGen {
47f66e5769SAart Bik   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops,
48f66e5769SAart Bik           OpOperand *op)
49a2c9d4bbSAart Bik       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
50a2c9d4bbSAart Bik         pointers(numTensors, std::vector<Value>(numLoops)),
51a2c9d4bbSAart Bik         indices(numTensors, std::vector<Value>(numLoops)),
52a2c9d4bbSAart Bik         highs(numTensors, std::vector<Value>(numLoops)),
53a2c9d4bbSAart Bik         pidxs(numTensors, std::vector<Value>(numLoops)),
54a2c9d4bbSAart Bik         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
55f66e5769SAart Bik         redKind(kNoReduc), sparseOut(op), lexIdx(), curVecLength(1),
56f66e5769SAart Bik         curVecMask() {}
57a2c9d4bbSAart Bik   /// Sparsification options.
5896a23911SAart Bik   SparsificationOptions options;
59a2c9d4bbSAart Bik   /// Universal dense indices and upper bounds (by index). The loops array
60a2c9d4bbSAart Bik   /// is updated with the value of the universal dense index in the current
61a2c9d4bbSAart Bik   /// loop. The sizes array is set once with the inferred dimension sizes.
62a2c9d4bbSAart Bik   std::vector<Value> loops;
63a2c9d4bbSAart Bik   std::vector<Value> sizes;
64a2c9d4bbSAart Bik   /// Buffers for storing dense and sparse numerical values (by tensor).
65a2c9d4bbSAart Bik   /// This array is set once during bufferization of all tensors.
66a2c9d4bbSAart Bik   std::vector<Value> buffers;
67a2c9d4bbSAart Bik   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
68a2c9d4bbSAart Bik   /// This array is set once during bufferization of all sparse tensors.
69a2c9d4bbSAart Bik   std::vector<std::vector<Value>> pointers;
70a2c9d4bbSAart Bik   std::vector<std::vector<Value>> indices;
71a2c9d4bbSAart Bik   /// Sparse iteration information (by tensor and index). These arrays
72a2c9d4bbSAart Bik   /// are updated to remain current within the current loop.
73a2c9d4bbSAart Bik   std::vector<std::vector<Value>> highs;
74a2c9d4bbSAart Bik   std::vector<std::vector<Value>> pidxs;
75a2c9d4bbSAart Bik   std::vector<std::vector<Value>> idxs;
76a2c9d4bbSAart Bik   /// Current reduction, updated during code generation. When indices of a
777373cabcSAart Bik   /// reduction are exhausted, all inner loops can use a scalarized reduction.
78a2c9d4bbSAart Bik   unsigned redExp;
79a2c9d4bbSAart Bik   Value redVal;
805da21338SAart Bik   Reduction redKind;
81f66e5769SAart Bik   // Sparse tensor as output.
82f66e5769SAart Bik   OpOperand *sparseOut;
83f66e5769SAart Bik   Value lexIdx;
84a2c9d4bbSAart Bik   // Current vector length and mask.
85a2c9d4bbSAart Bik   unsigned curVecLength;
86a2c9d4bbSAart Bik   Value curVecMask;
87a2c9d4bbSAart Bik };
88a2c9d4bbSAart Bik 
89a2c9d4bbSAart Bik } // namespace
90a2c9d4bbSAart Bik 
915da21338SAart Bik //===----------------------------------------------------------------------===//
925da21338SAart Bik // Sparse compiler analysis methods.
935da21338SAart Bik //===----------------------------------------------------------------------===//
945da21338SAart Bik 
955da21338SAart Bik /// Helper method to apply dimension ordering permutation.
965da21338SAart Bik static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) {
97c194b49cSAart Bik   if (enc) {
98c194b49cSAart Bik     auto order = enc.getDimOrdering();
99c194b49cSAart Bik     if (order) {
100c194b49cSAart Bik       assert(order.isPermutation());
101c194b49cSAart Bik       return order.getDimPosition(d);
102c194b49cSAart Bik     }
103c194b49cSAart Bik   }
104c194b49cSAart Bik   return d;
105c194b49cSAart Bik }
106c194b49cSAart Bik 
1075da21338SAart Bik /// Helper method to translate dim level type to internal representation.
1085da21338SAart Bik static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) {
10996a23911SAart Bik   if (enc) {
11096a23911SAart Bik     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
11196a23911SAart Bik     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
11296a23911SAart Bik       return Dim::kSparse;
11396a23911SAart Bik     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
11496a23911SAart Bik       return Dim::kSingle;
11596a23911SAart Bik   }
11696a23911SAart Bik   return Dim::kDense;
11796a23911SAart Bik }
11896a23911SAart Bik 
119b1d44e59SAart Bik /// Helper method to inspect affine expressions. Rejects cases where the
120c8d5dcb0SAart Bik /// same index is used more than once. Also rejects affine expressions
121c8d5dcb0SAart Bik /// that are not a direct index for annotated tensors.
122c8d5dcb0SAart Bik // TODO: accept more affine cases for sparse tensors
123b1d44e59SAart Bik static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim,
124b1d44e59SAart Bik                        bool isDense) {
125b1d44e59SAart Bik   switch (a.getKind()) {
126b1d44e59SAart Bik   case AffineExprKind::DimId: {
127b1d44e59SAart Bik     unsigned idx = a.cast<AffineDimExpr>().getPosition();
128b1d44e59SAart Bik     if (!merger.isDim(tensor, idx, Dim::kUndef))
129b1d44e59SAart Bik       return false; // used more than once
130b1d44e59SAart Bik     merger.setDim(tensor, idx, dim);
131b1d44e59SAart Bik     return true;
132b1d44e59SAart Bik   }
133b1d44e59SAart Bik   case AffineExprKind::Add:
134b1d44e59SAart Bik   case AffineExprKind::Mul: {
135b1d44e59SAart Bik     if (!isDense)
136b1d44e59SAart Bik       return false;
137b1d44e59SAart Bik     auto binOp = a.cast<AffineBinaryOpExpr>();
138b1d44e59SAart Bik     return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) &&
139b1d44e59SAart Bik            findAffine(merger, tensor, binOp.getRHS(), dim, isDense);
140b1d44e59SAart Bik   }
141b1d44e59SAart Bik   case AffineExprKind::Constant:
142b1d44e59SAart Bik     return isDense;
143b1d44e59SAart Bik   default:
144b1d44e59SAart Bik     return false;
145b1d44e59SAart Bik   }
146b1d44e59SAart Bik }
147b1d44e59SAart Bik 
14896a23911SAart Bik /// Helper method to inspect sparse encodings in the tensor types.
149a2c9d4bbSAart Bik /// Fills the per-dimension sparsity information for all tensors.
150b1d44e59SAart Bik /// Returns true if the sparse annotations and affine subscript
151b1d44e59SAart Bik /// expressions of all tensors are admissable. Returns false if
152b1d44e59SAart Bik /// no annotations are found or inadmissable constructs occur.
153bf9ef3efSAart Bik static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
154bf9ef3efSAart Bik   bool annotated = false;
1552f2b5b7dSTobias Gysi   for (OpOperand *t : op.getInputAndOutputOperands()) {
1562f2b5b7dSTobias Gysi     auto map = op.getTiedIndexingMap(t);
1572f2b5b7dSTobias Gysi     auto enc = getSparseTensorEncoding(t->get().getType());
158727a63e0SAart Bik     if (enc)
159bf9ef3efSAart Bik       annotated = true;
1602f2b5b7dSTobias Gysi     assert(map.getNumResults() == op.getRank(t));
161c194b49cSAart Bik     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
162b1d44e59SAart Bik       unsigned tensor = t->getOperandNumber();
163b1d44e59SAart Bik       AffineExpr a = map.getResult(perm(enc, d));
164b1d44e59SAart Bik       if (!findAffine(merger, tensor, a, toDim(enc, d), !enc))
165b1d44e59SAart Bik         return false; // inadmissable affine expression
166a2c9d4bbSAart Bik     }
167a2c9d4bbSAart Bik   }
168bf9ef3efSAart Bik   return annotated;
169a2c9d4bbSAart Bik }
170a2c9d4bbSAart Bik 
171a2c9d4bbSAart Bik /// A DFS helper to compute a topological sort. Note that recursion is
172a2c9d4bbSAart Bik /// bounded by the number of implicit loops, which is always small.
173a2c9d4bbSAart Bik /// Returns false when a cycle is detected.
174a2c9d4bbSAart Bik static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
175a2c9d4bbSAart Bik                        std::vector<unsigned> &topSort,
176a2c9d4bbSAart Bik                        std::vector<std::vector<bool>> &adjM) {
177a2c9d4bbSAart Bik   if (visit[i] != 0)
178a2c9d4bbSAart Bik     return visit[i] != 1; // 1 denotes cycle!
179a2c9d4bbSAart Bik   visit[i] = 1;
180a2c9d4bbSAart Bik   for (unsigned j = 0, e = visit.size(); j < e; j++)
181a2c9d4bbSAart Bik     if (adjM[i][j])
182a2c9d4bbSAart Bik       if (!topSortDFS(j, visit, topSort, adjM))
183a2c9d4bbSAart Bik         return false;
184a2c9d4bbSAart Bik   visit[i] = 2;
185a2c9d4bbSAart Bik   topSort.push_back(i);
186a2c9d4bbSAart Bik   return true;
187a2c9d4bbSAart Bik }
188a2c9d4bbSAart Bik 
189b1d44e59SAart Bik /// Helper method to add all constraints from the indices in one affine
190b1d44e59SAart Bik /// expression before all indices in the other affine expression. For
191b1d44e59SAart Bik /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3.
192b1d44e59SAart Bik static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
193b1d44e59SAart Bik                                AffineExpr a, AffineExpr b, unsigned fidx) {
194b1d44e59SAart Bik   switch (a.getKind()) {
195b1d44e59SAart Bik   case AffineExprKind::DimId: {
196b1d44e59SAart Bik     unsigned idx = a.cast<AffineDimExpr>().getPosition();
197b1d44e59SAart Bik     if (b)
198b1d44e59SAart Bik       addAffineOrderings(adjM, b, AffineExpr(), idx);
199b1d44e59SAart Bik     else
200b1d44e59SAart Bik       adjM[fidx][idx] = true;
201b1d44e59SAart Bik     break;
202b1d44e59SAart Bik   }
203b1d44e59SAart Bik   case AffineExprKind::Add:
204b1d44e59SAart Bik   case AffineExprKind::Mul: {
205b1d44e59SAart Bik     auto binOp = a.cast<AffineBinaryOpExpr>();
206b1d44e59SAart Bik     addAffineOrderings(adjM, binOp.getLHS(), b, fidx);
207b1d44e59SAart Bik     addAffineOrderings(adjM, binOp.getRHS(), b, fidx);
208b1d44e59SAart Bik     break;
209b1d44e59SAart Bik   }
210b1d44e59SAart Bik   default:
211b1d44e59SAart Bik     break;
212b1d44e59SAart Bik   }
213b1d44e59SAart Bik }
214b1d44e59SAart Bik 
215a2c9d4bbSAart Bik /// Computes a topologically sorted iteration graph for the linalg operation.
216a2c9d4bbSAart Bik /// Ensures all tensors are visited in natural index order. This is essential
217a2c9d4bbSAart Bik /// for sparse storage formats since these only support access along fixed
218a2c9d4bbSAart Bik /// dimensions. Even for dense storage formats, however, the natural index
219a2c9d4bbSAart Bik /// order yields innermost unit-stride access with better spatial locality.
220a2c9d4bbSAart Bik static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
221a2c9d4bbSAart Bik                                   std::vector<unsigned> &topSort,
222b6d1a31cSAart Bik                                   unsigned mask) {
223a2c9d4bbSAart Bik   // Set up an n x n from/to adjacency matrix of the iteration graph
224a2c9d4bbSAart Bik   // for the implicit loop indices i_0 .. i_n-1.
225a2c9d4bbSAart Bik   unsigned n = op.getNumLoops();
226a2c9d4bbSAart Bik   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
227a2c9d4bbSAart Bik 
228a2c9d4bbSAart Bik   // Iterate over the indexing maps of every tensor in the tensor expression.
2292f2b5b7dSTobias Gysi   for (OpOperand *t : op.getInputAndOutputOperands()) {
2302f2b5b7dSTobias Gysi     auto map = op.getTiedIndexingMap(t);
2312f2b5b7dSTobias Gysi     auto enc = getSparseTensorEncoding(t->get().getType());
232a2c9d4bbSAart Bik     assert(map.getNumDims() == n);
233b6d1a31cSAart Bik     // Skip dense tensor constraints when not requested.
234b6d1a31cSAart Bik     if (!(mask & SortMask::kIncludeDense) && !enc)
235a2c9d4bbSAart Bik       continue;
236c194b49cSAart Bik     // Each tensor expression and optional dimension ordering (row-major
237c194b49cSAart Bik     // by default) puts an ordering constraint on the loop indices. For
238c194b49cSAart Bik     // example, the tensor expresion A_ijk forces the ordering i < j < k
239c194b49cSAart Bik     // on the loop indices if no explicit dimension ordering is given.
240c194b49cSAart Bik     for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
241b1d44e59SAart Bik       AffineExpr f = map.getResult(perm(enc, d - 1));
242b1d44e59SAart Bik       AffineExpr t = map.getResult(perm(enc, d));
243b1d44e59SAart Bik       addAffineOrderings(adjM, f, t, 0);
244a2c9d4bbSAart Bik     }
245b6d1a31cSAart Bik     // Push unrelated loops into sparse iteration space, so these
246b6d1a31cSAart Bik     // will be skipped more often.
247b6d1a31cSAart Bik     if (mask & SortMask::kIncludeUndef) {
248b6d1a31cSAart Bik       unsigned tensor = t->getOperandNumber();
249b6d1a31cSAart Bik       for (unsigned i = 0; i < n; i++)
250b6d1a31cSAart Bik         if (merger.isDim(tensor, i, Dim::kSparse))
251b6d1a31cSAart Bik           for (unsigned j = 0; j < n; j++)
252b6d1a31cSAart Bik             if (merger.isDim(tensor, j, Dim::kUndef))
253b6d1a31cSAart Bik               adjM[i][j] = true;
254b6d1a31cSAart Bik     }
255a2c9d4bbSAart Bik   }
256a2c9d4bbSAart Bik 
257a2c9d4bbSAart Bik   // Topologically sort the iteration graph to determine loop order.
258a2c9d4bbSAart Bik   // Report failure for a cyclic iteration graph.
259a2c9d4bbSAart Bik   topSort.clear();
260a2c9d4bbSAart Bik   topSort.reserve(n);
261a2c9d4bbSAart Bik   std::vector<unsigned> visit(n, 0);
262a2c9d4bbSAart Bik   for (unsigned i = 0; i < n; i++)
263a2c9d4bbSAart Bik     if (visit[i] == 0)
264a2c9d4bbSAart Bik       if (!topSortDFS(i, visit, topSort, adjM))
265a2c9d4bbSAart Bik         return false; // cycle!
266a2c9d4bbSAart Bik   std::reverse(std::begin(topSort), std::end(topSort));
267a2c9d4bbSAart Bik   return true;
268a2c9d4bbSAart Bik }
269a2c9d4bbSAart Bik 
270c8d5dcb0SAart Bik /// Returns true if tensor has an in-place annotation.
271c8d5dcb0SAart Bik static bool isInPlace(Value val) {
272c8d5dcb0SAart Bik   if (auto arg = val.dyn_cast<BlockArgument>())
273c8d5dcb0SAart Bik     if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
274c8d5dcb0SAart Bik       if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
2754397a1baSMatthias Springer               arg.getArgNumber(),
2764397a1baSMatthias Springer               linalg::comprehensive_bufferize::BufferizableOpInterface::
2774397a1baSMatthias Springer                   kInplaceableAttrName))
278c8d5dcb0SAart Bik         return attr.getValue();
279c8d5dcb0SAart Bik   return false;
280c8d5dcb0SAart Bik }
281c8d5dcb0SAart Bik 
282f66e5769SAart Bik /// Returns true if tensor materializes uninitialized into the computation.
283c8d5dcb0SAart Bik static bool isMaterializing(Value val) {
284c8d5dcb0SAart Bik   return val.getDefiningOp<linalg::InitTensorOp>() ||
285c8d5dcb0SAart Bik          val.getDefiningOp<InitOp>();
286c8d5dcb0SAart Bik }
287c8d5dcb0SAart Bik 
28836b66ab9SAart Bik /// Returns true when the tensor expression is admissable for codegen.
28936b66ab9SAart Bik /// Since all sparse input tensors are admissable, we just need to check
29036b66ab9SAart Bik /// whether the output tensor in the tensor expression codegen is admissable.
291f66e5769SAart Bik /// Sets `sparseOut` when a "truly dynamic" sparse tensor output occurs.
29236b66ab9SAart Bik static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
293f66e5769SAart Bik                                   unsigned exp, OpOperand **sparseOut) {
29436b66ab9SAart Bik   OpOperand *lhs = op.getOutputOperand(0);
29536b66ab9SAart Bik   unsigned tensor = lhs->getOperandNumber();
29636b66ab9SAart Bik   auto enc = getSparseTensorEncoding(lhs->get().getType());
29736b66ab9SAart Bik   // An non-annotated output tensor is assumed dense, and becomes a random
298b1d44e59SAart Bik   // access n-dim memref. Admissable since insertions cannot occur.
29936b66ab9SAart Bik   if (!enc)
30036b66ab9SAart Bik     return true;
30136b66ab9SAart Bik   // An all-dense annotated "sparse" output tensor becomes a linearized random
30236b66ab9SAart Bik   // access 1-dim memref. Also admissable since insertions cannot occur.
30336b66ab9SAart Bik   bool allDense = true;
30436b66ab9SAart Bik   unsigned numLoops = op.iterator_types().getValue().size();
30536b66ab9SAart Bik   for (unsigned i = 0; i < numLoops; i++)
30636b66ab9SAart Bik     if (merger.isDim(tensor, i, Dim::kSparse)) {
30736b66ab9SAart Bik       allDense = false;
30836b66ab9SAart Bik       break;
30936b66ab9SAart Bik     }
31036b66ab9SAart Bik   if (allDense)
31136b66ab9SAart Bik     return true;
31236b66ab9SAart Bik   // A tensor expression with a sparse output tensor that changes its values
31336b66ab9SAart Bik   // but not its nonzero structure, an operation called "simply dynamic" in
314c8d5dcb0SAart Bik   // [Bik96,Ch9], is also admissable without special codegen, provided
315c8d5dcb0SAart Bik   // the tensor's underlying sparse storage scheme can be modified in place.
316f66e5769SAart Bik   if (merger.isConjunction(tensor, exp) && isInPlace(lhs->get()))
317f66e5769SAart Bik     return true;
318f66e5769SAart Bik   // Accept "truly dynamic" if the output tensor materializes uninitialized
319f66e5769SAart Bik   // into the computation and insertions occur in lexicographic index order.
320f66e5769SAart Bik   if (isMaterializing(lhs->get())) {
321f66e5769SAart Bik     // In this first sparse tensor output implementation, this is enforced by
322f66e5769SAart Bik     // rejecting any reduction loops (since the sparse parallel loops give a
323f66e5769SAart Bik     // lexicographically sorted and injective view into that tensor).
324f66e5769SAart Bik     // TODO: generalize to include reductions
325f66e5769SAart Bik     for (auto attr : op.iterator_types())
326f66e5769SAart Bik       if (isReductionIterator(attr))
327f66e5769SAart Bik         return false;
328f66e5769SAart Bik     *sparseOut = lhs;
329f66e5769SAart Bik     return true;
330f66e5769SAart Bik   }
33136b66ab9SAart Bik   return false;
33236b66ab9SAart Bik }
33336b66ab9SAart Bik 
3345da21338SAart Bik //===----------------------------------------------------------------------===//
3357373cabcSAart Bik // Sparse compiler synthesis methods (reductions).
3365da21338SAart Bik //===----------------------------------------------------------------------===//
3375da21338SAart Bik 
3385da21338SAart Bik /// Maps reduction kind to name encoding.
3395da21338SAart Bik static StringRef getReductionName(Reduction kind) {
3405da21338SAart Bik   switch (kind) {
3417373cabcSAart Bik   case kNoReduc:
3427373cabcSAart Bik     break;
3435da21338SAart Bik   case kSum:
3445da21338SAart Bik     return "add";
3455da21338SAart Bik   case kProduct:
3465da21338SAart Bik     return "mul";
3475da21338SAart Bik   case kAnd:
3485da21338SAart Bik     return "and";
3495da21338SAart Bik   case kOr:
3505da21338SAart Bik     return "or";
3515da21338SAart Bik   case kXor:
3525da21338SAart Bik     return "xor";
3535da21338SAart Bik   }
3545da21338SAart Bik   llvm_unreachable("unknown reduction kind");
3555da21338SAart Bik }
3565da21338SAart Bik 
3575da21338SAart Bik /// Maps operation to reduction.
3585da21338SAart Bik static Reduction getReduction(Kind kind) {
3595da21338SAart Bik   switch (kind) {
3605da21338SAart Bik   case Kind::kAddF:
3615da21338SAart Bik   case Kind::kAddI:
3625da21338SAart Bik   case Kind::kSubF:
3635da21338SAart Bik   case Kind::kSubI:
3645da21338SAart Bik     return kSum;
3655da21338SAart Bik   case Kind::kMulF:
3665da21338SAart Bik   case Kind::kMulI:
3675da21338SAart Bik     return kProduct;
3685da21338SAart Bik   case Kind::kAndI:
3695da21338SAart Bik     return kAnd;
3705da21338SAart Bik   case Kind::kOrI:
3715da21338SAart Bik     return kOr;
3725da21338SAart Bik   case Kind::kXorI:
3735da21338SAart Bik     return kXor;
3745da21338SAart Bik   default:
3755da21338SAart Bik     llvm_unreachable("unexpected reduction operator");
3765da21338SAart Bik   }
3775da21338SAart Bik }
3785da21338SAart Bik 
3797373cabcSAart Bik /// Generates an initial value for a vector reduction, following the scheme
3805da21338SAart Bik /// given in Chapter 5 of "The Software Vectorization Handbook", where the
3815da21338SAart Bik /// initial scalar value is correctly embedded in the vector reduction value,
3825da21338SAart Bik /// and a straightforward horizontal reduction will complete the operation.
3837373cabcSAart Bik static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter,
3847373cabcSAart Bik                                 Location loc, VectorType vtp) {
3857373cabcSAart Bik   Value r = codegen.redVal;
3867373cabcSAart Bik   switch (codegen.redKind) {
3877373cabcSAart Bik   case kNoReduc:
3887373cabcSAart Bik     break;
3895da21338SAart Bik   case kSum:
3905da21338SAart Bik   case kXor: {
3915da21338SAart Bik     // Initialize reduction vector to: | 0 | .. | 0 | r |
3925da21338SAart Bik     Attribute zero = rewriter.getZeroAttr(vtp);
393c8d5dcb0SAart Bik     Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero);
394*7c5ecc8bSMogball     return rewriter.create<vector::InsertElementOp>(
395*7c5ecc8bSMogball         loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0));
3965da21338SAart Bik   }
3975da21338SAart Bik   case kProduct: {
3985da21338SAart Bik     // Initialize reduction vector to: | 1 | .. | 1 | r |
3995da21338SAart Bik     Type etp = vtp.getElementType();
4005da21338SAart Bik     Attribute one;
4015da21338SAart Bik     if (etp.isa<FloatType>())
4025da21338SAart Bik       one = rewriter.getFloatAttr(etp, 1.0);
4035da21338SAart Bik     else
4045da21338SAart Bik       one = rewriter.getIntegerAttr(etp, 1);
405c8d5dcb0SAart Bik     Value vec = rewriter.create<arith::ConstantOp>(
406c8d5dcb0SAart Bik         loc, vtp, DenseElementsAttr::get(vtp, one));
407*7c5ecc8bSMogball     return rewriter.create<vector::InsertElementOp>(
408*7c5ecc8bSMogball         loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0));
4095da21338SAart Bik   }
4105da21338SAart Bik   case kAnd:
4115da21338SAart Bik   case kOr:
4125da21338SAart Bik     // Initialize reduction vector to: | r | .. | r | r |
4135da21338SAart Bik     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
4145da21338SAart Bik   }
4155da21338SAart Bik   llvm_unreachable("unknown reduction kind");
4165da21338SAart Bik }
4175da21338SAart Bik 
4187373cabcSAart Bik /// Generates final value for a vector reduction.
4197373cabcSAart Bik static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter,
4207373cabcSAart Bik                                Location loc, VectorType vtp) {
4217373cabcSAart Bik   StringRef name = getReductionName(codegen.redKind);
4227373cabcSAart Bik   StringAttr kind = rewriter.getStringAttr(name);
4237373cabcSAart Bik   return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind,
4247373cabcSAart Bik                                               codegen.redVal, ValueRange{});
4257373cabcSAart Bik }
4267373cabcSAart Bik 
4277373cabcSAart Bik /// Updates scalarized reduction value.
4287373cabcSAart Bik static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) {
4297373cabcSAart Bik   assert(codegen.redKind != kNoReduc);
4307373cabcSAart Bik   codegen.redVal = merger.exp(codegen.redExp).val = reduc;
4317373cabcSAart Bik }
4327373cabcSAart Bik 
4337373cabcSAart Bik //===----------------------------------------------------------------------===//
4347373cabcSAart Bik // Sparse compiler synthesis methods (statements and expressions).
4357373cabcSAart Bik //===----------------------------------------------------------------------===//
4367373cabcSAart Bik 
437a2c9d4bbSAart Bik /// Maps sparse integer option to actual integral storage type.
43896a23911SAart Bik static Type genIntType(PatternRewriter &rewriter, unsigned width) {
43996a23911SAart Bik   if (width == 0)
440a2c9d4bbSAart Bik     return rewriter.getIndexType();
44196a23911SAart Bik   return rewriter.getIntegerType(width);
442a2c9d4bbSAart Bik }
443a2c9d4bbSAart Bik 
444ec97a205SAart Bik /// Generates buffer for the output tensor. Note that all sparse kernels
445ec97a205SAart Bik /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)),
446ec97a205SAart Bik /// the output buffer is already initialized to all zeroes and only nonzeroes
447ec97a205SAart Bik /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)),
448ec97a205SAart Bik /// only nonzeroes values are used for the updates and no assumption on the
449ec97a205SAart Bik /// original contents of the output buffer is necessary..
450a2c9d4bbSAart Bik static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
451a2c9d4bbSAart Bik                              linalg::GenericOp op, MemRefType denseTp,
452a2c9d4bbSAart Bik                              ArrayRef<Value> args) {
453a2c9d4bbSAart Bik   Location loc = op.getLoc();
4542f2b5b7dSTobias Gysi   Value tensor = op.getOutputOperand(0)->get();
455a2c9d4bbSAart Bik   // The output tensor simply could materialize from the buffer that will
456a2c9d4bbSAart Bik   // be generated for the tensor present in the outs() clause. This has
457a2c9d4bbSAart Bik   // the major advantage that the sparse kernel only updates the nonzero
4585879da49SAart Bik   // positions for the output tensor.
459c8d5dcb0SAart Bik   if (isInPlace(tensor))
460a2c9d4bbSAart Bik     return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
461a2c9d4bbSAart Bik   // By default, a new buffer is allocated which is initialized to the
462a2c9d4bbSAart Bik   // tensor defined in the outs() clause. This is always correct but
463a2c9d4bbSAart Bik   // introduces a dense initialization component that may negatively
464ec97a205SAart Bik   // impact the running complexity of the sparse kernel. If the tensor
465c8d5dcb0SAart Bik   // materializes into the computation, we need to preserve the zero
466ec97a205SAart Bik   // initialization assumption of all sparse output buffers.
467c8d5dcb0SAart Bik   if (isMaterializing(tensor)) {
468ec97a205SAart Bik     Type tp = denseTp.getElementType();
469ec97a205SAart Bik     Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
470c8d5dcb0SAart Bik     Value zero =
471c8d5dcb0SAart Bik         rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
472ec97a205SAart Bik     rewriter.create<linalg::FillOp>(loc, zero, alloc);
473ec97a205SAart Bik     return alloc;
474ec97a205SAart Bik   }
475a2c9d4bbSAart Bik   Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
476a2c9d4bbSAart Bik   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
47768ac2e53SAart Bik   rewriter.create<memref::CopyOp>(loc, init, alloc);
478a2c9d4bbSAart Bik   return alloc;
479a2c9d4bbSAart Bik }
480a2c9d4bbSAart Bik 
481a2c9d4bbSAart Bik /// Local bufferization of all dense and sparse data structures.
482a2c9d4bbSAart Bik /// This code enables testing the first prototype sparse compiler.
483a2c9d4bbSAart Bik // TODO: replace this with a proliferated bufferization strategy
484c8d5dcb0SAart Bik static void genBuffers(Merger &merger, CodeGen &codegen,
485a2c9d4bbSAart Bik                        PatternRewriter &rewriter, linalg::GenericOp op) {
486a2c9d4bbSAart Bik   Location loc = op.getLoc();
4872f2b5b7dSTobias Gysi   assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
488a2c9d4bbSAart Bik   // For every tensor, find lower and upper bound on dimensions, set the
489a2c9d4bbSAart Bik   // same bounds on loop indices, and obtain dense or sparse buffer(s).
490a2c9d4bbSAart Bik   SmallVector<Value, 4> args;
4912f2b5b7dSTobias Gysi   for (OpOperand *t : op.getInputAndOutputOperands()) {
492727a63e0SAart Bik     unsigned tensor = t->getOperandNumber();
4932f2b5b7dSTobias Gysi     auto shape = op.getShape(t);
4942f2b5b7dSTobias Gysi     auto map = op.getTiedIndexingMap(t);
4952f2b5b7dSTobias Gysi     auto enc = getSparseTensorEncoding(t->get().getType());
496a2c9d4bbSAart Bik     // Scan all dimensions of current tensor.
497a2c9d4bbSAart Bik     args.clear();
498c194b49cSAart Bik     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
499b1d44e59SAart Bik       AffineExpr a = map.getResult(perm(enc, d));
500b1d44e59SAart Bik       if (a.getKind() != AffineExprKind::DimId)
501b1d44e59SAart Bik         continue; // compound
502b1d44e59SAart Bik       unsigned idx = a.cast<AffineDimExpr>().getPosition();
503a2c9d4bbSAart Bik       // Handle sparse storage schemes.
504727a63e0SAart Bik       if (merger.isDim(tensor, idx, Dim::kSparse)) {
505a2c9d4bbSAart Bik         auto dynShape = {ShapedType::kDynamicSize};
506a2c9d4bbSAart Bik         auto ptrTp = MemRefType::get(
50796a23911SAart Bik             dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
508a2c9d4bbSAart Bik         auto indTp = MemRefType::get(
50996a23911SAart Bik             dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
510a54f4eaeSMogball         Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d);
511a2c9d4bbSAart Bik         // Generate sparse primitives to obtains pointer and indices.
512727a63e0SAart Bik         codegen.pointers[tensor][idx] =
5132f2b5b7dSTobias Gysi             rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
514727a63e0SAart Bik         codegen.indices[tensor][idx] =
5152f2b5b7dSTobias Gysi             rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim);
516a2c9d4bbSAart Bik       }
517d37d72eaSAart Bik       // Find upper bound in current dimension.
518817303efSAart Bik       unsigned p = perm(enc, d);
519d37d72eaSAart Bik       Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p);
520d37d72eaSAart Bik       if (shape[p] == MemRefType::kDynamicSize)
521a2c9d4bbSAart Bik         args.push_back(up);
522817303efSAart Bik       assert(codegen.highs[tensor][idx] == nullptr);
523727a63e0SAart Bik       codegen.sizes[idx] = codegen.highs[tensor][idx] = up;
524a2c9d4bbSAart Bik     }
525727a63e0SAart Bik     // Perform the required bufferization. Dense inputs materialize
526727a63e0SAart Bik     // from the input tensors. Dense outputs need special handling.
527727a63e0SAart Bik     // Sparse inputs use sparse primitives to obtain the values.
528727a63e0SAart Bik     // We also accept in-place all-dense annotated "sparse" outputs.
5292f2b5b7dSTobias Gysi     Type elementType = getElementTypeOrSelf(t->get().getType());
53096a23911SAart Bik     if (!enc) {
531727a63e0SAart Bik       // Non-annotated dense tensors.
5322f2b5b7dSTobias Gysi       auto denseTp = MemRefType::get(shape, elementType);
533727a63e0SAart Bik       if (tensor < op.getNumInputs())
534727a63e0SAart Bik         codegen.buffers[tensor] =
5352f2b5b7dSTobias Gysi             rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get());
536a2c9d4bbSAart Bik       else
537727a63e0SAart Bik         codegen.buffers[tensor] =
538a2c9d4bbSAart Bik             genOutputBuffer(codegen, rewriter, op, denseTp, args);
539f66e5769SAart Bik     } else if (t == codegen.sparseOut) {
540f66e5769SAart Bik       // True sparse output needs a lexIdx array.
541f66e5769SAart Bik       Value rank = rewriter.create<arith::ConstantIndexOp>(loc, op.getRank(t));
542f66e5769SAart Bik       auto dynShape = {ShapedType::kDynamicSize};
543f66e5769SAart Bik       auto memTp = MemRefType::get(dynShape, rewriter.getIndexType());
544f66e5769SAart Bik       codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank);
545a2c9d4bbSAart Bik     } else {
546727a63e0SAart Bik       // Annotated sparse tensors.
547a2c9d4bbSAart Bik       auto dynShape = {ShapedType::kDynamicSize};
5482f2b5b7dSTobias Gysi       auto sparseTp = MemRefType::get(dynShape, elementType);
549727a63e0SAart Bik       codegen.buffers[tensor] =
5502f2b5b7dSTobias Gysi           rewriter.create<ToValuesOp>(loc, sparseTp, t->get());
551a2c9d4bbSAart Bik     }
552a2c9d4bbSAart Bik   }
553a2c9d4bbSAart Bik }
554a2c9d4bbSAart Bik 
555a2c9d4bbSAart Bik /// Constructs vector type.
556a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Type etp) {
557a2c9d4bbSAart Bik   return VectorType::get(codegen.curVecLength, etp);
558a2c9d4bbSAart Bik }
559a2c9d4bbSAart Bik 
560a2c9d4bbSAart Bik /// Constructs vector type from pointer.
561a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Value ptr) {
562a2c9d4bbSAart Bik   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
563a2c9d4bbSAart Bik }
564a2c9d4bbSAart Bik 
565a2c9d4bbSAart Bik /// Constructs vector iteration mask.
566a2c9d4bbSAart Bik static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
567a2c9d4bbSAart Bik                            Value iv, Value lo, Value hi, Value step) {
568a2c9d4bbSAart Bik   Location loc = iv.getLoc();
5697373cabcSAart Bik   VectorType mtp = vectorType(codegen, genIntType(rewriter, 1));
570a2c9d4bbSAart Bik   // Special case if the vector length evenly divides the trip count (for
571a2c9d4bbSAart Bik   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
572a2c9d4bbSAart Bik   // so that all subsequent masked memory operations are immediately folded
573a2c9d4bbSAart Bik   // into unconditional memory operations.
574a2c9d4bbSAart Bik   IntegerAttr loInt, hiInt, stepInt;
575a2c9d4bbSAart Bik   if (matchPattern(lo, m_Constant(&loInt)) &&
576a2c9d4bbSAart Bik       matchPattern(hi, m_Constant(&hiInt)) &&
577a2c9d4bbSAart Bik       matchPattern(step, m_Constant(&stepInt))) {
578a2c9d4bbSAart Bik     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
579a2c9d4bbSAart Bik       return rewriter.create<vector::BroadcastOp>(
580a54f4eaeSMogball           loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1));
581a2c9d4bbSAart Bik   }
582a2c9d4bbSAart Bik   // Otherwise, generate a vector mask that avoids overrunning the upperbound
583a2c9d4bbSAart Bik   // during vector execution. Here we rely on subsequent loop optimizations to
584a2c9d4bbSAart Bik   // avoid executing the mask in all iterations, for example, by splitting the
585a2c9d4bbSAart Bik   // loop into an unconditional vector loop and a scalar cleanup loop.
58676a18618SMatthias Springer   auto minMap = AffineMap::get(
58776a18618SMatthias Springer       /*dimCount=*/2, /*symbolCount=*/1,
58876a18618SMatthias Springer       {rewriter.getAffineSymbolExpr(0),
58976a18618SMatthias Springer        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
59076a18618SMatthias Springer       rewriter.getContext());
59176a18618SMatthias Springer   Value end =
59276a18618SMatthias Springer       rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step});
593a2c9d4bbSAart Bik   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
594a2c9d4bbSAart Bik }
595a2c9d4bbSAart Bik 
596a2c9d4bbSAart Bik /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
597a2c9d4bbSAart Bik static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
598a2c9d4bbSAart Bik                            Value ptr, ArrayRef<Value> args) {
599a2c9d4bbSAart Bik   Location loc = ptr.getLoc();
600a2c9d4bbSAart Bik   VectorType vtp = vectorType(codegen, ptr);
601a54f4eaeSMogball   Value pass =
602a54f4eaeSMogball       rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
603a2c9d4bbSAart Bik   if (args.back().getType().isa<VectorType>()) {
604a2c9d4bbSAart Bik     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
605a2c9d4bbSAart Bik     Value indexVec = args.back();
606a54f4eaeSMogball     scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0);
607a2c9d4bbSAart Bik     return rewriter.create<vector::GatherOp>(
608a2c9d4bbSAart Bik         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
609a2c9d4bbSAart Bik   }
610a2c9d4bbSAart Bik   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
611a2c9d4bbSAart Bik                                                codegen.curVecMask, pass);
612a2c9d4bbSAart Bik }
613a2c9d4bbSAart Bik 
614a2c9d4bbSAart Bik /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
615a2c9d4bbSAart Bik static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
616a2c9d4bbSAart Bik                            Value rhs, Value ptr, ArrayRef<Value> args) {
617a2c9d4bbSAart Bik   Location loc = ptr.getLoc();
618a2c9d4bbSAart Bik   if (args.back().getType().isa<VectorType>()) {
619a2c9d4bbSAart Bik     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
620a2c9d4bbSAart Bik     Value indexVec = args.back();
621a54f4eaeSMogball     scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0);
622a2c9d4bbSAart Bik     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
623a2c9d4bbSAart Bik                                        codegen.curVecMask, rhs);
624a2c9d4bbSAart Bik     return;
625a2c9d4bbSAart Bik   }
626a2c9d4bbSAart Bik   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
627a2c9d4bbSAart Bik                                          rhs);
628a2c9d4bbSAart Bik }
629a2c9d4bbSAart Bik 
630a2c9d4bbSAart Bik /// Generates a vectorized invariant. Here we rely on subsequent loop
631a2c9d4bbSAart Bik /// optimizations to hoist the invariant broadcast out of the vector loop.
632a2c9d4bbSAart Bik static Value genVectorInvariantValue(CodeGen &codegen,
633a2c9d4bbSAart Bik                                      PatternRewriter &rewriter, Value val) {
634a2c9d4bbSAart Bik   VectorType vtp = vectorType(codegen, val.getType());
635a2c9d4bbSAart Bik   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
636a2c9d4bbSAart Bik }
637a2c9d4bbSAart Bik 
638b1d44e59SAart Bik /// Generates an affine expression.
639b1d44e59SAart Bik //
640b1d44e59SAart Bik // TODO: generalize for sparse tensor subscripts
641b1d44e59SAart Bik //
642b1d44e59SAart Bik static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter,
643b1d44e59SAart Bik                        AffineExpr a, Location loc) {
644b1d44e59SAart Bik   switch (a.getKind()) {
645b1d44e59SAart Bik   case AffineExprKind::DimId: {
646b1d44e59SAart Bik     unsigned idx = a.cast<AffineDimExpr>().getPosition();
647b1d44e59SAart Bik     return codegen.loops[idx]; // universal dense index
648b1d44e59SAart Bik   }
649b1d44e59SAart Bik   case AffineExprKind::Add: {
650b1d44e59SAart Bik     auto binOp = a.cast<AffineBinaryOpExpr>();
651a54f4eaeSMogball     return rewriter.create<arith::AddIOp>(
652b1d44e59SAart Bik         loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
653b1d44e59SAart Bik         genAffine(codegen, rewriter, binOp.getRHS(), loc));
654b1d44e59SAart Bik   }
655b1d44e59SAart Bik   case AffineExprKind::Mul: {
656b1d44e59SAart Bik     auto binOp = a.cast<AffineBinaryOpExpr>();
657a54f4eaeSMogball     return rewriter.create<arith::MulIOp>(
658b1d44e59SAart Bik         loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
659b1d44e59SAart Bik         genAffine(codegen, rewriter, binOp.getRHS(), loc));
660b1d44e59SAart Bik   }
661b1d44e59SAart Bik   case AffineExprKind::Constant: {
662b1d44e59SAart Bik     int64_t c = a.cast<AffineConstantExpr>().getValue();
663a54f4eaeSMogball     return rewriter.create<arith::ConstantIndexOp>(loc, c);
664b1d44e59SAart Bik   }
665b1d44e59SAart Bik   default:
666b1d44e59SAart Bik     llvm_unreachable("unexpected affine subscript");
667b1d44e59SAart Bik   }
668b1d44e59SAart Bik }
669b1d44e59SAart Bik 
670b1d44e59SAart Bik /// Generates subscript for load/store on a dense or sparse tensor.
671b1d44e59SAart Bik static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter,
672b1d44e59SAart Bik                           linalg::GenericOp op, OpOperand *t,
673b1d44e59SAart Bik                           SmallVector<Value, 4> &args) {
674b1d44e59SAart Bik   unsigned tensor = t->getOperandNumber();
675b1d44e59SAart Bik   auto map = op.getTiedIndexingMap(t);
676b1d44e59SAart Bik   auto enc = getSparseTensorEncoding(t->get().getType());
677b1d44e59SAart Bik   unsigned rank = map.getNumResults();
678b1d44e59SAart Bik   if (enc) {
679b1d44e59SAart Bik     // Note that currently, all sparse subscripts are simple.
680b1d44e59SAart Bik     // TODO: accept affine too?
681c8d5dcb0SAart Bik     AffineExpr a = map.getResult(perm(enc, rank - 1));
682c8d5dcb0SAart Bik     assert(a.getKind() == AffineExprKind::DimId);
683c8d5dcb0SAart Bik     unsigned idx = a.cast<AffineDimExpr>().getPosition();
684b1d44e59SAart Bik     assert(codegen.pidxs[tensor][idx] != nullptr);
685b1d44e59SAart Bik     args.push_back(codegen.pidxs[tensor][idx]); // position index
686b1d44e59SAart Bik   } else {
687b1d44e59SAart Bik     for (unsigned d = 0; d < rank; d++) {
688b1d44e59SAart Bik       AffineExpr a = map.getResult(perm(enc, d));
689b1d44e59SAart Bik       args.push_back(genAffine(codegen, rewriter, a, op.getLoc()));
690b1d44e59SAart Bik     }
691b1d44e59SAart Bik   }
692b1d44e59SAart Bik   return codegen.buffers[tensor];
693b1d44e59SAart Bik }
694b1d44e59SAart Bik 
695a2c9d4bbSAart Bik /// Generates a load on a dense or sparse tensor.
696a2c9d4bbSAart Bik static Value genTensorLoad(Merger &merger, CodeGen &codegen,
697a2c9d4bbSAart Bik                            PatternRewriter &rewriter, linalg::GenericOp op,
698a2c9d4bbSAart Bik                            unsigned exp) {
699a2c9d4bbSAart Bik   // Test if the load was hoisted to a higher loop nest.
700a2c9d4bbSAart Bik   Value val = merger.exp(exp).val;
701a2c9d4bbSAart Bik   if (val) {
702a2c9d4bbSAart Bik     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
703a2c9d4bbSAart Bik       return genVectorInvariantValue(codegen, rewriter, val);
704a2c9d4bbSAart Bik     return val;
705a2c9d4bbSAart Bik   }
706a2c9d4bbSAart Bik   // Actual load.
707a2c9d4bbSAart Bik   SmallVector<Value, 4> args;
7084569c14aSGus Smith   OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
709b1d44e59SAart Bik   Value ptr = genSubscript(codegen, rewriter, op, t, args);
710a2c9d4bbSAart Bik   if (codegen.curVecLength > 1)
711a2c9d4bbSAart Bik     return genVectorLoad(codegen, rewriter, ptr, args);
712b1d44e59SAart Bik   return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
713a2c9d4bbSAart Bik }
714a2c9d4bbSAart Bik 
715727a63e0SAart Bik /// Generates a store on a dense or sparse tensor.
716a2c9d4bbSAart Bik static void genTensorStore(Merger &merger, CodeGen &codegen,
717a2c9d4bbSAart Bik                            PatternRewriter &rewriter, linalg::GenericOp op,
718b1d44e59SAart Bik                            Value rhs) {
719f66e5769SAart Bik   Location loc = op.getLoc();
720a2c9d4bbSAart Bik   // Test if this is a scalarized reduction.
721b1d44e59SAart Bik   if (codegen.redVal) {
722a2c9d4bbSAart Bik     if (codegen.curVecLength > 1)
723f66e5769SAart Bik       rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs,
724a2c9d4bbSAart Bik                                       codegen.redVal);
7257373cabcSAart Bik     updateReduc(merger, codegen, rhs);
726a2c9d4bbSAart Bik     return;
727a2c9d4bbSAart Bik   }
728f66e5769SAart Bik   // Insertion.
729f66e5769SAart Bik   OpOperand *t = op.getOutputOperand(0);
730f66e5769SAart Bik   if (t == codegen.sparseOut) {
731f66e5769SAart Bik     rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs);
732f66e5769SAart Bik     return;
733f66e5769SAart Bik   }
734a2c9d4bbSAart Bik   // Actual store.
735a2c9d4bbSAart Bik   SmallVector<Value, 4> args;
736b1d44e59SAart Bik   Value ptr = genSubscript(codegen, rewriter, op, t, args);
737a2c9d4bbSAart Bik   if (codegen.curVecLength > 1)
738a2c9d4bbSAart Bik     genVectorStore(codegen, rewriter, rhs, ptr, args);
739a2c9d4bbSAart Bik   else
740f66e5769SAart Bik     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
741a2c9d4bbSAart Bik }
742a2c9d4bbSAart Bik 
743a2c9d4bbSAart Bik /// Generates a pointer/index load from the sparse storage scheme. Narrower
744a2c9d4bbSAart Bik /// data types need to be zero extended before casting the value into the
745a2c9d4bbSAart Bik /// index type used for looping and indexing.
746a2c9d4bbSAart Bik static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
747a2c9d4bbSAart Bik                      Value ptr, Value s) {
748a2c9d4bbSAart Bik   // See https://llvm.org/docs/GetElementPtr.html for some background on
749a2c9d4bbSAart Bik   // the complications described below.
750a2c9d4bbSAart Bik   if (codegen.curVecLength > 1) {
751a2c9d4bbSAart Bik     // Since the index vector is used in a subsequent gather/scatter operations,
752a2c9d4bbSAart Bik     // which effectively defines an unsigned pointer + signed index, we must
753a2c9d4bbSAart Bik     // zero extend the vector to an index width. For 8-bit and 16-bit values,
754a2c9d4bbSAart Bik     // an 32-bit index width suffices. For 32-bit values, zero extending the
755a2c9d4bbSAart Bik     // elements into 64-bit loses some performance since the 32-bit indexed
75686e9bc1aSAart Bik     // gather/scatter is more efficient than the 64-bit index variant (if the
75786e9bc1aSAart Bik     // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
758727a63e0SAart Bik     // preserve this performance). For 64-bit values, there is no good way
759a2c9d4bbSAart Bik     // to state that the indices are unsigned, with creates the potential of
760a2c9d4bbSAart Bik     // incorrect address calculations in the unlikely case we need such
761a2c9d4bbSAart Bik     // extremely large offsets.
762a2c9d4bbSAart Bik     Type etp = ptr.getType().cast<MemRefType>().getElementType();
763a2c9d4bbSAart Bik     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
764a2c9d4bbSAart Bik     if (!etp.isa<IndexType>()) {
765a2c9d4bbSAart Bik       if (etp.getIntOrFloatBitWidth() < 32)
766a54f4eaeSMogball         vload = rewriter.create<arith::ExtUIOp>(
7677373cabcSAart Bik             loc, vload, vectorType(codegen, genIntType(rewriter, 32)));
76886e9bc1aSAart Bik       else if (etp.getIntOrFloatBitWidth() < 64 &&
76986e9bc1aSAart Bik                !codegen.options.enableSIMDIndex32)
770a54f4eaeSMogball         vload = rewriter.create<arith::ExtUIOp>(
7717373cabcSAart Bik             loc, vload, vectorType(codegen, genIntType(rewriter, 64)));
772a2c9d4bbSAart Bik     }
773a2c9d4bbSAart Bik     return vload;
774a2c9d4bbSAart Bik   }
775a2c9d4bbSAart Bik   // For the scalar case, we simply zero extend narrower indices into 64-bit
776a2c9d4bbSAart Bik   // values before casting to index without a performance penalty. Here too,
777a2c9d4bbSAart Bik   // however, indices that already are 64-bit, in theory, cannot express the
778a2c9d4bbSAart Bik   // full range as explained above.
779a2c9d4bbSAart Bik   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
780a2c9d4bbSAart Bik   if (!load.getType().isa<IndexType>()) {
781a2c9d4bbSAart Bik     if (load.getType().getIntOrFloatBitWidth() < 64)
7827373cabcSAart Bik       load =
7837373cabcSAart Bik           rewriter.create<arith::ExtUIOp>(loc, load, genIntType(rewriter, 64));
784a54f4eaeSMogball     load =
785a54f4eaeSMogball         rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType());
786a2c9d4bbSAart Bik   }
787a2c9d4bbSAart Bik   return load;
788a2c9d4bbSAart Bik }
789a2c9d4bbSAart Bik 
790a2c9d4bbSAart Bik /// Generates an invariant value.
791a2c9d4bbSAart Bik static Value genInvariantValue(Merger &merger, CodeGen &codegen,
792a2c9d4bbSAart Bik                                PatternRewriter &rewriter, unsigned exp) {
793a2c9d4bbSAart Bik   Value val = merger.exp(exp).val;
794a2c9d4bbSAart Bik   if (codegen.curVecLength > 1)
795a2c9d4bbSAart Bik     return genVectorInvariantValue(codegen, rewriter, val);
796a2c9d4bbSAart Bik   return val;
797a2c9d4bbSAart Bik }
798a2c9d4bbSAart Bik 
799a2c9d4bbSAart Bik /// Generates an address computation "sz * p + i".
800a2c9d4bbSAart Bik static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
801a2c9d4bbSAart Bik                         Location loc, Value size, Value p, Value i) {
802a54f4eaeSMogball   Value mul = rewriter.create<arith::MulIOp>(loc, size, p);
803a2c9d4bbSAart Bik   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
804a54f4eaeSMogball     Value inv =
805a54f4eaeSMogball         rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType());
806a2c9d4bbSAart Bik     mul = genVectorInvariantValue(codegen, rewriter, inv);
807a2c9d4bbSAart Bik   }
808a54f4eaeSMogball   return rewriter.create<arith::AddIOp>(loc, mul, i);
809a2c9d4bbSAart Bik }
810a2c9d4bbSAart Bik 
811a2c9d4bbSAart Bik /// Recursively generates tensor expression.
812a2c9d4bbSAart Bik static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
813a2c9d4bbSAart Bik                     linalg::GenericOp op, unsigned exp) {
814b8a021dbSAart Bik   Location loc = op.getLoc();
815123e8dfcSAart Bik   if (exp == -1u)
816123e8dfcSAart Bik     return Value();
817a2c9d4bbSAart Bik   if (merger.exp(exp).kind == Kind::kTensor)
818a2c9d4bbSAart Bik     return genTensorLoad(merger, codegen, rewriter, op, exp);
819b8a021dbSAart Bik   if (merger.exp(exp).kind == Kind::kInvariant)
820a2c9d4bbSAart Bik     return genInvariantValue(merger, codegen, rewriter, exp);
8214569c14aSGus Smith   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
8224569c14aSGus Smith   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
82345b3cfe8SAart Bik   return merger.buildExp(rewriter, loc, exp, v0, v1);
824a2c9d4bbSAart Bik }
825a2c9d4bbSAart Bik 
826b1d44e59SAart Bik /// Determines if affine expression is invariant.
827b1d44e59SAart Bik static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
828b1d44e59SAart Bik                               unsigned ldx, bool &atLevel) {
829b1d44e59SAart Bik   switch (a.getKind()) {
830b1d44e59SAart Bik   case AffineExprKind::DimId: {
831b1d44e59SAart Bik     unsigned idx = a.cast<AffineDimExpr>().getPosition();
832b1d44e59SAart Bik     if (idx == ldx)
833b1d44e59SAart Bik       atLevel = true;
834b1d44e59SAart Bik     return codegen.loops[idx] != nullptr; // no longer in play?
835b1d44e59SAart Bik   }
836b1d44e59SAart Bik   case AffineExprKind::Add:
837b1d44e59SAart Bik   case AffineExprKind::Mul: {
838b1d44e59SAart Bik     auto binOp = a.cast<AffineBinaryOpExpr>();
839b1d44e59SAart Bik     return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) &&
840b1d44e59SAart Bik            isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel);
841b1d44e59SAart Bik   }
842b1d44e59SAart Bik   default:
843b1d44e59SAart Bik     return true;
844b1d44e59SAart Bik   }
845b1d44e59SAart Bik }
846b1d44e59SAart Bik 
847a2c9d4bbSAart Bik /// Hoists loop invariant tensor loads for which indices have been exhausted.
848a2c9d4bbSAart Bik static void genInvariants(Merger &merger, CodeGen &codegen,
849a2c9d4bbSAart Bik                           PatternRewriter &rewriter, linalg::GenericOp op,
8507373cabcSAart Bik                           unsigned exp, unsigned ldx, bool atStart,
8515da21338SAart Bik                           Kind last = Kind::kTensor) {
852123e8dfcSAart Bik   if (exp == -1u)
853123e8dfcSAart Bik     return;
854a2c9d4bbSAart Bik   if (merger.exp(exp).kind == Kind::kTensor) {
855a2c9d4bbSAart Bik     // Inspect tensor indices.
856a2c9d4bbSAart Bik     bool atLevel = ldx == -1u;
8574569c14aSGus Smith     OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
858619bfe8bSAart Bik     auto map = op.getTiedIndexingMap(t);
859619bfe8bSAart Bik     auto enc = getSparseTensorEncoding(t->get().getType());
860c194b49cSAart Bik     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
861b1d44e59SAart Bik       AffineExpr a = map.getResult(perm(enc, d));
862b1d44e59SAart Bik       if (!isInvariantAffine(codegen, a, ldx, atLevel))
863a2c9d4bbSAart Bik         return; // still in play
864a2c9d4bbSAart Bik     }
865a2c9d4bbSAart Bik     // All exhausted at this level (atLevel denotes exactly at this level).
8667373cabcSAart Bik     if (!atLevel)
8677373cabcSAart Bik       return;
8682f2b5b7dSTobias Gysi     OpOperand *lhs = op.getOutputOperand(0);
869619bfe8bSAart Bik     if (lhs == t) {
8707373cabcSAart Bik       // Start or end a scalarized reduction
8717373cabcSAart Bik       if (atStart) {
8727373cabcSAart Bik         Value load = genTensorLoad(merger, codegen, rewriter, op, exp);
8735da21338SAart Bik         codegen.redKind = getReduction(last);
8747373cabcSAart Bik         codegen.redExp = exp;
8757373cabcSAart Bik         updateReduc(merger, codegen, load);
8767373cabcSAart Bik       } else {
8777373cabcSAart Bik         Value redVal = codegen.redVal;
8787373cabcSAart Bik         updateReduc(merger, codegen, Value());
8797373cabcSAart Bik         codegen.redExp = -1u;
8807373cabcSAart Bik         codegen.redKind = kNoReduc;
8817373cabcSAart Bik         genTensorStore(merger, codegen, rewriter, op, redVal);
8827373cabcSAart Bik       }
8837373cabcSAart Bik     } else {
8847373cabcSAart Bik       // Start or end loop invariant hoisting of a tensor load.
885a2c9d4bbSAart Bik       merger.exp(exp).val =
8867373cabcSAart Bik           atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
887a2c9d4bbSAart Bik     }
888123e8dfcSAart Bik   } else if (merger.exp(exp).kind != Kind::kInvariant) {
889a2c9d4bbSAart Bik     // Traverse into the binary operations. Note that we only hoist
890a2c9d4bbSAart Bik     // tensor loads, since subsequent MLIR/LLVM passes know how to
891a2c9d4bbSAart Bik     // deal with all other kinds of derived loop invariants.
8925da21338SAart Bik     Kind last = merger.exp(exp).kind;
8934569c14aSGus Smith     unsigned e0 = merger.exp(exp).children.e0;
8944569c14aSGus Smith     unsigned e1 = merger.exp(exp).children.e1;
8957373cabcSAart Bik     genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last);
8967373cabcSAart Bik     genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last);
897a2c9d4bbSAart Bik   }
898a2c9d4bbSAart Bik }
899a2c9d4bbSAart Bik 
900a2c9d4bbSAart Bik /// Generates initialization code for the subsequent loop sequence at
901a2c9d4bbSAart Bik /// current index level. Returns true if the loop sequence needs to
902a2c9d4bbSAart Bik /// maintain the universal index.
903a2c9d4bbSAart Bik static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
904a2c9d4bbSAart Bik                     linalg::GenericOp op, std::vector<unsigned> &topSort,
905a2c9d4bbSAart Bik                     unsigned at, llvm::BitVector &inits) {
906a2c9d4bbSAart Bik   bool needsUniv = false;
907a2c9d4bbSAart Bik   Location loc = op.getLoc();
908a2c9d4bbSAart Bik   unsigned idx = topSort[at];
909a2c9d4bbSAart Bik 
910a2c9d4bbSAart Bik   // Initialize sparse positions.
911a2c9d4bbSAart Bik   for (unsigned b = 0, be = inits.size(); b < be; b++) {
912a2c9d4bbSAart Bik     if (inits[b]) {
913a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
914a2c9d4bbSAart Bik       assert(idx == merger.index(b));
915a2c9d4bbSAart Bik       if (merger.isDim(b, Dim::kSparse)) {
916a2c9d4bbSAart Bik         // Initialize sparse index.
917a2c9d4bbSAart Bik         unsigned pat = at;
918a2c9d4bbSAart Bik         for (; pat != 0; pat--) {
919a2c9d4bbSAart Bik           if (codegen.pidxs[tensor][topSort[pat - 1]])
920a2c9d4bbSAart Bik             break;
921a2c9d4bbSAart Bik         }
922a2c9d4bbSAart Bik         Value ptr = codegen.pointers[tensor][idx];
923a54f4eaeSMogball         Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
924a54f4eaeSMogball         Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0)
925a2c9d4bbSAart Bik                               : codegen.pidxs[tensor][topSort[pat - 1]];
926a2c9d4bbSAart Bik         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
927a54f4eaeSMogball         Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one);
928a2c9d4bbSAart Bik         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
929a2c9d4bbSAart Bik       } else {
930a2c9d4bbSAart Bik         // Dense index still in play.
931a2c9d4bbSAart Bik         needsUniv = true;
932a2c9d4bbSAart Bik       }
933a2c9d4bbSAart Bik     }
934a2c9d4bbSAart Bik   }
935a2c9d4bbSAart Bik 
936a2c9d4bbSAart Bik   // Initialize the universal dense index.
937a54f4eaeSMogball   codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0);
938a2c9d4bbSAart Bik   return needsUniv;
939a2c9d4bbSAart Bik }
940a2c9d4bbSAart Bik 
941a2c9d4bbSAart Bik /// Returns vectorization strategy. Any implicit inner loop in the Linalg
942a2c9d4bbSAart Bik /// operation is a candidate. Whether it is actually converted to SIMD code
943a2c9d4bbSAart Bik /// depends on the requested strategy.
944a2c9d4bbSAart Bik static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
945a2c9d4bbSAart Bik   switch (codegen.options.vectorizationStrategy) {
946a2c9d4bbSAart Bik   case SparseVectorizationStrategy::kNone:
947a2c9d4bbSAart Bik     return false;
948a2c9d4bbSAart Bik   case SparseVectorizationStrategy::kDenseInnerLoop:
949a2c9d4bbSAart Bik     return isInner && !isSparse;
950a2c9d4bbSAart Bik   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
951a2c9d4bbSAart Bik     return isInner;
952a2c9d4bbSAart Bik   }
953a2c9d4bbSAart Bik   llvm_unreachable("unexpected vectorization strategy");
954a2c9d4bbSAart Bik }
955a2c9d4bbSAart Bik 
956a2c9d4bbSAart Bik /// Returns parallelization strategy. Any implicit loop in the Linalg operation
957a2c9d4bbSAart Bik /// that is marked "parallel" is a candidate. Whether it is actually converted
958a2c9d4bbSAart Bik /// to a parallel operation depends on the requested strategy.
959a2c9d4bbSAart Bik static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
960a2c9d4bbSAart Bik                           bool isSparse, bool isVector) {
961a2c9d4bbSAart Bik   switch (codegen.options.parallelizationStrategy) {
962a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kNone:
963a2c9d4bbSAart Bik     return false;
964a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kDenseOuterLoop:
965a2c9d4bbSAart Bik     return isOuter && !isSparse && !isReduction && !isVector;
966a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
967a2c9d4bbSAart Bik     return isOuter && !isReduction && !isVector;
968a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kDenseAnyLoop:
969a2c9d4bbSAart Bik     return !isSparse && !isReduction && !isVector;
970a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
971a2c9d4bbSAart Bik     return !isReduction && !isVector;
972a2c9d4bbSAart Bik   }
973a2c9d4bbSAart Bik   llvm_unreachable("unexpected parallelization strategy");
974a2c9d4bbSAart Bik }
975a2c9d4bbSAart Bik 
976849f016cSAart Bik /// Checks unit stride for dense tensors. The iteration graph may have ignored
977a2c9d4bbSAart Bik /// dense access patterns in order to avoid cycles (sparse access patterns are
978a2c9d4bbSAart Bik /// always placed innermost), but that means dense access has become strided.
979849f016cSAart Bik /// This prevents effective vectorization.
980a2c9d4bbSAart Bik static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
981849f016cSAart Bik                              unsigned idx) {
9822f2b5b7dSTobias Gysi   for (OpOperand *t : op.getInputAndOutputOperands()) {
9832f2b5b7dSTobias Gysi     if (!getSparseTensorEncoding(t->get().getType())) {
9842f2b5b7dSTobias Gysi       auto map = op.getTiedIndexingMap(t);
985c194b49cSAart Bik       for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
986b1d44e59SAart Bik         AffineExpr a = map.getResult(d);
987849f016cSAart Bik         // Report non-unit stride if innermost index appears at an outer
988849f016cSAart Bik         // dimension (true non-unit stride) or if the innermost index appears
989849f016cSAart Bik         // in a compound subscript in the innermost dimension. Even if the
990849f016cSAart Bik         // latter is unit stride, it does not play well with scatter/gather.
991c8d5dcb0SAart Bik         // TODO: accept unit stride affine innermost like a[i,j+k+1]?
992849f016cSAart Bik         if (a.isFunctionOfDim(idx) &&
993849f016cSAart Bik             ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId)))
994a2c9d4bbSAart Bik           return false;
995a2c9d4bbSAart Bik       }
996a2c9d4bbSAart Bik     }
997a2c9d4bbSAart Bik   }
998a2c9d4bbSAart Bik   return true;
999a2c9d4bbSAart Bik }
1000a2c9d4bbSAart Bik 
1001a2c9d4bbSAart Bik /// Generates a for-loop on a single index.
1002a2c9d4bbSAart Bik static Operation *genFor(Merger &merger, CodeGen &codegen,
1003a2c9d4bbSAart Bik                          PatternRewriter &rewriter, linalg::GenericOp op,
1004a2c9d4bbSAart Bik                          bool isOuter, bool isInner, unsigned idx,
1005a2c9d4bbSAart Bik                          llvm::BitVector &indices) {
1006a2c9d4bbSAart Bik   unsigned fb = indices.find_first();
1007a2c9d4bbSAart Bik   unsigned tensor = merger.tensor(fb);
1008a2c9d4bbSAart Bik   assert(idx == merger.index(fb));
1009a2c9d4bbSAart Bik   auto iteratorTypes = op.iterator_types().getValue();
1010583a7542STobias Gysi   bool isReduction = isReductionIterator(iteratorTypes[idx]);
1011a2c9d4bbSAart Bik   bool isSparse = merger.isDim(fb, Dim::kSparse);
1012f66e5769SAart Bik   bool isVector = !codegen.sparseOut &&
1013f66e5769SAart Bik                   isVectorFor(codegen, isInner, isSparse) &&
1014a2c9d4bbSAart Bik                   denseUnitStrides(merger, op, idx);
1015a2c9d4bbSAart Bik   bool isParallel =
1016f66e5769SAart Bik       !codegen.sparseOut &&
1017a2c9d4bbSAart Bik       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
1018a2c9d4bbSAart Bik 
1019a2c9d4bbSAart Bik   // Prepare vector length.
1020a2c9d4bbSAart Bik   if (isVector)
1021a2c9d4bbSAart Bik     codegen.curVecLength = codegen.options.vectorLength;
1022a2c9d4bbSAart Bik 
1023a2c9d4bbSAart Bik   // Loop bounds and increment.
1024a2c9d4bbSAart Bik   Location loc = op.getLoc();
1025a2c9d4bbSAart Bik   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
1026a2c9d4bbSAart Bik   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
1027a54f4eaeSMogball   Value step =
1028a54f4eaeSMogball       rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength);
1029a2c9d4bbSAart Bik 
1030a2c9d4bbSAart Bik   // Emit a parallel loop.
1031a2c9d4bbSAart Bik   if (isParallel) {
1032a2c9d4bbSAart Bik     assert(!isVector);
1033a2c9d4bbSAart Bik     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
1034a2c9d4bbSAart Bik     if (isSparse)
1035a2c9d4bbSAart Bik       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
1036a2c9d4bbSAart Bik     else
1037a2c9d4bbSAart Bik       codegen.loops[idx] = parOp.getInductionVars()[0];
1038a2c9d4bbSAart Bik     rewriter.setInsertionPointToStart(parOp.getBody());
1039a2c9d4bbSAart Bik     return parOp;
1040a2c9d4bbSAart Bik   }
1041a2c9d4bbSAart Bik 
10427373cabcSAart Bik   // Emit a sequential or vector loop.
1043a2c9d4bbSAart Bik   SmallVector<Value, 4> operands;
10447373cabcSAart Bik   if (codegen.redVal) {
10457373cabcSAart Bik     // In a vector loop, bring reduction into SIMD form, if not already.
10467373cabcSAart Bik     if (isVector && !codegen.redVal.getType().isa<VectorType>()) {
10477373cabcSAart Bik       VectorType vtp = vectorType(codegen, codegen.redVal.getType());
10487373cabcSAart Bik       Value vred = genVectorReducInit(codegen, rewriter, loc, vtp);
10497373cabcSAart Bik       updateReduc(merger, codegen, vred);
10507373cabcSAart Bik     }
10517373cabcSAart Bik     operands.push_back(codegen.redVal);
1052a2c9d4bbSAart Bik   }
1053a2c9d4bbSAart Bik   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
10547373cabcSAart Bik   if (codegen.redVal)
10557373cabcSAart Bik     updateReduc(merger, codegen, forOp.getRegionIterArgs().front());
1056a2c9d4bbSAart Bik   // Assign induction variable to sparse or dense index.
1057a2c9d4bbSAart Bik   Value iv = forOp.getInductionVar();
1058a2c9d4bbSAart Bik   if (isSparse)
1059a2c9d4bbSAart Bik     codegen.pidxs[tensor][idx] = iv;
1060a2c9d4bbSAart Bik   else
1061a2c9d4bbSAart Bik     codegen.loops[idx] = iv;
1062a2c9d4bbSAart Bik   rewriter.setInsertionPointToStart(forOp.getBody());
1063a2c9d4bbSAart Bik   // Share vector iteration mask between all subsequent loads/stores.
1064a2c9d4bbSAart Bik   if (isVector)
1065a2c9d4bbSAart Bik     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
1066a2c9d4bbSAart Bik   return forOp;
1067a2c9d4bbSAart Bik }
1068a2c9d4bbSAart Bik 
1069a2c9d4bbSAart Bik /// Emit a while-loop for co-iteration over multiple indices.
1070a2c9d4bbSAart Bik static Operation *genWhile(Merger &merger, CodeGen &codegen,
1071a2c9d4bbSAart Bik                            PatternRewriter &rewriter, linalg::GenericOp op,
1072a2c9d4bbSAart Bik                            unsigned idx, bool needsUniv,
1073a2c9d4bbSAart Bik                            llvm::BitVector &indices) {
1074a2c9d4bbSAart Bik   SmallVector<Type, 4> types;
1075a2c9d4bbSAart Bik   SmallVector<Value, 4> operands;
1076a2c9d4bbSAart Bik   // Construct the while-loop with a parameter for each index.
1077a2c9d4bbSAart Bik   Type indexType = rewriter.getIndexType();
1078a2c9d4bbSAart Bik   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1079a2c9d4bbSAart Bik     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1080a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1081a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1082a2c9d4bbSAart Bik       types.push_back(indexType);
1083a2c9d4bbSAart Bik       operands.push_back(codegen.pidxs[tensor][idx]);
1084a2c9d4bbSAart Bik     }
1085a2c9d4bbSAart Bik   }
10867373cabcSAart Bik   if (codegen.redVal) {
10877373cabcSAart Bik     types.push_back(codegen.redVal.getType());
10887373cabcSAart Bik     operands.push_back(codegen.redVal);
10897373cabcSAart Bik   }
1090a2c9d4bbSAart Bik   if (needsUniv) {
1091a2c9d4bbSAart Bik     types.push_back(indexType);
1092a2c9d4bbSAart Bik     operands.push_back(codegen.loops[idx]);
1093a2c9d4bbSAart Bik   }
10947373cabcSAart Bik   assert(types.size() == operands.size());
1095a2c9d4bbSAart Bik   Location loc = op.getLoc();
1096a2c9d4bbSAart Bik   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
1097a2c9d4bbSAart Bik   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
1098a2c9d4bbSAart Bik   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
1099a2c9d4bbSAart Bik 
1100a2c9d4bbSAart Bik   // Build the "before" region, which effectively consists
1101a2c9d4bbSAart Bik   // of a conjunction of "i < upper" tests on all induction.
1102a2c9d4bbSAart Bik   rewriter.setInsertionPointToStart(&whileOp.before().front());
1103a2c9d4bbSAart Bik   Value cond;
1104a2c9d4bbSAart Bik   unsigned o = 0;
1105a2c9d4bbSAart Bik   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1106a2c9d4bbSAart Bik     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1107a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1108a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1109a2c9d4bbSAart Bik       Value op1 = before->getArgument(o);
1110a2c9d4bbSAart Bik       Value op2 = codegen.highs[tensor][idx];
1111a54f4eaeSMogball       Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
1112a54f4eaeSMogball                                                  op1, op2);
1113a54f4eaeSMogball       cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc;
1114a2c9d4bbSAart Bik       codegen.pidxs[tensor][idx] = after->getArgument(o++);
1115a2c9d4bbSAart Bik     }
1116a2c9d4bbSAart Bik   }
11177373cabcSAart Bik   if (codegen.redVal)
11187373cabcSAart Bik     updateReduc(merger, codegen, after->getArgument(o++));
1119a2c9d4bbSAart Bik   if (needsUniv)
1120a2c9d4bbSAart Bik     codegen.loops[idx] = after->getArgument(o++);
1121a2c9d4bbSAart Bik   assert(o == operands.size());
1122a2c9d4bbSAart Bik   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
1123a2c9d4bbSAart Bik   rewriter.setInsertionPointToStart(&whileOp.after().front());
1124a2c9d4bbSAart Bik   return whileOp;
1125a2c9d4bbSAart Bik }
1126a2c9d4bbSAart Bik 
1127a2c9d4bbSAart Bik /// Generates a for-loop or a while-loop, depending on whether it implements
1128a2c9d4bbSAart Bik /// singleton iteration or co-iteration over the given conjunction.
1129a2c9d4bbSAart Bik static Operation *genLoop(Merger &merger, CodeGen &codegen,
1130a2c9d4bbSAart Bik                           PatternRewriter &rewriter, linalg::GenericOp op,
1131a2c9d4bbSAart Bik                           std::vector<unsigned> &topSort, unsigned at,
1132a2c9d4bbSAart Bik                           bool needsUniv, llvm::BitVector &indices) {
1133a2c9d4bbSAart Bik   unsigned idx = topSort[at];
1134a2c9d4bbSAart Bik   if (indices.count() == 1) {
1135a2c9d4bbSAart Bik     bool isOuter = at == 0;
1136a2c9d4bbSAart Bik     bool isInner = at == topSort.size() - 1;
1137a2c9d4bbSAart Bik     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
1138a2c9d4bbSAart Bik                   indices);
1139a2c9d4bbSAart Bik   }
1140a2c9d4bbSAart Bik   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
1141a2c9d4bbSAart Bik }
1142a2c9d4bbSAart Bik 
1143a2c9d4bbSAart Bik /// Generates the local variables for this loop, consisting of the sparse
1144a2c9d4bbSAart Bik /// indices, restored universal dense index, and dense positions.
1145a2c9d4bbSAart Bik static void genLocals(Merger &merger, CodeGen &codegen,
1146a2c9d4bbSAart Bik                       PatternRewriter &rewriter, linalg::GenericOp op,
1147a2c9d4bbSAart Bik                       std::vector<unsigned> &topSort, unsigned at,
1148a2c9d4bbSAart Bik                       bool needsUniv, llvm::BitVector &locals) {
1149a2c9d4bbSAart Bik   Location loc = op.getLoc();
1150a2c9d4bbSAart Bik   unsigned idx = topSort[at];
1151a2c9d4bbSAart Bik 
1152a2c9d4bbSAart Bik   // Initialize sparse indices.
1153a2c9d4bbSAart Bik   Value min;
1154a2c9d4bbSAart Bik   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1155a2c9d4bbSAart Bik     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1156a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1157a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1158a2c9d4bbSAart Bik       Value ptr = codegen.indices[tensor][idx];
1159a2c9d4bbSAart Bik       Value s = codegen.pidxs[tensor][idx];
1160a2c9d4bbSAart Bik       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1161a2c9d4bbSAart Bik       codegen.idxs[tensor][idx] = load;
1162a2c9d4bbSAart Bik       if (!needsUniv) {
1163a2c9d4bbSAart Bik         if (min) {
1164a54f4eaeSMogball           Value cmp = rewriter.create<arith::CmpIOp>(
1165a54f4eaeSMogball               loc, arith::CmpIPredicate::ult, load, min);
1166a2c9d4bbSAart Bik           min = rewriter.create<SelectOp>(loc, cmp, load, min);
1167a2c9d4bbSAart Bik         } else {
1168a2c9d4bbSAart Bik           min = load;
1169a2c9d4bbSAart Bik         }
1170a2c9d4bbSAart Bik       }
1171a2c9d4bbSAart Bik     }
1172a2c9d4bbSAart Bik   }
1173a2c9d4bbSAart Bik 
1174a2c9d4bbSAart Bik   // Merge dense universal index over minimum.
1175a2c9d4bbSAart Bik   if (min) {
1176a2c9d4bbSAart Bik     assert(!needsUniv);
1177a2c9d4bbSAart Bik     codegen.loops[idx] = min;
1178a2c9d4bbSAart Bik   }
1179a2c9d4bbSAart Bik 
1180727a63e0SAart Bik   // Initialize dense positions. Note that we generate dense indices of the
1181727a63e0SAart Bik   // output tensor unconditionally, since they may not appear in the lattice,
1182727a63e0SAart Bik   // but may be needed for linearized codegen.
1183a2c9d4bbSAart Bik   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1184727a63e0SAart Bik     if ((locals[b] || merger.isOutTensor(b, idx)) &&
1185727a63e0SAart Bik         merger.isDim(b, Dim::kDense)) {
1186a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1187a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1188a2c9d4bbSAart Bik       unsigned pat = at;
1189a2c9d4bbSAart Bik       for (; pat != 0; pat--)
1190a2c9d4bbSAart Bik         if (codegen.pidxs[tensor][topSort[pat - 1]])
1191a2c9d4bbSAart Bik           break;
1192a54f4eaeSMogball       Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0)
1193a2c9d4bbSAart Bik                            : codegen.pidxs[tensor][topSort[pat - 1]];
1194a2c9d4bbSAart Bik       codegen.pidxs[tensor][idx] = genAddress(
1195a2c9d4bbSAart Bik           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1196a2c9d4bbSAart Bik     }
1197a2c9d4bbSAart Bik   }
1198f66e5769SAart Bik 
1199f66e5769SAart Bik   // Move the insertion indices in lexicographic index order.
1200f66e5769SAart Bik   if (codegen.sparseOut) {
1201f66e5769SAart Bik     Value pos = rewriter.create<arith::ConstantIndexOp>(loc, at);
1202f66e5769SAart Bik     rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx,
1203f66e5769SAart Bik                                      pos);
1204f66e5769SAart Bik   }
1205a2c9d4bbSAart Bik }
1206a2c9d4bbSAart Bik 
1207a2c9d4bbSAart Bik /// Generates the induction structure for a while-loop.
1208a2c9d4bbSAart Bik static void genWhileInduction(Merger &merger, CodeGen &codegen,
1209a2c9d4bbSAart Bik                               PatternRewriter &rewriter, linalg::GenericOp op,
1210a2c9d4bbSAart Bik                               unsigned idx, bool needsUniv,
12117373cabcSAart Bik                               llvm::BitVector &induction,
12127373cabcSAart Bik                               scf::WhileOp whileOp) {
1213a2c9d4bbSAart Bik   Location loc = op.getLoc();
12147373cabcSAart Bik   // Finalize each else branch of all if statements.
12157373cabcSAart Bik   if (codegen.redVal) {
12167373cabcSAart Bik     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
12177373cabcSAart Bik                rewriter.getInsertionBlock()->getParentOp())) {
12187373cabcSAart Bik       rewriter.create<scf::YieldOp>(loc, codegen.redVal);
12197373cabcSAart Bik       updateReduc(merger, codegen, ifOp.getResult(0));
12207373cabcSAart Bik       rewriter.setInsertionPointAfter(ifOp);
12217373cabcSAart Bik     }
12227373cabcSAart Bik   }
12237373cabcSAart Bik   rewriter.setInsertionPointToEnd(&whileOp.after().front());
12247373cabcSAart Bik   // Finalize the induction. Note that the induction could be performed
12257373cabcSAart Bik   // in the individual if-branches to avoid re-evaluating the conditions.
12267373cabcSAart Bik   // However, that would result in a rather elaborate forest of yield
12277373cabcSAart Bik   // instructions during code generation. Moreover, performing the induction
12287373cabcSAart Bik   // after the if-statements more closely resembles code generated by TACO.
1229a2c9d4bbSAart Bik   unsigned o = 0;
1230a2c9d4bbSAart Bik   SmallVector<Value, 4> operands;
1231a54f4eaeSMogball   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1232a2c9d4bbSAart Bik   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1233a2c9d4bbSAart Bik     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1234a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1235a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1236a2c9d4bbSAart Bik       Value op1 = codegen.idxs[tensor][idx];
1237a2c9d4bbSAart Bik       Value op2 = codegen.loops[idx];
1238a2c9d4bbSAart Bik       Value op3 = codegen.pidxs[tensor][idx];
1239a54f4eaeSMogball       Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1240a54f4eaeSMogball                                                  op1, op2);
1241a54f4eaeSMogball       Value add = rewriter.create<arith::AddIOp>(loc, op3, one);
1242a2c9d4bbSAart Bik       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
12437373cabcSAart Bik       codegen.pidxs[tensor][idx] = whileOp->getResult(o++);
1244a2c9d4bbSAart Bik     }
1245a2c9d4bbSAart Bik   }
12467373cabcSAart Bik   if (codegen.redVal) {
12477373cabcSAart Bik     operands.push_back(codegen.redVal);
12487373cabcSAart Bik     updateReduc(merger, codegen, whileOp->getResult(o++));
12497373cabcSAart Bik   }
1250a2c9d4bbSAart Bik   if (needsUniv) {
1251a54f4eaeSMogball     operands.push_back(
1252a54f4eaeSMogball         rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one));
12537373cabcSAart Bik     codegen.loops[idx] = whileOp->getResult(o++);
1254a2c9d4bbSAart Bik   }
1255a2c9d4bbSAart Bik   assert(o == operands.size());
1256a2c9d4bbSAart Bik   rewriter.create<scf::YieldOp>(loc, operands);
12577373cabcSAart Bik   rewriter.setInsertionPointAfter(whileOp);
12587373cabcSAart Bik }
12597373cabcSAart Bik 
12607373cabcSAart Bik /// Generates the induction structure for a for-loop.
12617373cabcSAart Bik static void genForInduction(Merger &merger, CodeGen &codegen,
12627373cabcSAart Bik                             PatternRewriter &rewriter, linalg::GenericOp op,
12637373cabcSAart Bik                             Operation *loop) {
12647373cabcSAart Bik   Location loc = op.getLoc();
12657373cabcSAart Bik   unsigned o = 0;
12667373cabcSAart Bik   SmallVector<Value, 4> operands;
12677373cabcSAart Bik   if (codegen.redVal) {
12687373cabcSAart Bik     operands.push_back(codegen.redVal);
12697373cabcSAart Bik     updateReduc(merger, codegen, loop->getResult(o++));
12707373cabcSAart Bik   }
12717373cabcSAart Bik   assert(o == operands.size());
12727373cabcSAart Bik   if (o > 0)
12737373cabcSAart Bik     rewriter.create<scf::YieldOp>(loc, operands);
12747373cabcSAart Bik   rewriter.setInsertionPointAfter(loop);
1275a2c9d4bbSAart Bik }
1276a2c9d4bbSAart Bik 
1277a2c9d4bbSAart Bik /// Generates a single if-statement within a while-loop.
1278a2c9d4bbSAart Bik static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1279a2c9d4bbSAart Bik                        PatternRewriter &rewriter, linalg::GenericOp op,
1280a2c9d4bbSAart Bik                        unsigned idx, llvm::BitVector &conditions) {
1281a2c9d4bbSAart Bik   Location loc = op.getLoc();
12827373cabcSAart Bik   SmallVector<Type, 4> types;
1283a2c9d4bbSAart Bik   Value cond;
1284a2c9d4bbSAart Bik   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1285a2c9d4bbSAart Bik     if (conditions[b]) {
1286a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1287a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1288a2c9d4bbSAart Bik       Value clause;
1289a2c9d4bbSAart Bik       if (merger.isDim(b, Dim::kSparse)) {
1290a2c9d4bbSAart Bik         Value op1 = codegen.idxs[tensor][idx];
1291a2c9d4bbSAart Bik         Value op2 = codegen.loops[idx];
1292a54f4eaeSMogball         clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1293a54f4eaeSMogball                                                 op1, op2);
1294a2c9d4bbSAart Bik       } else {
1295a54f4eaeSMogball         clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true
1296a2c9d4bbSAart Bik       }
1297a54f4eaeSMogball       cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause;
1298a2c9d4bbSAart Bik     }
1299a2c9d4bbSAart Bik   }
13007373cabcSAart Bik   if (codegen.redVal)
13017373cabcSAart Bik     types.push_back(codegen.redVal.getType());
13027373cabcSAart Bik   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true);
1303a2c9d4bbSAart Bik   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
1304a2c9d4bbSAart Bik   return ifOp;
1305a2c9d4bbSAart Bik }
1306a2c9d4bbSAart Bik 
13077373cabcSAart Bik /// Generates end of true branch of if-statement within a while-loop.
13087373cabcSAart Bik static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
13097373cabcSAart Bik                   linalg::GenericOp op, scf::IfOp ifOp, Value ifInput) {
13107373cabcSAart Bik   if (codegen.redVal) {
13117373cabcSAart Bik     rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal);
13127373cabcSAart Bik     updateReduc(merger, codegen, ifInput);
13137373cabcSAart Bik   }
13147373cabcSAart Bik   rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
13157373cabcSAart Bik }
13167373cabcSAart Bik 
1317c8d5dcb0SAart Bik //===----------------------------------------------------------------------===//
1318c8d5dcb0SAart Bik // Sparse compiler synthesis methods (loop sequence).
1319c8d5dcb0SAart Bik //===----------------------------------------------------------------------===//
1320c8d5dcb0SAart Bik 
1321c8d5dcb0SAart Bik /// Starts a loop sequence at given level. Returns true if
1322c8d5dcb0SAart Bik /// the universal loop index must be maintained at this level.
1323c8d5dcb0SAart Bik static bool startLoopSeq(Merger &merger, CodeGen &codegen,
1324c8d5dcb0SAart Bik                          PatternRewriter &rewriter, linalg::GenericOp op,
1325c8d5dcb0SAart Bik                          std::vector<unsigned> &topSort, unsigned exp,
1326c8d5dcb0SAart Bik                          unsigned at, unsigned idx, unsigned ldx,
1327c8d5dcb0SAart Bik                          unsigned lts) {
1328c8d5dcb0SAart Bik   assert(codegen.curVecLength == 1);
13297373cabcSAart Bik   assert(!codegen.loops[idx]);
1330c8d5dcb0SAart Bik   // Emit invariants at this loop sequence level.
13317373cabcSAart Bik   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true);
1332c8d5dcb0SAart Bik   // Emit further intitialization at this loop sequence level.
1333c8d5dcb0SAart Bik   unsigned l0 = merger.set(lts)[0];
13347373cabcSAart Bik   bool needsUniv =
13357373cabcSAart Bik       genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits);
1336c8d5dcb0SAart Bik   // Maintain the universal index only if it is actually
1337c8d5dcb0SAart Bik   // consumed by a subsequent lattice point.
13387373cabcSAart Bik   if (needsUniv) {
1339c8d5dcb0SAart Bik     unsigned lsize = merger.set(lts).size();
1340c8d5dcb0SAart Bik     for (unsigned i = 1; i < lsize; i++) {
1341c8d5dcb0SAart Bik       unsigned li = merger.set(lts)[i];
1342c8d5dcb0SAart Bik       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse))
1343c8d5dcb0SAart Bik         return true;
1344c8d5dcb0SAart Bik     }
1345c8d5dcb0SAart Bik   }
1346c8d5dcb0SAart Bik   return false;
1347c8d5dcb0SAart Bik }
1348c8d5dcb0SAart Bik 
1349c8d5dcb0SAart Bik /// Starts a single loop in current sequence.
1350c8d5dcb0SAart Bik static Operation *startLoop(Merger &merger, CodeGen &codegen,
1351c8d5dcb0SAart Bik                             PatternRewriter &rewriter, linalg::GenericOp op,
1352c8d5dcb0SAart Bik                             std::vector<unsigned> &topSort, unsigned at,
1353c8d5dcb0SAart Bik                             unsigned li, bool needsUniv) {
1354c8d5dcb0SAart Bik   assert(codegen.curVecLength == 1);
1355c8d5dcb0SAart Bik   // Emit the for/while-loop control.
1356c8d5dcb0SAart Bik   Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at,
1357c8d5dcb0SAart Bik                             needsUniv, merger.lat(li).simple);
1358c8d5dcb0SAart Bik   // Emit the locals for this loop.
1359c8d5dcb0SAart Bik   genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1360c8d5dcb0SAart Bik             merger.lat(li).bits);
1361c8d5dcb0SAart Bik   return loop;
1362c8d5dcb0SAart Bik }
1363c8d5dcb0SAart Bik 
1364c8d5dcb0SAart Bik /// Ends a single loop in current sequence. Returns new values for needsUniv.
1365c8d5dcb0SAart Bik static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1366c8d5dcb0SAart Bik                     linalg::GenericOp op, Operation *loop, unsigned idx,
1367c8d5dcb0SAart Bik                     unsigned li, bool needsUniv) {
1368c8d5dcb0SAart Bik   codegen.curVecLength = 1;
1369c8d5dcb0SAart Bik   // End a while-loop.
1370c8d5dcb0SAart Bik   if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1371c8d5dcb0SAart Bik     genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
13727373cabcSAart Bik                       merger.lat(li).bits, whileOp);
1373c8d5dcb0SAart Bik     return needsUniv;
1374c8d5dcb0SAart Bik   }
1375c8d5dcb0SAart Bik   // End a for-loop.
13767373cabcSAart Bik   genForInduction(merger, codegen, rewriter, op, loop);
1377c8d5dcb0SAart Bik   return false;
1378c8d5dcb0SAart Bik }
1379c8d5dcb0SAart Bik 
1380c8d5dcb0SAart Bik /// Ends a loop sequence at given level.
1381c8d5dcb0SAart Bik static void endLoopSeq(Merger &merger, CodeGen &codegen,
1382c8d5dcb0SAart Bik                        PatternRewriter &rewriter, linalg::GenericOp op,
1383c8d5dcb0SAart Bik                        unsigned exp, unsigned idx, unsigned ldx) {
1384c8d5dcb0SAart Bik   assert(codegen.curVecLength == 1);
1385c8d5dcb0SAart Bik   codegen.loops[idx] = Value();
13867373cabcSAart Bik   // Bring a pending reduction back from SIMD form when sequence ends.
13877373cabcSAart Bik   if (codegen.redVal)
13887373cabcSAart Bik     if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>())
13897373cabcSAart Bik       updateReduc(merger, codegen,
13907373cabcSAart Bik                   genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp));
13917373cabcSAart Bik   // Unmark bookkeeping of invariants and loop index.
13927373cabcSAart Bik   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false);
1393c8d5dcb0SAart Bik }
1394c8d5dcb0SAart Bik 
1395a2c9d4bbSAart Bik /// Recursively generates code while computing iteration lattices in order
1396a2c9d4bbSAart Bik /// to manage the complexity of implementing co-iteration over unions
1397a2c9d4bbSAart Bik /// and intersections of sparse iterations spaces.
1398a2c9d4bbSAart Bik static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1399a2c9d4bbSAart Bik                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1400a2c9d4bbSAart Bik                     unsigned exp, unsigned at) {
1401a2c9d4bbSAart Bik   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1402a2c9d4bbSAart Bik   if (at == topSort.size()) {
1403a2c9d4bbSAart Bik     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1404b1d44e59SAart Bik     genTensorStore(merger, codegen, rewriter, op, rhs);
1405a2c9d4bbSAart Bik     return;
1406a2c9d4bbSAart Bik   }
1407a2c9d4bbSAart Bik 
1408a2c9d4bbSAart Bik   // Construct iteration lattices for current loop index, with L0 at top.
1409a2c9d4bbSAart Bik   unsigned idx = topSort[at];
1410a2c9d4bbSAart Bik   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1411c8d5dcb0SAart Bik   unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
1412a2c9d4bbSAart Bik 
1413c8d5dcb0SAart Bik   // Start a loop sequence.
1414c8d5dcb0SAart Bik   bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at,
1415c8d5dcb0SAart Bik                                 idx, ldx, lts);
1416c8d5dcb0SAart Bik 
1417c8d5dcb0SAart Bik   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1418c8d5dcb0SAart Bik   unsigned lsize = merger.set(lts).size();
1419a2c9d4bbSAart Bik   for (unsigned i = 0; i < lsize; i++) {
1420c8d5dcb0SAart Bik     // Start a loop.
1421a2c9d4bbSAart Bik     unsigned li = merger.set(lts)[i];
1422a2c9d4bbSAart Bik     Operation *loop =
1423c8d5dcb0SAart Bik         startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv);
1424a2c9d4bbSAart Bik 
1425a2c9d4bbSAart Bik     // Visit all lattices points with Li >= Lj to generate the
1426a2c9d4bbSAart Bik     // loop-body, possibly with if statements for coiteration.
14277373cabcSAart Bik     Value ifInput = codegen.redVal;
1428a2c9d4bbSAart Bik     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1429a2c9d4bbSAart Bik     for (unsigned j = 0; j < lsize; j++) {
1430a2c9d4bbSAart Bik       unsigned lj = merger.set(lts)[j];
1431a2c9d4bbSAart Bik       unsigned ej = merger.lat(lj).exp;
1432a2c9d4bbSAart Bik       if (li == lj || merger.latGT(li, lj)) {
1433a2c9d4bbSAart Bik         // Recurse into body of each branch.
1434a2c9d4bbSAart Bik         if (isWhile) {
1435a2c9d4bbSAart Bik           scf::IfOp ifOp =
1436a2c9d4bbSAart Bik               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1437a2c9d4bbSAart Bik           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
14387373cabcSAart Bik           endIf(merger, codegen, rewriter, op, ifOp, ifInput);
1439a2c9d4bbSAart Bik         } else {
1440a2c9d4bbSAart Bik           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1441a2c9d4bbSAart Bik         }
1442a2c9d4bbSAart Bik       }
1443a2c9d4bbSAart Bik     }
1444a2c9d4bbSAart Bik 
1445c8d5dcb0SAart Bik     // End a loop.
1446c8d5dcb0SAart Bik     needsUniv =
1447c8d5dcb0SAart Bik         endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv);
1448a2c9d4bbSAart Bik   }
1449a2c9d4bbSAart Bik 
1450c8d5dcb0SAart Bik   // End a loop sequence.
1451c8d5dcb0SAart Bik   endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx);
1452a2c9d4bbSAart Bik }
1453a2c9d4bbSAart Bik 
1454727a63e0SAart Bik /// Converts the result computed by the sparse kernel into the required form.
145536b66ab9SAart Bik static void genResult(Merger &merger, CodeGen &codegen,
145636b66ab9SAart Bik                       PatternRewriter &rewriter, linalg::GenericOp op) {
145736b66ab9SAart Bik   OpOperand *lhs = op.getOutputOperand(0);
145836b66ab9SAart Bik   Type resType = lhs->get().getType();
1459f66e5769SAart Bik   Value result;
1460f66e5769SAart Bik   if (getSparseTensorEncoding(resType)) {
1461f66e5769SAart Bik     // The sparse tensor rematerializes from the original sparse tensor's
1462f66e5769SAart Bik     // underlying sparse storage format.
1463f66e5769SAart Bik     rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(),
1464f66e5769SAart Bik                                         codegen.sparseOut == lhs);
146536b66ab9SAart Bik   } else {
1466f66e5769SAart Bik     // To rematerialize an non-annotated tensor, simply load it
146736b66ab9SAart Bik     // from the bufferized value.
1468f66e5769SAart Bik     Value val = codegen.buffers.back(); // value array
1469f66e5769SAart Bik     rewriter.replaceOpWithNewOp<memref::TensorLoadOp>(op, resType, val);
147036b66ab9SAart Bik   }
1471727a63e0SAart Bik }
1472727a63e0SAart Bik 
14735da21338SAart Bik //===----------------------------------------------------------------------===//
14745da21338SAart Bik // Sparse compiler rewriting methods.
14755da21338SAart Bik //===----------------------------------------------------------------------===//
14765da21338SAart Bik 
1477a2c9d4bbSAart Bik namespace {
1478a2c9d4bbSAart Bik 
1479a2c9d4bbSAart Bik /// Sparse rewriting rule for generic Lingalg operation.
1480a2c9d4bbSAart Bik struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1481a2c9d4bbSAart Bik public:
1482a2c9d4bbSAart Bik   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1483a2c9d4bbSAart Bik       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1484a2c9d4bbSAart Bik 
1485a2c9d4bbSAart Bik   LogicalResult matchAndRewrite(linalg::GenericOp op,
1486a2c9d4bbSAart Bik                                 PatternRewriter &rewriter) const override {
1487a2c9d4bbSAart Bik     // Detects sparse annotations and translate the per-dimension sparsity
1488a2c9d4bbSAart Bik     // information for all tensors to loop indices in the kernel.
1489a2c9d4bbSAart Bik     assert(op.getNumOutputs() == 1);
14902f2b5b7dSTobias Gysi     unsigned numTensors = op.getNumInputsAndOutputs();
1491a2c9d4bbSAart Bik     unsigned numLoops = op.iterator_types().getValue().size();
1492a2c9d4bbSAart Bik     Merger merger(numTensors, numLoops);
1493bf9ef3efSAart Bik     if (!findSparseAnnotations(merger, op))
1494bf9ef3efSAart Bik       return failure();
1495a2c9d4bbSAart Bik 
1496a2c9d4bbSAart Bik     // Computes a topologically sorted iteration graph to ensure
1497a2c9d4bbSAart Bik     // tensors are visited in natural index order. Fails on cycles.
1498a2c9d4bbSAart Bik     // This assumes that higher-level passes have already put the
1499a2c9d4bbSAart Bik     // tensors in each tensor expression in a feasible order.
1500a2c9d4bbSAart Bik     std::vector<unsigned> topSort;
1501b6d1a31cSAart Bik     if (!computeIterationGraph(merger, op, topSort,
1502b6d1a31cSAart Bik                                SortMask::kIncludeUndef |
1503b6d1a31cSAart Bik                                    SortMask::kIncludeDense) &&
1504b6d1a31cSAart Bik         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) &&
1505b6d1a31cSAart Bik         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) &&
1506b6d1a31cSAart Bik         !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly))
1507a2c9d4bbSAart Bik       return failure();
1508a2c9d4bbSAart Bik 
1509266a7414SAart Bik     // Builds the tensor expression for the Linalg operation in SSA form.
15107373cabcSAart Bik     Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
15117373cabcSAart Bik     if (!optExp.hasValue())
1512266a7414SAart Bik       return failure();
15137373cabcSAart Bik     unsigned exp = optExp.getValue();
1514a2c9d4bbSAart Bik 
1515266a7414SAart Bik     // Rejects an inadmissable tensor expression.
1516f66e5769SAart Bik     OpOperand *sparseOut = nullptr;
1517f66e5769SAart Bik     if (!isAdmissableTensorExp(merger, op, exp, &sparseOut))
151836b66ab9SAart Bik       return failure();
151936b66ab9SAart Bik 
1520a2c9d4bbSAart Bik     // Recursively generates code.
1521f66e5769SAart Bik     CodeGen codegen(options, numTensors, numLoops, sparseOut);
1522c8d5dcb0SAart Bik     genBuffers(merger, codegen, rewriter, op);
15237373cabcSAart Bik     genStmt(merger, codegen, rewriter, op, topSort, exp, 0);
152436b66ab9SAart Bik     genResult(merger, codegen, rewriter, op);
1525a2c9d4bbSAart Bik     return success();
1526a2c9d4bbSAart Bik   }
1527a2c9d4bbSAart Bik 
1528a2c9d4bbSAart Bik private:
1529a2c9d4bbSAart Bik   /// Options to control sparse code generation.
1530a2c9d4bbSAart Bik   SparsificationOptions options;
1531a2c9d4bbSAart Bik };
1532a2c9d4bbSAart Bik 
1533a2c9d4bbSAart Bik } // namespace
1534a2c9d4bbSAart Bik 
1535a2c9d4bbSAart Bik /// Populates the given patterns list with rewriting rules required for
1536a2c9d4bbSAart Bik /// the sparsification of linear algebra operations.
1537a2c9d4bbSAart Bik void mlir::populateSparsificationPatterns(
1538a2c9d4bbSAart Bik     RewritePatternSet &patterns, const SparsificationOptions &options) {
1539a2c9d4bbSAart Bik   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1540a2c9d4bbSAart Bik }
1541