1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements converting sparse tensor types to actual sparse code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodegenUtils.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/MemRef/IR/MemRef.h"
24 #include "mlir/Dialect/SCF/SCF.h"
25 #include "mlir/Dialect/SCF/Transforms.h"
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
29 #include "mlir/Dialect/Vector/IR/VectorOps.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/TensorEncoding.h"
32 #include "llvm/ADT/SmallBitVector.h"
33 
34 using namespace mlir;
35 using namespace mlir::sparse_tensor;
36 
37 //===----------------------------------------------------------------------===//
38 // Declarations of data structures.
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 
43 // Iteration graph sorting.
44 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 };
45 
46 // Reduction kinds.
47 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor };
48 
49 // Code generation.
50 struct CodeGen {
51   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops,
52           OpOperand *op, unsigned nest)
53       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
54         pointers(numTensors, std::vector<Value>(numLoops)),
55         indices(numTensors, std::vector<Value>(numLoops)),
56         highs(numTensors, std::vector<Value>(numLoops)),
57         pidxs(numTensors, std::vector<Value>(numLoops)),
58         idxs(numTensors, std::vector<Value>(numLoops)), redVal(), sparseOut(op),
59         outerParNest(nest), lexIdx(), expValues(), expFilled(), expAdded(),
60         expCount(), curVecMask() {}
61   /// Sparsification options.
62   SparsificationOptions options;
63   /// Universal dense indices and upper bounds (by index). The loops array
64   /// is updated with the value of the universal dense index in the current
65   /// loop. The sizes array is set once with the inferred dimension sizes.
66   std::vector<Value> loops;
67   std::vector<Value> sizes;
68   /// Buffers for storing dense and sparse numerical values (by tensor).
69   /// This array is set once during bufferization of all tensors.
70   std::vector<Value> buffers;
71   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
72   /// This array is set once during bufferization of all sparse tensors.
73   std::vector<std::vector<Value>> pointers;
74   std::vector<std::vector<Value>> indices;
75   /// Sparse iteration information (by tensor and index). These arrays
76   /// are updated to remain current within the current loop.
77   std::vector<std::vector<Value>> highs;
78   std::vector<std::vector<Value>> pidxs;
79   std::vector<std::vector<Value>> idxs;
80   /// Current reduction, updated during code generation. When indices of a
81   /// reduction are exhausted, all inner loops can use a scalarized reduction.
82   unsigned redExp = -1u;
83   Value redVal;
84   Reduction redKind = kNoReduc;
85   // Sparse tensor as output. Implemented either through direct injective
86   // insertion in lexicographic index order (where indices are updated
87   // in the temporary array `lexIdx`) or through access pattern expansion
88   // in the innermost loop nest (`expValues` through `expCount`).
89   OpOperand *sparseOut;
90   unsigned outerParNest;
91   Value lexIdx;
92   Value expValues;
93   Value expFilled;
94   Value expAdded;
95   Value expCount;
96   // Current vector length and mask.
97   unsigned curVecLength = 1;
98   Value curVecMask;
99 };
100 
101 } // namespace
102 
103 //===----------------------------------------------------------------------===//
104 // Sparse compiler analysis methods.
105 //===----------------------------------------------------------------------===//
106 
107 /// Helper method to apply dimension ordering permutation.
108 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) {
109   if (enc) {
110     auto order = enc.getDimOrdering();
111     if (order) {
112       assert(order.isPermutation());
113       return order.getDimPosition(d);
114     }
115   }
116   return d;
117 }
118 
119 /// Helper method to translate dim level type to internal representation.
120 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) {
121   if (enc) {
122     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
123     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
124       return Dim::kSparse;
125     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
126       return Dim::kSingle;
127   }
128   return Dim::kDense;
129 }
130 
131 /// Helper method to inspect affine expressions. Rejects cases where the
132 /// same index is used more than once. Also rejects affine expressions
133 /// that are not a direct index for annotated tensors.
134 // TODO: accept more affine cases for sparse tensors
135 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim,
136                        bool isDense) {
137   switch (a.getKind()) {
138   case AffineExprKind::DimId: {
139     unsigned idx = a.cast<AffineDimExpr>().getPosition();
140     if (!merger.isDim(tensor, idx, Dim::kUndef))
141       return false; // used more than once
142     merger.setDim(tensor, idx, dim);
143     return true;
144   }
145   case AffineExprKind::Add:
146   case AffineExprKind::Mul: {
147     if (!isDense)
148       return false;
149     auto binOp = a.cast<AffineBinaryOpExpr>();
150     return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) &&
151            findAffine(merger, tensor, binOp.getRHS(), dim, isDense);
152   }
153   case AffineExprKind::Constant:
154     return isDense;
155   default:
156     return false;
157   }
158 }
159 
160 /// Helper method to inspect sparse encodings in the tensor types.
161 /// Fills the per-dimension sparsity information for all tensors.
162 /// Returns true if the sparse annotations and affine subscript
163 /// expressions of all tensors are admissable. Returns false if
164 /// no annotations are found or inadmissable constructs occur.
165 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
166   bool annotated = false;
167   for (OpOperand *t : op.getInputAndOutputOperands()) {
168     auto map = op.getTiedIndexingMap(t);
169     auto enc = getSparseTensorEncoding(t->get().getType());
170     if (enc)
171       annotated = true;
172     assert(map.getNumResults() == op.getRank(t));
173     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
174       unsigned tensor = t->getOperandNumber();
175       AffineExpr a = map.getResult(perm(enc, d));
176       if (!findAffine(merger, tensor, a, toDim(enc, d), !enc))
177         return false; // inadmissable affine expression
178     }
179   }
180   return annotated;
181 }
182 
183 /// A DFS helper to compute a topological sort. Note that recursion is
184 /// bounded by the number of implicit loops, which is always small.
185 /// Returns false when a cycle is detected.
186 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
187                        std::vector<unsigned> &topSort,
188                        std::vector<std::vector<bool>> &adjM) {
189   if (visit[i] != 0)
190     return visit[i] != 1; // 1 denotes cycle!
191   visit[i] = 1;
192   for (unsigned j = 0, e = visit.size(); j < e; j++)
193     if (adjM[i][j])
194       if (!topSortDFS(j, visit, topSort, adjM))
195         return false;
196   visit[i] = 2;
197   topSort.push_back(i);
198   return true;
199 }
200 
201 /// Helper method to add all constraints from the indices in one affine
202 /// expression before all indices in the other affine expression. For
203 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3.
204 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
205                                AffineExpr a, AffineExpr b, unsigned fidx) {
206   switch (a.getKind()) {
207   case AffineExprKind::DimId: {
208     unsigned idx = a.cast<AffineDimExpr>().getPosition();
209     if (b)
210       addAffineOrderings(adjM, b, AffineExpr(), idx);
211     else
212       adjM[fidx][idx] = true;
213     break;
214   }
215   case AffineExprKind::Add:
216   case AffineExprKind::Mul: {
217     auto binOp = a.cast<AffineBinaryOpExpr>();
218     addAffineOrderings(adjM, binOp.getLHS(), b, fidx);
219     addAffineOrderings(adjM, binOp.getRHS(), b, fidx);
220     break;
221   }
222   default:
223     break;
224   }
225 }
226 
227 /// Computes a topologically sorted iteration graph for the linalg operation.
228 /// Ensures all tensors are visited in natural index order. This is essential
229 /// for sparse storage formats since these only support access along fixed
230 /// dimensions. Even for dense storage formats, however, the natural index
231 /// order yields innermost unit-stride access with better spatial locality.
232 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
233                                   std::vector<unsigned> &topSort,
234                                   unsigned mask) {
235   // Set up an n x n from/to adjacency matrix of the iteration graph
236   // for the implicit loop indices i_0 .. i_n-1.
237   unsigned n = op.getNumLoops();
238   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
239 
240   // Iterate over the indexing maps of every tensor in the tensor expression.
241   for (OpOperand *t : op.getInputAndOutputOperands()) {
242     auto map = op.getTiedIndexingMap(t);
243     auto enc = getSparseTensorEncoding(t->get().getType());
244     assert(map.getNumDims() == n);
245     // Skip dense tensor constraints when not requested.
246     if (!(mask & SortMask::kIncludeDense) && !enc)
247       continue;
248     // Each tensor expression and optional dimension ordering (row-major
249     // by default) puts an ordering constraint on the loop indices. For
250     // example, the tensor expresion A_ijk forces the ordering i < j < k
251     // on the loop indices if no explicit dimension ordering is given.
252     for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
253       AffineExpr f = map.getResult(perm(enc, d - 1));
254       AffineExpr t = map.getResult(perm(enc, d));
255       addAffineOrderings(adjM, f, t, 0);
256     }
257     // Push unrelated loops into sparse iteration space, so these
258     // will be skipped more often.
259     if (mask & SortMask::kIncludeUndef) {
260       unsigned tensor = t->getOperandNumber();
261       for (unsigned i = 0; i < n; i++)
262         if (merger.isDim(tensor, i, Dim::kSparse))
263           for (unsigned j = 0; j < n; j++)
264             if (merger.isDim(tensor, j, Dim::kUndef))
265               adjM[i][j] = true;
266     }
267   }
268 
269   // Topologically sort the iteration graph to determine loop order.
270   // Report failure for a cyclic iteration graph.
271   topSort.clear();
272   topSort.reserve(n);
273   std::vector<unsigned> visit(n, 0);
274   for (unsigned i = 0; i < n; i++)
275     if (visit[i] == 0)
276       if (!topSortDFS(i, visit, topSort, adjM))
277         return false; // cycle!
278   std::reverse(std::begin(topSort), std::end(topSort));
279   return true;
280 }
281 
282 /// Returns true if tensor has an in-place annotation.
283 static bool isInPlace(Value val) {
284   if (auto arg = val.dyn_cast<BlockArgument>())
285     if (auto funcOp = dyn_cast<func::FuncOp>(arg.getOwner()->getParentOp()))
286       if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
287               arg.getArgNumber(),
288               bufferization::BufferizableOpInterface::kInplaceableAttrName))
289         return attr.getValue();
290   return false;
291 }
292 
293 /// Returns true if tensor materializes uninitialized into the computation.
294 static bool isMaterializing(Value val) {
295   return val.getDefiningOp<linalg::InitTensorOp>() ||
296          val.getDefiningOp<InitOp>();
297 }
298 
299 /// Returns true when the tensor expression is admissable for codegen.
300 /// Since all sparse input tensors are admissable, we just need to check
301 /// whether the out tensor in the tensor expression codegen is admissable.
302 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
303 /// nesting depth when a "truly dynamic" sparse tensor output occurs.
304 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
305                                   std::vector<unsigned> &topSort, unsigned exp,
306                                   OpOperand **sparseOut,
307                                   unsigned &outerParNest) {
308   OpOperand *lhs = op.getOutputOperand(0);
309   unsigned tensor = lhs->getOperandNumber();
310   auto enc = getSparseTensorEncoding(lhs->get().getType());
311   // An non-annotated output tensor is assumed dense, and becomes a random
312   // access n-dim memref. Admissable since insertions cannot occur.
313   if (!enc)
314     return true;
315   // An all-dense annotated "sparse" output tensor becomes a linearized random
316   // access 1-dim memref. Also admissable since insertions cannot occur.
317   bool allDense = true;
318   auto iteratorTypes = op.iterator_types().getValue();
319   unsigned numLoops = iteratorTypes.size();
320   for (unsigned i = 0; i < numLoops; i++)
321     if (merger.isDim(tensor, i, Dim::kSparse)) {
322       allDense = false;
323       break;
324     }
325   if (allDense)
326     return true;
327   // A tensor expression with a sparse output tensor that changes its values
328   // but not its nonzero structure, an operation called "simply dynamic" in
329   // [Bik96,Ch9], is also admissable without special codegen, provided
330   // the tensor's underlying sparse storage scheme can be modified in place.
331   if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get()))
332     return true;
333   // Accept "truly dynamic" if the output tensor materializes uninitialized
334   // into the computation and insertions occur in lexicographic index order.
335   if (isMaterializing(lhs->get())) {
336     unsigned nest = 0;
337     for (unsigned i = 0; i < numLoops; i++) {
338       if (isReductionIterator(iteratorTypes[topSort[i]]))
339         break; // terminate at first reduction
340       nest++;
341     }
342     // Determine admissable dynamic insertion situations:
343     // (1) fully injective, since there are no reductions,
344     // (2) admissable 1-d expansion in innermost dimension.
345     if (nest >= op.getRank(lhs) - 1) {
346       *sparseOut = lhs;
347       outerParNest = nest;
348       return true;
349     }
350   }
351   return false;
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // Sparse compiler synthesis methods (reductions).
356 //===----------------------------------------------------------------------===//
357 
358 /// Maps reduction kind to vector::CombiningKind.
359 static vector::CombiningKind getCombiningKind(Reduction kind) {
360   switch (kind) {
361   case kNoReduc:
362     break;
363   case kSum:
364     return vector::CombiningKind::ADD;
365   case kProduct:
366     return vector::CombiningKind::MUL;
367   case kAnd:
368     return vector::CombiningKind::AND;
369   case kOr:
370     return vector::CombiningKind::OR;
371   case kXor:
372     return vector::CombiningKind::XOR;
373   }
374   llvm_unreachable("unknown reduction kind");
375 }
376 
377 /// Maps operation to reduction.
378 static Reduction getReduction(Kind kind) {
379   switch (kind) {
380   case Kind::kAddF:
381   case Kind::kAddI:
382   case Kind::kSubF:
383   case Kind::kSubI:
384     return kSum;
385   case Kind::kMulF:
386   case Kind::kMulI:
387     return kProduct;
388   case Kind::kAndI:
389     return kAnd;
390   case Kind::kOrI:
391     return kOr;
392   case Kind::kXorI:
393     return kXor;
394   default:
395     llvm_unreachable("unexpected reduction operator");
396   }
397 }
398 
399 /// Generates an initial value for a vector reduction, following the scheme
400 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
401 /// initial scalar value is correctly embedded in the vector reduction value,
402 /// and a straightforward horizontal reduction will complete the operation.
403 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter,
404                                 Location loc, VectorType vtp) {
405   Value r = codegen.redVal;
406   switch (codegen.redKind) {
407   case kNoReduc:
408     break;
409   case kSum:
410   case kXor:
411     // Initialize reduction vector to: | 0 | .. | 0 | r |
412     return rewriter.create<vector::InsertElementOp>(
413         loc, r, constantZero(rewriter, loc, vtp),
414         constantIndex(rewriter, loc, 0));
415   case kProduct:
416     // Initialize reduction vector to: | 1 | .. | 1 | r |
417     return rewriter.create<vector::InsertElementOp>(
418         loc, r, constantOne(rewriter, loc, vtp),
419         constantIndex(rewriter, loc, 0));
420   case kAnd:
421   case kOr:
422     // Initialize reduction vector to: | r | .. | r | r |
423     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
424   }
425   llvm_unreachable("unknown reduction kind");
426 }
427 
428 /// Generates final value for a vector reduction.
429 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter,
430                                Location loc, VectorType vtp) {
431   vector::CombiningKind kind = getCombiningKind(codegen.redKind);
432   return rewriter.create<vector::ReductionOp>(loc, kind, codegen.redVal);
433 }
434 
435 /// Updates scalarized reduction value.
436 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) {
437   assert(codegen.redKind != kNoReduc);
438   codegen.redVal = merger.exp(codegen.redExp).val = reduc;
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // Sparse compiler synthesis methods (statements and expressions).
443 //===----------------------------------------------------------------------===//
444 
445 /// Generates buffer for the output tensor. Note that all sparse kernels
446 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)),
447 /// the output buffer is already initialized to all zeroes and only nonzeroes
448 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)),
449 /// only nonzeroes values are used for the updates and no assumption on the
450 /// original contents of the output buffer is necessary..
451 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
452                              linalg::GenericOp op, MemRefType denseTp,
453                              ArrayRef<Value> args) {
454   Location loc = op.getLoc();
455   Value tensor = op.getOutputOperand(0)->get();
456   // The output tensor simply could materialize from the buffer that will
457   // be generated for the tensor present in the outs() clause. This has
458   // the major advantage that the sparse kernel only updates the nonzero
459   // positions for the output tensor.
460   if (isInPlace(tensor))
461     return rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
462   // By default, a new buffer is allocated which is initialized to the
463   // tensor defined in the outs() clause. This is always correct but
464   // introduces a dense initialization component that may negatively
465   // impact the running complexity of the sparse kernel. If the tensor
466   // materializes into the computation, we need to preserve the zero
467   // initialization assumption of all sparse output buffers.
468   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
469   if (isMaterializing(tensor)) {
470     Value zero = constantZero(rewriter, loc, denseTp.getElementType());
471     rewriter.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{alloc});
472   } else {
473     Value init =
474         rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
475     rewriter.create<memref::CopyOp>(loc, init, alloc);
476   }
477   return alloc;
478 }
479 
480 /// Local bufferization of all dense and sparse data structures.
481 /// This code enables testing the first prototype sparse compiler.
482 // TODO: replace this with a proliferated bufferization strategy
483 static void genBuffers(Merger &merger, CodeGen &codegen,
484                        PatternRewriter &rewriter, linalg::GenericOp op) {
485   Location loc = op.getLoc();
486   assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
487   // For every tensor, find lower and upper bound on dimensions, set the
488   // same bounds on loop indices, and obtain dense or sparse buffer(s).
489   SmallVector<Value, 4> args;
490   for (OpOperand *t : op.getInputAndOutputOperands()) {
491     unsigned tensor = t->getOperandNumber();
492     auto shape = op.getShape(t);
493     auto map = op.getTiedIndexingMap(t);
494     auto enc = getSparseTensorEncoding(t->get().getType());
495     // Scan all dimensions of current tensor.
496     args.clear();
497     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
498       AffineExpr a = map.getResult(perm(enc, d));
499       if (a.getKind() != AffineExprKind::DimId)
500         continue; // compound
501       unsigned idx = a.cast<AffineDimExpr>().getPosition();
502       // Handle sparse storage schemes.
503       if (merger.isDim(tensor, idx, Dim::kSparse)) {
504         auto dynShape = {ShapedType::kDynamicSize};
505         auto ptrTp =
506             MemRefType::get(dynShape, getPointerOverheadType(rewriter, enc));
507         auto indTp =
508             MemRefType::get(dynShape, getIndexOverheadType(rewriter, enc));
509         Value dim = constantIndex(rewriter, loc, d);
510         // Generate sparse primitives to obtains pointer and indices.
511         codegen.pointers[tensor][idx] =
512             rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
513         codegen.indices[tensor][idx] =
514             rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim);
515       }
516       // Find upper bound in current dimension.
517       unsigned p = perm(enc, d);
518       Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p);
519       if (ShapedType::isDynamic(shape[p]))
520         args.push_back(up);
521       assert(codegen.highs[tensor][idx] == nullptr);
522       codegen.sizes[idx] = codegen.highs[tensor][idx] = up;
523     }
524     // Perform the required bufferization. Dense inputs materialize
525     // from the input tensors. Dense outputs need special handling.
526     // Sparse inputs use sparse primitives to obtain the values.
527     // We also accept in-place all-dense annotated "sparse" outputs.
528     Type elementType = getElementTypeOrSelf(t->get().getType());
529     if (!enc) {
530       // Non-annotated dense tensors.
531       auto denseTp = MemRefType::get(shape, elementType);
532       if (tensor < op.getNumInputs())
533         codegen.buffers[tensor] =
534             rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, t->get());
535       else
536         codegen.buffers[tensor] =
537             genOutputBuffer(codegen, rewriter, op, denseTp, args);
538     } else if (t == codegen.sparseOut) {
539       // True sparse output needs a lexIdx array.
540       Value rank = constantIndex(rewriter, loc, op.getRank(t));
541       auto dynShape = {ShapedType::kDynamicSize};
542       auto memTp = MemRefType::get(dynShape, rewriter.getIndexType());
543       codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank);
544     } else {
545       // Annotated sparse tensors.
546       auto dynShape = {ShapedType::kDynamicSize};
547       auto sparseTp = MemRefType::get(dynShape, elementType);
548       codegen.buffers[tensor] =
549           rewriter.create<ToValuesOp>(loc, sparseTp, t->get());
550     }
551   }
552 }
553 
554 /// Constructs vector type.
555 static VectorType vectorType(CodeGen &codegen, Type etp) {
556   unsigned numScalableDims = codegen.options.enableVLAVectorization;
557   return VectorType::get(codegen.curVecLength, etp, numScalableDims);
558 }
559 
560 /// Constructs vector type from pointer.
561 static VectorType vectorType(CodeGen &codegen, Value ptr) {
562   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
563 }
564 
565 /// Constructs vector iteration mask.
566 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
567                            Value iv, Value lo, Value hi, Value step) {
568   Location loc = iv.getLoc();
569   VectorType mtp = vectorType(codegen, rewriter.getI1Type());
570   // Special case if the vector length evenly divides the trip count (for
571   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
572   // so that all subsequent masked memory operations are immediately folded
573   // into unconditional memory operations.
574   IntegerAttr loInt, hiInt, stepInt;
575   if (matchPattern(lo, m_Constant(&loInt)) &&
576       matchPattern(hi, m_Constant(&hiInt)) &&
577       matchPattern(step, m_Constant(&stepInt))) {
578     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
579       return rewriter.create<vector::BroadcastOp>(
580           loc, mtp, constantI1(rewriter, loc, true));
581   }
582   // Otherwise, generate a vector mask that avoids overrunning the upperbound
583   // during vector execution. Here we rely on subsequent loop optimizations to
584   // avoid executing the mask in all iterations, for example, by splitting the
585   // loop into an unconditional vector loop and a scalar cleanup loop.
586   auto minMap = AffineMap::get(
587       /*dimCount=*/2, /*symbolCount=*/1,
588       {rewriter.getAffineSymbolExpr(0),
589        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
590       rewriter.getContext());
591   Value end =
592       rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step});
593   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
594 }
595 
596 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
597 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
598                            Value ptr, ArrayRef<Value> args) {
599   Location loc = ptr.getLoc();
600   VectorType vtp = vectorType(codegen, ptr);
601   Value pass = constantZero(rewriter, loc, vtp);
602   if (args.back().getType().isa<VectorType>()) {
603     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
604     Value indexVec = args.back();
605     scalarArgs.back() = constantIndex(rewriter, loc, 0);
606     return rewriter.create<vector::GatherOp>(
607         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
608   }
609   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
610                                                codegen.curVecMask, pass);
611 }
612 
613 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
614 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
615                            Value rhs, Value ptr, ArrayRef<Value> args) {
616   Location loc = ptr.getLoc();
617   if (args.back().getType().isa<VectorType>()) {
618     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
619     Value indexVec = args.back();
620     scalarArgs.back() = constantIndex(rewriter, loc, 0);
621     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
622                                        codegen.curVecMask, rhs);
623     return;
624   }
625   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
626                                          rhs);
627 }
628 
629 /// Generates a vectorized invariant. Here we rely on subsequent loop
630 /// optimizations to hoist the invariant broadcast out of the vector loop.
631 static Value genVectorInvariantValue(CodeGen &codegen,
632                                      PatternRewriter &rewriter, Value val) {
633   VectorType vtp = vectorType(codegen, val.getType());
634   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
635 }
636 
637 /// Generates an affine expression.
638 //
639 // TODO: generalize for sparse tensor subscripts
640 //
641 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter,
642                        AffineExpr a, Location loc) {
643   switch (a.getKind()) {
644   case AffineExprKind::DimId: {
645     unsigned idx = a.cast<AffineDimExpr>().getPosition();
646     return codegen.loops[idx]; // universal dense index
647   }
648   case AffineExprKind::Add: {
649     auto binOp = a.cast<AffineBinaryOpExpr>();
650     return rewriter.create<arith::AddIOp>(
651         loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
652         genAffine(codegen, rewriter, binOp.getRHS(), loc));
653   }
654   case AffineExprKind::Mul: {
655     auto binOp = a.cast<AffineBinaryOpExpr>();
656     return rewriter.create<arith::MulIOp>(
657         loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
658         genAffine(codegen, rewriter, binOp.getRHS(), loc));
659   }
660   case AffineExprKind::Constant: {
661     int64_t c = a.cast<AffineConstantExpr>().getValue();
662     return constantIndex(rewriter, loc, c);
663   }
664   default:
665     llvm_unreachable("unexpected affine subscript");
666   }
667 }
668 
669 /// Generates index for load/store on sparse tensor.
670 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) {
671   auto map = op.getTiedIndexingMap(t);
672   auto enc = getSparseTensorEncoding(t->get().getType());
673   AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1));
674   assert(a.getKind() == AffineExprKind::DimId);
675   unsigned idx = a.cast<AffineDimExpr>().getPosition();
676   return codegen.loops[idx];
677 }
678 
679 /// Generates subscript for load/store on a dense or sparse tensor.
680 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter,
681                           linalg::GenericOp op, OpOperand *t,
682                           SmallVector<Value, 4> &args) {
683   unsigned tensor = t->getOperandNumber();
684   auto map = op.getTiedIndexingMap(t);
685   auto enc = getSparseTensorEncoding(t->get().getType());
686   unsigned rank = map.getNumResults();
687   if (enc) {
688     // Note that currently, all sparse subscripts are simple.
689     // TODO: accept affine too?
690     AffineExpr a = map.getResult(perm(enc, rank - 1));
691     assert(a.getKind() == AffineExprKind::DimId);
692     unsigned idx = a.cast<AffineDimExpr>().getPosition();
693     assert(codegen.pidxs[tensor][idx] != nullptr);
694     args.push_back(codegen.pidxs[tensor][idx]); // position index
695   } else {
696     for (unsigned d = 0; d < rank; d++) {
697       AffineExpr a = map.getResult(perm(enc, d));
698       args.push_back(genAffine(codegen, rewriter, a, op.getLoc()));
699     }
700   }
701   return codegen.buffers[tensor];
702 }
703 
704 /// Generates insertion code to implement dynamic tensor load.
705 static Value genInsertionLoad(CodeGen &codegen, PatternRewriter &rewriter,
706                               linalg::GenericOp op, OpOperand *t) {
707   Location loc = op.getLoc();
708   // Direct lexicographic index order, tensor loads as zero.
709   if (!codegen.expValues) {
710     Type tp = getElementTypeOrSelf(t->get().getType());
711     return constantZero(rewriter, loc, tp);
712   }
713   // Load from expanded access pattern.
714   Value index = genIndex(codegen, op, t);
715   return rewriter.create<memref::LoadOp>(loc, codegen.expValues, index);
716 }
717 
718 /// Generates insertion code to implement dynamic tensor store.
719 static void genInsertionStore(CodeGen &codegen, PatternRewriter &rewriter,
720                               linalg::GenericOp op, OpOperand *t, Value rhs) {
721   Location loc = op.getLoc();
722   // Direct insertion in lexicographic index order.
723   if (!codegen.expValues) {
724     rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs);
725     return;
726   }
727   // Generates insertion code along expanded access pattern.
728   //   if (!expFilled[i]) then
729   //     expFilled[i] = true
730   //     expAdded[inserts++] = i
731   //   endif
732   //   values[i] = rhs
733   Value index = genIndex(codegen, op, t);
734   Value fval = constantI1(rewriter, loc, false);
735   Value tval = constantI1(rewriter, loc, true);
736   // If statement.
737   Value filled = rewriter.create<memref::LoadOp>(loc, codegen.expFilled, index);
738   Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
739                                               filled, fval);
740   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, rewriter.getIndexType(),
741                                               cond, /*else=*/true);
742   // True branch.
743   rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
744   rewriter.create<memref::StoreOp>(loc, tval, codegen.expFilled, index);
745   rewriter.create<memref::StoreOp>(loc, index, codegen.expAdded,
746                                    codegen.expCount);
747   Value one = constantIndex(rewriter, loc, 1);
748   Value add = rewriter.create<arith::AddIOp>(loc, codegen.expCount, one);
749   rewriter.create<scf::YieldOp>(loc, add);
750   // False branch.
751   rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
752   rewriter.create<scf::YieldOp>(loc, codegen.expCount);
753   rewriter.setInsertionPointAfter(ifOp);
754   // Value assignment.
755   codegen.expCount = ifOp.getResult(0);
756   rewriter.create<memref::StoreOp>(loc, rhs, codegen.expValues, index);
757 }
758 
759 /// Generates a load on a dense or sparse tensor.
760 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
761                            PatternRewriter &rewriter, linalg::GenericOp op,
762                            unsigned exp) {
763   // Test if the load was hoisted to a higher loop nest.
764   Value val = merger.exp(exp).val;
765   if (val) {
766     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
767       return genVectorInvariantValue(codegen, rewriter, val);
768     return val;
769   }
770   // Load during insertion.
771   OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
772   if (t == codegen.sparseOut)
773     return genInsertionLoad(codegen, rewriter, op, t);
774   // Actual load.
775   SmallVector<Value, 4> args;
776   Value ptr = genSubscript(codegen, rewriter, op, t, args);
777   if (codegen.curVecLength > 1)
778     return genVectorLoad(codegen, rewriter, ptr, args);
779   return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
780 }
781 
782 /// Generates a store on a dense or sparse tensor.
783 static void genTensorStore(Merger &merger, CodeGen &codegen,
784                            PatternRewriter &rewriter, linalg::GenericOp op,
785                            unsigned exp, Value rhs) {
786   Location loc = op.getLoc();
787   // Test if this is a scalarized reduction.
788   if (codegen.redVal) {
789     if (codegen.curVecLength > 1)
790       rhs = rewriter.create<arith::SelectOp>(loc, codegen.curVecMask, rhs,
791                                              codegen.redVal);
792     updateReduc(merger, codegen, rhs);
793     return;
794   }
795   // Store during insertion.
796   OpOperand *t = op.getOutputOperand(0);
797   if (t == codegen.sparseOut) {
798     if (!rhs) {
799       // Only unary and binary are allowed to return uninitialized rhs
800       // to indicate missing output.
801       assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary);
802     } else {
803       genInsertionStore(codegen, rewriter, op, t, rhs);
804     }
805     return;
806   }
807   // Actual store.
808   SmallVector<Value, 4> args;
809   Value ptr = genSubscript(codegen, rewriter, op, t, args);
810   if (codegen.curVecLength > 1)
811     genVectorStore(codegen, rewriter, rhs, ptr, args);
812   else
813     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
814 }
815 
816 /// Generates a pointer/index load from the sparse storage scheme. Narrower
817 /// data types need to be zero extended before casting the value into the
818 /// index type used for looping and indexing.
819 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
820                      Value ptr, Value s) {
821   // See https://llvm.org/docs/GetElementPtr.html for some background on
822   // the complications described below.
823   if (codegen.curVecLength > 1) {
824     // Since the index vector is used in a subsequent gather/scatter operations,
825     // which effectively defines an unsigned pointer + signed index, we must
826     // zero extend the vector to an index width. For 8-bit and 16-bit values,
827     // an 32-bit index width suffices. For 32-bit values, zero extending the
828     // elements into 64-bit loses some performance since the 32-bit indexed
829     // gather/scatter is more efficient than the 64-bit index variant (if the
830     // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
831     // preserve this performance). For 64-bit values, there is no good way
832     // to state that the indices are unsigned, with creates the potential of
833     // incorrect address calculations in the unlikely case we need such
834     // extremely large offsets.
835     Type etp = ptr.getType().cast<MemRefType>().getElementType();
836     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
837     if (!etp.isa<IndexType>()) {
838       if (etp.getIntOrFloatBitWidth() < 32)
839         vload = rewriter.create<arith::ExtUIOp>(
840             loc, vectorType(codegen, rewriter.getI32Type()), vload);
841       else if (etp.getIntOrFloatBitWidth() < 64 &&
842                !codegen.options.enableSIMDIndex32)
843         vload = rewriter.create<arith::ExtUIOp>(
844             loc, vectorType(codegen, rewriter.getI64Type()), vload);
845     }
846     return vload;
847   }
848   // For the scalar case, we simply zero extend narrower indices into 64-bit
849   // values before casting to index without a performance penalty. Here too,
850   // however, indices that already are 64-bit, in theory, cannot express the
851   // full range as explained above.
852   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
853   if (!load.getType().isa<IndexType>()) {
854     if (load.getType().getIntOrFloatBitWidth() < 64)
855       load = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), load);
856     load =
857         rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), load);
858   }
859   return load;
860 }
861 
862 /// Generates an invariant value.
863 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
864                                PatternRewriter &rewriter, unsigned exp) {
865   Value val = merger.exp(exp).val;
866   if (codegen.curVecLength > 1)
867     return genVectorInvariantValue(codegen, rewriter, val);
868   return val;
869 }
870 
871 /// Generates an address computation "sz * p + i".
872 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
873                         Location loc, Value size, Value p, Value i) {
874   Value mul = rewriter.create<arith::MulIOp>(loc, size, p);
875   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
876     Value inv =
877         rewriter.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul);
878     mul = genVectorInvariantValue(codegen, rewriter, inv);
879   }
880   return rewriter.create<arith::AddIOp>(loc, mul, i);
881 }
882 
883 /// Generates an index value.
884 static Value genIndexValue(Merger &merger, CodeGen &codegen,
885                            PatternRewriter &rewriter, unsigned exp,
886                            unsigned ldx) {
887   unsigned idx = merger.exp(exp).index;
888   Value ival = codegen.loops[idx];
889   Type itype = ival.getType();
890   // During vectorization, we either encounter:
891   // (1) indices already in vector form, as in ... = ind[lo:hi], good to go, or
892   // (2) single index, as in ... = i, must convert to [i, i+1, ...] for inner i.
893   unsigned vl = codegen.curVecLength;
894   if (vl > 1 && !itype.isa<VectorType>()) {
895     Location loc = ival.getLoc();
896     VectorType vtp = vectorType(codegen, itype);
897     ival = rewriter.create<vector::BroadcastOp>(loc, vtp, ival);
898     if (idx == ldx) {
899       Value incr;
900       if (vtp.isScalable()) {
901         Type stepvty = vectorType(codegen, rewriter.getI64Type());
902         Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty);
903         incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv);
904       } else {
905         SmallVector<APInt, 4> integers;
906         for (unsigned i = 0; i < vl; i++)
907           integers.push_back(APInt(/*width=*/64, i));
908         auto values = DenseElementsAttr::get(vtp, integers);
909         incr = rewriter.create<arith::ConstantOp>(loc, vtp, values);
910       }
911       ival = rewriter.create<arith::AddIOp>(loc, ival, incr);
912     }
913   }
914   return ival;
915 }
916 
917 /// Recursively generates tensor expression.
918 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
919                     linalg::GenericOp op, unsigned exp, unsigned ldx) {
920   Location loc = op.getLoc();
921   if (exp == -1u)
922     return Value();
923   if (merger.exp(exp).kind == Kind::kTensor)
924     return genTensorLoad(merger, codegen, rewriter, op, exp);
925   if (merger.exp(exp).kind == Kind::kInvariant)
926     return genInvariantValue(merger, codegen, rewriter, exp);
927   if (merger.exp(exp).kind == Kind::kIndex)
928     return genIndexValue(merger, codegen, rewriter, exp, ldx);
929   Value v0 =
930       genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx);
931   Value v1 =
932       genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx);
933   return merger.buildExp(rewriter, loc, exp, v0, v1);
934 }
935 
936 /// Determines if affine expression is invariant.
937 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
938                               unsigned ldx, bool &atLevel) {
939   switch (a.getKind()) {
940   case AffineExprKind::DimId: {
941     unsigned idx = a.cast<AffineDimExpr>().getPosition();
942     if (idx == ldx)
943       atLevel = true;
944     return codegen.loops[idx] != nullptr; // no longer in play?
945   }
946   case AffineExprKind::Add:
947   case AffineExprKind::Mul: {
948     auto binOp = a.cast<AffineBinaryOpExpr>();
949     return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) &&
950            isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel);
951   }
952   default:
953     return true;
954   }
955 }
956 
957 /// Hoists loop invariant tensor loads for which indices have been exhausted.
958 static void genInvariants(Merger &merger, CodeGen &codegen,
959                           PatternRewriter &rewriter, linalg::GenericOp op,
960                           unsigned exp, unsigned ldx, bool atStart,
961                           Kind last = Kind::kTensor) {
962   if (exp == -1u)
963     return;
964   if (merger.exp(exp).kind == Kind::kTensor) {
965     // Inspect tensor indices.
966     bool atLevel = ldx == -1u;
967     OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
968     auto map = op.getTiedIndexingMap(t);
969     auto enc = getSparseTensorEncoding(t->get().getType());
970     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
971       AffineExpr a = map.getResult(perm(enc, d));
972       if (!isInvariantAffine(codegen, a, ldx, atLevel))
973         return; // still in play
974     }
975     // All exhausted at this level (atLevel denotes exactly at this level).
976     if (!atLevel)
977       return;
978     OpOperand *lhs = op.getOutputOperand(0);
979     if (lhs == t) {
980       // Start or end a scalarized reduction
981       if (atStart) {
982         Value load = genTensorLoad(merger, codegen, rewriter, op, exp);
983         codegen.redKind = getReduction(last);
984         codegen.redExp = exp;
985         updateReduc(merger, codegen, load);
986       } else {
987         Value redVal = codegen.redVal;
988         updateReduc(merger, codegen, Value());
989         codegen.redExp = -1u;
990         codegen.redKind = kNoReduc;
991         genTensorStore(merger, codegen, rewriter, op, exp, redVal);
992       }
993     } else {
994       // Start or end loop invariant hoisting of a tensor load.
995       merger.exp(exp).val =
996           atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
997     }
998   } else if (merger.exp(exp).kind != Kind::kInvariant &&
999              merger.exp(exp).kind != Kind::kIndex) {
1000     // Traverse into the binary operations. Note that we only hoist
1001     // tensor loads, since subsequent MLIR/LLVM passes know how to
1002     // deal with all other kinds of derived loop invariants.
1003     Kind last = merger.exp(exp).kind;
1004     unsigned e0 = merger.exp(exp).children.e0;
1005     unsigned e1 = merger.exp(exp).children.e1;
1006     genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last);
1007     genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last);
1008   }
1009 }
1010 
1011 /// Generates an expanded access pattern in innermost dimension.
1012 static void genExpansion(Merger &merger, CodeGen &codegen,
1013                          PatternRewriter &rewriter, linalg::GenericOp op,
1014                          unsigned at, bool atStart) {
1015   OpOperand *lhs = codegen.sparseOut;
1016   if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 ||
1017       at != codegen.outerParNest)
1018     return; // not needed at this level
1019   // Generate start or end of an expanded access pattern.
1020   Value tensor = lhs->get();
1021   Location loc = op.getLoc();
1022   if (atStart) {
1023     auto dynShape = {ShapedType::kDynamicSize};
1024     Type etp = tensor.getType().cast<ShapedType>().getElementType();
1025     Type t1 = MemRefType::get(dynShape, etp);
1026     Type t2 = MemRefType::get(dynShape, rewriter.getI1Type());
1027     Type t3 = MemRefType::get(dynShape, rewriter.getIndexType());
1028     Type t4 = rewriter.getIndexType();
1029     auto res =
1030         rewriter.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
1031     assert(res.getNumResults() == 4);
1032     assert(!codegen.expValues);
1033     codegen.expValues = res.getResult(0);
1034     codegen.expFilled = res.getResult(1);
1035     codegen.expAdded = res.getResult(2);
1036     codegen.expCount = res.getResult(3);
1037   } else {
1038     assert(codegen.expValues);
1039     rewriter.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues,
1040                                 codegen.expFilled, codegen.expAdded,
1041                                 codegen.expCount);
1042     codegen.expValues = codegen.expFilled = codegen.expAdded =
1043         codegen.expCount = Value();
1044   }
1045 }
1046 
1047 /// Generates initialization code for the subsequent loop sequence at
1048 /// current index level. Returns true if the loop sequence needs to
1049 /// maintain the universal index.
1050 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1051                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1052                     unsigned at, BitVector &inits) {
1053   bool needsUniv = false;
1054   Location loc = op.getLoc();
1055   unsigned idx = topSort[at];
1056 
1057   // Initialize sparse positions.
1058   for (unsigned b = 0, be = inits.size(); b < be; b++) {
1059     if (inits[b]) {
1060       unsigned tensor = merger.tensor(b);
1061       assert(idx == merger.index(b));
1062       if (merger.isDim(b, Dim::kSparse)) {
1063         // Initialize sparse index.
1064         unsigned pat = at;
1065         for (; pat != 0; pat--) {
1066           if (codegen.pidxs[tensor][topSort[pat - 1]])
1067             break;
1068         }
1069         Value ptr = codegen.pointers[tensor][idx];
1070         Value one = constantIndex(rewriter, loc, 1);
1071         Value p0 = (pat == 0) ? constantIndex(rewriter, loc, 0)
1072                               : codegen.pidxs[tensor][topSort[pat - 1]];
1073         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
1074         Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one);
1075         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
1076       } else {
1077         // Dense index still in play.
1078         needsUniv = true;
1079       }
1080     }
1081   }
1082 
1083   // Initialize the universal dense index.
1084   codegen.loops[idx] = constantIndex(rewriter, loc, 0);
1085   return needsUniv;
1086 }
1087 
1088 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
1089 /// operation is a candidate. Whether it is actually converted to SIMD code
1090 /// depends on the requested strategy.
1091 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction,
1092                         bool isSparse) {
1093   // Reject vectorization of sparse output, unless innermost is reduction.
1094   if (codegen.sparseOut && !isReduction)
1095     return false;
1096   // Inspect strategy.
1097   switch (codegen.options.vectorizationStrategy) {
1098   case SparseVectorizationStrategy::kNone:
1099     return false;
1100   case SparseVectorizationStrategy::kDenseInnerLoop:
1101     return isInner && !isSparse;
1102   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
1103     return isInner;
1104   }
1105   llvm_unreachable("unexpected vectorization strategy");
1106 }
1107 
1108 /// Returns parallelization strategy. Any implicit loop in the Linalg operation
1109 /// that is marked "parallel" is a candidate. Whether it is actually converted
1110 /// to a parallel operation depends on the requested strategy.
1111 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
1112                           bool isSparse, bool isVector) {
1113   // Reject parallelization of sparse output.
1114   if (codegen.sparseOut)
1115     return false;
1116   // Inspect strategy.
1117   switch (codegen.options.parallelizationStrategy) {
1118   case SparseParallelizationStrategy::kNone:
1119     return false;
1120   case SparseParallelizationStrategy::kDenseOuterLoop:
1121     return isOuter && !isSparse && !isReduction && !isVector;
1122   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
1123     return isOuter && !isReduction && !isVector;
1124   case SparseParallelizationStrategy::kDenseAnyLoop:
1125     return !isSparse && !isReduction && !isVector;
1126   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
1127     return !isReduction && !isVector;
1128   }
1129   llvm_unreachable("unexpected parallelization strategy");
1130 }
1131 
1132 /// Checks unit stride for dense tensors. The iteration graph may have ignored
1133 /// dense access patterns in order to avoid cycles (sparse access patterns are
1134 /// always placed innermost), but that means dense access has become strided.
1135 /// This prevents effective vectorization.
1136 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
1137                              unsigned idx) {
1138   for (OpOperand *t : op.getInputAndOutputOperands()) {
1139     if (!getSparseTensorEncoding(t->get().getType())) {
1140       auto map = op.getTiedIndexingMap(t);
1141       for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
1142         AffineExpr a = map.getResult(d);
1143         // Report non-unit stride if innermost index appears at an outer
1144         // dimension (true non-unit stride) or if the innermost index appears
1145         // in a compound subscript in the innermost dimension. Even if the
1146         // latter is unit stride, it does not play well with scatter/gather.
1147         // TODO: accept unit stride affine innermost like a[i,j+k+1]?
1148         if (a.isFunctionOfDim(idx) &&
1149             ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId)))
1150           return false;
1151       }
1152     }
1153   }
1154   return true;
1155 }
1156 
1157 /// Generates a for-loop on a single index.
1158 static Operation *genFor(Merger &merger, CodeGen &codegen,
1159                          PatternRewriter &rewriter, linalg::GenericOp op,
1160                          bool isOuter, bool isInner, unsigned idx,
1161                          BitVector &indices) {
1162   unsigned fb = indices.find_first();
1163   unsigned tensor = merger.tensor(fb);
1164   assert(idx == merger.index(fb));
1165   auto iteratorTypes = op.iterator_types().getValue();
1166   bool isReduction = isReductionIterator(iteratorTypes[idx]);
1167   bool isSparse = merger.isDim(fb, Dim::kSparse);
1168   bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) &&
1169                   denseUnitStrides(merger, op, idx);
1170   bool isParallel =
1171       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
1172 
1173   // Prepare vector length.
1174   if (isVector)
1175     codegen.curVecLength = codegen.options.vectorLength;
1176 
1177   // Loop bounds and increment.
1178   Location loc = op.getLoc();
1179   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
1180   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
1181   Value step = constantIndex(rewriter, loc, codegen.curVecLength);
1182   if (isVector && codegen.options.enableVLAVectorization) {
1183     Value vscale = rewriter.create<vector::VectorScaleOp>(
1184         loc, IndexType::get(rewriter.getContext()));
1185     step = rewriter.create<arith::MulIOp>(loc, vscale, step);
1186   }
1187 
1188   // Emit a parallel loop.
1189   if (isParallel) {
1190     assert(!isVector);
1191     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
1192     if (isSparse)
1193       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
1194     else
1195       codegen.loops[idx] = parOp.getInductionVars()[0];
1196     rewriter.setInsertionPointToStart(parOp.getBody());
1197     return parOp;
1198   }
1199 
1200   // Emit a sequential or vector loop.
1201   SmallVector<Value, 4> operands;
1202   if (codegen.redVal) {
1203     // In a vector loop, bring reduction into SIMD form, if not already.
1204     if (isVector && !codegen.redVal.getType().isa<VectorType>()) {
1205       VectorType vtp = vectorType(codegen, codegen.redVal.getType());
1206       Value vred = genVectorReducInit(codegen, rewriter, loc, vtp);
1207       updateReduc(merger, codegen, vred);
1208     }
1209     operands.push_back(codegen.redVal);
1210   }
1211   if (codegen.expValues)
1212     operands.push_back(codegen.expCount);
1213   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
1214   if (codegen.redVal)
1215     updateReduc(merger, codegen, forOp.getRegionIterArgs().front());
1216   if (codegen.expValues)
1217     codegen.expCount = forOp.getRegionIterArgs().back();
1218   // Assign induction variable to sparse or dense index.
1219   Value iv = forOp.getInductionVar();
1220   if (isSparse)
1221     codegen.pidxs[tensor][idx] = iv;
1222   else
1223     codegen.loops[idx] = iv;
1224   rewriter.setInsertionPointToStart(forOp.getBody());
1225   // Share vector iteration mask between all subsequent loads/stores.
1226   if (isVector)
1227     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
1228   return forOp;
1229 }
1230 
1231 /// Emit a while-loop for co-iteration over multiple indices.
1232 static Operation *genWhile(Merger &merger, CodeGen &codegen,
1233                            PatternRewriter &rewriter, linalg::GenericOp op,
1234                            unsigned idx, bool needsUniv, BitVector &indices) {
1235   SmallVector<Type, 4> types;
1236   SmallVector<Value, 4> operands;
1237   // Construct the while-loop with a parameter for each index.
1238   Type indexType = rewriter.getIndexType();
1239   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1240     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1241       unsigned tensor = merger.tensor(b);
1242       assert(idx == merger.index(b));
1243       types.push_back(indexType);
1244       operands.push_back(codegen.pidxs[tensor][idx]);
1245     }
1246   }
1247   if (codegen.redVal) {
1248     types.push_back(codegen.redVal.getType());
1249     operands.push_back(codegen.redVal);
1250   }
1251   if (codegen.expValues) {
1252     types.push_back(indexType);
1253     operands.push_back(codegen.expCount);
1254   }
1255   if (needsUniv) {
1256     types.push_back(indexType);
1257     operands.push_back(codegen.loops[idx]);
1258   }
1259   assert(types.size() == operands.size());
1260   Location loc = op.getLoc();
1261   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
1262 
1263   SmallVector<Location> locs(types.size(), loc);
1264   Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, locs);
1265   Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, types, locs);
1266 
1267   // Build the "before" region, which effectively consists
1268   // of a conjunction of "i < upper" tests on all induction.
1269   rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
1270   Value cond;
1271   unsigned o = 0;
1272   for (unsigned b = 0, be = indices.size(); b < be; b++) {
1273     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
1274       unsigned tensor = merger.tensor(b);
1275       assert(idx == merger.index(b));
1276       Value op1 = before->getArgument(o);
1277       Value op2 = codegen.highs[tensor][idx];
1278       Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
1279                                                  op1, op2);
1280       cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc;
1281       codegen.pidxs[tensor][idx] = after->getArgument(o++);
1282     }
1283   }
1284   if (codegen.redVal)
1285     updateReduc(merger, codegen, after->getArgument(o++));
1286   if (codegen.expValues)
1287     codegen.expCount = after->getArgument(o++);
1288   if (needsUniv)
1289     codegen.loops[idx] = after->getArgument(o++);
1290   assert(o == operands.size());
1291   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
1292   rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
1293   return whileOp;
1294 }
1295 
1296 /// Generates a for-loop or a while-loop, depending on whether it implements
1297 /// singleton iteration or co-iteration over the given conjunction.
1298 static Operation *genLoop(Merger &merger, CodeGen &codegen,
1299                           PatternRewriter &rewriter, linalg::GenericOp op,
1300                           std::vector<unsigned> &topSort, unsigned at,
1301                           bool needsUniv, BitVector &indices) {
1302   unsigned idx = topSort[at];
1303   if (indices.count() == 1) {
1304     bool isOuter = at == 0;
1305     bool isInner = at == topSort.size() - 1;
1306     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
1307                   indices);
1308   }
1309   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
1310 }
1311 
1312 /// Generates the local variables for this loop, consisting of the sparse
1313 /// indices, restored universal dense index, and dense positions.
1314 static void genLocals(Merger &merger, CodeGen &codegen,
1315                       PatternRewriter &rewriter, linalg::GenericOp op,
1316                       std::vector<unsigned> &topSort, unsigned at,
1317                       bool needsUniv, BitVector &locals) {
1318   Location loc = op.getLoc();
1319   unsigned idx = topSort[at];
1320 
1321   // Initialize sparse indices.
1322   Value min;
1323   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1324     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1325       unsigned tensor = merger.tensor(b);
1326       assert(idx == merger.index(b));
1327       Value ptr = codegen.indices[tensor][idx];
1328       Value s = codegen.pidxs[tensor][idx];
1329       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1330       codegen.idxs[tensor][idx] = load;
1331       if (!needsUniv) {
1332         if (min) {
1333           Value cmp = rewriter.create<arith::CmpIOp>(
1334               loc, arith::CmpIPredicate::ult, load, min);
1335           min = rewriter.create<arith::SelectOp>(loc, cmp, load, min);
1336         } else {
1337           min = load;
1338         }
1339       }
1340     }
1341   }
1342 
1343   // Merge dense universal index over minimum.
1344   if (min) {
1345     assert(!needsUniv);
1346     codegen.loops[idx] = min;
1347   }
1348 
1349   // Initialize dense positions. Note that we generate dense indices of the
1350   // output tensor unconditionally, since they may not appear in the lattice,
1351   // but may be needed for linearized codegen.
1352   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1353     if ((locals[b] || merger.isOutTensor(b, idx)) &&
1354         merger.isDim(b, Dim::kDense)) {
1355       unsigned tensor = merger.tensor(b);
1356       assert(idx == merger.index(b));
1357       unsigned pat = at;
1358       for (; pat != 0; pat--)
1359         if (codegen.pidxs[tensor][topSort[pat - 1]])
1360           break;
1361       Value p = (pat == 0) ? constantIndex(rewriter, loc, 0)
1362                            : codegen.pidxs[tensor][topSort[pat - 1]];
1363       codegen.pidxs[tensor][idx] = genAddress(
1364           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1365     }
1366   }
1367 
1368   // Move the insertion indices in lexicographic index order. During access
1369   // pattern expansion, we can skip setting the innermost dimension.
1370   if (codegen.sparseOut && !codegen.expValues) {
1371     Value pos = constantIndex(rewriter, loc, at);
1372     rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx,
1373                                      pos);
1374   }
1375 }
1376 
1377 /// Generates the induction structure for a while-loop.
1378 static void genWhileInduction(Merger &merger, CodeGen &codegen,
1379                               PatternRewriter &rewriter, linalg::GenericOp op,
1380                               unsigned idx, bool needsUniv,
1381                               BitVector &induction, scf::WhileOp whileOp) {
1382   Location loc = op.getLoc();
1383   // Finalize each else branch of all if statements.
1384   if (codegen.redVal || codegen.expValues) {
1385     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
1386                rewriter.getInsertionBlock()->getParentOp())) {
1387       unsigned y = 0;
1388       SmallVector<Value, 4> yields;
1389       if (codegen.redVal) {
1390         yields.push_back(codegen.redVal);
1391         updateReduc(merger, codegen, ifOp.getResult(y++));
1392       }
1393       if (codegen.expValues) {
1394         yields.push_back(codegen.expCount);
1395         codegen.expCount = ifOp->getResult(y++);
1396       }
1397       assert(y == yields.size());
1398       rewriter.create<scf::YieldOp>(loc, yields);
1399       rewriter.setInsertionPointAfter(ifOp);
1400     }
1401   }
1402   rewriter.setInsertionPointToEnd(&whileOp.getAfter().front());
1403   // Finalize the induction. Note that the induction could be performed
1404   // in the individual if-branches to avoid re-evaluating the conditions.
1405   // However, that would result in a rather elaborate forest of yield
1406   // instructions during code generation. Moreover, performing the induction
1407   // after the if-statements more closely resembles code generated by TACO.
1408   unsigned o = 0;
1409   SmallVector<Value, 4> operands;
1410   Value one = constantIndex(rewriter, loc, 1);
1411   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1412     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1413       unsigned tensor = merger.tensor(b);
1414       assert(idx == merger.index(b));
1415       Value op1 = codegen.idxs[tensor][idx];
1416       Value op2 = codegen.loops[idx];
1417       Value op3 = codegen.pidxs[tensor][idx];
1418       Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1419                                                  op1, op2);
1420       Value add = rewriter.create<arith::AddIOp>(loc, op3, one);
1421       operands.push_back(rewriter.create<arith::SelectOp>(loc, cmp, add, op3));
1422       codegen.pidxs[tensor][idx] = whileOp->getResult(o++);
1423     }
1424   }
1425   if (codegen.redVal) {
1426     operands.push_back(codegen.redVal);
1427     updateReduc(merger, codegen, whileOp->getResult(o++));
1428   }
1429   if (codegen.expValues) {
1430     operands.push_back(codegen.expCount);
1431     codegen.expCount = whileOp->getResult(o++);
1432   }
1433   if (needsUniv) {
1434     operands.push_back(
1435         rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one));
1436     codegen.loops[idx] = whileOp->getResult(o++);
1437   }
1438   assert(o == operands.size());
1439   rewriter.create<scf::YieldOp>(loc, operands);
1440   rewriter.setInsertionPointAfter(whileOp);
1441 }
1442 
1443 /// Generates the induction structure for a for-loop.
1444 static void genForInduction(Merger &merger, CodeGen &codegen,
1445                             PatternRewriter &rewriter, linalg::GenericOp op,
1446                             Operation *loop) {
1447   Location loc = op.getLoc();
1448   unsigned o = 0;
1449   SmallVector<Value, 4> operands;
1450   if (codegen.redVal) {
1451     operands.push_back(codegen.redVal);
1452     updateReduc(merger, codegen, loop->getResult(o++));
1453   }
1454   if (codegen.expValues) {
1455     operands.push_back(codegen.expCount);
1456     codegen.expCount = loop->getResult(o++);
1457   }
1458   assert(o == operands.size());
1459   if (o > 0)
1460     rewriter.create<scf::YieldOp>(loc, operands);
1461   rewriter.setInsertionPointAfter(loop);
1462 }
1463 
1464 /// Generates a single if-statement within a while-loop.
1465 static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1466                        PatternRewriter &rewriter, linalg::GenericOp op,
1467                        unsigned idx, BitVector &conditions) {
1468   Location loc = op.getLoc();
1469   SmallVector<Type, 4> types;
1470   Value cond;
1471   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1472     if (conditions[b]) {
1473       unsigned tensor = merger.tensor(b);
1474       assert(idx == merger.index(b));
1475       Value clause;
1476       if (merger.isDim(b, Dim::kSparse)) {
1477         Value op1 = codegen.idxs[tensor][idx];
1478         Value op2 = codegen.loops[idx];
1479         clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1480                                                 op1, op2);
1481       } else {
1482         clause = constantI1(rewriter, loc, true);
1483       }
1484       cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause;
1485     }
1486   }
1487   if (codegen.redVal)
1488     types.push_back(codegen.redVal.getType());
1489   if (codegen.expValues)
1490     types.push_back(rewriter.getIndexType());
1491   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true);
1492   rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1493   return ifOp;
1494 }
1495 
1496 /// Generates end of true branch of if-statement within a while-loop.
1497 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1498                   linalg::GenericOp op, scf::IfOp ifOp, Operation *loop,
1499                   Value redInput, Value cntInput) {
1500   SmallVector<Value, 4> operands;
1501   if (codegen.redVal) {
1502     operands.push_back(codegen.redVal);
1503     updateReduc(merger, codegen, redInput);
1504   }
1505   if (codegen.expValues) {
1506     operands.push_back(codegen.expCount);
1507     codegen.expCount = cntInput;
1508   }
1509   if (!operands.empty())
1510     rewriter.create<scf::YieldOp>(op.getLoc(), operands);
1511   rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1512 }
1513 
1514 //===----------------------------------------------------------------------===//
1515 // Sparse compiler synthesis methods (loop sequence).
1516 //===----------------------------------------------------------------------===//
1517 
1518 /// Starts a loop sequence at given level. Returns true if
1519 /// the universal loop index must be maintained at this level.
1520 static bool startLoopSeq(Merger &merger, CodeGen &codegen,
1521                          PatternRewriter &rewriter, linalg::GenericOp op,
1522                          std::vector<unsigned> &topSort, unsigned exp,
1523                          unsigned at, unsigned idx, unsigned ldx,
1524                          unsigned lts) {
1525   assert(codegen.curVecLength == 1);
1526   assert(!codegen.loops[idx]);
1527   // Emit invariants at this loop sequence level.
1528   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true);
1529   // Emit access pattern expansion for sparse tensor output.
1530   genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/true);
1531   // Emit further intitialization at this loop sequence level.
1532   unsigned l0 = merger.set(lts)[0];
1533   bool needsUniv =
1534       genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits);
1535   // Maintain the universal index only if it is actually
1536   // consumed by a subsequent lattice point.
1537   if (needsUniv) {
1538     unsigned lsize = merger.set(lts).size();
1539     for (unsigned i = 1; i < lsize; i++) {
1540       unsigned li = merger.set(lts)[i];
1541       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse))
1542         return true;
1543     }
1544   }
1545   return false;
1546 }
1547 
1548 /// Starts a single loop in current sequence.
1549 static Operation *startLoop(Merger &merger, CodeGen &codegen,
1550                             PatternRewriter &rewriter, linalg::GenericOp op,
1551                             std::vector<unsigned> &topSort, unsigned at,
1552                             unsigned li, bool needsUniv) {
1553   assert(codegen.curVecLength == 1);
1554   // Emit the for/while-loop control.
1555   Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at,
1556                             needsUniv, merger.lat(li).simple);
1557   // Emit the locals for this loop.
1558   genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1559             merger.lat(li).bits);
1560   return loop;
1561 }
1562 
1563 /// Ends a single loop in current sequence. Returns new values for needsUniv.
1564 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1565                     linalg::GenericOp op, Operation *loop, unsigned idx,
1566                     unsigned li, bool needsUniv) {
1567   codegen.curVecLength = 1;
1568   // End a while-loop.
1569   if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1570     genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
1571                       merger.lat(li).bits, whileOp);
1572     return needsUniv;
1573   }
1574   // End a for-loop.
1575   genForInduction(merger, codegen, rewriter, op, loop);
1576   return false;
1577 }
1578 
1579 /// Ends a loop sequence at given level.
1580 static void endLoopSeq(Merger &merger, CodeGen &codegen,
1581                        PatternRewriter &rewriter, linalg::GenericOp op,
1582                        unsigned exp, unsigned at, unsigned idx, unsigned ldx) {
1583   assert(codegen.curVecLength == 1);
1584   codegen.loops[idx] = Value();
1585   // Bring a pending reduction back from SIMD form when sequence ends.
1586   if (codegen.redVal)
1587     if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>())
1588       updateReduc(merger, codegen,
1589                   genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp));
1590   // Unmark bookkeeping of invariants and loop index.
1591   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false);
1592   // Finalize access pattern expansion for sparse tensor output.
1593   genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/false);
1594 }
1595 
1596 /// Recursively generates code while computing iteration lattices in order
1597 /// to manage the complexity of implementing co-iteration over unions
1598 /// and intersections of sparse iterations spaces.
1599 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1600                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1601                     unsigned exp, unsigned at) {
1602   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1603   if (at == topSort.size()) {
1604     unsigned ldx = topSort[at - 1];
1605     Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx);
1606     genTensorStore(merger, codegen, rewriter, op, exp, rhs);
1607     return;
1608   }
1609 
1610   // Construct iteration lattices for current loop index, with L0 at top.
1611   unsigned idx = topSort[at];
1612   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1613   unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
1614 
1615   // Start a loop sequence.
1616   bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at,
1617                                 idx, ldx, lts);
1618 
1619   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1620   unsigned lsize = merger.set(lts).size();
1621   for (unsigned i = 0; i < lsize; i++) {
1622     // Start a loop.
1623     unsigned li = merger.set(lts)[i];
1624     Operation *loop =
1625         startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv);
1626 
1627     // Visit all lattices points with Li >= Lj to generate the
1628     // loop-body, possibly with if statements for coiteration.
1629     Value redInput = codegen.redVal;
1630     Value cntInput = codegen.expCount;
1631     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1632     for (unsigned j = 0; j < lsize; j++) {
1633       unsigned lj = merger.set(lts)[j];
1634       unsigned ej = merger.lat(lj).exp;
1635       if (li == lj || merger.latGT(li, lj)) {
1636         // Recurse into body of each branch.
1637         if (isWhile) {
1638           scf::IfOp ifOp =
1639               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1640           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1641           endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput);
1642         } else {
1643           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1644         }
1645       }
1646     }
1647 
1648     // End a loop.
1649     needsUniv =
1650         endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv);
1651   }
1652 
1653   // End a loop sequence.
1654   endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx);
1655 }
1656 
1657 /// Converts the result computed by the sparse kernel into the required form.
1658 static void genResult(Merger &merger, CodeGen &codegen,
1659                       PatternRewriter &rewriter, linalg::GenericOp op) {
1660   OpOperand *lhs = op.getOutputOperand(0);
1661   Type resType = lhs->get().getType();
1662   if (getSparseTensorEncoding(resType)) {
1663     // The sparse tensor rematerializes from the original sparse tensor's
1664     // underlying sparse storage format.
1665     rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(),
1666                                         codegen.sparseOut == lhs);
1667   } else {
1668     // To rematerialize an non-annotated tensor, simply load it
1669     // from the bufferized value.
1670     Value val = codegen.buffers.back(); // value array
1671     rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1672   }
1673 }
1674 
1675 //===----------------------------------------------------------------------===//
1676 // Sparse compiler rewriting methods.
1677 //===----------------------------------------------------------------------===//
1678 
1679 namespace {
1680 
1681 /// Sparse rewriting rule for generic Lingalg operation.
1682 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1683 public:
1684   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1685       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1686 
1687   LogicalResult matchAndRewrite(linalg::GenericOp op,
1688                                 PatternRewriter &rewriter) const override {
1689     // Detects sparse annotations and translate the per-dimension sparsity
1690     // information for all tensors to loop indices in the kernel.
1691     assert(op.getNumOutputs() == 1);
1692     unsigned numTensors = op.getNumInputsAndOutputs();
1693     unsigned numLoops = op.iterator_types().getValue().size();
1694     Merger merger(numTensors, numLoops);
1695     if (!findSparseAnnotations(merger, op))
1696       return failure();
1697 
1698     // Computes a topologically sorted iteration graph to ensure
1699     // tensors are visited in natural index order. Fails on cycles.
1700     // This assumes that higher-level passes have already put the
1701     // tensors in each tensor expression in a feasible order.
1702     std::vector<unsigned> topSort;
1703     if (!computeIterationGraph(merger, op, topSort,
1704                                SortMask::kIncludeUndef |
1705                                    SortMask::kIncludeDense) &&
1706         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) &&
1707         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) &&
1708         !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly))
1709       return failure();
1710 
1711     // Builds the tensor expression for the Linalg operation in SSA form.
1712     Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
1713     if (!optExp.hasValue())
1714       return failure();
1715     unsigned exp = optExp.getValue();
1716 
1717     // Rejects an inadmissable tensor expression.
1718     OpOperand *sparseOut = nullptr;
1719     unsigned outerParNest = 0;
1720     if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut,
1721                                outerParNest))
1722       return failure();
1723 
1724     // Recursively generates code.
1725     merger.setHasSparseOut(sparseOut != nullptr);
1726     CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest);
1727     genBuffers(merger, codegen, rewriter, op);
1728     genStmt(merger, codegen, rewriter, op, topSort, exp, 0);
1729     genResult(merger, codegen, rewriter, op);
1730     return success();
1731   }
1732 
1733 private:
1734   /// Options to control sparse code generation.
1735   SparsificationOptions options;
1736 };
1737 
1738 } // namespace
1739 
1740 /// Populates the given patterns list with rewriting rules required for
1741 /// the sparsification of linear algebra operations.
1742 void mlir::populateSparsificationPatterns(
1743     RewritePatternSet &patterns, const SparsificationOptions &options) {
1744   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1745 }
1746