1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering sparse tensor types to actual sparse code.
10 //
11 // The concept of letting a compiler generate sparse code automatically was
12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and
13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor
14 // Algebra Compiler (TACO). The implementation in this file closely follows
15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting
16 // rule is applied to each tensor expression in linalg (MLIR's tensor index
17 // notation) where the sparsity of tensors is indicated with annotation using
18 // a per-dimension specification of sparse/dense storage together with a
19 // specification of the order on the dimensions. Subsequently, a topologically
20 // sorted iteration graph, reflecting the required order on indices with respect
21 // to the dimensions of each tensor, is constructed to ensure that all tensors
22 // are visited in natural index order. Next, iteration lattices are constructed
23 // for the tensor expression for every index in topological order. Each
24 // iteration lattice point consists of a conjunction of tensor indices together
25 // with a tensor (sub)expression that needs to be evaluated for that
26 // conjunction. Within the lattice, iteration points are ordered according to
27 // the way indices are exhausted. As such these iteration lattices drive actual
28 // sparse code generation, which consists of a tedious but relatively
29 // straightforward one-to-one mapping from iteration lattices to combinations
30 // of for-loops, while-loops, and if-statements.
31 //
32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations.
33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php).
34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou,
35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler.
36 // Proceedings of the ACM on Programming Languages, October 2017.
37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation.
38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org).
39 //
40 // Implementation detail: We use llvm::SmallVector for vectors with
41 // variable lengths and std::vector for vectors with fixed lengths.
42 //===----------------------------------------------------------------------===//
43 
44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
45 #include "mlir/Dialect/Linalg/Utils/Utils.h"
46 #include "mlir/Dialect/SCF/SCF.h"
47 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
48 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
49 #include "mlir/Dialect/StandardOps/IR/Ops.h"
50 #include "mlir/Dialect/Vector/VectorOps.h"
51 #include "mlir/IR/Matchers.h"
52 #include "mlir/IR/TensorEncoding.h"
53 #include "llvm/ADT/SmallBitVector.h"
54 
55 using namespace mlir;
56 using namespace mlir::sparse_tensor;
57 
58 namespace {
59 
60 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
61 enum class Dim { kSparse, kDense, kSingle, kUndef };
62 
63 /// Tensor expression. Represents a MLIR expression in tensor index notation.
64 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is
65 /// stored directly. For binary operations, e0 and e1 denote the index of the
66 /// children tensor expressions.
67 struct TensorExp {
68   TensorExp(Kind k, unsigned x, unsigned y, Value v)
69       : kind(k), e0(x), e1(y), val(v) {
70     assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
71            (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
72            (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
73   }
74   Kind kind;
75   /// Indices of children expression(s).
76   unsigned e0;
77   unsigned e1;
78   /// Direct link to IR for an invariant. During code generation,
79   /// field is used to cache "hoisted" loop invariant tensor loads.
80   Value val;
81 };
82 
83 /// Lattice point. Each lattice point consists of a conjunction of tensor
84 /// loop indices (encoded in a bitvector) and the index of the corresponding
85 /// tensor expression.
86 struct LatPoint {
87   LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
88     bits.set(b);
89   }
90   LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
91   /// Conjunction of tensor loop indices as bitvector. This represents
92   /// all indices involved in the tensor expression
93   llvm::BitVector bits;
94   /// Simplified conjunction of tensor loop indices as bitvector. This
95   /// represents a simplified condition under which this tensor expression
96   /// must execute. Pre-computed during codegen to avoid repeated eval.
97   llvm::BitVector simple;
98   /// Index of the tensor expresssion.
99   unsigned exp;
100 };
101 
102 /// A class to handle all iteration lattice operations. This class abstracts
103 /// away from some implementation details of storing iteration lattices and
104 /// tensor expressions. This allows for fine-tuning performance characteristics
105 /// independently from the basic algorithm if bottlenecks are identified.
106 class Merger {
107 public:
108   /// Constructs a merger for the given number of tensors and loops. The
109   /// user supplies the number of tensors involved in the kernel, with the
110   /// last tensor in this set denoting the output tensor. The merger adds an
111   /// additional synthetic tensor at the end of this set to represent all
112   /// invariant expressions in the kernel.
113   Merger(unsigned t, unsigned l)
114       : outTensor(t - 1), numTensors(t + 1), numLoops(l),
115         dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
116 
117   /// Adds a tensor expression. Returns its index.
118   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
119     unsigned e = tensorExps.size();
120     tensorExps.push_back(TensorExp(k, e0, e1, v));
121     return e;
122   }
123   unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
124 
125   /// Adds an iteration lattice point. Returns its index.
126   unsigned addLat(unsigned t, unsigned i, unsigned e) {
127     assert(t < numTensors && i < numLoops);
128     unsigned p = latPoints.size();
129     latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
130     return p;
131   }
132 
133   /// Adds a new, initially empty, set. Returns its index.
134   unsigned addSet() {
135     unsigned s = latSets.size();
136     latSets.emplace_back(SmallVector<unsigned, 16>());
137     return s;
138   }
139 
140   /// Computes a single conjunction of two lattice points by taking the "union"
141   /// of loop indices (effectively constructing a larger "intersection" of those
142   /// indices) with a newly constructed tensor (sub)expression of given kind.
143   /// Returns the index of the new lattice point.
144   unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
145     unsigned p = latPoints.size();
146     llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
147     nb |= latPoints[p1].bits;
148     unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
149     latPoints.push_back(LatPoint(nb, e));
150     return p;
151   }
152 
153   /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
154   /// cartesian product. Returns the index of the new set.
155   unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
156     unsigned s = addSet();
157     for (unsigned p0 : latSets[s0])
158       for (unsigned p1 : latSets[s1])
159         latSets[s].push_back(conjLatPoint(kind, p0, p1));
160     return s;
161   }
162 
163   /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
164   /// Returns the index of the new set.
165   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
166     unsigned s = takeConj(kind, s0, s1);
167     for (unsigned p : latSets[s0])
168       latSets[s].push_back(p);
169     for (unsigned p : latSets[s1])
170       latSets[s].push_back(p);
171     return s;
172   }
173 
174   /// Optimizes the iteration lattice points in the given set. This
175   /// method should be called right before code generation to avoid
176   /// generating redundant loops and conditions.
177   unsigned optimizeSet(unsigned s0) {
178     unsigned s = addSet();
179     assert(latSets[s0].size() != 0);
180     unsigned p0 = latSets[s0][0];
181     for (unsigned p1 : latSets[s0]) {
182       bool add = true;
183       if (p0 != p1) {
184         // Is this a straightforward copy?
185         unsigned e = latPoints[p1].exp;
186         if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
187           continue;
188         // Conjunction already covered?
189         for (unsigned p2 : latSets[s]) {
190           assert(!latGT(p1, p2)); // Lj => Li would be bad
191           if (onlyDenseDiff(p2, p1)) {
192             add = false;
193             break;
194           }
195         }
196         assert(!add || latGT(p0, p1));
197       }
198       if (add)
199         latSets[s].push_back(p1);
200     }
201     for (unsigned p : latSets[s])
202       latPoints[p].simple = simplifyCond(s, p);
203     return s;
204   }
205 
206   /// Simplifies the conditions in a conjunction of a given lattice point
207   /// within the given set using just two basic rules:
208   /// (1) multiple dense conditions are reduced to single dense, and
209   /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
210   llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
211     // First determine if this lattice point is a *singleton*, i.e.,
212     // the last point in a lattice, no other is less than this one.
213     bool isSingleton = true;
214     for (unsigned p1 : latSets[s]) {
215       if (p0 != p1 && latGT(p0, p1)) {
216         isSingleton = false;
217         break;
218       }
219     }
220     // Now apply the two basic rules.
221     llvm::BitVector simple = latPoints[p0].bits;
222     bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
223     for (unsigned b = 0, be = simple.size(); b < be; b++) {
224       if (simple[b] && !isDim(b, Dim::kSparse)) {
225         if (reset)
226           simple.reset(b);
227         reset = true;
228       }
229     }
230     return simple;
231   }
232 
233   /// Returns true if Li > Lj.
234   bool latGT(unsigned i, unsigned j) const {
235     const llvm::BitVector &bitsi = latPoints[i].bits;
236     const llvm::BitVector &bitsj = latPoints[j].bits;
237     assert(bitsi.size() == bitsj.size());
238     if (bitsi.count() > bitsj.count()) {
239       for (unsigned b = 0, be = bitsj.size(); b < be; b++)
240         if (bitsj[b] && !bitsi[b])
241           return false;
242       return true;
243     }
244     return false;
245   }
246 
247   /// Returns true if Li and Lj only differ in dense.
248   bool onlyDenseDiff(unsigned i, unsigned j) {
249     llvm::BitVector tmp = latPoints[j].bits;
250     tmp ^= latPoints[i].bits;
251     return !hasAnyDimOf(tmp, Dim::kSparse);
252   }
253 
254   /// Bit translation.
255   unsigned tensor(unsigned b) const { return b % numTensors; }
256   unsigned index(unsigned b) const { return b / numTensors; }
257 
258   /// Returns true if bit corresponds to queried dim.
259   bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
260 
261   /// Returns true if bit corresponds to index of output tensor.
262   bool isOutTensor(unsigned b, unsigned i) const {
263     return tensor(b) == outTensor && index(b) == i;
264   }
265 
266   /// Returns true if tensor access at given index has queried dim.
267   bool isDim(unsigned t, unsigned i, Dim d) const {
268     assert(t < numTensors && i < numLoops);
269     return dims[t][i] == d;
270   }
271 
272   /// Returns true if any set bit corresponds to queried dim.
273   bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
274     for (unsigned b = 0, be = bits.size(); b < be; b++)
275       if (bits[b] && isDim(b, d))
276         return true;
277     return false;
278   }
279 
280   /// Setter
281   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
282 
283   /// Getters.
284   TensorExp &exp(unsigned e) { return tensorExps[e]; }
285   LatPoint &lat(unsigned l) { return latPoints[l]; }
286   SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
287 
288 private:
289   const unsigned outTensor;
290   const unsigned numTensors;
291   const unsigned numLoops;
292 
293   std::vector<std::vector<Dim>> dims;
294   llvm::SmallVector<TensorExp, 32> tensorExps;
295   llvm::SmallVector<LatPoint, 16> latPoints;
296   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
297 };
298 
299 // Code generation.
300 struct CodeGen {
301   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
302       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
303         pointers(numTensors, std::vector<Value>(numLoops)),
304         indices(numTensors, std::vector<Value>(numLoops)),
305         highs(numTensors, std::vector<Value>(numLoops)),
306         pidxs(numTensors, std::vector<Value>(numLoops)),
307         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
308         curVecLength(1), curVecMask() {}
309   /// Sparsification options.
310   SparsificationOptions options;
311   /// Universal dense indices and upper bounds (by index). The loops array
312   /// is updated with the value of the universal dense index in the current
313   /// loop. The sizes array is set once with the inferred dimension sizes.
314   std::vector<Value> loops;
315   std::vector<Value> sizes;
316   /// Buffers for storing dense and sparse numerical values (by tensor).
317   /// This array is set once during bufferization of all tensors.
318   std::vector<Value> buffers;
319   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
320   /// This array is set once during bufferization of all sparse tensors.
321   std::vector<std::vector<Value>> pointers;
322   std::vector<std::vector<Value>> indices;
323   /// Sparse iteration information (by tensor and index). These arrays
324   /// are updated to remain current within the current loop.
325   std::vector<std::vector<Value>> highs;
326   std::vector<std::vector<Value>> pidxs;
327   std::vector<std::vector<Value>> idxs;
328   /// Current reduction, updated during code generation. When indices of a
329   /// reduction are exhausted,  all inner loops can "scalarize" the reduction.
330   // TODO: currently only done for (a chain of) innermost for-loops, where it
331   // is most effective; we could generalize to more outer and while-loops.
332   unsigned redExp;
333   Value redVal;
334   // Current vector length and mask.
335   unsigned curVecLength;
336   Value curVecMask;
337 };
338 
339 } // namespace
340 
341 // Helper method to apply dimension ordering permutation.
342 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) {
343   if (enc) {
344     auto order = enc.getDimOrdering();
345     if (order) {
346       assert(order.isPermutation());
347       return order.getDimPosition(d);
348     }
349   }
350   return d;
351 }
352 
353 // Helper method to translate dim level type to internal representation.
354 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
355   if (enc) {
356     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
357     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
358       return Dim::kSparse;
359     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
360       return Dim::kSingle;
361   }
362   return Dim::kDense;
363 }
364 
365 /// Helper method to inspect sparse encodings in the tensor types.
366 /// Fills the per-dimension sparsity information for all tensors.
367 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
368   bool annotated = false;
369   OpOperand *lhs = op.getOutputOperand(0);
370   for (OpOperand *t : op.getInputAndOutputOperands()) {
371     auto map = op.getTiedIndexingMap(t);
372     if (!map.isProjectedPermutation())
373       return false;
374     auto enc = getSparseTensorEncoding(t->get().getType());
375     if (enc)
376       annotated = true;
377     assert(map.getNumResults() == op.getRank(t));
378     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
379       unsigned idx = map.getDimPosition(perm(enc, d));
380       Dim dim = toDim(enc, d);
381       merger.setDim(t->getOperandNumber(), idx, toDim(enc, d));
382       // Accept only all-dense annotated "sparse" output.
383       // TODO: support truly sparse outputs too
384       if (t == lhs && dim != Dim::kDense)
385         return false;
386     }
387   }
388   return annotated;
389 }
390 
391 /// A DFS helper to compute a topological sort. Note that recursion is
392 /// bounded by the number of implicit loops, which is always small.
393 /// Returns false when a cycle is detected.
394 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
395                        std::vector<unsigned> &topSort,
396                        std::vector<std::vector<bool>> &adjM) {
397   if (visit[i] != 0)
398     return visit[i] != 1; // 1 denotes cycle!
399   visit[i] = 1;
400   for (unsigned j = 0, e = visit.size(); j < e; j++)
401     if (adjM[i][j])
402       if (!topSortDFS(j, visit, topSort, adjM))
403         return false;
404   visit[i] = 2;
405   topSort.push_back(i);
406   return true;
407 }
408 
409 /// Computes a topologically sorted iteration graph for the linalg operation.
410 /// Ensures all tensors are visited in natural index order. This is essential
411 /// for sparse storage formats since these only support access along fixed
412 /// dimensions. Even for dense storage formats, however, the natural index
413 /// order yields innermost unit-stride access with better spatial locality.
414 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
415                                   std::vector<unsigned> &topSort,
416                                   bool sparseOnly) {
417   // Set up an n x n from/to adjacency matrix of the iteration graph
418   // for the implicit loop indices i_0 .. i_n-1.
419   unsigned n = op.getNumLoops();
420   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
421 
422   // Iterate over the indexing maps of every tensor in the tensor expression.
423   for (OpOperand *t : op.getInputAndOutputOperands()) {
424     auto map = op.getTiedIndexingMap(t);
425     auto enc = getSparseTensorEncoding(t->get().getType());
426     assert(map.getNumDims() == n);
427     // Skip dense tensor constraints when sparse only is requested.
428     if (sparseOnly && !enc)
429       continue;
430     // Each tensor expression and optional dimension ordering (row-major
431     // by default) puts an ordering constraint on the loop indices. For
432     // example, the tensor expresion A_ijk forces the ordering i < j < k
433     // on the loop indices if no explicit dimension ordering is given.
434     for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
435       unsigned f = map.getDimPosition(perm(enc, d - 1));
436       unsigned t = map.getDimPosition(perm(enc, d));
437       adjM[f][t] = true;
438     }
439   }
440 
441   // Topologically sort the iteration graph to determine loop order.
442   // Report failure for a cyclic iteration graph.
443   topSort.clear();
444   topSort.reserve(n);
445   std::vector<unsigned> visit(n, 0);
446   for (unsigned i = 0; i < n; i++)
447     if (visit[i] == 0)
448       if (!topSortDFS(i, visit, topSort, adjM))
449         return false; // cycle!
450   std::reverse(std::begin(topSort), std::end(topSort));
451   return true;
452 }
453 
454 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
455 /// This simplifies constructing (sub)expressions during iteration lattice
456 /// building (compared to using the SSA representation everywhere).
457 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
458                                          Value val) {
459   if (auto arg = val.dyn_cast<BlockArgument>()) {
460     unsigned argN = arg.getArgNumber();
461     // Any parameter of the generic op is considered a tensor,
462     // indexed by the implicit loop bounds.
463     if (arg.getOwner()->getParentOp() == op)
464       return merger.addExp(Kind::kTensor, argN);
465     // Any parameter of a higher op is invariant.
466     return merger.addExp(Kind::kInvariant, val);
467   }
468   Operation *def = val.getDefiningOp();
469   if (def->getBlock() != &op.region().front()) {
470     // Something defined outside is invariant.
471     return merger.addExp(Kind::kInvariant, val);
472   } else if (def->getNumOperands() == 2) {
473     // Construct binary operations if subexpressions could be built.
474     auto x = buildTensorExp(merger, op, def->getOperand(0));
475     auto y = buildTensorExp(merger, op, def->getOperand(1));
476     if (x.hasValue() && y.hasValue()) {
477       unsigned e0 = x.getValue();
478       unsigned e1 = y.getValue();
479       if (isa<MulFOp>(def))
480         return merger.addExp(Kind::kMulF, e0, e1);
481       if (isa<MulIOp>(def))
482         return merger.addExp(Kind::kMulI, e0, e1);
483       if (isa<AddFOp>(def))
484         return merger.addExp(Kind::kAddF, e0, e1);
485       if (isa<AddIOp>(def))
486         return merger.addExp(Kind::kAddI, e0, e1);
487     }
488   }
489   // Cannot build (yet).
490   return None;
491 }
492 
493 /// Builds the iteration lattices in a bottom-up traversal given the remaining
494 /// tensor (sub)expression and the next loop index in the iteration graph.
495 static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
496                               unsigned exp, unsigned idx) {
497   Kind kind = merger.exp(exp).kind;
498   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
499     // Either the index is really used in the tensor expression, or it is
500     // set to the undefined index in that dimension. An invariant expression
501     // is set to a synthetic tensor with undefined indices only.
502     unsigned s = merger.addSet();
503     unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0
504                                        : op.getNumInputsAndOutputs();
505     merger.set(s).push_back(merger.addLat(t, idx, exp));
506     return s;
507   }
508   unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx);
509   unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx);
510   switch (kind) {
511   case Kind::kTensor:
512   case Kind::kInvariant:
513     llvm_unreachable("handled above");
514   case Kind::kMulF:
515   case Kind::kMulI:
516     return merger.takeConj(kind, s0, s1);
517   case Kind::kAddF:
518   case Kind::kAddI:
519     return merger.takeDisj(kind, s0, s1);
520   }
521   llvm_unreachable("unexpected expression kind");
522 }
523 
524 /// Maps sparse integer option to actual integral storage type.
525 static Type genIntType(PatternRewriter &rewriter, unsigned width) {
526   if (width == 0)
527     return rewriter.getIndexType();
528   return rewriter.getIntegerType(width);
529 }
530 
531 /// Detects in-place annotation on tensor argument.
532 static bool getInPlace(Value val) {
533   if (auto arg = val.dyn_cast<BlockArgument>())
534     if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
535       if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
536               arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName))
537         return attr.getValue();
538   return false;
539 }
540 
541 /// Generates buffer for the output tensor.
542 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
543                              linalg::GenericOp op, MemRefType denseTp,
544                              ArrayRef<Value> args) {
545   Location loc = op.getLoc();
546   Value tensor = op.getOutputOperand(0)->get();
547   // The output tensor simply could materialize from the buffer that will
548   // be generated for the tensor present in the outs() clause. This has
549   // the major advantage that the sparse kernel only updates the nonzero
550   // positions for the output tensor.
551   if (getInPlace(tensor))
552     return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
553   // By default, a new buffer is allocated which is initialized to the
554   // tensor defined in the outs() clause. This is always correct but
555   // introduces a dense initialization component that may negatively
556   // impact the running complexity of the sparse kernel.
557   Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
558   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
559   rewriter.create<linalg::CopyOp>(loc, init, alloc);
560   return alloc;
561 }
562 
563 /// Local bufferization of all dense and sparse data structures.
564 /// This code enables testing the first prototype sparse compiler.
565 // TODO: replace this with a proliferated bufferization strategy
566 static bool genBuffers(Merger &merger, CodeGen &codegen,
567                        PatternRewriter &rewriter, linalg::GenericOp op) {
568   Location loc = op.getLoc();
569   assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
570   // For every tensor, find lower and upper bound on dimensions, set the
571   // same bounds on loop indices, and obtain dense or sparse buffer(s).
572   SmallVector<Value, 4> args;
573   for (OpOperand *t : op.getInputAndOutputOperands()) {
574     unsigned tensor = t->getOperandNumber();
575     auto shape = op.getShape(t);
576     auto map = op.getTiedIndexingMap(t);
577     auto enc = getSparseTensorEncoding(t->get().getType());
578     // Scan all dimensions of current tensor.
579     args.clear();
580     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
581       unsigned idx = map.getDimPosition(perm(enc, d));
582       // Handle sparse storage schemes.
583       if (merger.isDim(tensor, idx, Dim::kSparse)) {
584         auto dynShape = {ShapedType::kDynamicSize};
585         auto ptrTp = MemRefType::get(
586             dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
587         auto indTp = MemRefType::get(
588             dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
589         Value dim = rewriter.create<ConstantIndexOp>(loc, d);
590         // Generate sparse primitives to obtains pointer and indices.
591         codegen.pointers[tensor][idx] =
592             rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
593         codegen.indices[tensor][idx] =
594             rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim);
595       }
596       // Find lower and upper bound in current dimension.
597       Value up;
598       if (shape[d] == MemRefType::kDynamicSize) {
599         up = rewriter.create<memref::DimOp>(loc, t->get(), d);
600         args.push_back(up);
601       } else {
602         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
603       }
604       codegen.sizes[idx] = codegen.highs[tensor][idx] = up;
605     }
606     // Perform the required bufferization. Dense inputs materialize
607     // from the input tensors. Dense outputs need special handling.
608     // Sparse inputs use sparse primitives to obtain the values.
609     // We also accept in-place all-dense annotated "sparse" outputs.
610     Type elementType = getElementTypeOrSelf(t->get().getType());
611     if (!enc) {
612       // Non-annotated dense tensors.
613       auto denseTp = MemRefType::get(shape, elementType);
614       if (tensor < op.getNumInputs())
615         codegen.buffers[tensor] =
616             rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get());
617       else
618         codegen.buffers[tensor] =
619             genOutputBuffer(codegen, rewriter, op, denseTp, args);
620     } else {
621       // Annotated sparse tensors.
622       if (tensor == op.getNumInputs() && !getInPlace(t->get()))
623         return false; // reject output if not in-place
624       auto dynShape = {ShapedType::kDynamicSize};
625       auto sparseTp = MemRefType::get(dynShape, elementType);
626       codegen.buffers[tensor] =
627           rewriter.create<ToValuesOp>(loc, sparseTp, t->get());
628     }
629   }
630   return true;
631 }
632 
633 /// Constructs vector type.
634 static VectorType vectorType(CodeGen &codegen, Type etp) {
635   return VectorType::get(codegen.curVecLength, etp);
636 }
637 
638 /// Constructs vector type from pointer.
639 static VectorType vectorType(CodeGen &codegen, Value ptr) {
640   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
641 }
642 
643 /// Constructs vector iteration mask.
644 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
645                            Value iv, Value lo, Value hi, Value step) {
646   Location loc = iv.getLoc();
647   VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1));
648   // Special case if the vector length evenly divides the trip count (for
649   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
650   // so that all subsequent masked memory operations are immediately folded
651   // into unconditional memory operations.
652   IntegerAttr loInt, hiInt, stepInt;
653   if (matchPattern(lo, m_Constant(&loInt)) &&
654       matchPattern(hi, m_Constant(&hiInt)) &&
655       matchPattern(step, m_Constant(&stepInt))) {
656     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
657       return rewriter.create<vector::BroadcastOp>(
658           loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1));
659   }
660   // Otherwise, generate a vector mask that avoids overrunning the upperbound
661   // during vector execution. Here we rely on subsequent loop optimizations to
662   // avoid executing the mask in all iterations, for example, by splitting the
663   // loop into an unconditional vector loop and a scalar cleanup loop.
664   Value end = rewriter.create<SubIOp>(loc, hi, iv);
665   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
666 }
667 
668 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
669 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
670                            Value ptr, ArrayRef<Value> args) {
671   Location loc = ptr.getLoc();
672   VectorType vtp = vectorType(codegen, ptr);
673   Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
674   if (args.back().getType().isa<VectorType>()) {
675     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
676     Value indexVec = args.back();
677     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
678     return rewriter.create<vector::GatherOp>(
679         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
680   }
681   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
682                                                codegen.curVecMask, pass);
683 }
684 
685 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
686 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
687                            Value rhs, Value ptr, ArrayRef<Value> args) {
688   Location loc = ptr.getLoc();
689   if (args.back().getType().isa<VectorType>()) {
690     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
691     Value indexVec = args.back();
692     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
693     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
694                                        codegen.curVecMask, rhs);
695     return;
696   }
697   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
698                                          rhs);
699 }
700 
701 /// Generates a vectorized invariant. Here we rely on subsequent loop
702 /// optimizations to hoist the invariant broadcast out of the vector loop.
703 static Value genVectorInvariantValue(CodeGen &codegen,
704                                      PatternRewriter &rewriter, Value val) {
705   VectorType vtp = vectorType(codegen, val.getType());
706   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
707 }
708 
709 /// Generates a load on a dense or sparse tensor.
710 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
711                            PatternRewriter &rewriter, linalg::GenericOp op,
712                            unsigned exp) {
713   // Test if the load was hoisted to a higher loop nest.
714   Value val = merger.exp(exp).val;
715   if (val) {
716     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
717       return genVectorInvariantValue(codegen, rewriter, val);
718     return val;
719   }
720   // Actual load.
721   SmallVector<Value, 4> args;
722   OpOperand *t = merger.exp(exp).e0 < op.getNumInputs()
723                      ? op.getInputOperand(merger.exp(exp).e0)
724                      : op.getOutputOperand(0);
725   unsigned tensor = t->getOperandNumber();
726   auto map = op.getTiedIndexingMap(t);
727   auto enc = getSparseTensorEncoding(t->get().getType());
728   unsigned rank = map.getNumResults();
729   if (enc) {
730     unsigned idx = map.getDimPosition(perm(enc, rank - 1));
731     assert(codegen.pidxs[tensor][idx] != nullptr);
732     args.push_back(codegen.pidxs[tensor][idx]); // position index
733   } else {
734     for (unsigned d = 0; d < rank; d++) {
735       unsigned idx = map.getDimPosition(d);
736       args.push_back(codegen.loops[idx]); // universal dense index
737     }
738   }
739   Location loc = op.getLoc();
740   Value ptr = codegen.buffers[tensor];
741   if (codegen.curVecLength > 1)
742     return genVectorLoad(codegen, rewriter, ptr, args);
743   return rewriter.create<memref::LoadOp>(loc, ptr, args);
744 }
745 
746 /// Generates a store on a dense or sparse tensor.
747 static void genTensorStore(Merger &merger, CodeGen &codegen,
748                            PatternRewriter &rewriter, linalg::GenericOp op,
749                            OpOperand *t, Value rhs) {
750   Location loc = op.getLoc();
751   // Test if this is a scalarized reduction.
752   OpOperand *lhs = op.getOutputOperand(0);
753   if (lhs == t && codegen.redVal) {
754     if (codegen.curVecLength > 1)
755       rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs,
756                                       codegen.redVal);
757     codegen.redVal = rhs;
758     return;
759   }
760   // Actual store.
761   SmallVector<Value, 4> args;
762   unsigned tensor = t->getOperandNumber();
763   auto map = op.getTiedIndexingMap(t);
764   auto enc = getSparseTensorEncoding(t->get().getType());
765   unsigned rank = map.getNumResults();
766   if (enc) {
767     unsigned idx = map.getDimPosition(perm(enc, rank - 1));
768     assert(codegen.pidxs[tensor][idx] != nullptr);
769     args.push_back(codegen.pidxs[tensor][idx]); // position index
770   } else {
771     for (unsigned d = 0; d < rank; d++) {
772       unsigned idx = map.getDimPosition(d);
773       args.push_back(codegen.loops[idx]); // universal dense index
774     }
775   }
776   Value ptr = codegen.buffers[tensor];
777   if (codegen.curVecLength > 1)
778     genVectorStore(codegen, rewriter, rhs, ptr, args);
779   else
780     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
781 }
782 
783 /// Generates a pointer/index load from the sparse storage scheme. Narrower
784 /// data types need to be zero extended before casting the value into the
785 /// index type used for looping and indexing.
786 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
787                      Value ptr, Value s) {
788   // See https://llvm.org/docs/GetElementPtr.html for some background on
789   // the complications described below.
790   if (codegen.curVecLength > 1) {
791     // Since the index vector is used in a subsequent gather/scatter operations,
792     // which effectively defines an unsigned pointer + signed index, we must
793     // zero extend the vector to an index width. For 8-bit and 16-bit values,
794     // an 32-bit index width suffices. For 32-bit values, zero extending the
795     // elements into 64-bit loses some performance since the 32-bit indexed
796     // gather/scatter is more efficient than the 64-bit index variant (if the
797     // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
798     // preserve this performance). For 64-bit values, there is no good way
799     // to state that the indices are unsigned, with creates the potential of
800     // incorrect address calculations in the unlikely case we need such
801     // extremely large offsets.
802     Type etp = ptr.getType().cast<MemRefType>().getElementType();
803     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
804     if (!etp.isa<IndexType>()) {
805       if (etp.getIntOrFloatBitWidth() < 32)
806         vload = rewriter.create<ZeroExtendIOp>(
807             loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
808       else if (etp.getIntOrFloatBitWidth() < 64 &&
809                !codegen.options.enableSIMDIndex32)
810         vload = rewriter.create<ZeroExtendIOp>(
811             loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
812     }
813     return vload;
814   }
815   // For the scalar case, we simply zero extend narrower indices into 64-bit
816   // values before casting to index without a performance penalty. Here too,
817   // however, indices that already are 64-bit, in theory, cannot express the
818   // full range as explained above.
819   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
820   if (!load.getType().isa<IndexType>()) {
821     if (load.getType().getIntOrFloatBitWidth() < 64)
822       load = rewriter.create<ZeroExtendIOp>(loc, load,
823                                             rewriter.getIntegerType(64));
824     load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
825   }
826   return load;
827 }
828 
829 /// Generates an invariant value.
830 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
831                                PatternRewriter &rewriter, unsigned exp) {
832   Value val = merger.exp(exp).val;
833   if (codegen.curVecLength > 1)
834     return genVectorInvariantValue(codegen, rewriter, val);
835   return val;
836 }
837 
838 /// Generates an address computation "sz * p + i".
839 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
840                         Location loc, Value size, Value p, Value i) {
841   Value mul = rewriter.create<MulIOp>(loc, size, p);
842   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
843     Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType());
844     mul = genVectorInvariantValue(codegen, rewriter, inv);
845   }
846   return rewriter.create<AddIOp>(loc, mul, i);
847 }
848 
849 /// Generates start of a reduction.
850 static Value genReductionStart(Merger &merger, CodeGen &codegen,
851                                PatternRewriter &rewriter,
852                                linalg::GenericOp op) {
853   if (codegen.redVal)
854     return codegen.redVal; // chained with previous for-loop
855   if (codegen.curVecLength > 1) {
856     // TODO: assumes + reductions for now
857     VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]);
858     return rewriter.create<ConstantOp>(op.getLoc(), vtp,
859                                        rewriter.getZeroAttr(vtp));
860   }
861   return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
862 }
863 
864 /// Generates end of a reduction.
865 static void genReductionEnd(Merger &merger, CodeGen &codegen,
866                             PatternRewriter &rewriter, linalg::GenericOp op) {
867   Value red = codegen.redVal;
868   if (!red)
869     return;
870   assert(codegen.curVecLength == 1);
871   codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
872   OpOperand *lhs = op.getOutputOperand(0);
873   if (auto vtp = red.getType().dyn_cast<VectorType>()) {
874     // TODO: assumes + reductions for now
875     StringAttr kind = rewriter.getStringAttr("add");
876     Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
877     // Integer reductions don't accept an accumulator.
878     if (vtp.getElementType().isa<IntegerType>()) {
879       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
880                                                  kind, red, ValueRange{});
881       red = rewriter.create<AddIOp>(op.getLoc(), red, ld);
882     } else {
883       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
884                                                  kind, red, ld);
885     }
886   }
887   genTensorStore(merger, codegen, rewriter, op, lhs, red);
888 }
889 
890 /// Recursively generates tensor expression.
891 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
892                     linalg::GenericOp op, unsigned exp) {
893   if (merger.exp(exp).kind == Kind::kTensor)
894     return genTensorLoad(merger, codegen, rewriter, op, exp);
895   else if (merger.exp(exp).kind == Kind::kInvariant)
896     return genInvariantValue(merger, codegen, rewriter, exp);
897   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
898   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
899   switch (merger.exp(exp).kind) {
900   case Kind::kTensor:
901   case Kind::kInvariant:
902     llvm_unreachable("handled above");
903   case Kind::kMulF:
904     return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
905   case Kind::kMulI:
906     return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
907   case Kind::kAddF:
908     return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
909   case Kind::kAddI:
910     return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
911   }
912   llvm_unreachable("unexpected expression kind");
913 }
914 
915 /// Hoists loop invariant tensor loads for which indices have been exhausted.
916 static void genInvariants(Merger &merger, CodeGen &codegen,
917                           PatternRewriter &rewriter, linalg::GenericOp op,
918                           unsigned exp, unsigned ldx, bool hoist) {
919   if (merger.exp(exp).kind == Kind::kTensor) {
920     // Inspect tensor indices.
921     bool atLevel = ldx == -1u;
922     OpOperand *tensor = merger.exp(exp).e0 < op.getNumInputs()
923                             ? op.getInputOperand(merger.exp(exp).e0)
924                             : op.getOutputOperand(0);
925     auto map = op.getTiedIndexingMap(tensor);
926     auto enc = getSparseTensorEncoding(tensor->get().getType());
927     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
928       unsigned idx = map.getDimPosition(perm(enc, d));
929       if (!codegen.loops[idx])
930         return; // still in play
931       else if (idx == ldx)
932         atLevel = true;
933     }
934     // All exhausted at this level (atLevel denotes exactly at this level).
935     OpOperand *lhs = op.getOutputOperand(0);
936     if (lhs == tensor) {
937       codegen.redExp = hoist ? exp : -1u;
938     } else if (atLevel) {
939       merger.exp(exp).val =
940           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
941     }
942   } else if (merger.exp(exp).kind != Kind::kInvariant) {
943     // Traverse into the binary operations. Note that we only hoist
944     // tensor loads, since subsequent MLIR/LLVM passes know how to
945     // deal with all other kinds of derived loop invariants.
946     unsigned e0 = merger.exp(exp).e0;
947     unsigned e1 = merger.exp(exp).e1;
948     genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist);
949     genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist);
950   }
951 }
952 
953 /// Generates initialization code for the subsequent loop sequence at
954 /// current index level. Returns true if the loop sequence needs to
955 /// maintain the universal index.
956 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
957                     linalg::GenericOp op, std::vector<unsigned> &topSort,
958                     unsigned at, llvm::BitVector &inits) {
959   bool needsUniv = false;
960   Location loc = op.getLoc();
961   unsigned idx = topSort[at];
962 
963   // Initialize sparse positions.
964   for (unsigned b = 0, be = inits.size(); b < be; b++) {
965     if (inits[b]) {
966       unsigned tensor = merger.tensor(b);
967       assert(idx == merger.index(b));
968       if (merger.isDim(b, Dim::kSparse)) {
969         // Initialize sparse index.
970         unsigned pat = at;
971         for (; pat != 0; pat--) {
972           if (codegen.pidxs[tensor][topSort[pat - 1]])
973             break;
974         }
975         Value ptr = codegen.pointers[tensor][idx];
976         Value one = rewriter.create<ConstantIndexOp>(loc, 1);
977         Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
978                               : codegen.pidxs[tensor][topSort[pat - 1]];
979         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
980         Value p1 = rewriter.create<AddIOp>(loc, p0, one);
981         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
982       } else {
983         // Dense index still in play.
984         needsUniv = true;
985       }
986     }
987   }
988 
989   // Initialize the universal dense index.
990   codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0);
991   return needsUniv;
992 }
993 
994 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
995 /// operation is a candidate. Whether it is actually converted to SIMD code
996 /// depends on the requested strategy.
997 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
998   switch (codegen.options.vectorizationStrategy) {
999   case SparseVectorizationStrategy::kNone:
1000     return false;
1001   case SparseVectorizationStrategy::kDenseInnerLoop:
1002     return isInner && !isSparse;
1003   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
1004     return isInner;
1005   }
1006   llvm_unreachable("unexpected vectorization strategy");
1007 }
1008 
1009 /// Returns parallelization strategy. Any implicit loop in the Linalg operation
1010 /// that is marked "parallel" is a candidate. Whether it is actually converted
1011 /// to a parallel operation depends on the requested strategy.
1012 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
1013                           bool isSparse, bool isVector) {
1014   switch (codegen.options.parallelizationStrategy) {
1015   case SparseParallelizationStrategy::kNone:
1016     return false;
1017   case SparseParallelizationStrategy::kDenseOuterLoop:
1018     return isOuter && !isSparse && !isReduction && !isVector;
1019   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
1020     return isOuter && !isReduction && !isVector;
1021   case SparseParallelizationStrategy::kDenseAnyLoop:
1022     return !isSparse && !isReduction && !isVector;
1023   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
1024     return !isReduction && !isVector;
1025   }
1026   llvm_unreachable("unexpected parallelization strategy");
1027 }
1028 
1029 /// Checks unit strides for dense tensors. The iteration graph may have ignored
1030 /// dense access patterns in order to avoid cycles (sparse access patterns are
1031 /// always placed innermost), but that means dense access has become strided.
1032 /// For now, we reject vectorization of such cases.
1033 /// TODO: implement strided load/stores on dense arrays
1034 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
1035                              unsigned idx) {
1036   for (OpOperand *t : op.getInputAndOutputOperands()) {
1037     if (!getSparseTensorEncoding(t->get().getType())) {
1038       auto map = op.getTiedIndexingMap(t);
1039       for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
1040         if (map.getDimPosition(d) == idx && d != rank - 1)
1041           return false;
1042       }
1043     }
1044   }
1045   return true;
1046 }
1047 
1048 /// Generates a for-loop on a single index.
1049 static Operation *genFor(Merger &merger, CodeGen &codegen,
1050                          PatternRewriter &rewriter, linalg::GenericOp op,
1051                          bool isOuter, bool isInner, unsigned idx,
1052                          llvm::BitVector &indices) {
1053   unsigned fb = indices.find_first();
1054   unsigned tensor = merger.tensor(fb);
1055   assert(idx == merger.index(fb));
1056   auto iteratorTypes = op.iterator_types().getValue();
1057   bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]);
1058   bool isSparse = merger.isDim(fb, Dim::kSparse);
1059   bool isVector = isVectorFor(codegen, isInner, isSparse) &&
1060                   denseUnitStrides(merger, op, idx);
1061   bool isParallel =
1062       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
1063 
1064   // Prepare vector length.
1065   if (isVector)
1066     codegen.curVecLength = codegen.options.vectorLength;
1067 
1068   // Loop bounds and increment.
1069   Location loc = op.getLoc();
1070   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
1071   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
1072   Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength);
1073 
1074   // Emit a parallel loop.
1075   if (isParallel) {
1076     assert(!isVector);
1077     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
1078     if (isSparse)
1079       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
1080     else
1081       codegen.loops[idx] = parOp.getInductionVars()[0];
1082     rewriter.setInsertionPointToStart(parOp.getBody());
1083     return parOp;
1084   }
1085 
1086   // Emit a sequential loop, potentially with a scalarized reduction.
1087   bool scalarRed = isInner && codegen.redExp != -1u;
1088   SmallVector<Value, 4> operands;
1089   if (scalarRed) {
1090     Value load = genReductionStart(merger, codegen, rewriter, op);
1091     operands.push_back(load);
1092   }
1093   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
1094   if (scalarRed) {
1095     codegen.redVal = merger.exp(codegen.redExp).val =
1096         forOp.getRegionIterArgs().front();
1097   }
1098   // Assign induction variable to sparse or dense index.
1099   Value iv = forOp.getInductionVar();
1100   if (isSparse)
1101     codegen.pidxs[tensor][idx] = iv;
1102   else
1103     codegen.loops[idx] = iv;
1104   rewriter.setInsertionPointToStart(forOp.getBody());
1105   // Share vector iteration mask between all subsequent loads/stores.
1106   if (isVector)
1107     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
1108   return forOp;
1109 }
1110 
1111 /// Emit a while-loop for co-iteration over multiple indices.
1112 static Operation *genWhile(Merger &merger, CodeGen &codegen,
1113                            PatternRewriter &rewriter, linalg::GenericOp op,
1114                            unsigned idx, bool needsUniv,
1115                            llvm::BitVector &indices) {
1116   SmallVector<Type, 4> types;
1117   SmallVector<Value, 4> operands;
1118   // Construct the while-loop with a parameter for each index.
1119   Type indexType = rewriter.getIndexType();
1120   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1121     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1122       unsigned tensor = merger.tensor(b);
1123       assert(idx == merger.index(b));
1124       types.push_back(indexType);
1125       assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
1126              "type mismatch for sparse index");
1127       operands.push_back(codegen.pidxs[tensor][idx]);
1128     }
1129   }
1130   if (needsUniv) {
1131     types.push_back(indexType);
1132     assert(codegen.loops[idx].getType().isa<IndexType>() &&
1133            "type mismatch for universal index");
1134     operands.push_back(codegen.loops[idx]);
1135   }
1136   Location loc = op.getLoc();
1137   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
1138   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
1139   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
1140 
1141   // Build the "before" region, which effectively consists
1142   // of a conjunction of "i < upper" tests on all induction.
1143   rewriter.setInsertionPointToStart(&whileOp.before().front());
1144   Value cond;
1145   unsigned o = 0;
1146   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1147     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1148       unsigned tensor = merger.tensor(b);
1149       assert(idx == merger.index(b));
1150       Value op1 = before->getArgument(o);
1151       Value op2 = codegen.highs[tensor][idx];
1152       Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2);
1153       cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc;
1154       codegen.pidxs[tensor][idx] = after->getArgument(o++);
1155     }
1156   }
1157   if (needsUniv)
1158     codegen.loops[idx] = after->getArgument(o++);
1159   assert(o == operands.size());
1160   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
1161   rewriter.setInsertionPointToStart(&whileOp.after().front());
1162   return whileOp;
1163 }
1164 
1165 /// Generates a for-loop or a while-loop, depending on whether it implements
1166 /// singleton iteration or co-iteration over the given conjunction.
1167 static Operation *genLoop(Merger &merger, CodeGen &codegen,
1168                           PatternRewriter &rewriter, linalg::GenericOp op,
1169                           std::vector<unsigned> &topSort, unsigned at,
1170                           bool needsUniv, llvm::BitVector &indices) {
1171   unsigned idx = topSort[at];
1172   if (indices.count() == 1) {
1173     bool isOuter = at == 0;
1174     bool isInner = at == topSort.size() - 1;
1175     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
1176                   indices);
1177   }
1178   genReductionEnd(merger, codegen, rewriter, op); // cannot chain
1179   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
1180 }
1181 
1182 /// Generates the local variables for this loop, consisting of the sparse
1183 /// indices, restored universal dense index, and dense positions.
1184 static void genLocals(Merger &merger, CodeGen &codegen,
1185                       PatternRewriter &rewriter, linalg::GenericOp op,
1186                       std::vector<unsigned> &topSort, unsigned at,
1187                       bool needsUniv, llvm::BitVector &locals) {
1188   Location loc = op.getLoc();
1189   unsigned idx = topSort[at];
1190 
1191   // Initialize sparse indices.
1192   Value min;
1193   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1194     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1195       unsigned tensor = merger.tensor(b);
1196       assert(idx == merger.index(b));
1197       Value ptr = codegen.indices[tensor][idx];
1198       Value s = codegen.pidxs[tensor][idx];
1199       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1200       codegen.idxs[tensor][idx] = load;
1201       if (!needsUniv) {
1202         if (min) {
1203           Value cmp =
1204               rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
1205           min = rewriter.create<SelectOp>(loc, cmp, load, min);
1206         } else {
1207           min = load;
1208         }
1209       }
1210     }
1211   }
1212 
1213   // Merge dense universal index over minimum.
1214   if (min) {
1215     assert(!needsUniv);
1216     codegen.loops[idx] = min;
1217   }
1218 
1219   // Initialize dense positions. Note that we generate dense indices of the
1220   // output tensor unconditionally, since they may not appear in the lattice,
1221   // but may be needed for linearized codegen.
1222   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1223     if ((locals[b] || merger.isOutTensor(b, idx)) &&
1224         merger.isDim(b, Dim::kDense)) {
1225       unsigned tensor = merger.tensor(b);
1226       assert(idx == merger.index(b));
1227       unsigned pat = at;
1228       for (; pat != 0; pat--)
1229         if (codegen.pidxs[tensor][topSort[pat - 1]])
1230           break;
1231       Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
1232                            : codegen.pidxs[tensor][topSort[pat - 1]];
1233       codegen.pidxs[tensor][idx] = genAddress(
1234           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1235     }
1236   }
1237 }
1238 
1239 /// Generates the induction structure for a while-loop.
1240 static void genWhileInduction(Merger &merger, CodeGen &codegen,
1241                               PatternRewriter &rewriter, linalg::GenericOp op,
1242                               unsigned idx, bool needsUniv,
1243                               llvm::BitVector &induction, ResultRange results) {
1244   Location loc = op.getLoc();
1245   unsigned o = 0;
1246   SmallVector<Value, 4> operands;
1247   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
1248   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1249     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1250       unsigned tensor = merger.tensor(b);
1251       assert(idx == merger.index(b));
1252       Value op1 = codegen.idxs[tensor][idx];
1253       Value op2 = codegen.loops[idx];
1254       Value op3 = codegen.pidxs[tensor][idx];
1255       Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1256       Value add = rewriter.create<AddIOp>(loc, op3, one);
1257       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
1258       codegen.pidxs[tensor][idx] = results[o++];
1259     }
1260   }
1261   if (needsUniv) {
1262     operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one));
1263     codegen.loops[idx] = results[o++];
1264   }
1265   assert(o == operands.size());
1266   rewriter.create<scf::YieldOp>(loc, operands);
1267 }
1268 
1269 /// Generates a single if-statement within a while-loop.
1270 static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1271                        PatternRewriter &rewriter, linalg::GenericOp op,
1272                        unsigned idx, llvm::BitVector &conditions) {
1273   Location loc = op.getLoc();
1274   Value cond;
1275   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1276     if (conditions[b]) {
1277       unsigned tensor = merger.tensor(b);
1278       assert(idx == merger.index(b));
1279       Value clause;
1280       if (merger.isDim(b, Dim::kSparse)) {
1281         Value op1 = codegen.idxs[tensor][idx];
1282         Value op2 = codegen.loops[idx];
1283         clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1284       } else {
1285         clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true
1286       }
1287       cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause;
1288     }
1289   }
1290   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
1291   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
1292   return ifOp;
1293 }
1294 
1295 /// Recursively generates code while computing iteration lattices in order
1296 /// to manage the complexity of implementing co-iteration over unions
1297 /// and intersections of sparse iterations spaces.
1298 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1299                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1300                     unsigned exp, unsigned at) {
1301   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1302   if (at == topSort.size()) {
1303     OpOperand *lhs = op.getOutputOperand(0);
1304     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1305     genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
1306     return;
1307   }
1308   assert(codegen.curVecLength == 1);
1309 
1310   // Construct iteration lattices for current loop index, with L0 at top.
1311   // Then emit initialization code for the loop sequence at this level.
1312   // We maintain the universal dense index if dense indices are still
1313   // in play for a non-singleton loop sequence.
1314   Location loc = op.getLoc();
1315   unsigned idx = topSort[at];
1316   unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx));
1317   unsigned lsize = merger.set(lts).size();
1318   assert(lsize != 0);
1319   unsigned l0 = merger.set(lts)[0];
1320   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1321   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
1322   bool needsUniv = false;
1323   if (genInit(merger, codegen, rewriter, op, topSort, at,
1324               merger.lat(l0).bits)) {
1325     // Maintain the universal index only if it is actually
1326     // consumed by a subsequent lattice point.
1327     for (unsigned i = 1; i < lsize; i++) {
1328       unsigned li = merger.set(lts)[i];
1329       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) {
1330         needsUniv = true;
1331         break;
1332       }
1333     }
1334   }
1335 
1336   // Emit a loop for every lattice point L0 >= Li.
1337   for (unsigned i = 0; i < lsize; i++) {
1338     unsigned li = merger.set(lts)[i];
1339 
1340     // Emit loop.
1341     codegen.curVecLength = 1;
1342     llvm::BitVector indices = merger.lat(li).simple;
1343     Operation *loop =
1344         genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
1345     genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1346               merger.lat(li).bits);
1347 
1348     // Visit all lattices points with Li >= Lj to generate the
1349     // loop-body, possibly with if statements for coiteration.
1350     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1351     for (unsigned j = 0; j < lsize; j++) {
1352       unsigned lj = merger.set(lts)[j];
1353       unsigned ej = merger.lat(lj).exp;
1354       if (li == lj || merger.latGT(li, lj)) {
1355         // Recurse into body of each branch.
1356         if (isWhile) {
1357           scf::IfOp ifOp =
1358               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1359           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1360           rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
1361         } else {
1362           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1363         }
1364       }
1365     }
1366 
1367     // Wrap-up induction and restore insertion point.
1368     if (isWhile) {
1369       scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
1370       rewriter.setInsertionPointToEnd(&whileOp.after().front());
1371       genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
1372                         merger.lat(li).bits, whileOp.results());
1373     } else {
1374       needsUniv = false;
1375       if (codegen.redVal) {
1376         rewriter.create<scf::YieldOp>(loc, codegen.redVal);
1377         codegen.redVal = loop->getResult(0);
1378       }
1379     }
1380     rewriter.setInsertionPointAfter(loop);
1381   }
1382 
1383   // Wrap-up loop sequence.
1384   codegen.curVecLength = 1;
1385   genReductionEnd(merger, codegen, rewriter, op);
1386   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
1387   codegen.loops[idx] = Value();
1388 }
1389 
1390 /// Converts the result computed by the sparse kernel into the required form.
1391 static void genResult(CodeGen &codegen, PatternRewriter &rewriter,
1392                       linalg::GenericOp op) {
1393   RankedTensorType resType = op.getOutputTensorTypes()[0];
1394   Value result = codegen.buffers.back();
1395   if (getSparseTensorEncoding(resType))
1396     result = rewriter.create<ToTensorOp>(op.getLoc(), resType, result);
1397   else
1398     result =
1399         rewriter.create<memref::TensorLoadOp>(op.getLoc(), resType, result);
1400   rewriter.replaceOp(op, result);
1401 }
1402 
1403 namespace {
1404 
1405 /// Sparse rewriting rule for generic Lingalg operation.
1406 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1407 public:
1408   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1409       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1410 
1411   LogicalResult matchAndRewrite(linalg::GenericOp op,
1412                                 PatternRewriter &rewriter) const override {
1413     // Detects sparse annotations and translate the per-dimension sparsity
1414     // information for all tensors to loop indices in the kernel.
1415     assert(op.getNumOutputs() == 1);
1416     assert(llvm::none_of(op.getInputAndOutputOperands(),
1417                          [&](OpOperand *t) { return op.isScalar(t); }));
1418     unsigned numTensors = op.getNumInputsAndOutputs();
1419     unsigned numLoops = op.iterator_types().getValue().size();
1420     Merger merger(numTensors, numLoops);
1421     if (!findSparseAnnotations(merger, op))
1422       return failure();
1423 
1424     // Computes a topologically sorted iteration graph to ensure
1425     // tensors are visited in natural index order. Fails on cycles.
1426     // This assumes that higher-level passes have already put the
1427     // tensors in each tensor expression in a feasible order.
1428     std::vector<unsigned> topSort;
1429     if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
1430         !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
1431       return failure();
1432 
1433     // Finds the terminating yield statement and builds the tensor
1434     // expression for the Linalg operation in SSA form.
1435     Operation *yield = op.region().front().getTerminator();
1436     Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
1437     if (!exp.hasValue())
1438       return failure(); // build failure
1439 
1440     // Recursively generates code.
1441     CodeGen codegen(options, numTensors, numLoops);
1442     if (!genBuffers(merger, codegen, rewriter, op))
1443       return failure(); // could not bufferize
1444     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
1445     genResult(codegen, rewriter, op);
1446     return success();
1447   }
1448 
1449 private:
1450   /// Options to control sparse code generation.
1451   SparsificationOptions options;
1452 };
1453 
1454 } // namespace
1455 
1456 /// Populates the given patterns list with rewriting rules required for
1457 /// the sparsification of linear algebra operations.
1458 void mlir::populateSparsificationPatterns(
1459     RewritePatternSet &patterns, const SparsificationOptions &options) {
1460   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1461 }
1462