1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements converting sparse tensor types to actual sparse code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodegenUtils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
17 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/SCF/Transforms.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
25 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
26 #include "mlir/Dialect/Vector/IR/VectorOps.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/IR/TensorEncoding.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 //===----------------------------------------------------------------------===//
35 // Declarations of data structures.
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 
40 // Iteration graph sorting.
41 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 };
42 
43 // Reduction kinds.
44 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor };
45 
46 // Code generation.
47 struct CodeGen {
48   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops,
49           OpOperand *op, unsigned nest)
50       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
51         pointers(numTensors, std::vector<Value>(numLoops)),
52         indices(numTensors, std::vector<Value>(numLoops)),
53         highs(numTensors, std::vector<Value>(numLoops)),
54         pidxs(numTensors, std::vector<Value>(numLoops)),
55         idxs(numTensors, std::vector<Value>(numLoops)), redVal(), sparseOut(op),
56         outerParNest(nest), lexIdx(), expValues(), expFilled(), expAdded(),
57         expCount(), curVecMask() {}
58   /// Sparsification options.
59   SparsificationOptions options;
60   /// Universal dense indices and upper bounds (by index). The loops array
61   /// is updated with the value of the universal dense index in the current
62   /// loop. The sizes array is set once with the inferred dimension sizes.
63   std::vector<Value> loops;
64   std::vector<Value> sizes;
65   /// Buffers for storing dense and sparse numerical values (by tensor).
66   /// This array is set once during bufferization of all tensors.
67   std::vector<Value> buffers;
68   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
69   /// This array is set once during bufferization of all sparse tensors.
70   std::vector<std::vector<Value>> pointers;
71   std::vector<std::vector<Value>> indices;
72   /// Sparse iteration information (by tensor and index). These arrays
73   /// are updated to remain current within the current loop.
74   std::vector<std::vector<Value>> highs;
75   std::vector<std::vector<Value>> pidxs;
76   std::vector<std::vector<Value>> idxs;
77   /// Current reduction, updated during code generation. When indices of a
78   /// reduction are exhausted, all inner loops can use a scalarized reduction.
79   unsigned redExp = -1u;
80   Value redVal;
81   Reduction redKind = kNoReduc;
82   // Sparse tensor as output. Implemented either through direct injective
83   // insertion in lexicographic index order (where indices are updated
84   // in the temporary array `lexIdx`) or through access pattern expansion
85   // in the innermost loop nest (`expValues` through `expCount`).
86   OpOperand *sparseOut;
87   unsigned outerParNest;
88   Value lexIdx;
89   Value expValues;
90   Value expFilled;
91   Value expAdded;
92   Value expCount;
93   // Current vector length and mask.
94   unsigned curVecLength = 1;
95   Value curVecMask;
96 };
97 
98 } // namespace
99 
100 //===----------------------------------------------------------------------===//
101 // Sparse compiler analysis methods.
102 //===----------------------------------------------------------------------===//
103 
104 /// Helper method to apply dimension ordering permutation.
105 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) {
106   if (enc) {
107     auto order = enc.getDimOrdering();
108     if (order) {
109       assert(order.isPermutation());
110       return order.getDimPosition(d);
111     }
112   }
113   return d;
114 }
115 
116 /// Helper method to translate dim level type to internal representation.
117 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) {
118   if (enc) {
119     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
120     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
121       return Dim::kSparse;
122     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
123       return Dim::kSingle;
124   }
125   return Dim::kDense;
126 }
127 
128 /// Helper method to inspect affine expressions. Rejects cases where the
129 /// same index is used more than once. Also rejects affine expressions
130 /// that are not a direct index for annotated tensors.
131 // TODO: accept more affine cases for sparse tensors
132 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim,
133                        bool isDense) {
134   switch (a.getKind()) {
135   case AffineExprKind::DimId: {
136     unsigned idx = a.cast<AffineDimExpr>().getPosition();
137     if (!merger.isDim(tensor, idx, Dim::kUndef))
138       return false; // used more than once
139     merger.setDim(tensor, idx, dim);
140     return true;
141   }
142   case AffineExprKind::Add:
143   case AffineExprKind::Mul: {
144     if (!isDense)
145       return false;
146     auto binOp = a.cast<AffineBinaryOpExpr>();
147     return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) &&
148            findAffine(merger, tensor, binOp.getRHS(), dim, isDense);
149   }
150   case AffineExprKind::Constant:
151     return isDense;
152   default:
153     return false;
154   }
155 }
156 
157 /// Helper method to inspect sparse encodings in the tensor types.
158 /// Fills the per-dimension sparsity information for all tensors.
159 /// Returns true if the sparse annotations and affine subscript
160 /// expressions of all tensors are admissable. Returns false if
161 /// no annotations are found or inadmissable constructs occur.
162 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
163   bool annotated = false;
164   for (OpOperand *t : op.getInputAndOutputOperands()) {
165     auto map = op.getTiedIndexingMap(t);
166     auto enc = getSparseTensorEncoding(t->get().getType());
167     if (enc)
168       annotated = true;
169     assert(map.getNumResults() == op.getRank(t));
170     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
171       unsigned tensor = t->getOperandNumber();
172       AffineExpr a = map.getResult(perm(enc, d));
173       if (!findAffine(merger, tensor, a, toDim(enc, d), !enc))
174         return false; // inadmissable affine expression
175     }
176   }
177   return annotated;
178 }
179 
180 /// A DFS helper to compute a topological sort. Note that recursion is
181 /// bounded by the number of implicit loops, which is always small.
182 /// Returns false when a cycle is detected.
183 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
184                        std::vector<unsigned> &topSort,
185                        std::vector<std::vector<bool>> &adjM) {
186   if (visit[i] != 0)
187     return visit[i] != 1; // 1 denotes cycle!
188   visit[i] = 1;
189   for (unsigned j = 0, e = visit.size(); j < e; j++)
190     if (adjM[i][j])
191       if (!topSortDFS(j, visit, topSort, adjM))
192         return false;
193   visit[i] = 2;
194   topSort.push_back(i);
195   return true;
196 }
197 
198 /// Helper method to add all constraints from the indices in one affine
199 /// expression before all indices in the other affine expression. For
200 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3.
201 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
202                                AffineExpr a, AffineExpr b, unsigned fidx) {
203   switch (a.getKind()) {
204   case AffineExprKind::DimId: {
205     unsigned idx = a.cast<AffineDimExpr>().getPosition();
206     if (b)
207       addAffineOrderings(adjM, b, AffineExpr(), idx);
208     else
209       adjM[fidx][idx] = true;
210     break;
211   }
212   case AffineExprKind::Add:
213   case AffineExprKind::Mul: {
214     auto binOp = a.cast<AffineBinaryOpExpr>();
215     addAffineOrderings(adjM, binOp.getLHS(), b, fidx);
216     addAffineOrderings(adjM, binOp.getRHS(), b, fidx);
217     break;
218   }
219   default:
220     break;
221   }
222 }
223 
224 /// Computes a topologically sorted iteration graph for the linalg operation.
225 /// Ensures all tensors are visited in natural index order. This is essential
226 /// for sparse storage formats since these only support access along fixed
227 /// dimensions. Even for dense storage formats, however, the natural index
228 /// order yields innermost unit-stride access with better spatial locality.
229 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
230                                   std::vector<unsigned> &topSort,
231                                   unsigned mask) {
232   // Set up an n x n from/to adjacency matrix of the iteration graph
233   // for the implicit loop indices i_0 .. i_n-1.
234   unsigned n = op.getNumLoops();
235   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
236 
237   // Iterate over the indexing maps of every tensor in the tensor expression.
238   for (OpOperand *t : op.getInputAndOutputOperands()) {
239     auto map = op.getTiedIndexingMap(t);
240     auto enc = getSparseTensorEncoding(t->get().getType());
241     assert(map.getNumDims() == n);
242     // Skip dense tensor constraints when not requested.
243     if (!(mask & SortMask::kIncludeDense) && !enc)
244       continue;
245     // Each tensor expression and optional dimension ordering (row-major
246     // by default) puts an ordering constraint on the loop indices. For
247     // example, the tensor expresion A_ijk forces the ordering i < j < k
248     // on the loop indices if no explicit dimension ordering is given.
249     for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
250       AffineExpr f = map.getResult(perm(enc, d - 1));
251       AffineExpr t = map.getResult(perm(enc, d));
252       addAffineOrderings(adjM, f, t, 0);
253     }
254     // Push unrelated loops into sparse iteration space, so these
255     // will be skipped more often.
256     if (mask & SortMask::kIncludeUndef) {
257       unsigned tensor = t->getOperandNumber();
258       for (unsigned i = 0; i < n; i++)
259         if (merger.isDim(tensor, i, Dim::kSparse))
260           for (unsigned j = 0; j < n; j++)
261             if (merger.isDim(tensor, j, Dim::kUndef))
262               adjM[i][j] = true;
263     }
264   }
265 
266   // Topologically sort the iteration graph to determine loop order.
267   // Report failure for a cyclic iteration graph.
268   topSort.clear();
269   topSort.reserve(n);
270   std::vector<unsigned> visit(n, 0);
271   for (unsigned i = 0; i < n; i++)
272     if (visit[i] == 0)
273       if (!topSortDFS(i, visit, topSort, adjM))
274         return false; // cycle!
275   std::reverse(std::begin(topSort), std::end(topSort));
276   return true;
277 }
278 
279 /// Returns true if tensor has an in-place annotation.
280 static bool isInPlace(Value val) {
281   if (auto arg = val.dyn_cast<BlockArgument>())
282     if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
283       if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
284               arg.getArgNumber(),
285               bufferization::BufferizableOpInterface::kInplaceableAttrName))
286         return attr.getValue();
287   return false;
288 }
289 
290 /// Returns true if tensor materializes uninitialized into the computation.
291 static bool isMaterializing(Value val) {
292   return val.getDefiningOp<linalg::InitTensorOp>() ||
293          val.getDefiningOp<InitOp>();
294 }
295 
296 /// Returns true when the tensor expression is admissable for codegen.
297 /// Since all sparse input tensors are admissable, we just need to check
298 /// whether the out tensor in the tensor expression codegen is admissable.
299 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
300 /// nesting depth when a "truly dynamic" sparse tensor output occurs.
301 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
302                                   std::vector<unsigned> &topSort, unsigned exp,
303                                   OpOperand **sparseOut,
304                                   unsigned &outerParNest) {
305   OpOperand *lhs = op.getOutputOperand(0);
306   unsigned tensor = lhs->getOperandNumber();
307   auto enc = getSparseTensorEncoding(lhs->get().getType());
308   // An non-annotated output tensor is assumed dense, and becomes a random
309   // access n-dim memref. Admissable since insertions cannot occur.
310   if (!enc)
311     return true;
312   // An all-dense annotated "sparse" output tensor becomes a linearized random
313   // access 1-dim memref. Also admissable since insertions cannot occur.
314   bool allDense = true;
315   auto iteratorTypes = op.iterator_types().getValue();
316   unsigned numLoops = iteratorTypes.size();
317   for (unsigned i = 0; i < numLoops; i++)
318     if (merger.isDim(tensor, i, Dim::kSparse)) {
319       allDense = false;
320       break;
321     }
322   if (allDense)
323     return true;
324   // A tensor expression with a sparse output tensor that changes its values
325   // but not its nonzero structure, an operation called "simply dynamic" in
326   // [Bik96,Ch9], is also admissable without special codegen, provided
327   // the tensor's underlying sparse storage scheme can be modified in place.
328   if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get()))
329     return true;
330   // Accept "truly dynamic" if the output tensor materializes uninitialized
331   // into the computation and insertions occur in lexicographic index order.
332   if (isMaterializing(lhs->get())) {
333     unsigned nest = 0;
334     for (unsigned i = 0; i < numLoops; i++) {
335       if (isReductionIterator(iteratorTypes[topSort[i]]))
336         break; // terminate at first reduction
337       nest++;
338     }
339     // Determine admissable dynamic insertion situations:
340     // (1) fully injective, since there are no reductions,
341     // (2) admissable 1-d expansion in innermost dimension.
342     if (nest >= op.getRank(lhs) - 1) {
343       *sparseOut = lhs;
344       outerParNest = nest;
345       return true;
346     }
347   }
348   return false;
349 }
350 
351 //===----------------------------------------------------------------------===//
352 // Sparse compiler synthesis methods (reductions).
353 //===----------------------------------------------------------------------===//
354 
355 /// Maps reduction kind to vector::CombiningKind.
356 static vector::CombiningKind getCombiningKind(Reduction kind) {
357   switch (kind) {
358   case kNoReduc:
359     break;
360   case kSum:
361     return vector::CombiningKind::ADD;
362   case kProduct:
363     return vector::CombiningKind::MUL;
364   case kAnd:
365     return vector::CombiningKind::AND;
366   case kOr:
367     return vector::CombiningKind::OR;
368   case kXor:
369     return vector::CombiningKind::XOR;
370   }
371   llvm_unreachable("unknown reduction kind");
372 }
373 
374 /// Maps operation to reduction.
375 static Reduction getReduction(Kind kind) {
376   switch (kind) {
377   case Kind::kAddF:
378   case Kind::kAddI:
379   case Kind::kSubF:
380   case Kind::kSubI:
381     return kSum;
382   case Kind::kMulF:
383   case Kind::kMulI:
384     return kProduct;
385   case Kind::kAndI:
386     return kAnd;
387   case Kind::kOrI:
388     return kOr;
389   case Kind::kXorI:
390     return kXor;
391   default:
392     llvm_unreachable("unexpected reduction operator");
393   }
394 }
395 
396 /// Generates an initial value for a vector reduction, following the scheme
397 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
398 /// initial scalar value is correctly embedded in the vector reduction value,
399 /// and a straightforward horizontal reduction will complete the operation.
400 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter,
401                                 Location loc, VectorType vtp) {
402   Value r = codegen.redVal;
403   switch (codegen.redKind) {
404   case kNoReduc:
405     break;
406   case kSum:
407   case kXor:
408     // Initialize reduction vector to: | 0 | .. | 0 | r |
409     return rewriter.create<vector::InsertElementOp>(
410         loc, r, constantZero(rewriter, loc, vtp),
411         constantIndex(rewriter, loc, 0));
412   case kProduct:
413     // Initialize reduction vector to: | 1 | .. | 1 | r |
414     return rewriter.create<vector::InsertElementOp>(
415         loc, r, constantOne(rewriter, loc, vtp),
416         constantIndex(rewriter, loc, 0));
417   case kAnd:
418   case kOr:
419     // Initialize reduction vector to: | r | .. | r | r |
420     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
421   }
422   llvm_unreachable("unknown reduction kind");
423 }
424 
425 /// Generates final value for a vector reduction.
426 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter,
427                                Location loc, VectorType vtp) {
428   vector::CombiningKind kind = getCombiningKind(codegen.redKind);
429   return rewriter.create<vector::ReductionOp>(loc, kind, codegen.redVal);
430 }
431 
432 /// Updates scalarized reduction value.
433 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) {
434   assert(codegen.redKind != kNoReduc);
435   codegen.redVal = merger.exp(codegen.redExp).val = reduc;
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // Sparse compiler synthesis methods (statements and expressions).
440 //===----------------------------------------------------------------------===//
441 
442 /// Generates buffer for the output tensor. Note that all sparse kernels
443 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)),
444 /// the output buffer is already initialized to all zeroes and only nonzeroes
445 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)),
446 /// only nonzeroes values are used for the updates and no assumption on the
447 /// original contents of the output buffer is necessary..
448 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
449                              linalg::GenericOp op, MemRefType denseTp,
450                              ArrayRef<Value> args) {
451   Location loc = op.getLoc();
452   Value tensor = op.getOutputOperand(0)->get();
453   // The output tensor simply could materialize from the buffer that will
454   // be generated for the tensor present in the outs() clause. This has
455   // the major advantage that the sparse kernel only updates the nonzero
456   // positions for the output tensor.
457   if (isInPlace(tensor))
458     return rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
459   // By default, a new buffer is allocated which is initialized to the
460   // tensor defined in the outs() clause. This is always correct but
461   // introduces a dense initialization component that may negatively
462   // impact the running complexity of the sparse kernel. If the tensor
463   // materializes into the computation, we need to preserve the zero
464   // initialization assumption of all sparse output buffers.
465   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
466   if (isMaterializing(tensor)) {
467     Value zero = constantZero(rewriter, loc, denseTp.getElementType());
468     rewriter.create<linalg::FillOp>(loc, zero, alloc);
469   } else {
470     Value init =
471         rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
472     rewriter.create<memref::CopyOp>(loc, init, alloc);
473   }
474   return alloc;
475 }
476 
477 /// Local bufferization of all dense and sparse data structures.
478 /// This code enables testing the first prototype sparse compiler.
479 // TODO: replace this with a proliferated bufferization strategy
480 static void genBuffers(Merger &merger, CodeGen &codegen,
481                        PatternRewriter &rewriter, linalg::GenericOp op) {
482   Location loc = op.getLoc();
483   assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
484   // For every tensor, find lower and upper bound on dimensions, set the
485   // same bounds on loop indices, and obtain dense or sparse buffer(s).
486   SmallVector<Value, 4> args;
487   for (OpOperand *t : op.getInputAndOutputOperands()) {
488     unsigned tensor = t->getOperandNumber();
489     auto shape = op.getShape(t);
490     auto map = op.getTiedIndexingMap(t);
491     auto enc = getSparseTensorEncoding(t->get().getType());
492     // Scan all dimensions of current tensor.
493     args.clear();
494     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
495       AffineExpr a = map.getResult(perm(enc, d));
496       if (a.getKind() != AffineExprKind::DimId)
497         continue; // compound
498       unsigned idx = a.cast<AffineDimExpr>().getPosition();
499       // Handle sparse storage schemes.
500       if (merger.isDim(tensor, idx, Dim::kSparse)) {
501         auto dynShape = {ShapedType::kDynamicSize};
502         auto ptrTp =
503             MemRefType::get(dynShape, getPointerOverheadType(rewriter, enc));
504         auto indTp =
505             MemRefType::get(dynShape, getIndexOverheadType(rewriter, enc));
506         Value dim = constantIndex(rewriter, loc, d);
507         // Generate sparse primitives to obtains pointer and indices.
508         codegen.pointers[tensor][idx] =
509             rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
510         codegen.indices[tensor][idx] =
511             rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim);
512       }
513       // Find upper bound in current dimension.
514       unsigned p = perm(enc, d);
515       Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p);
516       if (ShapedType::isDynamic(shape[p]))
517         args.push_back(up);
518       assert(codegen.highs[tensor][idx] == nullptr);
519       codegen.sizes[idx] = codegen.highs[tensor][idx] = up;
520     }
521     // Perform the required bufferization. Dense inputs materialize
522     // from the input tensors. Dense outputs need special handling.
523     // Sparse inputs use sparse primitives to obtain the values.
524     // We also accept in-place all-dense annotated "sparse" outputs.
525     Type elementType = getElementTypeOrSelf(t->get().getType());
526     if (!enc) {
527       // Non-annotated dense tensors.
528       auto denseTp = MemRefType::get(shape, elementType);
529       if (tensor < op.getNumInputs())
530         codegen.buffers[tensor] =
531             rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, t->get());
532       else
533         codegen.buffers[tensor] =
534             genOutputBuffer(codegen, rewriter, op, denseTp, args);
535     } else if (t == codegen.sparseOut) {
536       // True sparse output needs a lexIdx array.
537       Value rank = constantIndex(rewriter, loc, op.getRank(t));
538       auto dynShape = {ShapedType::kDynamicSize};
539       auto memTp = MemRefType::get(dynShape, rewriter.getIndexType());
540       codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank);
541     } else {
542       // Annotated sparse tensors.
543       auto dynShape = {ShapedType::kDynamicSize};
544       auto sparseTp = MemRefType::get(dynShape, elementType);
545       codegen.buffers[tensor] =
546           rewriter.create<ToValuesOp>(loc, sparseTp, t->get());
547     }
548   }
549 }
550 
551 /// Constructs vector type.
552 static VectorType vectorType(CodeGen &codegen, Type etp) {
553   return VectorType::get(codegen.curVecLength, etp);
554 }
555 
556 /// Constructs vector type from pointer.
557 static VectorType vectorType(CodeGen &codegen, Value ptr) {
558   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
559 }
560 
561 /// Constructs vector iteration mask.
562 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
563                            Value iv, Value lo, Value hi, Value step) {
564   Location loc = iv.getLoc();
565   VectorType mtp = vectorType(codegen, rewriter.getI1Type());
566   // Special case if the vector length evenly divides the trip count (for
567   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
568   // so that all subsequent masked memory operations are immediately folded
569   // into unconditional memory operations.
570   IntegerAttr loInt, hiInt, stepInt;
571   if (matchPattern(lo, m_Constant(&loInt)) &&
572       matchPattern(hi, m_Constant(&hiInt)) &&
573       matchPattern(step, m_Constant(&stepInt))) {
574     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
575       return rewriter.create<vector::BroadcastOp>(
576           loc, mtp, constantI1(rewriter, loc, true));
577   }
578   // Otherwise, generate a vector mask that avoids overrunning the upperbound
579   // during vector execution. Here we rely on subsequent loop optimizations to
580   // avoid executing the mask in all iterations, for example, by splitting the
581   // loop into an unconditional vector loop and a scalar cleanup loop.
582   auto minMap = AffineMap::get(
583       /*dimCount=*/2, /*symbolCount=*/1,
584       {rewriter.getAffineSymbolExpr(0),
585        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
586       rewriter.getContext());
587   Value end =
588       rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step});
589   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
590 }
591 
592 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
593 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
594                            Value ptr, ArrayRef<Value> args) {
595   Location loc = ptr.getLoc();
596   VectorType vtp = vectorType(codegen, ptr);
597   Value pass = constantZero(rewriter, loc, vtp);
598   if (args.back().getType().isa<VectorType>()) {
599     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
600     Value indexVec = args.back();
601     scalarArgs.back() = constantIndex(rewriter, loc, 0);
602     return rewriter.create<vector::GatherOp>(
603         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
604   }
605   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
606                                                codegen.curVecMask, pass);
607 }
608 
609 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
610 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
611                            Value rhs, Value ptr, ArrayRef<Value> args) {
612   Location loc = ptr.getLoc();
613   if (args.back().getType().isa<VectorType>()) {
614     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
615     Value indexVec = args.back();
616     scalarArgs.back() = constantIndex(rewriter, loc, 0);
617     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
618                                        codegen.curVecMask, rhs);
619     return;
620   }
621   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
622                                          rhs);
623 }
624 
625 /// Generates a vectorized invariant. Here we rely on subsequent loop
626 /// optimizations to hoist the invariant broadcast out of the vector loop.
627 static Value genVectorInvariantValue(CodeGen &codegen,
628                                      PatternRewriter &rewriter, Value val) {
629   VectorType vtp = vectorType(codegen, val.getType());
630   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
631 }
632 
633 /// Generates an affine expression.
634 //
635 // TODO: generalize for sparse tensor subscripts
636 //
637 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter,
638                        AffineExpr a, Location loc) {
639   switch (a.getKind()) {
640   case AffineExprKind::DimId: {
641     unsigned idx = a.cast<AffineDimExpr>().getPosition();
642     return codegen.loops[idx]; // universal dense index
643   }
644   case AffineExprKind::Add: {
645     auto binOp = a.cast<AffineBinaryOpExpr>();
646     return rewriter.create<arith::AddIOp>(
647         loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
648         genAffine(codegen, rewriter, binOp.getRHS(), loc));
649   }
650   case AffineExprKind::Mul: {
651     auto binOp = a.cast<AffineBinaryOpExpr>();
652     return rewriter.create<arith::MulIOp>(
653         loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
654         genAffine(codegen, rewriter, binOp.getRHS(), loc));
655   }
656   case AffineExprKind::Constant: {
657     int64_t c = a.cast<AffineConstantExpr>().getValue();
658     return constantIndex(rewriter, loc, c);
659   }
660   default:
661     llvm_unreachable("unexpected affine subscript");
662   }
663 }
664 
665 /// Generates index for load/store on sparse tensor.
666 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) {
667   auto map = op.getTiedIndexingMap(t);
668   auto enc = getSparseTensorEncoding(t->get().getType());
669   AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1));
670   assert(a.getKind() == AffineExprKind::DimId);
671   unsigned idx = a.cast<AffineDimExpr>().getPosition();
672   return codegen.loops[idx];
673 }
674 
675 /// Generates subscript for load/store on a dense or sparse tensor.
676 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter,
677                           linalg::GenericOp op, OpOperand *t,
678                           SmallVector<Value, 4> &args) {
679   unsigned tensor = t->getOperandNumber();
680   auto map = op.getTiedIndexingMap(t);
681   auto enc = getSparseTensorEncoding(t->get().getType());
682   unsigned rank = map.getNumResults();
683   if (enc) {
684     // Note that currently, all sparse subscripts are simple.
685     // TODO: accept affine too?
686     AffineExpr a = map.getResult(perm(enc, rank - 1));
687     assert(a.getKind() == AffineExprKind::DimId);
688     unsigned idx = a.cast<AffineDimExpr>().getPosition();
689     assert(codegen.pidxs[tensor][idx] != nullptr);
690     args.push_back(codegen.pidxs[tensor][idx]); // position index
691   } else {
692     for (unsigned d = 0; d < rank; d++) {
693       AffineExpr a = map.getResult(perm(enc, d));
694       args.push_back(genAffine(codegen, rewriter, a, op.getLoc()));
695     }
696   }
697   return codegen.buffers[tensor];
698 }
699 
700 /// Generates insertion code to implement dynamic tensor load.
701 static Value genInsertionLoad(CodeGen &codegen, PatternRewriter &rewriter,
702                               linalg::GenericOp op, OpOperand *t) {
703   Location loc = op.getLoc();
704   // Direct lexicographic index order, tensor loads as zero.
705   if (!codegen.expValues) {
706     Type tp = getElementTypeOrSelf(t->get().getType());
707     return constantZero(rewriter, loc, tp);
708   }
709   // Load from expanded access pattern.
710   Value index = genIndex(codegen, op, t);
711   return rewriter.create<memref::LoadOp>(loc, codegen.expValues, index);
712 }
713 
714 /// Generates insertion code to implement dynamic tensor store.
715 static void genInsertionStore(CodeGen &codegen, PatternRewriter &rewriter,
716                               linalg::GenericOp op, OpOperand *t, Value rhs) {
717   Location loc = op.getLoc();
718   // Direct insertion in lexicographic index order.
719   if (!codegen.expValues) {
720     rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs);
721     return;
722   }
723   // Generates insertion code along expanded access pattern.
724   //   if (!expFilled[i]) then
725   //     expFilled[i] = true
726   //     expAdded[inserts++] = i
727   //   endif
728   //   values[i] = rhs
729   Value index = genIndex(codegen, op, t);
730   Value fval = constantI1(rewriter, loc, false);
731   Value tval = constantI1(rewriter, loc, true);
732   // If statement.
733   Value filled = rewriter.create<memref::LoadOp>(loc, codegen.expFilled, index);
734   Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
735                                               filled, fval);
736   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, rewriter.getIndexType(),
737                                               cond, /*else=*/true);
738   // True branch.
739   rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
740   rewriter.create<memref::StoreOp>(loc, tval, codegen.expFilled, index);
741   rewriter.create<memref::StoreOp>(loc, index, codegen.expAdded,
742                                    codegen.expCount);
743   Value one = constantIndex(rewriter, loc, 1);
744   Value add = rewriter.create<arith::AddIOp>(loc, codegen.expCount, one);
745   rewriter.create<scf::YieldOp>(loc, add);
746   // False branch.
747   rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
748   rewriter.create<scf::YieldOp>(loc, codegen.expCount);
749   rewriter.setInsertionPointAfter(ifOp);
750   // Value assignment.
751   codegen.expCount = ifOp.getResult(0);
752   rewriter.create<memref::StoreOp>(loc, rhs, codegen.expValues, index);
753 }
754 
755 /// Generates a load on a dense or sparse tensor.
756 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
757                            PatternRewriter &rewriter, linalg::GenericOp op,
758                            unsigned exp) {
759   // Test if the load was hoisted to a higher loop nest.
760   Value val = merger.exp(exp).val;
761   if (val) {
762     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
763       return genVectorInvariantValue(codegen, rewriter, val);
764     return val;
765   }
766   // Load during insertion.
767   OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
768   if (t == codegen.sparseOut)
769     return genInsertionLoad(codegen, rewriter, op, t);
770   // Actual load.
771   SmallVector<Value, 4> args;
772   Value ptr = genSubscript(codegen, rewriter, op, t, args);
773   if (codegen.curVecLength > 1)
774     return genVectorLoad(codegen, rewriter, ptr, args);
775   return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
776 }
777 
778 /// Generates a store on a dense or sparse tensor.
779 static void genTensorStore(Merger &merger, CodeGen &codegen,
780                            PatternRewriter &rewriter, linalg::GenericOp op,
781                            Value rhs) {
782   Location loc = op.getLoc();
783   // Test if this is a scalarized reduction.
784   if (codegen.redVal) {
785     if (codegen.curVecLength > 1)
786       rhs = rewriter.create<arith::SelectOp>(loc, codegen.curVecMask, rhs,
787                                              codegen.redVal);
788     updateReduc(merger, codegen, rhs);
789     return;
790   }
791   // Store during insertion.
792   OpOperand *t = op.getOutputOperand(0);
793   if (t == codegen.sparseOut) {
794     genInsertionStore(codegen, rewriter, op, t, rhs);
795     return;
796   }
797   // Actual store.
798   SmallVector<Value, 4> args;
799   Value ptr = genSubscript(codegen, rewriter, op, t, args);
800   if (codegen.curVecLength > 1)
801     genVectorStore(codegen, rewriter, rhs, ptr, args);
802   else
803     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
804 }
805 
806 /// Generates a pointer/index load from the sparse storage scheme. Narrower
807 /// data types need to be zero extended before casting the value into the
808 /// index type used for looping and indexing.
809 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
810                      Value ptr, Value s) {
811   // See https://llvm.org/docs/GetElementPtr.html for some background on
812   // the complications described below.
813   if (codegen.curVecLength > 1) {
814     // Since the index vector is used in a subsequent gather/scatter operations,
815     // which effectively defines an unsigned pointer + signed index, we must
816     // zero extend the vector to an index width. For 8-bit and 16-bit values,
817     // an 32-bit index width suffices. For 32-bit values, zero extending the
818     // elements into 64-bit loses some performance since the 32-bit indexed
819     // gather/scatter is more efficient than the 64-bit index variant (if the
820     // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
821     // preserve this performance). For 64-bit values, there is no good way
822     // to state that the indices are unsigned, with creates the potential of
823     // incorrect address calculations in the unlikely case we need such
824     // extremely large offsets.
825     Type etp = ptr.getType().cast<MemRefType>().getElementType();
826     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
827     if (!etp.isa<IndexType>()) {
828       if (etp.getIntOrFloatBitWidth() < 32)
829         vload = rewriter.create<arith::ExtUIOp>(
830             loc, vectorType(codegen, rewriter.getI32Type()), vload);
831       else if (etp.getIntOrFloatBitWidth() < 64 &&
832                !codegen.options.enableSIMDIndex32)
833         vload = rewriter.create<arith::ExtUIOp>(
834             loc, vectorType(codegen, rewriter.getI64Type()), vload);
835     }
836     return vload;
837   }
838   // For the scalar case, we simply zero extend narrower indices into 64-bit
839   // values before casting to index without a performance penalty. Here too,
840   // however, indices that already are 64-bit, in theory, cannot express the
841   // full range as explained above.
842   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
843   if (!load.getType().isa<IndexType>()) {
844     if (load.getType().getIntOrFloatBitWidth() < 64)
845       load = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), load);
846     load =
847         rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), load);
848   }
849   return load;
850 }
851 
852 /// Generates an invariant value.
853 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
854                                PatternRewriter &rewriter, unsigned exp) {
855   Value val = merger.exp(exp).val;
856   if (codegen.curVecLength > 1)
857     return genVectorInvariantValue(codegen, rewriter, val);
858   return val;
859 }
860 
861 /// Generates an address computation "sz * p + i".
862 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
863                         Location loc, Value size, Value p, Value i) {
864   Value mul = rewriter.create<arith::MulIOp>(loc, size, p);
865   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
866     Value inv =
867         rewriter.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul);
868     mul = genVectorInvariantValue(codegen, rewriter, inv);
869   }
870   return rewriter.create<arith::AddIOp>(loc, mul, i);
871 }
872 
873 /// Recursively generates tensor expression.
874 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
875                     linalg::GenericOp op, unsigned exp) {
876   Location loc = op.getLoc();
877   if (exp == -1u)
878     return Value();
879   if (merger.exp(exp).kind == Kind::kTensor)
880     return genTensorLoad(merger, codegen, rewriter, op, exp);
881   if (merger.exp(exp).kind == Kind::kInvariant)
882     return genInvariantValue(merger, codegen, rewriter, exp);
883   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
884   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
885   return merger.buildExp(rewriter, loc, exp, v0, v1);
886 }
887 
888 /// Determines if affine expression is invariant.
889 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
890                               unsigned ldx, bool &atLevel) {
891   switch (a.getKind()) {
892   case AffineExprKind::DimId: {
893     unsigned idx = a.cast<AffineDimExpr>().getPosition();
894     if (idx == ldx)
895       atLevel = true;
896     return codegen.loops[idx] != nullptr; // no longer in play?
897   }
898   case AffineExprKind::Add:
899   case AffineExprKind::Mul: {
900     auto binOp = a.cast<AffineBinaryOpExpr>();
901     return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) &&
902            isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel);
903   }
904   default:
905     return true;
906   }
907 }
908 
909 /// Hoists loop invariant tensor loads for which indices have been exhausted.
910 static void genInvariants(Merger &merger, CodeGen &codegen,
911                           PatternRewriter &rewriter, linalg::GenericOp op,
912                           unsigned exp, unsigned ldx, bool atStart,
913                           Kind last = Kind::kTensor) {
914   if (exp == -1u)
915     return;
916   if (merger.exp(exp).kind == Kind::kTensor) {
917     // Inspect tensor indices.
918     bool atLevel = ldx == -1u;
919     OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
920     auto map = op.getTiedIndexingMap(t);
921     auto enc = getSparseTensorEncoding(t->get().getType());
922     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
923       AffineExpr a = map.getResult(perm(enc, d));
924       if (!isInvariantAffine(codegen, a, ldx, atLevel))
925         return; // still in play
926     }
927     // All exhausted at this level (atLevel denotes exactly at this level).
928     if (!atLevel)
929       return;
930     OpOperand *lhs = op.getOutputOperand(0);
931     if (lhs == t) {
932       // Start or end a scalarized reduction
933       if (atStart) {
934         Value load = genTensorLoad(merger, codegen, rewriter, op, exp);
935         codegen.redKind = getReduction(last);
936         codegen.redExp = exp;
937         updateReduc(merger, codegen, load);
938       } else {
939         Value redVal = codegen.redVal;
940         updateReduc(merger, codegen, Value());
941         codegen.redExp = -1u;
942         codegen.redKind = kNoReduc;
943         genTensorStore(merger, codegen, rewriter, op, redVal);
944       }
945     } else {
946       // Start or end loop invariant hoisting of a tensor load.
947       merger.exp(exp).val =
948           atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
949     }
950   } else if (merger.exp(exp).kind != Kind::kInvariant) {
951     // Traverse into the binary operations. Note that we only hoist
952     // tensor loads, since subsequent MLIR/LLVM passes know how to
953     // deal with all other kinds of derived loop invariants.
954     Kind last = merger.exp(exp).kind;
955     unsigned e0 = merger.exp(exp).children.e0;
956     unsigned e1 = merger.exp(exp).children.e1;
957     genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last);
958     genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last);
959   }
960 }
961 
962 /// Generates an expanded access pattern in innermost dimension.
963 static void genExpansion(Merger &merger, CodeGen &codegen,
964                          PatternRewriter &rewriter, linalg::GenericOp op,
965                          unsigned at, bool atStart) {
966   OpOperand *lhs = codegen.sparseOut;
967   if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 ||
968       at != codegen.outerParNest)
969     return; // not needed at this level
970   // Generate start or end of an expanded access pattern.
971   Value tensor = lhs->get();
972   Location loc = op.getLoc();
973   if (atStart) {
974     auto dynShape = {ShapedType::kDynamicSize};
975     Type etp = tensor.getType().cast<ShapedType>().getElementType();
976     Type t1 = MemRefType::get(dynShape, etp);
977     Type t2 = MemRefType::get(dynShape, rewriter.getI1Type());
978     Type t3 = MemRefType::get(dynShape, rewriter.getIndexType());
979     Type t4 = rewriter.getIndexType();
980     auto res =
981         rewriter.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
982     assert(res.getNumResults() == 4);
983     assert(!codegen.expValues);
984     codegen.expValues = res.getResult(0);
985     codegen.expFilled = res.getResult(1);
986     codegen.expAdded = res.getResult(2);
987     codegen.expCount = res.getResult(3);
988   } else {
989     assert(codegen.expValues);
990     rewriter.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues,
991                                 codegen.expFilled, codegen.expAdded,
992                                 codegen.expCount);
993     codegen.expValues = codegen.expFilled = codegen.expAdded =
994         codegen.expCount = Value();
995   }
996 }
997 
998 /// Generates initialization code for the subsequent loop sequence at
999 /// current index level. Returns true if the loop sequence needs to
1000 /// maintain the universal index.
1001 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1002                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1003                     unsigned at, BitVector &inits) {
1004   bool needsUniv = false;
1005   Location loc = op.getLoc();
1006   unsigned idx = topSort[at];
1007 
1008   // Initialize sparse positions.
1009   for (unsigned b = 0, be = inits.size(); b < be; b++) {
1010     if (inits[b]) {
1011       unsigned tensor = merger.tensor(b);
1012       assert(idx == merger.index(b));
1013       if (merger.isDim(b, Dim::kSparse)) {
1014         // Initialize sparse index.
1015         unsigned pat = at;
1016         for (; pat != 0; pat--) {
1017           if (codegen.pidxs[tensor][topSort[pat - 1]])
1018             break;
1019         }
1020         Value ptr = codegen.pointers[tensor][idx];
1021         Value one = constantIndex(rewriter, loc, 1);
1022         Value p0 = (pat == 0) ? constantIndex(rewriter, loc, 0)
1023                               : codegen.pidxs[tensor][topSort[pat - 1]];
1024         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
1025         Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one);
1026         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
1027       } else {
1028         // Dense index still in play.
1029         needsUniv = true;
1030       }
1031     }
1032   }
1033 
1034   // Initialize the universal dense index.
1035   codegen.loops[idx] = constantIndex(rewriter, loc, 0);
1036   return needsUniv;
1037 }
1038 
1039 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
1040 /// operation is a candidate. Whether it is actually converted to SIMD code
1041 /// depends on the requested strategy.
1042 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
1043   switch (codegen.options.vectorizationStrategy) {
1044   case SparseVectorizationStrategy::kNone:
1045     return false;
1046   case SparseVectorizationStrategy::kDenseInnerLoop:
1047     return isInner && !isSparse;
1048   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
1049     return isInner;
1050   }
1051   llvm_unreachable("unexpected vectorization strategy");
1052 }
1053 
1054 /// Returns parallelization strategy. Any implicit loop in the Linalg operation
1055 /// that is marked "parallel" is a candidate. Whether it is actually converted
1056 /// to a parallel operation depends on the requested strategy.
1057 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
1058                           bool isSparse, bool isVector) {
1059   switch (codegen.options.parallelizationStrategy) {
1060   case SparseParallelizationStrategy::kNone:
1061     return false;
1062   case SparseParallelizationStrategy::kDenseOuterLoop:
1063     return isOuter && !isSparse && !isReduction && !isVector;
1064   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
1065     return isOuter && !isReduction && !isVector;
1066   case SparseParallelizationStrategy::kDenseAnyLoop:
1067     return !isSparse && !isReduction && !isVector;
1068   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
1069     return !isReduction && !isVector;
1070   }
1071   llvm_unreachable("unexpected parallelization strategy");
1072 }
1073 
1074 /// Checks unit stride for dense tensors. The iteration graph may have ignored
1075 /// dense access patterns in order to avoid cycles (sparse access patterns are
1076 /// always placed innermost), but that means dense access has become strided.
1077 /// This prevents effective vectorization.
1078 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
1079                              unsigned idx) {
1080   for (OpOperand *t : op.getInputAndOutputOperands()) {
1081     if (!getSparseTensorEncoding(t->get().getType())) {
1082       auto map = op.getTiedIndexingMap(t);
1083       for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
1084         AffineExpr a = map.getResult(d);
1085         // Report non-unit stride if innermost index appears at an outer
1086         // dimension (true non-unit stride) or if the innermost index appears
1087         // in a compound subscript in the innermost dimension. Even if the
1088         // latter is unit stride, it does not play well with scatter/gather.
1089         // TODO: accept unit stride affine innermost like a[i,j+k+1]?
1090         if (a.isFunctionOfDim(idx) &&
1091             ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId)))
1092           return false;
1093       }
1094     }
1095   }
1096   return true;
1097 }
1098 
1099 /// Generates a for-loop on a single index.
1100 static Operation *genFor(Merger &merger, CodeGen &codegen,
1101                          PatternRewriter &rewriter, linalg::GenericOp op,
1102                          bool isOuter, bool isInner, unsigned idx,
1103                          BitVector &indices) {
1104   unsigned fb = indices.find_first();
1105   unsigned tensor = merger.tensor(fb);
1106   assert(idx == merger.index(fb));
1107   auto iteratorTypes = op.iterator_types().getValue();
1108   bool isReduction = isReductionIterator(iteratorTypes[idx]);
1109   bool isSparse = merger.isDim(fb, Dim::kSparse);
1110   bool isVector = !codegen.sparseOut &&
1111                   isVectorFor(codegen, isInner, isSparse) &&
1112                   denseUnitStrides(merger, op, idx);
1113   bool isParallel =
1114       !codegen.sparseOut &&
1115       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
1116 
1117   // Prepare vector length.
1118   if (isVector)
1119     codegen.curVecLength = codegen.options.vectorLength;
1120 
1121   // Loop bounds and increment.
1122   Location loc = op.getLoc();
1123   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
1124   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
1125   Value step = constantIndex(rewriter, loc, codegen.curVecLength);
1126 
1127   // Emit a parallel loop.
1128   if (isParallel) {
1129     assert(!isVector);
1130     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
1131     if (isSparse)
1132       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
1133     else
1134       codegen.loops[idx] = parOp.getInductionVars()[0];
1135     rewriter.setInsertionPointToStart(parOp.getBody());
1136     return parOp;
1137   }
1138 
1139   // Emit a sequential or vector loop.
1140   SmallVector<Value, 4> operands;
1141   if (codegen.redVal) {
1142     // In a vector loop, bring reduction into SIMD form, if not already.
1143     if (isVector && !codegen.redVal.getType().isa<VectorType>()) {
1144       VectorType vtp = vectorType(codegen, codegen.redVal.getType());
1145       Value vred = genVectorReducInit(codegen, rewriter, loc, vtp);
1146       updateReduc(merger, codegen, vred);
1147     }
1148     operands.push_back(codegen.redVal);
1149   }
1150   if (codegen.expValues)
1151     operands.push_back(codegen.expCount);
1152   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
1153   if (codegen.redVal)
1154     updateReduc(merger, codegen, forOp.getRegionIterArgs().front());
1155   if (codegen.expValues)
1156     codegen.expCount = forOp.getRegionIterArgs().back();
1157   // Assign induction variable to sparse or dense index.
1158   Value iv = forOp.getInductionVar();
1159   if (isSparse)
1160     codegen.pidxs[tensor][idx] = iv;
1161   else
1162     codegen.loops[idx] = iv;
1163   rewriter.setInsertionPointToStart(forOp.getBody());
1164   // Share vector iteration mask between all subsequent loads/stores.
1165   if (isVector)
1166     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
1167   return forOp;
1168 }
1169 
1170 /// Emit a while-loop for co-iteration over multiple indices.
1171 static Operation *genWhile(Merger &merger, CodeGen &codegen,
1172                            PatternRewriter &rewriter, linalg::GenericOp op,
1173                            unsigned idx, bool needsUniv,
1174                            BitVector &indices) {
1175   SmallVector<Type, 4> types;
1176   SmallVector<Value, 4> operands;
1177   // Construct the while-loop with a parameter for each index.
1178   Type indexType = rewriter.getIndexType();
1179   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1180     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1181       unsigned tensor = merger.tensor(b);
1182       assert(idx == merger.index(b));
1183       types.push_back(indexType);
1184       operands.push_back(codegen.pidxs[tensor][idx]);
1185     }
1186   }
1187   if (codegen.redVal) {
1188     types.push_back(codegen.redVal.getType());
1189     operands.push_back(codegen.redVal);
1190   }
1191   if (codegen.expValues) {
1192     types.push_back(indexType);
1193     operands.push_back(codegen.expCount);
1194   }
1195   if (needsUniv) {
1196     types.push_back(indexType);
1197     operands.push_back(codegen.loops[idx]);
1198   }
1199   assert(types.size() == operands.size());
1200   Location loc = op.getLoc();
1201   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
1202 
1203   SmallVector<Location> locs(types.size(), loc);
1204   Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, locs);
1205   Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, types, locs);
1206 
1207   // Build the "before" region, which effectively consists
1208   // of a conjunction of "i < upper" tests on all induction.
1209   rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
1210   Value cond;
1211   unsigned o = 0;
1212   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1213     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1214       unsigned tensor = merger.tensor(b);
1215       assert(idx == merger.index(b));
1216       Value op1 = before->getArgument(o);
1217       Value op2 = codegen.highs[tensor][idx];
1218       Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
1219                                                  op1, op2);
1220       cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc;
1221       codegen.pidxs[tensor][idx] = after->getArgument(o++);
1222     }
1223   }
1224   if (codegen.redVal)
1225     updateReduc(merger, codegen, after->getArgument(o++));
1226   if (codegen.expValues)
1227     codegen.expCount = after->getArgument(o++);
1228   if (needsUniv)
1229     codegen.loops[idx] = after->getArgument(o++);
1230   assert(o == operands.size());
1231   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
1232   rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
1233   return whileOp;
1234 }
1235 
1236 /// Generates a for-loop or a while-loop, depending on whether it implements
1237 /// singleton iteration or co-iteration over the given conjunction.
1238 static Operation *genLoop(Merger &merger, CodeGen &codegen,
1239                           PatternRewriter &rewriter, linalg::GenericOp op,
1240                           std::vector<unsigned> &topSort, unsigned at,
1241                           bool needsUniv, BitVector &indices) {
1242   unsigned idx = topSort[at];
1243   if (indices.count() == 1) {
1244     bool isOuter = at == 0;
1245     bool isInner = at == topSort.size() - 1;
1246     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
1247                   indices);
1248   }
1249   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
1250 }
1251 
1252 /// Generates the local variables for this loop, consisting of the sparse
1253 /// indices, restored universal dense index, and dense positions.
1254 static void genLocals(Merger &merger, CodeGen &codegen,
1255                       PatternRewriter &rewriter, linalg::GenericOp op,
1256                       std::vector<unsigned> &topSort, unsigned at,
1257                       bool needsUniv, BitVector &locals) {
1258   Location loc = op.getLoc();
1259   unsigned idx = topSort[at];
1260 
1261   // Initialize sparse indices.
1262   Value min;
1263   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1264     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1265       unsigned tensor = merger.tensor(b);
1266       assert(idx == merger.index(b));
1267       Value ptr = codegen.indices[tensor][idx];
1268       Value s = codegen.pidxs[tensor][idx];
1269       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1270       codegen.idxs[tensor][idx] = load;
1271       if (!needsUniv) {
1272         if (min) {
1273           Value cmp = rewriter.create<arith::CmpIOp>(
1274               loc, arith::CmpIPredicate::ult, load, min);
1275           min = rewriter.create<arith::SelectOp>(loc, cmp, load, min);
1276         } else {
1277           min = load;
1278         }
1279       }
1280     }
1281   }
1282 
1283   // Merge dense universal index over minimum.
1284   if (min) {
1285     assert(!needsUniv);
1286     codegen.loops[idx] = min;
1287   }
1288 
1289   // Initialize dense positions. Note that we generate dense indices of the
1290   // output tensor unconditionally, since they may not appear in the lattice,
1291   // but may be needed for linearized codegen.
1292   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1293     if ((locals[b] || merger.isOutTensor(b, idx)) &&
1294         merger.isDim(b, Dim::kDense)) {
1295       unsigned tensor = merger.tensor(b);
1296       assert(idx == merger.index(b));
1297       unsigned pat = at;
1298       for (; pat != 0; pat--)
1299         if (codegen.pidxs[tensor][topSort[pat - 1]])
1300           break;
1301       Value p = (pat == 0) ? constantIndex(rewriter, loc, 0)
1302                            : codegen.pidxs[tensor][topSort[pat - 1]];
1303       codegen.pidxs[tensor][idx] = genAddress(
1304           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1305     }
1306   }
1307 
1308   // Move the insertion indices in lexicographic index order. During access
1309   // pattern expansion, we can skip setting the innermost dimension.
1310   if (codegen.sparseOut && !codegen.expValues) {
1311     Value pos = constantIndex(rewriter, loc, at);
1312     rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx,
1313                                      pos);
1314   }
1315 }
1316 
1317 /// Generates the induction structure for a while-loop.
1318 static void genWhileInduction(Merger &merger, CodeGen &codegen,
1319                               PatternRewriter &rewriter, linalg::GenericOp op,
1320                               unsigned idx, bool needsUniv,
1321                               BitVector &induction,
1322                               scf::WhileOp whileOp) {
1323   Location loc = op.getLoc();
1324   // Finalize each else branch of all if statements.
1325   if (codegen.redVal || codegen.expValues) {
1326     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
1327                rewriter.getInsertionBlock()->getParentOp())) {
1328       unsigned y = 0;
1329       SmallVector<Value, 4> yields;
1330       if (codegen.redVal) {
1331         yields.push_back(codegen.redVal);
1332         updateReduc(merger, codegen, ifOp.getResult(y++));
1333       }
1334       if (codegen.expValues) {
1335         yields.push_back(codegen.expCount);
1336         codegen.expCount = ifOp->getResult(y++);
1337       }
1338       assert(y == yields.size());
1339       rewriter.create<scf::YieldOp>(loc, yields);
1340       rewriter.setInsertionPointAfter(ifOp);
1341     }
1342   }
1343   rewriter.setInsertionPointToEnd(&whileOp.getAfter().front());
1344   // Finalize the induction. Note that the induction could be performed
1345   // in the individual if-branches to avoid re-evaluating the conditions.
1346   // However, that would result in a rather elaborate forest of yield
1347   // instructions during code generation. Moreover, performing the induction
1348   // after the if-statements more closely resembles code generated by TACO.
1349   unsigned o = 0;
1350   SmallVector<Value, 4> operands;
1351   Value one = constantIndex(rewriter, loc, 1);
1352   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1353     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1354       unsigned tensor = merger.tensor(b);
1355       assert(idx == merger.index(b));
1356       Value op1 = codegen.idxs[tensor][idx];
1357       Value op2 = codegen.loops[idx];
1358       Value op3 = codegen.pidxs[tensor][idx];
1359       Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1360                                                  op1, op2);
1361       Value add = rewriter.create<arith::AddIOp>(loc, op3, one);
1362       operands.push_back(rewriter.create<arith::SelectOp>(loc, cmp, add, op3));
1363       codegen.pidxs[tensor][idx] = whileOp->getResult(o++);
1364     }
1365   }
1366   if (codegen.redVal) {
1367     operands.push_back(codegen.redVal);
1368     updateReduc(merger, codegen, whileOp->getResult(o++));
1369   }
1370   if (codegen.expValues) {
1371     operands.push_back(codegen.expCount);
1372     codegen.expCount = whileOp->getResult(o++);
1373   }
1374   if (needsUniv) {
1375     operands.push_back(
1376         rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one));
1377     codegen.loops[idx] = whileOp->getResult(o++);
1378   }
1379   assert(o == operands.size());
1380   rewriter.create<scf::YieldOp>(loc, operands);
1381   rewriter.setInsertionPointAfter(whileOp);
1382 }
1383 
1384 /// Generates the induction structure for a for-loop.
1385 static void genForInduction(Merger &merger, CodeGen &codegen,
1386                             PatternRewriter &rewriter, linalg::GenericOp op,
1387                             Operation *loop) {
1388   Location loc = op.getLoc();
1389   unsigned o = 0;
1390   SmallVector<Value, 4> operands;
1391   if (codegen.redVal) {
1392     operands.push_back(codegen.redVal);
1393     updateReduc(merger, codegen, loop->getResult(o++));
1394   }
1395   if (codegen.expValues) {
1396     operands.push_back(codegen.expCount);
1397     codegen.expCount = loop->getResult(o++);
1398   }
1399   assert(o == operands.size());
1400   if (o > 0)
1401     rewriter.create<scf::YieldOp>(loc, operands);
1402   rewriter.setInsertionPointAfter(loop);
1403 }
1404 
1405 /// Generates a single if-statement within a while-loop.
1406 static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1407                        PatternRewriter &rewriter, linalg::GenericOp op,
1408                        unsigned idx, BitVector &conditions) {
1409   Location loc = op.getLoc();
1410   SmallVector<Type, 4> types;
1411   Value cond;
1412   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1413     if (conditions[b]) {
1414       unsigned tensor = merger.tensor(b);
1415       assert(idx == merger.index(b));
1416       Value clause;
1417       if (merger.isDim(b, Dim::kSparse)) {
1418         Value op1 = codegen.idxs[tensor][idx];
1419         Value op2 = codegen.loops[idx];
1420         clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1421                                                 op1, op2);
1422       } else {
1423         clause = constantI1(rewriter, loc, true);
1424       }
1425       cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause;
1426     }
1427   }
1428   if (codegen.redVal)
1429     types.push_back(codegen.redVal.getType());
1430   if (codegen.expValues)
1431     types.push_back(rewriter.getIndexType());
1432   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true);
1433   rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1434   return ifOp;
1435 }
1436 
1437 /// Generates end of true branch of if-statement within a while-loop.
1438 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1439                   linalg::GenericOp op, scf::IfOp ifOp, Operation *loop,
1440                   Value redInput, Value cntInput) {
1441   SmallVector<Value, 4> operands;
1442   if (codegen.redVal) {
1443     operands.push_back(codegen.redVal);
1444     updateReduc(merger, codegen, redInput);
1445   }
1446   if (codegen.expValues) {
1447     operands.push_back(codegen.expCount);
1448     codegen.expCount = cntInput;
1449   }
1450   if (!operands.empty())
1451     rewriter.create<scf::YieldOp>(op.getLoc(), operands);
1452   rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1453 }
1454 
1455 //===----------------------------------------------------------------------===//
1456 // Sparse compiler synthesis methods (loop sequence).
1457 //===----------------------------------------------------------------------===//
1458 
1459 /// Starts a loop sequence at given level. Returns true if
1460 /// the universal loop index must be maintained at this level.
1461 static bool startLoopSeq(Merger &merger, CodeGen &codegen,
1462                          PatternRewriter &rewriter, linalg::GenericOp op,
1463                          std::vector<unsigned> &topSort, unsigned exp,
1464                          unsigned at, unsigned idx, unsigned ldx,
1465                          unsigned lts) {
1466   assert(codegen.curVecLength == 1);
1467   assert(!codegen.loops[idx]);
1468   // Emit invariants at this loop sequence level.
1469   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true);
1470   // Emit access pattern expansion for sparse tensor output.
1471   genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/true);
1472   // Emit further intitialization at this loop sequence level.
1473   unsigned l0 = merger.set(lts)[0];
1474   bool needsUniv =
1475       genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits);
1476   // Maintain the universal index only if it is actually
1477   // consumed by a subsequent lattice point.
1478   if (needsUniv) {
1479     unsigned lsize = merger.set(lts).size();
1480     for (unsigned i = 1; i < lsize; i++) {
1481       unsigned li = merger.set(lts)[i];
1482       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse))
1483         return true;
1484     }
1485   }
1486   return false;
1487 }
1488 
1489 /// Starts a single loop in current sequence.
1490 static Operation *startLoop(Merger &merger, CodeGen &codegen,
1491                             PatternRewriter &rewriter, linalg::GenericOp op,
1492                             std::vector<unsigned> &topSort, unsigned at,
1493                             unsigned li, bool needsUniv) {
1494   assert(codegen.curVecLength == 1);
1495   // Emit the for/while-loop control.
1496   Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at,
1497                             needsUniv, merger.lat(li).simple);
1498   // Emit the locals for this loop.
1499   genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1500             merger.lat(li).bits);
1501   return loop;
1502 }
1503 
1504 /// Ends a single loop in current sequence. Returns new values for needsUniv.
1505 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1506                     linalg::GenericOp op, Operation *loop, unsigned idx,
1507                     unsigned li, bool needsUniv) {
1508   codegen.curVecLength = 1;
1509   // End a while-loop.
1510   if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1511     genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
1512                       merger.lat(li).bits, whileOp);
1513     return needsUniv;
1514   }
1515   // End a for-loop.
1516   genForInduction(merger, codegen, rewriter, op, loop);
1517   return false;
1518 }
1519 
1520 /// Ends a loop sequence at given level.
1521 static void endLoopSeq(Merger &merger, CodeGen &codegen,
1522                        PatternRewriter &rewriter, linalg::GenericOp op,
1523                        unsigned exp, unsigned at, unsigned idx, unsigned ldx) {
1524   assert(codegen.curVecLength == 1);
1525   codegen.loops[idx] = Value();
1526   // Bring a pending reduction back from SIMD form when sequence ends.
1527   if (codegen.redVal)
1528     if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>())
1529       updateReduc(merger, codegen,
1530                   genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp));
1531   // Unmark bookkeeping of invariants and loop index.
1532   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false);
1533   // Finalize access pattern expansion for sparse tensor output.
1534   genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/false);
1535 }
1536 
1537 /// Recursively generates code while computing iteration lattices in order
1538 /// to manage the complexity of implementing co-iteration over unions
1539 /// and intersections of sparse iterations spaces.
1540 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1541                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1542                     unsigned exp, unsigned at) {
1543   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1544   if (at == topSort.size()) {
1545     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1546     genTensorStore(merger, codegen, rewriter, op, rhs);
1547     return;
1548   }
1549 
1550   // Construct iteration lattices for current loop index, with L0 at top.
1551   unsigned idx = topSort[at];
1552   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1553   unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
1554 
1555   // Start a loop sequence.
1556   bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at,
1557                                 idx, ldx, lts);
1558 
1559   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1560   unsigned lsize = merger.set(lts).size();
1561   for (unsigned i = 0; i < lsize; i++) {
1562     // Start a loop.
1563     unsigned li = merger.set(lts)[i];
1564     Operation *loop =
1565         startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv);
1566 
1567     // Visit all lattices points with Li >= Lj to generate the
1568     // loop-body, possibly with if statements for coiteration.
1569     Value redInput = codegen.redVal;
1570     Value cntInput = codegen.expCount;
1571     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1572     for (unsigned j = 0; j < lsize; j++) {
1573       unsigned lj = merger.set(lts)[j];
1574       unsigned ej = merger.lat(lj).exp;
1575       if (li == lj || merger.latGT(li, lj)) {
1576         // Recurse into body of each branch.
1577         if (isWhile) {
1578           scf::IfOp ifOp =
1579               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1580           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1581           endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput);
1582         } else {
1583           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1584         }
1585       }
1586     }
1587 
1588     // End a loop.
1589     needsUniv =
1590         endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv);
1591   }
1592 
1593   // End a loop sequence.
1594   endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx);
1595 }
1596 
1597 /// Converts the result computed by the sparse kernel into the required form.
1598 static void genResult(Merger &merger, CodeGen &codegen,
1599                       PatternRewriter &rewriter, linalg::GenericOp op) {
1600   OpOperand *lhs = op.getOutputOperand(0);
1601   Type resType = lhs->get().getType();
1602   if (getSparseTensorEncoding(resType)) {
1603     // The sparse tensor rematerializes from the original sparse tensor's
1604     // underlying sparse storage format.
1605     rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(),
1606                                         codegen.sparseOut == lhs);
1607   } else {
1608     // To rematerialize an non-annotated tensor, simply load it
1609     // from the bufferized value.
1610     Value val = codegen.buffers.back(); // value array
1611     rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1612   }
1613 }
1614 
1615 //===----------------------------------------------------------------------===//
1616 // Sparse compiler rewriting methods.
1617 //===----------------------------------------------------------------------===//
1618 
1619 namespace {
1620 
1621 /// Sparse rewriting rule for generic Lingalg operation.
1622 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1623 public:
1624   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1625       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1626 
1627   LogicalResult matchAndRewrite(linalg::GenericOp op,
1628                                 PatternRewriter &rewriter) const override {
1629     // Detects sparse annotations and translate the per-dimension sparsity
1630     // information for all tensors to loop indices in the kernel.
1631     assert(op.getNumOutputs() == 1);
1632     unsigned numTensors = op.getNumInputsAndOutputs();
1633     unsigned numLoops = op.iterator_types().getValue().size();
1634     Merger merger(numTensors, numLoops);
1635     if (!findSparseAnnotations(merger, op))
1636       return failure();
1637 
1638     // Computes a topologically sorted iteration graph to ensure
1639     // tensors are visited in natural index order. Fails on cycles.
1640     // This assumes that higher-level passes have already put the
1641     // tensors in each tensor expression in a feasible order.
1642     std::vector<unsigned> topSort;
1643     if (!computeIterationGraph(merger, op, topSort,
1644                                SortMask::kIncludeUndef |
1645                                    SortMask::kIncludeDense) &&
1646         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) &&
1647         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) &&
1648         !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly))
1649       return failure();
1650 
1651     // Builds the tensor expression for the Linalg operation in SSA form.
1652     Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
1653     if (!optExp.hasValue())
1654       return failure();
1655     unsigned exp = optExp.getValue();
1656 
1657     // Rejects an inadmissable tensor expression.
1658     OpOperand *sparseOut = nullptr;
1659     unsigned outerParNest = 0;
1660     if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut,
1661                                outerParNest))
1662       return failure();
1663 
1664     // Recursively generates code.
1665     merger.setHasSparseOut(sparseOut != nullptr);
1666     CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest);
1667     genBuffers(merger, codegen, rewriter, op);
1668     genStmt(merger, codegen, rewriter, op, topSort, exp, 0);
1669     genResult(merger, codegen, rewriter, op);
1670     return success();
1671   }
1672 
1673 private:
1674   /// Options to control sparse code generation.
1675   SparsificationOptions options;
1676 };
1677 
1678 } // namespace
1679 
1680 /// Populates the given patterns list with rewriting rules required for
1681 /// the sparsification of linear algebra operations.
1682 void mlir::populateSparsificationPatterns(
1683     RewritePatternSet &patterns, const SparsificationOptions &options) {
1684   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1685 }
1686