1 //===- Sparsification.cpp - Implementation of linalg sparsification -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering annotated linalg dialect to sparse code.
10 //
11 // The concept of letting a compiler generate sparse code automatically was
12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and
13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor
14 // Algebra Compiler (TACO). The implementation in this file closely follows
15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting
16 // rule is applied to each tensor expression in linalg (MLIR's tensor index
17 // notation) where the sparsity of tensors is indicated with annotation using
18 // a per-dimension specification of sparse/dense storage together with a
19 // specification of the order on the dimensions. Subsequently, a topologically
20 // sorted iteration graph, reflecting the required order on indices with respect
21 // to the dimensions of each tensor, is constructed to ensure that all tensors
22 // are visited in natural index order. Next, iteration lattices are constructed
23 // for the tensor expression for every index in topological order. Each
24 // iteration lattice point consists of a conjunction of tensor indices together
25 // with a tensor (sub)expression that needs to be evaluated for that
26 // conjunction. Within the lattice, iteration points are ordered according to
27 // the way indices are exhausted. As such these iteration lattices drive actual
28 // sparse code generation, which consists of a tedious but relatively
29 // straightforward one-to-one mapping from iteration lattices to combinations
30 // of for-loops, while-loops, and if-statements.
31 //
32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations.
33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php).
34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou,
35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler.
36 // Proceedings of the ACM on Programming Languages, October 2017.
37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation.
38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org).
39 //
40 // Implementation detail: We use llvm::SmallVector for vectors with
41 // variable lengths and std::vector for vectors with fixed lengths.
42 //===----------------------------------------------------------------------===//
43 
44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
45 #include "mlir/Dialect/Linalg/Utils/Utils.h"
46 #include "mlir/Dialect/SCF/SCF.h"
47 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
48 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
49 #include "mlir/Dialect/StandardOps/IR/Ops.h"
50 #include "mlir/Dialect/Vector/VectorOps.h"
51 #include "mlir/IR/Matchers.h"
52 #include "llvm/ADT/SmallBitVector.h"
53 
54 using namespace mlir;
55 
56 namespace {
57 
58 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
59 enum class Dim { kSparse, kDense, kUndef };
60 
61 /// Tensor expression. Represents a MLIR expression in tensor index notation.
62 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is
63 /// stored directly. For binary operations, e0 and e1 denote the index of the
64 /// children tensor expressions.
65 struct TensorExp {
66   TensorExp(Kind k, unsigned x, unsigned y, Value v)
67       : kind(k), e0(x), e1(y), val(v) {
68     assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
69            (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
70            (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
71   }
72   Kind kind;
73   /// Indices of children expression(s).
74   unsigned e0;
75   unsigned e1;
76   /// Direct link to IR for an invariant. During code generation,
77   /// field is used to cache "hoisted" loop invariant tensor loads.
78   Value val;
79 };
80 
81 /// Lattice point. Each lattice point consists of a conjunction of tensor
82 /// loop indices (encoded in a bitvector) and the index of the corresponding
83 /// tensor expression.
84 struct LatPoint {
85   LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
86     bits.set(b);
87   }
88   LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
89   /// Conjunction of tensor loop indices as bitvector. This represents
90   /// all indices involved in the tensor expression
91   llvm::BitVector bits;
92   /// Simplified conjunction of tensor loop indices as bitvector. This
93   /// represents a simplified condition under which this tensor expression
94   /// must execute. Pre-computed during codegen to avoid repeated eval.
95   llvm::BitVector simple;
96   /// Index of the tensor expresssion.
97   unsigned exp;
98 };
99 
100 /// A class to handle all iteration lattice operations. This class abstracts
101 /// away from some implementation details of storing iteration lattices and
102 /// tensor expressions. This allows for fine-tuning performance characteristics
103 /// independently from the basic algorithm if bottlenecks are identified.
104 class Merger {
105 public:
106   /// Constructs a merger for the given number of tensors and loops. The
107   /// user supplies the number of tensors involved in the kernel, with the
108   /// last tensor in this set denoting the output tensor. The merger adds an
109   /// additional synthetic tensor at the end of this set to represent all
110   /// invariant expressions in the kernel.
111   Merger(unsigned t, unsigned l)
112       : outTensor(t - 1), numTensors(t + 1), numLoops(l),
113         dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
114 
115   /// Adds a tensor expression. Returns its index.
116   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
117     unsigned e = tensorExps.size();
118     tensorExps.push_back(TensorExp(k, e0, e1, v));
119     return e;
120   }
121   unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
122 
123   /// Adds an iteration lattice point. Returns its index.
124   unsigned addLat(unsigned t, unsigned i, unsigned e) {
125     assert(t < numTensors && i < numLoops);
126     unsigned p = latPoints.size();
127     latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
128     return p;
129   }
130 
131   /// Adds a new, initially empty, set. Returns its index.
132   unsigned addSet() {
133     unsigned s = latSets.size();
134     latSets.emplace_back(SmallVector<unsigned, 16>());
135     return s;
136   }
137 
138   /// Computes a single conjunction of two lattice points by taking the "union"
139   /// of loop indices (effectively constructing a larger "intersection" of those
140   /// indices) with a newly constructed tensor (sub)expression of given kind.
141   /// Returns the index of the new lattice point.
142   unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
143     unsigned p = latPoints.size();
144     llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
145     nb |= latPoints[p1].bits;
146     unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
147     latPoints.push_back(LatPoint(nb, e));
148     return p;
149   }
150 
151   /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
152   /// cartesian product. Returns the index of the new set.
153   unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
154     unsigned s = addSet();
155     for (unsigned p0 : latSets[s0])
156       for (unsigned p1 : latSets[s1])
157         latSets[s].push_back(conjLatPoint(kind, p0, p1));
158     return s;
159   }
160 
161   /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
162   /// Returns the index of the new set.
163   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
164     unsigned s = takeConj(kind, s0, s1);
165     for (unsigned p : latSets[s0])
166       latSets[s].push_back(p);
167     for (unsigned p : latSets[s1])
168       latSets[s].push_back(p);
169     return s;
170   }
171 
172   /// Optimizes the iteration lattice points in the given set. This
173   /// method should be called right before code generation to avoid
174   /// generating redundant loops and conditions.
175   unsigned optimizeSet(unsigned s0) {
176     unsigned s = addSet();
177     assert(latSets[s0].size() != 0);
178     unsigned p0 = latSets[s0][0];
179     for (unsigned p1 : latSets[s0]) {
180       bool add = true;
181       if (p0 != p1) {
182         // Is this a straightforward copy?
183         unsigned e = latPoints[p1].exp;
184         if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
185           continue;
186         // Conjunction already covered?
187         for (unsigned p2 : latSets[s]) {
188           assert(!latGT(p1, p2)); // Lj => Li would be bad
189           if (onlyDenseDiff(p2, p1)) {
190             add = false;
191             break;
192           }
193         }
194         assert(!add || latGT(p0, p1));
195       }
196       if (add)
197         latSets[s].push_back(p1);
198     }
199     for (unsigned p : latSets[s])
200       latPoints[p].simple = simplifyCond(s, p);
201     return s;
202   }
203 
204   /// Simplifies the conditions in a conjunction of a given lattice point
205   /// within the given set using just two basic rules:
206   /// (1) multiple dense conditions are reduced to single dense, and
207   /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
208   llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
209     // First determine if this lattice point is a *singleton*, i.e.,
210     // the last point in a lattice, no other is less than this one.
211     bool isSingleton = true;
212     for (unsigned p1 : latSets[s]) {
213       if (p0 != p1 && latGT(p0, p1)) {
214         isSingleton = false;
215         break;
216       }
217     }
218     // Now apply the two basic rules.
219     llvm::BitVector simple = latPoints[p0].bits;
220     bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
221     for (unsigned b = 0, be = simple.size(); b < be; b++) {
222       if (simple[b] && !isDim(b, Dim::kSparse)) {
223         if (reset)
224           simple.reset(b);
225         reset = true;
226       }
227     }
228     return simple;
229   }
230 
231   /// Returns true if Li > Lj.
232   bool latGT(unsigned i, unsigned j) const {
233     const llvm::BitVector &bitsi = latPoints[i].bits;
234     const llvm::BitVector &bitsj = latPoints[j].bits;
235     assert(bitsi.size() == bitsj.size());
236     if (bitsi.count() > bitsj.count()) {
237       for (unsigned b = 0, be = bitsj.size(); b < be; b++)
238         if (bitsj[b] && !bitsi[b])
239           return false;
240       return true;
241     }
242     return false;
243   }
244 
245   /// Returns true if Li and Lj only differ in dense.
246   bool onlyDenseDiff(unsigned i, unsigned j) {
247     llvm::BitVector tmp = latPoints[j].bits;
248     tmp ^= latPoints[i].bits;
249     return !hasAnyDimOf(tmp, Dim::kSparse);
250   }
251 
252   /// Bit translation.
253   unsigned tensor(unsigned b) const { return b % numTensors; }
254   unsigned index(unsigned b) const { return b / numTensors; }
255 
256   /// Returns true if bit corresponds to queried dim.
257   bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
258 
259   /// Returns true if tensor access at given index has queried dim.
260   bool isDim(unsigned t, unsigned i, Dim d) const {
261     assert(t < numTensors && i < numLoops);
262     return dims[t][i] == d;
263   }
264 
265   /// Returns true if any set bit corresponds to queried dim.
266   bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
267     for (unsigned b = 0, be = bits.size(); b < be; b++)
268       if (bits[b] && isDim(b, d))
269         return true;
270     return false;
271   }
272 
273   /// Returns true if tensor has any sparse dimension.
274   bool isSparseTensor(unsigned t) const {
275     return llvm::any_of(dims[t], [](Dim d) { return d == Dim::kSparse; });
276   }
277 
278   /// Setter
279   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
280 
281   /// Getters.
282   TensorExp &exp(unsigned e) { return tensorExps[e]; }
283   LatPoint &lat(unsigned l) { return latPoints[l]; }
284   SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
285 
286 private:
287   const unsigned outTensor;
288   const unsigned numTensors;
289   const unsigned numLoops;
290 
291   std::vector<std::vector<Dim>> dims;
292   llvm::SmallVector<TensorExp, 32> tensorExps;
293   llvm::SmallVector<LatPoint, 16> latPoints;
294   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
295 };
296 
297 // Code generation.
298 struct CodeGen {
299   CodeGen(mlir::SparsificationOptions o, unsigned numTensors, unsigned numLoops)
300       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
301         pointers(numTensors, std::vector<Value>(numLoops)),
302         indices(numTensors, std::vector<Value>(numLoops)),
303         highs(numTensors, std::vector<Value>(numLoops)),
304         pidxs(numTensors, std::vector<Value>(numLoops)),
305         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
306         curVecLength(1), curVecMask() {}
307   /// Sparsification options.
308   mlir::SparsificationOptions options;
309   /// Universal dense indices and upper bounds (by index). The loops array
310   /// is updated with the value of the universal dense index in the current
311   /// loop. The sizes array is set once with the inferred dimension sizes.
312   std::vector<Value> loops;
313   std::vector<Value> sizes;
314   /// Buffers for storing dense and sparse numerical values (by tensor).
315   /// This array is set once during bufferization of all tensors.
316   std::vector<Value> buffers;
317   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
318   /// This array is set once during bufferization of all sparse tensors.
319   std::vector<std::vector<Value>> pointers;
320   std::vector<std::vector<Value>> indices;
321   /// Sparse iteration information (by tensor and index). These arrays
322   /// are updated to remain current within the current loop.
323   std::vector<std::vector<Value>> highs;
324   std::vector<std::vector<Value>> pidxs;
325   std::vector<std::vector<Value>> idxs;
326   /// Current reduction, updated during code generation. When indices of a
327   /// reduction are exhausted,  all inner loops can "scalarize" the reduction.
328   // TODO: currently only done for (a chain of) innermost for-loops, where it
329   // is most effective; we could generalize to more outer and while-loops.
330   unsigned redExp;
331   Value redVal;
332   // Current vector length and mask.
333   unsigned curVecLength;
334   Value curVecMask;
335 };
336 
337 } // namespace
338 
339 /// Helper method to inspect sparse annotations in the linalg operation.
340 /// Fills the per-dimension sparsity information for all tensors.
341 static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
342   unsigned numTensors = op.getNumShapedOperands();
343   ArrayAttr sparseAttr = op.sparseAttr();
344   for (unsigned t = 0; t < numTensors; t++) {
345     auto map = op.getIndexingMap(t);
346     auto dimAttr = sparseAttr[t].cast<ArrayAttr>();
347     // For each tensor, we accept a per-dimension Sparse or Dense annotation.
348     // This is translated to the loop index that indexes that dimension.
349     unsigned rank = op.getShapedType(t).getRank();
350     for (unsigned d = 0; d < rank; d++) {
351       unsigned idx = map.getDimPosition(d);
352       if (isSparseDim(dimAttr[d])) {
353         merger.setDim(t, idx, Dim::kSparse);
354       } else {
355         assert(isDenseDim(dimAttr[d]));
356         merger.setDim(t, idx, Dim::kDense);
357       }
358     }
359   }
360 }
361 
362 /// Returns true if tensor was set up with sparse storage scheme.
363 static bool linkedSparse(linalg::GenericOp op, unsigned tensor) {
364   if (tensor < op.getNumInputs())
365     return isa_and_nonnull<sparse_tensor::FromPointerOp>(
366         op.getInput(tensor).getDefiningOp());
367   return false;
368 }
369 
370 /// A DFS helper to compute a topological sort. Note that recursion is
371 /// bounded by the number of implicit loops, which is always small.
372 /// Returns false when a cycle is detected.
373 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
374                        std::vector<unsigned> &topSort,
375                        std::vector<std::vector<bool>> &adjM) {
376   if (visit[i] != 0)
377     return visit[i] != 1; // 1 denotes cycle!
378   visit[i] = 1;
379   for (unsigned j = 0, e = visit.size(); j < e; j++)
380     if (adjM[i][j])
381       if (!topSortDFS(j, visit, topSort, adjM))
382         return false;
383   visit[i] = 2;
384   topSort.push_back(i);
385   return true;
386 }
387 
388 /// Computes a topologically sorted iteration graph for the linalg operation.
389 /// Ensures all tensors are visited in natural index order. This is essential
390 /// for sparse storage formats since these only support access along fixed
391 /// dimensions. Even for dense storage formats, however, the natural index
392 /// order yields innermost unit-stride access with better spatial locality.
393 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
394                                   std::vector<unsigned> &topSort,
395                                   bool sparseOnly) {
396   // Set up an n x n from/to adjacency matrix of the iteration graph
397   // for the implicit loop indices i_0 .. i_n-1.
398   unsigned n = op.getNumLoops();
399   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
400 
401   // Iterate over the indexing maps of every tensor in the tensor expression.
402   unsigned numTensors = op.getNumShapedOperands();
403   for (unsigned t = 0; t < numTensors; t++) {
404     auto map = op.getIndexingMap(t);
405     assert(map.getNumDims() == n);
406     // Skip dense tensor constraints when sparse only is requested.
407     if (sparseOnly && !merger.isSparseTensor(t) && !linkedSparse(op, t))
408       continue;
409     // At the moment, we take the index variables in the tensor access
410     // expression in the order in which they appear (conceptually a
411     // "row-major" layout of every tensor). So, a tensor access A_ijk
412     // forces the ordering i < j < k on the loop indices.
413     // TODO: support affine map to define alternative dimension orders.
414     for (unsigned d = 1, e = map.getNumResults(); d < e; d++) {
415       unsigned f = map.getDimPosition(d - 1);
416       unsigned t = map.getDimPosition(d);
417       adjM[f][t] = true;
418     }
419   }
420 
421   // Topologically sort the iteration graph to determine loop order.
422   // Report failure for a cyclic iteration graph.
423   topSort.clear();
424   topSort.reserve(n);
425   std::vector<unsigned> visit(n, 0);
426   for (unsigned i = 0; i < n; i++)
427     if (visit[i] == 0)
428       if (!topSortDFS(i, visit, topSort, adjM))
429         return false; // cycle!
430   std::reverse(std::begin(topSort), std::end(topSort));
431   return true;
432 }
433 
434 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
435 /// This simplifies constructing (sub)expressions during iteration lattice
436 /// building (compared to using the SSA representation everywhere).
437 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
438                                          Value val) {
439   if (auto arg = val.dyn_cast<BlockArgument>()) {
440     unsigned argN = arg.getArgNumber();
441     if (arg.getOwner()->getParentOp() == op) {
442       // Any parameter of the generic op is considered a tensor,
443       // indexed by the implicit loop bounds.
444       auto map = op.getIndexingMap(argN);
445       if (map.isProjectedPermutation())
446         return merger.addExp(Kind::kTensor, argN);
447       // Cannot handle (yet).
448       return None;
449     }
450     // Any parameter of a higher op is invariant.
451     return merger.addExp(Kind::kInvariant, val);
452   }
453   Operation *def = val.getDefiningOp();
454   if (def->getBlock() != &op.region().front()) {
455     // Something defined outside is invariant.
456     return merger.addExp(Kind::kInvariant, val);
457   } else if (def->getNumOperands() == 2) {
458     // Construct binary operations if subexpressions could be built.
459     auto x = buildTensorExp(merger, op, def->getOperand(0));
460     auto y = buildTensorExp(merger, op, def->getOperand(1));
461     if (x.hasValue() && y.hasValue()) {
462       unsigned e0 = x.getValue();
463       unsigned e1 = y.getValue();
464       if (isa<MulFOp>(def))
465         return merger.addExp(Kind::kMulF, e0, e1);
466       if (isa<MulIOp>(def))
467         return merger.addExp(Kind::kMulI, e0, e1);
468       if (isa<AddFOp>(def))
469         return merger.addExp(Kind::kAddF, e0, e1);
470       if (isa<AddIOp>(def))
471         return merger.addExp(Kind::kAddI, e0, e1);
472     }
473   }
474   // Cannot build (yet).
475   return None;
476 }
477 
478 /// Builds the iteration lattices in a bottom-up traversal given the remaining
479 /// tensor (sub)expression and the next loop index in the iteration graph.
480 static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
481                               unsigned exp, unsigned idx) {
482   Kind kind = merger.exp(exp).kind;
483   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
484     // Either the index is really used in the tensor expression, or it is
485     // set to the undefined index in that dimension. An invariant expression
486     // is set to a synthetic tensor with undefined indices only.
487     unsigned s = merger.addSet();
488     unsigned t =
489         kind == Kind::kTensor ? merger.exp(exp).e0 : op.getNumShapedOperands();
490     merger.set(s).push_back(merger.addLat(t, idx, exp));
491     return s;
492   }
493   unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx);
494   unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx);
495   switch (kind) {
496   case Kind::kTensor:
497   case Kind::kInvariant:
498     llvm_unreachable("handled above");
499   case Kind::kMulF:
500   case Kind::kMulI:
501     return merger.takeConj(kind, s0, s1);
502   case Kind::kAddF:
503   case Kind::kAddI:
504     return merger.takeDisj(kind, s0, s1);
505   }
506   llvm_unreachable("unexpected expression kind");
507 }
508 
509 /// Maps sparse integer option to actual integral storage type.
510 static Type genIntType(PatternRewriter &rewriter, SparseIntType tp) {
511   switch (tp) {
512   case SparseIntType::kNative:
513     return rewriter.getIndexType();
514   case SparseIntType::kI64:
515     return rewriter.getIntegerType(64);
516   case SparseIntType::kI32:
517     return rewriter.getIntegerType(32);
518   case SparseIntType::kI16:
519     return rewriter.getIntegerType(16);
520   case SparseIntType::kI8:
521     return rewriter.getIntegerType(8);
522   }
523   llvm_unreachable("unexpected SparseIntType");
524 }
525 
526 /// Generates buffer for the output tensor.
527 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
528                              linalg::GenericOp op, MemRefType denseTp,
529                              ArrayRef<Value> args) {
530   Location loc = op.getLoc();
531   Value tensor = op.getOutput(0);
532   // The output tensor simply could materialize from the buffer that will
533   // be generated for the tensor present in the outs() clause. This has
534   // the major advantage that the sparse kernel only updates the nonzero
535   // positions for the output tensor. Currently this results in functional,
536   // but slightly imprecise IR, so it is put under an experimental option.
537   if (codegen.options.fastOutput)
538     return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
539   // By default, a new buffer is allocated which is initialized to the
540   // tensor defined in the outs() clause. This is always correct but
541   // introduces a dense initialization component that may negatively
542   // impact the running complexity of the sparse kernel.
543   Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
544   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
545   rewriter.create<linalg::CopyOp>(loc, init, alloc);
546   return alloc;
547 }
548 
549 /// Local bufferization of all dense and sparse data structures.
550 /// This code enables testing the first prototype sparse compiler.
551 // TODO: replace this with a proliferated bufferization strategy
552 static void genBuffers(Merger &merger, CodeGen &codegen,
553                        PatternRewriter &rewriter, linalg::GenericOp op) {
554   Location loc = op.getLoc();
555   unsigned numTensors = op.getNumShapedOperands();
556   unsigned numInputs = op.getNumInputs();
557   assert(numTensors == numInputs + 1);
558   // For every tensor, find lower and upper bound on dimensions, set the
559   // same bounds on loop indices, and obtain dense or sparse buffer(s).
560   SmallVector<Value, 4> args;
561   for (unsigned t = 0; t < numTensors; t++) {
562     Value tensor = t < numInputs ? op.getInput(t) : op.getOutput(0);
563     auto tensorType = op.getShapedType(t);
564     auto shape = tensorType.getShape();
565     auto map = op.getIndexingMap(t);
566     // Scan all dimensions of current tensor.
567     bool dense = !linkedSparse(op, t);
568     args.clear();
569     for (unsigned d = 0, rank = shape.size(); d < rank; d++) {
570       unsigned i = map.getDimPosition(d);
571       // Handle sparse storage schemes.
572       if (merger.isDim(t, i, Dim::kSparse)) {
573         dense = false;
574         auto dynShape = {ShapedType::kDynamicSize};
575         auto ptrTp = MemRefType::get(
576             dynShape, genIntType(rewriter, codegen.options.ptrType));
577         auto indTp = MemRefType::get(
578             dynShape, genIntType(rewriter, codegen.options.indType));
579         Value dim = rewriter.create<ConstantIndexOp>(loc, d);
580         // Generate sparse primitives to obtains pointer and indices.
581         codegen.pointers[t][i] = rewriter.create<sparse_tensor::ToPointersOp>(
582             loc, ptrTp, tensor, dim);
583         codegen.indices[t][i] = rewriter.create<sparse_tensor::ToIndicesOp>(
584             loc, indTp, tensor, dim);
585       }
586       // Find lower and upper bound in current dimension.
587       Value up;
588       if (shape[d] == MemRefType::kDynamicSize) {
589         up = rewriter.create<memref::DimOp>(loc, tensor, d);
590         args.push_back(up);
591       } else {
592         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
593       }
594       codegen.sizes[i] = codegen.highs[t][i] = up;
595     }
596     // Perform the required bufferization. All dense inputs materialize
597     // from the input tensor. The dense output tensor needs special
598     // handling. Sparse inputs use a sparse primitive to obtain the values.
599     if (dense) {
600       auto denseTp = MemRefType::get(shape, tensorType.getElementType());
601       if (t < numInputs)
602         codegen.buffers[t] =
603             rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
604       else
605         codegen.buffers[t] =
606             genOutputBuffer(codegen, rewriter, op, denseTp, args);
607     } else {
608       auto dynShape = {ShapedType::kDynamicSize};
609       auto sparseTp = MemRefType::get(dynShape, tensorType.getElementType());
610       codegen.buffers[t] =
611           rewriter.create<sparse_tensor::ToValuesOp>(loc, sparseTp, tensor);
612     }
613   }
614 }
615 
616 /// Constructs vector type.
617 static VectorType vectorType(CodeGen &codegen, Type etp) {
618   return VectorType::get(codegen.curVecLength, etp);
619 }
620 
621 /// Constructs vector type from pointer.
622 static VectorType vectorType(CodeGen &codegen, Value ptr) {
623   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
624 }
625 
626 /// Constructs vector iteration mask.
627 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
628                            Value iv, Value lo, Value hi, Value step) {
629   Location loc = iv.getLoc();
630   VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1));
631   // Special case if the vector length evenly divides the trip count (for
632   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
633   // so that all subsequent masked memory operations are immediately folded
634   // into unconditional memory operations.
635   IntegerAttr loInt, hiInt, stepInt;
636   if (matchPattern(lo, m_Constant(&loInt)) &&
637       matchPattern(hi, m_Constant(&hiInt)) &&
638       matchPattern(step, m_Constant(&stepInt))) {
639     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
640       return rewriter.create<vector::BroadcastOp>(
641           loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1));
642   }
643   // Otherwise, generate a vector mask that avoids overrunning the upperbound
644   // during vector execution. Here we rely on subsequent loop optimizations to
645   // avoid executing the mask in all iterations, for example, by splitting the
646   // loop into an unconditional vector loop and a scalar cleanup loop.
647   Value end = rewriter.create<SubIOp>(loc, hi, iv);
648   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
649 }
650 
651 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
652 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
653                            Value ptr, ArrayRef<Value> args) {
654   Location loc = ptr.getLoc();
655   VectorType vtp = vectorType(codegen, ptr);
656   Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
657   if (args.back().getType().isa<VectorType>()) {
658     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
659     Value indexVec = args.back();
660     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
661     return rewriter.create<vector::GatherOp>(
662         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
663   }
664   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
665                                                codegen.curVecMask, pass);
666 }
667 
668 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
669 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
670                            Value rhs, Value ptr, ArrayRef<Value> args) {
671   Location loc = ptr.getLoc();
672   if (args.back().getType().isa<VectorType>()) {
673     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
674     Value indexVec = args.back();
675     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
676     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
677                                        codegen.curVecMask, rhs);
678     return;
679   }
680   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
681                                          rhs);
682 }
683 
684 /// Generates a vectorized invariant. Here we rely on subsequent loop
685 /// optimizations to hoist the invariant broadcast out of the vector loop.
686 static Value genVectorInvariantValue(CodeGen &codegen,
687                                      PatternRewriter &rewriter, Value val) {
688   VectorType vtp = vectorType(codegen, val.getType());
689   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
690 }
691 
692 /// Generates a load on a dense or sparse tensor.
693 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
694                            PatternRewriter &rewriter, linalg::GenericOp op,
695                            unsigned exp) {
696   // Test if the load was hoisted to a higher loop nest.
697   Value val = merger.exp(exp).val;
698   if (val) {
699     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
700       return genVectorInvariantValue(codegen, rewriter, val);
701     return val;
702   }
703   // Actual load.
704   SmallVector<Value, 4> args;
705   unsigned tensor = merger.exp(exp).e0;
706   auto map = op.getIndexingMap(tensor);
707   bool sparse = linkedSparse(op, tensor);
708   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
709     unsigned idx = map.getDimPosition(i);
710     args.push_back(codegen.loops[idx]); // universal dense index
711     if (sparse || merger.isDim(tensor, idx, Dim::kSparse)) {
712       sparse = true;
713       args.clear();
714       args.push_back(codegen.pidxs[tensor][idx]); // position index
715     }
716   }
717   Location loc = op.getLoc();
718   Value ptr = codegen.buffers[tensor];
719   if (codegen.curVecLength > 1)
720     return genVectorLoad(codegen, rewriter, ptr, args);
721   return rewriter.create<memref::LoadOp>(loc, ptr, args);
722 }
723 
724 /// Generates a store on a dense tensor.
725 static void genTensorStore(Merger &merger, CodeGen &codegen,
726                            PatternRewriter &rewriter, linalg::GenericOp op,
727                            unsigned tensor, Value rhs) {
728   Location loc = op.getLoc();
729   // Test if this is a scalarized reduction.
730   unsigned lhs = op.getNumShapedOperands() - 1;
731   if (lhs == tensor && codegen.redVal) {
732     if (codegen.curVecLength > 1)
733       rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs,
734                                       codegen.redVal);
735     codegen.redVal = rhs;
736     return;
737   }
738   // Actual store.
739   SmallVector<Value, 4> args;
740   auto map = op.getIndexingMap(tensor);
741   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
742     unsigned idx = map.getDimPosition(i);
743     args.push_back(codegen.loops[idx]); // universal dense index
744   }
745   Value ptr = codegen.buffers[tensor];
746   if (codegen.curVecLength > 1)
747     genVectorStore(codegen, rewriter, rhs, ptr, args);
748   else
749     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
750 }
751 
752 /// Generates a pointer/index load from the sparse storage scheme. Narrower
753 /// data types need to be zero extended before casting the value into the
754 /// index type used for looping and indexing.
755 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
756                      Value ptr, Value s) {
757   // See https://llvm.org/docs/GetElementPtr.html for some background on
758   // the complications described below.
759   if (codegen.curVecLength > 1) {
760     // Since the index vector is used in a subsequent gather/scatter operations,
761     // which effectively defines an unsigned pointer + signed index, we must
762     // zero extend the vector to an index width. For 8-bit and 16-bit values,
763     // an 32-bit index width suffices. For 32-bit values, zero extending the
764     // elements into 64-bit loses some performance since the 32-bit indexed
765     // gather/scatter is more efficient than the 64-bit index variant (in
766     // the future, we could introduce a flag that states the negative space
767     // of 32-bit indices is unused). For 64-bit values, there is no good way
768     // to state that the indices are unsigned, with creates the potential of
769     // incorrect address calculations in the unlikely case we need such
770     // extremely large offsets.
771     Type etp = ptr.getType().cast<MemRefType>().getElementType();
772     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
773     if (!etp.isa<IndexType>()) {
774       if (etp.getIntOrFloatBitWidth() < 32)
775         vload = rewriter.create<ZeroExtendIOp>(
776             loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
777       else if (etp.getIntOrFloatBitWidth() < 64)
778         vload = rewriter.create<ZeroExtendIOp>(
779             loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
780     }
781     return vload;
782   }
783   // For the scalar case, we simply zero extend narrower indices into 64-bit
784   // values before casting to index without a performance penalty. Here too,
785   // however, indices that already are 64-bit, in theory, cannot express the
786   // full range as explained above.
787   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
788   if (!load.getType().isa<IndexType>()) {
789     if (load.getType().getIntOrFloatBitWidth() < 64)
790       load = rewriter.create<ZeroExtendIOp>(loc, load,
791                                             rewriter.getIntegerType(64));
792     load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
793   }
794   return load;
795 }
796 
797 /// Generates an invariant value.
798 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
799                                PatternRewriter &rewriter, unsigned exp) {
800   Value val = merger.exp(exp).val;
801   if (codegen.curVecLength > 1)
802     return genVectorInvariantValue(codegen, rewriter, val);
803   return val;
804 }
805 
806 /// Generates an address computation "sz * p + i".
807 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
808                         Location loc, Value size, Value p, Value i) {
809   Value mul = rewriter.create<MulIOp>(loc, size, p);
810   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
811     Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType());
812     mul = genVectorInvariantValue(codegen, rewriter, inv);
813   }
814   return rewriter.create<AddIOp>(loc, mul, i);
815 }
816 
817 /// Generates start of a reduction.
818 static Value genReductionStart(Merger &merger, CodeGen &codegen,
819                                PatternRewriter &rewriter,
820                                linalg::GenericOp op) {
821   if (codegen.redVal)
822     return codegen.redVal; // chained with previous for-loop
823   if (codegen.curVecLength > 1) {
824     // TODO: assumes + reductions for now
825     VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]);
826     return rewriter.create<ConstantOp>(op.getLoc(), vtp,
827                                        rewriter.getZeroAttr(vtp));
828   }
829   return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
830 }
831 
832 /// Generates end of a reduction.
833 static void genReductionEnd(Merger &merger, CodeGen &codegen,
834                             PatternRewriter &rewriter, linalg::GenericOp op) {
835   Value red = codegen.redVal;
836   if (!red)
837     return;
838   assert(codegen.curVecLength == 1);
839   codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
840   unsigned lhs = op.getNumShapedOperands() - 1;
841   if (auto vtp = red.getType().dyn_cast<VectorType>()) {
842     // TODO: assumes + reductions for now
843     StringAttr kind = rewriter.getStringAttr("add");
844     Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
845     // Integer reductions don't accept an accumulator.
846     if (vtp.getElementType().isa<IntegerType>()) {
847       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
848                                                  kind, red, ValueRange{});
849       red = rewriter.create<AddIOp>(op.getLoc(), red, ld);
850     } else {
851       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
852                                                  kind, red, ld);
853     }
854   }
855   genTensorStore(merger, codegen, rewriter, op, lhs, red);
856 }
857 
858 /// Recursively generates tensor expression.
859 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
860                     linalg::GenericOp op, unsigned exp) {
861   if (merger.exp(exp).kind == Kind::kTensor)
862     return genTensorLoad(merger, codegen, rewriter, op, exp);
863   else if (merger.exp(exp).kind == Kind::kInvariant)
864     return genInvariantValue(merger, codegen, rewriter, exp);
865   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
866   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
867   switch (merger.exp(exp).kind) {
868   case Kind::kTensor:
869   case Kind::kInvariant:
870     llvm_unreachable("handled above");
871   case Kind::kMulF:
872     return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
873   case Kind::kMulI:
874     return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
875   case Kind::kAddF:
876     return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
877   case Kind::kAddI:
878     return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
879   }
880   llvm_unreachable("unexpected expression kind");
881 }
882 
883 /// Hoists loop invariant tensor loads for which indices have been exhausted.
884 static void genInvariants(Merger &merger, CodeGen &codegen,
885                           PatternRewriter &rewriter, linalg::GenericOp op,
886                           unsigned exp, unsigned ldx, bool hoist) {
887   if (merger.exp(exp).kind == Kind::kTensor) {
888     // Inspect tensor indices.
889     bool atLevel = ldx == -1u;
890     unsigned tensor = merger.exp(exp).e0;
891     auto map = op.getIndexingMap(tensor);
892     for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
893       unsigned idx = map.getDimPosition(i);
894       if (!codegen.loops[idx])
895         return; // still in play
896       else if (idx == ldx)
897         atLevel = true;
898     }
899     // All exhausted at this level (atLevel denotes exactly at this level).
900     unsigned lhs = op.getNumShapedOperands() - 1;
901     if (lhs == tensor) {
902       codegen.redExp = hoist ? exp : -1u;
903     } else if (atLevel) {
904       merger.exp(exp).val =
905           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
906     }
907   } else if (merger.exp(exp).kind != Kind::kInvariant) {
908     // Traverse into the binary operations. Note that we only hoist
909     // tensor loads, since subsequent MLIR/LLVM passes know how to
910     // deal with all other kinds of derived loop invariants.
911     unsigned e0 = merger.exp(exp).e0;
912     unsigned e1 = merger.exp(exp).e1;
913     genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist);
914     genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist);
915   }
916 }
917 
918 /// Generates initialization code for the subsequent loop sequence at
919 /// current index level. Returns true if the loop sequence needs to
920 /// maintain the universal index.
921 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
922                     linalg::GenericOp op, std::vector<unsigned> &topSort,
923                     unsigned at, llvm::BitVector &inits) {
924   bool needsUniv = false;
925   Location loc = op.getLoc();
926   unsigned idx = topSort[at];
927 
928   // Initialize sparse positions.
929   for (unsigned b = 0, be = inits.size(); b < be; b++) {
930     if (inits[b]) {
931       unsigned tensor = merger.tensor(b);
932       assert(idx == merger.index(b));
933       if (merger.isDim(b, Dim::kSparse)) {
934         // Initialize sparse index.
935         unsigned pat = at;
936         for (; pat != 0; pat--) {
937           if (codegen.pidxs[tensor][topSort[pat - 1]])
938             break;
939         }
940         Value ptr = codegen.pointers[tensor][idx];
941         Value one = rewriter.create<ConstantIndexOp>(loc, 1);
942         Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
943                               : codegen.pidxs[tensor][topSort[pat - 1]];
944         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
945         Value p1 = rewriter.create<AddIOp>(loc, p0, one);
946         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
947       } else {
948         // Dense index still in play.
949         needsUniv = true;
950       }
951     }
952   }
953 
954   // Initialize the universal dense index.
955   codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0);
956   return needsUniv;
957 }
958 
959 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
960 /// operation is a candidate. Whether it is actually converted to SIMD code
961 /// depends on the requested strategy.
962 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
963   switch (codegen.options.vectorizationStrategy) {
964   case SparseVectorizationStrategy::kNone:
965     return false;
966   case SparseVectorizationStrategy::kDenseInnerLoop:
967     return isInner && !isSparse;
968   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
969     return isInner;
970   }
971   llvm_unreachable("unexpected vectorization strategy");
972 }
973 
974 /// Returns parallelization strategy. Any implicit loop in the Linalg operation
975 /// that is marked "parallel" is a candidate. Whether it is actually converted
976 /// to a parallel operation depends on the requested strategy.
977 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
978                           bool isSparse, bool isVector) {
979   switch (codegen.options.parallelizationStrategy) {
980   case SparseParallelizationStrategy::kNone:
981     return false;
982   case SparseParallelizationStrategy::kDenseOuterLoop:
983     return isOuter && !isSparse && !isReduction && !isVector;
984   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
985     return isOuter && !isReduction && !isVector;
986   case SparseParallelizationStrategy::kDenseAnyLoop:
987     return !isSparse && !isReduction && !isVector;
988   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
989     return !isReduction && !isVector;
990   }
991   llvm_unreachable("unexpected parallelization strategy");
992 }
993 
994 /// Checks unit strides for dense tensors. The iteration graph may have ignored
995 /// dense access patterns in order to avoid cycles (sparse access patterns are
996 /// always placed innermost), but that means dense access has become strided.
997 /// For now, we reject vectorization of such cases.
998 /// TODO: implement strided load/stores on dense arrays
999 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
1000                              unsigned idx) {
1001   unsigned numTensors = op.getNumShapedOperands();
1002   for (unsigned t = 0; t < numTensors; t++) {
1003     if (!merger.isSparseTensor(t) && !linkedSparse(op, t)) {
1004       auto map = op.getIndexingMap(t);
1005       unsigned r = map.getNumResults();
1006       for (unsigned i = 0; i < r; i++) {
1007         if (map.getDimPosition(i) == idx && i != r - 1)
1008           return false;
1009       }
1010     }
1011   }
1012   return true;
1013 }
1014 
1015 /// Generates a for-loop on a single index.
1016 static Operation *genFor(Merger &merger, CodeGen &codegen,
1017                          PatternRewriter &rewriter, linalg::GenericOp op,
1018                          bool isOuter, bool isInner, unsigned idx,
1019                          llvm::BitVector &indices) {
1020   unsigned fb = indices.find_first();
1021   unsigned tensor = merger.tensor(fb);
1022   assert(idx == merger.index(fb));
1023   auto iteratorTypes = op.iterator_types().getValue();
1024   bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]);
1025   bool isSparse = merger.isDim(fb, Dim::kSparse);
1026   bool isVector = isVectorFor(codegen, isInner, isSparse) &&
1027                   denseUnitStrides(merger, op, idx);
1028   bool isParallel =
1029       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
1030 
1031   // Prepare vector length.
1032   if (isVector)
1033     codegen.curVecLength = codegen.options.vectorLength;
1034 
1035   // Loop bounds and increment.
1036   Location loc = op.getLoc();
1037   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
1038   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
1039   Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength);
1040 
1041   // Emit a parallel loop.
1042   if (isParallel) {
1043     assert(!isVector);
1044     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
1045     if (isSparse)
1046       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
1047     else
1048       codegen.loops[idx] = parOp.getInductionVars()[0];
1049     rewriter.setInsertionPointToStart(parOp.getBody());
1050     return parOp;
1051   }
1052 
1053   // Emit a sequential loop, potentially with a scalarized reduction.
1054   bool scalarRed = isInner && codegen.redExp != -1u;
1055   SmallVector<Value, 4> operands;
1056   if (scalarRed) {
1057     Value load = genReductionStart(merger, codegen, rewriter, op);
1058     operands.push_back(load);
1059   }
1060   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
1061   if (scalarRed) {
1062     codegen.redVal = merger.exp(codegen.redExp).val =
1063         forOp.getRegionIterArgs().front();
1064   }
1065   // Assign induction variable to sparse or dense index.
1066   Value iv = forOp.getInductionVar();
1067   if (isSparse)
1068     codegen.pidxs[tensor][idx] = iv;
1069   else
1070     codegen.loops[idx] = iv;
1071   rewriter.setInsertionPointToStart(forOp.getBody());
1072   // Share vector iteration mask between all subsequent loads/stores.
1073   if (isVector)
1074     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
1075   return forOp;
1076 }
1077 
1078 /// Emit a while-loop for co-iteration over multiple indices.
1079 static Operation *genWhile(Merger &merger, CodeGen &codegen,
1080                            PatternRewriter &rewriter, linalg::GenericOp op,
1081                            unsigned idx, bool needsUniv,
1082                            llvm::BitVector &indices) {
1083   SmallVector<Type, 4> types;
1084   SmallVector<Value, 4> operands;
1085   // Construct the while-loop with a parameter for each index.
1086   Type indexType = rewriter.getIndexType();
1087   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1088     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1089       unsigned tensor = merger.tensor(b);
1090       assert(idx == merger.index(b));
1091       types.push_back(indexType);
1092       assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
1093              "type mismatch for sparse index");
1094       operands.push_back(codegen.pidxs[tensor][idx]);
1095     }
1096   }
1097   if (needsUniv) {
1098     types.push_back(indexType);
1099     assert(codegen.loops[idx].getType().isa<IndexType>() &&
1100            "type mismatch for universal index");
1101     operands.push_back(codegen.loops[idx]);
1102   }
1103   Location loc = op.getLoc();
1104   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
1105   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
1106   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
1107 
1108   // Build the "before" region, which effectively consists
1109   // of a conjunction of "i < upper" tests on all induction.
1110   rewriter.setInsertionPointToStart(&whileOp.before().front());
1111   Value cond;
1112   unsigned o = 0;
1113   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1114     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1115       unsigned tensor = merger.tensor(b);
1116       assert(idx == merger.index(b));
1117       Value op1 = before->getArgument(o);
1118       Value op2 = codegen.highs[tensor][idx];
1119       Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2);
1120       cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc;
1121       codegen.pidxs[tensor][idx] = after->getArgument(o++);
1122     }
1123   }
1124   if (needsUniv)
1125     codegen.loops[idx] = after->getArgument(o++);
1126   assert(o == operands.size());
1127   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
1128   rewriter.setInsertionPointToStart(&whileOp.after().front());
1129   return whileOp;
1130 }
1131 
1132 /// Generates a for-loop or a while-loop, depending on whether it implements
1133 /// singleton iteration or co-iteration over the given conjunction.
1134 static Operation *genLoop(Merger &merger, CodeGen &codegen,
1135                           PatternRewriter &rewriter, linalg::GenericOp op,
1136                           std::vector<unsigned> &topSort, unsigned at,
1137                           bool needsUniv, llvm::BitVector &indices) {
1138   unsigned idx = topSort[at];
1139   if (indices.count() == 1) {
1140     bool isOuter = at == 0;
1141     bool isInner = at == topSort.size() - 1;
1142     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
1143                   indices);
1144   }
1145   genReductionEnd(merger, codegen, rewriter, op); // cannot chain
1146   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
1147 }
1148 
1149 /// Generates the local variables for this loop, consisting of the sparse
1150 /// indices, restored universal dense index, and dense positions.
1151 static void genLocals(Merger &merger, CodeGen &codegen,
1152                       PatternRewriter &rewriter, linalg::GenericOp op,
1153                       std::vector<unsigned> &topSort, unsigned at,
1154                       bool needsUniv, llvm::BitVector &locals) {
1155   Location loc = op.getLoc();
1156   unsigned idx = topSort[at];
1157 
1158   // Initialize sparse indices.
1159   Value min;
1160   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1161     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1162       unsigned tensor = merger.tensor(b);
1163       assert(idx == merger.index(b));
1164       Value ptr = codegen.indices[tensor][idx];
1165       Value s = codegen.pidxs[tensor][idx];
1166       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1167       codegen.idxs[tensor][idx] = load;
1168       if (!needsUniv) {
1169         if (min) {
1170           Value cmp =
1171               rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
1172           min = rewriter.create<SelectOp>(loc, cmp, load, min);
1173         } else {
1174           min = load;
1175         }
1176       }
1177     }
1178   }
1179 
1180   // Merge dense universal index over minimum.
1181   if (min) {
1182     assert(!needsUniv);
1183     codegen.loops[idx] = min;
1184   }
1185 
1186   // Initialize dense positions.
1187   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1188     if (locals[b] && merger.isDim(b, Dim::kDense)) {
1189       unsigned tensor = merger.tensor(b);
1190       assert(idx == merger.index(b));
1191       unsigned pat = at;
1192       for (; pat != 0; pat--)
1193         if (codegen.pidxs[tensor][topSort[pat - 1]])
1194           break;
1195       Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
1196                            : codegen.pidxs[tensor][topSort[pat - 1]];
1197       codegen.pidxs[tensor][idx] = genAddress(
1198           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1199     }
1200   }
1201 }
1202 
1203 /// Generates the induction structure for a while-loop.
1204 static void genWhileInduction(Merger &merger, CodeGen &codegen,
1205                               PatternRewriter &rewriter, linalg::GenericOp op,
1206                               unsigned idx, bool needsUniv,
1207                               llvm::BitVector &induction, ResultRange results) {
1208   Location loc = op.getLoc();
1209   unsigned o = 0;
1210   SmallVector<Value, 4> operands;
1211   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
1212   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1213     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1214       unsigned tensor = merger.tensor(b);
1215       assert(idx == merger.index(b));
1216       Value op1 = codegen.idxs[tensor][idx];
1217       Value op2 = codegen.loops[idx];
1218       Value op3 = codegen.pidxs[tensor][idx];
1219       Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1220       Value add = rewriter.create<AddIOp>(loc, op3, one);
1221       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
1222       codegen.pidxs[tensor][idx] = results[o++];
1223     }
1224   }
1225   if (needsUniv) {
1226     operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one));
1227     codegen.loops[idx] = results[o++];
1228   }
1229   assert(o == operands.size());
1230   rewriter.create<scf::YieldOp>(loc, operands);
1231 }
1232 
1233 /// Generates a single if-statement within a while-loop.
1234 static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1235                        PatternRewriter &rewriter, linalg::GenericOp op,
1236                        unsigned idx, llvm::BitVector &conditions) {
1237   Location loc = op.getLoc();
1238   Value cond;
1239   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1240     if (conditions[b]) {
1241       unsigned tensor = merger.tensor(b);
1242       assert(idx == merger.index(b));
1243       Value clause;
1244       if (merger.isDim(b, Dim::kSparse)) {
1245         Value op1 = codegen.idxs[tensor][idx];
1246         Value op2 = codegen.loops[idx];
1247         clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1248       } else {
1249         clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true
1250       }
1251       cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause;
1252     }
1253   }
1254   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
1255   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
1256   return ifOp;
1257 }
1258 
1259 /// Recursively generates code while computing iteration lattices in order
1260 /// to manage the complexity of implementing co-iteration over unions
1261 /// and intersections of sparse iterations spaces.
1262 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1263                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1264                     unsigned exp, unsigned at) {
1265   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1266   if (at == topSort.size()) {
1267     unsigned lhs = op.getNumShapedOperands() - 1;
1268     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1269     genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
1270     return;
1271   }
1272   assert(codegen.curVecLength == 1);
1273 
1274   // Construct iteration lattices for current loop index, with L0 at top.
1275   // Then emit initialization code for the loop sequence at this level.
1276   // We maintain the universal dense index if dense indices are still
1277   // in play for a non-singleton loop sequence.
1278   Location loc = op.getLoc();
1279   unsigned idx = topSort[at];
1280   unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx));
1281   unsigned lsize = merger.set(lts).size();
1282   assert(lsize != 0);
1283   unsigned l0 = merger.set(lts)[0];
1284   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1285   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
1286   bool needsUniv = false;
1287   if (genInit(merger, codegen, rewriter, op, topSort, at,
1288               merger.lat(l0).bits)) {
1289     // Maintain the universal index only if it is actually
1290     // consumed by a subsequent lattice point.
1291     for (unsigned i = 1; i < lsize; i++) {
1292       unsigned li = merger.set(lts)[i];
1293       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) {
1294         needsUniv = true;
1295         break;
1296       }
1297     }
1298   }
1299 
1300   // Emit a loop for every lattice point L0 >= Li.
1301   for (unsigned i = 0; i < lsize; i++) {
1302     unsigned li = merger.set(lts)[i];
1303 
1304     // Emit loop.
1305     codegen.curVecLength = 1;
1306     llvm::BitVector indices = merger.lat(li).simple;
1307     Operation *loop =
1308         genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
1309     genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1310               merger.lat(li).bits);
1311 
1312     // Visit all lattices points with Li >= Lj to generate the
1313     // loop-body, possibly with if statements for coiteration.
1314     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1315     for (unsigned j = 0; j < lsize; j++) {
1316       unsigned lj = merger.set(lts)[j];
1317       unsigned ej = merger.lat(lj).exp;
1318       if (li == lj || merger.latGT(li, lj)) {
1319         // Recurse into body of each branch.
1320         if (isWhile) {
1321           scf::IfOp ifOp =
1322               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1323           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1324           rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
1325         } else {
1326           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1327         }
1328       }
1329     }
1330 
1331     // Wrap-up induction and restore insertion point.
1332     if (isWhile) {
1333       scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
1334       rewriter.setInsertionPointToEnd(&whileOp.after().front());
1335       genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
1336                         merger.lat(li).bits, whileOp.results());
1337     } else {
1338       needsUniv = false;
1339       if (codegen.redVal) {
1340         rewriter.create<scf::YieldOp>(loc, codegen.redVal);
1341         codegen.redVal = loop->getResult(0);
1342       }
1343     }
1344     rewriter.setInsertionPointAfter(loop);
1345   }
1346 
1347   // Wrap-up loop sequence.
1348   codegen.curVecLength = 1;
1349   genReductionEnd(merger, codegen, rewriter, op);
1350   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
1351   codegen.loops[idx] = Value();
1352 }
1353 
1354 namespace {
1355 
1356 /// Sparse rewriting rule for generic Lingalg operation.
1357 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1358 public:
1359   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1360       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1361 
1362   LogicalResult matchAndRewrite(linalg::GenericOp op,
1363                                 PatternRewriter &rewriter) const override {
1364     // Detects sparse annotations and translate the per-dimension sparsity
1365     // information for all tensors to loop indices in the kernel.
1366     if (!op.hasSparseSemantics())
1367       return failure();
1368     assert(op.getNumOutputs() == 1);
1369     unsigned numTensors = op.getNumShapedOperands();
1370     unsigned numLoops = op.iterator_types().getValue().size();
1371     Merger merger(numTensors, numLoops);
1372     findSparseAnnotations(merger, op);
1373 
1374     // Computes a topologically sorted iteration graph to ensure
1375     // tensors are visited in natural index order. Fails on cycles.
1376     // This assumes that higher-level passes have already put the
1377     // tensors in each tensor expression in a feasible order.
1378     std::vector<unsigned> topSort;
1379     if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
1380         !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
1381       return failure();
1382 
1383     // Finds the terminating yield statement and builds the tensor
1384     // expression for the Linalg operation in SSA form.
1385     Operation *yield = op.region().front().getTerminator();
1386     Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
1387     if (!exp.hasValue())
1388       return failure(); // build failure
1389 
1390     // Recursively generates code.
1391     CodeGen codegen(options, numTensors, numLoops);
1392     genBuffers(merger, codegen, rewriter, op);
1393     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
1394     Value result = rewriter.create<memref::TensorLoadOp>(
1395         op.getLoc(), codegen.buffers.back());
1396     rewriter.replaceOp(op, result);
1397     return success();
1398   }
1399 
1400 private:
1401   /// Options to control sparse code generation.
1402   SparsificationOptions options;
1403 };
1404 
1405 } // namespace
1406 
1407 /// Populates the given patterns list with rewriting rules required for
1408 /// the sparsification of linear algebra operations.
1409 void mlir::populateSparsificationPatterns(
1410     RewritePatternSet &patterns, const SparsificationOptions &options) {
1411   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1412 }
1413