1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering sparse tensor types to actual sparse code.
10 //
11 // The concept of letting a compiler generate sparse code automatically was
12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and
13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor
14 // Algebra Compiler (TACO). The implementation in this file closely follows
15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting
16 // rule is applied to each tensor expression in linalg (MLIR's tensor index
17 // notation) where the sparsity of tensors is indicated with annotation using
18 // a per-dimension specification of sparse/dense storage together with a
19 // specification of the order on the dimensions. Subsequently, a topologically
20 // sorted iteration graph, reflecting the required order on indices with respect
21 // to the dimensions of each tensor, is constructed to ensure that all tensors
22 // are visited in natural index order. Next, iteration lattices are constructed
23 // for the tensor expression for every index in topological order. Each
24 // iteration lattice point consists of a conjunction of tensor indices together
25 // with a tensor (sub)expression that needs to be evaluated for that
26 // conjunction. Within the lattice, iteration points are ordered according to
27 // the way indices are exhausted. As such these iteration lattices drive actual
28 // sparse code generation, which consists of a tedious but relatively
29 // straightforward one-to-one mapping from iteration lattices to combinations
30 // of for-loops, while-loops, and if-statements.
31 //
32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations.
33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php).
34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou,
35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler.
36 // Proceedings of the ACM on Programming Languages, October 2017.
37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation.
38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org).
39 //
40 // Implementation detail: We use llvm::SmallVector for vectors with
41 // variable lengths and std::vector for vectors with fixed lengths.
42 //===----------------------------------------------------------------------===//
43 
44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
45 #include "mlir/Dialect/Linalg/Utils/Utils.h"
46 #include "mlir/Dialect/SCF/SCF.h"
47 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
48 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
49 #include "mlir/Dialect/StandardOps/IR/Ops.h"
50 #include "mlir/Dialect/Vector/VectorOps.h"
51 #include "mlir/IR/Matchers.h"
52 #include "mlir/IR/TensorEncoding.h"
53 #include "llvm/ADT/SmallBitVector.h"
54 
55 using namespace mlir;
56 using namespace mlir::sparse_tensor;
57 
58 namespace {
59 
60 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
61 enum class Dim { kSparse, kDense, kSingle, kUndef };
62 
63 /// Tensor expression. Represents a MLIR expression in tensor index notation.
64 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is
65 /// stored directly. For binary operations, e0 and e1 denote the index of the
66 /// children tensor expressions.
67 struct TensorExp {
68   TensorExp(Kind k, unsigned x, unsigned y, Value v)
69       : kind(k), e0(x), e1(y), val(v) {
70     assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
71            (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
72            (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
73   }
74   Kind kind;
75   /// Indices of children expression(s).
76   unsigned e0;
77   unsigned e1;
78   /// Direct link to IR for an invariant. During code generation,
79   /// field is used to cache "hoisted" loop invariant tensor loads.
80   Value val;
81 };
82 
83 /// Lattice point. Each lattice point consists of a conjunction of tensor
84 /// loop indices (encoded in a bitvector) and the index of the corresponding
85 /// tensor expression.
86 struct LatPoint {
87   LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
88     bits.set(b);
89   }
90   LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
91   /// Conjunction of tensor loop indices as bitvector. This represents
92   /// all indices involved in the tensor expression
93   llvm::BitVector bits;
94   /// Simplified conjunction of tensor loop indices as bitvector. This
95   /// represents a simplified condition under which this tensor expression
96   /// must execute. Pre-computed during codegen to avoid repeated eval.
97   llvm::BitVector simple;
98   /// Index of the tensor expresssion.
99   unsigned exp;
100 };
101 
102 /// A class to handle all iteration lattice operations. This class abstracts
103 /// away from some implementation details of storing iteration lattices and
104 /// tensor expressions. This allows for fine-tuning performance characteristics
105 /// independently from the basic algorithm if bottlenecks are identified.
106 class Merger {
107 public:
108   /// Constructs a merger for the given number of tensors and loops. The
109   /// user supplies the number of tensors involved in the kernel, with the
110   /// last tensor in this set denoting the output tensor. The merger adds an
111   /// additional synthetic tensor at the end of this set to represent all
112   /// invariant expressions in the kernel.
113   Merger(unsigned t, unsigned l)
114       : outTensor(t - 1), numTensors(t + 1), numLoops(l),
115         dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
116 
117   /// Adds a tensor expression. Returns its index.
118   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
119     unsigned e = tensorExps.size();
120     tensorExps.push_back(TensorExp(k, e0, e1, v));
121     return e;
122   }
123   unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
124 
125   /// Adds an iteration lattice point. Returns its index.
126   unsigned addLat(unsigned t, unsigned i, unsigned e) {
127     assert(t < numTensors && i < numLoops);
128     unsigned p = latPoints.size();
129     latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
130     return p;
131   }
132 
133   /// Adds a new, initially empty, set. Returns its index.
134   unsigned addSet() {
135     unsigned s = latSets.size();
136     latSets.emplace_back(SmallVector<unsigned, 16>());
137     return s;
138   }
139 
140   /// Computes a single conjunction of two lattice points by taking the "union"
141   /// of loop indices (effectively constructing a larger "intersection" of those
142   /// indices) with a newly constructed tensor (sub)expression of given kind.
143   /// Returns the index of the new lattice point.
144   unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
145     unsigned p = latPoints.size();
146     llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
147     nb |= latPoints[p1].bits;
148     unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
149     latPoints.push_back(LatPoint(nb, e));
150     return p;
151   }
152 
153   /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
154   /// cartesian product. Returns the index of the new set.
155   unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
156     unsigned s = addSet();
157     for (unsigned p0 : latSets[s0])
158       for (unsigned p1 : latSets[s1])
159         latSets[s].push_back(conjLatPoint(kind, p0, p1));
160     return s;
161   }
162 
163   /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
164   /// Returns the index of the new set.
165   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
166     unsigned s = takeConj(kind, s0, s1);
167     for (unsigned p : latSets[s0])
168       latSets[s].push_back(p);
169     for (unsigned p : latSets[s1])
170       latSets[s].push_back(p);
171     return s;
172   }
173 
174   /// Optimizes the iteration lattice points in the given set. This
175   /// method should be called right before code generation to avoid
176   /// generating redundant loops and conditions.
177   unsigned optimizeSet(unsigned s0) {
178     unsigned s = addSet();
179     assert(latSets[s0].size() != 0);
180     unsigned p0 = latSets[s0][0];
181     for (unsigned p1 : latSets[s0]) {
182       bool add = true;
183       if (p0 != p1) {
184         // Is this a straightforward copy?
185         unsigned e = latPoints[p1].exp;
186         if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
187           continue;
188         // Conjunction already covered?
189         for (unsigned p2 : latSets[s]) {
190           assert(!latGT(p1, p2)); // Lj => Li would be bad
191           if (onlyDenseDiff(p2, p1)) {
192             add = false;
193             break;
194           }
195         }
196         assert(!add || latGT(p0, p1));
197       }
198       if (add)
199         latSets[s].push_back(p1);
200     }
201     for (unsigned p : latSets[s])
202       latPoints[p].simple = simplifyCond(s, p);
203     return s;
204   }
205 
206   /// Simplifies the conditions in a conjunction of a given lattice point
207   /// within the given set using just two basic rules:
208   /// (1) multiple dense conditions are reduced to single dense, and
209   /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
210   llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
211     // First determine if this lattice point is a *singleton*, i.e.,
212     // the last point in a lattice, no other is less than this one.
213     bool isSingleton = true;
214     for (unsigned p1 : latSets[s]) {
215       if (p0 != p1 && latGT(p0, p1)) {
216         isSingleton = false;
217         break;
218       }
219     }
220     // Now apply the two basic rules.
221     llvm::BitVector simple = latPoints[p0].bits;
222     bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
223     for (unsigned b = 0, be = simple.size(); b < be; b++) {
224       if (simple[b] && !isDim(b, Dim::kSparse)) {
225         if (reset)
226           simple.reset(b);
227         reset = true;
228       }
229     }
230     return simple;
231   }
232 
233   /// Returns true if Li > Lj.
234   bool latGT(unsigned i, unsigned j) const {
235     const llvm::BitVector &bitsi = latPoints[i].bits;
236     const llvm::BitVector &bitsj = latPoints[j].bits;
237     assert(bitsi.size() == bitsj.size());
238     if (bitsi.count() > bitsj.count()) {
239       for (unsigned b = 0, be = bitsj.size(); b < be; b++)
240         if (bitsj[b] && !bitsi[b])
241           return false;
242       return true;
243     }
244     return false;
245   }
246 
247   /// Returns true if Li and Lj only differ in dense.
248   bool onlyDenseDiff(unsigned i, unsigned j) {
249     llvm::BitVector tmp = latPoints[j].bits;
250     tmp ^= latPoints[i].bits;
251     return !hasAnyDimOf(tmp, Dim::kSparse);
252   }
253 
254   /// Bit translation.
255   unsigned tensor(unsigned b) const { return b % numTensors; }
256   unsigned index(unsigned b) const { return b / numTensors; }
257 
258   /// Returns true if bit corresponds to queried dim.
259   bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
260 
261   /// Returns true if tensor access at given index has queried dim.
262   bool isDim(unsigned t, unsigned i, Dim d) const {
263     assert(t < numTensors && i < numLoops);
264     return dims[t][i] == d;
265   }
266 
267   /// Returns true if any set bit corresponds to queried dim.
268   bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
269     for (unsigned b = 0, be = bits.size(); b < be; b++)
270       if (bits[b] && isDim(b, d))
271         return true;
272     return false;
273   }
274 
275   /// Setter
276   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
277 
278   /// Getters.
279   TensorExp &exp(unsigned e) { return tensorExps[e]; }
280   LatPoint &lat(unsigned l) { return latPoints[l]; }
281   SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
282 
283 private:
284   const unsigned outTensor;
285   const unsigned numTensors;
286   const unsigned numLoops;
287 
288   std::vector<std::vector<Dim>> dims;
289   llvm::SmallVector<TensorExp, 32> tensorExps;
290   llvm::SmallVector<LatPoint, 16> latPoints;
291   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
292 };
293 
294 // Code generation.
295 struct CodeGen {
296   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
297       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
298         pointers(numTensors, std::vector<Value>(numLoops)),
299         indices(numTensors, std::vector<Value>(numLoops)),
300         highs(numTensors, std::vector<Value>(numLoops)),
301         pidxs(numTensors, std::vector<Value>(numLoops)),
302         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
303         curVecLength(1), curVecMask() {}
304   /// Sparsification options.
305   SparsificationOptions options;
306   /// Universal dense indices and upper bounds (by index). The loops array
307   /// is updated with the value of the universal dense index in the current
308   /// loop. The sizes array is set once with the inferred dimension sizes.
309   std::vector<Value> loops;
310   std::vector<Value> sizes;
311   /// Buffers for storing dense and sparse numerical values (by tensor).
312   /// This array is set once during bufferization of all tensors.
313   std::vector<Value> buffers;
314   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
315   /// This array is set once during bufferization of all sparse tensors.
316   std::vector<std::vector<Value>> pointers;
317   std::vector<std::vector<Value>> indices;
318   /// Sparse iteration information (by tensor and index). These arrays
319   /// are updated to remain current within the current loop.
320   std::vector<std::vector<Value>> highs;
321   std::vector<std::vector<Value>> pidxs;
322   std::vector<std::vector<Value>> idxs;
323   /// Current reduction, updated during code generation. When indices of a
324   /// reduction are exhausted,  all inner loops can "scalarize" the reduction.
325   // TODO: currently only done for (a chain of) innermost for-loops, where it
326   // is most effective; we could generalize to more outer and while-loops.
327   unsigned redExp;
328   Value redVal;
329   // Current vector length and mask.
330   unsigned curVecLength;
331   Value curVecMask;
332 };
333 
334 } // namespace
335 
336 // Helper method to apply dimension ordering permutation.
337 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) {
338   if (enc) {
339     auto order = enc.getDimOrdering();
340     if (order) {
341       assert(order.isPermutation());
342       return order.getDimPosition(d);
343     }
344   }
345   return d;
346 }
347 
348 // Helper method to translate dim level type to internal representation.
349 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
350   if (enc) {
351     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
352     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
353       return Dim::kSparse;
354     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
355       return Dim::kSingle;
356   }
357   return Dim::kDense;
358 }
359 
360 /// Helper method to inspect sparse encodings in the tensor types.
361 /// Fills the per-dimension sparsity information for all tensors.
362 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
363   bool annotated = false;
364   OpOperand *lhs = op.getOutputOperand(0);
365   for (OpOperand *t : op.getInputAndOutputOperands()) {
366     auto map = op.getTiedIndexingMap(t);
367     if (!map.isProjectedPermutation())
368       return false;
369     auto enc = getSparseTensorEncoding(t->get().getType());
370     if (enc) {
371       annotated = true;
372       if (t == lhs)
373         return false; // TODO: handle sparse outputs
374     }
375     assert(map.getNumResults() == op.getRank(t));
376     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
377       unsigned idx = map.getDimPosition(perm(enc, d));
378       merger.setDim(t->getOperandNumber(), idx, toDim(enc, d));
379     }
380   }
381   return annotated;
382 }
383 
384 /// A DFS helper to compute a topological sort. Note that recursion is
385 /// bounded by the number of implicit loops, which is always small.
386 /// Returns false when a cycle is detected.
387 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
388                        std::vector<unsigned> &topSort,
389                        std::vector<std::vector<bool>> &adjM) {
390   if (visit[i] != 0)
391     return visit[i] != 1; // 1 denotes cycle!
392   visit[i] = 1;
393   for (unsigned j = 0, e = visit.size(); j < e; j++)
394     if (adjM[i][j])
395       if (!topSortDFS(j, visit, topSort, adjM))
396         return false;
397   visit[i] = 2;
398   topSort.push_back(i);
399   return true;
400 }
401 
402 /// Computes a topologically sorted iteration graph for the linalg operation.
403 /// Ensures all tensors are visited in natural index order. This is essential
404 /// for sparse storage formats since these only support access along fixed
405 /// dimensions. Even for dense storage formats, however, the natural index
406 /// order yields innermost unit-stride access with better spatial locality.
407 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
408                                   std::vector<unsigned> &topSort,
409                                   bool sparseOnly) {
410   // Set up an n x n from/to adjacency matrix of the iteration graph
411   // for the implicit loop indices i_0 .. i_n-1.
412   unsigned n = op.getNumLoops();
413   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
414 
415   // Iterate over the indexing maps of every tensor in the tensor expression.
416   for (OpOperand *t : op.getInputAndOutputOperands()) {
417     auto map = op.getTiedIndexingMap(t);
418     auto enc = getSparseTensorEncoding(t->get().getType());
419     assert(map.getNumDims() == n);
420     // Skip dense tensor constraints when sparse only is requested.
421     if (sparseOnly && !enc)
422       continue;
423     // Each tensor expression and optional dimension ordering (row-major
424     // by default) puts an ordering constraint on the loop indices. For
425     // example, the tensor expresion A_ijk forces the ordering i < j < k
426     // on the loop indices if no explicit dimension ordering is given.
427     for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
428       unsigned f = map.getDimPosition(perm(enc, d - 1));
429       unsigned t = map.getDimPosition(perm(enc, d));
430       adjM[f][t] = true;
431     }
432   }
433 
434   // Topologically sort the iteration graph to determine loop order.
435   // Report failure for a cyclic iteration graph.
436   topSort.clear();
437   topSort.reserve(n);
438   std::vector<unsigned> visit(n, 0);
439   for (unsigned i = 0; i < n; i++)
440     if (visit[i] == 0)
441       if (!topSortDFS(i, visit, topSort, adjM))
442         return false; // cycle!
443   std::reverse(std::begin(topSort), std::end(topSort));
444   return true;
445 }
446 
447 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
448 /// This simplifies constructing (sub)expressions during iteration lattice
449 /// building (compared to using the SSA representation everywhere).
450 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
451                                          Value val) {
452   if (auto arg = val.dyn_cast<BlockArgument>()) {
453     unsigned argN = arg.getArgNumber();
454     // Any parameter of the generic op is considered a tensor,
455     // indexed by the implicit loop bounds.
456     if (arg.getOwner()->getParentOp() == op)
457       return merger.addExp(Kind::kTensor, argN);
458     // Any parameter of a higher op is invariant.
459     return merger.addExp(Kind::kInvariant, val);
460   }
461   Operation *def = val.getDefiningOp();
462   if (def->getBlock() != &op.region().front()) {
463     // Something defined outside is invariant.
464     return merger.addExp(Kind::kInvariant, val);
465   } else if (def->getNumOperands() == 2) {
466     // Construct binary operations if subexpressions could be built.
467     auto x = buildTensorExp(merger, op, def->getOperand(0));
468     auto y = buildTensorExp(merger, op, def->getOperand(1));
469     if (x.hasValue() && y.hasValue()) {
470       unsigned e0 = x.getValue();
471       unsigned e1 = y.getValue();
472       if (isa<MulFOp>(def))
473         return merger.addExp(Kind::kMulF, e0, e1);
474       if (isa<MulIOp>(def))
475         return merger.addExp(Kind::kMulI, e0, e1);
476       if (isa<AddFOp>(def))
477         return merger.addExp(Kind::kAddF, e0, e1);
478       if (isa<AddIOp>(def))
479         return merger.addExp(Kind::kAddI, e0, e1);
480     }
481   }
482   // Cannot build (yet).
483   return None;
484 }
485 
486 /// Builds the iteration lattices in a bottom-up traversal given the remaining
487 /// tensor (sub)expression and the next loop index in the iteration graph.
488 static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
489                               unsigned exp, unsigned idx) {
490   Kind kind = merger.exp(exp).kind;
491   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
492     // Either the index is really used in the tensor expression, or it is
493     // set to the undefined index in that dimension. An invariant expression
494     // is set to a synthetic tensor with undefined indices only.
495     unsigned s = merger.addSet();
496     unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0
497                                        : op.getNumInputsAndOutputs();
498     merger.set(s).push_back(merger.addLat(t, idx, exp));
499     return s;
500   }
501   unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx);
502   unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx);
503   switch (kind) {
504   case Kind::kTensor:
505   case Kind::kInvariant:
506     llvm_unreachable("handled above");
507   case Kind::kMulF:
508   case Kind::kMulI:
509     return merger.takeConj(kind, s0, s1);
510   case Kind::kAddF:
511   case Kind::kAddI:
512     return merger.takeDisj(kind, s0, s1);
513   }
514   llvm_unreachable("unexpected expression kind");
515 }
516 
517 /// Maps sparse integer option to actual integral storage type.
518 static Type genIntType(PatternRewriter &rewriter, unsigned width) {
519   if (width == 0)
520     return rewriter.getIndexType();
521   return rewriter.getIntegerType(width);
522 }
523 
524 /// Detects in-place annotation on tensor argument.
525 static bool getInPlace(Value val) {
526   if (auto arg = val.dyn_cast<BlockArgument>())
527     if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
528       if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
529               arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName))
530         return attr.getValue();
531   return false;
532 }
533 
534 /// Generates buffer for the output tensor.
535 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
536                              linalg::GenericOp op, MemRefType denseTp,
537                              ArrayRef<Value> args) {
538   Location loc = op.getLoc();
539   Value tensor = op.getOutputOperand(0)->get();
540   // The output tensor simply could materialize from the buffer that will
541   // be generated for the tensor present in the outs() clause. This has
542   // the major advantage that the sparse kernel only updates the nonzero
543   // positions for the output tensor.
544   if (getInPlace(tensor))
545     return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
546   // By default, a new buffer is allocated which is initialized to the
547   // tensor defined in the outs() clause. This is always correct but
548   // introduces a dense initialization component that may negatively
549   // impact the running complexity of the sparse kernel.
550   Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
551   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
552   rewriter.create<linalg::CopyOp>(loc, init, alloc);
553   return alloc;
554 }
555 
556 /// Local bufferization of all dense and sparse data structures.
557 /// This code enables testing the first prototype sparse compiler.
558 // TODO: replace this with a proliferated bufferization strategy
559 static void genBuffers(Merger &merger, CodeGen &codegen,
560                        PatternRewriter &rewriter, linalg::GenericOp op) {
561   Location loc = op.getLoc();
562   assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
563   // For every tensor, find lower and upper bound on dimensions, set the
564   // same bounds on loop indices, and obtain dense or sparse buffer(s).
565   SmallVector<Value, 4> args;
566   for (OpOperand *t : op.getInputAndOutputOperands()) {
567     auto shape = op.getShape(t);
568     auto map = op.getTiedIndexingMap(t);
569     auto enc = getSparseTensorEncoding(t->get().getType());
570     // Scan all dimensions of current tensor.
571     args.clear();
572     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
573       unsigned idx = map.getDimPosition(perm(enc, d));
574       // Handle sparse storage schemes.
575       if (merger.isDim(t->getOperandNumber(), idx, Dim::kSparse)) {
576         auto dynShape = {ShapedType::kDynamicSize};
577         auto ptrTp = MemRefType::get(
578             dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
579         auto indTp = MemRefType::get(
580             dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
581         Value dim = rewriter.create<ConstantIndexOp>(loc, d);
582         // Generate sparse primitives to obtains pointer and indices.
583         codegen.pointers[t->getOperandNumber()][idx] =
584             rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
585         codegen.indices[t->getOperandNumber()][idx] =
586             rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim);
587       }
588       // Find lower and upper bound in current dimension.
589       Value up;
590       if (shape[d] == MemRefType::kDynamicSize) {
591         up = rewriter.create<memref::DimOp>(loc, t->get(), d);
592         args.push_back(up);
593       } else {
594         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
595       }
596       codegen.sizes[idx] = codegen.highs[t->getOperandNumber()][idx] = up;
597     }
598     // Perform the required bufferization. All dense inputs materialize
599     // from the input tensor. The dense output tensor needs special
600     // handling. Sparse inputs use a sparse primitive to obtain the values.
601     Type elementType = getElementTypeOrSelf(t->get().getType());
602     if (!enc) {
603       auto denseTp = MemRefType::get(shape, elementType);
604       if (t->getOperandNumber() < op.getNumInputs())
605         codegen.buffers[t->getOperandNumber()] =
606             rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get());
607       else
608         codegen.buffers[t->getOperandNumber()] =
609             genOutputBuffer(codegen, rewriter, op, denseTp, args);
610     } else {
611       auto dynShape = {ShapedType::kDynamicSize};
612       auto sparseTp = MemRefType::get(dynShape, elementType);
613       codegen.buffers[t->getOperandNumber()] =
614           rewriter.create<ToValuesOp>(loc, sparseTp, t->get());
615     }
616   }
617 }
618 
619 /// Constructs vector type.
620 static VectorType vectorType(CodeGen &codegen, Type etp) {
621   return VectorType::get(codegen.curVecLength, etp);
622 }
623 
624 /// Constructs vector type from pointer.
625 static VectorType vectorType(CodeGen &codegen, Value ptr) {
626   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
627 }
628 
629 /// Constructs vector iteration mask.
630 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
631                            Value iv, Value lo, Value hi, Value step) {
632   Location loc = iv.getLoc();
633   VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1));
634   // Special case if the vector length evenly divides the trip count (for
635   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
636   // so that all subsequent masked memory operations are immediately folded
637   // into unconditional memory operations.
638   IntegerAttr loInt, hiInt, stepInt;
639   if (matchPattern(lo, m_Constant(&loInt)) &&
640       matchPattern(hi, m_Constant(&hiInt)) &&
641       matchPattern(step, m_Constant(&stepInt))) {
642     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
643       return rewriter.create<vector::BroadcastOp>(
644           loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1));
645   }
646   // Otherwise, generate a vector mask that avoids overrunning the upperbound
647   // during vector execution. Here we rely on subsequent loop optimizations to
648   // avoid executing the mask in all iterations, for example, by splitting the
649   // loop into an unconditional vector loop and a scalar cleanup loop.
650   Value end = rewriter.create<SubIOp>(loc, hi, iv);
651   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
652 }
653 
654 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
655 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
656                            Value ptr, ArrayRef<Value> args) {
657   Location loc = ptr.getLoc();
658   VectorType vtp = vectorType(codegen, ptr);
659   Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
660   if (args.back().getType().isa<VectorType>()) {
661     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
662     Value indexVec = args.back();
663     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
664     return rewriter.create<vector::GatherOp>(
665         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
666   }
667   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
668                                                codegen.curVecMask, pass);
669 }
670 
671 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
672 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
673                            Value rhs, Value ptr, ArrayRef<Value> args) {
674   Location loc = ptr.getLoc();
675   if (args.back().getType().isa<VectorType>()) {
676     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
677     Value indexVec = args.back();
678     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
679     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
680                                        codegen.curVecMask, rhs);
681     return;
682   }
683   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
684                                          rhs);
685 }
686 
687 /// Generates a vectorized invariant. Here we rely on subsequent loop
688 /// optimizations to hoist the invariant broadcast out of the vector loop.
689 static Value genVectorInvariantValue(CodeGen &codegen,
690                                      PatternRewriter &rewriter, Value val) {
691   VectorType vtp = vectorType(codegen, val.getType());
692   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
693 }
694 
695 /// Generates a load on a dense or sparse tensor.
696 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
697                            PatternRewriter &rewriter, linalg::GenericOp op,
698                            unsigned exp) {
699   // Test if the load was hoisted to a higher loop nest.
700   Value val = merger.exp(exp).val;
701   if (val) {
702     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
703       return genVectorInvariantValue(codegen, rewriter, val);
704     return val;
705   }
706   // Actual load.
707   SmallVector<Value, 4> args;
708   OpOperand *tensor = merger.exp(exp).e0 < op.getNumInputs()
709                           ? op.getInputOperand(merger.exp(exp).e0)
710                           : op.getOutputOperand(0);
711   auto map = op.getTiedIndexingMap(tensor);
712   auto enc = getSparseTensorEncoding(tensor->get().getType());
713   for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
714     unsigned idx = map.getDimPosition(perm(enc, d));
715     args.push_back(codegen.loops[idx]); // universal dense index
716     if (enc) {
717       args.clear();
718       args.push_back(
719           codegen.pidxs[tensor->getOperandNumber()][idx]); // position index
720     }
721   }
722   Location loc = op.getLoc();
723   Value ptr = codegen.buffers[tensor->getOperandNumber()];
724   if (codegen.curVecLength > 1)
725     return genVectorLoad(codegen, rewriter, ptr, args);
726   return rewriter.create<memref::LoadOp>(loc, ptr, args);
727 }
728 
729 /// Generates a store on a dense tensor.
730 static void genTensorStore(Merger &merger, CodeGen &codegen,
731                            PatternRewriter &rewriter, linalg::GenericOp op,
732                            OpOperand *tensor, Value rhs) {
733   Location loc = op.getLoc();
734   // Test if this is a scalarized reduction.
735   OpOperand *lhs = op.getOutputOperand(0);
736   if (lhs == tensor && codegen.redVal) {
737     if (codegen.curVecLength > 1)
738       rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs,
739                                       codegen.redVal);
740     codegen.redVal = rhs;
741     return;
742   }
743   // Actual store.
744   SmallVector<Value, 4> args;
745   auto map = op.getTiedIndexingMap(tensor);
746   assert(!getSparseTensorEncoding(tensor->get().getType()));
747   for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
748     unsigned idx = map.getDimPosition(d);
749     args.push_back(codegen.loops[idx]); // universal dense index
750   }
751   Value ptr = codegen.buffers[tensor->getOperandNumber()];
752   if (codegen.curVecLength > 1)
753     genVectorStore(codegen, rewriter, rhs, ptr, args);
754   else
755     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
756 }
757 
758 /// Generates a pointer/index load from the sparse storage scheme. Narrower
759 /// data types need to be zero extended before casting the value into the
760 /// index type used for looping and indexing.
761 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
762                      Value ptr, Value s) {
763   // See https://llvm.org/docs/GetElementPtr.html for some background on
764   // the complications described below.
765   if (codegen.curVecLength > 1) {
766     // Since the index vector is used in a subsequent gather/scatter operations,
767     // which effectively defines an unsigned pointer + signed index, we must
768     // zero extend the vector to an index width. For 8-bit and 16-bit values,
769     // an 32-bit index width suffices. For 32-bit values, zero extending the
770     // elements into 64-bit loses some performance since the 32-bit indexed
771     // gather/scatter is more efficient than the 64-bit index variant (if the
772     // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
773     // preserve this performance)). For 64-bit values, there is no good way
774     // to state that the indices are unsigned, with creates the potential of
775     // incorrect address calculations in the unlikely case we need such
776     // extremely large offsets.
777     Type etp = ptr.getType().cast<MemRefType>().getElementType();
778     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
779     if (!etp.isa<IndexType>()) {
780       if (etp.getIntOrFloatBitWidth() < 32)
781         vload = rewriter.create<ZeroExtendIOp>(
782             loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
783       else if (etp.getIntOrFloatBitWidth() < 64 &&
784                !codegen.options.enableSIMDIndex32)
785         vload = rewriter.create<ZeroExtendIOp>(
786             loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
787     }
788     return vload;
789   }
790   // For the scalar case, we simply zero extend narrower indices into 64-bit
791   // values before casting to index without a performance penalty. Here too,
792   // however, indices that already are 64-bit, in theory, cannot express the
793   // full range as explained above.
794   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
795   if (!load.getType().isa<IndexType>()) {
796     if (load.getType().getIntOrFloatBitWidth() < 64)
797       load = rewriter.create<ZeroExtendIOp>(loc, load,
798                                             rewriter.getIntegerType(64));
799     load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
800   }
801   return load;
802 }
803 
804 /// Generates an invariant value.
805 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
806                                PatternRewriter &rewriter, unsigned exp) {
807   Value val = merger.exp(exp).val;
808   if (codegen.curVecLength > 1)
809     return genVectorInvariantValue(codegen, rewriter, val);
810   return val;
811 }
812 
813 /// Generates an address computation "sz * p + i".
814 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
815                         Location loc, Value size, Value p, Value i) {
816   Value mul = rewriter.create<MulIOp>(loc, size, p);
817   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
818     Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType());
819     mul = genVectorInvariantValue(codegen, rewriter, inv);
820   }
821   return rewriter.create<AddIOp>(loc, mul, i);
822 }
823 
824 /// Generates start of a reduction.
825 static Value genReductionStart(Merger &merger, CodeGen &codegen,
826                                PatternRewriter &rewriter,
827                                linalg::GenericOp op) {
828   if (codegen.redVal)
829     return codegen.redVal; // chained with previous for-loop
830   if (codegen.curVecLength > 1) {
831     // TODO: assumes + reductions for now
832     VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]);
833     return rewriter.create<ConstantOp>(op.getLoc(), vtp,
834                                        rewriter.getZeroAttr(vtp));
835   }
836   return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
837 }
838 
839 /// Generates end of a reduction.
840 static void genReductionEnd(Merger &merger, CodeGen &codegen,
841                             PatternRewriter &rewriter, linalg::GenericOp op) {
842   Value red = codegen.redVal;
843   if (!red)
844     return;
845   assert(codegen.curVecLength == 1);
846   codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
847   OpOperand *lhs = op.getOutputOperand(0);
848   if (auto vtp = red.getType().dyn_cast<VectorType>()) {
849     // TODO: assumes + reductions for now
850     StringAttr kind = rewriter.getStringAttr("add");
851     Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
852     // Integer reductions don't accept an accumulator.
853     if (vtp.getElementType().isa<IntegerType>()) {
854       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
855                                                  kind, red, ValueRange{});
856       red = rewriter.create<AddIOp>(op.getLoc(), red, ld);
857     } else {
858       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
859                                                  kind, red, ld);
860     }
861   }
862   genTensorStore(merger, codegen, rewriter, op, lhs, red);
863 }
864 
865 /// Recursively generates tensor expression.
866 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
867                     linalg::GenericOp op, unsigned exp) {
868   if (merger.exp(exp).kind == Kind::kTensor)
869     return genTensorLoad(merger, codegen, rewriter, op, exp);
870   else if (merger.exp(exp).kind == Kind::kInvariant)
871     return genInvariantValue(merger, codegen, rewriter, exp);
872   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
873   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
874   switch (merger.exp(exp).kind) {
875   case Kind::kTensor:
876   case Kind::kInvariant:
877     llvm_unreachable("handled above");
878   case Kind::kMulF:
879     return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
880   case Kind::kMulI:
881     return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
882   case Kind::kAddF:
883     return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
884   case Kind::kAddI:
885     return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
886   }
887   llvm_unreachable("unexpected expression kind");
888 }
889 
890 /// Hoists loop invariant tensor loads for which indices have been exhausted.
891 static void genInvariants(Merger &merger, CodeGen &codegen,
892                           PatternRewriter &rewriter, linalg::GenericOp op,
893                           unsigned exp, unsigned ldx, bool hoist) {
894   if (merger.exp(exp).kind == Kind::kTensor) {
895     // Inspect tensor indices.
896     bool atLevel = ldx == -1u;
897     OpOperand *tensor = merger.exp(exp).e0 < op.getNumInputs()
898                             ? op.getInputOperand(merger.exp(exp).e0)
899                             : op.getOutputOperand(0);
900     auto map = op.getTiedIndexingMap(tensor);
901     auto enc = getSparseTensorEncoding(tensor->get().getType());
902     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
903       unsigned idx = map.getDimPosition(perm(enc, d));
904       if (!codegen.loops[idx])
905         return; // still in play
906       else if (idx == ldx)
907         atLevel = true;
908     }
909     // All exhausted at this level (atLevel denotes exactly at this level).
910     OpOperand *lhs = op.getOutputOperand(0);
911     if (lhs == tensor) {
912       codegen.redExp = hoist ? exp : -1u;
913     } else if (atLevel) {
914       merger.exp(exp).val =
915           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
916     }
917   } else if (merger.exp(exp).kind != Kind::kInvariant) {
918     // Traverse into the binary operations. Note that we only hoist
919     // tensor loads, since subsequent MLIR/LLVM passes know how to
920     // deal with all other kinds of derived loop invariants.
921     unsigned e0 = merger.exp(exp).e0;
922     unsigned e1 = merger.exp(exp).e1;
923     genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist);
924     genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist);
925   }
926 }
927 
928 /// Generates initialization code for the subsequent loop sequence at
929 /// current index level. Returns true if the loop sequence needs to
930 /// maintain the universal index.
931 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
932                     linalg::GenericOp op, std::vector<unsigned> &topSort,
933                     unsigned at, llvm::BitVector &inits) {
934   bool needsUniv = false;
935   Location loc = op.getLoc();
936   unsigned idx = topSort[at];
937 
938   // Initialize sparse positions.
939   for (unsigned b = 0, be = inits.size(); b < be; b++) {
940     if (inits[b]) {
941       unsigned tensor = merger.tensor(b);
942       assert(idx == merger.index(b));
943       if (merger.isDim(b, Dim::kSparse)) {
944         // Initialize sparse index.
945         unsigned pat = at;
946         for (; pat != 0; pat--) {
947           if (codegen.pidxs[tensor][topSort[pat - 1]])
948             break;
949         }
950         Value ptr = codegen.pointers[tensor][idx];
951         Value one = rewriter.create<ConstantIndexOp>(loc, 1);
952         Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
953                               : codegen.pidxs[tensor][topSort[pat - 1]];
954         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
955         Value p1 = rewriter.create<AddIOp>(loc, p0, one);
956         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
957       } else {
958         // Dense index still in play.
959         needsUniv = true;
960       }
961     }
962   }
963 
964   // Initialize the universal dense index.
965   codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0);
966   return needsUniv;
967 }
968 
969 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
970 /// operation is a candidate. Whether it is actually converted to SIMD code
971 /// depends on the requested strategy.
972 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
973   switch (codegen.options.vectorizationStrategy) {
974   case SparseVectorizationStrategy::kNone:
975     return false;
976   case SparseVectorizationStrategy::kDenseInnerLoop:
977     return isInner && !isSparse;
978   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
979     return isInner;
980   }
981   llvm_unreachable("unexpected vectorization strategy");
982 }
983 
984 /// Returns parallelization strategy. Any implicit loop in the Linalg operation
985 /// that is marked "parallel" is a candidate. Whether it is actually converted
986 /// to a parallel operation depends on the requested strategy.
987 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
988                           bool isSparse, bool isVector) {
989   switch (codegen.options.parallelizationStrategy) {
990   case SparseParallelizationStrategy::kNone:
991     return false;
992   case SparseParallelizationStrategy::kDenseOuterLoop:
993     return isOuter && !isSparse && !isReduction && !isVector;
994   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
995     return isOuter && !isReduction && !isVector;
996   case SparseParallelizationStrategy::kDenseAnyLoop:
997     return !isSparse && !isReduction && !isVector;
998   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
999     return !isReduction && !isVector;
1000   }
1001   llvm_unreachable("unexpected parallelization strategy");
1002 }
1003 
1004 /// Checks unit strides for dense tensors. The iteration graph may have ignored
1005 /// dense access patterns in order to avoid cycles (sparse access patterns are
1006 /// always placed innermost), but that means dense access has become strided.
1007 /// For now, we reject vectorization of such cases.
1008 /// TODO: implement strided load/stores on dense arrays
1009 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
1010                              unsigned idx) {
1011   for (OpOperand *t : op.getInputAndOutputOperands()) {
1012     if (!getSparseTensorEncoding(t->get().getType())) {
1013       auto map = op.getTiedIndexingMap(t);
1014       for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
1015         if (map.getDimPosition(d) == idx && d != rank - 1)
1016           return false;
1017       }
1018     }
1019   }
1020   return true;
1021 }
1022 
1023 /// Generates a for-loop on a single index.
1024 static Operation *genFor(Merger &merger, CodeGen &codegen,
1025                          PatternRewriter &rewriter, linalg::GenericOp op,
1026                          bool isOuter, bool isInner, unsigned idx,
1027                          llvm::BitVector &indices) {
1028   unsigned fb = indices.find_first();
1029   unsigned tensor = merger.tensor(fb);
1030   assert(idx == merger.index(fb));
1031   auto iteratorTypes = op.iterator_types().getValue();
1032   bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]);
1033   bool isSparse = merger.isDim(fb, Dim::kSparse);
1034   bool isVector = isVectorFor(codegen, isInner, isSparse) &&
1035                   denseUnitStrides(merger, op, idx);
1036   bool isParallel =
1037       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
1038 
1039   // Prepare vector length.
1040   if (isVector)
1041     codegen.curVecLength = codegen.options.vectorLength;
1042 
1043   // Loop bounds and increment.
1044   Location loc = op.getLoc();
1045   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
1046   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
1047   Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength);
1048 
1049   // Emit a parallel loop.
1050   if (isParallel) {
1051     assert(!isVector);
1052     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
1053     if (isSparse)
1054       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
1055     else
1056       codegen.loops[idx] = parOp.getInductionVars()[0];
1057     rewriter.setInsertionPointToStart(parOp.getBody());
1058     return parOp;
1059   }
1060 
1061   // Emit a sequential loop, potentially with a scalarized reduction.
1062   bool scalarRed = isInner && codegen.redExp != -1u;
1063   SmallVector<Value, 4> operands;
1064   if (scalarRed) {
1065     Value load = genReductionStart(merger, codegen, rewriter, op);
1066     operands.push_back(load);
1067   }
1068   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
1069   if (scalarRed) {
1070     codegen.redVal = merger.exp(codegen.redExp).val =
1071         forOp.getRegionIterArgs().front();
1072   }
1073   // Assign induction variable to sparse or dense index.
1074   Value iv = forOp.getInductionVar();
1075   if (isSparse)
1076     codegen.pidxs[tensor][idx] = iv;
1077   else
1078     codegen.loops[idx] = iv;
1079   rewriter.setInsertionPointToStart(forOp.getBody());
1080   // Share vector iteration mask between all subsequent loads/stores.
1081   if (isVector)
1082     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
1083   return forOp;
1084 }
1085 
1086 /// Emit a while-loop for co-iteration over multiple indices.
1087 static Operation *genWhile(Merger &merger, CodeGen &codegen,
1088                            PatternRewriter &rewriter, linalg::GenericOp op,
1089                            unsigned idx, bool needsUniv,
1090                            llvm::BitVector &indices) {
1091   SmallVector<Type, 4> types;
1092   SmallVector<Value, 4> operands;
1093   // Construct the while-loop with a parameter for each index.
1094   Type indexType = rewriter.getIndexType();
1095   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1096     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1097       unsigned tensor = merger.tensor(b);
1098       assert(idx == merger.index(b));
1099       types.push_back(indexType);
1100       assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
1101              "type mismatch for sparse index");
1102       operands.push_back(codegen.pidxs[tensor][idx]);
1103     }
1104   }
1105   if (needsUniv) {
1106     types.push_back(indexType);
1107     assert(codegen.loops[idx].getType().isa<IndexType>() &&
1108            "type mismatch for universal index");
1109     operands.push_back(codegen.loops[idx]);
1110   }
1111   Location loc = op.getLoc();
1112   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
1113   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
1114   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
1115 
1116   // Build the "before" region, which effectively consists
1117   // of a conjunction of "i < upper" tests on all induction.
1118   rewriter.setInsertionPointToStart(&whileOp.before().front());
1119   Value cond;
1120   unsigned o = 0;
1121   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1122     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1123       unsigned tensor = merger.tensor(b);
1124       assert(idx == merger.index(b));
1125       Value op1 = before->getArgument(o);
1126       Value op2 = codegen.highs[tensor][idx];
1127       Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2);
1128       cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc;
1129       codegen.pidxs[tensor][idx] = after->getArgument(o++);
1130     }
1131   }
1132   if (needsUniv)
1133     codegen.loops[idx] = after->getArgument(o++);
1134   assert(o == operands.size());
1135   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
1136   rewriter.setInsertionPointToStart(&whileOp.after().front());
1137   return whileOp;
1138 }
1139 
1140 /// Generates a for-loop or a while-loop, depending on whether it implements
1141 /// singleton iteration or co-iteration over the given conjunction.
1142 static Operation *genLoop(Merger &merger, CodeGen &codegen,
1143                           PatternRewriter &rewriter, linalg::GenericOp op,
1144                           std::vector<unsigned> &topSort, unsigned at,
1145                           bool needsUniv, llvm::BitVector &indices) {
1146   unsigned idx = topSort[at];
1147   if (indices.count() == 1) {
1148     bool isOuter = at == 0;
1149     bool isInner = at == topSort.size() - 1;
1150     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
1151                   indices);
1152   }
1153   genReductionEnd(merger, codegen, rewriter, op); // cannot chain
1154   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
1155 }
1156 
1157 /// Generates the local variables for this loop, consisting of the sparse
1158 /// indices, restored universal dense index, and dense positions.
1159 static void genLocals(Merger &merger, CodeGen &codegen,
1160                       PatternRewriter &rewriter, linalg::GenericOp op,
1161                       std::vector<unsigned> &topSort, unsigned at,
1162                       bool needsUniv, llvm::BitVector &locals) {
1163   Location loc = op.getLoc();
1164   unsigned idx = topSort[at];
1165 
1166   // Initialize sparse indices.
1167   Value min;
1168   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1169     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1170       unsigned tensor = merger.tensor(b);
1171       assert(idx == merger.index(b));
1172       Value ptr = codegen.indices[tensor][idx];
1173       Value s = codegen.pidxs[tensor][idx];
1174       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1175       codegen.idxs[tensor][idx] = load;
1176       if (!needsUniv) {
1177         if (min) {
1178           Value cmp =
1179               rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
1180           min = rewriter.create<SelectOp>(loc, cmp, load, min);
1181         } else {
1182           min = load;
1183         }
1184       }
1185     }
1186   }
1187 
1188   // Merge dense universal index over minimum.
1189   if (min) {
1190     assert(!needsUniv);
1191     codegen.loops[idx] = min;
1192   }
1193 
1194   // Initialize dense positions.
1195   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1196     if (locals[b] && merger.isDim(b, Dim::kDense)) {
1197       unsigned tensor = merger.tensor(b);
1198       assert(idx == merger.index(b));
1199       unsigned pat = at;
1200       for (; pat != 0; pat--)
1201         if (codegen.pidxs[tensor][topSort[pat - 1]])
1202           break;
1203       Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
1204                            : codegen.pidxs[tensor][topSort[pat - 1]];
1205       codegen.pidxs[tensor][idx] = genAddress(
1206           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1207     }
1208   }
1209 }
1210 
1211 /// Generates the induction structure for a while-loop.
1212 static void genWhileInduction(Merger &merger, CodeGen &codegen,
1213                               PatternRewriter &rewriter, linalg::GenericOp op,
1214                               unsigned idx, bool needsUniv,
1215                               llvm::BitVector &induction, ResultRange results) {
1216   Location loc = op.getLoc();
1217   unsigned o = 0;
1218   SmallVector<Value, 4> operands;
1219   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
1220   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1221     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1222       unsigned tensor = merger.tensor(b);
1223       assert(idx == merger.index(b));
1224       Value op1 = codegen.idxs[tensor][idx];
1225       Value op2 = codegen.loops[idx];
1226       Value op3 = codegen.pidxs[tensor][idx];
1227       Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1228       Value add = rewriter.create<AddIOp>(loc, op3, one);
1229       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
1230       codegen.pidxs[tensor][idx] = results[o++];
1231     }
1232   }
1233   if (needsUniv) {
1234     operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one));
1235     codegen.loops[idx] = results[o++];
1236   }
1237   assert(o == operands.size());
1238   rewriter.create<scf::YieldOp>(loc, operands);
1239 }
1240 
1241 /// Generates a single if-statement within a while-loop.
1242 static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1243                        PatternRewriter &rewriter, linalg::GenericOp op,
1244                        unsigned idx, llvm::BitVector &conditions) {
1245   Location loc = op.getLoc();
1246   Value cond;
1247   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1248     if (conditions[b]) {
1249       unsigned tensor = merger.tensor(b);
1250       assert(idx == merger.index(b));
1251       Value clause;
1252       if (merger.isDim(b, Dim::kSparse)) {
1253         Value op1 = codegen.idxs[tensor][idx];
1254         Value op2 = codegen.loops[idx];
1255         clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1256       } else {
1257         clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true
1258       }
1259       cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause;
1260     }
1261   }
1262   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
1263   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
1264   return ifOp;
1265 }
1266 
1267 /// Recursively generates code while computing iteration lattices in order
1268 /// to manage the complexity of implementing co-iteration over unions
1269 /// and intersections of sparse iterations spaces.
1270 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1271                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1272                     unsigned exp, unsigned at) {
1273   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1274   if (at == topSort.size()) {
1275     OpOperand *lhs = op.getOutputOperand(0);
1276     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1277     genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
1278     return;
1279   }
1280   assert(codegen.curVecLength == 1);
1281 
1282   // Construct iteration lattices for current loop index, with L0 at top.
1283   // Then emit initialization code for the loop sequence at this level.
1284   // We maintain the universal dense index if dense indices are still
1285   // in play for a non-singleton loop sequence.
1286   Location loc = op.getLoc();
1287   unsigned idx = topSort[at];
1288   unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx));
1289   unsigned lsize = merger.set(lts).size();
1290   assert(lsize != 0);
1291   unsigned l0 = merger.set(lts)[0];
1292   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1293   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
1294   bool needsUniv = false;
1295   if (genInit(merger, codegen, rewriter, op, topSort, at,
1296               merger.lat(l0).bits)) {
1297     // Maintain the universal index only if it is actually
1298     // consumed by a subsequent lattice point.
1299     for (unsigned i = 1; i < lsize; i++) {
1300       unsigned li = merger.set(lts)[i];
1301       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) {
1302         needsUniv = true;
1303         break;
1304       }
1305     }
1306   }
1307 
1308   // Emit a loop for every lattice point L0 >= Li.
1309   for (unsigned i = 0; i < lsize; i++) {
1310     unsigned li = merger.set(lts)[i];
1311 
1312     // Emit loop.
1313     codegen.curVecLength = 1;
1314     llvm::BitVector indices = merger.lat(li).simple;
1315     Operation *loop =
1316         genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
1317     genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1318               merger.lat(li).bits);
1319 
1320     // Visit all lattices points with Li >= Lj to generate the
1321     // loop-body, possibly with if statements for coiteration.
1322     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1323     for (unsigned j = 0; j < lsize; j++) {
1324       unsigned lj = merger.set(lts)[j];
1325       unsigned ej = merger.lat(lj).exp;
1326       if (li == lj || merger.latGT(li, lj)) {
1327         // Recurse into body of each branch.
1328         if (isWhile) {
1329           scf::IfOp ifOp =
1330               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1331           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1332           rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
1333         } else {
1334           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1335         }
1336       }
1337     }
1338 
1339     // Wrap-up induction and restore insertion point.
1340     if (isWhile) {
1341       scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
1342       rewriter.setInsertionPointToEnd(&whileOp.after().front());
1343       genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
1344                         merger.lat(li).bits, whileOp.results());
1345     } else {
1346       needsUniv = false;
1347       if (codegen.redVal) {
1348         rewriter.create<scf::YieldOp>(loc, codegen.redVal);
1349         codegen.redVal = loop->getResult(0);
1350       }
1351     }
1352     rewriter.setInsertionPointAfter(loop);
1353   }
1354 
1355   // Wrap-up loop sequence.
1356   codegen.curVecLength = 1;
1357   genReductionEnd(merger, codegen, rewriter, op);
1358   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
1359   codegen.loops[idx] = Value();
1360 }
1361 
1362 namespace {
1363 
1364 /// Sparse rewriting rule for generic Lingalg operation.
1365 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1366 public:
1367   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1368       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1369 
1370   LogicalResult matchAndRewrite(linalg::GenericOp op,
1371                                 PatternRewriter &rewriter) const override {
1372     // Detects sparse annotations and translate the per-dimension sparsity
1373     // information for all tensors to loop indices in the kernel.
1374     assert(op.getNumOutputs() == 1);
1375     unsigned numTensors = op.getNumInputsAndOutputs();
1376     unsigned numLoops = op.iterator_types().getValue().size();
1377     Merger merger(numTensors, numLoops);
1378     if (!findSparseAnnotations(merger, op))
1379       return failure();
1380 
1381     // Computes a topologically sorted iteration graph to ensure
1382     // tensors are visited in natural index order. Fails on cycles.
1383     // This assumes that higher-level passes have already put the
1384     // tensors in each tensor expression in a feasible order.
1385     std::vector<unsigned> topSort;
1386     if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
1387         !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
1388       return failure();
1389 
1390     // Finds the terminating yield statement and builds the tensor
1391     // expression for the Linalg operation in SSA form.
1392     Operation *yield = op.region().front().getTerminator();
1393     Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
1394     if (!exp.hasValue())
1395       return failure(); // build failure
1396 
1397     // Recursively generates code.
1398     CodeGen codegen(options, numTensors, numLoops);
1399     genBuffers(merger, codegen, rewriter, op);
1400     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
1401     Value result = rewriter.create<memref::TensorLoadOp>(
1402         op.getLoc(), codegen.buffers.back());
1403     rewriter.replaceOp(op, result);
1404     return success();
1405   }
1406 
1407 private:
1408   /// Options to control sparse code generation.
1409   SparsificationOptions options;
1410 };
1411 
1412 } // namespace
1413 
1414 /// Populates the given patterns list with rewriting rules required for
1415 /// the sparsification of linear algebra operations.
1416 void mlir::populateSparsificationPatterns(
1417     RewritePatternSet &patterns, const SparsificationOptions &options) {
1418   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1419 }
1420