1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering sparse tensor types to actual sparse code.
10 //
11 // The concept of letting a compiler generate sparse code automatically was
12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and
13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor
14 // Algebra Compiler (TACO). The implementation in this file closely follows
15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting
16 // rule is applied to each tensor expression in linalg (MLIR's tensor index
17 // notation) where the sparsity of tensors is indicated with annotation using
18 // a per-dimension specification of sparse/dense storage together with a
19 // specification of the order on the dimensions. Subsequently, a topologically
20 // sorted iteration graph, reflecting the required order on indices with respect
21 // to the dimensions of each tensor, is constructed to ensure that all tensors
22 // are visited in natural index order. Next, iteration lattices are constructed
23 // for the tensor expression for every index in topological order. Each
24 // iteration lattice point consists of a conjunction of tensor indices together
25 // with a tensor (sub)expression that needs to be evaluated for that
26 // conjunction. Within the lattice, iteration points are ordered according to
27 // the way indices are exhausted. As such these iteration lattices drive actual
28 // sparse code generation, which consists of a tedious but relatively
29 // straightforward one-to-one mapping from iteration lattices to combinations
30 // of for-loops, while-loops, and if-statements.
31 //
32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations.
33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php).
34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou,
35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler.
36 // Proceedings of the ACM on Programming Languages, October 2017.
37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation.
38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org).
39 //
40 // Implementation detail: We use llvm::SmallVector for vectors with
41 // variable lengths and std::vector for vectors with fixed lengths.
42 //===----------------------------------------------------------------------===//
43 
44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
45 #include "mlir/Dialect/Linalg/Utils/Utils.h"
46 #include "mlir/Dialect/SCF/SCF.h"
47 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
48 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
49 #include "mlir/Dialect/StandardOps/IR/Ops.h"
50 #include "mlir/Dialect/Vector/VectorOps.h"
51 #include "mlir/IR/Matchers.h"
52 #include "mlir/IR/TensorEncoding.h"
53 #include "llvm/ADT/SmallBitVector.h"
54 
55 using namespace mlir;
56 using namespace mlir::sparse_tensor;
57 
58 namespace {
59 
60 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
61 enum class Dim { kSparse, kDense, kSingle, kUndef };
62 
63 /// Tensor expression. Represents a MLIR expression in tensor index notation.
64 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is
65 /// stored directly. For binary operations, e0 and e1 denote the index of the
66 /// children tensor expressions.
67 struct TensorExp {
68   TensorExp(Kind k, unsigned x, unsigned y, Value v)
69       : kind(k), e0(x), e1(y), val(v) {
70     assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
71            (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
72            (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
73   }
74   Kind kind;
75   /// Indices of children expression(s).
76   unsigned e0;
77   unsigned e1;
78   /// Direct link to IR for an invariant. During code generation,
79   /// field is used to cache "hoisted" loop invariant tensor loads.
80   Value val;
81 };
82 
83 /// Lattice point. Each lattice point consists of a conjunction of tensor
84 /// loop indices (encoded in a bitvector) and the index of the corresponding
85 /// tensor expression.
86 struct LatPoint {
87   LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
88     bits.set(b);
89   }
90   LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
91   /// Conjunction of tensor loop indices as bitvector. This represents
92   /// all indices involved in the tensor expression
93   llvm::BitVector bits;
94   /// Simplified conjunction of tensor loop indices as bitvector. This
95   /// represents a simplified condition under which this tensor expression
96   /// must execute. Pre-computed during codegen to avoid repeated eval.
97   llvm::BitVector simple;
98   /// Index of the tensor expresssion.
99   unsigned exp;
100 };
101 
102 /// A class to handle all iteration lattice operations. This class abstracts
103 /// away from some implementation details of storing iteration lattices and
104 /// tensor expressions. This allows for fine-tuning performance characteristics
105 /// independently from the basic algorithm if bottlenecks are identified.
106 class Merger {
107 public:
108   /// Constructs a merger for the given number of tensors and loops. The
109   /// user supplies the number of tensors involved in the kernel, with the
110   /// last tensor in this set denoting the output tensor. The merger adds an
111   /// additional synthetic tensor at the end of this set to represent all
112   /// invariant expressions in the kernel.
113   Merger(unsigned t, unsigned l)
114       : outTensor(t - 1), numTensors(t + 1), numLoops(l),
115         dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
116 
117   /// Adds a tensor expression. Returns its index.
118   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
119     unsigned e = tensorExps.size();
120     tensorExps.push_back(TensorExp(k, e0, e1, v));
121     return e;
122   }
123   unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
124 
125   /// Adds an iteration lattice point. Returns its index.
126   unsigned addLat(unsigned t, unsigned i, unsigned e) {
127     assert(t < numTensors && i < numLoops);
128     unsigned p = latPoints.size();
129     latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
130     return p;
131   }
132 
133   /// Adds a new, initially empty, set. Returns its index.
134   unsigned addSet() {
135     unsigned s = latSets.size();
136     latSets.emplace_back(SmallVector<unsigned, 16>());
137     return s;
138   }
139 
140   /// Computes a single conjunction of two lattice points by taking the "union"
141   /// of loop indices (effectively constructing a larger "intersection" of those
142   /// indices) with a newly constructed tensor (sub)expression of given kind.
143   /// Returns the index of the new lattice point.
144   unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
145     unsigned p = latPoints.size();
146     llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
147     nb |= latPoints[p1].bits;
148     unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
149     latPoints.push_back(LatPoint(nb, e));
150     return p;
151   }
152 
153   /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
154   /// cartesian product. Returns the index of the new set.
155   unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
156     unsigned s = addSet();
157     for (unsigned p0 : latSets[s0])
158       for (unsigned p1 : latSets[s1])
159         latSets[s].push_back(conjLatPoint(kind, p0, p1));
160     return s;
161   }
162 
163   /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
164   /// Returns the index of the new set.
165   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
166     unsigned s = takeConj(kind, s0, s1);
167     for (unsigned p : latSets[s0])
168       latSets[s].push_back(p);
169     for (unsigned p : latSets[s1])
170       latSets[s].push_back(p);
171     return s;
172   }
173 
174   /// Optimizes the iteration lattice points in the given set. This
175   /// method should be called right before code generation to avoid
176   /// generating redundant loops and conditions.
177   unsigned optimizeSet(unsigned s0) {
178     unsigned s = addSet();
179     assert(latSets[s0].size() != 0);
180     unsigned p0 = latSets[s0][0];
181     for (unsigned p1 : latSets[s0]) {
182       bool add = true;
183       if (p0 != p1) {
184         // Is this a straightforward copy?
185         unsigned e = latPoints[p1].exp;
186         if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
187           continue;
188         // Conjunction already covered?
189         for (unsigned p2 : latSets[s]) {
190           assert(!latGT(p1, p2)); // Lj => Li would be bad
191           if (onlyDenseDiff(p2, p1)) {
192             add = false;
193             break;
194           }
195         }
196         assert(!add || latGT(p0, p1));
197       }
198       if (add)
199         latSets[s].push_back(p1);
200     }
201     for (unsigned p : latSets[s])
202       latPoints[p].simple = simplifyCond(s, p);
203     return s;
204   }
205 
206   /// Simplifies the conditions in a conjunction of a given lattice point
207   /// within the given set using just two basic rules:
208   /// (1) multiple dense conditions are reduced to single dense, and
209   /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
210   llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
211     // First determine if this lattice point is a *singleton*, i.e.,
212     // the last point in a lattice, no other is less than this one.
213     bool isSingleton = true;
214     for (unsigned p1 : latSets[s]) {
215       if (p0 != p1 && latGT(p0, p1)) {
216         isSingleton = false;
217         break;
218       }
219     }
220     // Now apply the two basic rules.
221     llvm::BitVector simple = latPoints[p0].bits;
222     bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
223     for (unsigned b = 0, be = simple.size(); b < be; b++) {
224       if (simple[b] && !isDim(b, Dim::kSparse)) {
225         if (reset)
226           simple.reset(b);
227         reset = true;
228       }
229     }
230     return simple;
231   }
232 
233   /// Returns true if Li > Lj.
234   bool latGT(unsigned i, unsigned j) const {
235     const llvm::BitVector &bitsi = latPoints[i].bits;
236     const llvm::BitVector &bitsj = latPoints[j].bits;
237     assert(bitsi.size() == bitsj.size());
238     if (bitsi.count() > bitsj.count()) {
239       for (unsigned b = 0, be = bitsj.size(); b < be; b++)
240         if (bitsj[b] && !bitsi[b])
241           return false;
242       return true;
243     }
244     return false;
245   }
246 
247   /// Returns true if Li and Lj only differ in dense.
248   bool onlyDenseDiff(unsigned i, unsigned j) {
249     llvm::BitVector tmp = latPoints[j].bits;
250     tmp ^= latPoints[i].bits;
251     return !hasAnyDimOf(tmp, Dim::kSparse);
252   }
253 
254   /// Bit translation.
255   unsigned tensor(unsigned b) const { return b % numTensors; }
256   unsigned index(unsigned b) const { return b / numTensors; }
257 
258   /// Returns true if bit corresponds to queried dim.
259   bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
260 
261   /// Returns true if tensor access at given index has queried dim.
262   bool isDim(unsigned t, unsigned i, Dim d) const {
263     assert(t < numTensors && i < numLoops);
264     return dims[t][i] == d;
265   }
266 
267   /// Returns true if any set bit corresponds to queried dim.
268   bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
269     for (unsigned b = 0, be = bits.size(); b < be; b++)
270       if (bits[b] && isDim(b, d))
271         return true;
272     return false;
273   }
274 
275   /// Setter
276   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
277 
278   /// Getters.
279   TensorExp &exp(unsigned e) { return tensorExps[e]; }
280   LatPoint &lat(unsigned l) { return latPoints[l]; }
281   SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
282 
283 private:
284   const unsigned outTensor;
285   const unsigned numTensors;
286   const unsigned numLoops;
287 
288   std::vector<std::vector<Dim>> dims;
289   llvm::SmallVector<TensorExp, 32> tensorExps;
290   llvm::SmallVector<LatPoint, 16> latPoints;
291   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
292 };
293 
294 // Code generation.
295 struct CodeGen {
296   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
297       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
298         pointers(numTensors, std::vector<Value>(numLoops)),
299         indices(numTensors, std::vector<Value>(numLoops)),
300         highs(numTensors, std::vector<Value>(numLoops)),
301         pidxs(numTensors, std::vector<Value>(numLoops)),
302         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
303         curVecLength(1), curVecMask() {}
304   /// Sparsification options.
305   SparsificationOptions options;
306   /// Universal dense indices and upper bounds (by index). The loops array
307   /// is updated with the value of the universal dense index in the current
308   /// loop. The sizes array is set once with the inferred dimension sizes.
309   std::vector<Value> loops;
310   std::vector<Value> sizes;
311   /// Buffers for storing dense and sparse numerical values (by tensor).
312   /// This array is set once during bufferization of all tensors.
313   std::vector<Value> buffers;
314   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
315   /// This array is set once during bufferization of all sparse tensors.
316   std::vector<std::vector<Value>> pointers;
317   std::vector<std::vector<Value>> indices;
318   /// Sparse iteration information (by tensor and index). These arrays
319   /// are updated to remain current within the current loop.
320   std::vector<std::vector<Value>> highs;
321   std::vector<std::vector<Value>> pidxs;
322   std::vector<std::vector<Value>> idxs;
323   /// Current reduction, updated during code generation. When indices of a
324   /// reduction are exhausted,  all inner loops can "scalarize" the reduction.
325   // TODO: currently only done for (a chain of) innermost for-loops, where it
326   // is most effective; we could generalize to more outer and while-loops.
327   unsigned redExp;
328   Value redVal;
329   // Current vector length and mask.
330   unsigned curVecLength;
331   Value curVecMask;
332 };
333 
334 } // namespace
335 
336 // Helper method to translate dim level type to internal representation.
337 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
338   if (enc) {
339     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
340     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
341       return Dim::kSparse;
342     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
343       return Dim::kSingle;
344   }
345   return Dim::kDense;
346 }
347 
348 /// Helper method to inspect sparse encodings in the tensor types.
349 /// Fills the per-dimension sparsity information for all tensors.
350 static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
351   unsigned numTensors = op.getNumShapedOperands();
352   for (unsigned t = 0; t < numTensors; t++) {
353     auto map = op.getIndexingMap(t);
354     unsigned rank = op.getShapedType(t).getRank();
355     auto enc = getSparseTensorEncoding(op.getShapedType(t));
356     for (unsigned d = 0; d < rank; d++) {
357       unsigned idx = map.getDimPosition(d);
358       merger.setDim(t, idx, toDim(enc, d));
359     }
360   }
361 }
362 
363 /// A DFS helper to compute a topological sort. Note that recursion is
364 /// bounded by the number of implicit loops, which is always small.
365 /// Returns false when a cycle is detected.
366 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
367                        std::vector<unsigned> &topSort,
368                        std::vector<std::vector<bool>> &adjM) {
369   if (visit[i] != 0)
370     return visit[i] != 1; // 1 denotes cycle!
371   visit[i] = 1;
372   for (unsigned j = 0, e = visit.size(); j < e; j++)
373     if (adjM[i][j])
374       if (!topSortDFS(j, visit, topSort, adjM))
375         return false;
376   visit[i] = 2;
377   topSort.push_back(i);
378   return true;
379 }
380 
381 /// Computes a topologically sorted iteration graph for the linalg operation.
382 /// Ensures all tensors are visited in natural index order. This is essential
383 /// for sparse storage formats since these only support access along fixed
384 /// dimensions. Even for dense storage formats, however, the natural index
385 /// order yields innermost unit-stride access with better spatial locality.
386 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
387                                   std::vector<unsigned> &topSort,
388                                   bool sparseOnly) {
389   // Set up an n x n from/to adjacency matrix of the iteration graph
390   // for the implicit loop indices i_0 .. i_n-1.
391   unsigned n = op.getNumLoops();
392   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
393 
394   // Iterate over the indexing maps of every tensor in the tensor expression.
395   unsigned numTensors = op.getNumShapedOperands();
396   for (unsigned t = 0; t < numTensors; t++) {
397     auto map = op.getIndexingMap(t);
398     assert(map.getNumDims() == n);
399     // Skip dense tensor constraints when sparse only is requested.
400     if (sparseOnly && !getSparseTensorEncoding(op.getShapedType(t)))
401       continue;
402     // At the moment, we take the index variables in the tensor access
403     // expression in the order in which they appear (conceptually a
404     // "row-major" layout of every tensor). So, a tensor access A_ijk
405     // forces the ordering i < j < k on the loop indices.
406     // TODO: support affine map to define alternative dimension orders.
407     for (unsigned d = 1, e = map.getNumResults(); d < e; d++) {
408       unsigned f = map.getDimPosition(d - 1);
409       unsigned t = map.getDimPosition(d);
410       adjM[f][t] = true;
411     }
412   }
413 
414   // Topologically sort the iteration graph to determine loop order.
415   // Report failure for a cyclic iteration graph.
416   topSort.clear();
417   topSort.reserve(n);
418   std::vector<unsigned> visit(n, 0);
419   for (unsigned i = 0; i < n; i++)
420     if (visit[i] == 0)
421       if (!topSortDFS(i, visit, topSort, adjM))
422         return false; // cycle!
423   std::reverse(std::begin(topSort), std::end(topSort));
424   return true;
425 }
426 
427 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
428 /// This simplifies constructing (sub)expressions during iteration lattice
429 /// building (compared to using the SSA representation everywhere).
430 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
431                                          Value val) {
432   if (auto arg = val.dyn_cast<BlockArgument>()) {
433     unsigned argN = arg.getArgNumber();
434     if (arg.getOwner()->getParentOp() == op) {
435       // Any parameter of the generic op is considered a tensor,
436       // indexed by the implicit loop bounds.
437       auto map = op.getIndexingMap(argN);
438       if (map.isProjectedPermutation())
439         return merger.addExp(Kind::kTensor, argN);
440       // Cannot handle (yet).
441       return None;
442     }
443     // Any parameter of a higher op is invariant.
444     return merger.addExp(Kind::kInvariant, val);
445   }
446   Operation *def = val.getDefiningOp();
447   if (def->getBlock() != &op.region().front()) {
448     // Something defined outside is invariant.
449     return merger.addExp(Kind::kInvariant, val);
450   } else if (def->getNumOperands() == 2) {
451     // Construct binary operations if subexpressions could be built.
452     auto x = buildTensorExp(merger, op, def->getOperand(0));
453     auto y = buildTensorExp(merger, op, def->getOperand(1));
454     if (x.hasValue() && y.hasValue()) {
455       unsigned e0 = x.getValue();
456       unsigned e1 = y.getValue();
457       if (isa<MulFOp>(def))
458         return merger.addExp(Kind::kMulF, e0, e1);
459       if (isa<MulIOp>(def))
460         return merger.addExp(Kind::kMulI, e0, e1);
461       if (isa<AddFOp>(def))
462         return merger.addExp(Kind::kAddF, e0, e1);
463       if (isa<AddIOp>(def))
464         return merger.addExp(Kind::kAddI, e0, e1);
465     }
466   }
467   // Cannot build (yet).
468   return None;
469 }
470 
471 /// Builds the iteration lattices in a bottom-up traversal given the remaining
472 /// tensor (sub)expression and the next loop index in the iteration graph.
473 static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
474                               unsigned exp, unsigned idx) {
475   Kind kind = merger.exp(exp).kind;
476   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
477     // Either the index is really used in the tensor expression, or it is
478     // set to the undefined index in that dimension. An invariant expression
479     // is set to a synthetic tensor with undefined indices only.
480     unsigned s = merger.addSet();
481     unsigned t =
482         kind == Kind::kTensor ? merger.exp(exp).e0 : op.getNumShapedOperands();
483     merger.set(s).push_back(merger.addLat(t, idx, exp));
484     return s;
485   }
486   unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx);
487   unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx);
488   switch (kind) {
489   case Kind::kTensor:
490   case Kind::kInvariant:
491     llvm_unreachable("handled above");
492   case Kind::kMulF:
493   case Kind::kMulI:
494     return merger.takeConj(kind, s0, s1);
495   case Kind::kAddF:
496   case Kind::kAddI:
497     return merger.takeDisj(kind, s0, s1);
498   }
499   llvm_unreachable("unexpected expression kind");
500 }
501 
502 /// Maps sparse integer option to actual integral storage type.
503 static Type genIntType(PatternRewriter &rewriter, unsigned width) {
504   if (width == 0)
505     return rewriter.getIndexType();
506   return rewriter.getIntegerType(width);
507 }
508 
509 /// Detects in-place annotation on tensor argument.
510 static bool getInPlace(Value val) {
511   if (auto arg = val.dyn_cast<BlockArgument>())
512     if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
513       if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
514               arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName))
515         return attr.getValue();
516   return false;
517 }
518 
519 /// Generates buffer for the output tensor.
520 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
521                              linalg::GenericOp op, MemRefType denseTp,
522                              ArrayRef<Value> args) {
523   Location loc = op.getLoc();
524   Value tensor = op.getOutput(0);
525   // The output tensor simply could materialize from the buffer that will
526   // be generated for the tensor present in the outs() clause. This has
527   // the major advantage that the sparse kernel only updates the nonzero
528   // positions for the output tensor.
529   if (getInPlace(tensor))
530     return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
531   // By default, a new buffer is allocated which is initialized to the
532   // tensor defined in the outs() clause. This is always correct but
533   // introduces a dense initialization component that may negatively
534   // impact the running complexity of the sparse kernel.
535   Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
536   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
537   rewriter.create<linalg::CopyOp>(loc, init, alloc);
538   return alloc;
539 }
540 
541 /// Local bufferization of all dense and sparse data structures.
542 /// This code enables testing the first prototype sparse compiler.
543 // TODO: replace this with a proliferated bufferization strategy
544 static void genBuffers(Merger &merger, CodeGen &codegen,
545                        PatternRewriter &rewriter, linalg::GenericOp op) {
546   Location loc = op.getLoc();
547   unsigned numTensors = op.getNumShapedOperands();
548   unsigned numInputs = op.getNumInputs();
549   assert(numTensors == numInputs + 1);
550   // For every tensor, find lower and upper bound on dimensions, set the
551   // same bounds on loop indices, and obtain dense or sparse buffer(s).
552   SmallVector<Value, 4> args;
553   for (unsigned t = 0; t < numTensors; t++) {
554     Value tensor = t < numInputs ? op.getInput(t) : op.getOutput(0);
555     auto tensorType = op.getShapedType(t);
556     auto shape = tensorType.getShape();
557     auto map = op.getIndexingMap(t);
558     auto enc = getSparseTensorEncoding(tensorType);
559     // Scan all dimensions of current tensor.
560     args.clear();
561     for (unsigned d = 0, rank = shape.size(); d < rank; d++) {
562       unsigned i = map.getDimPosition(d);
563       // Handle sparse storage schemes.
564       if (merger.isDim(t, i, Dim::kSparse)) {
565         auto dynShape = {ShapedType::kDynamicSize};
566         auto ptrTp = MemRefType::get(
567             dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
568         auto indTp = MemRefType::get(
569             dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
570         Value dim = rewriter.create<ConstantIndexOp>(loc, d);
571         // Generate sparse primitives to obtains pointer and indices.
572         codegen.pointers[t][i] =
573             rewriter.create<ToPointersOp>(loc, ptrTp, tensor, dim);
574         codegen.indices[t][i] =
575             rewriter.create<ToIndicesOp>(loc, indTp, tensor, dim);
576       }
577       // Find lower and upper bound in current dimension.
578       Value up;
579       if (shape[d] == MemRefType::kDynamicSize) {
580         up = rewriter.create<memref::DimOp>(loc, tensor, d);
581         args.push_back(up);
582       } else {
583         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
584       }
585       codegen.sizes[i] = codegen.highs[t][i] = up;
586     }
587     // Perform the required bufferization. All dense inputs materialize
588     // from the input tensor. The dense output tensor needs special
589     // handling. Sparse inputs use a sparse primitive to obtain the values.
590     if (!enc) {
591       auto denseTp = MemRefType::get(shape, tensorType.getElementType());
592       if (t < numInputs)
593         codegen.buffers[t] =
594             rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
595       else
596         codegen.buffers[t] =
597             genOutputBuffer(codegen, rewriter, op, denseTp, args);
598     } else {
599       auto dynShape = {ShapedType::kDynamicSize};
600       auto sparseTp = MemRefType::get(dynShape, tensorType.getElementType());
601       codegen.buffers[t] = rewriter.create<ToValuesOp>(loc, sparseTp, tensor);
602     }
603   }
604 }
605 
606 /// Constructs vector type.
607 static VectorType vectorType(CodeGen &codegen, Type etp) {
608   return VectorType::get(codegen.curVecLength, etp);
609 }
610 
611 /// Constructs vector type from pointer.
612 static VectorType vectorType(CodeGen &codegen, Value ptr) {
613   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
614 }
615 
616 /// Constructs vector iteration mask.
617 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
618                            Value iv, Value lo, Value hi, Value step) {
619   Location loc = iv.getLoc();
620   VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1));
621   // Special case if the vector length evenly divides the trip count (for
622   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
623   // so that all subsequent masked memory operations are immediately folded
624   // into unconditional memory operations.
625   IntegerAttr loInt, hiInt, stepInt;
626   if (matchPattern(lo, m_Constant(&loInt)) &&
627       matchPattern(hi, m_Constant(&hiInt)) &&
628       matchPattern(step, m_Constant(&stepInt))) {
629     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
630       return rewriter.create<vector::BroadcastOp>(
631           loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1));
632   }
633   // Otherwise, generate a vector mask that avoids overrunning the upperbound
634   // during vector execution. Here we rely on subsequent loop optimizations to
635   // avoid executing the mask in all iterations, for example, by splitting the
636   // loop into an unconditional vector loop and a scalar cleanup loop.
637   Value end = rewriter.create<SubIOp>(loc, hi, iv);
638   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
639 }
640 
641 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
642 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
643                            Value ptr, ArrayRef<Value> args) {
644   Location loc = ptr.getLoc();
645   VectorType vtp = vectorType(codegen, ptr);
646   Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
647   if (args.back().getType().isa<VectorType>()) {
648     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
649     Value indexVec = args.back();
650     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
651     return rewriter.create<vector::GatherOp>(
652         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
653   }
654   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
655                                                codegen.curVecMask, pass);
656 }
657 
658 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
659 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
660                            Value rhs, Value ptr, ArrayRef<Value> args) {
661   Location loc = ptr.getLoc();
662   if (args.back().getType().isa<VectorType>()) {
663     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
664     Value indexVec = args.back();
665     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
666     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
667                                        codegen.curVecMask, rhs);
668     return;
669   }
670   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
671                                          rhs);
672 }
673 
674 /// Generates a vectorized invariant. Here we rely on subsequent loop
675 /// optimizations to hoist the invariant broadcast out of the vector loop.
676 static Value genVectorInvariantValue(CodeGen &codegen,
677                                      PatternRewriter &rewriter, Value val) {
678   VectorType vtp = vectorType(codegen, val.getType());
679   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
680 }
681 
682 /// Generates a load on a dense or sparse tensor.
683 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
684                            PatternRewriter &rewriter, linalg::GenericOp op,
685                            unsigned exp) {
686   // Test if the load was hoisted to a higher loop nest.
687   Value val = merger.exp(exp).val;
688   if (val) {
689     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
690       return genVectorInvariantValue(codegen, rewriter, val);
691     return val;
692   }
693   // Actual load.
694   SmallVector<Value, 4> args;
695   unsigned tensor = merger.exp(exp).e0;
696   auto map = op.getIndexingMap(tensor);
697   auto enc = getSparseTensorEncoding(op.getShapedType(tensor));
698   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
699     unsigned idx = map.getDimPosition(i);
700     args.push_back(codegen.loops[idx]); // universal dense index
701     if (enc) {
702       args.clear();
703       args.push_back(codegen.pidxs[tensor][idx]); // position index
704     }
705   }
706   Location loc = op.getLoc();
707   Value ptr = codegen.buffers[tensor];
708   if (codegen.curVecLength > 1)
709     return genVectorLoad(codegen, rewriter, ptr, args);
710   return rewriter.create<memref::LoadOp>(loc, ptr, args);
711 }
712 
713 /// Generates a store on a dense tensor.
714 static void genTensorStore(Merger &merger, CodeGen &codegen,
715                            PatternRewriter &rewriter, linalg::GenericOp op,
716                            unsigned tensor, Value rhs) {
717   Location loc = op.getLoc();
718   // Test if this is a scalarized reduction.
719   unsigned lhs = op.getNumShapedOperands() - 1;
720   if (lhs == tensor && codegen.redVal) {
721     if (codegen.curVecLength > 1)
722       rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs,
723                                       codegen.redVal);
724     codegen.redVal = rhs;
725     return;
726   }
727   // Actual store.
728   SmallVector<Value, 4> args;
729   auto map = op.getIndexingMap(tensor);
730   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
731     unsigned idx = map.getDimPosition(i);
732     args.push_back(codegen.loops[idx]); // universal dense index
733   }
734   Value ptr = codegen.buffers[tensor];
735   if (codegen.curVecLength > 1)
736     genVectorStore(codegen, rewriter, rhs, ptr, args);
737   else
738     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
739 }
740 
741 /// Generates a pointer/index load from the sparse storage scheme. Narrower
742 /// data types need to be zero extended before casting the value into the
743 /// index type used for looping and indexing.
744 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
745                      Value ptr, Value s) {
746   // See https://llvm.org/docs/GetElementPtr.html for some background on
747   // the complications described below.
748   if (codegen.curVecLength > 1) {
749     // Since the index vector is used in a subsequent gather/scatter operations,
750     // which effectively defines an unsigned pointer + signed index, we must
751     // zero extend the vector to an index width. For 8-bit and 16-bit values,
752     // an 32-bit index width suffices. For 32-bit values, zero extending the
753     // elements into 64-bit loses some performance since the 32-bit indexed
754     // gather/scatter is more efficient than the 64-bit index variant (in
755     // the future, we could introduce a flag that states the negative space
756     // of 32-bit indices is unused). For 64-bit values, there is no good way
757     // to state that the indices are unsigned, with creates the potential of
758     // incorrect address calculations in the unlikely case we need such
759     // extremely large offsets.
760     Type etp = ptr.getType().cast<MemRefType>().getElementType();
761     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
762     if (!etp.isa<IndexType>()) {
763       if (etp.getIntOrFloatBitWidth() < 32)
764         vload = rewriter.create<ZeroExtendIOp>(
765             loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
766       else if (etp.getIntOrFloatBitWidth() < 64)
767         vload = rewriter.create<ZeroExtendIOp>(
768             loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
769     }
770     return vload;
771   }
772   // For the scalar case, we simply zero extend narrower indices into 64-bit
773   // values before casting to index without a performance penalty. Here too,
774   // however, indices that already are 64-bit, in theory, cannot express the
775   // full range as explained above.
776   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
777   if (!load.getType().isa<IndexType>()) {
778     if (load.getType().getIntOrFloatBitWidth() < 64)
779       load = rewriter.create<ZeroExtendIOp>(loc, load,
780                                             rewriter.getIntegerType(64));
781     load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
782   }
783   return load;
784 }
785 
786 /// Generates an invariant value.
787 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
788                                PatternRewriter &rewriter, unsigned exp) {
789   Value val = merger.exp(exp).val;
790   if (codegen.curVecLength > 1)
791     return genVectorInvariantValue(codegen, rewriter, val);
792   return val;
793 }
794 
795 /// Generates an address computation "sz * p + i".
796 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
797                         Location loc, Value size, Value p, Value i) {
798   Value mul = rewriter.create<MulIOp>(loc, size, p);
799   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
800     Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType());
801     mul = genVectorInvariantValue(codegen, rewriter, inv);
802   }
803   return rewriter.create<AddIOp>(loc, mul, i);
804 }
805 
806 /// Generates start of a reduction.
807 static Value genReductionStart(Merger &merger, CodeGen &codegen,
808                                PatternRewriter &rewriter,
809                                linalg::GenericOp op) {
810   if (codegen.redVal)
811     return codegen.redVal; // chained with previous for-loop
812   if (codegen.curVecLength > 1) {
813     // TODO: assumes + reductions for now
814     VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]);
815     return rewriter.create<ConstantOp>(op.getLoc(), vtp,
816                                        rewriter.getZeroAttr(vtp));
817   }
818   return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
819 }
820 
821 /// Generates end of a reduction.
822 static void genReductionEnd(Merger &merger, CodeGen &codegen,
823                             PatternRewriter &rewriter, linalg::GenericOp op) {
824   Value red = codegen.redVal;
825   if (!red)
826     return;
827   assert(codegen.curVecLength == 1);
828   codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
829   unsigned lhs = op.getNumShapedOperands() - 1;
830   if (auto vtp = red.getType().dyn_cast<VectorType>()) {
831     // TODO: assumes + reductions for now
832     StringAttr kind = rewriter.getStringAttr("add");
833     Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
834     // Integer reductions don't accept an accumulator.
835     if (vtp.getElementType().isa<IntegerType>()) {
836       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
837                                                  kind, red, ValueRange{});
838       red = rewriter.create<AddIOp>(op.getLoc(), red, ld);
839     } else {
840       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
841                                                  kind, red, ld);
842     }
843   }
844   genTensorStore(merger, codegen, rewriter, op, lhs, red);
845 }
846 
847 /// Recursively generates tensor expression.
848 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
849                     linalg::GenericOp op, unsigned exp) {
850   if (merger.exp(exp).kind == Kind::kTensor)
851     return genTensorLoad(merger, codegen, rewriter, op, exp);
852   else if (merger.exp(exp).kind == Kind::kInvariant)
853     return genInvariantValue(merger, codegen, rewriter, exp);
854   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
855   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
856   switch (merger.exp(exp).kind) {
857   case Kind::kTensor:
858   case Kind::kInvariant:
859     llvm_unreachable("handled above");
860   case Kind::kMulF:
861     return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
862   case Kind::kMulI:
863     return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
864   case Kind::kAddF:
865     return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
866   case Kind::kAddI:
867     return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
868   }
869   llvm_unreachable("unexpected expression kind");
870 }
871 
872 /// Hoists loop invariant tensor loads for which indices have been exhausted.
873 static void genInvariants(Merger &merger, CodeGen &codegen,
874                           PatternRewriter &rewriter, linalg::GenericOp op,
875                           unsigned exp, unsigned ldx, bool hoist) {
876   if (merger.exp(exp).kind == Kind::kTensor) {
877     // Inspect tensor indices.
878     bool atLevel = ldx == -1u;
879     unsigned tensor = merger.exp(exp).e0;
880     auto map = op.getIndexingMap(tensor);
881     for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
882       unsigned idx = map.getDimPosition(i);
883       if (!codegen.loops[idx])
884         return; // still in play
885       else if (idx == ldx)
886         atLevel = true;
887     }
888     // All exhausted at this level (atLevel denotes exactly at this level).
889     unsigned lhs = op.getNumShapedOperands() - 1;
890     if (lhs == tensor) {
891       codegen.redExp = hoist ? exp : -1u;
892     } else if (atLevel) {
893       merger.exp(exp).val =
894           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
895     }
896   } else if (merger.exp(exp).kind != Kind::kInvariant) {
897     // Traverse into the binary operations. Note that we only hoist
898     // tensor loads, since subsequent MLIR/LLVM passes know how to
899     // deal with all other kinds of derived loop invariants.
900     unsigned e0 = merger.exp(exp).e0;
901     unsigned e1 = merger.exp(exp).e1;
902     genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist);
903     genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist);
904   }
905 }
906 
907 /// Generates initialization code for the subsequent loop sequence at
908 /// current index level. Returns true if the loop sequence needs to
909 /// maintain the universal index.
910 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
911                     linalg::GenericOp op, std::vector<unsigned> &topSort,
912                     unsigned at, llvm::BitVector &inits) {
913   bool needsUniv = false;
914   Location loc = op.getLoc();
915   unsigned idx = topSort[at];
916 
917   // Initialize sparse positions.
918   for (unsigned b = 0, be = inits.size(); b < be; b++) {
919     if (inits[b]) {
920       unsigned tensor = merger.tensor(b);
921       assert(idx == merger.index(b));
922       if (merger.isDim(b, Dim::kSparse)) {
923         // Initialize sparse index.
924         unsigned pat = at;
925         for (; pat != 0; pat--) {
926           if (codegen.pidxs[tensor][topSort[pat - 1]])
927             break;
928         }
929         Value ptr = codegen.pointers[tensor][idx];
930         Value one = rewriter.create<ConstantIndexOp>(loc, 1);
931         Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
932                               : codegen.pidxs[tensor][topSort[pat - 1]];
933         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
934         Value p1 = rewriter.create<AddIOp>(loc, p0, one);
935         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
936       } else {
937         // Dense index still in play.
938         needsUniv = true;
939       }
940     }
941   }
942 
943   // Initialize the universal dense index.
944   codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0);
945   return needsUniv;
946 }
947 
948 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
949 /// operation is a candidate. Whether it is actually converted to SIMD code
950 /// depends on the requested strategy.
951 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
952   switch (codegen.options.vectorizationStrategy) {
953   case SparseVectorizationStrategy::kNone:
954     return false;
955   case SparseVectorizationStrategy::kDenseInnerLoop:
956     return isInner && !isSparse;
957   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
958     return isInner;
959   }
960   llvm_unreachable("unexpected vectorization strategy");
961 }
962 
963 /// Returns parallelization strategy. Any implicit loop in the Linalg operation
964 /// that is marked "parallel" is a candidate. Whether it is actually converted
965 /// to a parallel operation depends on the requested strategy.
966 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
967                           bool isSparse, bool isVector) {
968   switch (codegen.options.parallelizationStrategy) {
969   case SparseParallelizationStrategy::kNone:
970     return false;
971   case SparseParallelizationStrategy::kDenseOuterLoop:
972     return isOuter && !isSparse && !isReduction && !isVector;
973   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
974     return isOuter && !isReduction && !isVector;
975   case SparseParallelizationStrategy::kDenseAnyLoop:
976     return !isSparse && !isReduction && !isVector;
977   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
978     return !isReduction && !isVector;
979   }
980   llvm_unreachable("unexpected parallelization strategy");
981 }
982 
983 /// Checks unit strides for dense tensors. The iteration graph may have ignored
984 /// dense access patterns in order to avoid cycles (sparse access patterns are
985 /// always placed innermost), but that means dense access has become strided.
986 /// For now, we reject vectorization of such cases.
987 /// TODO: implement strided load/stores on dense arrays
988 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
989                              unsigned idx) {
990   unsigned numTensors = op.getNumShapedOperands();
991   for (unsigned t = 0; t < numTensors; t++) {
992     if (!getSparseTensorEncoding(op.getShapedType(t))) {
993       auto map = op.getIndexingMap(t);
994       unsigned r = map.getNumResults();
995       for (unsigned i = 0; i < r; i++) {
996         if (map.getDimPosition(i) == idx && i != r - 1)
997           return false;
998       }
999     }
1000   }
1001   return true;
1002 }
1003 
1004 /// Generates a for-loop on a single index.
1005 static Operation *genFor(Merger &merger, CodeGen &codegen,
1006                          PatternRewriter &rewriter, linalg::GenericOp op,
1007                          bool isOuter, bool isInner, unsigned idx,
1008                          llvm::BitVector &indices) {
1009   unsigned fb = indices.find_first();
1010   unsigned tensor = merger.tensor(fb);
1011   assert(idx == merger.index(fb));
1012   auto iteratorTypes = op.iterator_types().getValue();
1013   bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]);
1014   bool isSparse = merger.isDim(fb, Dim::kSparse);
1015   bool isVector = isVectorFor(codegen, isInner, isSparse) &&
1016                   denseUnitStrides(merger, op, idx);
1017   bool isParallel =
1018       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
1019 
1020   // Prepare vector length.
1021   if (isVector)
1022     codegen.curVecLength = codegen.options.vectorLength;
1023 
1024   // Loop bounds and increment.
1025   Location loc = op.getLoc();
1026   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
1027   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
1028   Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength);
1029 
1030   // Emit a parallel loop.
1031   if (isParallel) {
1032     assert(!isVector);
1033     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
1034     if (isSparse)
1035       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
1036     else
1037       codegen.loops[idx] = parOp.getInductionVars()[0];
1038     rewriter.setInsertionPointToStart(parOp.getBody());
1039     return parOp;
1040   }
1041 
1042   // Emit a sequential loop, potentially with a scalarized reduction.
1043   bool scalarRed = isInner && codegen.redExp != -1u;
1044   SmallVector<Value, 4> operands;
1045   if (scalarRed) {
1046     Value load = genReductionStart(merger, codegen, rewriter, op);
1047     operands.push_back(load);
1048   }
1049   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
1050   if (scalarRed) {
1051     codegen.redVal = merger.exp(codegen.redExp).val =
1052         forOp.getRegionIterArgs().front();
1053   }
1054   // Assign induction variable to sparse or dense index.
1055   Value iv = forOp.getInductionVar();
1056   if (isSparse)
1057     codegen.pidxs[tensor][idx] = iv;
1058   else
1059     codegen.loops[idx] = iv;
1060   rewriter.setInsertionPointToStart(forOp.getBody());
1061   // Share vector iteration mask between all subsequent loads/stores.
1062   if (isVector)
1063     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
1064   return forOp;
1065 }
1066 
1067 /// Emit a while-loop for co-iteration over multiple indices.
1068 static Operation *genWhile(Merger &merger, CodeGen &codegen,
1069                            PatternRewriter &rewriter, linalg::GenericOp op,
1070                            unsigned idx, bool needsUniv,
1071                            llvm::BitVector &indices) {
1072   SmallVector<Type, 4> types;
1073   SmallVector<Value, 4> operands;
1074   // Construct the while-loop with a parameter for each index.
1075   Type indexType = rewriter.getIndexType();
1076   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1077     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1078       unsigned tensor = merger.tensor(b);
1079       assert(idx == merger.index(b));
1080       types.push_back(indexType);
1081       assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
1082              "type mismatch for sparse index");
1083       operands.push_back(codegen.pidxs[tensor][idx]);
1084     }
1085   }
1086   if (needsUniv) {
1087     types.push_back(indexType);
1088     assert(codegen.loops[idx].getType().isa<IndexType>() &&
1089            "type mismatch for universal index");
1090     operands.push_back(codegen.loops[idx]);
1091   }
1092   Location loc = op.getLoc();
1093   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
1094   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
1095   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
1096 
1097   // Build the "before" region, which effectively consists
1098   // of a conjunction of "i < upper" tests on all induction.
1099   rewriter.setInsertionPointToStart(&whileOp.before().front());
1100   Value cond;
1101   unsigned o = 0;
1102   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1103     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1104       unsigned tensor = merger.tensor(b);
1105       assert(idx == merger.index(b));
1106       Value op1 = before->getArgument(o);
1107       Value op2 = codegen.highs[tensor][idx];
1108       Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2);
1109       cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc;
1110       codegen.pidxs[tensor][idx] = after->getArgument(o++);
1111     }
1112   }
1113   if (needsUniv)
1114     codegen.loops[idx] = after->getArgument(o++);
1115   assert(o == operands.size());
1116   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
1117   rewriter.setInsertionPointToStart(&whileOp.after().front());
1118   return whileOp;
1119 }
1120 
1121 /// Generates a for-loop or a while-loop, depending on whether it implements
1122 /// singleton iteration or co-iteration over the given conjunction.
1123 static Operation *genLoop(Merger &merger, CodeGen &codegen,
1124                           PatternRewriter &rewriter, linalg::GenericOp op,
1125                           std::vector<unsigned> &topSort, unsigned at,
1126                           bool needsUniv, llvm::BitVector &indices) {
1127   unsigned idx = topSort[at];
1128   if (indices.count() == 1) {
1129     bool isOuter = at == 0;
1130     bool isInner = at == topSort.size() - 1;
1131     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
1132                   indices);
1133   }
1134   genReductionEnd(merger, codegen, rewriter, op); // cannot chain
1135   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
1136 }
1137 
1138 /// Generates the local variables for this loop, consisting of the sparse
1139 /// indices, restored universal dense index, and dense positions.
1140 static void genLocals(Merger &merger, CodeGen &codegen,
1141                       PatternRewriter &rewriter, linalg::GenericOp op,
1142                       std::vector<unsigned> &topSort, unsigned at,
1143                       bool needsUniv, llvm::BitVector &locals) {
1144   Location loc = op.getLoc();
1145   unsigned idx = topSort[at];
1146 
1147   // Initialize sparse indices.
1148   Value min;
1149   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1150     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1151       unsigned tensor = merger.tensor(b);
1152       assert(idx == merger.index(b));
1153       Value ptr = codegen.indices[tensor][idx];
1154       Value s = codegen.pidxs[tensor][idx];
1155       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1156       codegen.idxs[tensor][idx] = load;
1157       if (!needsUniv) {
1158         if (min) {
1159           Value cmp =
1160               rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
1161           min = rewriter.create<SelectOp>(loc, cmp, load, min);
1162         } else {
1163           min = load;
1164         }
1165       }
1166     }
1167   }
1168 
1169   // Merge dense universal index over minimum.
1170   if (min) {
1171     assert(!needsUniv);
1172     codegen.loops[idx] = min;
1173   }
1174 
1175   // Initialize dense positions.
1176   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1177     if (locals[b] && merger.isDim(b, Dim::kDense)) {
1178       unsigned tensor = merger.tensor(b);
1179       assert(idx == merger.index(b));
1180       unsigned pat = at;
1181       for (; pat != 0; pat--)
1182         if (codegen.pidxs[tensor][topSort[pat - 1]])
1183           break;
1184       Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
1185                            : codegen.pidxs[tensor][topSort[pat - 1]];
1186       codegen.pidxs[tensor][idx] = genAddress(
1187           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1188     }
1189   }
1190 }
1191 
1192 /// Generates the induction structure for a while-loop.
1193 static void genWhileInduction(Merger &merger, CodeGen &codegen,
1194                               PatternRewriter &rewriter, linalg::GenericOp op,
1195                               unsigned idx, bool needsUniv,
1196                               llvm::BitVector &induction, ResultRange results) {
1197   Location loc = op.getLoc();
1198   unsigned o = 0;
1199   SmallVector<Value, 4> operands;
1200   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
1201   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1202     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1203       unsigned tensor = merger.tensor(b);
1204       assert(idx == merger.index(b));
1205       Value op1 = codegen.idxs[tensor][idx];
1206       Value op2 = codegen.loops[idx];
1207       Value op3 = codegen.pidxs[tensor][idx];
1208       Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1209       Value add = rewriter.create<AddIOp>(loc, op3, one);
1210       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
1211       codegen.pidxs[tensor][idx] = results[o++];
1212     }
1213   }
1214   if (needsUniv) {
1215     operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one));
1216     codegen.loops[idx] = results[o++];
1217   }
1218   assert(o == operands.size());
1219   rewriter.create<scf::YieldOp>(loc, operands);
1220 }
1221 
1222 /// Generates a single if-statement within a while-loop.
1223 static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1224                        PatternRewriter &rewriter, linalg::GenericOp op,
1225                        unsigned idx, llvm::BitVector &conditions) {
1226   Location loc = op.getLoc();
1227   Value cond;
1228   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1229     if (conditions[b]) {
1230       unsigned tensor = merger.tensor(b);
1231       assert(idx == merger.index(b));
1232       Value clause;
1233       if (merger.isDim(b, Dim::kSparse)) {
1234         Value op1 = codegen.idxs[tensor][idx];
1235         Value op2 = codegen.loops[idx];
1236         clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1237       } else {
1238         clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true
1239       }
1240       cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause;
1241     }
1242   }
1243   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
1244   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
1245   return ifOp;
1246 }
1247 
1248 /// Recursively generates code while computing iteration lattices in order
1249 /// to manage the complexity of implementing co-iteration over unions
1250 /// and intersections of sparse iterations spaces.
1251 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1252                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1253                     unsigned exp, unsigned at) {
1254   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1255   if (at == topSort.size()) {
1256     unsigned lhs = op.getNumShapedOperands() - 1;
1257     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1258     genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
1259     return;
1260   }
1261   assert(codegen.curVecLength == 1);
1262 
1263   // Construct iteration lattices for current loop index, with L0 at top.
1264   // Then emit initialization code for the loop sequence at this level.
1265   // We maintain the universal dense index if dense indices are still
1266   // in play for a non-singleton loop sequence.
1267   Location loc = op.getLoc();
1268   unsigned idx = topSort[at];
1269   unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx));
1270   unsigned lsize = merger.set(lts).size();
1271   assert(lsize != 0);
1272   unsigned l0 = merger.set(lts)[0];
1273   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1274   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
1275   bool needsUniv = false;
1276   if (genInit(merger, codegen, rewriter, op, topSort, at,
1277               merger.lat(l0).bits)) {
1278     // Maintain the universal index only if it is actually
1279     // consumed by a subsequent lattice point.
1280     for (unsigned i = 1; i < lsize; i++) {
1281       unsigned li = merger.set(lts)[i];
1282       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) {
1283         needsUniv = true;
1284         break;
1285       }
1286     }
1287   }
1288 
1289   // Emit a loop for every lattice point L0 >= Li.
1290   for (unsigned i = 0; i < lsize; i++) {
1291     unsigned li = merger.set(lts)[i];
1292 
1293     // Emit loop.
1294     codegen.curVecLength = 1;
1295     llvm::BitVector indices = merger.lat(li).simple;
1296     Operation *loop =
1297         genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
1298     genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1299               merger.lat(li).bits);
1300 
1301     // Visit all lattices points with Li >= Lj to generate the
1302     // loop-body, possibly with if statements for coiteration.
1303     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1304     for (unsigned j = 0; j < lsize; j++) {
1305       unsigned lj = merger.set(lts)[j];
1306       unsigned ej = merger.lat(lj).exp;
1307       if (li == lj || merger.latGT(li, lj)) {
1308         // Recurse into body of each branch.
1309         if (isWhile) {
1310           scf::IfOp ifOp =
1311               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1312           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1313           rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
1314         } else {
1315           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1316         }
1317       }
1318     }
1319 
1320     // Wrap-up induction and restore insertion point.
1321     if (isWhile) {
1322       scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
1323       rewriter.setInsertionPointToEnd(&whileOp.after().front());
1324       genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
1325                         merger.lat(li).bits, whileOp.results());
1326     } else {
1327       needsUniv = false;
1328       if (codegen.redVal) {
1329         rewriter.create<scf::YieldOp>(loc, codegen.redVal);
1330         codegen.redVal = loop->getResult(0);
1331       }
1332     }
1333     rewriter.setInsertionPointAfter(loop);
1334   }
1335 
1336   // Wrap-up loop sequence.
1337   codegen.curVecLength = 1;
1338   genReductionEnd(merger, codegen, rewriter, op);
1339   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
1340   codegen.loops[idx] = Value();
1341 }
1342 
1343 namespace {
1344 
1345 /// Sparse rewriting rule for generic Lingalg operation.
1346 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1347 public:
1348   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1349       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1350 
1351   LogicalResult matchAndRewrite(linalg::GenericOp op,
1352                                 PatternRewriter &rewriter) const override {
1353     // Detects sparse annotations and translate the per-dimension sparsity
1354     // information for all tensors to loop indices in the kernel.
1355     assert(op.getNumOutputs() == 1);
1356     unsigned numTensors = op.getNumShapedOperands();
1357     unsigned numLoops = op.iterator_types().getValue().size();
1358     Merger merger(numTensors, numLoops);
1359     findSparseAnnotations(merger, op);
1360 
1361     // Computes a topologically sorted iteration graph to ensure
1362     // tensors are visited in natural index order. Fails on cycles.
1363     // This assumes that higher-level passes have already put the
1364     // tensors in each tensor expression in a feasible order.
1365     std::vector<unsigned> topSort;
1366     if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
1367         !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
1368       return failure();
1369 
1370     // Finds the terminating yield statement and builds the tensor
1371     // expression for the Linalg operation in SSA form.
1372     Operation *yield = op.region().front().getTerminator();
1373     Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
1374     if (!exp.hasValue())
1375       return failure(); // build failure
1376 
1377     // Recursively generates code.
1378     CodeGen codegen(options, numTensors, numLoops);
1379     genBuffers(merger, codegen, rewriter, op);
1380     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
1381     Value result = rewriter.create<memref::TensorLoadOp>(
1382         op.getLoc(), codegen.buffers.back());
1383     rewriter.replaceOp(op, result);
1384     return success();
1385   }
1386 
1387 private:
1388   /// Options to control sparse code generation.
1389   SparsificationOptions options;
1390 };
1391 
1392 } // namespace
1393 
1394 /// Populates the given patterns list with rewriting rules required for
1395 /// the sparsification of linear algebra operations.
1396 void mlir::populateSparsificationPatterns(
1397     RewritePatternSet &patterns, const SparsificationOptions &options) {
1398   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1399 }
1400