1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements converting sparse tensor types to actual sparse code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodegenUtils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
17 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/SCF/Transforms.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
25 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/IR/TensorEncoding.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 
32 using namespace mlir;
33 using namespace mlir::sparse_tensor;
34 
35 //===----------------------------------------------------------------------===//
36 // Declarations of data structures.
37 //===----------------------------------------------------------------------===//
38 
39 namespace {
40 
41 // Iteration graph sorting.
42 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 };
43 
44 // Reduction kinds.
45 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor };
46 
47 // Code generation.
48 struct CodeGen {
49   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops,
50           OpOperand *op, unsigned nest)
51       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
52         pointers(numTensors, std::vector<Value>(numLoops)),
53         indices(numTensors, std::vector<Value>(numLoops)),
54         highs(numTensors, std::vector<Value>(numLoops)),
55         pidxs(numTensors, std::vector<Value>(numLoops)),
56         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
57         redKind(kNoReduc), sparseOut(op), outerParNest(nest), lexIdx(),
58         expValues(), expFilled(), expAdded(), expCount(), curVecLength(1),
59         curVecMask() {}
60   /// Sparsification options.
61   SparsificationOptions options;
62   /// Universal dense indices and upper bounds (by index). The loops array
63   /// is updated with the value of the universal dense index in the current
64   /// loop. The sizes array is set once with the inferred dimension sizes.
65   std::vector<Value> loops;
66   std::vector<Value> sizes;
67   /// Buffers for storing dense and sparse numerical values (by tensor).
68   /// This array is set once during bufferization of all tensors.
69   std::vector<Value> buffers;
70   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
71   /// This array is set once during bufferization of all sparse tensors.
72   std::vector<std::vector<Value>> pointers;
73   std::vector<std::vector<Value>> indices;
74   /// Sparse iteration information (by tensor and index). These arrays
75   /// are updated to remain current within the current loop.
76   std::vector<std::vector<Value>> highs;
77   std::vector<std::vector<Value>> pidxs;
78   std::vector<std::vector<Value>> idxs;
79   /// Current reduction, updated during code generation. When indices of a
80   /// reduction are exhausted, all inner loops can use a scalarized reduction.
81   unsigned redExp;
82   Value redVal;
83   Reduction redKind;
84   // Sparse tensor as output. Implemented either through direct injective
85   // insertion in lexicographic index order (where indices are updated
86   // in the temporary array `lexIdx`) or through access pattern expansion
87   // in the innermost loop nest (`expValues` through `expCount`).
88   OpOperand *sparseOut;
89   unsigned outerParNest;
90   Value lexIdx;
91   Value expValues;
92   Value expFilled;
93   Value expAdded;
94   Value expCount;
95   // Current vector length and mask.
96   unsigned curVecLength;
97   Value curVecMask;
98 };
99 
100 } // namespace
101 
102 //===----------------------------------------------------------------------===//
103 // Sparse compiler analysis methods.
104 //===----------------------------------------------------------------------===//
105 
106 /// Helper method to apply dimension ordering permutation.
107 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) {
108   if (enc) {
109     auto order = enc.getDimOrdering();
110     if (order) {
111       assert(order.isPermutation());
112       return order.getDimPosition(d);
113     }
114   }
115   return d;
116 }
117 
118 /// Helper method to translate dim level type to internal representation.
119 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) {
120   if (enc) {
121     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
122     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
123       return Dim::kSparse;
124     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
125       return Dim::kSingle;
126   }
127   return Dim::kDense;
128 }
129 
130 /// Helper method to inspect affine expressions. Rejects cases where the
131 /// same index is used more than once. Also rejects affine expressions
132 /// that are not a direct index for annotated tensors.
133 // TODO: accept more affine cases for sparse tensors
134 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim,
135                        bool isDense) {
136   switch (a.getKind()) {
137   case AffineExprKind::DimId: {
138     unsigned idx = a.cast<AffineDimExpr>().getPosition();
139     if (!merger.isDim(tensor, idx, Dim::kUndef))
140       return false; // used more than once
141     merger.setDim(tensor, idx, dim);
142     return true;
143   }
144   case AffineExprKind::Add:
145   case AffineExprKind::Mul: {
146     if (!isDense)
147       return false;
148     auto binOp = a.cast<AffineBinaryOpExpr>();
149     return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) &&
150            findAffine(merger, tensor, binOp.getRHS(), dim, isDense);
151   }
152   case AffineExprKind::Constant:
153     return isDense;
154   default:
155     return false;
156   }
157 }
158 
159 /// Helper method to inspect sparse encodings in the tensor types.
160 /// Fills the per-dimension sparsity information for all tensors.
161 /// Returns true if the sparse annotations and affine subscript
162 /// expressions of all tensors are admissable. Returns false if
163 /// no annotations are found or inadmissable constructs occur.
164 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
165   bool annotated = false;
166   for (OpOperand *t : op.getInputAndOutputOperands()) {
167     auto map = op.getTiedIndexingMap(t);
168     auto enc = getSparseTensorEncoding(t->get().getType());
169     if (enc)
170       annotated = true;
171     assert(map.getNumResults() == op.getRank(t));
172     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
173       unsigned tensor = t->getOperandNumber();
174       AffineExpr a = map.getResult(perm(enc, d));
175       if (!findAffine(merger, tensor, a, toDim(enc, d), !enc))
176         return false; // inadmissable affine expression
177     }
178   }
179   return annotated;
180 }
181 
182 /// A DFS helper to compute a topological sort. Note that recursion is
183 /// bounded by the number of implicit loops, which is always small.
184 /// Returns false when a cycle is detected.
185 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
186                        std::vector<unsigned> &topSort,
187                        std::vector<std::vector<bool>> &adjM) {
188   if (visit[i] != 0)
189     return visit[i] != 1; // 1 denotes cycle!
190   visit[i] = 1;
191   for (unsigned j = 0, e = visit.size(); j < e; j++)
192     if (adjM[i][j])
193       if (!topSortDFS(j, visit, topSort, adjM))
194         return false;
195   visit[i] = 2;
196   topSort.push_back(i);
197   return true;
198 }
199 
200 /// Helper method to add all constraints from the indices in one affine
201 /// expression before all indices in the other affine expression. For
202 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3.
203 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
204                                AffineExpr a, AffineExpr b, unsigned fidx) {
205   switch (a.getKind()) {
206   case AffineExprKind::DimId: {
207     unsigned idx = a.cast<AffineDimExpr>().getPosition();
208     if (b)
209       addAffineOrderings(adjM, b, AffineExpr(), idx);
210     else
211       adjM[fidx][idx] = true;
212     break;
213   }
214   case AffineExprKind::Add:
215   case AffineExprKind::Mul: {
216     auto binOp = a.cast<AffineBinaryOpExpr>();
217     addAffineOrderings(adjM, binOp.getLHS(), b, fidx);
218     addAffineOrderings(adjM, binOp.getRHS(), b, fidx);
219     break;
220   }
221   default:
222     break;
223   }
224 }
225 
226 /// Computes a topologically sorted iteration graph for the linalg operation.
227 /// Ensures all tensors are visited in natural index order. This is essential
228 /// for sparse storage formats since these only support access along fixed
229 /// dimensions. Even for dense storage formats, however, the natural index
230 /// order yields innermost unit-stride access with better spatial locality.
231 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
232                                   std::vector<unsigned> &topSort,
233                                   unsigned mask) {
234   // Set up an n x n from/to adjacency matrix of the iteration graph
235   // for the implicit loop indices i_0 .. i_n-1.
236   unsigned n = op.getNumLoops();
237   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
238 
239   // Iterate over the indexing maps of every tensor in the tensor expression.
240   for (OpOperand *t : op.getInputAndOutputOperands()) {
241     auto map = op.getTiedIndexingMap(t);
242     auto enc = getSparseTensorEncoding(t->get().getType());
243     assert(map.getNumDims() == n);
244     // Skip dense tensor constraints when not requested.
245     if (!(mask & SortMask::kIncludeDense) && !enc)
246       continue;
247     // Each tensor expression and optional dimension ordering (row-major
248     // by default) puts an ordering constraint on the loop indices. For
249     // example, the tensor expresion A_ijk forces the ordering i < j < k
250     // on the loop indices if no explicit dimension ordering is given.
251     for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
252       AffineExpr f = map.getResult(perm(enc, d - 1));
253       AffineExpr t = map.getResult(perm(enc, d));
254       addAffineOrderings(adjM, f, t, 0);
255     }
256     // Push unrelated loops into sparse iteration space, so these
257     // will be skipped more often.
258     if (mask & SortMask::kIncludeUndef) {
259       unsigned tensor = t->getOperandNumber();
260       for (unsigned i = 0; i < n; i++)
261         if (merger.isDim(tensor, i, Dim::kSparse))
262           for (unsigned j = 0; j < n; j++)
263             if (merger.isDim(tensor, j, Dim::kUndef))
264               adjM[i][j] = true;
265     }
266   }
267 
268   // Topologically sort the iteration graph to determine loop order.
269   // Report failure for a cyclic iteration graph.
270   topSort.clear();
271   topSort.reserve(n);
272   std::vector<unsigned> visit(n, 0);
273   for (unsigned i = 0; i < n; i++)
274     if (visit[i] == 0)
275       if (!topSortDFS(i, visit, topSort, adjM))
276         return false; // cycle!
277   std::reverse(std::begin(topSort), std::end(topSort));
278   return true;
279 }
280 
281 /// Returns true if tensor has an in-place annotation.
282 static bool isInPlace(Value val) {
283   if (auto arg = val.dyn_cast<BlockArgument>())
284     if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
285       if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
286               arg.getArgNumber(),
287               bufferization::BufferizableOpInterface::kInplaceableAttrName))
288         return attr.getValue();
289   return false;
290 }
291 
292 /// Returns true if tensor materializes uninitialized into the computation.
293 static bool isMaterializing(Value val) {
294   return val.getDefiningOp<linalg::InitTensorOp>() ||
295          val.getDefiningOp<InitOp>();
296 }
297 
298 /// Returns true when the tensor expression is admissable for codegen.
299 /// Since all sparse input tensors are admissable, we just need to check
300 /// whether the out tensor in the tensor expression codegen is admissable.
301 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
302 /// nesting depth when a "truly dynamic" sparse tensor output occurs.
303 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
304                                   std::vector<unsigned> &topSort, unsigned exp,
305                                   OpOperand **sparseOut,
306                                   unsigned &outerParNest) {
307   OpOperand *lhs = op.getOutputOperand(0);
308   unsigned tensor = lhs->getOperandNumber();
309   auto enc = getSparseTensorEncoding(lhs->get().getType());
310   // An non-annotated output tensor is assumed dense, and becomes a random
311   // access n-dim memref. Admissable since insertions cannot occur.
312   if (!enc)
313     return true;
314   // An all-dense annotated "sparse" output tensor becomes a linearized random
315   // access 1-dim memref. Also admissable since insertions cannot occur.
316   bool allDense = true;
317   auto iteratorTypes = op.iterator_types().getValue();
318   unsigned numLoops = iteratorTypes.size();
319   for (unsigned i = 0; i < numLoops; i++)
320     if (merger.isDim(tensor, i, Dim::kSparse)) {
321       allDense = false;
322       break;
323     }
324   if (allDense)
325     return true;
326   // A tensor expression with a sparse output tensor that changes its values
327   // but not its nonzero structure, an operation called "simply dynamic" in
328   // [Bik96,Ch9], is also admissable without special codegen, provided
329   // the tensor's underlying sparse storage scheme can be modified in place.
330   if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get()))
331     return true;
332   // Accept "truly dynamic" if the output tensor materializes uninitialized
333   // into the computation and insertions occur in lexicographic index order.
334   if (isMaterializing(lhs->get())) {
335     unsigned nest = 0;
336     for (unsigned i = 0; i < numLoops; i++) {
337       if (isReductionIterator(iteratorTypes[topSort[i]]))
338         break; // terminate at first reduction
339       nest++;
340     }
341     // Determine admissable dynamic insertion situations:
342     // (1) fully injective, since there are no reductions,
343     // (2) admissable 1-d expansion in innermost dimension.
344     if (nest >= op.getRank(lhs) - 1) {
345       *sparseOut = lhs;
346       outerParNest = nest;
347       return true;
348     }
349   }
350   return false;
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // Sparse compiler synthesis methods (reductions).
355 //===----------------------------------------------------------------------===//
356 
357 /// Maps reduction kind to name encoding.
358 static StringRef getReductionName(Reduction kind) {
359   switch (kind) {
360   case kNoReduc:
361     break;
362   case kSum:
363     return "add";
364   case kProduct:
365     return "mul";
366   case kAnd:
367     return "and";
368   case kOr:
369     return "or";
370   case kXor:
371     return "xor";
372   }
373   llvm_unreachable("unknown reduction kind");
374 }
375 
376 /// Maps operation to reduction.
377 static Reduction getReduction(Kind kind) {
378   switch (kind) {
379   case Kind::kAddF:
380   case Kind::kAddI:
381   case Kind::kSubF:
382   case Kind::kSubI:
383     return kSum;
384   case Kind::kMulF:
385   case Kind::kMulI:
386     return kProduct;
387   case Kind::kAndI:
388     return kAnd;
389   case Kind::kOrI:
390     return kOr;
391   case Kind::kXorI:
392     return kXor;
393   default:
394     llvm_unreachable("unexpected reduction operator");
395   }
396 }
397 
398 /// Generates an initial value for a vector reduction, following the scheme
399 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
400 /// initial scalar value is correctly embedded in the vector reduction value,
401 /// and a straightforward horizontal reduction will complete the operation.
402 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter,
403                                 Location loc, VectorType vtp) {
404   Value r = codegen.redVal;
405   switch (codegen.redKind) {
406   case kNoReduc:
407     break;
408   case kSum:
409   case kXor:
410     // Initialize reduction vector to: | 0 | .. | 0 | r |
411     return rewriter.create<vector::InsertElementOp>(
412         loc, r, constantZero(rewriter, loc, vtp),
413         constantIndex(rewriter, loc, 0));
414   case kProduct:
415     // Initialize reduction vector to: | 1 | .. | 1 | r |
416     return rewriter.create<vector::InsertElementOp>(
417         loc, r, constantOne(rewriter, loc, vtp),
418         constantIndex(rewriter, loc, 0));
419   case kAnd:
420   case kOr:
421     // Initialize reduction vector to: | r | .. | r | r |
422     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
423   }
424   llvm_unreachable("unknown reduction kind");
425 }
426 
427 /// Generates final value for a vector reduction.
428 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter,
429                                Location loc, VectorType vtp) {
430   StringRef name = getReductionName(codegen.redKind);
431   StringAttr kind = rewriter.getStringAttr(name);
432   return rewriter.create<vector::ReductionOp>(loc, vtp.getElementType(), kind,
433                                               codegen.redVal, ValueRange{});
434 }
435 
436 /// Updates scalarized reduction value.
437 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) {
438   assert(codegen.redKind != kNoReduc);
439   codegen.redVal = merger.exp(codegen.redExp).val = reduc;
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // Sparse compiler synthesis methods (statements and expressions).
444 //===----------------------------------------------------------------------===//
445 
446 /// Generates buffer for the output tensor. Note that all sparse kernels
447 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)),
448 /// the output buffer is already initialized to all zeroes and only nonzeroes
449 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)),
450 /// only nonzeroes values are used for the updates and no assumption on the
451 /// original contents of the output buffer is necessary..
452 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
453                              linalg::GenericOp op, MemRefType denseTp,
454                              ArrayRef<Value> args) {
455   Location loc = op.getLoc();
456   Value tensor = op.getOutputOperand(0)->get();
457   // The output tensor simply could materialize from the buffer that will
458   // be generated for the tensor present in the outs() clause. This has
459   // the major advantage that the sparse kernel only updates the nonzero
460   // positions for the output tensor.
461   if (isInPlace(tensor))
462     return rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
463   // By default, a new buffer is allocated which is initialized to the
464   // tensor defined in the outs() clause. This is always correct but
465   // introduces a dense initialization component that may negatively
466   // impact the running complexity of the sparse kernel. If the tensor
467   // materializes into the computation, we need to preserve the zero
468   // initialization assumption of all sparse output buffers.
469   if (isMaterializing(tensor)) {
470     Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
471     Value zero = constantZero(rewriter, loc, denseTp.getElementType());
472     rewriter.create<linalg::FillOp>(loc, zero, alloc);
473     return alloc;
474   }
475   Value init = rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
476   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
477   rewriter.create<memref::CopyOp>(loc, init, alloc);
478   return alloc;
479 }
480 
481 /// Local bufferization of all dense and sparse data structures.
482 /// This code enables testing the first prototype sparse compiler.
483 // TODO: replace this with a proliferated bufferization strategy
484 static void genBuffers(Merger &merger, CodeGen &codegen,
485                        PatternRewriter &rewriter, linalg::GenericOp op) {
486   Location loc = op.getLoc();
487   assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
488   // For every tensor, find lower and upper bound on dimensions, set the
489   // same bounds on loop indices, and obtain dense or sparse buffer(s).
490   SmallVector<Value, 4> args;
491   for (OpOperand *t : op.getInputAndOutputOperands()) {
492     unsigned tensor = t->getOperandNumber();
493     auto shape = op.getShape(t);
494     auto map = op.getTiedIndexingMap(t);
495     auto enc = getSparseTensorEncoding(t->get().getType());
496     // Scan all dimensions of current tensor.
497     args.clear();
498     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
499       AffineExpr a = map.getResult(perm(enc, d));
500       if (a.getKind() != AffineExprKind::DimId)
501         continue; // compound
502       unsigned idx = a.cast<AffineDimExpr>().getPosition();
503       // Handle sparse storage schemes.
504       if (merger.isDim(tensor, idx, Dim::kSparse)) {
505         auto dynShape = {ShapedType::kDynamicSize};
506         auto ptrTp =
507             MemRefType::get(dynShape, getPointerOverheadType(rewriter, enc));
508         auto indTp =
509             MemRefType::get(dynShape, getIndexOverheadType(rewriter, enc));
510         Value dim = constantIndex(rewriter, loc, d);
511         // Generate sparse primitives to obtains pointer and indices.
512         codegen.pointers[tensor][idx] =
513             rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
514         codegen.indices[tensor][idx] =
515             rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim);
516       }
517       // Find upper bound in current dimension.
518       unsigned p = perm(enc, d);
519       Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p);
520       if (ShapedType::isDynamic(shape[p]))
521         args.push_back(up);
522       assert(codegen.highs[tensor][idx] == nullptr);
523       codegen.sizes[idx] = codegen.highs[tensor][idx] = up;
524     }
525     // Perform the required bufferization. Dense inputs materialize
526     // from the input tensors. Dense outputs need special handling.
527     // Sparse inputs use sparse primitives to obtain the values.
528     // We also accept in-place all-dense annotated "sparse" outputs.
529     Type elementType = getElementTypeOrSelf(t->get().getType());
530     if (!enc) {
531       // Non-annotated dense tensors.
532       auto denseTp = MemRefType::get(shape, elementType);
533       if (tensor < op.getNumInputs())
534         codegen.buffers[tensor] =
535             rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, t->get());
536       else
537         codegen.buffers[tensor] =
538             genOutputBuffer(codegen, rewriter, op, denseTp, args);
539     } else if (t == codegen.sparseOut) {
540       // True sparse output needs a lexIdx array.
541       Value rank = constantIndex(rewriter, loc, op.getRank(t));
542       auto dynShape = {ShapedType::kDynamicSize};
543       auto memTp = MemRefType::get(dynShape, rewriter.getIndexType());
544       codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank);
545     } else {
546       // Annotated sparse tensors.
547       auto dynShape = {ShapedType::kDynamicSize};
548       auto sparseTp = MemRefType::get(dynShape, elementType);
549       codegen.buffers[tensor] =
550           rewriter.create<ToValuesOp>(loc, sparseTp, t->get());
551     }
552   }
553 }
554 
555 /// Constructs vector type.
556 static VectorType vectorType(CodeGen &codegen, Type etp) {
557   return VectorType::get(codegen.curVecLength, etp);
558 }
559 
560 /// Constructs vector type from pointer.
561 static VectorType vectorType(CodeGen &codegen, Value ptr) {
562   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
563 }
564 
565 /// Constructs vector iteration mask.
566 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
567                            Value iv, Value lo, Value hi, Value step) {
568   Location loc = iv.getLoc();
569   VectorType mtp = vectorType(codegen, rewriter.getI1Type());
570   // Special case if the vector length evenly divides the trip count (for
571   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
572   // so that all subsequent masked memory operations are immediately folded
573   // into unconditional memory operations.
574   IntegerAttr loInt, hiInt, stepInt;
575   if (matchPattern(lo, m_Constant(&loInt)) &&
576       matchPattern(hi, m_Constant(&hiInt)) &&
577       matchPattern(step, m_Constant(&stepInt))) {
578     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
579       return rewriter.create<vector::BroadcastOp>(
580           loc, mtp, constantI1(rewriter, loc, true));
581   }
582   // Otherwise, generate a vector mask that avoids overrunning the upperbound
583   // during vector execution. Here we rely on subsequent loop optimizations to
584   // avoid executing the mask in all iterations, for example, by splitting the
585   // loop into an unconditional vector loop and a scalar cleanup loop.
586   auto minMap = AffineMap::get(
587       /*dimCount=*/2, /*symbolCount=*/1,
588       {rewriter.getAffineSymbolExpr(0),
589        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
590       rewriter.getContext());
591   Value end =
592       rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step});
593   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
594 }
595 
596 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
597 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
598                            Value ptr, ArrayRef<Value> args) {
599   Location loc = ptr.getLoc();
600   VectorType vtp = vectorType(codegen, ptr);
601   Value pass = constantZero(rewriter, loc, vtp);
602   if (args.back().getType().isa<VectorType>()) {
603     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
604     Value indexVec = args.back();
605     scalarArgs.back() = constantIndex(rewriter, loc, 0);
606     return rewriter.create<vector::GatherOp>(
607         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
608   }
609   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
610                                                codegen.curVecMask, pass);
611 }
612 
613 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
614 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
615                            Value rhs, Value ptr, ArrayRef<Value> args) {
616   Location loc = ptr.getLoc();
617   if (args.back().getType().isa<VectorType>()) {
618     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
619     Value indexVec = args.back();
620     scalarArgs.back() = constantIndex(rewriter, loc, 0);
621     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
622                                        codegen.curVecMask, rhs);
623     return;
624   }
625   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
626                                          rhs);
627 }
628 
629 /// Generates a vectorized invariant. Here we rely on subsequent loop
630 /// optimizations to hoist the invariant broadcast out of the vector loop.
631 static Value genVectorInvariantValue(CodeGen &codegen,
632                                      PatternRewriter &rewriter, Value val) {
633   VectorType vtp = vectorType(codegen, val.getType());
634   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
635 }
636 
637 /// Generates an affine expression.
638 //
639 // TODO: generalize for sparse tensor subscripts
640 //
641 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter,
642                        AffineExpr a, Location loc) {
643   switch (a.getKind()) {
644   case AffineExprKind::DimId: {
645     unsigned idx = a.cast<AffineDimExpr>().getPosition();
646     return codegen.loops[idx]; // universal dense index
647   }
648   case AffineExprKind::Add: {
649     auto binOp = a.cast<AffineBinaryOpExpr>();
650     return rewriter.create<arith::AddIOp>(
651         loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
652         genAffine(codegen, rewriter, binOp.getRHS(), loc));
653   }
654   case AffineExprKind::Mul: {
655     auto binOp = a.cast<AffineBinaryOpExpr>();
656     return rewriter.create<arith::MulIOp>(
657         loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
658         genAffine(codegen, rewriter, binOp.getRHS(), loc));
659   }
660   case AffineExprKind::Constant: {
661     int64_t c = a.cast<AffineConstantExpr>().getValue();
662     return constantIndex(rewriter, loc, c);
663   }
664   default:
665     llvm_unreachable("unexpected affine subscript");
666   }
667 }
668 
669 /// Generates index for load/store on sparse tensor.
670 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) {
671   auto map = op.getTiedIndexingMap(t);
672   auto enc = getSparseTensorEncoding(t->get().getType());
673   AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1));
674   assert(a.getKind() == AffineExprKind::DimId);
675   unsigned idx = a.cast<AffineDimExpr>().getPosition();
676   return codegen.loops[idx];
677 }
678 
679 /// Generates subscript for load/store on a dense or sparse tensor.
680 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter,
681                           linalg::GenericOp op, OpOperand *t,
682                           SmallVector<Value, 4> &args) {
683   unsigned tensor = t->getOperandNumber();
684   auto map = op.getTiedIndexingMap(t);
685   auto enc = getSparseTensorEncoding(t->get().getType());
686   unsigned rank = map.getNumResults();
687   if (enc) {
688     // Note that currently, all sparse subscripts are simple.
689     // TODO: accept affine too?
690     AffineExpr a = map.getResult(perm(enc, rank - 1));
691     assert(a.getKind() == AffineExprKind::DimId);
692     unsigned idx = a.cast<AffineDimExpr>().getPosition();
693     assert(codegen.pidxs[tensor][idx] != nullptr);
694     args.push_back(codegen.pidxs[tensor][idx]); // position index
695   } else {
696     for (unsigned d = 0; d < rank; d++) {
697       AffineExpr a = map.getResult(perm(enc, d));
698       args.push_back(genAffine(codegen, rewriter, a, op.getLoc()));
699     }
700   }
701   return codegen.buffers[tensor];
702 }
703 
704 /// Generates insertion code to implement dynamic tensor load.
705 static Value genInsertionLoad(CodeGen &codegen, PatternRewriter &rewriter,
706                               linalg::GenericOp op, OpOperand *t) {
707   Location loc = op.getLoc();
708   // Direct lexicographic index order, tensor loads as zero.
709   if (!codegen.expValues) {
710     Type tp = getElementTypeOrSelf(t->get().getType());
711     return constantZero(rewriter, loc, tp);
712   }
713   // Load from expanded access pattern.
714   Value index = genIndex(codegen, op, t);
715   return rewriter.create<memref::LoadOp>(loc, codegen.expValues, index);
716 }
717 
718 /// Generates insertion code to implement dynamic tensor store.
719 static void genInsertionStore(CodeGen &codegen, PatternRewriter &rewriter,
720                               linalg::GenericOp op, OpOperand *t, Value rhs) {
721   Location loc = op.getLoc();
722   // Direct insertion in lexicographic index order.
723   if (!codegen.expValues) {
724     rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs);
725     return;
726   }
727   // Generates insertion code along expanded access pattern.
728   //   if (!expFilled[i]) then
729   //     expFilled[i] = true
730   //     expAdded[inserts++] = i
731   //   endif
732   //   values[i] = rhs
733   Value index = genIndex(codegen, op, t);
734   Value fval = constantI1(rewriter, loc, false);
735   Value tval = constantI1(rewriter, loc, true);
736   // If statement.
737   Value filled = rewriter.create<memref::LoadOp>(loc, codegen.expFilled, index);
738   Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
739                                               filled, fval);
740   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, rewriter.getIndexType(),
741                                               cond, /*else=*/true);
742   // True branch.
743   rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
744   rewriter.create<memref::StoreOp>(loc, tval, codegen.expFilled, index);
745   rewriter.create<memref::StoreOp>(loc, index, codegen.expAdded,
746                                    codegen.expCount);
747   Value one = constantIndex(rewriter, loc, 1);
748   Value add = rewriter.create<arith::AddIOp>(loc, codegen.expCount, one);
749   rewriter.create<scf::YieldOp>(loc, add);
750   // False branch.
751   rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
752   rewriter.create<scf::YieldOp>(loc, codegen.expCount);
753   rewriter.setInsertionPointAfter(ifOp);
754   // Value assignment.
755   codegen.expCount = ifOp.getResult(0);
756   rewriter.create<memref::StoreOp>(loc, rhs, codegen.expValues, index);
757 }
758 
759 /// Generates a load on a dense or sparse tensor.
760 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
761                            PatternRewriter &rewriter, linalg::GenericOp op,
762                            unsigned exp) {
763   // Test if the load was hoisted to a higher loop nest.
764   Value val = merger.exp(exp).val;
765   if (val) {
766     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
767       return genVectorInvariantValue(codegen, rewriter, val);
768     return val;
769   }
770   // Load during insertion.
771   OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
772   if (t == codegen.sparseOut)
773     return genInsertionLoad(codegen, rewriter, op, t);
774   // Actual load.
775   SmallVector<Value, 4> args;
776   Value ptr = genSubscript(codegen, rewriter, op, t, args);
777   if (codegen.curVecLength > 1)
778     return genVectorLoad(codegen, rewriter, ptr, args);
779   return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
780 }
781 
782 /// Generates a store on a dense or sparse tensor.
783 static void genTensorStore(Merger &merger, CodeGen &codegen,
784                            PatternRewriter &rewriter, linalg::GenericOp op,
785                            Value rhs) {
786   Location loc = op.getLoc();
787   // Test if this is a scalarized reduction.
788   if (codegen.redVal) {
789     if (codegen.curVecLength > 1)
790       rhs = rewriter.create<arith::SelectOp>(loc, codegen.curVecMask, rhs,
791                                              codegen.redVal);
792     updateReduc(merger, codegen, rhs);
793     return;
794   }
795   // Store during insertion.
796   OpOperand *t = op.getOutputOperand(0);
797   if (t == codegen.sparseOut) {
798     genInsertionStore(codegen, rewriter, op, t, rhs);
799     return;
800   }
801   // Actual store.
802   SmallVector<Value, 4> args;
803   Value ptr = genSubscript(codegen, rewriter, op, t, args);
804   if (codegen.curVecLength > 1)
805     genVectorStore(codegen, rewriter, rhs, ptr, args);
806   else
807     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
808 }
809 
810 /// Generates a pointer/index load from the sparse storage scheme. Narrower
811 /// data types need to be zero extended before casting the value into the
812 /// index type used for looping and indexing.
813 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
814                      Value ptr, Value s) {
815   // See https://llvm.org/docs/GetElementPtr.html for some background on
816   // the complications described below.
817   if (codegen.curVecLength > 1) {
818     // Since the index vector is used in a subsequent gather/scatter operations,
819     // which effectively defines an unsigned pointer + signed index, we must
820     // zero extend the vector to an index width. For 8-bit and 16-bit values,
821     // an 32-bit index width suffices. For 32-bit values, zero extending the
822     // elements into 64-bit loses some performance since the 32-bit indexed
823     // gather/scatter is more efficient than the 64-bit index variant (if the
824     // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
825     // preserve this performance). For 64-bit values, there is no good way
826     // to state that the indices are unsigned, with creates the potential of
827     // incorrect address calculations in the unlikely case we need such
828     // extremely large offsets.
829     Type etp = ptr.getType().cast<MemRefType>().getElementType();
830     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
831     if (!etp.isa<IndexType>()) {
832       if (etp.getIntOrFloatBitWidth() < 32)
833         vload = rewriter.create<arith::ExtUIOp>(
834             loc, vectorType(codegen, rewriter.getI32Type()), vload);
835       else if (etp.getIntOrFloatBitWidth() < 64 &&
836                !codegen.options.enableSIMDIndex32)
837         vload = rewriter.create<arith::ExtUIOp>(
838             loc, vectorType(codegen, rewriter.getI64Type()), vload);
839     }
840     return vload;
841   }
842   // For the scalar case, we simply zero extend narrower indices into 64-bit
843   // values before casting to index without a performance penalty. Here too,
844   // however, indices that already are 64-bit, in theory, cannot express the
845   // full range as explained above.
846   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
847   if (!load.getType().isa<IndexType>()) {
848     if (load.getType().getIntOrFloatBitWidth() < 64)
849       load = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), load);
850     load =
851         rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), load);
852   }
853   return load;
854 }
855 
856 /// Generates an invariant value.
857 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
858                                PatternRewriter &rewriter, unsigned exp) {
859   Value val = merger.exp(exp).val;
860   if (codegen.curVecLength > 1)
861     return genVectorInvariantValue(codegen, rewriter, val);
862   return val;
863 }
864 
865 /// Generates an address computation "sz * p + i".
866 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
867                         Location loc, Value size, Value p, Value i) {
868   Value mul = rewriter.create<arith::MulIOp>(loc, size, p);
869   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
870     Value inv =
871         rewriter.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul);
872     mul = genVectorInvariantValue(codegen, rewriter, inv);
873   }
874   return rewriter.create<arith::AddIOp>(loc, mul, i);
875 }
876 
877 /// Recursively generates tensor expression.
878 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
879                     linalg::GenericOp op, unsigned exp) {
880   Location loc = op.getLoc();
881   if (exp == -1u)
882     return Value();
883   if (merger.exp(exp).kind == Kind::kTensor)
884     return genTensorLoad(merger, codegen, rewriter, op, exp);
885   if (merger.exp(exp).kind == Kind::kInvariant)
886     return genInvariantValue(merger, codegen, rewriter, exp);
887   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
888   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
889   return merger.buildExp(rewriter, loc, exp, v0, v1);
890 }
891 
892 /// Determines if affine expression is invariant.
893 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
894                               unsigned ldx, bool &atLevel) {
895   switch (a.getKind()) {
896   case AffineExprKind::DimId: {
897     unsigned idx = a.cast<AffineDimExpr>().getPosition();
898     if (idx == ldx)
899       atLevel = true;
900     return codegen.loops[idx] != nullptr; // no longer in play?
901   }
902   case AffineExprKind::Add:
903   case AffineExprKind::Mul: {
904     auto binOp = a.cast<AffineBinaryOpExpr>();
905     return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) &&
906            isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel);
907   }
908   default:
909     return true;
910   }
911 }
912 
913 /// Hoists loop invariant tensor loads for which indices have been exhausted.
914 static void genInvariants(Merger &merger, CodeGen &codegen,
915                           PatternRewriter &rewriter, linalg::GenericOp op,
916                           unsigned exp, unsigned ldx, bool atStart,
917                           Kind last = Kind::kTensor) {
918   if (exp == -1u)
919     return;
920   if (merger.exp(exp).kind == Kind::kTensor) {
921     // Inspect tensor indices.
922     bool atLevel = ldx == -1u;
923     OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
924     auto map = op.getTiedIndexingMap(t);
925     auto enc = getSparseTensorEncoding(t->get().getType());
926     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
927       AffineExpr a = map.getResult(perm(enc, d));
928       if (!isInvariantAffine(codegen, a, ldx, atLevel))
929         return; // still in play
930     }
931     // All exhausted at this level (atLevel denotes exactly at this level).
932     if (!atLevel)
933       return;
934     OpOperand *lhs = op.getOutputOperand(0);
935     if (lhs == t) {
936       // Start or end a scalarized reduction
937       if (atStart) {
938         Value load = genTensorLoad(merger, codegen, rewriter, op, exp);
939         codegen.redKind = getReduction(last);
940         codegen.redExp = exp;
941         updateReduc(merger, codegen, load);
942       } else {
943         Value redVal = codegen.redVal;
944         updateReduc(merger, codegen, Value());
945         codegen.redExp = -1u;
946         codegen.redKind = kNoReduc;
947         genTensorStore(merger, codegen, rewriter, op, redVal);
948       }
949     } else {
950       // Start or end loop invariant hoisting of a tensor load.
951       merger.exp(exp).val =
952           atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
953     }
954   } else if (merger.exp(exp).kind != Kind::kInvariant) {
955     // Traverse into the binary operations. Note that we only hoist
956     // tensor loads, since subsequent MLIR/LLVM passes know how to
957     // deal with all other kinds of derived loop invariants.
958     Kind last = merger.exp(exp).kind;
959     unsigned e0 = merger.exp(exp).children.e0;
960     unsigned e1 = merger.exp(exp).children.e1;
961     genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last);
962     genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last);
963   }
964 }
965 
966 /// Generates an expanded access pattern in innermost dimension.
967 static void genExpansion(Merger &merger, CodeGen &codegen,
968                          PatternRewriter &rewriter, linalg::GenericOp op,
969                          unsigned at, bool atStart) {
970   OpOperand *lhs = codegen.sparseOut;
971   if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 ||
972       at != codegen.outerParNest)
973     return; // not needed at this level
974   // Generate start or end of an expanded access pattern.
975   Value tensor = lhs->get();
976   Location loc = op.getLoc();
977   if (atStart) {
978     auto dynShape = {ShapedType::kDynamicSize};
979     Type etp = tensor.getType().cast<ShapedType>().getElementType();
980     Type t1 = MemRefType::get(dynShape, etp);
981     Type t2 = MemRefType::get(dynShape, rewriter.getI1Type());
982     Type t3 = MemRefType::get(dynShape, rewriter.getIndexType());
983     Type t4 = rewriter.getIndexType();
984     auto res =
985         rewriter.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
986     assert(res.getNumResults() == 4);
987     assert(!codegen.expValues);
988     codegen.expValues = res.getResult(0);
989     codegen.expFilled = res.getResult(1);
990     codegen.expAdded = res.getResult(2);
991     codegen.expCount = res.getResult(3);
992   } else {
993     assert(codegen.expValues);
994     rewriter.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues,
995                                 codegen.expFilled, codegen.expAdded,
996                                 codegen.expCount);
997     codegen.expValues = codegen.expFilled = codegen.expAdded =
998         codegen.expCount = Value();
999   }
1000 }
1001 
1002 /// Generates initialization code for the subsequent loop sequence at
1003 /// current index level. Returns true if the loop sequence needs to
1004 /// maintain the universal index.
1005 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1006                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1007                     unsigned at, BitVector &inits) {
1008   bool needsUniv = false;
1009   Location loc = op.getLoc();
1010   unsigned idx = topSort[at];
1011 
1012   // Initialize sparse positions.
1013   for (unsigned b = 0, be = inits.size(); b < be; b++) {
1014     if (inits[b]) {
1015       unsigned tensor = merger.tensor(b);
1016       assert(idx == merger.index(b));
1017       if (merger.isDim(b, Dim::kSparse)) {
1018         // Initialize sparse index.
1019         unsigned pat = at;
1020         for (; pat != 0; pat--) {
1021           if (codegen.pidxs[tensor][topSort[pat - 1]])
1022             break;
1023         }
1024         Value ptr = codegen.pointers[tensor][idx];
1025         Value one = constantIndex(rewriter, loc, 1);
1026         Value p0 = (pat == 0) ? constantIndex(rewriter, loc, 0)
1027                               : codegen.pidxs[tensor][topSort[pat - 1]];
1028         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
1029         Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one);
1030         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
1031       } else {
1032         // Dense index still in play.
1033         needsUniv = true;
1034       }
1035     }
1036   }
1037 
1038   // Initialize the universal dense index.
1039   codegen.loops[idx] = constantIndex(rewriter, loc, 0);
1040   return needsUniv;
1041 }
1042 
1043 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
1044 /// operation is a candidate. Whether it is actually converted to SIMD code
1045 /// depends on the requested strategy.
1046 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
1047   switch (codegen.options.vectorizationStrategy) {
1048   case SparseVectorizationStrategy::kNone:
1049     return false;
1050   case SparseVectorizationStrategy::kDenseInnerLoop:
1051     return isInner && !isSparse;
1052   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
1053     return isInner;
1054   }
1055   llvm_unreachable("unexpected vectorization strategy");
1056 }
1057 
1058 /// Returns parallelization strategy. Any implicit loop in the Linalg operation
1059 /// that is marked "parallel" is a candidate. Whether it is actually converted
1060 /// to a parallel operation depends on the requested strategy.
1061 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
1062                           bool isSparse, bool isVector) {
1063   switch (codegen.options.parallelizationStrategy) {
1064   case SparseParallelizationStrategy::kNone:
1065     return false;
1066   case SparseParallelizationStrategy::kDenseOuterLoop:
1067     return isOuter && !isSparse && !isReduction && !isVector;
1068   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
1069     return isOuter && !isReduction && !isVector;
1070   case SparseParallelizationStrategy::kDenseAnyLoop:
1071     return !isSparse && !isReduction && !isVector;
1072   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
1073     return !isReduction && !isVector;
1074   }
1075   llvm_unreachable("unexpected parallelization strategy");
1076 }
1077 
1078 /// Checks unit stride for dense tensors. The iteration graph may have ignored
1079 /// dense access patterns in order to avoid cycles (sparse access patterns are
1080 /// always placed innermost), but that means dense access has become strided.
1081 /// This prevents effective vectorization.
1082 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
1083                              unsigned idx) {
1084   for (OpOperand *t : op.getInputAndOutputOperands()) {
1085     if (!getSparseTensorEncoding(t->get().getType())) {
1086       auto map = op.getTiedIndexingMap(t);
1087       for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
1088         AffineExpr a = map.getResult(d);
1089         // Report non-unit stride if innermost index appears at an outer
1090         // dimension (true non-unit stride) or if the innermost index appears
1091         // in a compound subscript in the innermost dimension. Even if the
1092         // latter is unit stride, it does not play well with scatter/gather.
1093         // TODO: accept unit stride affine innermost like a[i,j+k+1]?
1094         if (a.isFunctionOfDim(idx) &&
1095             ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId)))
1096           return false;
1097       }
1098     }
1099   }
1100   return true;
1101 }
1102 
1103 /// Generates a for-loop on a single index.
1104 static Operation *genFor(Merger &merger, CodeGen &codegen,
1105                          PatternRewriter &rewriter, linalg::GenericOp op,
1106                          bool isOuter, bool isInner, unsigned idx,
1107                          BitVector &indices) {
1108   unsigned fb = indices.find_first();
1109   unsigned tensor = merger.tensor(fb);
1110   assert(idx == merger.index(fb));
1111   auto iteratorTypes = op.iterator_types().getValue();
1112   bool isReduction = isReductionIterator(iteratorTypes[idx]);
1113   bool isSparse = merger.isDim(fb, Dim::kSparse);
1114   bool isVector = !codegen.sparseOut &&
1115                   isVectorFor(codegen, isInner, isSparse) &&
1116                   denseUnitStrides(merger, op, idx);
1117   bool isParallel =
1118       !codegen.sparseOut &&
1119       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
1120 
1121   // Prepare vector length.
1122   if (isVector)
1123     codegen.curVecLength = codegen.options.vectorLength;
1124 
1125   // Loop bounds and increment.
1126   Location loc = op.getLoc();
1127   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
1128   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
1129   Value step = constantIndex(rewriter, loc, codegen.curVecLength);
1130 
1131   // Emit a parallel loop.
1132   if (isParallel) {
1133     assert(!isVector);
1134     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
1135     if (isSparse)
1136       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
1137     else
1138       codegen.loops[idx] = parOp.getInductionVars()[0];
1139     rewriter.setInsertionPointToStart(parOp.getBody());
1140     return parOp;
1141   }
1142 
1143   // Emit a sequential or vector loop.
1144   SmallVector<Value, 4> operands;
1145   if (codegen.redVal) {
1146     // In a vector loop, bring reduction into SIMD form, if not already.
1147     if (isVector && !codegen.redVal.getType().isa<VectorType>()) {
1148       VectorType vtp = vectorType(codegen, codegen.redVal.getType());
1149       Value vred = genVectorReducInit(codegen, rewriter, loc, vtp);
1150       updateReduc(merger, codegen, vred);
1151     }
1152     operands.push_back(codegen.redVal);
1153   }
1154   if (codegen.expValues)
1155     operands.push_back(codegen.expCount);
1156   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
1157   if (codegen.redVal)
1158     updateReduc(merger, codegen, forOp.getRegionIterArgs().front());
1159   if (codegen.expValues)
1160     codegen.expCount = forOp.getRegionIterArgs().back();
1161   // Assign induction variable to sparse or dense index.
1162   Value iv = forOp.getInductionVar();
1163   if (isSparse)
1164     codegen.pidxs[tensor][idx] = iv;
1165   else
1166     codegen.loops[idx] = iv;
1167   rewriter.setInsertionPointToStart(forOp.getBody());
1168   // Share vector iteration mask between all subsequent loads/stores.
1169   if (isVector)
1170     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
1171   return forOp;
1172 }
1173 
1174 /// Emit a while-loop for co-iteration over multiple indices.
1175 static Operation *genWhile(Merger &merger, CodeGen &codegen,
1176                            PatternRewriter &rewriter, linalg::GenericOp op,
1177                            unsigned idx, bool needsUniv,
1178                            BitVector &indices) {
1179   SmallVector<Type, 4> types;
1180   SmallVector<Value, 4> operands;
1181   // Construct the while-loop with a parameter for each index.
1182   Type indexType = rewriter.getIndexType();
1183   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1184     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1185       unsigned tensor = merger.tensor(b);
1186       assert(idx == merger.index(b));
1187       types.push_back(indexType);
1188       operands.push_back(codegen.pidxs[tensor][idx]);
1189     }
1190   }
1191   if (codegen.redVal) {
1192     types.push_back(codegen.redVal.getType());
1193     operands.push_back(codegen.redVal);
1194   }
1195   if (codegen.expValues) {
1196     types.push_back(indexType);
1197     operands.push_back(codegen.expCount);
1198   }
1199   if (needsUniv) {
1200     types.push_back(indexType);
1201     operands.push_back(codegen.loops[idx]);
1202   }
1203   assert(types.size() == operands.size());
1204   Location loc = op.getLoc();
1205   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
1206 
1207   SmallVector<Location> locs(types.size(), loc);
1208   Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, locs);
1209   Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, types, locs);
1210 
1211   // Build the "before" region, which effectively consists
1212   // of a conjunction of "i < upper" tests on all induction.
1213   rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
1214   Value cond;
1215   unsigned o = 0;
1216   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1217     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1218       unsigned tensor = merger.tensor(b);
1219       assert(idx == merger.index(b));
1220       Value op1 = before->getArgument(o);
1221       Value op2 = codegen.highs[tensor][idx];
1222       Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
1223                                                  op1, op2);
1224       cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc;
1225       codegen.pidxs[tensor][idx] = after->getArgument(o++);
1226     }
1227   }
1228   if (codegen.redVal)
1229     updateReduc(merger, codegen, after->getArgument(o++));
1230   if (codegen.expValues)
1231     codegen.expCount = after->getArgument(o++);
1232   if (needsUniv)
1233     codegen.loops[idx] = after->getArgument(o++);
1234   assert(o == operands.size());
1235   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
1236   rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
1237   return whileOp;
1238 }
1239 
1240 /// Generates a for-loop or a while-loop, depending on whether it implements
1241 /// singleton iteration or co-iteration over the given conjunction.
1242 static Operation *genLoop(Merger &merger, CodeGen &codegen,
1243                           PatternRewriter &rewriter, linalg::GenericOp op,
1244                           std::vector<unsigned> &topSort, unsigned at,
1245                           bool needsUniv, BitVector &indices) {
1246   unsigned idx = topSort[at];
1247   if (indices.count() == 1) {
1248     bool isOuter = at == 0;
1249     bool isInner = at == topSort.size() - 1;
1250     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
1251                   indices);
1252   }
1253   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
1254 }
1255 
1256 /// Generates the local variables for this loop, consisting of the sparse
1257 /// indices, restored universal dense index, and dense positions.
1258 static void genLocals(Merger &merger, CodeGen &codegen,
1259                       PatternRewriter &rewriter, linalg::GenericOp op,
1260                       std::vector<unsigned> &topSort, unsigned at,
1261                       bool needsUniv, BitVector &locals) {
1262   Location loc = op.getLoc();
1263   unsigned idx = topSort[at];
1264 
1265   // Initialize sparse indices.
1266   Value min;
1267   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1268     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1269       unsigned tensor = merger.tensor(b);
1270       assert(idx == merger.index(b));
1271       Value ptr = codegen.indices[tensor][idx];
1272       Value s = codegen.pidxs[tensor][idx];
1273       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1274       codegen.idxs[tensor][idx] = load;
1275       if (!needsUniv) {
1276         if (min) {
1277           Value cmp = rewriter.create<arith::CmpIOp>(
1278               loc, arith::CmpIPredicate::ult, load, min);
1279           min = rewriter.create<arith::SelectOp>(loc, cmp, load, min);
1280         } else {
1281           min = load;
1282         }
1283       }
1284     }
1285   }
1286 
1287   // Merge dense universal index over minimum.
1288   if (min) {
1289     assert(!needsUniv);
1290     codegen.loops[idx] = min;
1291   }
1292 
1293   // Initialize dense positions. Note that we generate dense indices of the
1294   // output tensor unconditionally, since they may not appear in the lattice,
1295   // but may be needed for linearized codegen.
1296   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1297     if ((locals[b] || merger.isOutTensor(b, idx)) &&
1298         merger.isDim(b, Dim::kDense)) {
1299       unsigned tensor = merger.tensor(b);
1300       assert(idx == merger.index(b));
1301       unsigned pat = at;
1302       for (; pat != 0; pat--)
1303         if (codegen.pidxs[tensor][topSort[pat - 1]])
1304           break;
1305       Value p = (pat == 0) ? constantIndex(rewriter, loc, 0)
1306                            : codegen.pidxs[tensor][topSort[pat - 1]];
1307       codegen.pidxs[tensor][idx] = genAddress(
1308           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1309     }
1310   }
1311 
1312   // Move the insertion indices in lexicographic index order. During access
1313   // pattern expansion, we can skip setting the innermost dimension.
1314   if (codegen.sparseOut && !codegen.expValues) {
1315     Value pos = constantIndex(rewriter, loc, at);
1316     rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx,
1317                                      pos);
1318   }
1319 }
1320 
1321 /// Generates the induction structure for a while-loop.
1322 static void genWhileInduction(Merger &merger, CodeGen &codegen,
1323                               PatternRewriter &rewriter, linalg::GenericOp op,
1324                               unsigned idx, bool needsUniv,
1325                               BitVector &induction,
1326                               scf::WhileOp whileOp) {
1327   Location loc = op.getLoc();
1328   // Finalize each else branch of all if statements.
1329   if (codegen.redVal || codegen.expValues) {
1330     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
1331                rewriter.getInsertionBlock()->getParentOp())) {
1332       unsigned y = 0;
1333       SmallVector<Value, 4> yields;
1334       if (codegen.redVal) {
1335         yields.push_back(codegen.redVal);
1336         updateReduc(merger, codegen, ifOp.getResult(y++));
1337       }
1338       if (codegen.expValues) {
1339         yields.push_back(codegen.expCount);
1340         codegen.expCount = ifOp->getResult(y++);
1341       }
1342       assert(y == yields.size());
1343       rewriter.create<scf::YieldOp>(loc, yields);
1344       rewriter.setInsertionPointAfter(ifOp);
1345     }
1346   }
1347   rewriter.setInsertionPointToEnd(&whileOp.getAfter().front());
1348   // Finalize the induction. Note that the induction could be performed
1349   // in the individual if-branches to avoid re-evaluating the conditions.
1350   // However, that would result in a rather elaborate forest of yield
1351   // instructions during code generation. Moreover, performing the induction
1352   // after the if-statements more closely resembles code generated by TACO.
1353   unsigned o = 0;
1354   SmallVector<Value, 4> operands;
1355   Value one = constantIndex(rewriter, loc, 1);
1356   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1357     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1358       unsigned tensor = merger.tensor(b);
1359       assert(idx == merger.index(b));
1360       Value op1 = codegen.idxs[tensor][idx];
1361       Value op2 = codegen.loops[idx];
1362       Value op3 = codegen.pidxs[tensor][idx];
1363       Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1364                                                  op1, op2);
1365       Value add = rewriter.create<arith::AddIOp>(loc, op3, one);
1366       operands.push_back(rewriter.create<arith::SelectOp>(loc, cmp, add, op3));
1367       codegen.pidxs[tensor][idx] = whileOp->getResult(o++);
1368     }
1369   }
1370   if (codegen.redVal) {
1371     operands.push_back(codegen.redVal);
1372     updateReduc(merger, codegen, whileOp->getResult(o++));
1373   }
1374   if (codegen.expValues) {
1375     operands.push_back(codegen.expCount);
1376     codegen.expCount = whileOp->getResult(o++);
1377   }
1378   if (needsUniv) {
1379     operands.push_back(
1380         rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one));
1381     codegen.loops[idx] = whileOp->getResult(o++);
1382   }
1383   assert(o == operands.size());
1384   rewriter.create<scf::YieldOp>(loc, operands);
1385   rewriter.setInsertionPointAfter(whileOp);
1386 }
1387 
1388 /// Generates the induction structure for a for-loop.
1389 static void genForInduction(Merger &merger, CodeGen &codegen,
1390                             PatternRewriter &rewriter, linalg::GenericOp op,
1391                             Operation *loop) {
1392   Location loc = op.getLoc();
1393   unsigned o = 0;
1394   SmallVector<Value, 4> operands;
1395   if (codegen.redVal) {
1396     operands.push_back(codegen.redVal);
1397     updateReduc(merger, codegen, loop->getResult(o++));
1398   }
1399   if (codegen.expValues) {
1400     operands.push_back(codegen.expCount);
1401     codegen.expCount = loop->getResult(o++);
1402   }
1403   assert(o == operands.size());
1404   if (o > 0)
1405     rewriter.create<scf::YieldOp>(loc, operands);
1406   rewriter.setInsertionPointAfter(loop);
1407 }
1408 
1409 /// Generates a single if-statement within a while-loop.
1410 static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1411                        PatternRewriter &rewriter, linalg::GenericOp op,
1412                        unsigned idx, BitVector &conditions) {
1413   Location loc = op.getLoc();
1414   SmallVector<Type, 4> types;
1415   Value cond;
1416   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1417     if (conditions[b]) {
1418       unsigned tensor = merger.tensor(b);
1419       assert(idx == merger.index(b));
1420       Value clause;
1421       if (merger.isDim(b, Dim::kSparse)) {
1422         Value op1 = codegen.idxs[tensor][idx];
1423         Value op2 = codegen.loops[idx];
1424         clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1425                                                 op1, op2);
1426       } else {
1427         clause = constantI1(rewriter, loc, true);
1428       }
1429       cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause;
1430     }
1431   }
1432   if (codegen.redVal)
1433     types.push_back(codegen.redVal.getType());
1434   if (codegen.expValues)
1435     types.push_back(rewriter.getIndexType());
1436   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true);
1437   rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1438   return ifOp;
1439 }
1440 
1441 /// Generates end of true branch of if-statement within a while-loop.
1442 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1443                   linalg::GenericOp op, scf::IfOp ifOp, Operation *loop,
1444                   Value redInput, Value cntInput) {
1445   SmallVector<Value, 4> operands;
1446   if (codegen.redVal) {
1447     operands.push_back(codegen.redVal);
1448     updateReduc(merger, codegen, redInput);
1449   }
1450   if (codegen.expValues) {
1451     operands.push_back(codegen.expCount);
1452     codegen.expCount = cntInput;
1453   }
1454   if (!operands.empty())
1455     rewriter.create<scf::YieldOp>(op.getLoc(), operands);
1456   rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1457 }
1458 
1459 //===----------------------------------------------------------------------===//
1460 // Sparse compiler synthesis methods (loop sequence).
1461 //===----------------------------------------------------------------------===//
1462 
1463 /// Starts a loop sequence at given level. Returns true if
1464 /// the universal loop index must be maintained at this level.
1465 static bool startLoopSeq(Merger &merger, CodeGen &codegen,
1466                          PatternRewriter &rewriter, linalg::GenericOp op,
1467                          std::vector<unsigned> &topSort, unsigned exp,
1468                          unsigned at, unsigned idx, unsigned ldx,
1469                          unsigned lts) {
1470   assert(codegen.curVecLength == 1);
1471   assert(!codegen.loops[idx]);
1472   // Emit invariants at this loop sequence level.
1473   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true);
1474   // Emit access pattern expansion for sparse tensor output.
1475   genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/true);
1476   // Emit further intitialization at this loop sequence level.
1477   unsigned l0 = merger.set(lts)[0];
1478   bool needsUniv =
1479       genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits);
1480   // Maintain the universal index only if it is actually
1481   // consumed by a subsequent lattice point.
1482   if (needsUniv) {
1483     unsigned lsize = merger.set(lts).size();
1484     for (unsigned i = 1; i < lsize; i++) {
1485       unsigned li = merger.set(lts)[i];
1486       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse))
1487         return true;
1488     }
1489   }
1490   return false;
1491 }
1492 
1493 /// Starts a single loop in current sequence.
1494 static Operation *startLoop(Merger &merger, CodeGen &codegen,
1495                             PatternRewriter &rewriter, linalg::GenericOp op,
1496                             std::vector<unsigned> &topSort, unsigned at,
1497                             unsigned li, bool needsUniv) {
1498   assert(codegen.curVecLength == 1);
1499   // Emit the for/while-loop control.
1500   Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at,
1501                             needsUniv, merger.lat(li).simple);
1502   // Emit the locals for this loop.
1503   genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1504             merger.lat(li).bits);
1505   return loop;
1506 }
1507 
1508 /// Ends a single loop in current sequence. Returns new values for needsUniv.
1509 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1510                     linalg::GenericOp op, Operation *loop, unsigned idx,
1511                     unsigned li, bool needsUniv) {
1512   codegen.curVecLength = 1;
1513   // End a while-loop.
1514   if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1515     genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
1516                       merger.lat(li).bits, whileOp);
1517     return needsUniv;
1518   }
1519   // End a for-loop.
1520   genForInduction(merger, codegen, rewriter, op, loop);
1521   return false;
1522 }
1523 
1524 /// Ends a loop sequence at given level.
1525 static void endLoopSeq(Merger &merger, CodeGen &codegen,
1526                        PatternRewriter &rewriter, linalg::GenericOp op,
1527                        unsigned exp, unsigned at, unsigned idx, unsigned ldx) {
1528   assert(codegen.curVecLength == 1);
1529   codegen.loops[idx] = Value();
1530   // Bring a pending reduction back from SIMD form when sequence ends.
1531   if (codegen.redVal)
1532     if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>())
1533       updateReduc(merger, codegen,
1534                   genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp));
1535   // Unmark bookkeeping of invariants and loop index.
1536   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false);
1537   // Finalize access pattern expansion for sparse tensor output.
1538   genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/false);
1539 }
1540 
1541 /// Recursively generates code while computing iteration lattices in order
1542 /// to manage the complexity of implementing co-iteration over unions
1543 /// and intersections of sparse iterations spaces.
1544 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1545                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1546                     unsigned exp, unsigned at) {
1547   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1548   if (at == topSort.size()) {
1549     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1550     genTensorStore(merger, codegen, rewriter, op, rhs);
1551     return;
1552   }
1553 
1554   // Construct iteration lattices for current loop index, with L0 at top.
1555   unsigned idx = topSort[at];
1556   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1557   unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
1558 
1559   // Start a loop sequence.
1560   bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at,
1561                                 idx, ldx, lts);
1562 
1563   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1564   unsigned lsize = merger.set(lts).size();
1565   for (unsigned i = 0; i < lsize; i++) {
1566     // Start a loop.
1567     unsigned li = merger.set(lts)[i];
1568     Operation *loop =
1569         startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv);
1570 
1571     // Visit all lattices points with Li >= Lj to generate the
1572     // loop-body, possibly with if statements for coiteration.
1573     Value redInput = codegen.redVal;
1574     Value cntInput = codegen.expCount;
1575     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1576     for (unsigned j = 0; j < lsize; j++) {
1577       unsigned lj = merger.set(lts)[j];
1578       unsigned ej = merger.lat(lj).exp;
1579       if (li == lj || merger.latGT(li, lj)) {
1580         // Recurse into body of each branch.
1581         if (isWhile) {
1582           scf::IfOp ifOp =
1583               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1584           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1585           endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput);
1586         } else {
1587           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1588         }
1589       }
1590     }
1591 
1592     // End a loop.
1593     needsUniv =
1594         endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv);
1595   }
1596 
1597   // End a loop sequence.
1598   endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx);
1599 }
1600 
1601 /// Converts the result computed by the sparse kernel into the required form.
1602 static void genResult(Merger &merger, CodeGen &codegen,
1603                       PatternRewriter &rewriter, linalg::GenericOp op) {
1604   OpOperand *lhs = op.getOutputOperand(0);
1605   Type resType = lhs->get().getType();
1606   if (getSparseTensorEncoding(resType)) {
1607     // The sparse tensor rematerializes from the original sparse tensor's
1608     // underlying sparse storage format.
1609     rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(),
1610                                         codegen.sparseOut == lhs);
1611   } else {
1612     // To rematerialize an non-annotated tensor, simply load it
1613     // from the bufferized value.
1614     Value val = codegen.buffers.back(); // value array
1615     rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1616   }
1617 }
1618 
1619 //===----------------------------------------------------------------------===//
1620 // Sparse compiler rewriting methods.
1621 //===----------------------------------------------------------------------===//
1622 
1623 namespace {
1624 
1625 /// Sparse rewriting rule for generic Lingalg operation.
1626 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1627 public:
1628   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1629       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1630 
1631   LogicalResult matchAndRewrite(linalg::GenericOp op,
1632                                 PatternRewriter &rewriter) const override {
1633     // Detects sparse annotations and translate the per-dimension sparsity
1634     // information for all tensors to loop indices in the kernel.
1635     assert(op.getNumOutputs() == 1);
1636     unsigned numTensors = op.getNumInputsAndOutputs();
1637     unsigned numLoops = op.iterator_types().getValue().size();
1638     Merger merger(numTensors, numLoops);
1639     if (!findSparseAnnotations(merger, op))
1640       return failure();
1641 
1642     // Computes a topologically sorted iteration graph to ensure
1643     // tensors are visited in natural index order. Fails on cycles.
1644     // This assumes that higher-level passes have already put the
1645     // tensors in each tensor expression in a feasible order.
1646     std::vector<unsigned> topSort;
1647     if (!computeIterationGraph(merger, op, topSort,
1648                                SortMask::kIncludeUndef |
1649                                    SortMask::kIncludeDense) &&
1650         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) &&
1651         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) &&
1652         !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly))
1653       return failure();
1654 
1655     // Builds the tensor expression for the Linalg operation in SSA form.
1656     Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
1657     if (!optExp.hasValue())
1658       return failure();
1659     unsigned exp = optExp.getValue();
1660 
1661     // Rejects an inadmissable tensor expression.
1662     OpOperand *sparseOut = nullptr;
1663     unsigned outerParNest = 0;
1664     if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut,
1665                                outerParNest))
1666       return failure();
1667 
1668     // Recursively generates code.
1669     merger.setHasSparseOut(sparseOut != nullptr);
1670     CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest);
1671     genBuffers(merger, codegen, rewriter, op);
1672     genStmt(merger, codegen, rewriter, op, topSort, exp, 0);
1673     genResult(merger, codegen, rewriter, op);
1674     return success();
1675   }
1676 
1677 private:
1678   /// Options to control sparse code generation.
1679   SparsificationOptions options;
1680 };
1681 
1682 } // namespace
1683 
1684 /// Populates the given patterns list with rewriting rules required for
1685 /// the sparsification of linear algebra operations.
1686 void mlir::populateSparsificationPatterns(
1687     RewritePatternSet &patterns, const SparsificationOptions &options) {
1688   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1689 }
1690