1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering sparse tensor types to actual sparse code.
10 //
11 // The concept of letting a compiler generate sparse code automatically was
12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and
13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor
14 // Algebra Compiler (TACO). The implementation in this file closely follows
15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting
16 // rule is applied to each tensor expression in linalg (MLIR's tensor index
17 // notation) where the sparsity of tensors is indicated with annotation using
18 // a per-dimension specification of sparse/dense storage together with a
19 // specification of the order on the dimensions. Subsequently, a topologically
20 // sorted iteration graph, reflecting the required order on indices with respect
21 // to the dimensions of each tensor, is constructed to ensure that all tensors
22 // are visited in natural index order. Next, iteration lattices are constructed
23 // for the tensor expression for every index in topological order. Each
24 // iteration lattice point consists of a conjunction of tensor indices together
25 // with a tensor (sub)expression that needs to be evaluated for that
26 // conjunction. Within the lattice, iteration points are ordered according to
27 // the way indices are exhausted. As such these iteration lattices drive actual
28 // sparse code generation, which consists of a tedious but relatively
29 // straightforward one-to-one mapping from iteration lattices to combinations
30 // of for-loops, while-loops, and if-statements.
31 //
32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations.
33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php).
34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou,
35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler.
36 // Proceedings of the ACM on Programming Languages, October 2017.
37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation.
38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org).
39 //
40 // Implementation detail: We use llvm::SmallVector for vectors with
41 // variable lengths and std::vector for vectors with fixed lengths.
42 //===----------------------------------------------------------------------===//
43 
44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
45 #include "mlir/Dialect/Linalg/Utils/Utils.h"
46 #include "mlir/Dialect/SCF/SCF.h"
47 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
48 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
49 #include "mlir/Dialect/StandardOps/IR/Ops.h"
50 #include "mlir/Dialect/Vector/VectorOps.h"
51 #include "mlir/IR/Matchers.h"
52 #include "mlir/IR/TensorEncoding.h"
53 #include "llvm/ADT/SmallBitVector.h"
54 
55 using namespace mlir;
56 using namespace mlir::sparse_tensor;
57 
58 namespace {
59 
60 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
61 enum class Dim { kSparse, kDense, kSingle, kUndef };
62 
63 /// Tensor expression. Represents a MLIR expression in tensor index notation.
64 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is
65 /// stored directly. For binary operations, e0 and e1 denote the index of the
66 /// children tensor expressions.
67 struct TensorExp {
68   TensorExp(Kind k, unsigned x, unsigned y, Value v)
69       : kind(k), e0(x), e1(y), val(v) {
70     assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
71            (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
72            (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
73   }
74   Kind kind;
75   /// Indices of children expression(s).
76   unsigned e0;
77   unsigned e1;
78   /// Direct link to IR for an invariant. During code generation,
79   /// field is used to cache "hoisted" loop invariant tensor loads.
80   Value val;
81 };
82 
83 /// Lattice point. Each lattice point consists of a conjunction of tensor
84 /// loop indices (encoded in a bitvector) and the index of the corresponding
85 /// tensor expression.
86 struct LatPoint {
87   LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
88     bits.set(b);
89   }
90   LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
91   /// Conjunction of tensor loop indices as bitvector. This represents
92   /// all indices involved in the tensor expression
93   llvm::BitVector bits;
94   /// Simplified conjunction of tensor loop indices as bitvector. This
95   /// represents a simplified condition under which this tensor expression
96   /// must execute. Pre-computed during codegen to avoid repeated eval.
97   llvm::BitVector simple;
98   /// Index of the tensor expresssion.
99   unsigned exp;
100 };
101 
102 /// A class to handle all iteration lattice operations. This class abstracts
103 /// away from some implementation details of storing iteration lattices and
104 /// tensor expressions. This allows for fine-tuning performance characteristics
105 /// independently from the basic algorithm if bottlenecks are identified.
106 class Merger {
107 public:
108   /// Constructs a merger for the given number of tensors and loops. The
109   /// user supplies the number of tensors involved in the kernel, with the
110   /// last tensor in this set denoting the output tensor. The merger adds an
111   /// additional synthetic tensor at the end of this set to represent all
112   /// invariant expressions in the kernel.
113   Merger(unsigned t, unsigned l)
114       : outTensor(t - 1), numTensors(t + 1), numLoops(l),
115         dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
116 
117   /// Adds a tensor expression. Returns its index.
118   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
119     unsigned e = tensorExps.size();
120     tensorExps.push_back(TensorExp(k, e0, e1, v));
121     return e;
122   }
123   unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
124 
125   /// Adds an iteration lattice point. Returns its index.
126   unsigned addLat(unsigned t, unsigned i, unsigned e) {
127     assert(t < numTensors && i < numLoops);
128     unsigned p = latPoints.size();
129     latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
130     return p;
131   }
132 
133   /// Adds a new, initially empty, set. Returns its index.
134   unsigned addSet() {
135     unsigned s = latSets.size();
136     latSets.emplace_back(SmallVector<unsigned, 16>());
137     return s;
138   }
139 
140   /// Computes a single conjunction of two lattice points by taking the "union"
141   /// of loop indices (effectively constructing a larger "intersection" of those
142   /// indices) with a newly constructed tensor (sub)expression of given kind.
143   /// Returns the index of the new lattice point.
144   unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
145     unsigned p = latPoints.size();
146     llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
147     nb |= latPoints[p1].bits;
148     unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
149     latPoints.push_back(LatPoint(nb, e));
150     return p;
151   }
152 
153   /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
154   /// cartesian product. Returns the index of the new set.
155   unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
156     unsigned s = addSet();
157     for (unsigned p0 : latSets[s0])
158       for (unsigned p1 : latSets[s1])
159         latSets[s].push_back(conjLatPoint(kind, p0, p1));
160     return s;
161   }
162 
163   /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
164   /// Returns the index of the new set.
165   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
166     unsigned s = takeConj(kind, s0, s1);
167     for (unsigned p : latSets[s0])
168       latSets[s].push_back(p);
169     for (unsigned p : latSets[s1])
170       latSets[s].push_back(p);
171     return s;
172   }
173 
174   /// Optimizes the iteration lattice points in the given set. This
175   /// method should be called right before code generation to avoid
176   /// generating redundant loops and conditions.
177   unsigned optimizeSet(unsigned s0) {
178     unsigned s = addSet();
179     assert(latSets[s0].size() != 0);
180     unsigned p0 = latSets[s0][0];
181     for (unsigned p1 : latSets[s0]) {
182       bool add = true;
183       if (p0 != p1) {
184         // Is this a straightforward copy?
185         unsigned e = latPoints[p1].exp;
186         if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
187           continue;
188         // Conjunction already covered?
189         for (unsigned p2 : latSets[s]) {
190           assert(!latGT(p1, p2)); // Lj => Li would be bad
191           if (onlyDenseDiff(p2, p1)) {
192             add = false;
193             break;
194           }
195         }
196         assert(!add || latGT(p0, p1));
197       }
198       if (add)
199         latSets[s].push_back(p1);
200     }
201     for (unsigned p : latSets[s])
202       latPoints[p].simple = simplifyCond(s, p);
203     return s;
204   }
205 
206   /// Simplifies the conditions in a conjunction of a given lattice point
207   /// within the given set using just two basic rules:
208   /// (1) multiple dense conditions are reduced to single dense, and
209   /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
210   llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
211     // First determine if this lattice point is a *singleton*, i.e.,
212     // the last point in a lattice, no other is less than this one.
213     bool isSingleton = true;
214     for (unsigned p1 : latSets[s]) {
215       if (p0 != p1 && latGT(p0, p1)) {
216         isSingleton = false;
217         break;
218       }
219     }
220     // Now apply the two basic rules.
221     llvm::BitVector simple = latPoints[p0].bits;
222     bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
223     for (unsigned b = 0, be = simple.size(); b < be; b++) {
224       if (simple[b] && !isDim(b, Dim::kSparse)) {
225         if (reset)
226           simple.reset(b);
227         reset = true;
228       }
229     }
230     return simple;
231   }
232 
233   /// Returns true if Li > Lj.
234   bool latGT(unsigned i, unsigned j) const {
235     const llvm::BitVector &bitsi = latPoints[i].bits;
236     const llvm::BitVector &bitsj = latPoints[j].bits;
237     assert(bitsi.size() == bitsj.size());
238     if (bitsi.count() > bitsj.count()) {
239       for (unsigned b = 0, be = bitsj.size(); b < be; b++)
240         if (bitsj[b] && !bitsi[b])
241           return false;
242       return true;
243     }
244     return false;
245   }
246 
247   /// Returns true if Li and Lj only differ in dense.
248   bool onlyDenseDiff(unsigned i, unsigned j) {
249     llvm::BitVector tmp = latPoints[j].bits;
250     tmp ^= latPoints[i].bits;
251     return !hasAnyDimOf(tmp, Dim::kSparse);
252   }
253 
254   /// Bit translation.
255   unsigned tensor(unsigned b) const { return b % numTensors; }
256   unsigned index(unsigned b) const { return b / numTensors; }
257 
258   /// Returns true if bit corresponds to queried dim.
259   bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
260 
261   /// Returns true if tensor access at given index has queried dim.
262   bool isDim(unsigned t, unsigned i, Dim d) const {
263     assert(t < numTensors && i < numLoops);
264     return dims[t][i] == d;
265   }
266 
267   /// Returns true if any set bit corresponds to queried dim.
268   bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
269     for (unsigned b = 0, be = bits.size(); b < be; b++)
270       if (bits[b] && isDim(b, d))
271         return true;
272     return false;
273   }
274 
275   /// Setter
276   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
277 
278   /// Getters.
279   TensorExp &exp(unsigned e) { return tensorExps[e]; }
280   LatPoint &lat(unsigned l) { return latPoints[l]; }
281   SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
282 
283 private:
284   const unsigned outTensor;
285   const unsigned numTensors;
286   const unsigned numLoops;
287 
288   std::vector<std::vector<Dim>> dims;
289   llvm::SmallVector<TensorExp, 32> tensorExps;
290   llvm::SmallVector<LatPoint, 16> latPoints;
291   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
292 };
293 
294 // Code generation.
295 struct CodeGen {
296   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
297       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
298         pointers(numTensors, std::vector<Value>(numLoops)),
299         indices(numTensors, std::vector<Value>(numLoops)),
300         highs(numTensors, std::vector<Value>(numLoops)),
301         pidxs(numTensors, std::vector<Value>(numLoops)),
302         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
303         curVecLength(1), curVecMask() {}
304   /// Sparsification options.
305   SparsificationOptions options;
306   /// Universal dense indices and upper bounds (by index). The loops array
307   /// is updated with the value of the universal dense index in the current
308   /// loop. The sizes array is set once with the inferred dimension sizes.
309   std::vector<Value> loops;
310   std::vector<Value> sizes;
311   /// Buffers for storing dense and sparse numerical values (by tensor).
312   /// This array is set once during bufferization of all tensors.
313   std::vector<Value> buffers;
314   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
315   /// This array is set once during bufferization of all sparse tensors.
316   std::vector<std::vector<Value>> pointers;
317   std::vector<std::vector<Value>> indices;
318   /// Sparse iteration information (by tensor and index). These arrays
319   /// are updated to remain current within the current loop.
320   std::vector<std::vector<Value>> highs;
321   std::vector<std::vector<Value>> pidxs;
322   std::vector<std::vector<Value>> idxs;
323   /// Current reduction, updated during code generation. When indices of a
324   /// reduction are exhausted,  all inner loops can "scalarize" the reduction.
325   // TODO: currently only done for (a chain of) innermost for-loops, where it
326   // is most effective; we could generalize to more outer and while-loops.
327   unsigned redExp;
328   Value redVal;
329   // Current vector length and mask.
330   unsigned curVecLength;
331   Value curVecMask;
332 };
333 
334 } // namespace
335 
336 // Helper method to translate dim level type to internal representation.
337 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
338   if (enc) {
339     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
340     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
341       return Dim::kSparse;
342     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
343       return Dim::kSingle;
344   }
345   return Dim::kDense;
346 }
347 
348 /// Helper method to inspect sparse encodings in the tensor types.
349 /// Fills the per-dimension sparsity information for all tensors.
350 static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
351   unsigned numTensors = op.getNumShapedOperands();
352   for (unsigned t = 0; t < numTensors; t++) {
353     auto map = op.getIndexingMap(t);
354     unsigned rank = op.getShapedType(t).getRank();
355     auto enc = getSparseTensorEncoding(op.getShapedType(t));
356     for (unsigned d = 0; d < rank; d++) {
357       unsigned idx = map.getDimPosition(d);
358       merger.setDim(t, idx, toDim(enc, d));
359     }
360   }
361 }
362 
363 /// A DFS helper to compute a topological sort. Note that recursion is
364 /// bounded by the number of implicit loops, which is always small.
365 /// Returns false when a cycle is detected.
366 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
367                        std::vector<unsigned> &topSort,
368                        std::vector<std::vector<bool>> &adjM) {
369   if (visit[i] != 0)
370     return visit[i] != 1; // 1 denotes cycle!
371   visit[i] = 1;
372   for (unsigned j = 0, e = visit.size(); j < e; j++)
373     if (adjM[i][j])
374       if (!topSortDFS(j, visit, topSort, adjM))
375         return false;
376   visit[i] = 2;
377   topSort.push_back(i);
378   return true;
379 }
380 
381 /// Computes a topologically sorted iteration graph for the linalg operation.
382 /// Ensures all tensors are visited in natural index order. This is essential
383 /// for sparse storage formats since these only support access along fixed
384 /// dimensions. Even for dense storage formats, however, the natural index
385 /// order yields innermost unit-stride access with better spatial locality.
386 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
387                                   std::vector<unsigned> &topSort,
388                                   bool sparseOnly) {
389   // Set up an n x n from/to adjacency matrix of the iteration graph
390   // for the implicit loop indices i_0 .. i_n-1.
391   unsigned n = op.getNumLoops();
392   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
393 
394   // Iterate over the indexing maps of every tensor in the tensor expression.
395   unsigned numTensors = op.getNumShapedOperands();
396   for (unsigned t = 0; t < numTensors; t++) {
397     auto map = op.getIndexingMap(t);
398     assert(map.getNumDims() == n);
399     // Skip dense tensor constraints when sparse only is requested.
400     if (sparseOnly && !getSparseTensorEncoding(op.getShapedType(t)))
401       continue;
402     // At the moment, we take the index variables in the tensor access
403     // expression in the order in which they appear (conceptually a
404     // "row-major" layout of every tensor). So, a tensor access A_ijk
405     // forces the ordering i < j < k on the loop indices.
406     // TODO: support affine map to define alternative dimension orders.
407     for (unsigned d = 1, e = map.getNumResults(); d < e; d++) {
408       unsigned f = map.getDimPosition(d - 1);
409       unsigned t = map.getDimPosition(d);
410       adjM[f][t] = true;
411     }
412   }
413 
414   // Topologically sort the iteration graph to determine loop order.
415   // Report failure for a cyclic iteration graph.
416   topSort.clear();
417   topSort.reserve(n);
418   std::vector<unsigned> visit(n, 0);
419   for (unsigned i = 0; i < n; i++)
420     if (visit[i] == 0)
421       if (!topSortDFS(i, visit, topSort, adjM))
422         return false; // cycle!
423   std::reverse(std::begin(topSort), std::end(topSort));
424   return true;
425 }
426 
427 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
428 /// This simplifies constructing (sub)expressions during iteration lattice
429 /// building (compared to using the SSA representation everywhere).
430 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
431                                          Value val) {
432   if (auto arg = val.dyn_cast<BlockArgument>()) {
433     unsigned argN = arg.getArgNumber();
434     if (arg.getOwner()->getParentOp() == op) {
435       // Any parameter of the generic op is considered a tensor,
436       // indexed by the implicit loop bounds.
437       auto map = op.getIndexingMap(argN);
438       if (map.isProjectedPermutation())
439         return merger.addExp(Kind::kTensor, argN);
440       // Cannot handle (yet).
441       return None;
442     }
443     // Any parameter of a higher op is invariant.
444     return merger.addExp(Kind::kInvariant, val);
445   }
446   Operation *def = val.getDefiningOp();
447   if (def->getBlock() != &op.region().front()) {
448     // Something defined outside is invariant.
449     return merger.addExp(Kind::kInvariant, val);
450   } else if (def->getNumOperands() == 2) {
451     // Construct binary operations if subexpressions could be built.
452     auto x = buildTensorExp(merger, op, def->getOperand(0));
453     auto y = buildTensorExp(merger, op, def->getOperand(1));
454     if (x.hasValue() && y.hasValue()) {
455       unsigned e0 = x.getValue();
456       unsigned e1 = y.getValue();
457       if (isa<MulFOp>(def))
458         return merger.addExp(Kind::kMulF, e0, e1);
459       if (isa<MulIOp>(def))
460         return merger.addExp(Kind::kMulI, e0, e1);
461       if (isa<AddFOp>(def))
462         return merger.addExp(Kind::kAddF, e0, e1);
463       if (isa<AddIOp>(def))
464         return merger.addExp(Kind::kAddI, e0, e1);
465     }
466   }
467   // Cannot build (yet).
468   return None;
469 }
470 
471 /// Builds the iteration lattices in a bottom-up traversal given the remaining
472 /// tensor (sub)expression and the next loop index in the iteration graph.
473 static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
474                               unsigned exp, unsigned idx) {
475   Kind kind = merger.exp(exp).kind;
476   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
477     // Either the index is really used in the tensor expression, or it is
478     // set to the undefined index in that dimension. An invariant expression
479     // is set to a synthetic tensor with undefined indices only.
480     unsigned s = merger.addSet();
481     unsigned t =
482         kind == Kind::kTensor ? merger.exp(exp).e0 : op.getNumShapedOperands();
483     merger.set(s).push_back(merger.addLat(t, idx, exp));
484     return s;
485   }
486   unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx);
487   unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx);
488   switch (kind) {
489   case Kind::kTensor:
490   case Kind::kInvariant:
491     llvm_unreachable("handled above");
492   case Kind::kMulF:
493   case Kind::kMulI:
494     return merger.takeConj(kind, s0, s1);
495   case Kind::kAddF:
496   case Kind::kAddI:
497     return merger.takeDisj(kind, s0, s1);
498   }
499   llvm_unreachable("unexpected expression kind");
500 }
501 
502 /// Maps sparse integer option to actual integral storage type.
503 static Type genIntType(PatternRewriter &rewriter, unsigned width) {
504   if (width == 0)
505     return rewriter.getIndexType();
506   return rewriter.getIntegerType(width);
507 }
508 
509 /// Generates buffer for the output tensor.
510 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
511                              linalg::GenericOp op, MemRefType denseTp,
512                              ArrayRef<Value> args) {
513   Location loc = op.getLoc();
514   Value tensor = op.getOutput(0);
515   // The output tensor simply could materialize from the buffer that will
516   // be generated for the tensor present in the outs() clause. This has
517   // the major advantage that the sparse kernel only updates the nonzero
518   // positions for the output tensor. Currently this results in functional,
519   // but slightly imprecise IR, so it is put under an experimental option.
520   if (codegen.options.fastOutput)
521     return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
522   // By default, a new buffer is allocated which is initialized to the
523   // tensor defined in the outs() clause. This is always correct but
524   // introduces a dense initialization component that may negatively
525   // impact the running complexity of the sparse kernel.
526   Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
527   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
528   rewriter.create<linalg::CopyOp>(loc, init, alloc);
529   return alloc;
530 }
531 
532 /// Local bufferization of all dense and sparse data structures.
533 /// This code enables testing the first prototype sparse compiler.
534 // TODO: replace this with a proliferated bufferization strategy
535 static void genBuffers(Merger &merger, CodeGen &codegen,
536                        PatternRewriter &rewriter, linalg::GenericOp op) {
537   Location loc = op.getLoc();
538   unsigned numTensors = op.getNumShapedOperands();
539   unsigned numInputs = op.getNumInputs();
540   assert(numTensors == numInputs + 1);
541   // For every tensor, find lower and upper bound on dimensions, set the
542   // same bounds on loop indices, and obtain dense or sparse buffer(s).
543   SmallVector<Value, 4> args;
544   for (unsigned t = 0; t < numTensors; t++) {
545     Value tensor = t < numInputs ? op.getInput(t) : op.getOutput(0);
546     auto tensorType = op.getShapedType(t);
547     auto shape = tensorType.getShape();
548     auto map = op.getIndexingMap(t);
549     auto enc = getSparseTensorEncoding(tensorType);
550     // Scan all dimensions of current tensor.
551     args.clear();
552     for (unsigned d = 0, rank = shape.size(); d < rank; d++) {
553       unsigned i = map.getDimPosition(d);
554       // Handle sparse storage schemes.
555       if (merger.isDim(t, i, Dim::kSparse)) {
556         auto dynShape = {ShapedType::kDynamicSize};
557         auto ptrTp = MemRefType::get(
558             dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
559         auto indTp = MemRefType::get(
560             dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
561         Value dim = rewriter.create<ConstantIndexOp>(loc, d);
562         // Generate sparse primitives to obtains pointer and indices.
563         codegen.pointers[t][i] =
564             rewriter.create<ToPointersOp>(loc, ptrTp, tensor, dim);
565         codegen.indices[t][i] =
566             rewriter.create<ToIndicesOp>(loc, indTp, tensor, dim);
567       }
568       // Find lower and upper bound in current dimension.
569       Value up;
570       if (shape[d] == MemRefType::kDynamicSize) {
571         up = rewriter.create<memref::DimOp>(loc, tensor, d);
572         args.push_back(up);
573       } else {
574         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
575       }
576       codegen.sizes[i] = codegen.highs[t][i] = up;
577     }
578     // Perform the required bufferization. All dense inputs materialize
579     // from the input tensor. The dense output tensor needs special
580     // handling. Sparse inputs use a sparse primitive to obtain the values.
581     if (!enc) {
582       auto denseTp = MemRefType::get(shape, tensorType.getElementType());
583       if (t < numInputs)
584         codegen.buffers[t] =
585             rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
586       else
587         codegen.buffers[t] =
588             genOutputBuffer(codegen, rewriter, op, denseTp, args);
589     } else {
590       auto dynShape = {ShapedType::kDynamicSize};
591       auto sparseTp = MemRefType::get(dynShape, tensorType.getElementType());
592       codegen.buffers[t] = rewriter.create<ToValuesOp>(loc, sparseTp, tensor);
593     }
594   }
595 }
596 
597 /// Constructs vector type.
598 static VectorType vectorType(CodeGen &codegen, Type etp) {
599   return VectorType::get(codegen.curVecLength, etp);
600 }
601 
602 /// Constructs vector type from pointer.
603 static VectorType vectorType(CodeGen &codegen, Value ptr) {
604   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
605 }
606 
607 /// Constructs vector iteration mask.
608 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
609                            Value iv, Value lo, Value hi, Value step) {
610   Location loc = iv.getLoc();
611   VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1));
612   // Special case if the vector length evenly divides the trip count (for
613   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
614   // so that all subsequent masked memory operations are immediately folded
615   // into unconditional memory operations.
616   IntegerAttr loInt, hiInt, stepInt;
617   if (matchPattern(lo, m_Constant(&loInt)) &&
618       matchPattern(hi, m_Constant(&hiInt)) &&
619       matchPattern(step, m_Constant(&stepInt))) {
620     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
621       return rewriter.create<vector::BroadcastOp>(
622           loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1));
623   }
624   // Otherwise, generate a vector mask that avoids overrunning the upperbound
625   // during vector execution. Here we rely on subsequent loop optimizations to
626   // avoid executing the mask in all iterations, for example, by splitting the
627   // loop into an unconditional vector loop and a scalar cleanup loop.
628   Value end = rewriter.create<SubIOp>(loc, hi, iv);
629   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
630 }
631 
632 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
633 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
634                            Value ptr, ArrayRef<Value> args) {
635   Location loc = ptr.getLoc();
636   VectorType vtp = vectorType(codegen, ptr);
637   Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
638   if (args.back().getType().isa<VectorType>()) {
639     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
640     Value indexVec = args.back();
641     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
642     return rewriter.create<vector::GatherOp>(
643         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
644   }
645   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
646                                                codegen.curVecMask, pass);
647 }
648 
649 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
650 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
651                            Value rhs, Value ptr, ArrayRef<Value> args) {
652   Location loc = ptr.getLoc();
653   if (args.back().getType().isa<VectorType>()) {
654     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
655     Value indexVec = args.back();
656     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
657     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
658                                        codegen.curVecMask, rhs);
659     return;
660   }
661   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
662                                          rhs);
663 }
664 
665 /// Generates a vectorized invariant. Here we rely on subsequent loop
666 /// optimizations to hoist the invariant broadcast out of the vector loop.
667 static Value genVectorInvariantValue(CodeGen &codegen,
668                                      PatternRewriter &rewriter, Value val) {
669   VectorType vtp = vectorType(codegen, val.getType());
670   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
671 }
672 
673 /// Generates a load on a dense or sparse tensor.
674 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
675                            PatternRewriter &rewriter, linalg::GenericOp op,
676                            unsigned exp) {
677   // Test if the load was hoisted to a higher loop nest.
678   Value val = merger.exp(exp).val;
679   if (val) {
680     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
681       return genVectorInvariantValue(codegen, rewriter, val);
682     return val;
683   }
684   // Actual load.
685   SmallVector<Value, 4> args;
686   unsigned tensor = merger.exp(exp).e0;
687   auto map = op.getIndexingMap(tensor);
688   auto enc = getSparseTensorEncoding(op.getShapedType(tensor));
689   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
690     unsigned idx = map.getDimPosition(i);
691     args.push_back(codegen.loops[idx]); // universal dense index
692     if (enc) {
693       args.clear();
694       args.push_back(codegen.pidxs[tensor][idx]); // position index
695     }
696   }
697   Location loc = op.getLoc();
698   Value ptr = codegen.buffers[tensor];
699   if (codegen.curVecLength > 1)
700     return genVectorLoad(codegen, rewriter, ptr, args);
701   return rewriter.create<memref::LoadOp>(loc, ptr, args);
702 }
703 
704 /// Generates a store on a dense tensor.
705 static void genTensorStore(Merger &merger, CodeGen &codegen,
706                            PatternRewriter &rewriter, linalg::GenericOp op,
707                            unsigned tensor, Value rhs) {
708   Location loc = op.getLoc();
709   // Test if this is a scalarized reduction.
710   unsigned lhs = op.getNumShapedOperands() - 1;
711   if (lhs == tensor && codegen.redVal) {
712     if (codegen.curVecLength > 1)
713       rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs,
714                                       codegen.redVal);
715     codegen.redVal = rhs;
716     return;
717   }
718   // Actual store.
719   SmallVector<Value, 4> args;
720   auto map = op.getIndexingMap(tensor);
721   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
722     unsigned idx = map.getDimPosition(i);
723     args.push_back(codegen.loops[idx]); // universal dense index
724   }
725   Value ptr = codegen.buffers[tensor];
726   if (codegen.curVecLength > 1)
727     genVectorStore(codegen, rewriter, rhs, ptr, args);
728   else
729     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
730 }
731 
732 /// Generates a pointer/index load from the sparse storage scheme. Narrower
733 /// data types need to be zero extended before casting the value into the
734 /// index type used for looping and indexing.
735 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
736                      Value ptr, Value s) {
737   // See https://llvm.org/docs/GetElementPtr.html for some background on
738   // the complications described below.
739   if (codegen.curVecLength > 1) {
740     // Since the index vector is used in a subsequent gather/scatter operations,
741     // which effectively defines an unsigned pointer + signed index, we must
742     // zero extend the vector to an index width. For 8-bit and 16-bit values,
743     // an 32-bit index width suffices. For 32-bit values, zero extending the
744     // elements into 64-bit loses some performance since the 32-bit indexed
745     // gather/scatter is more efficient than the 64-bit index variant (in
746     // the future, we could introduce a flag that states the negative space
747     // of 32-bit indices is unused). For 64-bit values, there is no good way
748     // to state that the indices are unsigned, with creates the potential of
749     // incorrect address calculations in the unlikely case we need such
750     // extremely large offsets.
751     Type etp = ptr.getType().cast<MemRefType>().getElementType();
752     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
753     if (!etp.isa<IndexType>()) {
754       if (etp.getIntOrFloatBitWidth() < 32)
755         vload = rewriter.create<ZeroExtendIOp>(
756             loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
757       else if (etp.getIntOrFloatBitWidth() < 64)
758         vload = rewriter.create<ZeroExtendIOp>(
759             loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
760     }
761     return vload;
762   }
763   // For the scalar case, we simply zero extend narrower indices into 64-bit
764   // values before casting to index without a performance penalty. Here too,
765   // however, indices that already are 64-bit, in theory, cannot express the
766   // full range as explained above.
767   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
768   if (!load.getType().isa<IndexType>()) {
769     if (load.getType().getIntOrFloatBitWidth() < 64)
770       load = rewriter.create<ZeroExtendIOp>(loc, load,
771                                             rewriter.getIntegerType(64));
772     load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
773   }
774   return load;
775 }
776 
777 /// Generates an invariant value.
778 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
779                                PatternRewriter &rewriter, unsigned exp) {
780   Value val = merger.exp(exp).val;
781   if (codegen.curVecLength > 1)
782     return genVectorInvariantValue(codegen, rewriter, val);
783   return val;
784 }
785 
786 /// Generates an address computation "sz * p + i".
787 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
788                         Location loc, Value size, Value p, Value i) {
789   Value mul = rewriter.create<MulIOp>(loc, size, p);
790   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
791     Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType());
792     mul = genVectorInvariantValue(codegen, rewriter, inv);
793   }
794   return rewriter.create<AddIOp>(loc, mul, i);
795 }
796 
797 /// Generates start of a reduction.
798 static Value genReductionStart(Merger &merger, CodeGen &codegen,
799                                PatternRewriter &rewriter,
800                                linalg::GenericOp op) {
801   if (codegen.redVal)
802     return codegen.redVal; // chained with previous for-loop
803   if (codegen.curVecLength > 1) {
804     // TODO: assumes + reductions for now
805     VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]);
806     return rewriter.create<ConstantOp>(op.getLoc(), vtp,
807                                        rewriter.getZeroAttr(vtp));
808   }
809   return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
810 }
811 
812 /// Generates end of a reduction.
813 static void genReductionEnd(Merger &merger, CodeGen &codegen,
814                             PatternRewriter &rewriter, linalg::GenericOp op) {
815   Value red = codegen.redVal;
816   if (!red)
817     return;
818   assert(codegen.curVecLength == 1);
819   codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
820   unsigned lhs = op.getNumShapedOperands() - 1;
821   if (auto vtp = red.getType().dyn_cast<VectorType>()) {
822     // TODO: assumes + reductions for now
823     StringAttr kind = rewriter.getStringAttr("add");
824     Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
825     // Integer reductions don't accept an accumulator.
826     if (vtp.getElementType().isa<IntegerType>()) {
827       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
828                                                  kind, red, ValueRange{});
829       red = rewriter.create<AddIOp>(op.getLoc(), red, ld);
830     } else {
831       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
832                                                  kind, red, ld);
833     }
834   }
835   genTensorStore(merger, codegen, rewriter, op, lhs, red);
836 }
837 
838 /// Recursively generates tensor expression.
839 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
840                     linalg::GenericOp op, unsigned exp) {
841   if (merger.exp(exp).kind == Kind::kTensor)
842     return genTensorLoad(merger, codegen, rewriter, op, exp);
843   else if (merger.exp(exp).kind == Kind::kInvariant)
844     return genInvariantValue(merger, codegen, rewriter, exp);
845   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
846   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
847   switch (merger.exp(exp).kind) {
848   case Kind::kTensor:
849   case Kind::kInvariant:
850     llvm_unreachable("handled above");
851   case Kind::kMulF:
852     return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
853   case Kind::kMulI:
854     return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
855   case Kind::kAddF:
856     return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
857   case Kind::kAddI:
858     return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
859   }
860   llvm_unreachable("unexpected expression kind");
861 }
862 
863 /// Hoists loop invariant tensor loads for which indices have been exhausted.
864 static void genInvariants(Merger &merger, CodeGen &codegen,
865                           PatternRewriter &rewriter, linalg::GenericOp op,
866                           unsigned exp, unsigned ldx, bool hoist) {
867   if (merger.exp(exp).kind == Kind::kTensor) {
868     // Inspect tensor indices.
869     bool atLevel = ldx == -1u;
870     unsigned tensor = merger.exp(exp).e0;
871     auto map = op.getIndexingMap(tensor);
872     for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
873       unsigned idx = map.getDimPosition(i);
874       if (!codegen.loops[idx])
875         return; // still in play
876       else if (idx == ldx)
877         atLevel = true;
878     }
879     // All exhausted at this level (atLevel denotes exactly at this level).
880     unsigned lhs = op.getNumShapedOperands() - 1;
881     if (lhs == tensor) {
882       codegen.redExp = hoist ? exp : -1u;
883     } else if (atLevel) {
884       merger.exp(exp).val =
885           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
886     }
887   } else if (merger.exp(exp).kind != Kind::kInvariant) {
888     // Traverse into the binary operations. Note that we only hoist
889     // tensor loads, since subsequent MLIR/LLVM passes know how to
890     // deal with all other kinds of derived loop invariants.
891     unsigned e0 = merger.exp(exp).e0;
892     unsigned e1 = merger.exp(exp).e1;
893     genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist);
894     genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist);
895   }
896 }
897 
898 /// Generates initialization code for the subsequent loop sequence at
899 /// current index level. Returns true if the loop sequence needs to
900 /// maintain the universal index.
901 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
902                     linalg::GenericOp op, std::vector<unsigned> &topSort,
903                     unsigned at, llvm::BitVector &inits) {
904   bool needsUniv = false;
905   Location loc = op.getLoc();
906   unsigned idx = topSort[at];
907 
908   // Initialize sparse positions.
909   for (unsigned b = 0, be = inits.size(); b < be; b++) {
910     if (inits[b]) {
911       unsigned tensor = merger.tensor(b);
912       assert(idx == merger.index(b));
913       if (merger.isDim(b, Dim::kSparse)) {
914         // Initialize sparse index.
915         unsigned pat = at;
916         for (; pat != 0; pat--) {
917           if (codegen.pidxs[tensor][topSort[pat - 1]])
918             break;
919         }
920         Value ptr = codegen.pointers[tensor][idx];
921         Value one = rewriter.create<ConstantIndexOp>(loc, 1);
922         Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
923                               : codegen.pidxs[tensor][topSort[pat - 1]];
924         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
925         Value p1 = rewriter.create<AddIOp>(loc, p0, one);
926         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
927       } else {
928         // Dense index still in play.
929         needsUniv = true;
930       }
931     }
932   }
933 
934   // Initialize the universal dense index.
935   codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0);
936   return needsUniv;
937 }
938 
939 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
940 /// operation is a candidate. Whether it is actually converted to SIMD code
941 /// depends on the requested strategy.
942 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
943   switch (codegen.options.vectorizationStrategy) {
944   case SparseVectorizationStrategy::kNone:
945     return false;
946   case SparseVectorizationStrategy::kDenseInnerLoop:
947     return isInner && !isSparse;
948   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
949     return isInner;
950   }
951   llvm_unreachable("unexpected vectorization strategy");
952 }
953 
954 /// Returns parallelization strategy. Any implicit loop in the Linalg operation
955 /// that is marked "parallel" is a candidate. Whether it is actually converted
956 /// to a parallel operation depends on the requested strategy.
957 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
958                           bool isSparse, bool isVector) {
959   switch (codegen.options.parallelizationStrategy) {
960   case SparseParallelizationStrategy::kNone:
961     return false;
962   case SparseParallelizationStrategy::kDenseOuterLoop:
963     return isOuter && !isSparse && !isReduction && !isVector;
964   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
965     return isOuter && !isReduction && !isVector;
966   case SparseParallelizationStrategy::kDenseAnyLoop:
967     return !isSparse && !isReduction && !isVector;
968   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
969     return !isReduction && !isVector;
970   }
971   llvm_unreachable("unexpected parallelization strategy");
972 }
973 
974 /// Checks unit strides for dense tensors. The iteration graph may have ignored
975 /// dense access patterns in order to avoid cycles (sparse access patterns are
976 /// always placed innermost), but that means dense access has become strided.
977 /// For now, we reject vectorization of such cases.
978 /// TODO: implement strided load/stores on dense arrays
979 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
980                              unsigned idx) {
981   unsigned numTensors = op.getNumShapedOperands();
982   for (unsigned t = 0; t < numTensors; t++) {
983     if (!getSparseTensorEncoding(op.getShapedType(t))) {
984       auto map = op.getIndexingMap(t);
985       unsigned r = map.getNumResults();
986       for (unsigned i = 0; i < r; i++) {
987         if (map.getDimPosition(i) == idx && i != r - 1)
988           return false;
989       }
990     }
991   }
992   return true;
993 }
994 
995 /// Generates a for-loop on a single index.
996 static Operation *genFor(Merger &merger, CodeGen &codegen,
997                          PatternRewriter &rewriter, linalg::GenericOp op,
998                          bool isOuter, bool isInner, unsigned idx,
999                          llvm::BitVector &indices) {
1000   unsigned fb = indices.find_first();
1001   unsigned tensor = merger.tensor(fb);
1002   assert(idx == merger.index(fb));
1003   auto iteratorTypes = op.iterator_types().getValue();
1004   bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]);
1005   bool isSparse = merger.isDim(fb, Dim::kSparse);
1006   bool isVector = isVectorFor(codegen, isInner, isSparse) &&
1007                   denseUnitStrides(merger, op, idx);
1008   bool isParallel =
1009       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
1010 
1011   // Prepare vector length.
1012   if (isVector)
1013     codegen.curVecLength = codegen.options.vectorLength;
1014 
1015   // Loop bounds and increment.
1016   Location loc = op.getLoc();
1017   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
1018   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
1019   Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength);
1020 
1021   // Emit a parallel loop.
1022   if (isParallel) {
1023     assert(!isVector);
1024     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
1025     if (isSparse)
1026       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
1027     else
1028       codegen.loops[idx] = parOp.getInductionVars()[0];
1029     rewriter.setInsertionPointToStart(parOp.getBody());
1030     return parOp;
1031   }
1032 
1033   // Emit a sequential loop, potentially with a scalarized reduction.
1034   bool scalarRed = isInner && codegen.redExp != -1u;
1035   SmallVector<Value, 4> operands;
1036   if (scalarRed) {
1037     Value load = genReductionStart(merger, codegen, rewriter, op);
1038     operands.push_back(load);
1039   }
1040   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
1041   if (scalarRed) {
1042     codegen.redVal = merger.exp(codegen.redExp).val =
1043         forOp.getRegionIterArgs().front();
1044   }
1045   // Assign induction variable to sparse or dense index.
1046   Value iv = forOp.getInductionVar();
1047   if (isSparse)
1048     codegen.pidxs[tensor][idx] = iv;
1049   else
1050     codegen.loops[idx] = iv;
1051   rewriter.setInsertionPointToStart(forOp.getBody());
1052   // Share vector iteration mask between all subsequent loads/stores.
1053   if (isVector)
1054     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
1055   return forOp;
1056 }
1057 
1058 /// Emit a while-loop for co-iteration over multiple indices.
1059 static Operation *genWhile(Merger &merger, CodeGen &codegen,
1060                            PatternRewriter &rewriter, linalg::GenericOp op,
1061                            unsigned idx, bool needsUniv,
1062                            llvm::BitVector &indices) {
1063   SmallVector<Type, 4> types;
1064   SmallVector<Value, 4> operands;
1065   // Construct the while-loop with a parameter for each index.
1066   Type indexType = rewriter.getIndexType();
1067   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1068     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1069       unsigned tensor = merger.tensor(b);
1070       assert(idx == merger.index(b));
1071       types.push_back(indexType);
1072       assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
1073              "type mismatch for sparse index");
1074       operands.push_back(codegen.pidxs[tensor][idx]);
1075     }
1076   }
1077   if (needsUniv) {
1078     types.push_back(indexType);
1079     assert(codegen.loops[idx].getType().isa<IndexType>() &&
1080            "type mismatch for universal index");
1081     operands.push_back(codegen.loops[idx]);
1082   }
1083   Location loc = op.getLoc();
1084   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
1085   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
1086   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
1087 
1088   // Build the "before" region, which effectively consists
1089   // of a conjunction of "i < upper" tests on all induction.
1090   rewriter.setInsertionPointToStart(&whileOp.before().front());
1091   Value cond;
1092   unsigned o = 0;
1093   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1094     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1095       unsigned tensor = merger.tensor(b);
1096       assert(idx == merger.index(b));
1097       Value op1 = before->getArgument(o);
1098       Value op2 = codegen.highs[tensor][idx];
1099       Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2);
1100       cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc;
1101       codegen.pidxs[tensor][idx] = after->getArgument(o++);
1102     }
1103   }
1104   if (needsUniv)
1105     codegen.loops[idx] = after->getArgument(o++);
1106   assert(o == operands.size());
1107   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
1108   rewriter.setInsertionPointToStart(&whileOp.after().front());
1109   return whileOp;
1110 }
1111 
1112 /// Generates a for-loop or a while-loop, depending on whether it implements
1113 /// singleton iteration or co-iteration over the given conjunction.
1114 static Operation *genLoop(Merger &merger, CodeGen &codegen,
1115                           PatternRewriter &rewriter, linalg::GenericOp op,
1116                           std::vector<unsigned> &topSort, unsigned at,
1117                           bool needsUniv, llvm::BitVector &indices) {
1118   unsigned idx = topSort[at];
1119   if (indices.count() == 1) {
1120     bool isOuter = at == 0;
1121     bool isInner = at == topSort.size() - 1;
1122     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
1123                   indices);
1124   }
1125   genReductionEnd(merger, codegen, rewriter, op); // cannot chain
1126   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
1127 }
1128 
1129 /// Generates the local variables for this loop, consisting of the sparse
1130 /// indices, restored universal dense index, and dense positions.
1131 static void genLocals(Merger &merger, CodeGen &codegen,
1132                       PatternRewriter &rewriter, linalg::GenericOp op,
1133                       std::vector<unsigned> &topSort, unsigned at,
1134                       bool needsUniv, llvm::BitVector &locals) {
1135   Location loc = op.getLoc();
1136   unsigned idx = topSort[at];
1137 
1138   // Initialize sparse indices.
1139   Value min;
1140   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1141     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1142       unsigned tensor = merger.tensor(b);
1143       assert(idx == merger.index(b));
1144       Value ptr = codegen.indices[tensor][idx];
1145       Value s = codegen.pidxs[tensor][idx];
1146       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1147       codegen.idxs[tensor][idx] = load;
1148       if (!needsUniv) {
1149         if (min) {
1150           Value cmp =
1151               rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
1152           min = rewriter.create<SelectOp>(loc, cmp, load, min);
1153         } else {
1154           min = load;
1155         }
1156       }
1157     }
1158   }
1159 
1160   // Merge dense universal index over minimum.
1161   if (min) {
1162     assert(!needsUniv);
1163     codegen.loops[idx] = min;
1164   }
1165 
1166   // Initialize dense positions.
1167   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1168     if (locals[b] && merger.isDim(b, Dim::kDense)) {
1169       unsigned tensor = merger.tensor(b);
1170       assert(idx == merger.index(b));
1171       unsigned pat = at;
1172       for (; pat != 0; pat--)
1173         if (codegen.pidxs[tensor][topSort[pat - 1]])
1174           break;
1175       Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
1176                            : codegen.pidxs[tensor][topSort[pat - 1]];
1177       codegen.pidxs[tensor][idx] = genAddress(
1178           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1179     }
1180   }
1181 }
1182 
1183 /// Generates the induction structure for a while-loop.
1184 static void genWhileInduction(Merger &merger, CodeGen &codegen,
1185                               PatternRewriter &rewriter, linalg::GenericOp op,
1186                               unsigned idx, bool needsUniv,
1187                               llvm::BitVector &induction, ResultRange results) {
1188   Location loc = op.getLoc();
1189   unsigned o = 0;
1190   SmallVector<Value, 4> operands;
1191   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
1192   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1193     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1194       unsigned tensor = merger.tensor(b);
1195       assert(idx == merger.index(b));
1196       Value op1 = codegen.idxs[tensor][idx];
1197       Value op2 = codegen.loops[idx];
1198       Value op3 = codegen.pidxs[tensor][idx];
1199       Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1200       Value add = rewriter.create<AddIOp>(loc, op3, one);
1201       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
1202       codegen.pidxs[tensor][idx] = results[o++];
1203     }
1204   }
1205   if (needsUniv) {
1206     operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one));
1207     codegen.loops[idx] = results[o++];
1208   }
1209   assert(o == operands.size());
1210   rewriter.create<scf::YieldOp>(loc, operands);
1211 }
1212 
1213 /// Generates a single if-statement within a while-loop.
1214 static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1215                        PatternRewriter &rewriter, linalg::GenericOp op,
1216                        unsigned idx, llvm::BitVector &conditions) {
1217   Location loc = op.getLoc();
1218   Value cond;
1219   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1220     if (conditions[b]) {
1221       unsigned tensor = merger.tensor(b);
1222       assert(idx == merger.index(b));
1223       Value clause;
1224       if (merger.isDim(b, Dim::kSparse)) {
1225         Value op1 = codegen.idxs[tensor][idx];
1226         Value op2 = codegen.loops[idx];
1227         clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1228       } else {
1229         clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true
1230       }
1231       cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause;
1232     }
1233   }
1234   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
1235   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
1236   return ifOp;
1237 }
1238 
1239 /// Recursively generates code while computing iteration lattices in order
1240 /// to manage the complexity of implementing co-iteration over unions
1241 /// and intersections of sparse iterations spaces.
1242 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1243                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1244                     unsigned exp, unsigned at) {
1245   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1246   if (at == topSort.size()) {
1247     unsigned lhs = op.getNumShapedOperands() - 1;
1248     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1249     genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
1250     return;
1251   }
1252   assert(codegen.curVecLength == 1);
1253 
1254   // Construct iteration lattices for current loop index, with L0 at top.
1255   // Then emit initialization code for the loop sequence at this level.
1256   // We maintain the universal dense index if dense indices are still
1257   // in play for a non-singleton loop sequence.
1258   Location loc = op.getLoc();
1259   unsigned idx = topSort[at];
1260   unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx));
1261   unsigned lsize = merger.set(lts).size();
1262   assert(lsize != 0);
1263   unsigned l0 = merger.set(lts)[0];
1264   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1265   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
1266   bool needsUniv = false;
1267   if (genInit(merger, codegen, rewriter, op, topSort, at,
1268               merger.lat(l0).bits)) {
1269     // Maintain the universal index only if it is actually
1270     // consumed by a subsequent lattice point.
1271     for (unsigned i = 1; i < lsize; i++) {
1272       unsigned li = merger.set(lts)[i];
1273       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) {
1274         needsUniv = true;
1275         break;
1276       }
1277     }
1278   }
1279 
1280   // Emit a loop for every lattice point L0 >= Li.
1281   for (unsigned i = 0; i < lsize; i++) {
1282     unsigned li = merger.set(lts)[i];
1283 
1284     // Emit loop.
1285     codegen.curVecLength = 1;
1286     llvm::BitVector indices = merger.lat(li).simple;
1287     Operation *loop =
1288         genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
1289     genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1290               merger.lat(li).bits);
1291 
1292     // Visit all lattices points with Li >= Lj to generate the
1293     // loop-body, possibly with if statements for coiteration.
1294     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1295     for (unsigned j = 0; j < lsize; j++) {
1296       unsigned lj = merger.set(lts)[j];
1297       unsigned ej = merger.lat(lj).exp;
1298       if (li == lj || merger.latGT(li, lj)) {
1299         // Recurse into body of each branch.
1300         if (isWhile) {
1301           scf::IfOp ifOp =
1302               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1303           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1304           rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
1305         } else {
1306           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1307         }
1308       }
1309     }
1310 
1311     // Wrap-up induction and restore insertion point.
1312     if (isWhile) {
1313       scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
1314       rewriter.setInsertionPointToEnd(&whileOp.after().front());
1315       genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
1316                         merger.lat(li).bits, whileOp.results());
1317     } else {
1318       needsUniv = false;
1319       if (codegen.redVal) {
1320         rewriter.create<scf::YieldOp>(loc, codegen.redVal);
1321         codegen.redVal = loop->getResult(0);
1322       }
1323     }
1324     rewriter.setInsertionPointAfter(loop);
1325   }
1326 
1327   // Wrap-up loop sequence.
1328   codegen.curVecLength = 1;
1329   genReductionEnd(merger, codegen, rewriter, op);
1330   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
1331   codegen.loops[idx] = Value();
1332 }
1333 
1334 namespace {
1335 
1336 /// Sparse rewriting rule for generic Lingalg operation.
1337 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1338 public:
1339   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1340       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1341 
1342   LogicalResult matchAndRewrite(linalg::GenericOp op,
1343                                 PatternRewriter &rewriter) const override {
1344     // Detects sparse annotations and translate the per-dimension sparsity
1345     // information for all tensors to loop indices in the kernel.
1346     assert(op.getNumOutputs() == 1);
1347     unsigned numTensors = op.getNumShapedOperands();
1348     unsigned numLoops = op.iterator_types().getValue().size();
1349     Merger merger(numTensors, numLoops);
1350     findSparseAnnotations(merger, op);
1351 
1352     // Computes a topologically sorted iteration graph to ensure
1353     // tensors are visited in natural index order. Fails on cycles.
1354     // This assumes that higher-level passes have already put the
1355     // tensors in each tensor expression in a feasible order.
1356     std::vector<unsigned> topSort;
1357     if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
1358         !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
1359       return failure();
1360 
1361     // Finds the terminating yield statement and builds the tensor
1362     // expression for the Linalg operation in SSA form.
1363     Operation *yield = op.region().front().getTerminator();
1364     Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
1365     if (!exp.hasValue())
1366       return failure(); // build failure
1367 
1368     // Recursively generates code.
1369     CodeGen codegen(options, numTensors, numLoops);
1370     genBuffers(merger, codegen, rewriter, op);
1371     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
1372     Value result = rewriter.create<memref::TensorLoadOp>(
1373         op.getLoc(), codegen.buffers.back());
1374     rewriter.replaceOp(op, result);
1375     return success();
1376   }
1377 
1378 private:
1379   /// Options to control sparse code generation.
1380   SparsificationOptions options;
1381 };
1382 
1383 } // namespace
1384 
1385 /// Populates the given patterns list with rewriting rules required for
1386 /// the sparsification of linear algebra operations.
1387 void mlir::populateSparsificationPatterns(
1388     RewritePatternSet &patterns, const SparsificationOptions &options) {
1389   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1390 }
1391