1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering sparse tensor types to actual sparse code.
10 //
11 // The concept of letting a compiler generate sparse code automatically was
12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and
13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor
14 // Algebra Compiler (TACO). The implementation in this file closely follows
15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting
16 // rule is applied to each tensor expression in linalg (MLIR's tensor index
17 // notation) where the sparsity of tensors is indicated with annotation using
18 // a per-dimension specification of sparse/dense storage together with a
19 // specification of the order on the dimensions. Subsequently, a topologically
20 // sorted iteration graph, reflecting the required order on indices with respect
21 // to the dimensions of each tensor, is constructed to ensure that all tensors
22 // are visited in natural index order. Next, iteration lattices are constructed
23 // for the tensor expression for every index in topological order. Each
24 // iteration lattice point consists of a conjunction of tensor indices together
25 // with a tensor (sub)expression that needs to be evaluated for that
26 // conjunction. Within the lattice, iteration points are ordered according to
27 // the way indices are exhausted. As such these iteration lattices drive actual
28 // sparse code generation, which consists of a tedious but relatively
29 // straightforward one-to-one mapping from iteration lattices to combinations
30 // of for-loops, while-loops, and if-statements.
31 //
32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations.
33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php).
34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou,
35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler.
36 // Proceedings of the ACM on Programming Languages, October 2017.
37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation.
38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org).
39 //
40 // Implementation detail: We use llvm::SmallVector for vectors with
41 // variable lengths and std::vector for vectors with fixed lengths.
42 //===----------------------------------------------------------------------===//
43 
44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
45 #include "mlir/Dialect/Linalg/Utils/Utils.h"
46 #include "mlir/Dialect/SCF/SCF.h"
47 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
48 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
49 #include "mlir/Dialect/StandardOps/IR/Ops.h"
50 #include "mlir/Dialect/Vector/VectorOps.h"
51 #include "mlir/IR/Matchers.h"
52 #include "mlir/IR/TensorEncoding.h"
53 #include "llvm/ADT/SmallBitVector.h"
54 
55 using namespace mlir;
56 using namespace mlir::sparse_tensor;
57 
58 namespace {
59 
60 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
61 enum class Dim { kSparse, kDense, kSingle, kUndef };
62 
63 /// Tensor expression. Represents a MLIR expression in tensor index notation.
64 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is
65 /// stored directly. For binary operations, e0 and e1 denote the index of the
66 /// children tensor expressions.
67 struct TensorExp {
68   TensorExp(Kind k, unsigned x, unsigned y, Value v)
69       : kind(k), e0(x), e1(y), val(v) {
70     assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
71            (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
72            (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
73   }
74   Kind kind;
75   /// Indices of children expression(s).
76   unsigned e0;
77   unsigned e1;
78   /// Direct link to IR for an invariant. During code generation,
79   /// field is used to cache "hoisted" loop invariant tensor loads.
80   Value val;
81 };
82 
83 /// Lattice point. Each lattice point consists of a conjunction of tensor
84 /// loop indices (encoded in a bitvector) and the index of the corresponding
85 /// tensor expression.
86 struct LatPoint {
87   LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
88     bits.set(b);
89   }
90   LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
91   /// Conjunction of tensor loop indices as bitvector. This represents
92   /// all indices involved in the tensor expression
93   llvm::BitVector bits;
94   /// Simplified conjunction of tensor loop indices as bitvector. This
95   /// represents a simplified condition under which this tensor expression
96   /// must execute. Pre-computed during codegen to avoid repeated eval.
97   llvm::BitVector simple;
98   /// Index of the tensor expresssion.
99   unsigned exp;
100 };
101 
102 /// A class to handle all iteration lattice operations. This class abstracts
103 /// away from some implementation details of storing iteration lattices and
104 /// tensor expressions. This allows for fine-tuning performance characteristics
105 /// independently from the basic algorithm if bottlenecks are identified.
106 class Merger {
107 public:
108   /// Constructs a merger for the given number of tensors and loops. The
109   /// user supplies the number of tensors involved in the kernel, with the
110   /// last tensor in this set denoting the output tensor. The merger adds an
111   /// additional synthetic tensor at the end of this set to represent all
112   /// invariant expressions in the kernel.
113   Merger(unsigned t, unsigned l)
114       : outTensor(t - 1), numTensors(t + 1), numLoops(l),
115         dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
116 
117   /// Adds a tensor expression. Returns its index.
118   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
119     unsigned e = tensorExps.size();
120     tensorExps.push_back(TensorExp(k, e0, e1, v));
121     return e;
122   }
123   unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
124 
125   /// Adds an iteration lattice point. Returns its index.
126   unsigned addLat(unsigned t, unsigned i, unsigned e) {
127     assert(t < numTensors && i < numLoops);
128     unsigned p = latPoints.size();
129     latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
130     return p;
131   }
132 
133   /// Adds a new, initially empty, set. Returns its index.
134   unsigned addSet() {
135     unsigned s = latSets.size();
136     latSets.emplace_back(SmallVector<unsigned, 16>());
137     return s;
138   }
139 
140   /// Computes a single conjunction of two lattice points by taking the "union"
141   /// of loop indices (effectively constructing a larger "intersection" of those
142   /// indices) with a newly constructed tensor (sub)expression of given kind.
143   /// Returns the index of the new lattice point.
144   unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
145     unsigned p = latPoints.size();
146     llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
147     nb |= latPoints[p1].bits;
148     unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
149     latPoints.push_back(LatPoint(nb, e));
150     return p;
151   }
152 
153   /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
154   /// cartesian product. Returns the index of the new set.
155   unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
156     unsigned s = addSet();
157     for (unsigned p0 : latSets[s0])
158       for (unsigned p1 : latSets[s1])
159         latSets[s].push_back(conjLatPoint(kind, p0, p1));
160     return s;
161   }
162 
163   /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
164   /// Returns the index of the new set.
165   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
166     unsigned s = takeConj(kind, s0, s1);
167     for (unsigned p : latSets[s0])
168       latSets[s].push_back(p);
169     for (unsigned p : latSets[s1])
170       latSets[s].push_back(p);
171     return s;
172   }
173 
174   /// Optimizes the iteration lattice points in the given set. This
175   /// method should be called right before code generation to avoid
176   /// generating redundant loops and conditions.
177   unsigned optimizeSet(unsigned s0) {
178     unsigned s = addSet();
179     assert(latSets[s0].size() != 0);
180     unsigned p0 = latSets[s0][0];
181     for (unsigned p1 : latSets[s0]) {
182       bool add = true;
183       if (p0 != p1) {
184         // Is this a straightforward copy?
185         unsigned e = latPoints[p1].exp;
186         if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
187           continue;
188         // Conjunction already covered?
189         for (unsigned p2 : latSets[s]) {
190           assert(!latGT(p1, p2)); // Lj => Li would be bad
191           if (onlyDenseDiff(p2, p1)) {
192             add = false;
193             break;
194           }
195         }
196         assert(!add || latGT(p0, p1));
197       }
198       if (add)
199         latSets[s].push_back(p1);
200     }
201     for (unsigned p : latSets[s])
202       latPoints[p].simple = simplifyCond(s, p);
203     return s;
204   }
205 
206   /// Simplifies the conditions in a conjunction of a given lattice point
207   /// within the given set using just two basic rules:
208   /// (1) multiple dense conditions are reduced to single dense, and
209   /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
210   llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
211     // First determine if this lattice point is a *singleton*, i.e.,
212     // the last point in a lattice, no other is less than this one.
213     bool isSingleton = true;
214     for (unsigned p1 : latSets[s]) {
215       if (p0 != p1 && latGT(p0, p1)) {
216         isSingleton = false;
217         break;
218       }
219     }
220     // Now apply the two basic rules.
221     llvm::BitVector simple = latPoints[p0].bits;
222     bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
223     for (unsigned b = 0, be = simple.size(); b < be; b++) {
224       if (simple[b] && !isDim(b, Dim::kSparse)) {
225         if (reset)
226           simple.reset(b);
227         reset = true;
228       }
229     }
230     return simple;
231   }
232 
233   /// Returns true if Li > Lj.
234   bool latGT(unsigned i, unsigned j) const {
235     const llvm::BitVector &bitsi = latPoints[i].bits;
236     const llvm::BitVector &bitsj = latPoints[j].bits;
237     assert(bitsi.size() == bitsj.size());
238     if (bitsi.count() > bitsj.count()) {
239       for (unsigned b = 0, be = bitsj.size(); b < be; b++)
240         if (bitsj[b] && !bitsi[b])
241           return false;
242       return true;
243     }
244     return false;
245   }
246 
247   /// Returns true if Li and Lj only differ in dense.
248   bool onlyDenseDiff(unsigned i, unsigned j) {
249     llvm::BitVector tmp = latPoints[j].bits;
250     tmp ^= latPoints[i].bits;
251     return !hasAnyDimOf(tmp, Dim::kSparse);
252   }
253 
254   /// Bit translation.
255   unsigned tensor(unsigned b) const { return b % numTensors; }
256   unsigned index(unsigned b) const { return b / numTensors; }
257 
258   /// Returns true if bit corresponds to queried dim.
259   bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
260 
261   /// Returns true if bit corresponds to index of output tensor.
262   bool isOutTensor(unsigned b, unsigned i) const {
263     return tensor(b) == outTensor && index(b) == i;
264   }
265 
266   /// Returns true if tensor access at given index has queried dim.
267   bool isDim(unsigned t, unsigned i, Dim d) const {
268     assert(t < numTensors && i < numLoops);
269     return dims[t][i] == d;
270   }
271 
272   /// Returns true if any set bit corresponds to queried dim.
273   bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
274     for (unsigned b = 0, be = bits.size(); b < be; b++)
275       if (bits[b] && isDim(b, d))
276         return true;
277     return false;
278   }
279 
280   /// Setter
281   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
282 
283   /// Getters.
284   TensorExp &exp(unsigned e) { return tensorExps[e]; }
285   LatPoint &lat(unsigned l) { return latPoints[l]; }
286   SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
287 
288 private:
289   const unsigned outTensor;
290   const unsigned numTensors;
291   const unsigned numLoops;
292 
293   std::vector<std::vector<Dim>> dims;
294   llvm::SmallVector<TensorExp, 32> tensorExps;
295   llvm::SmallVector<LatPoint, 16> latPoints;
296   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
297 };
298 
299 // Code generation.
300 struct CodeGen {
301   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
302       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
303         pointers(numTensors, std::vector<Value>(numLoops)),
304         indices(numTensors, std::vector<Value>(numLoops)),
305         highs(numTensors, std::vector<Value>(numLoops)),
306         pidxs(numTensors, std::vector<Value>(numLoops)),
307         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
308         curVecLength(1), curVecMask() {}
309   /// Sparsification options.
310   SparsificationOptions options;
311   /// Universal dense indices and upper bounds (by index). The loops array
312   /// is updated with the value of the universal dense index in the current
313   /// loop. The sizes array is set once with the inferred dimension sizes.
314   std::vector<Value> loops;
315   std::vector<Value> sizes;
316   /// Buffers for storing dense and sparse numerical values (by tensor).
317   /// This array is set once during bufferization of all tensors.
318   std::vector<Value> buffers;
319   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
320   /// This array is set once during bufferization of all sparse tensors.
321   std::vector<std::vector<Value>> pointers;
322   std::vector<std::vector<Value>> indices;
323   /// Sparse iteration information (by tensor and index). These arrays
324   /// are updated to remain current within the current loop.
325   std::vector<std::vector<Value>> highs;
326   std::vector<std::vector<Value>> pidxs;
327   std::vector<std::vector<Value>> idxs;
328   /// Current reduction, updated during code generation. When indices of a
329   /// reduction are exhausted,  all inner loops can "scalarize" the reduction.
330   // TODO: currently only done for (a chain of) innermost for-loops, where it
331   // is most effective; we could generalize to more outer and while-loops.
332   unsigned redExp;
333   Value redVal;
334   // Current vector length and mask.
335   unsigned curVecLength;
336   Value curVecMask;
337 };
338 
339 } // namespace
340 
341 // Helper method to apply dimension ordering permutation.
342 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) {
343   if (enc) {
344     auto order = enc.getDimOrdering();
345     if (order) {
346       assert(order.isPermutation());
347       return order.getDimPosition(d);
348     }
349   }
350   return d;
351 }
352 
353 // Helper method to translate dim level type to internal representation.
354 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
355   if (enc) {
356     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
357     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
358       return Dim::kSparse;
359     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
360       return Dim::kSingle;
361   }
362   return Dim::kDense;
363 }
364 
365 /// Helper method to inspect sparse encodings in the tensor types.
366 /// Fills the per-dimension sparsity information for all tensors.
367 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
368   bool annotated = false;
369   OpOperand *lhs = op.getOutputOperand(0);
370   for (OpOperand *t : op.getInputAndOutputOperands()) {
371     auto map = op.getTiedIndexingMap(t);
372     if (!map.isProjectedPermutation())
373       return false;
374     auto enc = getSparseTensorEncoding(t->get().getType());
375     if (enc)
376       annotated = true;
377     assert(map.getNumResults() == op.getRank(t));
378     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
379       unsigned idx = map.getDimPosition(perm(enc, d));
380       Dim dim = toDim(enc, d);
381       merger.setDim(t->getOperandNumber(), idx, toDim(enc, d));
382       // Accept only all-dense annotated "sparse" output.
383       // TODO: support truly sparse outputs too
384       if (t == lhs && dim != Dim::kDense)
385         return false;
386     }
387   }
388   return annotated;
389 }
390 
391 /// A DFS helper to compute a topological sort. Note that recursion is
392 /// bounded by the number of implicit loops, which is always small.
393 /// Returns false when a cycle is detected.
394 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
395                        std::vector<unsigned> &topSort,
396                        std::vector<std::vector<bool>> &adjM) {
397   if (visit[i] != 0)
398     return visit[i] != 1; // 1 denotes cycle!
399   visit[i] = 1;
400   for (unsigned j = 0, e = visit.size(); j < e; j++)
401     if (adjM[i][j])
402       if (!topSortDFS(j, visit, topSort, adjM))
403         return false;
404   visit[i] = 2;
405   topSort.push_back(i);
406   return true;
407 }
408 
409 /// Computes a topologically sorted iteration graph for the linalg operation.
410 /// Ensures all tensors are visited in natural index order. This is essential
411 /// for sparse storage formats since these only support access along fixed
412 /// dimensions. Even for dense storage formats, however, the natural index
413 /// order yields innermost unit-stride access with better spatial locality.
414 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
415                                   std::vector<unsigned> &topSort,
416                                   bool sparseOnly) {
417   // Set up an n x n from/to adjacency matrix of the iteration graph
418   // for the implicit loop indices i_0 .. i_n-1.
419   unsigned n = op.getNumLoops();
420   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
421 
422   // Iterate over the indexing maps of every tensor in the tensor expression.
423   for (OpOperand *t : op.getInputAndOutputOperands()) {
424     auto map = op.getTiedIndexingMap(t);
425     auto enc = getSparseTensorEncoding(t->get().getType());
426     assert(map.getNumDims() == n);
427     // Skip dense tensor constraints when sparse only is requested.
428     if (sparseOnly && !enc)
429       continue;
430     // Each tensor expression and optional dimension ordering (row-major
431     // by default) puts an ordering constraint on the loop indices. For
432     // example, the tensor expresion A_ijk forces the ordering i < j < k
433     // on the loop indices if no explicit dimension ordering is given.
434     for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
435       unsigned f = map.getDimPosition(perm(enc, d - 1));
436       unsigned t = map.getDimPosition(perm(enc, d));
437       adjM[f][t] = true;
438     }
439   }
440 
441   // Topologically sort the iteration graph to determine loop order.
442   // Report failure for a cyclic iteration graph.
443   topSort.clear();
444   topSort.reserve(n);
445   std::vector<unsigned> visit(n, 0);
446   for (unsigned i = 0; i < n; i++)
447     if (visit[i] == 0)
448       if (!topSortDFS(i, visit, topSort, adjM))
449         return false; // cycle!
450   std::reverse(std::begin(topSort), std::end(topSort));
451   return true;
452 }
453 
454 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
455 /// This simplifies constructing (sub)expressions during iteration lattice
456 /// building (compared to using the SSA representation everywhere).
457 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
458                                          Value val) {
459   if (auto arg = val.dyn_cast<BlockArgument>()) {
460     unsigned argN = arg.getArgNumber();
461     // Any argument of the generic op that is not marked as a scalar
462     // argument is considered a tensor, indexed by the implicit loop
463     // bounds. This includes rank-0 tensor arguments.
464     if (arg.getOwner()->getParentOp() == op) {
465       OpOperand *t = op.getInputAndOutputOperands()[argN];
466       if (!op.isScalar(t))
467         return merger.addExp(Kind::kTensor, argN);
468       val = t->get(); // get scalar value
469     }
470     // Any other argument (marked as scalar argument for the generic op
471     // or belonging to an enveloping op) is considered invariant.
472     return merger.addExp(Kind::kInvariant, val);
473   }
474   Operation *def = val.getDefiningOp();
475   if (def->getBlock() != &op.region().front()) {
476     // Something defined outside is invariant.
477     return merger.addExp(Kind::kInvariant, val);
478   } else if (def->getNumOperands() == 2) {
479     // Construct binary operations if subexpressions could be built.
480     auto x = buildTensorExp(merger, op, def->getOperand(0));
481     auto y = buildTensorExp(merger, op, def->getOperand(1));
482     if (x.hasValue() && y.hasValue()) {
483       unsigned e0 = x.getValue();
484       unsigned e1 = y.getValue();
485       if (isa<MulFOp>(def))
486         return merger.addExp(Kind::kMulF, e0, e1);
487       if (isa<MulIOp>(def))
488         return merger.addExp(Kind::kMulI, e0, e1);
489       if (isa<AddFOp>(def))
490         return merger.addExp(Kind::kAddF, e0, e1);
491       if (isa<AddIOp>(def))
492         return merger.addExp(Kind::kAddI, e0, e1);
493     }
494   }
495   // Cannot build (yet).
496   return None;
497 }
498 
499 /// Builds the iteration lattices in a bottom-up traversal given the remaining
500 /// tensor (sub)expression and the next loop index in the iteration graph.
501 static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
502                               unsigned exp, unsigned idx) {
503   Kind kind = merger.exp(exp).kind;
504   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
505     // Either the index is really used in the tensor expression, or it is
506     // set to the undefined index in that dimension. An invariant expression
507     // is set to a synthetic tensor with undefined indices only.
508     unsigned s = merger.addSet();
509     unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0
510                                        : op.getNumInputsAndOutputs();
511     merger.set(s).push_back(merger.addLat(t, idx, exp));
512     return s;
513   }
514   unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx);
515   unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx);
516   switch (kind) {
517   case Kind::kTensor:
518   case Kind::kInvariant:
519     llvm_unreachable("handled above");
520   case Kind::kMulF:
521   case Kind::kMulI:
522     return merger.takeConj(kind, s0, s1);
523   case Kind::kAddF:
524   case Kind::kAddI:
525     return merger.takeDisj(kind, s0, s1);
526   }
527   llvm_unreachable("unexpected expression kind");
528 }
529 
530 /// Maps sparse integer option to actual integral storage type.
531 static Type genIntType(PatternRewriter &rewriter, unsigned width) {
532   if (width == 0)
533     return rewriter.getIndexType();
534   return rewriter.getIntegerType(width);
535 }
536 
537 /// Detects in-place annotation on tensor argument.
538 static bool getInPlace(Value val) {
539   if (auto arg = val.dyn_cast<BlockArgument>())
540     if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
541       if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
542               arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName))
543         return attr.getValue();
544   return false;
545 }
546 
547 /// Generates buffer for the output tensor.
548 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
549                              linalg::GenericOp op, MemRefType denseTp,
550                              ArrayRef<Value> args) {
551   Location loc = op.getLoc();
552   Value tensor = op.getOutputOperand(0)->get();
553   // The output tensor simply could materialize from the buffer that will
554   // be generated for the tensor present in the outs() clause. This has
555   // the major advantage that the sparse kernel only updates the nonzero
556   // positions for the output tensor.
557   if (getInPlace(tensor))
558     return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
559   // By default, a new buffer is allocated which is initialized to the
560   // tensor defined in the outs() clause. This is always correct but
561   // introduces a dense initialization component that may negatively
562   // impact the running complexity of the sparse kernel.
563   Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
564   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
565   rewriter.create<linalg::CopyOp>(loc, init, alloc);
566   return alloc;
567 }
568 
569 /// Local bufferization of all dense and sparse data structures.
570 /// This code enables testing the first prototype sparse compiler.
571 // TODO: replace this with a proliferated bufferization strategy
572 static bool genBuffers(Merger &merger, CodeGen &codegen,
573                        PatternRewriter &rewriter, linalg::GenericOp op) {
574   Location loc = op.getLoc();
575   assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
576   // For every tensor, find lower and upper bound on dimensions, set the
577   // same bounds on loop indices, and obtain dense or sparse buffer(s).
578   SmallVector<Value, 4> args;
579   for (OpOperand *t : op.getInputAndOutputOperands()) {
580     unsigned tensor = t->getOperandNumber();
581     auto shape = op.getShape(t);
582     auto map = op.getTiedIndexingMap(t);
583     auto enc = getSparseTensorEncoding(t->get().getType());
584     // Scan all dimensions of current tensor.
585     args.clear();
586     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
587       unsigned idx = map.getDimPosition(perm(enc, d));
588       // Handle sparse storage schemes.
589       if (merger.isDim(tensor, idx, Dim::kSparse)) {
590         auto dynShape = {ShapedType::kDynamicSize};
591         auto ptrTp = MemRefType::get(
592             dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
593         auto indTp = MemRefType::get(
594             dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
595         Value dim = rewriter.create<ConstantIndexOp>(loc, d);
596         // Generate sparse primitives to obtains pointer and indices.
597         codegen.pointers[tensor][idx] =
598             rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
599         codegen.indices[tensor][idx] =
600             rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim);
601       }
602       // Find lower and upper bound in current dimension.
603       Value up;
604       if (shape[d] == MemRefType::kDynamicSize) {
605         up = rewriter.create<memref::DimOp>(loc, t->get(), d);
606         args.push_back(up);
607       } else {
608         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
609       }
610       codegen.sizes[idx] = codegen.highs[tensor][idx] = up;
611     }
612     // Perform the required bufferization. Dense inputs materialize
613     // from the input tensors. Dense outputs need special handling.
614     // Sparse inputs use sparse primitives to obtain the values.
615     // We also accept in-place all-dense annotated "sparse" outputs.
616     Type elementType = getElementTypeOrSelf(t->get().getType());
617     if (!enc) {
618       // Non-annotated dense tensors.
619       auto denseTp = MemRefType::get(shape, elementType);
620       if (tensor < op.getNumInputs())
621         codegen.buffers[tensor] =
622             rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get());
623       else
624         codegen.buffers[tensor] =
625             genOutputBuffer(codegen, rewriter, op, denseTp, args);
626     } else {
627       // Annotated sparse tensors.
628       if (tensor == op.getNumInputs() && !getInPlace(t->get()))
629         return false; // reject output if not in-place
630       auto dynShape = {ShapedType::kDynamicSize};
631       auto sparseTp = MemRefType::get(dynShape, elementType);
632       codegen.buffers[tensor] =
633           rewriter.create<ToValuesOp>(loc, sparseTp, t->get());
634     }
635   }
636   return true;
637 }
638 
639 /// Constructs vector type.
640 static VectorType vectorType(CodeGen &codegen, Type etp) {
641   return VectorType::get(codegen.curVecLength, etp);
642 }
643 
644 /// Constructs vector type from pointer.
645 static VectorType vectorType(CodeGen &codegen, Value ptr) {
646   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
647 }
648 
649 /// Constructs vector iteration mask.
650 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
651                            Value iv, Value lo, Value hi, Value step) {
652   Location loc = iv.getLoc();
653   VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1));
654   // Special case if the vector length evenly divides the trip count (for
655   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
656   // so that all subsequent masked memory operations are immediately folded
657   // into unconditional memory operations.
658   IntegerAttr loInt, hiInt, stepInt;
659   if (matchPattern(lo, m_Constant(&loInt)) &&
660       matchPattern(hi, m_Constant(&hiInt)) &&
661       matchPattern(step, m_Constant(&stepInt))) {
662     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
663       return rewriter.create<vector::BroadcastOp>(
664           loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1));
665   }
666   // Otherwise, generate a vector mask that avoids overrunning the upperbound
667   // during vector execution. Here we rely on subsequent loop optimizations to
668   // avoid executing the mask in all iterations, for example, by splitting the
669   // loop into an unconditional vector loop and a scalar cleanup loop.
670   Value end = rewriter.create<SubIOp>(loc, hi, iv);
671   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
672 }
673 
674 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
675 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
676                            Value ptr, ArrayRef<Value> args) {
677   Location loc = ptr.getLoc();
678   VectorType vtp = vectorType(codegen, ptr);
679   Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
680   if (args.back().getType().isa<VectorType>()) {
681     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
682     Value indexVec = args.back();
683     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
684     return rewriter.create<vector::GatherOp>(
685         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
686   }
687   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
688                                                codegen.curVecMask, pass);
689 }
690 
691 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
692 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
693                            Value rhs, Value ptr, ArrayRef<Value> args) {
694   Location loc = ptr.getLoc();
695   if (args.back().getType().isa<VectorType>()) {
696     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
697     Value indexVec = args.back();
698     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
699     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
700                                        codegen.curVecMask, rhs);
701     return;
702   }
703   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
704                                          rhs);
705 }
706 
707 /// Generates a vectorized invariant. Here we rely on subsequent loop
708 /// optimizations to hoist the invariant broadcast out of the vector loop.
709 static Value genVectorInvariantValue(CodeGen &codegen,
710                                      PatternRewriter &rewriter, Value val) {
711   VectorType vtp = vectorType(codegen, val.getType());
712   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
713 }
714 
715 /// Generates a load on a dense or sparse tensor.
716 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
717                            PatternRewriter &rewriter, linalg::GenericOp op,
718                            unsigned exp) {
719   // Test if the load was hoisted to a higher loop nest.
720   Value val = merger.exp(exp).val;
721   if (val) {
722     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
723       return genVectorInvariantValue(codegen, rewriter, val);
724     return val;
725   }
726   // Actual load.
727   SmallVector<Value, 4> args;
728   OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
729   unsigned tensor = t->getOperandNumber();
730   auto map = op.getTiedIndexingMap(t);
731   auto enc = getSparseTensorEncoding(t->get().getType());
732   unsigned rank = map.getNumResults();
733   if (enc) {
734     unsigned idx = map.getDimPosition(perm(enc, rank - 1));
735     assert(codegen.pidxs[tensor][idx] != nullptr);
736     args.push_back(codegen.pidxs[tensor][idx]); // position index
737   } else {
738     for (unsigned d = 0; d < rank; d++) {
739       unsigned idx = map.getDimPosition(d);
740       args.push_back(codegen.loops[idx]); // universal dense index
741     }
742   }
743   Location loc = op.getLoc();
744   Value ptr = codegen.buffers[tensor];
745   if (codegen.curVecLength > 1)
746     return genVectorLoad(codegen, rewriter, ptr, args);
747   return rewriter.create<memref::LoadOp>(loc, ptr, args);
748 }
749 
750 /// Generates a store on a dense or sparse tensor.
751 static void genTensorStore(Merger &merger, CodeGen &codegen,
752                            PatternRewriter &rewriter, linalg::GenericOp op,
753                            OpOperand *t, Value rhs) {
754   Location loc = op.getLoc();
755   // Test if this is a scalarized reduction.
756   OpOperand *lhs = op.getOutputOperand(0);
757   if (lhs == t && codegen.redVal) {
758     if (codegen.curVecLength > 1)
759       rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs,
760                                       codegen.redVal);
761     codegen.redVal = rhs;
762     return;
763   }
764   // Actual store.
765   SmallVector<Value, 4> args;
766   unsigned tensor = t->getOperandNumber();
767   auto map = op.getTiedIndexingMap(t);
768   auto enc = getSparseTensorEncoding(t->get().getType());
769   unsigned rank = map.getNumResults();
770   if (enc) {
771     unsigned idx = map.getDimPosition(perm(enc, rank - 1));
772     assert(codegen.pidxs[tensor][idx] != nullptr);
773     args.push_back(codegen.pidxs[tensor][idx]); // position index
774   } else {
775     for (unsigned d = 0; d < rank; d++) {
776       unsigned idx = map.getDimPosition(d);
777       args.push_back(codegen.loops[idx]); // universal dense index
778     }
779   }
780   Value ptr = codegen.buffers[tensor];
781   if (codegen.curVecLength > 1)
782     genVectorStore(codegen, rewriter, rhs, ptr, args);
783   else
784     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
785 }
786 
787 /// Generates a pointer/index load from the sparse storage scheme. Narrower
788 /// data types need to be zero extended before casting the value into the
789 /// index type used for looping and indexing.
790 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
791                      Value ptr, Value s) {
792   // See https://llvm.org/docs/GetElementPtr.html for some background on
793   // the complications described below.
794   if (codegen.curVecLength > 1) {
795     // Since the index vector is used in a subsequent gather/scatter operations,
796     // which effectively defines an unsigned pointer + signed index, we must
797     // zero extend the vector to an index width. For 8-bit and 16-bit values,
798     // an 32-bit index width suffices. For 32-bit values, zero extending the
799     // elements into 64-bit loses some performance since the 32-bit indexed
800     // gather/scatter is more efficient than the 64-bit index variant (if the
801     // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
802     // preserve this performance). For 64-bit values, there is no good way
803     // to state that the indices are unsigned, with creates the potential of
804     // incorrect address calculations in the unlikely case we need such
805     // extremely large offsets.
806     Type etp = ptr.getType().cast<MemRefType>().getElementType();
807     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
808     if (!etp.isa<IndexType>()) {
809       if (etp.getIntOrFloatBitWidth() < 32)
810         vload = rewriter.create<ZeroExtendIOp>(
811             loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
812       else if (etp.getIntOrFloatBitWidth() < 64 &&
813                !codegen.options.enableSIMDIndex32)
814         vload = rewriter.create<ZeroExtendIOp>(
815             loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
816     }
817     return vload;
818   }
819   // For the scalar case, we simply zero extend narrower indices into 64-bit
820   // values before casting to index without a performance penalty. Here too,
821   // however, indices that already are 64-bit, in theory, cannot express the
822   // full range as explained above.
823   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
824   if (!load.getType().isa<IndexType>()) {
825     if (load.getType().getIntOrFloatBitWidth() < 64)
826       load = rewriter.create<ZeroExtendIOp>(loc, load,
827                                             rewriter.getIntegerType(64));
828     load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
829   }
830   return load;
831 }
832 
833 /// Generates an invariant value.
834 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
835                                PatternRewriter &rewriter, unsigned exp) {
836   Value val = merger.exp(exp).val;
837   if (codegen.curVecLength > 1)
838     return genVectorInvariantValue(codegen, rewriter, val);
839   return val;
840 }
841 
842 /// Generates an address computation "sz * p + i".
843 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
844                         Location loc, Value size, Value p, Value i) {
845   Value mul = rewriter.create<MulIOp>(loc, size, p);
846   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
847     Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType());
848     mul = genVectorInvariantValue(codegen, rewriter, inv);
849   }
850   return rewriter.create<AddIOp>(loc, mul, i);
851 }
852 
853 /// Generates start of a reduction.
854 static Value genReductionStart(Merger &merger, CodeGen &codegen,
855                                PatternRewriter &rewriter,
856                                linalg::GenericOp op) {
857   if (codegen.redVal)
858     return codegen.redVal; // chained with previous for-loop
859   if (codegen.curVecLength > 1) {
860     // TODO: assumes + reductions for now
861     VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]);
862     return rewriter.create<ConstantOp>(op.getLoc(), vtp,
863                                        rewriter.getZeroAttr(vtp));
864   }
865   return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
866 }
867 
868 /// Generates end of a reduction.
869 static void genReductionEnd(Merger &merger, CodeGen &codegen,
870                             PatternRewriter &rewriter, linalg::GenericOp op) {
871   Value red = codegen.redVal;
872   if (!red)
873     return;
874   assert(codegen.curVecLength == 1);
875   codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
876   OpOperand *lhs = op.getOutputOperand(0);
877   if (auto vtp = red.getType().dyn_cast<VectorType>()) {
878     // TODO: assumes + reductions for now
879     StringAttr kind = rewriter.getStringAttr("add");
880     Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
881     // Integer reductions don't accept an accumulator.
882     if (vtp.getElementType().isa<IntegerType>()) {
883       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
884                                                  kind, red, ValueRange{});
885       red = rewriter.create<AddIOp>(op.getLoc(), red, ld);
886     } else {
887       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
888                                                  kind, red, ld);
889     }
890   }
891   genTensorStore(merger, codegen, rewriter, op, lhs, red);
892 }
893 
894 /// Recursively generates tensor expression.
895 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
896                     linalg::GenericOp op, unsigned exp) {
897   if (merger.exp(exp).kind == Kind::kTensor)
898     return genTensorLoad(merger, codegen, rewriter, op, exp);
899   else if (merger.exp(exp).kind == Kind::kInvariant)
900     return genInvariantValue(merger, codegen, rewriter, exp);
901   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
902   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
903   switch (merger.exp(exp).kind) {
904   case Kind::kTensor:
905   case Kind::kInvariant:
906     llvm_unreachable("handled above");
907   case Kind::kMulF:
908     return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
909   case Kind::kMulI:
910     return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
911   case Kind::kAddF:
912     return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
913   case Kind::kAddI:
914     return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
915   }
916   llvm_unreachable("unexpected expression kind");
917 }
918 
919 /// Hoists loop invariant tensor loads for which indices have been exhausted.
920 static void genInvariants(Merger &merger, CodeGen &codegen,
921                           PatternRewriter &rewriter, linalg::GenericOp op,
922                           unsigned exp, unsigned ldx, bool hoist) {
923   if (merger.exp(exp).kind == Kind::kTensor) {
924     // Inspect tensor indices.
925     bool atLevel = ldx == -1u;
926     OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
927     auto map = op.getTiedIndexingMap(t);
928     auto enc = getSparseTensorEncoding(t->get().getType());
929     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
930       unsigned idx = map.getDimPosition(perm(enc, d));
931       if (!codegen.loops[idx])
932         return; // still in play
933       else if (idx == ldx)
934         atLevel = true;
935     }
936     // All exhausted at this level (atLevel denotes exactly at this level).
937     OpOperand *lhs = op.getOutputOperand(0);
938     if (lhs == t) {
939       codegen.redExp = hoist ? exp : -1u;
940     } else if (atLevel) {
941       merger.exp(exp).val =
942           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
943     }
944   } else if (merger.exp(exp).kind != Kind::kInvariant) {
945     // Traverse into the binary operations. Note that we only hoist
946     // tensor loads, since subsequent MLIR/LLVM passes know how to
947     // deal with all other kinds of derived loop invariants.
948     unsigned e0 = merger.exp(exp).e0;
949     unsigned e1 = merger.exp(exp).e1;
950     genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist);
951     genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist);
952   }
953 }
954 
955 /// Generates initialization code for the subsequent loop sequence at
956 /// current index level. Returns true if the loop sequence needs to
957 /// maintain the universal index.
958 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
959                     linalg::GenericOp op, std::vector<unsigned> &topSort,
960                     unsigned at, llvm::BitVector &inits) {
961   bool needsUniv = false;
962   Location loc = op.getLoc();
963   unsigned idx = topSort[at];
964 
965   // Initialize sparse positions.
966   for (unsigned b = 0, be = inits.size(); b < be; b++) {
967     if (inits[b]) {
968       unsigned tensor = merger.tensor(b);
969       assert(idx == merger.index(b));
970       if (merger.isDim(b, Dim::kSparse)) {
971         // Initialize sparse index.
972         unsigned pat = at;
973         for (; pat != 0; pat--) {
974           if (codegen.pidxs[tensor][topSort[pat - 1]])
975             break;
976         }
977         Value ptr = codegen.pointers[tensor][idx];
978         Value one = rewriter.create<ConstantIndexOp>(loc, 1);
979         Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
980                               : codegen.pidxs[tensor][topSort[pat - 1]];
981         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
982         Value p1 = rewriter.create<AddIOp>(loc, p0, one);
983         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
984       } else {
985         // Dense index still in play.
986         needsUniv = true;
987       }
988     }
989   }
990 
991   // Initialize the universal dense index.
992   codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0);
993   return needsUniv;
994 }
995 
996 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
997 /// operation is a candidate. Whether it is actually converted to SIMD code
998 /// depends on the requested strategy.
999 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
1000   switch (codegen.options.vectorizationStrategy) {
1001   case SparseVectorizationStrategy::kNone:
1002     return false;
1003   case SparseVectorizationStrategy::kDenseInnerLoop:
1004     return isInner && !isSparse;
1005   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
1006     return isInner;
1007   }
1008   llvm_unreachable("unexpected vectorization strategy");
1009 }
1010 
1011 /// Returns parallelization strategy. Any implicit loop in the Linalg operation
1012 /// that is marked "parallel" is a candidate. Whether it is actually converted
1013 /// to a parallel operation depends on the requested strategy.
1014 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
1015                           bool isSparse, bool isVector) {
1016   switch (codegen.options.parallelizationStrategy) {
1017   case SparseParallelizationStrategy::kNone:
1018     return false;
1019   case SparseParallelizationStrategy::kDenseOuterLoop:
1020     return isOuter && !isSparse && !isReduction && !isVector;
1021   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
1022     return isOuter && !isReduction && !isVector;
1023   case SparseParallelizationStrategy::kDenseAnyLoop:
1024     return !isSparse && !isReduction && !isVector;
1025   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
1026     return !isReduction && !isVector;
1027   }
1028   llvm_unreachable("unexpected parallelization strategy");
1029 }
1030 
1031 /// Checks unit strides for dense tensors. The iteration graph may have ignored
1032 /// dense access patterns in order to avoid cycles (sparse access patterns are
1033 /// always placed innermost), but that means dense access has become strided.
1034 /// For now, we reject vectorization of such cases.
1035 /// TODO: implement strided load/stores on dense arrays
1036 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
1037                              unsigned idx) {
1038   for (OpOperand *t : op.getInputAndOutputOperands()) {
1039     if (!getSparseTensorEncoding(t->get().getType())) {
1040       auto map = op.getTiedIndexingMap(t);
1041       for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
1042         if (map.getDimPosition(d) == idx && d != rank - 1)
1043           return false;
1044       }
1045     }
1046   }
1047   return true;
1048 }
1049 
1050 /// Generates a for-loop on a single index.
1051 static Operation *genFor(Merger &merger, CodeGen &codegen,
1052                          PatternRewriter &rewriter, linalg::GenericOp op,
1053                          bool isOuter, bool isInner, unsigned idx,
1054                          llvm::BitVector &indices) {
1055   unsigned fb = indices.find_first();
1056   unsigned tensor = merger.tensor(fb);
1057   assert(idx == merger.index(fb));
1058   auto iteratorTypes = op.iterator_types().getValue();
1059   bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]);
1060   bool isSparse = merger.isDim(fb, Dim::kSparse);
1061   bool isVector = isVectorFor(codegen, isInner, isSparse) &&
1062                   denseUnitStrides(merger, op, idx);
1063   bool isParallel =
1064       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
1065 
1066   // Prepare vector length.
1067   if (isVector)
1068     codegen.curVecLength = codegen.options.vectorLength;
1069 
1070   // Loop bounds and increment.
1071   Location loc = op.getLoc();
1072   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
1073   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
1074   Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength);
1075 
1076   // Emit a parallel loop.
1077   if (isParallel) {
1078     assert(!isVector);
1079     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
1080     if (isSparse)
1081       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
1082     else
1083       codegen.loops[idx] = parOp.getInductionVars()[0];
1084     rewriter.setInsertionPointToStart(parOp.getBody());
1085     return parOp;
1086   }
1087 
1088   // Emit a sequential loop, potentially with a scalarized reduction.
1089   bool scalarRed = isInner && codegen.redExp != -1u;
1090   SmallVector<Value, 4> operands;
1091   if (scalarRed) {
1092     Value load = genReductionStart(merger, codegen, rewriter, op);
1093     operands.push_back(load);
1094   }
1095   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
1096   if (scalarRed) {
1097     codegen.redVal = merger.exp(codegen.redExp).val =
1098         forOp.getRegionIterArgs().front();
1099   }
1100   // Assign induction variable to sparse or dense index.
1101   Value iv = forOp.getInductionVar();
1102   if (isSparse)
1103     codegen.pidxs[tensor][idx] = iv;
1104   else
1105     codegen.loops[idx] = iv;
1106   rewriter.setInsertionPointToStart(forOp.getBody());
1107   // Share vector iteration mask between all subsequent loads/stores.
1108   if (isVector)
1109     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
1110   return forOp;
1111 }
1112 
1113 /// Emit a while-loop for co-iteration over multiple indices.
1114 static Operation *genWhile(Merger &merger, CodeGen &codegen,
1115                            PatternRewriter &rewriter, linalg::GenericOp op,
1116                            unsigned idx, bool needsUniv,
1117                            llvm::BitVector &indices) {
1118   SmallVector<Type, 4> types;
1119   SmallVector<Value, 4> operands;
1120   // Construct the while-loop with a parameter for each index.
1121   Type indexType = rewriter.getIndexType();
1122   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1123     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1124       unsigned tensor = merger.tensor(b);
1125       assert(idx == merger.index(b));
1126       types.push_back(indexType);
1127       assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
1128              "type mismatch for sparse index");
1129       operands.push_back(codegen.pidxs[tensor][idx]);
1130     }
1131   }
1132   if (needsUniv) {
1133     types.push_back(indexType);
1134     assert(codegen.loops[idx].getType().isa<IndexType>() &&
1135            "type mismatch for universal index");
1136     operands.push_back(codegen.loops[idx]);
1137   }
1138   Location loc = op.getLoc();
1139   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
1140   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
1141   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
1142 
1143   // Build the "before" region, which effectively consists
1144   // of a conjunction of "i < upper" tests on all induction.
1145   rewriter.setInsertionPointToStart(&whileOp.before().front());
1146   Value cond;
1147   unsigned o = 0;
1148   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1149     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1150       unsigned tensor = merger.tensor(b);
1151       assert(idx == merger.index(b));
1152       Value op1 = before->getArgument(o);
1153       Value op2 = codegen.highs[tensor][idx];
1154       Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2);
1155       cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc;
1156       codegen.pidxs[tensor][idx] = after->getArgument(o++);
1157     }
1158   }
1159   if (needsUniv)
1160     codegen.loops[idx] = after->getArgument(o++);
1161   assert(o == operands.size());
1162   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
1163   rewriter.setInsertionPointToStart(&whileOp.after().front());
1164   return whileOp;
1165 }
1166 
1167 /// Generates a for-loop or a while-loop, depending on whether it implements
1168 /// singleton iteration or co-iteration over the given conjunction.
1169 static Operation *genLoop(Merger &merger, CodeGen &codegen,
1170                           PatternRewriter &rewriter, linalg::GenericOp op,
1171                           std::vector<unsigned> &topSort, unsigned at,
1172                           bool needsUniv, llvm::BitVector &indices) {
1173   unsigned idx = topSort[at];
1174   if (indices.count() == 1) {
1175     bool isOuter = at == 0;
1176     bool isInner = at == topSort.size() - 1;
1177     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
1178                   indices);
1179   }
1180   genReductionEnd(merger, codegen, rewriter, op); // cannot chain
1181   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
1182 }
1183 
1184 /// Generates the local variables for this loop, consisting of the sparse
1185 /// indices, restored universal dense index, and dense positions.
1186 static void genLocals(Merger &merger, CodeGen &codegen,
1187                       PatternRewriter &rewriter, linalg::GenericOp op,
1188                       std::vector<unsigned> &topSort, unsigned at,
1189                       bool needsUniv, llvm::BitVector &locals) {
1190   Location loc = op.getLoc();
1191   unsigned idx = topSort[at];
1192 
1193   // Initialize sparse indices.
1194   Value min;
1195   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1196     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1197       unsigned tensor = merger.tensor(b);
1198       assert(idx == merger.index(b));
1199       Value ptr = codegen.indices[tensor][idx];
1200       Value s = codegen.pidxs[tensor][idx];
1201       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1202       codegen.idxs[tensor][idx] = load;
1203       if (!needsUniv) {
1204         if (min) {
1205           Value cmp =
1206               rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
1207           min = rewriter.create<SelectOp>(loc, cmp, load, min);
1208         } else {
1209           min = load;
1210         }
1211       }
1212     }
1213   }
1214 
1215   // Merge dense universal index over minimum.
1216   if (min) {
1217     assert(!needsUniv);
1218     codegen.loops[idx] = min;
1219   }
1220 
1221   // Initialize dense positions. Note that we generate dense indices of the
1222   // output tensor unconditionally, since they may not appear in the lattice,
1223   // but may be needed for linearized codegen.
1224   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1225     if ((locals[b] || merger.isOutTensor(b, idx)) &&
1226         merger.isDim(b, Dim::kDense)) {
1227       unsigned tensor = merger.tensor(b);
1228       assert(idx == merger.index(b));
1229       unsigned pat = at;
1230       for (; pat != 0; pat--)
1231         if (codegen.pidxs[tensor][topSort[pat - 1]])
1232           break;
1233       Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
1234                            : codegen.pidxs[tensor][topSort[pat - 1]];
1235       codegen.pidxs[tensor][idx] = genAddress(
1236           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1237     }
1238   }
1239 }
1240 
1241 /// Generates the induction structure for a while-loop.
1242 static void genWhileInduction(Merger &merger, CodeGen &codegen,
1243                               PatternRewriter &rewriter, linalg::GenericOp op,
1244                               unsigned idx, bool needsUniv,
1245                               llvm::BitVector &induction, ResultRange results) {
1246   Location loc = op.getLoc();
1247   unsigned o = 0;
1248   SmallVector<Value, 4> operands;
1249   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
1250   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1251     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1252       unsigned tensor = merger.tensor(b);
1253       assert(idx == merger.index(b));
1254       Value op1 = codegen.idxs[tensor][idx];
1255       Value op2 = codegen.loops[idx];
1256       Value op3 = codegen.pidxs[tensor][idx];
1257       Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1258       Value add = rewriter.create<AddIOp>(loc, op3, one);
1259       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
1260       codegen.pidxs[tensor][idx] = results[o++];
1261     }
1262   }
1263   if (needsUniv) {
1264     operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one));
1265     codegen.loops[idx] = results[o++];
1266   }
1267   assert(o == operands.size());
1268   rewriter.create<scf::YieldOp>(loc, operands);
1269 }
1270 
1271 /// Generates a single if-statement within a while-loop.
1272 static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1273                        PatternRewriter &rewriter, linalg::GenericOp op,
1274                        unsigned idx, llvm::BitVector &conditions) {
1275   Location loc = op.getLoc();
1276   Value cond;
1277   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1278     if (conditions[b]) {
1279       unsigned tensor = merger.tensor(b);
1280       assert(idx == merger.index(b));
1281       Value clause;
1282       if (merger.isDim(b, Dim::kSparse)) {
1283         Value op1 = codegen.idxs[tensor][idx];
1284         Value op2 = codegen.loops[idx];
1285         clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1286       } else {
1287         clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true
1288       }
1289       cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause;
1290     }
1291   }
1292   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
1293   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
1294   return ifOp;
1295 }
1296 
1297 /// Recursively generates code while computing iteration lattices in order
1298 /// to manage the complexity of implementing co-iteration over unions
1299 /// and intersections of sparse iterations spaces.
1300 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1301                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1302                     unsigned exp, unsigned at) {
1303   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1304   if (at == topSort.size()) {
1305     OpOperand *lhs = op.getOutputOperand(0);
1306     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1307     genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
1308     return;
1309   }
1310   assert(codegen.curVecLength == 1);
1311 
1312   // Construct iteration lattices for current loop index, with L0 at top.
1313   // Then emit initialization code for the loop sequence at this level.
1314   // We maintain the universal dense index if dense indices are still
1315   // in play for a non-singleton loop sequence.
1316   Location loc = op.getLoc();
1317   unsigned idx = topSort[at];
1318   unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx));
1319   unsigned lsize = merger.set(lts).size();
1320   assert(lsize != 0);
1321   unsigned l0 = merger.set(lts)[0];
1322   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1323   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
1324   bool needsUniv = false;
1325   if (genInit(merger, codegen, rewriter, op, topSort, at,
1326               merger.lat(l0).bits)) {
1327     // Maintain the universal index only if it is actually
1328     // consumed by a subsequent lattice point.
1329     for (unsigned i = 1; i < lsize; i++) {
1330       unsigned li = merger.set(lts)[i];
1331       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) {
1332         needsUniv = true;
1333         break;
1334       }
1335     }
1336   }
1337 
1338   // Emit a loop for every lattice point L0 >= Li.
1339   for (unsigned i = 0; i < lsize; i++) {
1340     unsigned li = merger.set(lts)[i];
1341 
1342     // Emit loop.
1343     codegen.curVecLength = 1;
1344     llvm::BitVector indices = merger.lat(li).simple;
1345     Operation *loop =
1346         genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
1347     genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1348               merger.lat(li).bits);
1349 
1350     // Visit all lattices points with Li >= Lj to generate the
1351     // loop-body, possibly with if statements for coiteration.
1352     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1353     for (unsigned j = 0; j < lsize; j++) {
1354       unsigned lj = merger.set(lts)[j];
1355       unsigned ej = merger.lat(lj).exp;
1356       if (li == lj || merger.latGT(li, lj)) {
1357         // Recurse into body of each branch.
1358         if (isWhile) {
1359           scf::IfOp ifOp =
1360               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1361           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1362           rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
1363         } else {
1364           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1365         }
1366       }
1367     }
1368 
1369     // Wrap-up induction and restore insertion point.
1370     if (isWhile) {
1371       scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
1372       rewriter.setInsertionPointToEnd(&whileOp.after().front());
1373       genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
1374                         merger.lat(li).bits, whileOp.results());
1375     } else {
1376       needsUniv = false;
1377       if (codegen.redVal) {
1378         rewriter.create<scf::YieldOp>(loc, codegen.redVal);
1379         codegen.redVal = loop->getResult(0);
1380       }
1381     }
1382     rewriter.setInsertionPointAfter(loop);
1383   }
1384 
1385   // Wrap-up loop sequence.
1386   codegen.curVecLength = 1;
1387   genReductionEnd(merger, codegen, rewriter, op);
1388   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
1389   codegen.loops[idx] = Value();
1390 }
1391 
1392 /// Converts the result computed by the sparse kernel into the required form.
1393 static void genResult(CodeGen &codegen, PatternRewriter &rewriter,
1394                       linalg::GenericOp op) {
1395   RankedTensorType resType = op.getOutputTensorTypes()[0];
1396   Value result = codegen.buffers.back();
1397   if (getSparseTensorEncoding(resType))
1398     result = rewriter.create<ToTensorOp>(op.getLoc(), resType, result);
1399   else
1400     result =
1401         rewriter.create<memref::TensorLoadOp>(op.getLoc(), resType, result);
1402   rewriter.replaceOp(op, result);
1403 }
1404 
1405 namespace {
1406 
1407 /// Sparse rewriting rule for generic Lingalg operation.
1408 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1409 public:
1410   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1411       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1412 
1413   LogicalResult matchAndRewrite(linalg::GenericOp op,
1414                                 PatternRewriter &rewriter) const override {
1415     // Detects sparse annotations and translate the per-dimension sparsity
1416     // information for all tensors to loop indices in the kernel.
1417     assert(op.getNumOutputs() == 1);
1418     unsigned numTensors = op.getNumInputsAndOutputs();
1419     unsigned numLoops = op.iterator_types().getValue().size();
1420     Merger merger(numTensors, numLoops);
1421     if (!findSparseAnnotations(merger, op))
1422       return failure();
1423 
1424     // Computes a topologically sorted iteration graph to ensure
1425     // tensors are visited in natural index order. Fails on cycles.
1426     // This assumes that higher-level passes have already put the
1427     // tensors in each tensor expression in a feasible order.
1428     std::vector<unsigned> topSort;
1429     if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
1430         !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
1431       return failure();
1432 
1433     // Finds the terminating yield statement and builds the tensor
1434     // expression for the Linalg operation in SSA form.
1435     Operation *yield = op.region().front().getTerminator();
1436     Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
1437     if (!exp.hasValue())
1438       return failure(); // build failure
1439 
1440     // Recursively generates code.
1441     CodeGen codegen(options, numTensors, numLoops);
1442     if (!genBuffers(merger, codegen, rewriter, op))
1443       return failure(); // could not bufferize
1444     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
1445     genResult(codegen, rewriter, op);
1446     return success();
1447   }
1448 
1449 private:
1450   /// Options to control sparse code generation.
1451   SparsificationOptions options;
1452 };
1453 
1454 } // namespace
1455 
1456 /// Populates the given patterns list with rewriting rules required for
1457 /// the sparsification of linear algebra operations.
1458 void mlir::populateSparsificationPatterns(
1459     RewritePatternSet &patterns, const SparsificationOptions &options) {
1460   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1461 }
1462