1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering sparse tensor types to actual sparse code.
10 //
11 // The concept of letting a compiler generate sparse code automatically was
12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and
13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor
14 // Algebra Compiler (TACO). The implementation in this file closely follows
15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting
16 // rule is applied to each tensor expression in linalg (MLIR's tensor index
17 // notation) where the sparsity of tensors is indicated with annotation using
18 // a per-dimension specification of sparse/dense storage together with a
19 // specification of the order on the dimensions. Subsequently, a topologically
20 // sorted iteration graph, reflecting the required order on indices with respect
21 // to the dimensions of each tensor, is constructed to ensure that all tensors
22 // are visited in natural index order. Next, iteration lattices are constructed
23 // for the tensor expression for every index in topological order. Each
24 // iteration lattice point consists of a conjunction of tensor indices together
25 // with a tensor (sub)expression that needs to be evaluated for that
26 // conjunction. Within the lattice, iteration points are ordered according to
27 // the way indices are exhausted. As such these iteration lattices drive actual
28 // sparse code generation, which consists of a tedious but relatively
29 // straightforward one-to-one mapping from iteration lattices to combinations
30 // of for-loops, while-loops, and if-statements.
31 //
32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations.
33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php).
34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou,
35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler.
36 // Proceedings of the ACM on Programming Languages, October 2017.
37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation.
38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org).
39 //
40 // Implementation detail: We use llvm::SmallVector for vectors with
41 // variable lengths and std::vector for vectors with fixed lengths.
42 //===----------------------------------------------------------------------===//
43 
44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
45 #include "mlir/Dialect/Linalg/Utils/Utils.h"
46 #include "mlir/Dialect/MemRef/IR/MemRef.h"
47 #include "mlir/Dialect/SCF/SCF.h"
48 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
49 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
50 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
51 #include "mlir/Dialect/StandardOps/IR/Ops.h"
52 #include "mlir/Dialect/Vector/VectorOps.h"
53 #include "mlir/IR/Matchers.h"
54 #include "mlir/IR/TensorEncoding.h"
55 #include "llvm/ADT/SmallBitVector.h"
56 
57 using namespace mlir;
58 using namespace mlir::sparse_tensor;
59 
60 namespace {
61 
62 // Code generation.
63 struct CodeGen {
64   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
65       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
66         pointers(numTensors, std::vector<Value>(numLoops)),
67         indices(numTensors, std::vector<Value>(numLoops)),
68         highs(numTensors, std::vector<Value>(numLoops)),
69         pidxs(numTensors, std::vector<Value>(numLoops)),
70         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
71         curVecLength(1), curVecMask() {}
72   /// Sparsification options.
73   SparsificationOptions options;
74   /// Universal dense indices and upper bounds (by index). The loops array
75   /// is updated with the value of the universal dense index in the current
76   /// loop. The sizes array is set once with the inferred dimension sizes.
77   std::vector<Value> loops;
78   std::vector<Value> sizes;
79   /// Buffers for storing dense and sparse numerical values (by tensor).
80   /// This array is set once during bufferization of all tensors.
81   std::vector<Value> buffers;
82   /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
83   /// This array is set once during bufferization of all sparse tensors.
84   std::vector<std::vector<Value>> pointers;
85   std::vector<std::vector<Value>> indices;
86   /// Sparse iteration information (by tensor and index). These arrays
87   /// are updated to remain current within the current loop.
88   std::vector<std::vector<Value>> highs;
89   std::vector<std::vector<Value>> pidxs;
90   std::vector<std::vector<Value>> idxs;
91   /// Current reduction, updated during code generation. When indices of a
92   /// reduction are exhausted,  all inner loops can "scalarize" the reduction.
93   // TODO: currently only done for (a chain of) innermost for-loops, where it
94   // is most effective; we could generalize to more outer and while-loops.
95   unsigned redExp;
96   Value redVal;
97   // Current vector length and mask.
98   unsigned curVecLength;
99   Value curVecMask;
100 };
101 
102 } // namespace
103 
104 // Helper method to apply dimension ordering permutation.
105 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) {
106   if (enc) {
107     auto order = enc.getDimOrdering();
108     if (order) {
109       assert(order.isPermutation());
110       return order.getDimPosition(d);
111     }
112   }
113   return d;
114 }
115 
116 // Helper method to translate dim level type to internal representation.
117 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
118   if (enc) {
119     SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
120     if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
121       return Dim::kSparse;
122     if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
123       return Dim::kSingle;
124   }
125   return Dim::kDense;
126 }
127 
128 /// Helper method to inspect sparse encodings in the tensor types.
129 /// Fills the per-dimension sparsity information for all tensors.
130 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
131   bool annotated = false;
132   for (OpOperand *t : op.getInputAndOutputOperands()) {
133     auto map = op.getTiedIndexingMap(t);
134     if (!map.isProjectedPermutation())
135       return false;
136     auto enc = getSparseTensorEncoding(t->get().getType());
137     if (enc)
138       annotated = true;
139     assert(map.getNumResults() == op.getRank(t));
140     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
141       unsigned idx = map.getDimPosition(perm(enc, d));
142       merger.setDim(t->getOperandNumber(), idx, toDim(enc, d));
143     }
144   }
145   return annotated;
146 }
147 
148 /// A DFS helper to compute a topological sort. Note that recursion is
149 /// bounded by the number of implicit loops, which is always small.
150 /// Returns false when a cycle is detected.
151 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
152                        std::vector<unsigned> &topSort,
153                        std::vector<std::vector<bool>> &adjM) {
154   if (visit[i] != 0)
155     return visit[i] != 1; // 1 denotes cycle!
156   visit[i] = 1;
157   for (unsigned j = 0, e = visit.size(); j < e; j++)
158     if (adjM[i][j])
159       if (!topSortDFS(j, visit, topSort, adjM))
160         return false;
161   visit[i] = 2;
162   topSort.push_back(i);
163   return true;
164 }
165 
166 /// Computes a topologically sorted iteration graph for the linalg operation.
167 /// Ensures all tensors are visited in natural index order. This is essential
168 /// for sparse storage formats since these only support access along fixed
169 /// dimensions. Even for dense storage formats, however, the natural index
170 /// order yields innermost unit-stride access with better spatial locality.
171 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
172                                   std::vector<unsigned> &topSort,
173                                   bool sparseOnly) {
174   // Set up an n x n from/to adjacency matrix of the iteration graph
175   // for the implicit loop indices i_0 .. i_n-1.
176   unsigned n = op.getNumLoops();
177   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
178 
179   // Iterate over the indexing maps of every tensor in the tensor expression.
180   for (OpOperand *t : op.getInputAndOutputOperands()) {
181     auto map = op.getTiedIndexingMap(t);
182     auto enc = getSparseTensorEncoding(t->get().getType());
183     assert(map.getNumDims() == n);
184     // Skip dense tensor constraints when sparse only is requested.
185     if (sparseOnly && !enc)
186       continue;
187     // Each tensor expression and optional dimension ordering (row-major
188     // by default) puts an ordering constraint on the loop indices. For
189     // example, the tensor expresion A_ijk forces the ordering i < j < k
190     // on the loop indices if no explicit dimension ordering is given.
191     for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
192       unsigned f = map.getDimPosition(perm(enc, d - 1));
193       unsigned t = map.getDimPosition(perm(enc, d));
194       adjM[f][t] = true;
195     }
196   }
197 
198   // Topologically sort the iteration graph to determine loop order.
199   // Report failure for a cyclic iteration graph.
200   topSort.clear();
201   topSort.reserve(n);
202   std::vector<unsigned> visit(n, 0);
203   for (unsigned i = 0; i < n; i++)
204     if (visit[i] == 0)
205       if (!topSortDFS(i, visit, topSort, adjM))
206         return false; // cycle!
207   std::reverse(std::begin(topSort), std::end(topSort));
208   return true;
209 }
210 
211 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
212 /// This simplifies constructing (sub)expressions during iteration lattice
213 /// building (compared to using the SSA representation everywhere).
214 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
215                                          Value val) {
216   if (auto arg = val.dyn_cast<BlockArgument>()) {
217     unsigned argN = arg.getArgNumber();
218     // Any argument of the generic op that is not marked as a scalar
219     // argument is considered a tensor, indexed by the implicit loop
220     // bounds. This includes rank-0 tensor arguments.
221     if (arg.getOwner()->getParentOp() == op) {
222       OpOperand *t = op.getInputAndOutputOperands()[argN];
223       if (!op.isScalar(t))
224         return merger.addExp(Kind::kTensor, argN);
225       val = t->get(); // get scalar value
226     }
227     // Any other argument (marked as scalar argument for the generic op
228     // or belonging to an enveloping op) is considered invariant.
229     return merger.addExp(Kind::kInvariant, val);
230   }
231   Operation *def = val.getDefiningOp();
232   if (def->getBlock() != &op.region().front()) {
233     // Something defined outside is invariant.
234     return merger.addExp(Kind::kInvariant, val);
235   } else if (def->getNumOperands() == 2) {
236     // Construct binary operations if subexpressions could be built.
237     auto x = buildTensorExp(merger, op, def->getOperand(0));
238     auto y = buildTensorExp(merger, op, def->getOperand(1));
239     if (x.hasValue() && y.hasValue()) {
240       unsigned e0 = x.getValue();
241       unsigned e1 = y.getValue();
242       if (isa<MulFOp>(def))
243         return merger.addExp(Kind::kMulF, e0, e1);
244       if (isa<MulIOp>(def))
245         return merger.addExp(Kind::kMulI, e0, e1);
246       if (isa<AddFOp>(def))
247         return merger.addExp(Kind::kAddF, e0, e1);
248       if (isa<AddIOp>(def))
249         return merger.addExp(Kind::kAddI, e0, e1);
250     }
251   }
252   // Cannot build (yet).
253   return None;
254 }
255 
256 /// Returns true if given tensor co-iterates with conjunction only.
257 /// For the output tensor, this defines a "simply dynamic" operation.
258 /// For instance: A(I) = A(I) * B(I) * C(I)
259 static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) {
260   switch (merger.exp(exp).kind) {
261   case Kind::kTensor:
262     return merger.exp(exp).e0 == tensor;
263   case Kind::kMulF:
264   case Kind::kMulI:
265     return isConjunction(merger, tensor, merger.exp(exp).e0) ||
266            isConjunction(merger, tensor, merger.exp(exp).e1);
267   default:
268     return false;
269   }
270 }
271 
272 /// Returns true when the tensor expression is admissable for codegen.
273 /// Since all sparse input tensors are admissable, we just need to check
274 /// whether the output tensor in the tensor expression codegen is admissable.
275 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
276                                   unsigned exp) {
277   OpOperand *lhs = op.getOutputOperand(0);
278   unsigned tensor = lhs->getOperandNumber();
279   auto enc = getSparseTensorEncoding(lhs->get().getType());
280   // An non-annotated output tensor is assumed dense, and becomes a random
281   // access n-dim memref. Admissable since inserstions cannot occur.
282   if (!enc)
283     return true;
284   // An all-dense annotated "sparse" output tensor becomes a linearized random
285   // access 1-dim memref. Also admissable since insertions cannot occur.
286   bool allDense = true;
287   unsigned numLoops = op.iterator_types().getValue().size();
288   for (unsigned i = 0; i < numLoops; i++)
289     if (merger.isDim(tensor, i, Dim::kSparse)) {
290       allDense = false;
291       break;
292     }
293   if (allDense)
294     return true;
295   // A tensor expression with a sparse output tensor that changes its values
296   // but not its nonzero structure, an operation called "simply dynamic" in
297   // [Bik96,Ch9], is also admissable without special codegen.
298   if (isConjunction(merger, tensor, exp))
299     return true;
300   // Reject for now since this requires changes to the nonzero structure.
301   // TODO: implement "workspaces" [Kjolstad2019]
302   return false;
303 }
304 
305 /// Builds the iteration lattices in a bottom-up traversal given the remaining
306 /// tensor (sub)expression and the next loop index in the iteration graph.
307 static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
308                               unsigned exp, unsigned idx) {
309   Kind kind = merger.exp(exp).kind;
310   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
311     // Either the index is really used in the tensor expression, or it is
312     // set to the undefined index in that dimension. An invariant expression
313     // is set to a synthetic tensor with undefined indices only.
314     unsigned s = merger.addSet();
315     unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0
316                                        : op.getNumInputsAndOutputs();
317     merger.set(s).push_back(merger.addLat(t, idx, exp));
318     return s;
319   }
320   unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx);
321   unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx);
322   switch (kind) {
323   case Kind::kTensor:
324   case Kind::kInvariant:
325     llvm_unreachable("handled above");
326   case Kind::kMulF:
327   case Kind::kMulI:
328     return merger.takeConj(kind, s0, s1);
329   case Kind::kAddF:
330   case Kind::kAddI:
331     return merger.takeDisj(kind, s0, s1);
332   }
333   llvm_unreachable("unexpected expression kind");
334 }
335 
336 /// Maps sparse integer option to actual integral storage type.
337 static Type genIntType(PatternRewriter &rewriter, unsigned width) {
338   if (width == 0)
339     return rewriter.getIndexType();
340   return rewriter.getIntegerType(width);
341 }
342 
343 /// Detects in-place annotation on tensor argument.
344 static bool getInPlace(Value val) {
345   if (auto arg = val.dyn_cast<BlockArgument>())
346     if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
347       if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
348               arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName))
349         return attr.getValue();
350   return false;
351 }
352 
353 /// Generates buffer for the output tensor.
354 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
355                              linalg::GenericOp op, MemRefType denseTp,
356                              ArrayRef<Value> args) {
357   Location loc = op.getLoc();
358   Value tensor = op.getOutputOperand(0)->get();
359   // The output tensor simply could materialize from the buffer that will
360   // be generated for the tensor present in the outs() clause. This has
361   // the major advantage that the sparse kernel only updates the nonzero
362   // positions for the output tensor.
363   if (getInPlace(tensor))
364     return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
365   // By default, a new buffer is allocated which is initialized to the
366   // tensor defined in the outs() clause. This is always correct but
367   // introduces a dense initialization component that may negatively
368   // impact the running complexity of the sparse kernel.
369   Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
370   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
371   rewriter.create<linalg::CopyOp>(loc, init, alloc);
372   return alloc;
373 }
374 
375 /// Local bufferization of all dense and sparse data structures.
376 /// This code enables testing the first prototype sparse compiler.
377 // TODO: replace this with a proliferated bufferization strategy
378 static bool genBuffers(Merger &merger, CodeGen &codegen,
379                        PatternRewriter &rewriter, linalg::GenericOp op) {
380   Location loc = op.getLoc();
381   assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
382   // For every tensor, find lower and upper bound on dimensions, set the
383   // same bounds on loop indices, and obtain dense or sparse buffer(s).
384   SmallVector<Value, 4> args;
385   for (OpOperand *t : op.getInputAndOutputOperands()) {
386     unsigned tensor = t->getOperandNumber();
387     auto shape = op.getShape(t);
388     auto map = op.getTiedIndexingMap(t);
389     auto enc = getSparseTensorEncoding(t->get().getType());
390     // Scan all dimensions of current tensor.
391     args.clear();
392     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
393       unsigned idx = map.getDimPosition(perm(enc, d));
394       // Handle sparse storage schemes.
395       if (merger.isDim(tensor, idx, Dim::kSparse)) {
396         auto dynShape = {ShapedType::kDynamicSize};
397         auto ptrTp = MemRefType::get(
398             dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
399         auto indTp = MemRefType::get(
400             dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
401         Value dim = rewriter.create<ConstantIndexOp>(loc, d);
402         // Generate sparse primitives to obtains pointer and indices.
403         codegen.pointers[tensor][idx] =
404             rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
405         codegen.indices[tensor][idx] =
406             rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim);
407       }
408       // Find lower and upper bound in current dimension.
409       Value up;
410       if (shape[d] == MemRefType::kDynamicSize) {
411         up = rewriter.create<memref::DimOp>(loc, t->get(), d);
412         args.push_back(up);
413       } else {
414         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
415       }
416       codegen.sizes[idx] = codegen.highs[tensor][idx] = up;
417     }
418     // Perform the required bufferization. Dense inputs materialize
419     // from the input tensors. Dense outputs need special handling.
420     // Sparse inputs use sparse primitives to obtain the values.
421     // We also accept in-place all-dense annotated "sparse" outputs.
422     Type elementType = getElementTypeOrSelf(t->get().getType());
423     if (!enc) {
424       // Non-annotated dense tensors.
425       auto denseTp = MemRefType::get(shape, elementType);
426       if (tensor < op.getNumInputs())
427         codegen.buffers[tensor] =
428             rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get());
429       else
430         codegen.buffers[tensor] =
431             genOutputBuffer(codegen, rewriter, op, denseTp, args);
432     } else {
433       // Annotated sparse tensors.
434       if (tensor == op.getNumInputs() && !getInPlace(t->get()))
435         return false; // reject output if not in-place
436       auto dynShape = {ShapedType::kDynamicSize};
437       auto sparseTp = MemRefType::get(dynShape, elementType);
438       codegen.buffers[tensor] =
439           rewriter.create<ToValuesOp>(loc, sparseTp, t->get());
440     }
441   }
442   return true;
443 }
444 
445 /// Constructs vector type.
446 static VectorType vectorType(CodeGen &codegen, Type etp) {
447   return VectorType::get(codegen.curVecLength, etp);
448 }
449 
450 /// Constructs vector type from pointer.
451 static VectorType vectorType(CodeGen &codegen, Value ptr) {
452   return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType());
453 }
454 
455 /// Constructs vector iteration mask.
456 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
457                            Value iv, Value lo, Value hi, Value step) {
458   Location loc = iv.getLoc();
459   VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1));
460   // Special case if the vector length evenly divides the trip count (for
461   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
462   // so that all subsequent masked memory operations are immediately folded
463   // into unconditional memory operations.
464   IntegerAttr loInt, hiInt, stepInt;
465   if (matchPattern(lo, m_Constant(&loInt)) &&
466       matchPattern(hi, m_Constant(&hiInt)) &&
467       matchPattern(step, m_Constant(&stepInt))) {
468     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0)
469       return rewriter.create<vector::BroadcastOp>(
470           loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1));
471   }
472   // Otherwise, generate a vector mask that avoids overrunning the upperbound
473   // during vector execution. Here we rely on subsequent loop optimizations to
474   // avoid executing the mask in all iterations, for example, by splitting the
475   // loop into an unconditional vector loop and a scalar cleanup loop.
476   Value end = rewriter.create<SubIOp>(loc, hi, iv);
477   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
478 }
479 
480 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi].
481 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
482                            Value ptr, ArrayRef<Value> args) {
483   Location loc = ptr.getLoc();
484   VectorType vtp = vectorType(codegen, ptr);
485   Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
486   if (args.back().getType().isa<VectorType>()) {
487     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
488     Value indexVec = args.back();
489     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
490     return rewriter.create<vector::GatherOp>(
491         loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
492   }
493   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
494                                                codegen.curVecMask, pass);
495 }
496 
497 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs.
498 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
499                            Value rhs, Value ptr, ArrayRef<Value> args) {
500   Location loc = ptr.getLoc();
501   if (args.back().getType().isa<VectorType>()) {
502     SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
503     Value indexVec = args.back();
504     scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
505     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
506                                        codegen.curVecMask, rhs);
507     return;
508   }
509   rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
510                                          rhs);
511 }
512 
513 /// Generates a vectorized invariant. Here we rely on subsequent loop
514 /// optimizations to hoist the invariant broadcast out of the vector loop.
515 static Value genVectorInvariantValue(CodeGen &codegen,
516                                      PatternRewriter &rewriter, Value val) {
517   VectorType vtp = vectorType(codegen, val.getType());
518   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
519 }
520 
521 /// Generates a load on a dense or sparse tensor.
522 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
523                            PatternRewriter &rewriter, linalg::GenericOp op,
524                            unsigned exp) {
525   // Test if the load was hoisted to a higher loop nest.
526   Value val = merger.exp(exp).val;
527   if (val) {
528     if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>())
529       return genVectorInvariantValue(codegen, rewriter, val);
530     return val;
531   }
532   // Actual load.
533   SmallVector<Value, 4> args;
534   OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
535   unsigned tensor = t->getOperandNumber();
536   auto map = op.getTiedIndexingMap(t);
537   auto enc = getSparseTensorEncoding(t->get().getType());
538   unsigned rank = map.getNumResults();
539   if (enc) {
540     unsigned idx = map.getDimPosition(perm(enc, rank - 1));
541     assert(codegen.pidxs[tensor][idx] != nullptr);
542     args.push_back(codegen.pidxs[tensor][idx]); // position index
543   } else {
544     for (unsigned d = 0; d < rank; d++) {
545       unsigned idx = map.getDimPosition(d);
546       args.push_back(codegen.loops[idx]); // universal dense index
547     }
548   }
549   Location loc = op.getLoc();
550   Value ptr = codegen.buffers[tensor];
551   if (codegen.curVecLength > 1)
552     return genVectorLoad(codegen, rewriter, ptr, args);
553   return rewriter.create<memref::LoadOp>(loc, ptr, args);
554 }
555 
556 /// Generates a store on a dense or sparse tensor.
557 static void genTensorStore(Merger &merger, CodeGen &codegen,
558                            PatternRewriter &rewriter, linalg::GenericOp op,
559                            OpOperand *t, Value rhs) {
560   Location loc = op.getLoc();
561   // Test if this is a scalarized reduction.
562   OpOperand *lhs = op.getOutputOperand(0);
563   if (lhs == t && codegen.redVal) {
564     if (codegen.curVecLength > 1)
565       rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs,
566                                       codegen.redVal);
567     codegen.redVal = rhs;
568     return;
569   }
570   // Actual store.
571   SmallVector<Value, 4> args;
572   unsigned tensor = t->getOperandNumber();
573   auto map = op.getTiedIndexingMap(t);
574   auto enc = getSparseTensorEncoding(t->get().getType());
575   unsigned rank = map.getNumResults();
576   if (enc) {
577     unsigned idx = map.getDimPosition(perm(enc, rank - 1));
578     assert(codegen.pidxs[tensor][idx] != nullptr);
579     args.push_back(codegen.pidxs[tensor][idx]); // position index
580   } else {
581     for (unsigned d = 0; d < rank; d++) {
582       unsigned idx = map.getDimPosition(d);
583       args.push_back(codegen.loops[idx]); // universal dense index
584     }
585   }
586   Value ptr = codegen.buffers[tensor];
587   if (codegen.curVecLength > 1)
588     genVectorStore(codegen, rewriter, rhs, ptr, args);
589   else
590     rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
591 }
592 
593 /// Generates a pointer/index load from the sparse storage scheme. Narrower
594 /// data types need to be zero extended before casting the value into the
595 /// index type used for looping and indexing.
596 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
597                      Value ptr, Value s) {
598   // See https://llvm.org/docs/GetElementPtr.html for some background on
599   // the complications described below.
600   if (codegen.curVecLength > 1) {
601     // Since the index vector is used in a subsequent gather/scatter operations,
602     // which effectively defines an unsigned pointer + signed index, we must
603     // zero extend the vector to an index width. For 8-bit and 16-bit values,
604     // an 32-bit index width suffices. For 32-bit values, zero extending the
605     // elements into 64-bit loses some performance since the 32-bit indexed
606     // gather/scatter is more efficient than the 64-bit index variant (if the
607     // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
608     // preserve this performance). For 64-bit values, there is no good way
609     // to state that the indices are unsigned, with creates the potential of
610     // incorrect address calculations in the unlikely case we need such
611     // extremely large offsets.
612     Type etp = ptr.getType().cast<MemRefType>().getElementType();
613     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
614     if (!etp.isa<IndexType>()) {
615       if (etp.getIntOrFloatBitWidth() < 32)
616         vload = rewriter.create<ZeroExtendIOp>(
617             loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
618       else if (etp.getIntOrFloatBitWidth() < 64 &&
619                !codegen.options.enableSIMDIndex32)
620         vload = rewriter.create<ZeroExtendIOp>(
621             loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
622     }
623     return vload;
624   }
625   // For the scalar case, we simply zero extend narrower indices into 64-bit
626   // values before casting to index without a performance penalty. Here too,
627   // however, indices that already are 64-bit, in theory, cannot express the
628   // full range as explained above.
629   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
630   if (!load.getType().isa<IndexType>()) {
631     if (load.getType().getIntOrFloatBitWidth() < 64)
632       load = rewriter.create<ZeroExtendIOp>(loc, load,
633                                             rewriter.getIntegerType(64));
634     load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
635   }
636   return load;
637 }
638 
639 /// Generates an invariant value.
640 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
641                                PatternRewriter &rewriter, unsigned exp) {
642   Value val = merger.exp(exp).val;
643   if (codegen.curVecLength > 1)
644     return genVectorInvariantValue(codegen, rewriter, val);
645   return val;
646 }
647 
648 /// Generates an address computation "sz * p + i".
649 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
650                         Location loc, Value size, Value p, Value i) {
651   Value mul = rewriter.create<MulIOp>(loc, size, p);
652   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
653     Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType());
654     mul = genVectorInvariantValue(codegen, rewriter, inv);
655   }
656   return rewriter.create<AddIOp>(loc, mul, i);
657 }
658 
659 /// Generates start of a reduction.
660 static Value genReductionStart(Merger &merger, CodeGen &codegen,
661                                PatternRewriter &rewriter,
662                                linalg::GenericOp op) {
663   if (codegen.redVal)
664     return codegen.redVal; // chained with previous for-loop
665   if (codegen.curVecLength > 1) {
666     // TODO: assumes + reductions for now
667     VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]);
668     return rewriter.create<ConstantOp>(op.getLoc(), vtp,
669                                        rewriter.getZeroAttr(vtp));
670   }
671   return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
672 }
673 
674 /// Generates end of a reduction.
675 static void genReductionEnd(Merger &merger, CodeGen &codegen,
676                             PatternRewriter &rewriter, linalg::GenericOp op) {
677   Value red = codegen.redVal;
678   if (!red)
679     return;
680   assert(codegen.curVecLength == 1);
681   codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
682   OpOperand *lhs = op.getOutputOperand(0);
683   if (auto vtp = red.getType().dyn_cast<VectorType>()) {
684     // TODO: assumes + reductions for now
685     StringAttr kind = rewriter.getStringAttr("add");
686     Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
687     // Integer reductions don't accept an accumulator.
688     if (vtp.getElementType().isa<IntegerType>()) {
689       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
690                                                  kind, red, ValueRange{});
691       red = rewriter.create<AddIOp>(op.getLoc(), red, ld);
692     } else {
693       red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
694                                                  kind, red, ld);
695     }
696   }
697   genTensorStore(merger, codegen, rewriter, op, lhs, red);
698 }
699 
700 /// Recursively generates tensor expression.
701 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
702                     linalg::GenericOp op, unsigned exp) {
703   if (merger.exp(exp).kind == Kind::kTensor)
704     return genTensorLoad(merger, codegen, rewriter, op, exp);
705   else if (merger.exp(exp).kind == Kind::kInvariant)
706     return genInvariantValue(merger, codegen, rewriter, exp);
707   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
708   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
709   switch (merger.exp(exp).kind) {
710   case Kind::kTensor:
711   case Kind::kInvariant:
712     llvm_unreachable("handled above");
713   case Kind::kMulF:
714     return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
715   case Kind::kMulI:
716     return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
717   case Kind::kAddF:
718     return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
719   case Kind::kAddI:
720     return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
721   }
722   llvm_unreachable("unexpected expression kind");
723 }
724 
725 /// Hoists loop invariant tensor loads for which indices have been exhausted.
726 static void genInvariants(Merger &merger, CodeGen &codegen,
727                           PatternRewriter &rewriter, linalg::GenericOp op,
728                           unsigned exp, unsigned ldx, bool hoist) {
729   if (merger.exp(exp).kind == Kind::kTensor) {
730     // Inspect tensor indices.
731     bool atLevel = ldx == -1u;
732     OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
733     auto map = op.getTiedIndexingMap(t);
734     auto enc = getSparseTensorEncoding(t->get().getType());
735     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
736       unsigned idx = map.getDimPosition(perm(enc, d));
737       if (!codegen.loops[idx])
738         return; // still in play
739       else if (idx == ldx)
740         atLevel = true;
741     }
742     // All exhausted at this level (atLevel denotes exactly at this level).
743     OpOperand *lhs = op.getOutputOperand(0);
744     if (lhs == t) {
745       codegen.redExp = hoist ? exp : -1u;
746     } else if (atLevel) {
747       merger.exp(exp).val =
748           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
749     }
750   } else if (merger.exp(exp).kind != Kind::kInvariant) {
751     // Traverse into the binary operations. Note that we only hoist
752     // tensor loads, since subsequent MLIR/LLVM passes know how to
753     // deal with all other kinds of derived loop invariants.
754     unsigned e0 = merger.exp(exp).e0;
755     unsigned e1 = merger.exp(exp).e1;
756     genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist);
757     genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist);
758   }
759 }
760 
761 /// Generates initialization code for the subsequent loop sequence at
762 /// current index level. Returns true if the loop sequence needs to
763 /// maintain the universal index.
764 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
765                     linalg::GenericOp op, std::vector<unsigned> &topSort,
766                     unsigned at, llvm::BitVector &inits) {
767   bool needsUniv = false;
768   Location loc = op.getLoc();
769   unsigned idx = topSort[at];
770 
771   // Initialize sparse positions.
772   for (unsigned b = 0, be = inits.size(); b < be; b++) {
773     if (inits[b]) {
774       unsigned tensor = merger.tensor(b);
775       assert(idx == merger.index(b));
776       if (merger.isDim(b, Dim::kSparse)) {
777         // Initialize sparse index.
778         unsigned pat = at;
779         for (; pat != 0; pat--) {
780           if (codegen.pidxs[tensor][topSort[pat - 1]])
781             break;
782         }
783         Value ptr = codegen.pointers[tensor][idx];
784         Value one = rewriter.create<ConstantIndexOp>(loc, 1);
785         Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
786                               : codegen.pidxs[tensor][topSort[pat - 1]];
787         codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0);
788         Value p1 = rewriter.create<AddIOp>(loc, p0, one);
789         codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1);
790       } else {
791         // Dense index still in play.
792         needsUniv = true;
793       }
794     }
795   }
796 
797   // Initialize the universal dense index.
798   codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0);
799   return needsUniv;
800 }
801 
802 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
803 /// operation is a candidate. Whether it is actually converted to SIMD code
804 /// depends on the requested strategy.
805 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
806   switch (codegen.options.vectorizationStrategy) {
807   case SparseVectorizationStrategy::kNone:
808     return false;
809   case SparseVectorizationStrategy::kDenseInnerLoop:
810     return isInner && !isSparse;
811   case SparseVectorizationStrategy::kAnyStorageInnerLoop:
812     return isInner;
813   }
814   llvm_unreachable("unexpected vectorization strategy");
815 }
816 
817 /// Returns parallelization strategy. Any implicit loop in the Linalg operation
818 /// that is marked "parallel" is a candidate. Whether it is actually converted
819 /// to a parallel operation depends on the requested strategy.
820 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
821                           bool isSparse, bool isVector) {
822   switch (codegen.options.parallelizationStrategy) {
823   case SparseParallelizationStrategy::kNone:
824     return false;
825   case SparseParallelizationStrategy::kDenseOuterLoop:
826     return isOuter && !isSparse && !isReduction && !isVector;
827   case SparseParallelizationStrategy::kAnyStorageOuterLoop:
828     return isOuter && !isReduction && !isVector;
829   case SparseParallelizationStrategy::kDenseAnyLoop:
830     return !isSparse && !isReduction && !isVector;
831   case SparseParallelizationStrategy::kAnyStorageAnyLoop:
832     return !isReduction && !isVector;
833   }
834   llvm_unreachable("unexpected parallelization strategy");
835 }
836 
837 /// Checks unit strides for dense tensors. The iteration graph may have ignored
838 /// dense access patterns in order to avoid cycles (sparse access patterns are
839 /// always placed innermost), but that means dense access has become strided.
840 /// For now, we reject vectorization of such cases.
841 /// TODO: implement strided load/stores on dense arrays
842 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
843                              unsigned idx) {
844   for (OpOperand *t : op.getInputAndOutputOperands()) {
845     if (!getSparseTensorEncoding(t->get().getType())) {
846       auto map = op.getTiedIndexingMap(t);
847       for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
848         if (map.getDimPosition(d) == idx && d != rank - 1)
849           return false;
850       }
851     }
852   }
853   return true;
854 }
855 
856 /// Generates a for-loop on a single index.
857 static Operation *genFor(Merger &merger, CodeGen &codegen,
858                          PatternRewriter &rewriter, linalg::GenericOp op,
859                          bool isOuter, bool isInner, unsigned idx,
860                          llvm::BitVector &indices) {
861   unsigned fb = indices.find_first();
862   unsigned tensor = merger.tensor(fb);
863   assert(idx == merger.index(fb));
864   auto iteratorTypes = op.iterator_types().getValue();
865   bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]);
866   bool isSparse = merger.isDim(fb, Dim::kSparse);
867   bool isVector = isVectorFor(codegen, isInner, isSparse) &&
868                   denseUnitStrides(merger, op, idx);
869   bool isParallel =
870       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
871 
872   // Prepare vector length.
873   if (isVector)
874     codegen.curVecLength = codegen.options.vectorLength;
875 
876   // Loop bounds and increment.
877   Location loc = op.getLoc();
878   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
879   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
880   Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength);
881 
882   // Emit a parallel loop.
883   if (isParallel) {
884     assert(!isVector);
885     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
886     if (isSparse)
887       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
888     else
889       codegen.loops[idx] = parOp.getInductionVars()[0];
890     rewriter.setInsertionPointToStart(parOp.getBody());
891     return parOp;
892   }
893 
894   // Emit a sequential loop, potentially with a scalarized reduction.
895   bool scalarRed = isInner && codegen.redExp != -1u;
896   SmallVector<Value, 4> operands;
897   if (scalarRed) {
898     Value load = genReductionStart(merger, codegen, rewriter, op);
899     operands.push_back(load);
900   }
901   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
902   if (scalarRed) {
903     codegen.redVal = merger.exp(codegen.redExp).val =
904         forOp.getRegionIterArgs().front();
905   }
906   // Assign induction variable to sparse or dense index.
907   Value iv = forOp.getInductionVar();
908   if (isSparse)
909     codegen.pidxs[tensor][idx] = iv;
910   else
911     codegen.loops[idx] = iv;
912   rewriter.setInsertionPointToStart(forOp.getBody());
913   // Share vector iteration mask between all subsequent loads/stores.
914   if (isVector)
915     codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step);
916   return forOp;
917 }
918 
919 /// Emit a while-loop for co-iteration over multiple indices.
920 static Operation *genWhile(Merger &merger, CodeGen &codegen,
921                            PatternRewriter &rewriter, linalg::GenericOp op,
922                            unsigned idx, bool needsUniv,
923                            llvm::BitVector &indices) {
924   SmallVector<Type, 4> types;
925   SmallVector<Value, 4> operands;
926   // Construct the while-loop with a parameter for each index.
927   Type indexType = rewriter.getIndexType();
928   for (unsigned b = 0, be = indices.size(); b < be; b++) {
929     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
930       unsigned tensor = merger.tensor(b);
931       assert(idx == merger.index(b));
932       types.push_back(indexType);
933       assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
934              "type mismatch for sparse index");
935       operands.push_back(codegen.pidxs[tensor][idx]);
936     }
937   }
938   if (needsUniv) {
939     types.push_back(indexType);
940     assert(codegen.loops[idx].getType().isa<IndexType>() &&
941            "type mismatch for universal index");
942     operands.push_back(codegen.loops[idx]);
943   }
944   Location loc = op.getLoc();
945   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
946   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
947   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
948 
949   // Build the "before" region, which effectively consists
950   // of a conjunction of "i < upper" tests on all induction.
951   rewriter.setInsertionPointToStart(&whileOp.before().front());
952   Value cond;
953   unsigned o = 0;
954   for (unsigned b = 0, be = indices.size(); b < be; b++) {
955     if (indices[b] && merger.isDim(b, Dim::kSparse)) {
956       unsigned tensor = merger.tensor(b);
957       assert(idx == merger.index(b));
958       Value op1 = before->getArgument(o);
959       Value op2 = codegen.highs[tensor][idx];
960       Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2);
961       cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc;
962       codegen.pidxs[tensor][idx] = after->getArgument(o++);
963     }
964   }
965   if (needsUniv)
966     codegen.loops[idx] = after->getArgument(o++);
967   assert(o == operands.size());
968   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
969   rewriter.setInsertionPointToStart(&whileOp.after().front());
970   return whileOp;
971 }
972 
973 /// Generates a for-loop or a while-loop, depending on whether it implements
974 /// singleton iteration or co-iteration over the given conjunction.
975 static Operation *genLoop(Merger &merger, CodeGen &codegen,
976                           PatternRewriter &rewriter, linalg::GenericOp op,
977                           std::vector<unsigned> &topSort, unsigned at,
978                           bool needsUniv, llvm::BitVector &indices) {
979   unsigned idx = topSort[at];
980   if (indices.count() == 1) {
981     bool isOuter = at == 0;
982     bool isInner = at == topSort.size() - 1;
983     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
984                   indices);
985   }
986   genReductionEnd(merger, codegen, rewriter, op); // cannot chain
987   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
988 }
989 
990 /// Generates the local variables for this loop, consisting of the sparse
991 /// indices, restored universal dense index, and dense positions.
992 static void genLocals(Merger &merger, CodeGen &codegen,
993                       PatternRewriter &rewriter, linalg::GenericOp op,
994                       std::vector<unsigned> &topSort, unsigned at,
995                       bool needsUniv, llvm::BitVector &locals) {
996   Location loc = op.getLoc();
997   unsigned idx = topSort[at];
998 
999   // Initialize sparse indices.
1000   Value min;
1001   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1002     if (locals[b] && merger.isDim(b, Dim::kSparse)) {
1003       unsigned tensor = merger.tensor(b);
1004       assert(idx == merger.index(b));
1005       Value ptr = codegen.indices[tensor][idx];
1006       Value s = codegen.pidxs[tensor][idx];
1007       Value load = genLoad(codegen, rewriter, loc, ptr, s);
1008       codegen.idxs[tensor][idx] = load;
1009       if (!needsUniv) {
1010         if (min) {
1011           Value cmp =
1012               rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
1013           min = rewriter.create<SelectOp>(loc, cmp, load, min);
1014         } else {
1015           min = load;
1016         }
1017       }
1018     }
1019   }
1020 
1021   // Merge dense universal index over minimum.
1022   if (min) {
1023     assert(!needsUniv);
1024     codegen.loops[idx] = min;
1025   }
1026 
1027   // Initialize dense positions. Note that we generate dense indices of the
1028   // output tensor unconditionally, since they may not appear in the lattice,
1029   // but may be needed for linearized codegen.
1030   for (unsigned b = 0, be = locals.size(); b < be; b++) {
1031     if ((locals[b] || merger.isOutTensor(b, idx)) &&
1032         merger.isDim(b, Dim::kDense)) {
1033       unsigned tensor = merger.tensor(b);
1034       assert(idx == merger.index(b));
1035       unsigned pat = at;
1036       for (; pat != 0; pat--)
1037         if (codegen.pidxs[tensor][topSort[pat - 1]])
1038           break;
1039       Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
1040                            : codegen.pidxs[tensor][topSort[pat - 1]];
1041       codegen.pidxs[tensor][idx] = genAddress(
1042           codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]);
1043     }
1044   }
1045 }
1046 
1047 /// Generates the induction structure for a while-loop.
1048 static void genWhileInduction(Merger &merger, CodeGen &codegen,
1049                               PatternRewriter &rewriter, linalg::GenericOp op,
1050                               unsigned idx, bool needsUniv,
1051                               llvm::BitVector &induction, ResultRange results) {
1052   Location loc = op.getLoc();
1053   unsigned o = 0;
1054   SmallVector<Value, 4> operands;
1055   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
1056   for (unsigned b = 0, be = induction.size(); b < be; b++) {
1057     if (induction[b] && merger.isDim(b, Dim::kSparse)) {
1058       unsigned tensor = merger.tensor(b);
1059       assert(idx == merger.index(b));
1060       Value op1 = codegen.idxs[tensor][idx];
1061       Value op2 = codegen.loops[idx];
1062       Value op3 = codegen.pidxs[tensor][idx];
1063       Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1064       Value add = rewriter.create<AddIOp>(loc, op3, one);
1065       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
1066       codegen.pidxs[tensor][idx] = results[o++];
1067     }
1068   }
1069   if (needsUniv) {
1070     operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one));
1071     codegen.loops[idx] = results[o++];
1072   }
1073   assert(o == operands.size());
1074   rewriter.create<scf::YieldOp>(loc, operands);
1075 }
1076 
1077 /// Generates a single if-statement within a while-loop.
1078 static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
1079                        PatternRewriter &rewriter, linalg::GenericOp op,
1080                        unsigned idx, llvm::BitVector &conditions) {
1081   Location loc = op.getLoc();
1082   Value cond;
1083   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1084     if (conditions[b]) {
1085       unsigned tensor = merger.tensor(b);
1086       assert(idx == merger.index(b));
1087       Value clause;
1088       if (merger.isDim(b, Dim::kSparse)) {
1089         Value op1 = codegen.idxs[tensor][idx];
1090         Value op2 = codegen.loops[idx];
1091         clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
1092       } else {
1093         clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true
1094       }
1095       cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause;
1096     }
1097   }
1098   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
1099   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
1100   return ifOp;
1101 }
1102 
1103 /// Recursively generates code while computing iteration lattices in order
1104 /// to manage the complexity of implementing co-iteration over unions
1105 /// and intersections of sparse iterations spaces.
1106 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
1107                     linalg::GenericOp op, std::vector<unsigned> &topSort,
1108                     unsigned exp, unsigned at) {
1109   // At each leaf, assign remaining tensor (sub)expression to output tensor.
1110   if (at == topSort.size()) {
1111     OpOperand *lhs = op.getOutputOperand(0);
1112     Value rhs = genExp(merger, codegen, rewriter, op, exp);
1113     genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
1114     return;
1115   }
1116   assert(codegen.curVecLength == 1);
1117 
1118   // Construct iteration lattices for current loop index, with L0 at top.
1119   // Then emit initialization code for the loop sequence at this level.
1120   // We maintain the universal dense index if dense indices are still
1121   // in play for a non-singleton loop sequence.
1122   Location loc = op.getLoc();
1123   unsigned idx = topSort[at];
1124   unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx));
1125   unsigned lsize = merger.set(lts).size();
1126   assert(lsize != 0);
1127   unsigned l0 = merger.set(lts)[0];
1128   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
1129   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
1130   bool needsUniv = false;
1131   if (genInit(merger, codegen, rewriter, op, topSort, at,
1132               merger.lat(l0).bits)) {
1133     // Maintain the universal index only if it is actually
1134     // consumed by a subsequent lattice point.
1135     for (unsigned i = 1; i < lsize; i++) {
1136       unsigned li = merger.set(lts)[i];
1137       if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) {
1138         needsUniv = true;
1139         break;
1140       }
1141     }
1142   }
1143 
1144   // Emit a loop for every lattice point L0 >= Li.
1145   for (unsigned i = 0; i < lsize; i++) {
1146     unsigned li = merger.set(lts)[i];
1147 
1148     // Emit loop.
1149     codegen.curVecLength = 1;
1150     llvm::BitVector indices = merger.lat(li).simple;
1151     Operation *loop =
1152         genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
1153     genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
1154               merger.lat(li).bits);
1155 
1156     // Visit all lattices points with Li >= Lj to generate the
1157     // loop-body, possibly with if statements for coiteration.
1158     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1159     for (unsigned j = 0; j < lsize; j++) {
1160       unsigned lj = merger.set(lts)[j];
1161       unsigned ej = merger.lat(lj).exp;
1162       if (li == lj || merger.latGT(li, lj)) {
1163         // Recurse into body of each branch.
1164         if (isWhile) {
1165           scf::IfOp ifOp =
1166               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1167           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1168           rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
1169         } else {
1170           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
1171         }
1172       }
1173     }
1174 
1175     // Wrap-up induction and restore insertion point.
1176     if (isWhile) {
1177       scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
1178       rewriter.setInsertionPointToEnd(&whileOp.after().front());
1179       genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
1180                         merger.lat(li).bits, whileOp.results());
1181     } else {
1182       needsUniv = false;
1183       if (codegen.redVal) {
1184         rewriter.create<scf::YieldOp>(loc, codegen.redVal);
1185         codegen.redVal = loop->getResult(0);
1186       }
1187     }
1188     rewriter.setInsertionPointAfter(loop);
1189   }
1190 
1191   // Wrap-up loop sequence.
1192   codegen.curVecLength = 1;
1193   genReductionEnd(merger, codegen, rewriter, op);
1194   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
1195   codegen.loops[idx] = Value();
1196 }
1197 
1198 /// Converts the result computed by the sparse kernel into the required form.
1199 static void genResult(Merger &merger, CodeGen &codegen,
1200                       PatternRewriter &rewriter, linalg::GenericOp op) {
1201   Location loc = op.getLoc();
1202   OpOperand *lhs = op.getOutputOperand(0);
1203   Type resType = lhs->get().getType();
1204   unsigned tensor = lhs->getOperandNumber();
1205   auto map = op.getTiedIndexingMap(lhs);
1206   auto enc = getSparseTensorEncoding(resType);
1207   Value result = codegen.buffers.back(); // value array
1208   if (enc) {
1209     // The sparse annotation unambigiously defines the arrays needed
1210     // to "reconstruct" the sparse tensor from the storage scheme
1211     // (even though lowering should never need this eventually).
1212     SmallVector<Value, 4> args;
1213     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
1214       unsigned idx = map.getDimPosition(perm(enc, d));
1215       if (merger.isDim(tensor, idx, Dim::kSparse)) {
1216         args.push_back(codegen.pointers[tensor][idx]);
1217         args.push_back(codegen.indices[tensor][idx]);
1218       }
1219     }
1220     args.push_back(result);
1221     result = rewriter.create<ToTensorOp>(loc, resType, args);
1222   } else {
1223     // To "reconstruct" an non-annotated tensor, sipmly load it
1224     // from the bufferized value.
1225     result = rewriter.create<memref::TensorLoadOp>(loc, resType, result);
1226   }
1227   rewriter.replaceOp(op, result);
1228 }
1229 
1230 namespace {
1231 
1232 /// Sparse rewriting rule for generic Lingalg operation.
1233 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1234 public:
1235   GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1236       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1237 
1238   LogicalResult matchAndRewrite(linalg::GenericOp op,
1239                                 PatternRewriter &rewriter) const override {
1240     // Detects sparse annotations and translate the per-dimension sparsity
1241     // information for all tensors to loop indices in the kernel.
1242     assert(op.getNumOutputs() == 1);
1243     unsigned numTensors = op.getNumInputsAndOutputs();
1244     unsigned numLoops = op.iterator_types().getValue().size();
1245     Merger merger(numTensors, numLoops);
1246     if (!findSparseAnnotations(merger, op))
1247       return failure();
1248 
1249     // Computes a topologically sorted iteration graph to ensure
1250     // tensors are visited in natural index order. Fails on cycles.
1251     // This assumes that higher-level passes have already put the
1252     // tensors in each tensor expression in a feasible order.
1253     std::vector<unsigned> topSort;
1254     if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
1255         !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
1256       return failure();
1257 
1258     // Finds the terminating yield statement and builds the tensor
1259     // expression for the Linalg operation in SSA form.
1260     Operation *yield = op.region().front().getTerminator();
1261     Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
1262     if (!exp.hasValue())
1263       return failure(); // build failure
1264 
1265     // Reject an inadmissable tensor expression.
1266     if (!isAdmissableTensorExp(merger, op, exp.getValue()))
1267       return failure();
1268 
1269     // Recursively generates code.
1270     CodeGen codegen(options, numTensors, numLoops);
1271     if (!genBuffers(merger, codegen, rewriter, op))
1272       return failure(); // could not bufferize
1273     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
1274     genResult(merger, codegen, rewriter, op);
1275     return success();
1276   }
1277 
1278 private:
1279   /// Options to control sparse code generation.
1280   SparsificationOptions options;
1281 };
1282 
1283 } // namespace
1284 
1285 /// Populates the given patterns list with rewriting rules required for
1286 /// the sparsification of linear algebra operations.
1287 void mlir::populateSparsificationPatterns(
1288     RewritePatternSet &patterns, const SparsificationOptions &options) {
1289   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1290 }
1291