196a23911SAart Bik //===- Sparsification.cpp - Implementation of sparsification --------------===//
2a2c9d4bbSAart Bik //
3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information.
5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a2c9d4bbSAart Bik //
7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
8a2c9d4bbSAart Bik //
9160399c7SAart Bik // This file implements converting sparse tensor types to actual sparse code.
10a2c9d4bbSAart Bik //
11a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
12a2c9d4bbSAart Bik 
1376a18618SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15*4397a1baSMatthias Springer #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
16a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h"
1866f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
19a2c9d4bbSAart Bik #include "mlir/Dialect/SCF/SCF.h"
2076a18618SMatthias Springer #include "mlir/Dialect/SCF/Transforms.h"
21a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23744146f6SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
24a2c9d4bbSAart Bik #include "mlir/Dialect/StandardOps/IR/Ops.h"
25a2c9d4bbSAart Bik #include "mlir/Dialect/Vector/VectorOps.h"
26a2c9d4bbSAart Bik #include "mlir/IR/Matchers.h"
2796a23911SAart Bik #include "mlir/IR/TensorEncoding.h"
28a2c9d4bbSAart Bik #include "llvm/ADT/SmallBitVector.h"
29a2c9d4bbSAart Bik 
30a2c9d4bbSAart Bik using namespace mlir;
3196a23911SAart Bik using namespace mlir::sparse_tensor;
32a2c9d4bbSAart Bik 
335da21338SAart Bik //===----------------------------------------------------------------------===//
345da21338SAart Bik // Declarations of data structures.
355da21338SAart Bik //===----------------------------------------------------------------------===//
365da21338SAart Bik 
37a2c9d4bbSAart Bik namespace {
38a2c9d4bbSAart Bik 
39b6d1a31cSAart Bik // Iteration graph sorting.
40b6d1a31cSAart Bik enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 };
41b6d1a31cSAart Bik 
425da21338SAart Bik // Reduction kinds.
437373cabcSAart Bik enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor };
445da21338SAart Bik 
45a2c9d4bbSAart Bik // Code generation.
46a2c9d4bbSAart Bik struct CodeGen {
4796a23911SAart Bik   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
48a2c9d4bbSAart Bik       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
49a2c9d4bbSAart Bik         pointers(numTensors, std::vector<Value>(numLoops)),
50a2c9d4bbSAart Bik         indices(numTensors, std::vector<Value>(numLoops)),
51a2c9d4bbSAart Bik         highs(numTensors, std::vector<Value>(numLoops)),
52a2c9d4bbSAart Bik         pidxs(numTensors, std::vector<Value>(numLoops)),
53a2c9d4bbSAart Bik         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
547373cabcSAart Bik         redKind(kNoReduc), curVecLength(1), curVecMask() {}
55a2c9d4bbSAart Bik   /// Sparsification options.
5696a23911SAart Bik   SparsificationOptions options;
57a2c9d4bbSAart Bik   /// Universal dense indices and upper bounds (by index). The loops array
58a2c9d4bbSAart Bik   /// is updated with the value of the universal dense index in the current
59a2c9d4bbSAart Bik   /// loop. The sizes array is set once with the inferred dimension sizes.
60a2c9d4bbSAart Bik   std::vector<Value> loops;
61a2c9d4bbSAart Bik   std::vector<Value> sizes;
62a2c9d4bbSAart Bik   /// Buffers for storing dense and sparse numerical values (by tensor).
63a2c9d4bbSAart Bik   /// This array is set once during bufferization of all tensors.
64a2c9d4bbSAart Bik   std::vector<Value> buffers;
65a2c9d4bbSAart Bik   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
66a2c9d4bbSAart Bik   /// This array is set once during bufferization of all sparse tensors.
67a2c9d4bbSAart Bik   std::vector<std::vector<Value>> pointers;
68a2c9d4bbSAart Bik   std::vector<std::vector<Value>> indices;
69a2c9d4bbSAart Bik   /// Sparse iteration information (by tensor and index). These arrays
70a2c9d4bbSAart Bik   /// are updated to remain current within the current loop.
71a2c9d4bbSAart Bik   std::vector<std::vector<Value>> highs;
72a2c9d4bbSAart Bik   std::vector<std::vector<Value>> pidxs;
73a2c9d4bbSAart Bik   std::vector<std::vector<Value>> idxs;
74a2c9d4bbSAart Bik   /// Current reduction, updated during code generation. When indices of a
757373cabcSAart Bik   /// reduction are exhausted, all inner loops can use a scalarized reduction.
76a2c9d4bbSAart Bik   unsigned redExp;
77a2c9d4bbSAart Bik   Value redVal;
785da21338SAart Bik   Reduction redKind;
79a2c9d4bbSAart Bik   // Current vector length and mask.
80a2c9d4bbSAart Bik   unsigned curVecLength;
81a2c9d4bbSAart Bik   Value curVecMask;
82a2c9d4bbSAart Bik };
83a2c9d4bbSAart Bik 
84a2c9d4bbSAart Bik } // namespace
85a2c9d4bbSAart Bik 
865da21338SAart Bik //===----------------------------------------------------------------------===//
875da21338SAart Bik // Sparse compiler analysis methods.
885da21338SAart Bik //===----------------------------------------------------------------------===//
895da21338SAart Bik 
905da21338SAart Bik /// Helper method to apply dimension ordering permutation.
915da21338SAart Bik static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) {
92c194b49cSAart Bik   if (enc) {
93c194b49cSAart Bik     auto order = enc.getDimOrdering();
94c194b49cSAart Bik     if (order) {
95c194b49cSAart Bik       assert(order.isPermutation());
96c194b49cSAart Bik       return order.getDimPosition(d);
97c194b49cSAart Bik     }
98c194b49cSAart Bik   }
99c194b49cSAart Bik   return d;
100c194b49cSAart Bik }
101c194b49cSAart Bik 
1025da21338SAart Bik /// Helper method to translate dim level type to internal representation.
1035da21338SAart Bik static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) {
10496a23911SAart Bik   if (enc) {
10596a23911SAart Bik     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
10696a23911SAart Bik     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
10796a23911SAart Bik       return Dim::kSparse;
10896a23911SAart Bik     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
10996a23911SAart Bik       return Dim::kSingle;
11096a23911SAart Bik   }
11196a23911SAart Bik   return Dim::kDense;
11296a23911SAart Bik }
11396a23911SAart Bik 
114b1d44e59SAart Bik /// Helper method to inspect affine expressions. Rejects cases where the
115c8d5dcb0SAart Bik /// same index is used more than once. Also rejects affine expressions
116c8d5dcb0SAart Bik /// that are not a direct index for annotated tensors.
117c8d5dcb0SAart Bik // TODO: accept more affine cases for sparse tensors
118b1d44e59SAart Bik static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim,
119b1d44e59SAart Bik                        bool isDense) {
120b1d44e59SAart Bik   switch (a.getKind()) {
121b1d44e59SAart Bik   case AffineExprKind::DimId: {
122b1d44e59SAart Bik     unsigned idx = a.cast<AffineDimExpr>().getPosition();
123b1d44e59SAart Bik     if (!merger.isDim(tensor, idx, Dim::kUndef))
124b1d44e59SAart Bik       return false; // used more than once
125b1d44e59SAart Bik     merger.setDim(tensor, idx, dim);
126b1d44e59SAart Bik     return true;
127b1d44e59SAart Bik   }
128b1d44e59SAart Bik   case AffineExprKind::Add:
129b1d44e59SAart Bik   case AffineExprKind::Mul: {
130b1d44e59SAart Bik     if (!isDense)
131b1d44e59SAart Bik       return false;
132b1d44e59SAart Bik     auto binOp = a.cast<AffineBinaryOpExpr>();
133b1d44e59SAart Bik     return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) &&
134b1d44e59SAart Bik            findAffine(merger, tensor, binOp.getRHS(), dim, isDense);
135b1d44e59SAart Bik   }
136b1d44e59SAart Bik   case AffineExprKind::Constant:
137b1d44e59SAart Bik     return isDense;
138b1d44e59SAart Bik   default:
139b1d44e59SAart Bik     return false;
140b1d44e59SAart Bik   }
141b1d44e59SAart Bik }
142b1d44e59SAart Bik 
14396a23911SAart Bik /// Helper method to inspect sparse encodings in the tensor types.
144a2c9d4bbSAart Bik /// Fills the per-dimension sparsity information for all tensors.
145b1d44e59SAart Bik /// Returns true if the sparse annotations and affine subscript
146b1d44e59SAart Bik /// expressions of all tensors are admissable. Returns false if
147b1d44e59SAart Bik /// no annotations are found or inadmissable constructs occur.
148bf9ef3efSAart Bik static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
149bf9ef3efSAart Bik   bool annotated = false;
1502f2b5b7dSTobias Gysi   for (OpOperand *t : op.getInputAndOutputOperands()) {
1512f2b5b7dSTobias Gysi     auto map = op.getTiedIndexingMap(t);
1522f2b5b7dSTobias Gysi     auto enc = getSparseTensorEncoding(t->get().getType());
153727a63e0SAart Bik     if (enc)
154bf9ef3efSAart Bik       annotated = true;
1552f2b5b7dSTobias Gysi     assert(map.getNumResults() == op.getRank(t));
156c194b49cSAart Bik     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
157b1d44e59SAart Bik       unsigned tensor = t->getOperandNumber();
158b1d44e59SAart Bik       AffineExpr a = map.getResult(perm(enc, d));
159b1d44e59SAart Bik       if (!findAffine(merger, tensor, a, toDim(enc, d), !enc))
160b1d44e59SAart Bik         return false; // inadmissable affine expression
161a2c9d4bbSAart Bik     }
162a2c9d4bbSAart Bik   }
163bf9ef3efSAart Bik   return annotated;
164a2c9d4bbSAart Bik }
165a2c9d4bbSAart Bik 
166a2c9d4bbSAart Bik /// A DFS helper to compute a topological sort. Note that recursion is
167a2c9d4bbSAart Bik /// bounded by the number of implicit loops, which is always small.
168a2c9d4bbSAart Bik /// Returns false when a cycle is detected.
169a2c9d4bbSAart Bik static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
170a2c9d4bbSAart Bik                        std::vector<unsigned> &topSort,
171a2c9d4bbSAart Bik                        std::vector<std::vector<bool>> &adjM) {
172a2c9d4bbSAart Bik   if (visit[i] != 0)
173a2c9d4bbSAart Bik     return visit[i] != 1; // 1 denotes cycle!
174a2c9d4bbSAart Bik   visit[i] = 1;
175a2c9d4bbSAart Bik   for (unsigned j = 0, e = visit.size(); j < e; j++)
176a2c9d4bbSAart Bik     if (adjM[i][j])
177a2c9d4bbSAart Bik       if (!topSortDFS(j, visit, topSort, adjM))
178a2c9d4bbSAart Bik         return false;
179a2c9d4bbSAart Bik   visit[i] = 2;
180a2c9d4bbSAart Bik   topSort.push_back(i);
181a2c9d4bbSAart Bik   return true;
182a2c9d4bbSAart Bik }
183a2c9d4bbSAart Bik 
184b1d44e59SAart Bik /// Helper method to add all constraints from the indices in one affine
185b1d44e59SAart Bik /// expression before all indices in the other affine expression. For
186b1d44e59SAart Bik /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3.
187b1d44e59SAart Bik static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
188b1d44e59SAart Bik                                AffineExpr a, AffineExpr b, unsigned fidx) {
189b1d44e59SAart Bik   switch (a.getKind()) {
190b1d44e59SAart Bik   case AffineExprKind::DimId: {
191b1d44e59SAart Bik     unsigned idx = a.cast<AffineDimExpr>().getPosition();
192b1d44e59SAart Bik     if (b)
193b1d44e59SAart Bik       addAffineOrderings(adjM, b, AffineExpr(), idx);
194b1d44e59SAart Bik     else
195b1d44e59SAart Bik       adjM[fidx][idx] = true;
196b1d44e59SAart Bik     break;
197b1d44e59SAart Bik   }
198b1d44e59SAart Bik   case AffineExprKind::Add:
199b1d44e59SAart Bik   case AffineExprKind::Mul: {
200b1d44e59SAart Bik     auto binOp = a.cast<AffineBinaryOpExpr>();
201b1d44e59SAart Bik     addAffineOrderings(adjM, binOp.getLHS(), b, fidx);
202b1d44e59SAart Bik     addAffineOrderings(adjM, binOp.getRHS(), b, fidx);
203b1d44e59SAart Bik     break;
204b1d44e59SAart Bik   }
205b1d44e59SAart Bik   default:
206b1d44e59SAart Bik     break;
207b1d44e59SAart Bik   }
208b1d44e59SAart Bik }
209b1d44e59SAart Bik 
210a2c9d4bbSAart Bik /// Computes a topologically sorted iteration graph for the linalg operation.
211a2c9d4bbSAart Bik /// Ensures all tensors are visited in natural index order. This is essential
212a2c9d4bbSAart Bik /// for sparse storage formats since these only support access along fixed
213a2c9d4bbSAart Bik /// dimensions. Even for dense storage formats, however, the natural index
214a2c9d4bbSAart Bik /// order yields innermost unit-stride access with better spatial locality.
215a2c9d4bbSAart Bik static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
216a2c9d4bbSAart Bik                                   std::vector<unsigned> &topSort,
217b6d1a31cSAart Bik                                   unsigned mask) {
218a2c9d4bbSAart Bik   // Set up an n x n from/to adjacency matrix of the iteration graph
219a2c9d4bbSAart Bik   // for the implicit loop indices i_0 .. i_n-1.
220a2c9d4bbSAart Bik   unsigned n = op.getNumLoops();
221a2c9d4bbSAart Bik   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
222a2c9d4bbSAart Bik 
223a2c9d4bbSAart Bik   // Iterate over the indexing maps of every tensor in the tensor expression.
2242f2b5b7dSTobias Gysi   for (OpOperand *t : op.getInputAndOutputOperands()) {
2252f2b5b7dSTobias Gysi     auto map = op.getTiedIndexingMap(t);
2262f2b5b7dSTobias Gysi     auto enc = getSparseTensorEncoding(t->get().getType());
227a2c9d4bbSAart Bik     assert(map.getNumDims() == n);
228b6d1a31cSAart Bik     // Skip dense tensor constraints when not requested.
229b6d1a31cSAart Bik     if (!(mask & SortMask::kIncludeDense) && !enc)
230a2c9d4bbSAart Bik       continue;
231c194b49cSAart Bik     // Each tensor expression and optional dimension ordering (row-major
232c194b49cSAart Bik     // by default) puts an ordering constraint on the loop indices. For
233c194b49cSAart Bik     // example, the tensor expresion A_ijk forces the ordering i < j < k
234c194b49cSAart Bik     // on the loop indices if no explicit dimension ordering is given.
235c194b49cSAart Bik     for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
236b1d44e59SAart Bik       AffineExpr f = map.getResult(perm(enc, d - 1));
237b1d44e59SAart Bik       AffineExpr t = map.getResult(perm(enc, d));
238b1d44e59SAart Bik       addAffineOrderings(adjM, f, t, 0);
239a2c9d4bbSAart Bik     }
240b6d1a31cSAart Bik     // Push unrelated loops into sparse iteration space, so these
241b6d1a31cSAart Bik     // will be skipped more often.
242b6d1a31cSAart Bik     if (mask & SortMask::kIncludeUndef) {
243b6d1a31cSAart Bik       unsigned tensor = t->getOperandNumber();
244b6d1a31cSAart Bik       for (unsigned i = 0; i < n; i++)
245b6d1a31cSAart Bik         if (merger.isDim(tensor, i, Dim::kSparse))
246b6d1a31cSAart Bik           for (unsigned j = 0; j < n; j++)
247b6d1a31cSAart Bik             if (merger.isDim(tensor, j, Dim::kUndef))
248b6d1a31cSAart Bik               adjM[i][j] = true;
249b6d1a31cSAart Bik     }
250a2c9d4bbSAart Bik   }
251a2c9d4bbSAart Bik 
252a2c9d4bbSAart Bik   // Topologically sort the iteration graph to determine loop order.
253a2c9d4bbSAart Bik   // Report failure for a cyclic iteration graph.
254a2c9d4bbSAart Bik   topSort.clear();
255a2c9d4bbSAart Bik   topSort.reserve(n);
256a2c9d4bbSAart Bik   std::vector<unsigned> visit(n, 0);
257a2c9d4bbSAart Bik   for (unsigned i = 0; i < n; i++)
258a2c9d4bbSAart Bik     if (visit[i] == 0)
259a2c9d4bbSAart Bik       if (!topSortDFS(i, visit, topSort, adjM))
260a2c9d4bbSAart Bik         return false; // cycle!
261a2c9d4bbSAart Bik   std::reverse(std::begin(topSort), std::end(topSort));
262a2c9d4bbSAart Bik   return true;
263a2c9d4bbSAart Bik }
264a2c9d4bbSAart Bik 
265c8d5dcb0SAart Bik /// Returns true if tensor has an in-place annotation.
266c8d5dcb0SAart Bik static bool isInPlace(Value val) {
267c8d5dcb0SAart Bik   if (auto arg = val.dyn_cast<BlockArgument>())
268c8d5dcb0SAart Bik     if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
269c8d5dcb0SAart Bik       if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
270*4397a1baSMatthias Springer               arg.getArgNumber(),
271*4397a1baSMatthias Springer               linalg::comprehensive_bufferize::BufferizableOpInterface::
272*4397a1baSMatthias Springer                   kInplaceableAttrName))
273c8d5dcb0SAart Bik         return attr.getValue();
274c8d5dcb0SAart Bik   return false;
275c8d5dcb0SAart Bik }
276c8d5dcb0SAart Bik 
277c8d5dcb0SAart Bik /// Returns true if tensor materializes into the computation.
278c8d5dcb0SAart Bik static bool isMaterializing(Value val) {
279c8d5dcb0SAart Bik   return val.getDefiningOp<linalg::InitTensorOp>() ||
280c8d5dcb0SAart Bik          val.getDefiningOp<InitOp>();
281c8d5dcb0SAart Bik }
282c8d5dcb0SAart Bik 
28336b66ab9SAart Bik /// Returns true when the tensor expression is admissable for codegen.
28436b66ab9SAart Bik /// Since all sparse input tensors are admissable, we just need to check
28536b66ab9SAart Bik /// whether the output tensor in the tensor expression codegen is admissable.
28636b66ab9SAart Bik static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
28736b66ab9SAart Bik                                   unsigned exp) {
28836b66ab9SAart Bik   OpOperand *lhs = op.getOutputOperand(0);
28936b66ab9SAart Bik   unsigned tensor = lhs->getOperandNumber();
29036b66ab9SAart Bik   auto enc = getSparseTensorEncoding(lhs->get().getType());
29136b66ab9SAart Bik   // An non-annotated output tensor is assumed dense, and becomes a random
292b1d44e59SAart Bik   // access n-dim memref. Admissable since insertions cannot occur.
29336b66ab9SAart Bik   if (!enc)
29436b66ab9SAart Bik     return true;
29536b66ab9SAart Bik   // An all-dense annotated "sparse" output tensor becomes a linearized random
29636b66ab9SAart Bik   // access 1-dim memref. Also admissable since insertions cannot occur.
29736b66ab9SAart Bik   bool allDense = true;
29836b66ab9SAart Bik   unsigned numLoops = op.iterator_types().getValue().size();
29936b66ab9SAart Bik   for (unsigned i = 0; i < numLoops; i++)
30036b66ab9SAart Bik     if (merger.isDim(tensor, i, Dim::kSparse)) {
30136b66ab9SAart Bik       allDense = false;
30236b66ab9SAart Bik       break;
30336b66ab9SAart Bik     }
30436b66ab9SAart Bik   if (allDense)
30536b66ab9SAart Bik     return true;
30636b66ab9SAart Bik   // A tensor expression with a sparse output tensor that changes its values
30736b66ab9SAart Bik   // but not its nonzero structure, an operation called "simply dynamic" in
308c8d5dcb0SAart Bik   // [Bik96,Ch9], is also admissable without special codegen, provided
309c8d5dcb0SAart Bik   // the tensor's underlying sparse storage scheme can be modified in place.
31045b3cfe8SAart Bik   if (merger.isConjunction(tensor, exp))
311c8d5dcb0SAart Bik     return isInPlace(lhs->get());
31236b66ab9SAart Bik   // Reject for now since this requires changes to the nonzero structure.
31336b66ab9SAart Bik   // TODO: implement "workspaces" [Kjolstad2019]
31436b66ab9SAart Bik   return false;
31536b66ab9SAart Bik }
31636b66ab9SAart Bik 
3175da21338SAart Bik //===----------------------------------------------------------------------===//
3187373cabcSAart Bik // Sparse compiler synthesis methods (reductions).
3195da21338SAart Bik //===----------------------------------------------------------------------===//
3205da21338SAart Bik 
3215da21338SAart Bik /// Maps reduction kind to name encoding.
3225da21338SAart Bik static StringRef getReductionName(Reduction kind) {
3235da21338SAart Bik   switch (kind) {
3247373cabcSAart Bik   case kNoReduc:
3257373cabcSAart Bik     break;
3265da21338SAart Bik   case kSum:
3275da21338SAart Bik     return "add";
3285da21338SAart Bik   case kProduct:
3295da21338SAart Bik     return "mul";
3305da21338SAart Bik   case kAnd:
3315da21338SAart Bik     return "and";
3325da21338SAart Bik   case kOr:
3335da21338SAart Bik     return "or";
3345da21338SAart Bik   case kXor:
3355da21338SAart Bik     return "xor";
3365da21338SAart Bik   }
3375da21338SAart Bik   llvm_unreachable("unknown reduction kind");
3385da21338SAart Bik }
3395da21338SAart Bik 
3405da21338SAart Bik /// Maps operation to reduction.
3415da21338SAart Bik static Reduction getReduction(Kind kind) {
3425da21338SAart Bik   switch (kind) {
3435da21338SAart Bik   case Kind::kAddF:
3445da21338SAart Bik   case Kind::kAddI:
3455da21338SAart Bik   case Kind::kSubF:
3465da21338SAart Bik   case Kind::kSubI:
3475da21338SAart Bik     return kSum;
3485da21338SAart Bik   case Kind::kMulF:
3495da21338SAart Bik   case Kind::kMulI:
3505da21338SAart Bik     return kProduct;
3515da21338SAart Bik   case Kind::kAndI:
3525da21338SAart Bik     return kAnd;
3535da21338SAart Bik   case Kind::kOrI:
3545da21338SAart Bik     return kOr;
3555da21338SAart Bik   case Kind::kXorI:
3565da21338SAart Bik     return kXor;
3575da21338SAart Bik   default:
3585da21338SAart Bik     llvm_unreachable("unexpected reduction operator");
3595da21338SAart Bik   }
3605da21338SAart Bik }
3615da21338SAart Bik 
3627373cabcSAart Bik /// Generates an initial value for a vector reduction, following the scheme
3635da21338SAart Bik /// given in Chapter 5 of "The Software Vectorization Handbook", where the
3645da21338SAart Bik /// initial scalar value is correctly embedded in the vector reduction value,
3655da21338SAart Bik /// and a straightforward horizontal reduction will complete the operation.
3667373cabcSAart Bik static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter,
3677373cabcSAart Bik                                 Location loc, VectorType vtp) {
3687373cabcSAart Bik   Value r = codegen.redVal;
3697373cabcSAart Bik   switch (codegen.redKind) {
3707373cabcSAart Bik   case kNoReduc:
3717373cabcSAart Bik     break;
3725da21338SAart Bik   case kSum:
3735da21338SAart Bik   case kXor: {
3745da21338SAart Bik     // Initialize reduction vector to: | 0 | .. | 0 | r |
3755da21338SAart Bik     Attribute zero = rewriter.getZeroAttr(vtp);
376c8d5dcb0SAart Bik     Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero);
3775da21338SAart Bik     return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0);
3785da21338SAart Bik   }
3795da21338SAart Bik   case kProduct: {
3805da21338SAart Bik     // Initialize reduction vector to: | 1 | .. | 1 | r |
3815da21338SAart Bik     Type etp = vtp.getElementType();
3825da21338SAart Bik     Attribute one;
3835da21338SAart Bik     if (etp.isa<FloatType>())
3845da21338SAart Bik       one = rewriter.getFloatAttr(etp, 1.0);
3855da21338SAart Bik     else
3865da21338SAart Bik       one = rewriter.getIntegerAttr(etp, 1);
387c8d5dcb0SAart Bik     Value vec = rewriter.create<arith::ConstantOp>(
388c8d5dcb0SAart Bik         loc, vtp, DenseElementsAttr::get(vtp, one));
3895da21338SAart Bik     return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0);
3905da21338SAart Bik   }
3915da21338SAart Bik   case kAnd:
3925da21338SAart Bik   case kOr:
3935da21338SAart Bik     // Initialize reduction vector to: | r | .. | r | r |
3945da21338SAart Bik     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
3955da21338SAart Bik   }
3965da21338SAart Bik   llvm_unreachable("unknown reduction kind");
3975da21338SAart Bik }
3985da21338SAart Bik 
3997373cabcSAart Bik /// Generates final value for a vector reduction.
4007373cabcSAart Bik static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter,
4017373cabcSAart Bik                                Location loc, VectorType vtp) {
4027373cabcSAart Bik   StringRef name = getReductionName(codegen.redKind);
4037373cabcSAart Bik   StringAttr kind = rewriter.getStringAttr(name);
4047373cabcSAart Bik   return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind,
4057373cabcSAart Bik                                               codegen.redVal, ValueRange{});
4067373cabcSAart Bik }
4077373cabcSAart Bik 
4087373cabcSAart Bik /// Updates scalarized reduction value.
4097373cabcSAart Bik static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) {
4107373cabcSAart Bik   assert(codegen.redKind != kNoReduc);
4117373cabcSAart Bik   codegen.redVal = merger.exp(codegen.redExp).val = reduc;
4127373cabcSAart Bik }
4137373cabcSAart Bik 
4147373cabcSAart Bik //===----------------------------------------------------------------------===//
4157373cabcSAart Bik // Sparse compiler synthesis methods (statements and expressions).
4167373cabcSAart Bik //===----------------------------------------------------------------------===//
4177373cabcSAart Bik 
418a2c9d4bbSAart Bik /// Maps sparse integer option to actual integral storage type.
41996a23911SAart Bik static Type genIntType(PatternRewriter &rewriter, unsigned width) {
42096a23911SAart Bik   if (width == 0)
421a2c9d4bbSAart Bik     return rewriter.getIndexType();
42296a23911SAart Bik   return rewriter.getIntegerType(width);
423a2c9d4bbSAart Bik }
424a2c9d4bbSAart Bik 
425ec97a205SAart Bik /// Generates buffer for the output tensor. Note that all sparse kernels
426ec97a205SAart Bik /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)),
427ec97a205SAart Bik /// the output buffer is already initialized to all zeroes and only nonzeroes
428ec97a205SAart Bik /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)),
429ec97a205SAart Bik /// only nonzeroes values are used for the updates and no assumption on the
430ec97a205SAart Bik /// original contents of the output buffer is necessary..
431a2c9d4bbSAart Bik static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
432a2c9d4bbSAart Bik                              linalg::GenericOp op, MemRefType denseTp,
433a2c9d4bbSAart Bik                              ArrayRef<Value> args) {
434a2c9d4bbSAart Bik   Location loc = op.getLoc();
4352f2b5b7dSTobias Gysi   Value tensor = op.getOutputOperand(0)->get();
436a2c9d4bbSAart Bik   // The output tensor simply could materialize from the buffer that will
437a2c9d4bbSAart Bik   // be generated for the tensor present in the outs() clause. This has
438a2c9d4bbSAart Bik   // the major advantage that the sparse kernel only updates the nonzero
4395879da49SAart Bik   // positions for the output tensor.
440c8d5dcb0SAart Bik   if (isInPlace(tensor))
441a2c9d4bbSAart Bik     return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
442a2c9d4bbSAart Bik   // By default, a new buffer is allocated which is initialized to the
443a2c9d4bbSAart Bik   // tensor defined in the outs() clause. This is always correct but
444a2c9d4bbSAart Bik   // introduces a dense initialization component that may negatively
445ec97a205SAart Bik   // impact the running complexity of the sparse kernel. If the tensor
446c8d5dcb0SAart Bik   // materializes into the computation, we need to preserve the zero
447ec97a205SAart Bik   // initialization assumption of all sparse output buffers.
448c8d5dcb0SAart Bik   if (isMaterializing(tensor)) {
449ec97a205SAart Bik     Type tp = denseTp.getElementType();
450ec97a205SAart Bik     Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
451c8d5dcb0SAart Bik     Value zero =
452c8d5dcb0SAart Bik         rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
453ec97a205SAart Bik     rewriter.create<linalg::FillOp>(loc, zero, alloc);
454ec97a205SAart Bik     return alloc;
455ec97a205SAart Bik   }
456a2c9d4bbSAart Bik   Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
457a2c9d4bbSAart Bik   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
45868ac2e53SAart Bik   rewriter.create<memref::CopyOp>(loc, init, alloc);
459a2c9d4bbSAart Bik   return alloc;
460a2c9d4bbSAart Bik }
461a2c9d4bbSAart Bik 
462a2c9d4bbSAart Bik /// Local bufferization of all dense and sparse data structures.
463a2c9d4bbSAart Bik /// This code enables testing the first prototype sparse compiler.
464a2c9d4bbSAart Bik // TODO: replace this with a proliferated bufferization strategy
465c8d5dcb0SAart Bik static void genBuffers(Merger &merger, CodeGen &codegen,
466a2c9d4bbSAart Bik                        PatternRewriter &rewriter, linalg::GenericOp op) {
467a2c9d4bbSAart Bik   Location loc = op.getLoc();
4682f2b5b7dSTobias Gysi   assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
469a2c9d4bbSAart Bik   // For every tensor, find lower and upper bound on dimensions, set the
470a2c9d4bbSAart Bik   // same bounds on loop indices, and obtain dense or sparse buffer(s).
471a2c9d4bbSAart Bik   SmallVector<Value, 4> args;
4722f2b5b7dSTobias Gysi   for (OpOperand *t : op.getInputAndOutputOperands()) {
473727a63e0SAart Bik     unsigned tensor = t->getOperandNumber();
4742f2b5b7dSTobias Gysi     auto shape = op.getShape(t);
4752f2b5b7dSTobias Gysi     auto map = op.getTiedIndexingMap(t);
4762f2b5b7dSTobias Gysi     auto enc = getSparseTensorEncoding(t->get().getType());
477a2c9d4bbSAart Bik     // Scan all dimensions of current tensor.
478a2c9d4bbSAart Bik     args.clear();
479c194b49cSAart Bik     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
480b1d44e59SAart Bik       AffineExpr a = map.getResult(perm(enc, d));
481b1d44e59SAart Bik       if (a.getKind() != AffineExprKind::DimId)
482b1d44e59SAart Bik         continue; // compound
483b1d44e59SAart Bik       unsigned idx = a.cast<AffineDimExpr>().getPosition();
484a2c9d4bbSAart Bik       // Handle sparse storage schemes.
485727a63e0SAart Bik       if (merger.isDim(tensor, idx, Dim::kSparse)) {
486a2c9d4bbSAart Bik         auto dynShape = {ShapedType::kDynamicSize};
487a2c9d4bbSAart Bik         auto ptrTp = MemRefType::get(
48896a23911SAart Bik             dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
489a2c9d4bbSAart Bik         auto indTp = MemRefType::get(
49096a23911SAart Bik             dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
491a54f4eaeSMogball         Value dim = rewriter.create<arith::ConstantIndexOp>(loc, d);
492a2c9d4bbSAart Bik         // Generate sparse primitives to obtains pointer and indices.
493727a63e0SAart Bik         codegen.pointers[tensor][idx] =
4942f2b5b7dSTobias Gysi             rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
495727a63e0SAart Bik         codegen.indices[tensor][idx] =
4962f2b5b7dSTobias Gysi             rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim);
497a2c9d4bbSAart Bik       }
498d37d72eaSAart Bik       // Find upper bound in current dimension.
499817303efSAart Bik       unsigned p = perm(enc, d);
500d37d72eaSAart Bik       Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p);
501d37d72eaSAart Bik       if (shape[p] == MemRefType::kDynamicSize)
502a2c9d4bbSAart Bik         args.push_back(up);
503817303efSAart Bik       assert(codegen.highs[tensor][idx] == nullptr);
504727a63e0SAart Bik       codegen.sizes[idx] = codegen.highs[tensor][idx] = up;
505a2c9d4bbSAart Bik     }
506727a63e0SAart Bik     // Perform the required bufferization. Dense inputs materialize
507727a63e0SAart Bik     // from the input tensors. Dense outputs need special handling.
508727a63e0SAart Bik     // Sparse inputs use sparse primitives to obtain the values.
509727a63e0SAart Bik     // We also accept in-place all-dense annotated "sparse" outputs.
5102f2b5b7dSTobias Gysi     Type elementType = getElementTypeOrSelf(t->get().getType());
51196a23911SAart Bik     if (!enc) {
512727a63e0SAart Bik       // Non-annotated dense tensors.
5132f2b5b7dSTobias Gysi       auto denseTp = MemRefType::get(shape, elementType);
514727a63e0SAart Bik       if (tensor < op.getNumInputs())
515727a63e0SAart Bik         codegen.buffers[tensor] =
5162f2b5b7dSTobias Gysi             rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get());
517a2c9d4bbSAart Bik       else
518727a63e0SAart Bik         codegen.buffers[tensor] =
519a2c9d4bbSAart Bik             genOutputBuffer(codegen, rewriter, op, denseTp, args);
520a2c9d4bbSAart Bik     } else {
521727a63e0SAart Bik       // Annotated sparse tensors.
522a2c9d4bbSAart Bik       auto dynShape = {ShapedType::kDynamicSize};
5232f2b5b7dSTobias Gysi       auto sparseTp = MemRefType::get(dynShape, elementType);
524727a63e0SAart Bik       codegen.buffers[tensor] =
5252f2b5b7dSTobias Gysi           rewriter.create<ToValuesOp>(loc, sparseTp, t->get());
526a2c9d4bbSAart Bik     }
527a2c9d4bbSAart Bik   }
528a2c9d4bbSAart Bik }
529a2c9d4bbSAart Bik 
530a2c9d4bbSAart Bik /// Constructs vector type.
531a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Type etp) {
532a2c9d4bbSAart Bik   return VectorType::get(codegen.curVecLength, etp);
533a2c9d4bbSAart Bik }
534a2c9d4bbSAart Bik 
535a2c9d4bbSAart Bik /// Constructs vector type from pointer.
536a2c9d4bbSAart Bik static VectorType vectorType(CodeGen &codegen, Value ptr) {
537a2c9d4bbSAart Bik   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
538a2c9d4bbSAart Bik }
539a2c9d4bbSAart Bik 
540a2c9d4bbSAart Bik /// Constructs vector iteration mask.
541a2c9d4bbSAart Bik static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
542a2c9d4bbSAart Bik                            Value iv, Value lo, Value hi, Value step) {
543a2c9d4bbSAart Bik   Location loc = iv.getLoc();
5447373cabcSAart Bik   VectorType mtp = vectorType(codegen, genIntType(rewriter, 1));
545a2c9d4bbSAart Bik   // Special case if the vector length evenly divides the trip count (for
546a2c9d4bbSAart Bik   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
547a2c9d4bbSAart Bik   // so that all subsequent masked memory operations are immediately folded
548a2c9d4bbSAart Bik   // into unconditional memory operations.
549a2c9d4bbSAart Bik   IntegerAttr loInt, hiInt, stepInt;
550a2c9d4bbSAart Bik   if (matchPattern(lo, m_Constant(&loInt)) &&
551a2c9d4bbSAart Bik       matchPattern(hi, m_Constant(&hiInt)) &&
552a2c9d4bbSAart Bik       matchPattern(step, m_Constant(&stepInt))) {
553a2c9d4bbSAart Bik     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
554a2c9d4bbSAart Bik       return rewriter.create<vector::BroadcastOp>(
555a54f4eaeSMogball           loc, mtp, rewriter.create<arith::ConstantIntOp>(loc, 1, 1));
556a2c9d4bbSAart Bik   }
557a2c9d4bbSAart Bik   // Otherwise, generate a vector mask that avoids overrunning the upperbound
558a2c9d4bbSAart Bik   // during vector execution. Here we rely on subsequent loop optimizations to
559a2c9d4bbSAart Bik   // avoid executing the mask in all iterations, for example, by splitting the
560a2c9d4bbSAart Bik   // loop into an unconditional vector loop and a scalar cleanup loop.
56176a18618SMatthias Springer   auto minMap = AffineMap::get(
56276a18618SMatthias Springer       /*dimCount=*/2, /*symbolCount=*/1,
56376a18618SMatthias Springer       {rewriter.getAffineSymbolExpr(0),
56476a18618SMatthias Springer        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
56576a18618SMatthias Springer       rewriter.getContext());
56676a18618SMatthias Springer   Value end =
56776a18618SMatthias Springer       rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step});
568a2c9d4bbSAart Bik   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
569a2c9d4bbSAart Bik }
570a2c9d4bbSAart Bik 
571a2c9d4bbSAart Bik /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
572a2c9d4bbSAart Bik static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
573a2c9d4bbSAart Bik                            Value ptr, ArrayRef<Value> args) {
574a2c9d4bbSAart Bik   Location loc = ptr.getLoc();
575a2c9d4bbSAart Bik   VectorType vtp = vectorType(codegen, ptr);
576a54f4eaeSMogball   Value pass =
577a54f4eaeSMogball       rewriter.create<arith::ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
578a2c9d4bbSAart Bik   if (args.back().getType().isa<VectorType>()) {
579a2c9d4bbSAart Bik     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
580a2c9d4bbSAart Bik     Value indexVec = args.back();
581a54f4eaeSMogball     scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0);
582a2c9d4bbSAart Bik     return rewriter.create<vector::GatherOp>(
583a2c9d4bbSAart Bik         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
584a2c9d4bbSAart Bik   }
585a2c9d4bbSAart Bik   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
586a2c9d4bbSAart Bik                                                codegen.curVecMask, pass);
587a2c9d4bbSAart Bik }
588a2c9d4bbSAart Bik 
589a2c9d4bbSAart Bik /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
590a2c9d4bbSAart Bik static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
591a2c9d4bbSAart Bik                            Value rhs, Value ptr, ArrayRef<Value> args) {
592a2c9d4bbSAart Bik   Location loc = ptr.getLoc();
593a2c9d4bbSAart Bik   if (args.back().getType().isa<VectorType>()) {
594a2c9d4bbSAart Bik     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
595a2c9d4bbSAart Bik     Value indexVec = args.back();
596a54f4eaeSMogball     scalarArgs.back() = rewriter.create<arith::ConstantIndexOp>(loc, 0);
597a2c9d4bbSAart Bik     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
598a2c9d4bbSAart Bik                                        codegen.curVecMask, rhs);
599a2c9d4bbSAart Bik     return;
600a2c9d4bbSAart Bik   }
601a2c9d4bbSAart Bik   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
602a2c9d4bbSAart Bik                                          rhs);
603a2c9d4bbSAart Bik }
604a2c9d4bbSAart Bik 
605a2c9d4bbSAart Bik /// Generates a vectorized invariant. Here we rely on subsequent loop
606a2c9d4bbSAart Bik /// optimizations to hoist the invariant broadcast out of the vector loop.
607a2c9d4bbSAart Bik static Value genVectorInvariantValue(CodeGen &codegen,
608a2c9d4bbSAart Bik                                      PatternRewriter &rewriter, Value val) {
609a2c9d4bbSAart Bik   VectorType vtp = vectorType(codegen, val.getType());
610a2c9d4bbSAart Bik   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
611a2c9d4bbSAart Bik }
612a2c9d4bbSAart Bik 
613b1d44e59SAart Bik /// Generates an affine expression.
614b1d44e59SAart Bik //
615b1d44e59SAart Bik // TODO: generalize for sparse tensor subscripts
616b1d44e59SAart Bik //
617b1d44e59SAart Bik static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter,
618b1d44e59SAart Bik                        AffineExpr a, Location loc) {
619b1d44e59SAart Bik   switch (a.getKind()) {
620b1d44e59SAart Bik   case AffineExprKind::DimId: {
621b1d44e59SAart Bik     unsigned idx = a.cast<AffineDimExpr>().getPosition();
622b1d44e59SAart Bik     return codegen.loops[idx]; // universal dense index
623b1d44e59SAart Bik   }
624b1d44e59SAart Bik   case AffineExprKind::Add: {
625b1d44e59SAart Bik     auto binOp = a.cast<AffineBinaryOpExpr>();
626a54f4eaeSMogball     return rewriter.create<arith::AddIOp>(
627b1d44e59SAart Bik         loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
628b1d44e59SAart Bik         genAffine(codegen, rewriter, binOp.getRHS(), loc));
629b1d44e59SAart Bik   }
630b1d44e59SAart Bik   case AffineExprKind::Mul: {
631b1d44e59SAart Bik     auto binOp = a.cast<AffineBinaryOpExpr>();
632a54f4eaeSMogball     return rewriter.create<arith::MulIOp>(
633b1d44e59SAart Bik         loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
634b1d44e59SAart Bik         genAffine(codegen, rewriter, binOp.getRHS(), loc));
635b1d44e59SAart Bik   }
636b1d44e59SAart Bik   case AffineExprKind::Constant: {
637b1d44e59SAart Bik     int64_t c = a.cast<AffineConstantExpr>().getValue();
638a54f4eaeSMogball     return rewriter.create<arith::ConstantIndexOp>(loc, c);
639b1d44e59SAart Bik   }
640b1d44e59SAart Bik   default:
641b1d44e59SAart Bik     llvm_unreachable("unexpected affine subscript");
642b1d44e59SAart Bik   }
643b1d44e59SAart Bik }
644b1d44e59SAart Bik 
645b1d44e59SAart Bik /// Generates subscript for load/store on a dense or sparse tensor.
646b1d44e59SAart Bik static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter,
647b1d44e59SAart Bik                           linalg::GenericOp op, OpOperand *t,
648b1d44e59SAart Bik                           SmallVector<Value, 4> &args) {
649b1d44e59SAart Bik   unsigned tensor = t->getOperandNumber();
650b1d44e59SAart Bik   auto map = op.getTiedIndexingMap(t);
651b1d44e59SAart Bik   auto enc = getSparseTensorEncoding(t->get().getType());
652b1d44e59SAart Bik   unsigned rank = map.getNumResults();
653b1d44e59SAart Bik   if (enc) {
654b1d44e59SAart Bik     // Note that currently, all sparse subscripts are simple.
655b1d44e59SAart Bik     // TODO: accept affine too?
656c8d5dcb0SAart Bik     AffineExpr a = map.getResult(perm(enc, rank - 1));
657c8d5dcb0SAart Bik     assert(a.getKind() == AffineExprKind::DimId);
658c8d5dcb0SAart Bik     unsigned idx = a.cast<AffineDimExpr>().getPosition();
659b1d44e59SAart Bik     assert(codegen.pidxs[tensor][idx] != nullptr);
660b1d44e59SAart Bik     args.push_back(codegen.pidxs[tensor][idx]); // position index
661b1d44e59SAart Bik   } else {
662b1d44e59SAart Bik     for (unsigned d = 0; d < rank; d++) {
663b1d44e59SAart Bik       AffineExpr a = map.getResult(perm(enc, d));
664b1d44e59SAart Bik       args.push_back(genAffine(codegen, rewriter, a, op.getLoc()));
665b1d44e59SAart Bik     }
666b1d44e59SAart Bik   }
667b1d44e59SAart Bik   return codegen.buffers[tensor];
668b1d44e59SAart Bik }
669b1d44e59SAart Bik 
670a2c9d4bbSAart Bik /// Generates a load on a dense or sparse tensor.
671a2c9d4bbSAart Bik static Value genTensorLoad(Merger &merger, CodeGen &codegen,
672a2c9d4bbSAart Bik                            PatternRewriter &rewriter, linalg::GenericOp op,
673a2c9d4bbSAart Bik                            unsigned exp) {
674a2c9d4bbSAart Bik   // Test if the load was hoisted to a higher loop nest.
675a2c9d4bbSAart Bik   Value val = merger.exp(exp).val;
676a2c9d4bbSAart Bik   if (val) {
677a2c9d4bbSAart Bik     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
678a2c9d4bbSAart Bik       return genVectorInvariantValue(codegen, rewriter, val);
679a2c9d4bbSAart Bik     return val;
680a2c9d4bbSAart Bik   }
681a2c9d4bbSAart Bik   // Actual load.
682a2c9d4bbSAart Bik   SmallVector<Value, 4> args;
6834569c14aSGus Smith   OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
684b1d44e59SAart Bik   Value ptr = genSubscript(codegen, rewriter, op, t, args);
685a2c9d4bbSAart Bik   if (codegen.curVecLength > 1)
686a2c9d4bbSAart Bik     return genVectorLoad(codegen, rewriter, ptr, args);
687b1d44e59SAart Bik   return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
688a2c9d4bbSAart Bik }
689a2c9d4bbSAart Bik 
690727a63e0SAart Bik /// Generates a store on a dense or sparse tensor.
691a2c9d4bbSAart Bik static void genTensorStore(Merger &merger, CodeGen &codegen,
692a2c9d4bbSAart Bik                            PatternRewriter &rewriter, linalg::GenericOp op,
693b1d44e59SAart Bik                            Value rhs) {
694a2c9d4bbSAart Bik   // Test if this is a scalarized reduction.
695b1d44e59SAart Bik   if (codegen.redVal) {
696a2c9d4bbSAart Bik     if (codegen.curVecLength > 1)
697b1d44e59SAart Bik       rhs = rewriter.create<SelectOp>(op.getLoc(), codegen.curVecMask, rhs,
698a2c9d4bbSAart Bik                                       codegen.redVal);
6997373cabcSAart Bik     updateReduc(merger, codegen, rhs);
700a2c9d4bbSAart Bik     return;
701a2c9d4bbSAart Bik   }
702a2c9d4bbSAart Bik   // Actual store.
703a2c9d4bbSAart Bik   SmallVector<Value, 4> args;
704b1d44e59SAart Bik   OpOperand *t = op.getOutputOperand(0);
705b1d44e59SAart Bik   Value ptr = genSubscript(codegen, rewriter, op, t, args);
706a2c9d4bbSAart Bik   if (codegen.curVecLength > 1)
707a2c9d4bbSAart Bik     genVectorStore(codegen, rewriter, rhs, ptr, args);
708a2c9d4bbSAart Bik   else
709b1d44e59SAart Bik     rewriter.create<memref::StoreOp>(op.getLoc(), rhs, ptr, args);
710a2c9d4bbSAart Bik }
711a2c9d4bbSAart Bik 
712a2c9d4bbSAart Bik /// Generates a pointer/index load from the sparse storage scheme. Narrower
713a2c9d4bbSAart Bik /// data types need to be zero extended before casting the value into the
714a2c9d4bbSAart Bik /// index type used for looping and indexing.
715a2c9d4bbSAart Bik static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
716a2c9d4bbSAart Bik                      Value ptr, Value s) {
717a2c9d4bbSAart Bik   // See https://llvm.org/docs/GetElementPtr.html for some background on
718a2c9d4bbSAart Bik   // the complications described below.
719a2c9d4bbSAart Bik   if (codegen.curVecLength > 1) {
720a2c9d4bbSAart Bik     // Since the index vector is used in a subsequent gather/scatter operations,
721a2c9d4bbSAart Bik     // which effectively defines an unsigned pointer + signed index, we must
722a2c9d4bbSAart Bik     // zero extend the vector to an index width. For 8-bit and 16-bit values,
723a2c9d4bbSAart Bik     // an 32-bit index width suffices. For 32-bit values, zero extending the
724a2c9d4bbSAart Bik     // elements into 64-bit loses some performance since the 32-bit indexed
72586e9bc1aSAart Bik     // gather/scatter is more efficient than the 64-bit index variant (if the
72686e9bc1aSAart Bik     // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
727727a63e0SAart Bik     // preserve this performance). For 64-bit values, there is no good way
728a2c9d4bbSAart Bik     // to state that the indices are unsigned, with creates the potential of
729a2c9d4bbSAart Bik     // incorrect address calculations in the unlikely case we need such
730a2c9d4bbSAart Bik     // extremely large offsets.
731a2c9d4bbSAart Bik     Type etp = ptr.getType().cast<MemRefType>().getElementType();
732a2c9d4bbSAart Bik     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
733a2c9d4bbSAart Bik     if (!etp.isa<IndexType>()) {
734a2c9d4bbSAart Bik       if (etp.getIntOrFloatBitWidth() < 32)
735a54f4eaeSMogball         vload = rewriter.create<arith::ExtUIOp>(
7367373cabcSAart Bik             loc, vload, vectorType(codegen, genIntType(rewriter, 32)));
73786e9bc1aSAart Bik       else if (etp.getIntOrFloatBitWidth() < 64 &&
73886e9bc1aSAart Bik                !codegen.options.enableSIMDIndex32)
739a54f4eaeSMogball         vload = rewriter.create<arith::ExtUIOp>(
7407373cabcSAart Bik             loc, vload, vectorType(codegen, genIntType(rewriter, 64)));
741a2c9d4bbSAart Bik     }
742a2c9d4bbSAart Bik     return vload;
743a2c9d4bbSAart Bik   }
744a2c9d4bbSAart Bik   // For the scalar case, we simply zero extend narrower indices into 64-bit
745a2c9d4bbSAart Bik   // values before casting to index without a performance penalty. Here too,
746a2c9d4bbSAart Bik   // however, indices that already are 64-bit, in theory, cannot express the
747a2c9d4bbSAart Bik   // full range as explained above.
748a2c9d4bbSAart Bik   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
749a2c9d4bbSAart Bik   if (!load.getType().isa<IndexType>()) {
750a2c9d4bbSAart Bik     if (load.getType().getIntOrFloatBitWidth() < 64)
7517373cabcSAart Bik       load =
7527373cabcSAart Bik           rewriter.create<arith::ExtUIOp>(loc, load, genIntType(rewriter, 64));
753a54f4eaeSMogball     load =
754a54f4eaeSMogball         rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType());
755a2c9d4bbSAart Bik   }
756a2c9d4bbSAart Bik   return load;
757a2c9d4bbSAart Bik }
758a2c9d4bbSAart Bik 
759a2c9d4bbSAart Bik /// Generates an invariant value.
760a2c9d4bbSAart Bik static Value genInvariantValue(Merger &merger, CodeGen &codegen,
761a2c9d4bbSAart Bik                                PatternRewriter &rewriter, unsigned exp) {
762a2c9d4bbSAart Bik   Value val = merger.exp(exp).val;
763a2c9d4bbSAart Bik   if (codegen.curVecLength > 1)
764a2c9d4bbSAart Bik     return genVectorInvariantValue(codegen, rewriter, val);
765a2c9d4bbSAart Bik   return val;
766a2c9d4bbSAart Bik }
767a2c9d4bbSAart Bik 
768a2c9d4bbSAart Bik /// Generates an address computation "sz * p + i".
769a2c9d4bbSAart Bik static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
770a2c9d4bbSAart Bik                         Location loc, Value size, Value p, Value i) {
771a54f4eaeSMogball   Value mul = rewriter.create<arith::MulIOp>(loc, size, p);
772a2c9d4bbSAart Bik   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
773a54f4eaeSMogball     Value inv =
774a54f4eaeSMogball         rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType());
775a2c9d4bbSAart Bik     mul = genVectorInvariantValue(codegen, rewriter, inv);
776a2c9d4bbSAart Bik   }
777a54f4eaeSMogball   return rewriter.create<arith::AddIOp>(loc, mul, i);
778a2c9d4bbSAart Bik }
779a2c9d4bbSAart Bik 
780a2c9d4bbSAart Bik /// Recursively generates tensor expression.
781a2c9d4bbSAart Bik static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
782a2c9d4bbSAart Bik                     linalg::GenericOp op, unsigned exp) {
783b8a021dbSAart Bik   Location loc = op.getLoc();
784123e8dfcSAart Bik   if (exp == -1u)
785123e8dfcSAart Bik     return Value();
786a2c9d4bbSAart Bik   if (merger.exp(exp).kind == Kind::kTensor)
787a2c9d4bbSAart Bik     return genTensorLoad(merger, codegen, rewriter, op, exp);
788b8a021dbSAart Bik   if (merger.exp(exp).kind == Kind::kInvariant)
789a2c9d4bbSAart Bik     return genInvariantValue(merger, codegen, rewriter, exp);
7904569c14aSGus Smith   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
7914569c14aSGus Smith   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
79245b3cfe8SAart Bik   return merger.buildExp(rewriter, loc, exp, v0, v1);
793a2c9d4bbSAart Bik }
794a2c9d4bbSAart Bik 
795b1d44e59SAart Bik /// Determines if affine expression is invariant.
796b1d44e59SAart Bik static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
797b1d44e59SAart Bik                               unsigned ldx, bool &atLevel) {
798b1d44e59SAart Bik   switch (a.getKind()) {
799b1d44e59SAart Bik   case AffineExprKind::DimId: {
800b1d44e59SAart Bik     unsigned idx = a.cast<AffineDimExpr>().getPosition();
801b1d44e59SAart Bik     if (idx == ldx)
802b1d44e59SAart Bik       atLevel = true;
803b1d44e59SAart Bik     return codegen.loops[idx] != nullptr; // no longer in play?
804b1d44e59SAart Bik   }
805b1d44e59SAart Bik   case AffineExprKind::Add:
806b1d44e59SAart Bik   case AffineExprKind::Mul: {
807b1d44e59SAart Bik     auto binOp = a.cast<AffineBinaryOpExpr>();
808b1d44e59SAart Bik     return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) &&
809b1d44e59SAart Bik            isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel);
810b1d44e59SAart Bik   }
811b1d44e59SAart Bik   default:
812b1d44e59SAart Bik     return true;
813b1d44e59SAart Bik   }
814b1d44e59SAart Bik }
815b1d44e59SAart Bik 
816a2c9d4bbSAart Bik /// Hoists loop invariant tensor loads for which indices have been exhausted.
817a2c9d4bbSAart Bik static void genInvariants(Merger &merger, CodeGen &codegen,
818a2c9d4bbSAart Bik                           PatternRewriter &rewriter, linalg::GenericOp op,
8197373cabcSAart Bik                           unsigned exp, unsigned ldx, bool atStart,
8205da21338SAart Bik                           Kind last = Kind::kTensor) {
821123e8dfcSAart Bik   if (exp == -1u)
822123e8dfcSAart Bik     return;
823a2c9d4bbSAart Bik   if (merger.exp(exp).kind == Kind::kTensor) {
824a2c9d4bbSAart Bik     // Inspect tensor indices.
825a2c9d4bbSAart Bik     bool atLevel = ldx == -1u;
8264569c14aSGus Smith     OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
827619bfe8bSAart Bik     auto map = op.getTiedIndexingMap(t);
828619bfe8bSAart Bik     auto enc = getSparseTensorEncoding(t->get().getType());
829c194b49cSAart Bik     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
830b1d44e59SAart Bik       AffineExpr a = map.getResult(perm(enc, d));
831b1d44e59SAart Bik       if (!isInvariantAffine(codegen, a, ldx, atLevel))
832a2c9d4bbSAart Bik         return; // still in play
833a2c9d4bbSAart Bik     }
834a2c9d4bbSAart Bik     // All exhausted at this level (atLevel denotes exactly at this level).
8357373cabcSAart Bik     if (!atLevel)
8367373cabcSAart Bik       return;
8372f2b5b7dSTobias Gysi     OpOperand *lhs = op.getOutputOperand(0);
838619bfe8bSAart Bik     if (lhs == t) {
8397373cabcSAart Bik       // Start or end a scalarized reduction
8407373cabcSAart Bik       if (atStart) {
8417373cabcSAart Bik         Value load = genTensorLoad(merger, codegen, rewriter, op, exp);
8425da21338SAart Bik         codegen.redKind = getReduction(last);
8437373cabcSAart Bik         codegen.redExp = exp;
8447373cabcSAart Bik         updateReduc(merger, codegen, load);
8457373cabcSAart Bik       } else {
8467373cabcSAart Bik         Value redVal = codegen.redVal;
8477373cabcSAart Bik         updateReduc(merger, codegen, Value());
8487373cabcSAart Bik         codegen.redExp = -1u;
8497373cabcSAart Bik         codegen.redKind = kNoReduc;
8507373cabcSAart Bik         genTensorStore(merger, codegen, rewriter, op, redVal);
8517373cabcSAart Bik       }
8527373cabcSAart Bik     } else {
8537373cabcSAart Bik       // Start or end loop invariant hoisting of a tensor load.
854a2c9d4bbSAart Bik       merger.exp(exp).val =
8557373cabcSAart Bik           atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
856a2c9d4bbSAart Bik     }
857123e8dfcSAart Bik   } else if (merger.exp(exp).kind != Kind::kInvariant) {
858a2c9d4bbSAart Bik     // Traverse into the binary operations. Note that we only hoist
859a2c9d4bbSAart Bik     // tensor loads, since subsequent MLIR/LLVM passes know how to
860a2c9d4bbSAart Bik     // deal with all other kinds of derived loop invariants.
8615da21338SAart Bik     Kind last = merger.exp(exp).kind;
8624569c14aSGus Smith     unsigned e0 = merger.exp(exp).children.e0;
8634569c14aSGus Smith     unsigned e1 = merger.exp(exp).children.e1;
8647373cabcSAart Bik     genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last);
8657373cabcSAart Bik     genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last);
866a2c9d4bbSAart Bik   }
867a2c9d4bbSAart Bik }
868a2c9d4bbSAart Bik 
869a2c9d4bbSAart Bik /// Generates initialization code for the subsequent loop sequence at
870a2c9d4bbSAart Bik /// current index level. Returns true if the loop sequence needs to
871a2c9d4bbSAart Bik /// maintain the universal index.
872a2c9d4bbSAart Bik static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
873a2c9d4bbSAart Bik                     linalg::GenericOp op, std::vector<unsigned> &topSort,
874a2c9d4bbSAart Bik                     unsigned at, llvm::BitVector &inits) {
875a2c9d4bbSAart Bik   bool needsUniv = false;
876a2c9d4bbSAart Bik   Location loc = op.getLoc();
877a2c9d4bbSAart Bik   unsigned idx = topSort[at];
878a2c9d4bbSAart Bik 
879a2c9d4bbSAart Bik   // Initialize sparse positions.
880a2c9d4bbSAart Bik   for (unsigned b = 0, be = inits.size(); b < be; b++) {
881a2c9d4bbSAart Bik     if (inits[b]) {
882a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
883a2c9d4bbSAart Bik       assert(idx == merger.index(b));
884a2c9d4bbSAart Bik       if (merger.isDim(b, Dim::kSparse)) {
885a2c9d4bbSAart Bik         // Initialize sparse index.
886a2c9d4bbSAart Bik         unsigned pat = at;
887a2c9d4bbSAart Bik         for (; pat != 0; pat--) {
888a2c9d4bbSAart Bik           if (codegen.pidxs[tensor][topSort[pat - 1]])
889a2c9d4bbSAart Bik             break;
890a2c9d4bbSAart Bik         }
891a2c9d4bbSAart Bik         Value ptr = codegen.pointers[tensor][idx];
892a54f4eaeSMogball         Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
893a54f4eaeSMogball         Value p0 = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0)
894a2c9d4bbSAart Bik                               : codegen.pidxs[tensor][topSort[pat - 1]];
895a2c9d4bbSAart Bik         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
896a54f4eaeSMogball         Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one);
897a2c9d4bbSAart Bik         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
898a2c9d4bbSAart Bik       } else {
899a2c9d4bbSAart Bik         // Dense index still in play.
900a2c9d4bbSAart Bik         needsUniv = true;
901a2c9d4bbSAart Bik       }
902a2c9d4bbSAart Bik     }
903a2c9d4bbSAart Bik   }
904a2c9d4bbSAart Bik 
905a2c9d4bbSAart Bik   // Initialize the universal dense index.
906a54f4eaeSMogball   codegen.loops[idx] = rewriter.create<arith::ConstantIndexOp>(loc, 0);
907a2c9d4bbSAart Bik   return needsUniv;
908a2c9d4bbSAart Bik }
909a2c9d4bbSAart Bik 
910a2c9d4bbSAart Bik /// Returns vectorization strategy. Any implicit inner loop in the Linalg
911a2c9d4bbSAart Bik /// operation is a candidate. Whether it is actually converted to SIMD code
912a2c9d4bbSAart Bik /// depends on the requested strategy.
913a2c9d4bbSAart Bik static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
914a2c9d4bbSAart Bik   switch (codegen.options.vectorizationStrategy) {
915a2c9d4bbSAart Bik   case SparseVectorizationStrategy::kNone:
916a2c9d4bbSAart Bik     return false;
917a2c9d4bbSAart Bik   case SparseVectorizationStrategy::kDenseInnerLoop:
918a2c9d4bbSAart Bik     return isInner && !isSparse;
919a2c9d4bbSAart Bik   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
920a2c9d4bbSAart Bik     return isInner;
921a2c9d4bbSAart Bik   }
922a2c9d4bbSAart Bik   llvm_unreachable("unexpected vectorization strategy");
923a2c9d4bbSAart Bik }
924a2c9d4bbSAart Bik 
925a2c9d4bbSAart Bik /// Returns parallelization strategy. Any implicit loop in the Linalg operation
926a2c9d4bbSAart Bik /// that is marked "parallel" is a candidate. Whether it is actually converted
927a2c9d4bbSAart Bik /// to a parallel operation depends on the requested strategy.
928a2c9d4bbSAart Bik static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
929a2c9d4bbSAart Bik                           bool isSparse, bool isVector) {
930a2c9d4bbSAart Bik   switch (codegen.options.parallelizationStrategy) {
931a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kNone:
932a2c9d4bbSAart Bik     return false;
933a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kDenseOuterLoop:
934a2c9d4bbSAart Bik     return isOuter && !isSparse && !isReduction && !isVector;
935a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
936a2c9d4bbSAart Bik     return isOuter && !isReduction && !isVector;
937a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kDenseAnyLoop:
938a2c9d4bbSAart Bik     return !isSparse && !isReduction && !isVector;
939a2c9d4bbSAart Bik   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
940a2c9d4bbSAart Bik     return !isReduction && !isVector;
941a2c9d4bbSAart Bik   }
942a2c9d4bbSAart Bik   llvm_unreachable("unexpected parallelization strategy");
943a2c9d4bbSAart Bik }
944a2c9d4bbSAart Bik 
945849f016cSAart Bik /// Checks unit stride for dense tensors. The iteration graph may have ignored
946a2c9d4bbSAart Bik /// dense access patterns in order to avoid cycles (sparse access patterns are
947a2c9d4bbSAart Bik /// always placed innermost), but that means dense access has become strided.
948849f016cSAart Bik /// This prevents effective vectorization.
949a2c9d4bbSAart Bik static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
950849f016cSAart Bik                              unsigned idx) {
9512f2b5b7dSTobias Gysi   for (OpOperand *t : op.getInputAndOutputOperands()) {
9522f2b5b7dSTobias Gysi     if (!getSparseTensorEncoding(t->get().getType())) {
9532f2b5b7dSTobias Gysi       auto map = op.getTiedIndexingMap(t);
954c194b49cSAart Bik       for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
955b1d44e59SAart Bik         AffineExpr a = map.getResult(d);
956849f016cSAart Bik         // Report non-unit stride if innermost index appears at an outer
957849f016cSAart Bik         // dimension (true non-unit stride) or if the innermost index appears
958849f016cSAart Bik         // in a compound subscript in the innermost dimension. Even if the
959849f016cSAart Bik         // latter is unit stride, it does not play well with scatter/gather.
960c8d5dcb0SAart Bik         // TODO: accept unit stride affine innermost like a[i,j+k+1]?
961849f016cSAart Bik         if (a.isFunctionOfDim(idx) &&
962849f016cSAart Bik             ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId)))
963a2c9d4bbSAart Bik           return false;
964a2c9d4bbSAart Bik       }
965a2c9d4bbSAart Bik     }
966a2c9d4bbSAart Bik   }
967a2c9d4bbSAart Bik   return true;
968a2c9d4bbSAart Bik }
969a2c9d4bbSAart Bik 
970a2c9d4bbSAart Bik /// Generates a for-loop on a single index.
971a2c9d4bbSAart Bik static Operation *genFor(Merger &merger, CodeGen &codegen,
972a2c9d4bbSAart Bik                          PatternRewriter &rewriter, linalg::GenericOp op,
973a2c9d4bbSAart Bik                          bool isOuter, bool isInner, unsigned idx,
974a2c9d4bbSAart Bik                          llvm::BitVector &indices) {
975a2c9d4bbSAart Bik   unsigned fb = indices.find_first();
976a2c9d4bbSAart Bik   unsigned tensor = merger.tensor(fb);
977a2c9d4bbSAart Bik   assert(idx == merger.index(fb));
978a2c9d4bbSAart Bik   auto iteratorTypes = op.iterator_types().getValue();
979583a7542STobias Gysi   bool isReduction = isReductionIterator(iteratorTypes[idx]);
980a2c9d4bbSAart Bik   bool isSparse = merger.isDim(fb, Dim::kSparse);
981a2c9d4bbSAart Bik   bool isVector = isVectorFor(codegen, isInner, isSparse) &&
982a2c9d4bbSAart Bik                   denseUnitStrides(merger, op, idx);
983a2c9d4bbSAart Bik   bool isParallel =
984a2c9d4bbSAart Bik       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
985a2c9d4bbSAart Bik 
986a2c9d4bbSAart Bik   // Prepare vector length.
987a2c9d4bbSAart Bik   if (isVector)
988a2c9d4bbSAart Bik     codegen.curVecLength = codegen.options.vectorLength;
989a2c9d4bbSAart Bik 
990a2c9d4bbSAart Bik   // Loop bounds and increment.
991a2c9d4bbSAart Bik   Location loc = op.getLoc();
992a2c9d4bbSAart Bik   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
993a2c9d4bbSAart Bik   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
994a54f4eaeSMogball   Value step =
995a54f4eaeSMogball       rewriter.create<arith::ConstantIndexOp>(loc, codegen.curVecLength);
996a2c9d4bbSAart Bik 
997a2c9d4bbSAart Bik   // Emit a parallel loop.
998a2c9d4bbSAart Bik   if (isParallel) {
999a2c9d4bbSAart Bik     assert(!isVector);
1000a2c9d4bbSAart Bik     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
1001a2c9d4bbSAart Bik     if (isSparse)
1002a2c9d4bbSAart Bik       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
1003a2c9d4bbSAart Bik     else
1004a2c9d4bbSAart Bik       codegen.loops[idx] = parOp.getInductionVars()[0];
1005a2c9d4bbSAart Bik     rewriter.setInsertionPointToStart(parOp.getBody());
1006a2c9d4bbSAart Bik     return parOp;
1007a2c9d4bbSAart Bik   }
1008a2c9d4bbSAart Bik 
10097373cabcSAart Bik   // Emit a sequential or vector loop.
1010a2c9d4bbSAart Bik   SmallVector<Value, 4> operands;
10117373cabcSAart Bik   if (codegen.redVal) {
10127373cabcSAart Bik     // In a vector loop, bring reduction into SIMD form, if not already.
10137373cabcSAart Bik     if (isVector && !codegen.redVal.getType().isa<VectorType>()) {
10147373cabcSAart Bik       VectorType vtp = vectorType(codegen, codegen.redVal.getType());
10157373cabcSAart Bik       Value vred = genVectorReducInit(codegen, rewriter, loc, vtp);
10167373cabcSAart Bik       updateReduc(merger, codegen, vred);
10177373cabcSAart Bik     }
10187373cabcSAart Bik     operands.push_back(codegen.redVal);
1019a2c9d4bbSAart Bik   }
1020a2c9d4bbSAart Bik   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
10217373cabcSAart Bik   if (codegen.redVal)
10227373cabcSAart Bik     updateReduc(merger, codegen, forOp.getRegionIterArgs().front());
1023a2c9d4bbSAart Bik   // Assign induction variable to sparse or dense index.
1024a2c9d4bbSAart Bik   Value iv = forOp.getInductionVar();
1025a2c9d4bbSAart Bik   if (isSparse)
1026a2c9d4bbSAart Bik     codegen.pidxs[tensor][idx] = iv;
1027a2c9d4bbSAart Bik   else
1028a2c9d4bbSAart Bik     codegen.loops[idx] = iv;
1029a2c9d4bbSAart Bik   rewriter.setInsertionPointToStart(forOp.getBody());
1030a2c9d4bbSAart Bik   // Share vector iteration mask between all subsequent loads/stores.
1031a2c9d4bbSAart Bik   if (isVector)
1032a2c9d4bbSAart Bik     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
1033a2c9d4bbSAart Bik   return forOp;
1034a2c9d4bbSAart Bik }
1035a2c9d4bbSAart Bik 
1036a2c9d4bbSAart Bik /// Emit a while-loop for co-iteration over multiple indices.
1037a2c9d4bbSAart Bik static Operation *genWhile(Merger &merger, CodeGen &codegen,
1038a2c9d4bbSAart Bik                            PatternRewriter &rewriter, linalg::GenericOp op,
1039a2c9d4bbSAart Bik                            unsigned idx, bool needsUniv,
1040a2c9d4bbSAart Bik                            llvm::BitVector &indices) {
1041a2c9d4bbSAart Bik   SmallVector<Type, 4> types;
1042a2c9d4bbSAart Bik   SmallVector<Value, 4> operands;
1043a2c9d4bbSAart Bik   // Construct the while-loop with a parameter for each index.
1044a2c9d4bbSAart Bik   Type indexType = rewriter.getIndexType();
1045a2c9d4bbSAart Bik   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1046a2c9d4bbSAart Bik     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1047a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1048a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1049a2c9d4bbSAart Bik       types.push_back(indexType);
1050a2c9d4bbSAart Bik       operands.push_back(codegen.pidxs[tensor][idx]);
1051a2c9d4bbSAart Bik     }
1052a2c9d4bbSAart Bik   }
10537373cabcSAart Bik   if (codegen.redVal) {
10547373cabcSAart Bik     types.push_back(codegen.redVal.getType());
10557373cabcSAart Bik     operands.push_back(codegen.redVal);
10567373cabcSAart Bik   }
1057a2c9d4bbSAart Bik   if (needsUniv) {
1058a2c9d4bbSAart Bik     types.push_back(indexType);
1059a2c9d4bbSAart Bik     operands.push_back(codegen.loops[idx]);
1060a2c9d4bbSAart Bik   }
10617373cabcSAart Bik   assert(types.size() == operands.size());
1062a2c9d4bbSAart Bik   Location loc = op.getLoc();
1063a2c9d4bbSAart Bik   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
1064a2c9d4bbSAart Bik   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
1065a2c9d4bbSAart Bik   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
1066a2c9d4bbSAart Bik 
1067a2c9d4bbSAart Bik   // Build the "before" region, which effectively consists
1068a2c9d4bbSAart Bik   // of a conjunction of "i < upper" tests on all induction.
1069a2c9d4bbSAart Bik   rewriter.setInsertionPointToStart(&whileOp.before().front());
1070a2c9d4bbSAart Bik   Value cond;
1071a2c9d4bbSAart Bik   unsigned o = 0;
1072a2c9d4bbSAart Bik   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1073a2c9d4bbSAart Bik     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1074a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1075a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1076a2c9d4bbSAart Bik       Value op1 = before->getArgument(o);
1077a2c9d4bbSAart Bik       Value op2 = codegen.highs[tensor][idx];
1078a54f4eaeSMogball       Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
1079a54f4eaeSMogball                                                  op1, op2);
1080a54f4eaeSMogball       cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc;
1081a2c9d4bbSAart Bik       codegen.pidxs[tensor][idx] = after->getArgument(o++);
1082a2c9d4bbSAart Bik     }
1083a2c9d4bbSAart Bik   }
10847373cabcSAart Bik   if (codegen.redVal)
10857373cabcSAart Bik     updateReduc(merger, codegen, after->getArgument(o++));
1086a2c9d4bbSAart Bik   if (needsUniv)
1087a2c9d4bbSAart Bik     codegen.loops[idx] = after->getArgument(o++);
1088a2c9d4bbSAart Bik   assert(o == operands.size());
1089a2c9d4bbSAart Bik   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
1090a2c9d4bbSAart Bik   rewriter.setInsertionPointToStart(&whileOp.after().front());
1091a2c9d4bbSAart Bik   return whileOp;
1092a2c9d4bbSAart Bik }
1093a2c9d4bbSAart Bik 
1094a2c9d4bbSAart Bik /// Generates a for-loop or a while-loop, depending on whether it implements
1095a2c9d4bbSAart Bik /// singleton iteration or co-iteration over the given conjunction.
1096a2c9d4bbSAart Bik static Operation *genLoop(Merger &merger, CodeGen &codegen,
1097a2c9d4bbSAart Bik                           PatternRewriter &rewriter, linalg::GenericOp op,
1098a2c9d4bbSAart Bik                           std::vector<unsigned> &topSort, unsigned at,
1099a2c9d4bbSAart Bik                           bool needsUniv, llvm::BitVector &indices) {
1100a2c9d4bbSAart Bik   unsigned idx = topSort[at];
1101a2c9d4bbSAart Bik   if (indices.count() == 1) {
1102a2c9d4bbSAart Bik     bool isOuter = at == 0;
1103a2c9d4bbSAart Bik     bool isInner = at == topSort.size() - 1;
1104a2c9d4bbSAart Bik     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
1105a2c9d4bbSAart Bik                   indices);
1106a2c9d4bbSAart Bik   }
1107a2c9d4bbSAart Bik   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
1108a2c9d4bbSAart Bik }
1109a2c9d4bbSAart Bik 
1110a2c9d4bbSAart Bik /// Generates the local variables for this loop, consisting of the sparse
1111a2c9d4bbSAart Bik /// indices, restored universal dense index, and dense positions.
1112a2c9d4bbSAart Bik static void genLocals(Merger &merger, CodeGen &codegen,
1113a2c9d4bbSAart Bik                       PatternRewriter &rewriter, linalg::GenericOp op,
1114a2c9d4bbSAart Bik                       std::vector<unsigned> &topSort, unsigned at,
1115a2c9d4bbSAart Bik                       bool needsUniv, llvm::BitVector &locals) {
1116a2c9d4bbSAart Bik   Location loc = op.getLoc();
1117a2c9d4bbSAart Bik   unsigned idx = topSort[at];
1118a2c9d4bbSAart Bik 
1119a2c9d4bbSAart Bik   // Initialize sparse indices.
1120a2c9d4bbSAart Bik   Value min;
1121a2c9d4bbSAart Bik   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1122a2c9d4bbSAart Bik     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1123a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1124a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1125a2c9d4bbSAart Bik       Value ptr = codegen.indices[tensor][idx];
1126a2c9d4bbSAart Bik       Value s = codegen.pidxs[tensor][idx];
1127a2c9d4bbSAart Bik       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1128a2c9d4bbSAart Bik       codegen.idxs[tensor][idx] = load;
1129a2c9d4bbSAart Bik       if (!needsUniv) {
1130a2c9d4bbSAart Bik         if (min) {
1131a54f4eaeSMogball           Value cmp = rewriter.create<arith::CmpIOp>(
1132a54f4eaeSMogball               loc, arith::CmpIPredicate::ult, load, min);
1133a2c9d4bbSAart Bik           min = rewriter.create<SelectOp>(loc, cmp, load, min);
1134a2c9d4bbSAart Bik         } else {
1135a2c9d4bbSAart Bik           min = load;
1136a2c9d4bbSAart Bik         }
1137a2c9d4bbSAart Bik       }
1138a2c9d4bbSAart Bik     }
1139a2c9d4bbSAart Bik   }
1140a2c9d4bbSAart Bik 
1141a2c9d4bbSAart Bik   // Merge dense universal index over minimum.
1142a2c9d4bbSAart Bik   if (min) {
1143a2c9d4bbSAart Bik     assert(!needsUniv);
1144a2c9d4bbSAart Bik     codegen.loops[idx] = min;
1145a2c9d4bbSAart Bik   }
1146a2c9d4bbSAart Bik 
1147727a63e0SAart Bik   // Initialize dense positions. Note that we generate dense indices of the
1148727a63e0SAart Bik   // output tensor unconditionally, since they may not appear in the lattice,
1149727a63e0SAart Bik   // but may be needed for linearized codegen.
1150a2c9d4bbSAart Bik   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1151727a63e0SAart Bik     if ((locals[b] || merger.isOutTensor(b, idx)) &&
1152727a63e0SAart Bik         merger.isDim(b, Dim::kDense)) {
1153a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1154a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1155a2c9d4bbSAart Bik       unsigned pat = at;
1156a2c9d4bbSAart Bik       for (; pat != 0; pat--)
1157a2c9d4bbSAart Bik         if (codegen.pidxs[tensor][topSort[pat - 1]])
1158a2c9d4bbSAart Bik           break;
1159a54f4eaeSMogball       Value p = (pat == 0) ? rewriter.create<arith::ConstantIndexOp>(loc, 0)
1160a2c9d4bbSAart Bik                            : codegen.pidxs[tensor][topSort[pat - 1]];
1161a2c9d4bbSAart Bik       codegen.pidxs[tensor][idx] = genAddress(
1162a2c9d4bbSAart Bik           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1163a2c9d4bbSAart Bik     }
1164a2c9d4bbSAart Bik   }
1165a2c9d4bbSAart Bik }
1166a2c9d4bbSAart Bik 
1167a2c9d4bbSAart Bik /// Generates the induction structure for a while-loop.
1168a2c9d4bbSAart Bik static void genWhileInduction(Merger &merger, CodeGen &codegen,
1169a2c9d4bbSAart Bik                               PatternRewriter &rewriter, linalg::GenericOp op,
1170a2c9d4bbSAart Bik                               unsigned idx, bool needsUniv,
11717373cabcSAart Bik                               llvm::BitVector &induction,
11727373cabcSAart Bik                               scf::WhileOp whileOp) {
1173a2c9d4bbSAart Bik   Location loc = op.getLoc();
11747373cabcSAart Bik   // Finalize each else branch of all if statements.
11757373cabcSAart Bik   if (codegen.redVal) {
11767373cabcSAart Bik     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
11777373cabcSAart Bik                rewriter.getInsertionBlock()->getParentOp())) {
11787373cabcSAart Bik       rewriter.create<scf::YieldOp>(loc, codegen.redVal);
11797373cabcSAart Bik       updateReduc(merger, codegen, ifOp.getResult(0));
11807373cabcSAart Bik       rewriter.setInsertionPointAfter(ifOp);
11817373cabcSAart Bik     }
11827373cabcSAart Bik   }
11837373cabcSAart Bik   rewriter.setInsertionPointToEnd(&whileOp.after().front());
11847373cabcSAart Bik   // Finalize the induction. Note that the induction could be performed
11857373cabcSAart Bik   // in the individual if-branches to avoid re-evaluating the conditions.
11867373cabcSAart Bik   // However, that would result in a rather elaborate forest of yield
11877373cabcSAart Bik   // instructions during code generation. Moreover, performing the induction
11887373cabcSAart Bik   // after the if-statements more closely resembles code generated by TACO.
1189a2c9d4bbSAart Bik   unsigned o = 0;
1190a2c9d4bbSAart Bik   SmallVector<Value, 4> operands;
1191a54f4eaeSMogball   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1192a2c9d4bbSAart Bik   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1193a2c9d4bbSAart Bik     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1194a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1195a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1196a2c9d4bbSAart Bik       Value op1 = codegen.idxs[tensor][idx];
1197a2c9d4bbSAart Bik       Value op2 = codegen.loops[idx];
1198a2c9d4bbSAart Bik       Value op3 = codegen.pidxs[tensor][idx];
1199a54f4eaeSMogball       Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1200a54f4eaeSMogball                                                  op1, op2);
1201a54f4eaeSMogball       Value add = rewriter.create<arith::AddIOp>(loc, op3, one);
1202a2c9d4bbSAart Bik       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
12037373cabcSAart Bik       codegen.pidxs[tensor][idx] = whileOp->getResult(o++);
1204a2c9d4bbSAart Bik     }
1205a2c9d4bbSAart Bik   }
12067373cabcSAart Bik   if (codegen.redVal) {
12077373cabcSAart Bik     operands.push_back(codegen.redVal);
12087373cabcSAart Bik     updateReduc(merger, codegen, whileOp->getResult(o++));
12097373cabcSAart Bik   }
1210a2c9d4bbSAart Bik   if (needsUniv) {
1211a54f4eaeSMogball     operands.push_back(
1212a54f4eaeSMogball         rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one));
12137373cabcSAart Bik     codegen.loops[idx] = whileOp->getResult(o++);
1214a2c9d4bbSAart Bik   }
1215a2c9d4bbSAart Bik   assert(o == operands.size());
1216a2c9d4bbSAart Bik   rewriter.create<scf::YieldOp>(loc, operands);
12177373cabcSAart Bik   rewriter.setInsertionPointAfter(whileOp);
12187373cabcSAart Bik }
12197373cabcSAart Bik 
12207373cabcSAart Bik /// Generates the induction structure for a for-loop.
12217373cabcSAart Bik static void genForInduction(Merger &merger, CodeGen &codegen,
12227373cabcSAart Bik                             PatternRewriter &rewriter, linalg::GenericOp op,
12237373cabcSAart Bik                             Operation *loop) {
12247373cabcSAart Bik   Location loc = op.getLoc();
12257373cabcSAart Bik   unsigned o = 0;
12267373cabcSAart Bik   SmallVector<Value, 4> operands;
12277373cabcSAart Bik   if (codegen.redVal) {
12287373cabcSAart Bik     operands.push_back(codegen.redVal);
12297373cabcSAart Bik     updateReduc(merger, codegen, loop->getResult(o++));
12307373cabcSAart Bik   }
12317373cabcSAart Bik   assert(o == operands.size());
12327373cabcSAart Bik   if (o > 0)
12337373cabcSAart Bik     rewriter.create<scf::YieldOp>(loc, operands);
12347373cabcSAart Bik   rewriter.setInsertionPointAfter(loop);
1235a2c9d4bbSAart Bik }
1236a2c9d4bbSAart Bik 
1237a2c9d4bbSAart Bik /// Generates a single if-statement within a while-loop.
1238a2c9d4bbSAart Bik static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1239a2c9d4bbSAart Bik                        PatternRewriter &rewriter, linalg::GenericOp op,
1240a2c9d4bbSAart Bik                        unsigned idx, llvm::BitVector &conditions) {
1241a2c9d4bbSAart Bik   Location loc = op.getLoc();
12427373cabcSAart Bik   SmallVector<Type, 4> types;
1243a2c9d4bbSAart Bik   Value cond;
1244a2c9d4bbSAart Bik   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1245a2c9d4bbSAart Bik     if (conditions[b]) {
1246a2c9d4bbSAart Bik       unsigned tensor = merger.tensor(b);
1247a2c9d4bbSAart Bik       assert(idx == merger.index(b));
1248a2c9d4bbSAart Bik       Value clause;
1249a2c9d4bbSAart Bik       if (merger.isDim(b, Dim::kSparse)) {
1250a2c9d4bbSAart Bik         Value op1 = codegen.idxs[tensor][idx];
1251a2c9d4bbSAart Bik         Value op2 = codegen.loops[idx];
1252a54f4eaeSMogball         clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1253a54f4eaeSMogball                                                 op1, op2);
1254a2c9d4bbSAart Bik       } else {
1255a54f4eaeSMogball         clause = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true
1256a2c9d4bbSAart Bik       }
1257a54f4eaeSMogball       cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause;
1258a2c9d4bbSAart Bik     }
1259a2c9d4bbSAart Bik   }
12607373cabcSAart Bik   if (codegen.redVal)
12617373cabcSAart Bik     types.push_back(codegen.redVal.getType());
12627373cabcSAart Bik   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true);
1263a2c9d4bbSAart Bik   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
1264a2c9d4bbSAart Bik   return ifOp;
1265a2c9d4bbSAart Bik }
1266a2c9d4bbSAart Bik 
12677373cabcSAart Bik /// Generates end of true branch of if-statement within a while-loop.
12687373cabcSAart Bik static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
12697373cabcSAart Bik                   linalg::GenericOp op, scf::IfOp ifOp, Value ifInput) {
12707373cabcSAart Bik   if (codegen.redVal) {
12717373cabcSAart Bik     rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal);
12727373cabcSAart Bik     updateReduc(merger, codegen, ifInput);
12737373cabcSAart Bik   }
12747373cabcSAart Bik   rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
12757373cabcSAart Bik }
12767373cabcSAart Bik 
1277c8d5dcb0SAart Bik //===----------------------------------------------------------------------===//
1278c8d5dcb0SAart Bik // Sparse compiler synthesis methods (loop sequence).
1279c8d5dcb0SAart Bik //===----------------------------------------------------------------------===//
1280c8d5dcb0SAart Bik 
1281c8d5dcb0SAart Bik /// Starts a loop sequence at given level. Returns true if
1282c8d5dcb0SAart Bik /// the universal loop index must be maintained at this level.
1283c8d5dcb0SAart Bik static bool startLoopSeq(Merger &merger, CodeGen &codegen,
1284c8d5dcb0SAart Bik                          PatternRewriter &rewriter, linalg::GenericOp op,
1285c8d5dcb0SAart Bik                          std::vector<unsigned> &topSort, unsigned exp,
1286c8d5dcb0SAart Bik                          unsigned at, unsigned idx, unsigned ldx,
1287c8d5dcb0SAart Bik                          unsigned lts) {
1288c8d5dcb0SAart Bik   assert(codegen.curVecLength == 1);
12897373cabcSAart Bik   assert(!codegen.loops[idx]);
1290c8d5dcb0SAart Bik   // Emit invariants at this loop sequence level.
12917373cabcSAart Bik   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true);
1292c8d5dcb0SAart Bik   // Emit further intitialization at this loop sequence level.
1293c8d5dcb0SAart Bik   unsigned l0 = merger.set(lts)[0];
12947373cabcSAart Bik   bool needsUniv =
12957373cabcSAart Bik       genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits);
1296c8d5dcb0SAart Bik   // Maintain the universal index only if it is actually
1297c8d5dcb0SAart Bik   // consumed by a subsequent lattice point.
12987373cabcSAart Bik   if (needsUniv) {
1299c8d5dcb0SAart Bik     unsigned lsize = merger.set(lts).size();
1300c8d5dcb0SAart Bik     for (unsigned i = 1; i < lsize; i++) {
1301c8d5dcb0SAart Bik       unsigned li = merger.set(lts)[i];
1302c8d5dcb0SAart Bik       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse))
1303c8d5dcb0SAart Bik         return true;
1304c8d5dcb0SAart Bik     }
1305c8d5dcb0SAart Bik   }
1306c8d5dcb0SAart Bik   return false;
1307c8d5dcb0SAart Bik }
1308c8d5dcb0SAart Bik 
1309c8d5dcb0SAart Bik /// Starts a single loop in current sequence.
1310c8d5dcb0SAart Bik static Operation *startLoop(Merger &merger, CodeGen &codegen,
1311c8d5dcb0SAart Bik                             PatternRewriter &rewriter, linalg::GenericOp op,
1312c8d5dcb0SAart Bik                             std::vector<unsigned> &topSort, unsigned at,
1313c8d5dcb0SAart Bik                             unsigned li, bool needsUniv) {
1314c8d5dcb0SAart Bik   assert(codegen.curVecLength == 1);
1315c8d5dcb0SAart Bik   // Emit the for/while-loop control.
1316c8d5dcb0SAart Bik   Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at,
1317c8d5dcb0SAart Bik                             needsUniv, merger.lat(li).simple);
1318c8d5dcb0SAart Bik   // Emit the locals for this loop.
1319c8d5dcb0SAart Bik   genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1320c8d5dcb0SAart Bik             merger.lat(li).bits);
1321c8d5dcb0SAart Bik   return loop;
1322c8d5dcb0SAart Bik }
1323c8d5dcb0SAart Bik 
1324c8d5dcb0SAart Bik /// Ends a single loop in current sequence. Returns new values for needsUniv.
1325c8d5dcb0SAart Bik static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1326c8d5dcb0SAart Bik                     linalg::GenericOp op, Operation *loop, unsigned idx,
1327c8d5dcb0SAart Bik                     unsigned li, bool needsUniv) {
1328c8d5dcb0SAart Bik   codegen.curVecLength = 1;
1329c8d5dcb0SAart Bik   // End a while-loop.
1330c8d5dcb0SAart Bik   if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1331c8d5dcb0SAart Bik     genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
13327373cabcSAart Bik                       merger.lat(li).bits, whileOp);
1333c8d5dcb0SAart Bik     return needsUniv;
1334c8d5dcb0SAart Bik   }
1335c8d5dcb0SAart Bik   // End a for-loop.
13367373cabcSAart Bik   genForInduction(merger, codegen, rewriter, op, loop);
1337c8d5dcb0SAart Bik   return false;
1338c8d5dcb0SAart Bik }
1339c8d5dcb0SAart Bik 
1340c8d5dcb0SAart Bik /// Ends a loop sequence at given level.
1341c8d5dcb0SAart Bik static void endLoopSeq(Merger &merger, CodeGen &codegen,
1342c8d5dcb0SAart Bik                        PatternRewriter &rewriter, linalg::GenericOp op,
1343c8d5dcb0SAart Bik                        unsigned exp, unsigned idx, unsigned ldx) {
1344c8d5dcb0SAart Bik   assert(codegen.curVecLength == 1);
1345c8d5dcb0SAart Bik   codegen.loops[idx] = Value();
13467373cabcSAart Bik   // Bring a pending reduction back from SIMD form when sequence ends.
13477373cabcSAart Bik   if (codegen.redVal)
13487373cabcSAart Bik     if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>())
13497373cabcSAart Bik       updateReduc(merger, codegen,
13507373cabcSAart Bik                   genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp));
13517373cabcSAart Bik   // Unmark bookkeeping of invariants and loop index.
13527373cabcSAart Bik   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false);
1353c8d5dcb0SAart Bik }
1354c8d5dcb0SAart Bik 
1355a2c9d4bbSAart Bik /// Recursively generates code while computing iteration lattices in order
1356a2c9d4bbSAart Bik /// to manage the complexity of implementing co-iteration over unions
1357a2c9d4bbSAart Bik /// and intersections of sparse iterations spaces.
1358a2c9d4bbSAart Bik static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1359a2c9d4bbSAart Bik                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1360a2c9d4bbSAart Bik                     unsigned exp, unsigned at) {
1361a2c9d4bbSAart Bik   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1362a2c9d4bbSAart Bik   if (at == topSort.size()) {
1363a2c9d4bbSAart Bik     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1364b1d44e59SAart Bik     genTensorStore(merger, codegen, rewriter, op, rhs);
1365a2c9d4bbSAart Bik     return;
1366a2c9d4bbSAart Bik   }
1367a2c9d4bbSAart Bik 
1368a2c9d4bbSAart Bik   // Construct iteration lattices for current loop index, with L0 at top.
1369a2c9d4bbSAart Bik   unsigned idx = topSort[at];
1370a2c9d4bbSAart Bik   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1371c8d5dcb0SAart Bik   unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
1372a2c9d4bbSAart Bik 
1373c8d5dcb0SAart Bik   // Start a loop sequence.
1374c8d5dcb0SAart Bik   bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at,
1375c8d5dcb0SAart Bik                                 idx, ldx, lts);
1376c8d5dcb0SAart Bik 
1377c8d5dcb0SAart Bik   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1378c8d5dcb0SAart Bik   unsigned lsize = merger.set(lts).size();
1379a2c9d4bbSAart Bik   for (unsigned i = 0; i < lsize; i++) {
1380c8d5dcb0SAart Bik     // Start a loop.
1381a2c9d4bbSAart Bik     unsigned li = merger.set(lts)[i];
1382a2c9d4bbSAart Bik     Operation *loop =
1383c8d5dcb0SAart Bik         startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv);
1384a2c9d4bbSAart Bik 
1385a2c9d4bbSAart Bik     // Visit all lattices points with Li >= Lj to generate the
1386a2c9d4bbSAart Bik     // loop-body, possibly with if statements for coiteration.
13877373cabcSAart Bik     Value ifInput = codegen.redVal;
1388a2c9d4bbSAart Bik     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1389a2c9d4bbSAart Bik     for (unsigned j = 0; j < lsize; j++) {
1390a2c9d4bbSAart Bik       unsigned lj = merger.set(lts)[j];
1391a2c9d4bbSAart Bik       unsigned ej = merger.lat(lj).exp;
1392a2c9d4bbSAart Bik       if (li == lj || merger.latGT(li, lj)) {
1393a2c9d4bbSAart Bik         // Recurse into body of each branch.
1394a2c9d4bbSAart Bik         if (isWhile) {
1395a2c9d4bbSAart Bik           scf::IfOp ifOp =
1396a2c9d4bbSAart Bik               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1397a2c9d4bbSAart Bik           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
13987373cabcSAart Bik           endIf(merger, codegen, rewriter, op, ifOp, ifInput);
1399a2c9d4bbSAart Bik         } else {
1400a2c9d4bbSAart Bik           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1401a2c9d4bbSAart Bik         }
1402a2c9d4bbSAart Bik       }
1403a2c9d4bbSAart Bik     }
1404a2c9d4bbSAart Bik 
1405c8d5dcb0SAart Bik     // End a loop.
1406c8d5dcb0SAart Bik     needsUniv =
1407c8d5dcb0SAart Bik         endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv);
1408a2c9d4bbSAart Bik   }
1409a2c9d4bbSAart Bik 
1410c8d5dcb0SAart Bik   // End a loop sequence.
1411c8d5dcb0SAart Bik   endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx);
1412a2c9d4bbSAart Bik }
1413a2c9d4bbSAart Bik 
1414727a63e0SAart Bik /// Converts the result computed by the sparse kernel into the required form.
141536b66ab9SAart Bik static void genResult(Merger &merger, CodeGen &codegen,
141636b66ab9SAart Bik                       PatternRewriter &rewriter, linalg::GenericOp op) {
141736b66ab9SAart Bik   Location loc = op.getLoc();
141836b66ab9SAart Bik   OpOperand *lhs = op.getOutputOperand(0);
141936b66ab9SAart Bik   Type resType = lhs->get().getType();
142036b66ab9SAart Bik   unsigned tensor = lhs->getOperandNumber();
142136b66ab9SAart Bik   auto map = op.getTiedIndexingMap(lhs);
142236b66ab9SAart Bik   auto enc = getSparseTensorEncoding(resType);
142336b66ab9SAart Bik   Value result = codegen.buffers.back(); // value array
142436b66ab9SAart Bik   if (enc) {
142536b66ab9SAart Bik     // The sparse annotation unambigiously defines the arrays needed
142636b66ab9SAart Bik     // to "reconstruct" the sparse tensor from the storage scheme
142736b66ab9SAart Bik     // (even though lowering should never need this eventually).
142836b66ab9SAart Bik     SmallVector<Value, 4> args;
142936b66ab9SAart Bik     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
1430b1d44e59SAart Bik       AffineExpr a = map.getResult(perm(enc, d));
1431b1d44e59SAart Bik       if (a.getKind() != AffineExprKind::DimId)
1432b1d44e59SAart Bik         continue; // compound
1433b1d44e59SAart Bik       unsigned idx = a.cast<AffineDimExpr>().getPosition();
143436b66ab9SAart Bik       if (merger.isDim(tensor, idx, Dim::kSparse)) {
143536b66ab9SAart Bik         args.push_back(codegen.pointers[tensor][idx]);
143636b66ab9SAart Bik         args.push_back(codegen.indices[tensor][idx]);
143736b66ab9SAart Bik       }
143836b66ab9SAart Bik     }
143936b66ab9SAart Bik     args.push_back(result);
144036b66ab9SAart Bik     result = rewriter.create<ToTensorOp>(loc, resType, args);
144136b66ab9SAart Bik   } else {
144236b66ab9SAart Bik     // To "reconstruct" an non-annotated tensor, sipmly load it
144336b66ab9SAart Bik     // from the bufferized value.
144436b66ab9SAart Bik     result = rewriter.create<memref::TensorLoadOp>(loc, resType, result);
144536b66ab9SAart Bik   }
1446727a63e0SAart Bik   rewriter.replaceOp(op, result);
1447727a63e0SAart Bik }
1448727a63e0SAart Bik 
14495da21338SAart Bik //===----------------------------------------------------------------------===//
14505da21338SAart Bik // Sparse compiler rewriting methods.
14515da21338SAart Bik //===----------------------------------------------------------------------===//
14525da21338SAart Bik 
1453a2c9d4bbSAart Bik namespace {
1454a2c9d4bbSAart Bik 
1455a2c9d4bbSAart Bik /// Sparse rewriting rule for generic Lingalg operation.
1456a2c9d4bbSAart Bik struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1457a2c9d4bbSAart Bik public:
1458a2c9d4bbSAart Bik   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1459a2c9d4bbSAart Bik       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1460a2c9d4bbSAart Bik 
1461a2c9d4bbSAart Bik   LogicalResult matchAndRewrite(linalg::GenericOp op,
1462a2c9d4bbSAart Bik                                 PatternRewriter &rewriter) const override {
1463a2c9d4bbSAart Bik     // Detects sparse annotations and translate the per-dimension sparsity
1464a2c9d4bbSAart Bik     // information for all tensors to loop indices in the kernel.
1465a2c9d4bbSAart Bik     assert(op.getNumOutputs() == 1);
14662f2b5b7dSTobias Gysi     unsigned numTensors = op.getNumInputsAndOutputs();
1467a2c9d4bbSAart Bik     unsigned numLoops = op.iterator_types().getValue().size();
1468a2c9d4bbSAart Bik     Merger merger(numTensors, numLoops);
1469bf9ef3efSAart Bik     if (!findSparseAnnotations(merger, op))
1470bf9ef3efSAart Bik       return failure();
1471a2c9d4bbSAart Bik 
1472a2c9d4bbSAart Bik     // Computes a topologically sorted iteration graph to ensure
1473a2c9d4bbSAart Bik     // tensors are visited in natural index order. Fails on cycles.
1474a2c9d4bbSAart Bik     // This assumes that higher-level passes have already put the
1475a2c9d4bbSAart Bik     // tensors in each tensor expression in a feasible order.
1476a2c9d4bbSAart Bik     std::vector<unsigned> topSort;
1477b6d1a31cSAart Bik     if (!computeIterationGraph(merger, op, topSort,
1478b6d1a31cSAart Bik                                SortMask::kIncludeUndef |
1479b6d1a31cSAart Bik                                    SortMask::kIncludeDense) &&
1480b6d1a31cSAart Bik         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) &&
1481b6d1a31cSAart Bik         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) &&
1482b6d1a31cSAart Bik         !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly))
1483a2c9d4bbSAart Bik       return failure();
1484a2c9d4bbSAart Bik 
1485266a7414SAart Bik     // Builds the tensor expression for the Linalg operation in SSA form.
14867373cabcSAart Bik     Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
14877373cabcSAart Bik     if (!optExp.hasValue())
1488266a7414SAart Bik       return failure();
14897373cabcSAart Bik     unsigned exp = optExp.getValue();
1490a2c9d4bbSAart Bik 
1491266a7414SAart Bik     // Rejects an inadmissable tensor expression.
14927373cabcSAart Bik     if (!isAdmissableTensorExp(merger, op, exp))
149336b66ab9SAart Bik       return failure();
149436b66ab9SAart Bik 
1495a2c9d4bbSAart Bik     // Recursively generates code.
1496a2c9d4bbSAart Bik     CodeGen codegen(options, numTensors, numLoops);
1497c8d5dcb0SAart Bik     genBuffers(merger, codegen, rewriter, op);
14987373cabcSAart Bik     genStmt(merger, codegen, rewriter, op, topSort, exp, 0);
149936b66ab9SAart Bik     genResult(merger, codegen, rewriter, op);
1500a2c9d4bbSAart Bik     return success();
1501a2c9d4bbSAart Bik   }
1502a2c9d4bbSAart Bik 
1503a2c9d4bbSAart Bik private:
1504a2c9d4bbSAart Bik   /// Options to control sparse code generation.
1505a2c9d4bbSAart Bik   SparsificationOptions options;
1506a2c9d4bbSAart Bik };
1507a2c9d4bbSAart Bik 
1508a2c9d4bbSAart Bik } // namespace
1509a2c9d4bbSAart Bik 
1510a2c9d4bbSAart Bik /// Populates the given patterns list with rewriting rules required for
1511a2c9d4bbSAart Bik /// the sparsification of linear algebra operations.
1512a2c9d4bbSAart Bik void mlir::populateSparsificationPatterns(
1513a2c9d4bbSAart Bik     RewritePatternSet &patterns, const SparsificationOptions &options) {
1514a2c9d4bbSAart Bik   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1515a2c9d4bbSAart Bik }
1516