1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering sparse tensor types to actual sparse code.
10 //
11 // The concept of letting a compiler generate sparse code automatically was
12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and
13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor
14 // Algebra Compiler (TACO). The implementation in this file closely follows
15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting
16 // rule is applied to each tensor expression in linalg (MLIR's tensor index
17 // notation) where the sparsity of tensors is indicated with annotation using
18 // a per-dimension specification of sparse/dense storage together with a
19 // specification of the order on the dimensions. Subsequently, a topologically
20 // sorted iteration graph, reflecting the required order on indices with respect
21 // to the dimensions of each tensor, is constructed to ensure that all tensors
22 // are visited in natural index order. Next, iteration lattices are constructed
23 // for the tensor expression for every index in topological order. Each
24 // iteration lattice point consists of a conjunction of tensor indices together
25 // with a tensor (sub)expression that needs to be evaluated for that
26 // conjunction. Within the lattice, iteration points are ordered according to
27 // the way indices are exhausted. As such these iteration lattices drive actual
28 // sparse code generation, which consists of a tedious but relatively
29 // straightforward one-to-one mapping from iteration lattices to combinations
30 // of for-loops, while-loops, and if-statements.
31 //
32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations.
33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php).
34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou,
35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler.
36 // Proceedings of the ACM on Programming Languages, October 2017.
37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation.
38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org).
39 //
40 // Implementation detail: We use llvm::SmallVector for vectors with
41 // variable lengths and std::vector for vectors with fixed lengths.
42 //===----------------------------------------------------------------------===//
43 
44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
45 #include "mlir/Dialect/Linalg/Utils/Utils.h"
46 #include "mlir/Dialect/MemRef/IR/MemRef.h"
47 #include "mlir/Dialect/SCF/SCF.h"
48 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
49 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
50 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
51 #include "mlir/Dialect/StandardOps/IR/Ops.h"
52 #include "mlir/Dialect/Vector/VectorOps.h"
53 #include "mlir/IR/Matchers.h"
54 #include "mlir/IR/TensorEncoding.h"
55 #include "llvm/ADT/SmallBitVector.h"
56 
57 using namespace mlir;
58 using namespace mlir::sparse_tensor;
59 
60 namespace {
61 
62 // Code generation.
63 struct CodeGen {
64   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
65       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
66         pointers(numTensors, std::vector<Value>(numLoops)),
67         indices(numTensors, std::vector<Value>(numLoops)),
68         highs(numTensors, std::vector<Value>(numLoops)),
69         pidxs(numTensors, std::vector<Value>(numLoops)),
70         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
71         curVecLength(1), curVecMask() {}
72   /// Sparsification options.
73   SparsificationOptions options;
74   /// Universal dense indices and upper bounds (by index). The loops array
75   /// is updated with the value of the universal dense index in the current
76   /// loop. The sizes array is set once with the inferred dimension sizes.
77   std::vector<Value> loops;
78   std::vector<Value> sizes;
79   /// Buffers for storing dense and sparse numerical values (by tensor).
80   /// This array is set once during bufferization of all tensors.
81   std::vector<Value> buffers;
82   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
83   /// This array is set once during bufferization of all sparse tensors.
84   std::vector<std::vector<Value>> pointers;
85   std::vector<std::vector<Value>> indices;
86   /// Sparse iteration information (by tensor and index). These arrays
87   /// are updated to remain current within the current loop.
88   std::vector<std::vector<Value>> highs;
89   std::vector<std::vector<Value>> pidxs;
90   std::vector<std::vector<Value>> idxs;
91   /// Current reduction, updated during code generation. When indices of a
92   /// reduction are exhausted,  all inner loops can "scalarize" the reduction.
93   // TODO: currently only done for (a chain of) innermost for-loops, where it
94   // is most effective; we could generalize to more outer and while-loops.
95   unsigned redExp;
96   Value redVal;
97   // Current vector length and mask.
98   unsigned curVecLength;
99   Value curVecMask;
100 };
101 
102 } // namespace
103 
104 // Helper method to apply dimension ordering permutation.
105 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) {
106   if (enc) {
107     auto order = enc.getDimOrdering();
108     if (order) {
109       assert(order.isPermutation());
110       return order.getDimPosition(d);
111     }
112   }
113   return d;
114 }
115 
116 // Helper method to translate dim level type to internal representation.
117 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
118   if (enc) {
119     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
120     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
121       return Dim::kSparse;
122     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
123       return Dim::kSingle;
124   }
125   return Dim::kDense;
126 }
127 
128 /// Helper method to inspect sparse encodings in the tensor types.
129 /// Fills the per-dimension sparsity information for all tensors.
130 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
131   bool annotated = false;
132   for (OpOperand *t : op.getInputAndOutputOperands()) {
133     auto map = op.getTiedIndexingMap(t);
134     if (!map.isProjectedPermutation())
135       return false;
136     auto enc = getSparseTensorEncoding(t->get().getType());
137     if (enc)
138       annotated = true;
139     assert(map.getNumResults() == op.getRank(t));
140     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
141       unsigned idx = map.getDimPosition(perm(enc, d));
142       merger.setDim(t->getOperandNumber(), idx, toDim(enc, d));
143     }
144   }
145   return annotated;
146 }
147 
148 /// A DFS helper to compute a topological sort. Note that recursion is
149 /// bounded by the number of implicit loops, which is always small.
150 /// Returns false when a cycle is detected.
151 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
152                        std::vector<unsigned> &topSort,
153                        std::vector<std::vector<bool>> &adjM) {
154   if (visit[i] != 0)
155     return visit[i] != 1; // 1 denotes cycle!
156   visit[i] = 1;
157   for (unsigned j = 0, e = visit.size(); j < e; j++)
158     if (adjM[i][j])
159       if (!topSortDFS(j, visit, topSort, adjM))
160         return false;
161   visit[i] = 2;
162   topSort.push_back(i);
163   return true;
164 }
165 
166 /// Computes a topologically sorted iteration graph for the linalg operation.
167 /// Ensures all tensors are visited in natural index order. This is essential
168 /// for sparse storage formats since these only support access along fixed
169 /// dimensions. Even for dense storage formats, however, the natural index
170 /// order yields innermost unit-stride access with better spatial locality.
171 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
172                                   std::vector<unsigned> &topSort,
173                                   bool sparseOnly) {
174   // Set up an n x n from/to adjacency matrix of the iteration graph
175   // for the implicit loop indices i_0 .. i_n-1.
176   unsigned n = op.getNumLoops();
177   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
178 
179   // Iterate over the indexing maps of every tensor in the tensor expression.
180   for (OpOperand *t : op.getInputAndOutputOperands()) {
181     auto map = op.getTiedIndexingMap(t);
182     auto enc = getSparseTensorEncoding(t->get().getType());
183     assert(map.getNumDims() == n);
184     // Skip dense tensor constraints when sparse only is requested.
185     if (sparseOnly && !enc)
186       continue;
187     // Each tensor expression and optional dimension ordering (row-major
188     // by default) puts an ordering constraint on the loop indices. For
189     // example, the tensor expresion A_ijk forces the ordering i < j < k
190     // on the loop indices if no explicit dimension ordering is given.
191     for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
192       unsigned f = map.getDimPosition(perm(enc, d - 1));
193       unsigned t = map.getDimPosition(perm(enc, d));
194       adjM[f][t] = true;
195     }
196   }
197 
198   // Topologically sort the iteration graph to determine loop order.
199   // Report failure for a cyclic iteration graph.
200   topSort.clear();
201   topSort.reserve(n);
202   std::vector<unsigned> visit(n, 0);
203   for (unsigned i = 0; i < n; i++)
204     if (visit[i] == 0)
205       if (!topSortDFS(i, visit, topSort, adjM))
206         return false; // cycle!
207   std::reverse(std::begin(topSort), std::end(topSort));
208   return true;
209 }
210 
211 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
212 /// This simplifies constructing (sub)expressions during iteration lattice
213 /// building (compared to using the SSA representation everywhere).
214 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
215                                          Value val) {
216   if (auto arg = val.dyn_cast<BlockArgument>()) {
217     unsigned argN = arg.getArgNumber();
218     // Any argument of the generic op that is not marked as a scalar
219     // argument is considered a tensor, indexed by the implicit loop
220     // bounds. This includes rank-0 tensor arguments.
221     if (arg.getOwner()->getParentOp() == op) {
222       OpOperand *t = op.getInputAndOutputOperands()[argN];
223       if (!op.isScalar(t))
224         return merger.addExp(Kind::kTensor, argN);
225       val = t->get(); // get scalar value
226     }
227     // Any other argument (marked as scalar argument for the generic op
228     // or belonging to an enveloping op) is considered invariant.
229     return merger.addExp(Kind::kInvariant, val);
230   }
231   Operation *def = val.getDefiningOp();
232   if (def->getBlock() != &op.region().front()) {
233     // Something defined outside is invariant.
234     return merger.addExp(Kind::kInvariant, val);
235   } else if (def->getNumOperands() == 2) {
236     // Construct binary operations if subexpressions could be built.
237     auto x = buildTensorExp(merger, op, def->getOperand(0));
238     auto y = buildTensorExp(merger, op, def->getOperand(1));
239     if (x.hasValue() && y.hasValue()) {
240       unsigned e0 = x.getValue();
241       unsigned e1 = y.getValue();
242       if (isa<MulFOp>(def))
243         return merger.addExp(Kind::kMulF, e0, e1);
244       if (isa<MulIOp>(def))
245         return merger.addExp(Kind::kMulI, e0, e1);
246       if (isa<AddFOp>(def))
247         return merger.addExp(Kind::kAddF, e0, e1);
248       if (isa<AddIOp>(def))
249         return merger.addExp(Kind::kAddI, e0, e1);
250     }
251   }
252   // Cannot build (yet).
253   return None;
254 }
255 
256 /// Returns true if given tensor co-iterates with conjunction only.
257 /// For the output tensor, this defines a "simply dynamic" operation.
258 /// For instance: A(I) = A(I) * B(I) * C(I)
259 static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) {
260   switch (merger.exp(exp).kind) {
261   case Kind::kTensor:
262     return merger.exp(exp).e0 == tensor;
263   case Kind::kMulF:
264   case Kind::kMulI:
265     return isConjunction(merger, tensor, merger.exp(exp).e0) ||
266            isConjunction(merger, tensor, merger.exp(exp).e1);
267   default:
268     return false;
269   }
270 }
271 
272 /// Returns true when the tensor expression is admissable for codegen.
273 /// Since all sparse input tensors are admissable, we just need to check
274 /// whether the output tensor in the tensor expression codegen is admissable.
275 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
276                                   unsigned exp) {
277   OpOperand *lhs = op.getOutputOperand(0);
278   unsigned tensor = lhs->getOperandNumber();
279   auto enc = getSparseTensorEncoding(lhs->get().getType());
280   // An non-annotated output tensor is assumed dense, and becomes a random
281   // access n-dim memref. Admissable since inserstions cannot occur.
282   if (!enc)
283     return true;
284   // An all-dense annotated "sparse" output tensor becomes a linearized random
285   // access 1-dim memref. Also admissable since insertions cannot occur.
286   bool allDense = true;
287   unsigned numLoops = op.iterator_types().getValue().size();
288   for (unsigned i = 0; i < numLoops; i++)
289     if (merger.isDim(tensor, i, Dim::kSparse)) {
290       allDense = false;
291       break;
292     }
293   if (allDense)
294     return true;
295   // A tensor expression with a sparse output tensor that changes its values
296   // but not its nonzero structure, an operation called "simply dynamic" in
297   // [Bik96,Ch9], is also admissable without special codegen.
298   if (isConjunction(merger, tensor, exp))
299     return true;
300   // Reject for now since this requires changes to the nonzero structure.
301   // TODO: implement "workspaces" [Kjolstad2019]
302   return false;
303 }
304 
305 /// Maps sparse integer option to actual integral storage type.
306 static Type genIntType(PatternRewriter &rewriter, unsigned width) {
307   if (width == 0)
308     return rewriter.getIndexType();
309   return rewriter.getIntegerType(width);
310 }
311 
312 /// Detects in-place annotation on tensor argument.
313 static bool getInPlace(Value val) {
314   if (auto arg = val.dyn_cast<BlockArgument>())
315     if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
316       if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
317               arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName))
318         return attr.getValue();
319   return false;
320 }
321 
322 /// Generates buffer for the output tensor.
323 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
324                              linalg::GenericOp op, MemRefType denseTp,
325                              ArrayRef<Value> args) {
326   Location loc = op.getLoc();
327   Value tensor = op.getOutputOperand(0)->get();
328   // The output tensor simply could materialize from the buffer that will
329   // be generated for the tensor present in the outs() clause. This has
330   // the major advantage that the sparse kernel only updates the nonzero
331   // positions for the output tensor.
332   if (getInPlace(tensor))
333     return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
334   // By default, a new buffer is allocated which is initialized to the
335   // tensor defined in the outs() clause. This is always correct but
336   // introduces a dense initialization component that may negatively
337   // impact the running complexity of the sparse kernel.
338   Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
339   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
340   rewriter.create<linalg::CopyOp>(loc, init, alloc);
341   return alloc;
342 }
343 
344 /// Local bufferization of all dense and sparse data structures.
345 /// This code enables testing the first prototype sparse compiler.
346 // TODO: replace this with a proliferated bufferization strategy
347 static bool genBuffers(Merger &merger, CodeGen &codegen,
348                        PatternRewriter &rewriter, linalg::GenericOp op) {
349   Location loc = op.getLoc();
350   assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
351   // For every tensor, find lower and upper bound on dimensions, set the
352   // same bounds on loop indices, and obtain dense or sparse buffer(s).
353   SmallVector<Value, 4> args;
354   for (OpOperand *t : op.getInputAndOutputOperands()) {
355     unsigned tensor = t->getOperandNumber();
356     auto shape = op.getShape(t);
357     auto map = op.getTiedIndexingMap(t);
358     auto enc = getSparseTensorEncoding(t->get().getType());
359     // Scan all dimensions of current tensor.
360     args.clear();
361     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
362       unsigned idx = map.getDimPosition(perm(enc, d));
363       // Handle sparse storage schemes.
364       if (merger.isDim(tensor, idx, Dim::kSparse)) {
365         auto dynShape = {ShapedType::kDynamicSize};
366         auto ptrTp = MemRefType::get(
367             dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
368         auto indTp = MemRefType::get(
369             dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
370         Value dim = rewriter.create<ConstantIndexOp>(loc, d);
371         // Generate sparse primitives to obtains pointer and indices.
372         codegen.pointers[tensor][idx] =
373             rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
374         codegen.indices[tensor][idx] =
375             rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim);
376       }
377       // Find lower and upper bound in current dimension.
378       Value up;
379       if (shape[d] == MemRefType::kDynamicSize) {
380         up = rewriter.create<memref::DimOp>(loc, t->get(), d);
381         args.push_back(up);
382       } else {
383         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
384       }
385       codegen.sizes[idx] = codegen.highs[tensor][idx] = up;
386     }
387     // Perform the required bufferization. Dense inputs materialize
388     // from the input tensors. Dense outputs need special handling.
389     // Sparse inputs use sparse primitives to obtain the values.
390     // We also accept in-place all-dense annotated "sparse" outputs.
391     Type elementType = getElementTypeOrSelf(t->get().getType());
392     if (!enc) {
393       // Non-annotated dense tensors.
394       auto denseTp = MemRefType::get(shape, elementType);
395       if (tensor < op.getNumInputs())
396         codegen.buffers[tensor] =
397             rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get());
398       else
399         codegen.buffers[tensor] =
400             genOutputBuffer(codegen, rewriter, op, denseTp, args);
401     } else {
402       // Annotated sparse tensors.
403       if (tensor == op.getNumInputs() && !getInPlace(t->get()))
404         return false; // reject output if not in-place
405       auto dynShape = {ShapedType::kDynamicSize};
406       auto sparseTp = MemRefType::get(dynShape, elementType);
407       codegen.buffers[tensor] =
408           rewriter.create<ToValuesOp>(loc, sparseTp, t->get());
409     }
410   }
411   return true;
412 }
413 
414 /// Constructs vector type.
415 static VectorType vectorType(CodeGen &codegen, Type etp) {
416   return VectorType::get(codegen.curVecLength, etp);
417 }
418 
419 /// Constructs vector type from pointer.
420 static VectorType vectorType(CodeGen &codegen, Value ptr) {
421   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
422 }
423 
424 /// Constructs vector iteration mask.
425 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
426                            Value iv, Value lo, Value hi, Value step) {
427   Location loc = iv.getLoc();
428   VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1));
429   // Special case if the vector length evenly divides the trip count (for
430   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
431   // so that all subsequent masked memory operations are immediately folded
432   // into unconditional memory operations.
433   IntegerAttr loInt, hiInt, stepInt;
434   if (matchPattern(lo, m_Constant(&loInt)) &&
435       matchPattern(hi, m_Constant(&hiInt)) &&
436       matchPattern(step, m_Constant(&stepInt))) {
437     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
438       return rewriter.create<vector::BroadcastOp>(
439           loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1));
440   }
441   // Otherwise, generate a vector mask that avoids overrunning the upperbound
442   // during vector execution. Here we rely on subsequent loop optimizations to
443   // avoid executing the mask in all iterations, for example, by splitting the
444   // loop into an unconditional vector loop and a scalar cleanup loop.
445   Value end = rewriter.create<SubIOp>(loc, hi, iv);
446   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
447 }
448 
449 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
450 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
451                            Value ptr, ArrayRef<Value> args) {
452   Location loc = ptr.getLoc();
453   VectorType vtp = vectorType(codegen, ptr);
454   Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
455   if (args.back().getType().isa<VectorType>()) {
456     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
457     Value indexVec = args.back();
458     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
459     return rewriter.create<vector::GatherOp>(
460         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
461   }
462   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
463                                                codegen.curVecMask, pass);
464 }
465 
466 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
467 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
468                            Value rhs, Value ptr, ArrayRef<Value> args) {
469   Location loc = ptr.getLoc();
470   if (args.back().getType().isa<VectorType>()) {
471     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
472     Value indexVec = args.back();
473     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
474     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
475                                        codegen.curVecMask, rhs);
476     return;
477   }
478   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
479                                          rhs);
480 }
481 
482 /// Generates a vectorized invariant. Here we rely on subsequent loop
483 /// optimizations to hoist the invariant broadcast out of the vector loop.
484 static Value genVectorInvariantValue(CodeGen &codegen,
485                                      PatternRewriter &rewriter, Value val) {
486   VectorType vtp = vectorType(codegen, val.getType());
487   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
488 }
489 
490 /// Generates a load on a dense or sparse tensor.
491 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
492                            PatternRewriter &rewriter, linalg::GenericOp op,
493                            unsigned exp) {
494   // Test if the load was hoisted to a higher loop nest.
495   Value val = merger.exp(exp).val;
496   if (val) {
497     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
498       return genVectorInvariantValue(codegen, rewriter, val);
499     return val;
500   }
501   // Actual load.
502   SmallVector<Value, 4> args;
503   OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
504   unsigned tensor = t->getOperandNumber();
505   auto map = op.getTiedIndexingMap(t);
506   auto enc = getSparseTensorEncoding(t->get().getType());
507   unsigned rank = map.getNumResults();
508   if (enc) {
509     unsigned idx = map.getDimPosition(perm(enc, rank - 1));
510     assert(codegen.pidxs[tensor][idx] != nullptr);
511     args.push_back(codegen.pidxs[tensor][idx]); // position index
512   } else {
513     for (unsigned d = 0; d < rank; d++) {
514       unsigned idx = map.getDimPosition(d);
515       args.push_back(codegen.loops[idx]); // universal dense index
516     }
517   }
518   Location loc = op.getLoc();
519   Value ptr = codegen.buffers[tensor];
520   if (codegen.curVecLength > 1)
521     return genVectorLoad(codegen, rewriter, ptr, args);
522   return rewriter.create<memref::LoadOp>(loc, ptr, args);
523 }
524 
525 /// Generates a store on a dense or sparse tensor.
526 static void genTensorStore(Merger &merger, CodeGen &codegen,
527                            PatternRewriter &rewriter, linalg::GenericOp op,
528                            OpOperand *t, Value rhs) {
529   Location loc = op.getLoc();
530   // Test if this is a scalarized reduction.
531   OpOperand *lhs = op.getOutputOperand(0);
532   if (lhs == t && codegen.redVal) {
533     if (codegen.curVecLength > 1)
534       rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs,
535                                       codegen.redVal);
536     codegen.redVal = rhs;
537     return;
538   }
539   // Actual store.
540   SmallVector<Value, 4> args;
541   unsigned tensor = t->getOperandNumber();
542   auto map = op.getTiedIndexingMap(t);
543   auto enc = getSparseTensorEncoding(t->get().getType());
544   unsigned rank = map.getNumResults();
545   if (enc) {
546     unsigned idx = map.getDimPosition(perm(enc, rank - 1));
547     assert(codegen.pidxs[tensor][idx] != nullptr);
548     args.push_back(codegen.pidxs[tensor][idx]); // position index
549   } else {
550     for (unsigned d = 0; d < rank; d++) {
551       unsigned idx = map.getDimPosition(d);
552       args.push_back(codegen.loops[idx]); // universal dense index
553     }
554   }
555   Value ptr = codegen.buffers[tensor];
556   if (codegen.curVecLength > 1)
557     genVectorStore(codegen, rewriter, rhs, ptr, args);
558   else
559     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
560 }
561 
562 /// Generates a pointer/index load from the sparse storage scheme. Narrower
563 /// data types need to be zero extended before casting the value into the
564 /// index type used for looping and indexing.
565 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
566                      Value ptr, Value s) {
567   // See https://llvm.org/docs/GetElementPtr.html for some background on
568   // the complications described below.
569   if (codegen.curVecLength > 1) {
570     // Since the index vector is used in a subsequent gather/scatter operations,
571     // which effectively defines an unsigned pointer + signed index, we must
572     // zero extend the vector to an index width. For 8-bit and 16-bit values,
573     // an 32-bit index width suffices. For 32-bit values, zero extending the
574     // elements into 64-bit loses some performance since the 32-bit indexed
575     // gather/scatter is more efficient than the 64-bit index variant (if the
576     // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
577     // preserve this performance). For 64-bit values, there is no good way
578     // to state that the indices are unsigned, with creates the potential of
579     // incorrect address calculations in the unlikely case we need such
580     // extremely large offsets.
581     Type etp = ptr.getType().cast<MemRefType>().getElementType();
582     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
583     if (!etp.isa<IndexType>()) {
584       if (etp.getIntOrFloatBitWidth() < 32)
585         vload = rewriter.create<ZeroExtendIOp>(
586             loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
587       else if (etp.getIntOrFloatBitWidth() < 64 &&
588                !codegen.options.enableSIMDIndex32)
589         vload = rewriter.create<ZeroExtendIOp>(
590             loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
591     }
592     return vload;
593   }
594   // For the scalar case, we simply zero extend narrower indices into 64-bit
595   // values before casting to index without a performance penalty. Here too,
596   // however, indices that already are 64-bit, in theory, cannot express the
597   // full range as explained above.
598   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
599   if (!load.getType().isa<IndexType>()) {
600     if (load.getType().getIntOrFloatBitWidth() < 64)
601       load = rewriter.create<ZeroExtendIOp>(loc, load,
602                                             rewriter.getIntegerType(64));
603     load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
604   }
605   return load;
606 }
607 
608 /// Generates an invariant value.
609 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
610                                PatternRewriter &rewriter, unsigned exp) {
611   Value val = merger.exp(exp).val;
612   if (codegen.curVecLength > 1)
613     return genVectorInvariantValue(codegen, rewriter, val);
614   return val;
615 }
616 
617 /// Generates an address computation "sz * p + i".
618 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
619                         Location loc, Value size, Value p, Value i) {
620   Value mul = rewriter.create<MulIOp>(loc, size, p);
621   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
622     Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType());
623     mul = genVectorInvariantValue(codegen, rewriter, inv);
624   }
625   return rewriter.create<AddIOp>(loc, mul, i);
626 }
627 
628 /// Generates start of a reduction.
629 static Value genReductionStart(Merger &merger, CodeGen &codegen,
630                                PatternRewriter &rewriter,
631                                linalg::GenericOp op) {
632   if (codegen.redVal)
633     return codegen.redVal; // chained with previous for-loop
634   if (codegen.curVecLength > 1) {
635     // TODO: assumes + reductions for now
636     VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]);
637     return rewriter.create<ConstantOp>(op.getLoc(), vtp,
638                                        rewriter.getZeroAttr(vtp));
639   }
640   return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
641 }
642 
643 /// Generates end of a reduction.
644 static void genReductionEnd(Merger &merger, CodeGen &codegen,
645                             PatternRewriter &rewriter, linalg::GenericOp op) {
646   Value red = codegen.redVal;
647   if (!red)
648     return;
649   assert(codegen.curVecLength == 1);
650   codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
651   OpOperand *lhs = op.getOutputOperand(0);
652   if (auto vtp = red.getType().dyn_cast<VectorType>()) {
653     // TODO: assumes + reductions for now
654     StringAttr kind = rewriter.getStringAttr("add");
655     Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
656     // Integer reductions don't accept an accumulator.
657     if (vtp.getElementType().isa<IntegerType>()) {
658       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
659                                                  kind, red, ValueRange{});
660       red = rewriter.create<AddIOp>(op.getLoc(), red, ld);
661     } else {
662       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
663                                                  kind, red, ld);
664     }
665   }
666   genTensorStore(merger, codegen, rewriter, op, lhs, red);
667 }
668 
669 /// Recursively generates tensor expression.
670 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
671                     linalg::GenericOp op, unsigned exp) {
672   if (merger.exp(exp).kind == Kind::kTensor)
673     return genTensorLoad(merger, codegen, rewriter, op, exp);
674   else if (merger.exp(exp).kind == Kind::kInvariant)
675     return genInvariantValue(merger, codegen, rewriter, exp);
676   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
677   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
678   switch (merger.exp(exp).kind) {
679   case Kind::kTensor:
680   case Kind::kInvariant:
681     llvm_unreachable("handled above");
682   case Kind::kMulF:
683     return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
684   case Kind::kMulI:
685     return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
686   case Kind::kAddF:
687     return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
688   case Kind::kAddI:
689     return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
690   }
691   llvm_unreachable("unexpected expression kind");
692 }
693 
694 /// Hoists loop invariant tensor loads for which indices have been exhausted.
695 static void genInvariants(Merger &merger, CodeGen &codegen,
696                           PatternRewriter &rewriter, linalg::GenericOp op,
697                           unsigned exp, unsigned ldx, bool hoist) {
698   if (merger.exp(exp).kind == Kind::kTensor) {
699     // Inspect tensor indices.
700     bool atLevel = ldx == -1u;
701     OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
702     auto map = op.getTiedIndexingMap(t);
703     auto enc = getSparseTensorEncoding(t->get().getType());
704     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
705       unsigned idx = map.getDimPosition(perm(enc, d));
706       if (!codegen.loops[idx])
707         return; // still in play
708       else if (idx == ldx)
709         atLevel = true;
710     }
711     // All exhausted at this level (atLevel denotes exactly at this level).
712     OpOperand *lhs = op.getOutputOperand(0);
713     if (lhs == t) {
714       codegen.redExp = hoist ? exp : -1u;
715     } else if (atLevel) {
716       merger.exp(exp).val =
717           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
718     }
719   } else if (merger.exp(exp).kind != Kind::kInvariant) {
720     // Traverse into the binary operations. Note that we only hoist
721     // tensor loads, since subsequent MLIR/LLVM passes know how to
722     // deal with all other kinds of derived loop invariants.
723     unsigned e0 = merger.exp(exp).e0;
724     unsigned e1 = merger.exp(exp).e1;
725     genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist);
726     genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist);
727   }
728 }
729 
730 /// Generates initialization code for the subsequent loop sequence at
731 /// current index level. Returns true if the loop sequence needs to
732 /// maintain the universal index.
733 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
734                     linalg::GenericOp op, std::vector<unsigned> &topSort,
735                     unsigned at, llvm::BitVector &inits) {
736   bool needsUniv = false;
737   Location loc = op.getLoc();
738   unsigned idx = topSort[at];
739 
740   // Initialize sparse positions.
741   for (unsigned b = 0, be = inits.size(); b < be; b++) {
742     if (inits[b]) {
743       unsigned tensor = merger.tensor(b);
744       assert(idx == merger.index(b));
745       if (merger.isDim(b, Dim::kSparse)) {
746         // Initialize sparse index.
747         unsigned pat = at;
748         for (; pat != 0; pat--) {
749           if (codegen.pidxs[tensor][topSort[pat - 1]])
750             break;
751         }
752         Value ptr = codegen.pointers[tensor][idx];
753         Value one = rewriter.create<ConstantIndexOp>(loc, 1);
754         Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
755                               : codegen.pidxs[tensor][topSort[pat - 1]];
756         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
757         Value p1 = rewriter.create<AddIOp>(loc, p0, one);
758         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
759       } else {
760         // Dense index still in play.
761         needsUniv = true;
762       }
763     }
764   }
765 
766   // Initialize the universal dense index.
767   codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0);
768   return needsUniv;
769 }
770 
771 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
772 /// operation is a candidate. Whether it is actually converted to SIMD code
773 /// depends on the requested strategy.
774 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
775   switch (codegen.options.vectorizationStrategy) {
776   case SparseVectorizationStrategy::kNone:
777     return false;
778   case SparseVectorizationStrategy::kDenseInnerLoop:
779     return isInner && !isSparse;
780   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
781     return isInner;
782   }
783   llvm_unreachable("unexpected vectorization strategy");
784 }
785 
786 /// Returns parallelization strategy. Any implicit loop in the Linalg operation
787 /// that is marked "parallel" is a candidate. Whether it is actually converted
788 /// to a parallel operation depends on the requested strategy.
789 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
790                           bool isSparse, bool isVector) {
791   switch (codegen.options.parallelizationStrategy) {
792   case SparseParallelizationStrategy::kNone:
793     return false;
794   case SparseParallelizationStrategy::kDenseOuterLoop:
795     return isOuter && !isSparse && !isReduction && !isVector;
796   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
797     return isOuter && !isReduction && !isVector;
798   case SparseParallelizationStrategy::kDenseAnyLoop:
799     return !isSparse && !isReduction && !isVector;
800   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
801     return !isReduction && !isVector;
802   }
803   llvm_unreachable("unexpected parallelization strategy");
804 }
805 
806 /// Checks unit strides for dense tensors. The iteration graph may have ignored
807 /// dense access patterns in order to avoid cycles (sparse access patterns are
808 /// always placed innermost), but that means dense access has become strided.
809 /// For now, we reject vectorization of such cases.
810 /// TODO: implement strided load/stores on dense arrays
811 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
812                              unsigned idx) {
813   for (OpOperand *t : op.getInputAndOutputOperands()) {
814     if (!getSparseTensorEncoding(t->get().getType())) {
815       auto map = op.getTiedIndexingMap(t);
816       for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
817         if (map.getDimPosition(d) == idx && d != rank - 1)
818           return false;
819       }
820     }
821   }
822   return true;
823 }
824 
825 /// Generates a for-loop on a single index.
826 static Operation *genFor(Merger &merger, CodeGen &codegen,
827                          PatternRewriter &rewriter, linalg::GenericOp op,
828                          bool isOuter, bool isInner, unsigned idx,
829                          llvm::BitVector &indices) {
830   unsigned fb = indices.find_first();
831   unsigned tensor = merger.tensor(fb);
832   assert(idx == merger.index(fb));
833   auto iteratorTypes = op.iterator_types().getValue();
834   bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]);
835   bool isSparse = merger.isDim(fb, Dim::kSparse);
836   bool isVector = isVectorFor(codegen, isInner, isSparse) &&
837                   denseUnitStrides(merger, op, idx);
838   bool isParallel =
839       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
840 
841   // Prepare vector length.
842   if (isVector)
843     codegen.curVecLength = codegen.options.vectorLength;
844 
845   // Loop bounds and increment.
846   Location loc = op.getLoc();
847   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
848   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
849   Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength);
850 
851   // Emit a parallel loop.
852   if (isParallel) {
853     assert(!isVector);
854     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
855     if (isSparse)
856       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
857     else
858       codegen.loops[idx] = parOp.getInductionVars()[0];
859     rewriter.setInsertionPointToStart(parOp.getBody());
860     return parOp;
861   }
862 
863   // Emit a sequential loop, potentially with a scalarized reduction.
864   bool scalarRed = isInner && codegen.redExp != -1u;
865   SmallVector<Value, 4> operands;
866   if (scalarRed) {
867     Value load = genReductionStart(merger, codegen, rewriter, op);
868     operands.push_back(load);
869   }
870   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
871   if (scalarRed) {
872     codegen.redVal = merger.exp(codegen.redExp).val =
873         forOp.getRegionIterArgs().front();
874   }
875   // Assign induction variable to sparse or dense index.
876   Value iv = forOp.getInductionVar();
877   if (isSparse)
878     codegen.pidxs[tensor][idx] = iv;
879   else
880     codegen.loops[idx] = iv;
881   rewriter.setInsertionPointToStart(forOp.getBody());
882   // Share vector iteration mask between all subsequent loads/stores.
883   if (isVector)
884     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
885   return forOp;
886 }
887 
888 /// Emit a while-loop for co-iteration over multiple indices.
889 static Operation *genWhile(Merger &merger, CodeGen &codegen,
890                            PatternRewriter &rewriter, linalg::GenericOp op,
891                            unsigned idx, bool needsUniv,
892                            llvm::BitVector &indices) {
893   SmallVector<Type, 4> types;
894   SmallVector<Value, 4> operands;
895   // Construct the while-loop with a parameter for each index.
896   Type indexType = rewriter.getIndexType();
897   for (unsigned b = 0, be = indices.size(); b < be; b++) {
898     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
899       unsigned tensor = merger.tensor(b);
900       assert(idx == merger.index(b));
901       types.push_back(indexType);
902       assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
903              "type mismatch for sparse index");
904       operands.push_back(codegen.pidxs[tensor][idx]);
905     }
906   }
907   if (needsUniv) {
908     types.push_back(indexType);
909     assert(codegen.loops[idx].getType().isa<IndexType>() &&
910            "type mismatch for universal index");
911     operands.push_back(codegen.loops[idx]);
912   }
913   Location loc = op.getLoc();
914   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
915   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
916   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
917 
918   // Build the "before" region, which effectively consists
919   // of a conjunction of "i < upper" tests on all induction.
920   rewriter.setInsertionPointToStart(&whileOp.before().front());
921   Value cond;
922   unsigned o = 0;
923   for (unsigned b = 0, be = indices.size(); b < be; b++) {
924     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
925       unsigned tensor = merger.tensor(b);
926       assert(idx == merger.index(b));
927       Value op1 = before->getArgument(o);
928       Value op2 = codegen.highs[tensor][idx];
929       Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2);
930       cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc;
931       codegen.pidxs[tensor][idx] = after->getArgument(o++);
932     }
933   }
934   if (needsUniv)
935     codegen.loops[idx] = after->getArgument(o++);
936   assert(o == operands.size());
937   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
938   rewriter.setInsertionPointToStart(&whileOp.after().front());
939   return whileOp;
940 }
941 
942 /// Generates a for-loop or a while-loop, depending on whether it implements
943 /// singleton iteration or co-iteration over the given conjunction.
944 static Operation *genLoop(Merger &merger, CodeGen &codegen,
945                           PatternRewriter &rewriter, linalg::GenericOp op,
946                           std::vector<unsigned> &topSort, unsigned at,
947                           bool needsUniv, llvm::BitVector &indices) {
948   unsigned idx = topSort[at];
949   if (indices.count() == 1) {
950     bool isOuter = at == 0;
951     bool isInner = at == topSort.size() - 1;
952     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
953                   indices);
954   }
955   genReductionEnd(merger, codegen, rewriter, op); // cannot chain
956   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
957 }
958 
959 /// Generates the local variables for this loop, consisting of the sparse
960 /// indices, restored universal dense index, and dense positions.
961 static void genLocals(Merger &merger, CodeGen &codegen,
962                       PatternRewriter &rewriter, linalg::GenericOp op,
963                       std::vector<unsigned> &topSort, unsigned at,
964                       bool needsUniv, llvm::BitVector &locals) {
965   Location loc = op.getLoc();
966   unsigned idx = topSort[at];
967 
968   // Initialize sparse indices.
969   Value min;
970   for (unsigned b = 0, be = locals.size(); b < be; b++) {
971     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
972       unsigned tensor = merger.tensor(b);
973       assert(idx == merger.index(b));
974       Value ptr = codegen.indices[tensor][idx];
975       Value s = codegen.pidxs[tensor][idx];
976       Value load = genLoad(codegen, rewriter, loc, ptr, s);
977       codegen.idxs[tensor][idx] = load;
978       if (!needsUniv) {
979         if (min) {
980           Value cmp =
981               rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
982           min = rewriter.create<SelectOp>(loc, cmp, load, min);
983         } else {
984           min = load;
985         }
986       }
987     }
988   }
989 
990   // Merge dense universal index over minimum.
991   if (min) {
992     assert(!needsUniv);
993     codegen.loops[idx] = min;
994   }
995 
996   // Initialize dense positions. Note that we generate dense indices of the
997   // output tensor unconditionally, since they may not appear in the lattice,
998   // but may be needed for linearized codegen.
999   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1000     if ((locals[b] || merger.isOutTensor(b, idx)) &&
1001         merger.isDim(b, Dim::kDense)) {
1002       unsigned tensor = merger.tensor(b);
1003       assert(idx == merger.index(b));
1004       unsigned pat = at;
1005       for (; pat != 0; pat--)
1006         if (codegen.pidxs[tensor][topSort[pat - 1]])
1007           break;
1008       Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
1009                            : codegen.pidxs[tensor][topSort[pat - 1]];
1010       codegen.pidxs[tensor][idx] = genAddress(
1011           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1012     }
1013   }
1014 }
1015 
1016 /// Generates the induction structure for a while-loop.
1017 static void genWhileInduction(Merger &merger, CodeGen &codegen,
1018                               PatternRewriter &rewriter, linalg::GenericOp op,
1019                               unsigned idx, bool needsUniv,
1020                               llvm::BitVector &induction, ResultRange results) {
1021   Location loc = op.getLoc();
1022   unsigned o = 0;
1023   SmallVector<Value, 4> operands;
1024   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
1025   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1026     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1027       unsigned tensor = merger.tensor(b);
1028       assert(idx == merger.index(b));
1029       Value op1 = codegen.idxs[tensor][idx];
1030       Value op2 = codegen.loops[idx];
1031       Value op3 = codegen.pidxs[tensor][idx];
1032       Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1033       Value add = rewriter.create<AddIOp>(loc, op3, one);
1034       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
1035       codegen.pidxs[tensor][idx] = results[o++];
1036     }
1037   }
1038   if (needsUniv) {
1039     operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one));
1040     codegen.loops[idx] = results[o++];
1041   }
1042   assert(o == operands.size());
1043   rewriter.create<scf::YieldOp>(loc, operands);
1044 }
1045 
1046 /// Generates a single if-statement within a while-loop.
1047 static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1048                        PatternRewriter &rewriter, linalg::GenericOp op,
1049                        unsigned idx, llvm::BitVector &conditions) {
1050   Location loc = op.getLoc();
1051   Value cond;
1052   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1053     if (conditions[b]) {
1054       unsigned tensor = merger.tensor(b);
1055       assert(idx == merger.index(b));
1056       Value clause;
1057       if (merger.isDim(b, Dim::kSparse)) {
1058         Value op1 = codegen.idxs[tensor][idx];
1059         Value op2 = codegen.loops[idx];
1060         clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1061       } else {
1062         clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true
1063       }
1064       cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause;
1065     }
1066   }
1067   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
1068   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
1069   return ifOp;
1070 }
1071 
1072 /// Recursively generates code while computing iteration lattices in order
1073 /// to manage the complexity of implementing co-iteration over unions
1074 /// and intersections of sparse iterations spaces.
1075 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1076                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1077                     unsigned exp, unsigned at) {
1078   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1079   if (at == topSort.size()) {
1080     OpOperand *lhs = op.getOutputOperand(0);
1081     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1082     genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
1083     return;
1084   }
1085   assert(codegen.curVecLength == 1);
1086 
1087   // Construct iteration lattices for current loop index, with L0 at top.
1088   // Then emit initialization code for the loop sequence at this level.
1089   // We maintain the universal dense index if dense indices are still
1090   // in play for a non-singleton loop sequence.
1091   Location loc = op.getLoc();
1092   unsigned idx = topSort[at];
1093   unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
1094   unsigned lsize = merger.set(lts).size();
1095   assert(lsize != 0);
1096   unsigned l0 = merger.set(lts)[0];
1097   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1098   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
1099   bool needsUniv = false;
1100   if (genInit(merger, codegen, rewriter, op, topSort, at,
1101               merger.lat(l0).bits)) {
1102     // Maintain the universal index only if it is actually
1103     // consumed by a subsequent lattice point.
1104     for (unsigned i = 1; i < lsize; i++) {
1105       unsigned li = merger.set(lts)[i];
1106       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) {
1107         needsUniv = true;
1108         break;
1109       }
1110     }
1111   }
1112 
1113   // Emit a loop for every lattice point L0 >= Li.
1114   for (unsigned i = 0; i < lsize; i++) {
1115     unsigned li = merger.set(lts)[i];
1116 
1117     // Emit loop.
1118     codegen.curVecLength = 1;
1119     llvm::BitVector indices = merger.lat(li).simple;
1120     Operation *loop =
1121         genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
1122     genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1123               merger.lat(li).bits);
1124 
1125     // Visit all lattices points with Li >= Lj to generate the
1126     // loop-body, possibly with if statements for coiteration.
1127     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1128     for (unsigned j = 0; j < lsize; j++) {
1129       unsigned lj = merger.set(lts)[j];
1130       unsigned ej = merger.lat(lj).exp;
1131       if (li == lj || merger.latGT(li, lj)) {
1132         // Recurse into body of each branch.
1133         if (isWhile) {
1134           scf::IfOp ifOp =
1135               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1136           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1137           rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
1138         } else {
1139           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1140         }
1141       }
1142     }
1143 
1144     // Wrap-up induction and restore insertion point.
1145     if (isWhile) {
1146       scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
1147       rewriter.setInsertionPointToEnd(&whileOp.after().front());
1148       genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
1149                         merger.lat(li).bits, whileOp.results());
1150     } else {
1151       needsUniv = false;
1152       if (codegen.redVal) {
1153         rewriter.create<scf::YieldOp>(loc, codegen.redVal);
1154         codegen.redVal = loop->getResult(0);
1155       }
1156     }
1157     rewriter.setInsertionPointAfter(loop);
1158   }
1159 
1160   // Wrap-up loop sequence.
1161   codegen.curVecLength = 1;
1162   genReductionEnd(merger, codegen, rewriter, op);
1163   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
1164   codegen.loops[idx] = Value();
1165 }
1166 
1167 /// Converts the result computed by the sparse kernel into the required form.
1168 static void genResult(Merger &merger, CodeGen &codegen,
1169                       PatternRewriter &rewriter, linalg::GenericOp op) {
1170   Location loc = op.getLoc();
1171   OpOperand *lhs = op.getOutputOperand(0);
1172   Type resType = lhs->get().getType();
1173   unsigned tensor = lhs->getOperandNumber();
1174   auto map = op.getTiedIndexingMap(lhs);
1175   auto enc = getSparseTensorEncoding(resType);
1176   Value result = codegen.buffers.back(); // value array
1177   if (enc) {
1178     // The sparse annotation unambigiously defines the arrays needed
1179     // to "reconstruct" the sparse tensor from the storage scheme
1180     // (even though lowering should never need this eventually).
1181     SmallVector<Value, 4> args;
1182     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
1183       unsigned idx = map.getDimPosition(perm(enc, d));
1184       if (merger.isDim(tensor, idx, Dim::kSparse)) {
1185         args.push_back(codegen.pointers[tensor][idx]);
1186         args.push_back(codegen.indices[tensor][idx]);
1187       }
1188     }
1189     args.push_back(result);
1190     result = rewriter.create<ToTensorOp>(loc, resType, args);
1191   } else {
1192     // To "reconstruct" an non-annotated tensor, sipmly load it
1193     // from the bufferized value.
1194     result = rewriter.create<memref::TensorLoadOp>(loc, resType, result);
1195   }
1196   rewriter.replaceOp(op, result);
1197 }
1198 
1199 namespace {
1200 
1201 /// Sparse rewriting rule for generic Lingalg operation.
1202 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1203 public:
1204   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1205       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1206 
1207   LogicalResult matchAndRewrite(linalg::GenericOp op,
1208                                 PatternRewriter &rewriter) const override {
1209     // Detects sparse annotations and translate the per-dimension sparsity
1210     // information for all tensors to loop indices in the kernel.
1211     assert(op.getNumOutputs() == 1);
1212     unsigned numTensors = op.getNumInputsAndOutputs();
1213     unsigned numLoops = op.iterator_types().getValue().size();
1214     Merger merger(numTensors, numLoops);
1215     if (!findSparseAnnotations(merger, op))
1216       return failure();
1217 
1218     // Computes a topologically sorted iteration graph to ensure
1219     // tensors are visited in natural index order. Fails on cycles.
1220     // This assumes that higher-level passes have already put the
1221     // tensors in each tensor expression in a feasible order.
1222     std::vector<unsigned> topSort;
1223     if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
1224         !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
1225       return failure();
1226 
1227     // Finds the terminating yield statement and builds the tensor
1228     // expression for the Linalg operation in SSA form.
1229     Operation *yield = op.region().front().getTerminator();
1230     Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
1231     if (!exp.hasValue())
1232       return failure(); // build failure
1233 
1234     // Reject an inadmissable tensor expression.
1235     if (!isAdmissableTensorExp(merger, op, exp.getValue()))
1236       return failure();
1237 
1238     // Recursively generates code.
1239     CodeGen codegen(options, numTensors, numLoops);
1240     if (!genBuffers(merger, codegen, rewriter, op))
1241       return failure(); // could not bufferize
1242     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
1243     genResult(merger, codegen, rewriter, op);
1244     return success();
1245   }
1246 
1247 private:
1248   /// Options to control sparse code generation.
1249   SparsificationOptions options;
1250 };
1251 
1252 } // namespace
1253 
1254 /// Populates the given patterns list with rewriting rules required for
1255 /// the sparsification of linear algebra operations.
1256 void mlir::populateSparsificationPatterns(
1257     RewritePatternSet &patterns, const SparsificationOptions &options) {
1258   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1259 }
1260