1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "CodegenUtils.h"
18 
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/Dialect/Tensor/IR/Tensor.h"
29 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
30 #include "mlir/Transforms/DialectConversion.h"
31 
32 using namespace mlir;
33 using namespace mlir::sparse_tensor;
34 using mlir::bufferization::BufferizableOpInterface;
35 using mlir::bufferization::BufferizationDialect;
36 
37 namespace {
38 
39 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
40 /// `createFuncCall()`, and `replaceOpWithFuncCall()`.
41 enum class EmitCInterface : bool { Off = false, On = true };
42 
43 //===----------------------------------------------------------------------===//
44 // Helper methods.
45 //===----------------------------------------------------------------------===//
46 
47 /// Returns the equivalent of `void*` for opaque arguments to the
48 /// execution engine.
49 static Type getOpaquePointerType(OpBuilder &builder) {
50   return LLVM::LLVMPointerType::get(builder.getI8Type());
51 }
52 
53 /// Returns a function reference (first hit also inserts into module). Sets
54 /// the "_emit_c_interface" on the function declaration when requested,
55 /// so that LLVM lowering generates a wrapper function that takes care
56 /// of ABI complications with passing in and returning MemRefs to C functions.
57 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
58                                  TypeRange resultType, ValueRange operands,
59                                  EmitCInterface emitCInterface) {
60   MLIRContext *context = op->getContext();
61   auto module = op->getParentOfType<ModuleOp>();
62   auto result = SymbolRefAttr::get(context, name);
63   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
64   if (!func) {
65     OpBuilder moduleBuilder(module.getBodyRegion());
66     func = moduleBuilder.create<func::FuncOp>(
67         op->getLoc(), name,
68         FunctionType::get(context, operands.getTypes(), resultType));
69     func.setPrivate();
70     if (static_cast<bool>(emitCInterface))
71       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
72                     UnitAttr::get(context));
73   }
74   return result;
75 }
76 
77 /// Creates a `CallOp` to the function reference returned by `getFunc()`.
78 static func::CallOp createFuncCall(OpBuilder &builder, Operation *op,
79                                    StringRef name, TypeRange resultType,
80                                    ValueRange operands,
81                                    EmitCInterface emitCInterface) {
82   auto fn = getFunc(op, name, resultType, operands, emitCInterface);
83   return builder.create<func::CallOp>(op->getLoc(), resultType, fn, operands);
84 }
85 
86 /// Replaces the `op` with  a `CallOp` to the function reference returned
87 /// by `getFunc()`.
88 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
89                                           StringRef name, TypeRange resultType,
90                                           ValueRange operands,
91                                           EmitCInterface emitCInterface) {
92   auto fn = getFunc(op, name, resultType, operands, emitCInterface);
93   return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
94                                                    operands);
95 }
96 
97 /// Generates dimension size call.
98 static Value genDimSizeCall(OpBuilder &builder, Operation *op,
99                             SparseTensorEncodingAttr &enc, Value src,
100                             int64_t idx) {
101   // Permute the index according to an optional dimension ordering.
102   if (AffineMap p = enc.getDimOrdering())
103     idx = p.getPermutedPosition(idx);
104   // Generate the call.
105   StringRef name = "sparseDimSize";
106   SmallVector<Value, 2> params{src, constantIndex(builder, op->getLoc(), idx)};
107   Type iTp = builder.getIndexType();
108   return createFuncCall(builder, op, name, iTp, params, EmitCInterface::Off)
109       .getResult(0);
110 }
111 
112 /// Generates a call into the "swiss army knife" method of the sparse runtime
113 /// support library for materializing sparse tensors into the computation.
114 static Value genNewCall(OpBuilder &builder, Operation *op,
115                         ArrayRef<Value> params) {
116   StringRef name = "newSparseTensor";
117   Type pTp = getOpaquePointerType(builder);
118   return createFuncCall(builder, op, name, pTp, params, EmitCInterface::On)
119       .getResult(0);
120 }
121 
122 /// Populates given sizes array from type.
123 static void sizesFromType(OpBuilder &builder, SmallVector<Value, 4> &sizes,
124                           Location loc, ShapedType stp) {
125   auto shape = stp.getShape();
126   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) {
127     uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
128     sizes.push_back(constantIndex(builder, loc, s));
129   }
130 }
131 
132 /// Populates given sizes array from source.
133 static void sizesFromSrc(OpBuilder &builder, SmallVector<Value, 4> &sizes,
134                          Location loc, Value src) {
135   unsigned rank = src.getType().cast<ShapedType>().getRank();
136   for (unsigned i = 0; i < rank; i++)
137     sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, i));
138 }
139 
140 /// Populates given sizes array from type (for static sizes) and from
141 /// an already converted into opague pointer source (for dynamic sizes).
142 static void sizesFromPtr(OpBuilder &builder, SmallVector<Value, 4> &sizes,
143                          Operation *op, SparseTensorEncodingAttr &enc,
144                          ShapedType stp, Value src) {
145   Location loc = op->getLoc();
146   auto shape = stp.getShape();
147   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
148     if (shape[i] == ShapedType::kDynamicSize)
149       sizes.push_back(genDimSizeCall(builder, op, enc, src, i));
150     else
151       sizes.push_back(constantIndex(builder, loc, shape[i]));
152 }
153 
154 /// Generates an uninitialized temporary buffer of the given size and
155 /// type, but returns it as type `memref<? x $tp>` (rather than as type
156 /// `memref<$sz x $tp>`).
157 static Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) {
158   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
159   return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz});
160 }
161 
162 /// Generates an uninitialized buffer of the given size and type,
163 /// but returns it as type `memref<? x $tp>` (rather than as type
164 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
165 /// this buffer must be explicitly deallocated by client.
166 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
167   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
168   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
169 }
170 
171 /// Generates an uninitialized temporary buffer of the given size and
172 /// type, but returns it as type `memref<? x $tp>` (rather than as type
173 /// `memref<$sz x $tp>`).
174 static Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp) {
175   return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp);
176 }
177 
178 /// Generates an uninitialized temporary buffer with room for one value
179 /// of the given type, and returns the `memref<$tp>`.
180 static Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp) {
181   return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
182 }
183 
184 /// Generates a temporary buffer of the given type and given contents.
185 static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
186   unsigned sz = values.size();
187   assert(sz >= 1);
188   Value buffer = genAlloca(builder, loc, sz, values[0].getType());
189   for (unsigned i = 0; i < sz; i++) {
190     Value idx = constantIndex(builder, loc, i);
191     builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
192   }
193   return buffer;
194 }
195 
196 /// Populates parameters required to call the "swiss army knife" method of the
197 /// sparse runtime support library for materializing sparse tensors into the
198 /// computation.
199 static void newParams(OpBuilder &builder, SmallVector<Value, 8> &params,
200                       Operation *op, ShapedType stp,
201                       SparseTensorEncodingAttr &enc, Action action,
202                       ValueRange szs, Value ptr = Value()) {
203   Location loc = op->getLoc();
204   ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
205   unsigned sz = dlt.size();
206   // Sparsity annotations.
207   SmallVector<Value, 4> attrs;
208   for (unsigned i = 0; i < sz; i++)
209     attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i]));
210   params.push_back(genBuffer(builder, loc, attrs));
211   // Dimension sizes array of the enveloping tensor. Useful for either
212   // verification of external data, or for construction of internal data.
213   params.push_back(genBuffer(builder, loc, szs));
214   // Dimension order permutation array. This is the "identity" permutation by
215   // default, or otherwise the "reverse" permutation of a given ordering, so
216   // that indices can be mapped quickly to the right position.
217   SmallVector<Value, 4> rev(sz);
218   if (AffineMap p = enc.getDimOrdering()) {
219     for (unsigned i = 0; i < sz; i++)
220       rev[p.getDimPosition(i)] = constantIndex(builder, loc, i);
221   } else {
222     for (unsigned i = 0; i < sz; i++)
223       rev[i] = constantIndex(builder, loc, i);
224   }
225   params.push_back(genBuffer(builder, loc, rev));
226   // Secondary and primary types encoding.
227   Type elemTp = stp.getElementType();
228   params.push_back(constantPointerTypeEncoding(builder, loc, enc));
229   params.push_back(constantIndexTypeEncoding(builder, loc, enc));
230   params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp));
231   // User action.
232   params.push_back(constantAction(builder, loc, action));
233   // Payload pointer.
234   if (!ptr)
235     ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder));
236   params.push_back(ptr);
237 }
238 
239 /// Generates the code to read the value from tensor[ivs], and conditionally
240 /// stores the indices ivs to the memory in ind. The generated code looks like
241 /// the following and the insertion point after this routine is inside the
242 /// if-then branch behind the assignment to ind. This is to ensure that the
243 /// addEltX call generated after is inside the if-then branch.
244 ///    if (tensor[ivs] != 0)
245 ///      ind = ivs
246 static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
247                                       Value tensor, Value ind, ValueRange ivs) {
248   Value val = builder.create<tensor::ExtractOp>(loc, tensor, ivs);
249   Value cond = genIsNonzero(builder, loc, val);
250   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else*/ false);
251   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
252   unsigned i = 0;
253   for (auto iv : ivs) {
254     Value idx = constantIndex(builder, loc, i++);
255     builder.create<memref::StoreOp>(loc, iv, ind, idx);
256   }
257   return val;
258 }
259 
260 /// Generates a call to release/delete a `SparseTensorCOO`.
261 static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp,
262                           Value coo) {
263   SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
264   TypeRange noTp;
265   createFuncCall(builder, op, name, noTp, coo, EmitCInterface::Off);
266 }
267 
268 /// Generates a call that adds one element to a coordinate scheme.
269 /// In particular, this generates code like the following:
270 ///   val = a[i1,..,ik];
271 ///   if val != 0
272 ///     t->add(&val, [i1,..,ik], [p1,..,pk]);
273 static void genAddEltCall(OpBuilder &builder, Operation *op, Type eltType,
274                           Value ptr, Value valPtr, Value ind, Value perm) {
275   SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
276   SmallVector<Value, 4> params{ptr, valPtr, ind, perm};
277   Type pTp = getOpaquePointerType(builder);
278   createFuncCall(builder, op, name, pTp, params, EmitCInterface::On);
279 }
280 
281 /// Generates a call to `iter->getNext()`.  If there is a next element,
282 /// then it is copied into the out-parameters `ind` and `elemPtr`,
283 /// and the return value is true.  If there isn't a next element, then
284 /// the memory for `iter` is freed and the return value is false.
285 static Value genGetNextCall(OpBuilder &builder, Operation *op, Value iter,
286                             Value ind, Value elemPtr) {
287   Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
288   SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)};
289   SmallVector<Value, 3> params{iter, ind, elemPtr};
290   Type i1 = builder.getI1Type();
291   return createFuncCall(builder, op, name, i1, params, EmitCInterface::On)
292       .getResult(0);
293 }
294 
295 /// If the tensor is a sparse constant, generates and returns the pair of
296 /// the constants for the indices and the values.
297 static Optional<std::pair<Value, Value>>
298 genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) {
299   if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
300     if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
301       DenseElementsAttr indicesAttr = attr.getIndices();
302       Value indices = builder.create<arith::ConstantOp>(loc, indicesAttr);
303       DenseElementsAttr valuesAttr = attr.getValues();
304       Value values = builder.create<arith::ConstantOp>(loc, valuesAttr);
305       return std::make_pair(indices, values);
306     }
307   }
308   return {};
309 }
310 
311 /// Generates the code to copy the index at indices[ivs] to ind, and return
312 /// the value at value[ivs].
313 static Value genIndexAndValueForSparse(OpBuilder &builder, Location loc,
314                                        Value indices, Value values, Value ind,
315                                        ValueRange ivs, unsigned rank) {
316   for (unsigned i = 0; i < rank; i++) {
317     Value idx = constantIndex(builder, loc, i);
318     Value val = builder.create<tensor::ExtractOp>(loc, indices,
319                                                   ValueRange{ivs[0], idx});
320     val = builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), val);
321     builder.create<memref::StoreOp>(loc, val, ind, idx);
322   }
323   return builder.create<tensor::ExtractOp>(loc, values, ivs[0]);
324 }
325 
326 /// Generates code to allocate a buffer of the given type, and zero
327 /// initialize it.  If the buffer type has any dynamic sizes, then the
328 /// `sizes` parameter should be as filled by sizesFromPtr(); that way
329 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr().
330 static Value allocDenseTensor(OpBuilder &builder, Location loc,
331                               RankedTensorType tensorTp, ValueRange sizes) {
332   Type elemTp = tensorTp.getElementType();
333   auto shape = tensorTp.getShape();
334   auto memTp = MemRefType::get(shape, elemTp);
335   SmallVector<Value> dynamicSizes;
336   for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) {
337     if (shape[i] == ShapedType::kDynamicSize)
338       dynamicSizes.push_back(sizes[i]);
339   }
340   Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes);
341   Value zero = constantZero(builder, loc, elemTp);
342   builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem});
343   return mem;
344 }
345 
346 /// Generates code to deallocate a dense buffer.
347 static void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer) {
348   builder.create<memref::DeallocOp>(loc, buffer);
349 }
350 
351 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into
352 /// the tensor created by allocDenseTensor().  The `rank` is the rank
353 /// of the `tensor` and the length of `ind`.
354 static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc,
355                                         Value elemPtr, Value tensor,
356                                         unsigned rank, Value ind) {
357   SmallVector<Value, 4> ivs;
358   ivs.reserve(rank);
359   for (unsigned i = 0; i < rank; i++) {
360     Value idx = constantIndex(builder, loc, i);
361     ivs.push_back(builder.create<memref::LoadOp>(loc, ind, idx));
362   }
363   Value elemV = builder.create<memref::LoadOp>(loc, elemPtr);
364   builder.create<memref::StoreOp>(loc, elemV, tensor, ivs);
365 }
366 
367 /// Determine if the runtime library supports direct conversion to the
368 /// given target `dimTypes`.
369 static bool canUseDirectConversion(
370     ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes) {
371   bool alreadyCompressed = false;
372   for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) {
373     switch (dimTypes[r]) {
374     case SparseTensorEncodingAttr::DimLevelType::Compressed:
375       if (alreadyCompressed)
376         return false; // Multiple compressed dimensions not yet supported.
377       alreadyCompressed = true;
378       break;
379     case SparseTensorEncodingAttr::DimLevelType::Dense:
380       if (alreadyCompressed)
381         return false; // Dense after Compressed not yet supported.
382       break;
383     case SparseTensorEncodingAttr::DimLevelType::Singleton:
384       // Although Singleton isn't generally supported yet, the direct
385       // conversion method doesn't have any particular problems with
386       // singleton after compressed.
387       break;
388     }
389   }
390   return true;
391 }
392 
393 /// Helper method to translate indices during a reshaping operation.
394 /// TODO: provide as general utility to MLIR at large?
395 static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
396                              ArrayRef<ReassociationIndices> reassociation,
397                              TensorType dstTp, TensorType srcTp, Value dstIdx,
398                              Value srcIdx) {
399   unsigned dstRank = dstTp.getRank();
400   unsigned srcRank = srcTp.getRank();
401   unsigned start = 0;
402   unsigned i = 0;
403   bool isExpand = srcRank > dstRank;
404   ArrayRef<int64_t> shape = isExpand ? srcTp.getShape() : dstTp.getShape();
405   // Iterate over reassociation map.
406   for (const auto &map : llvm::enumerate(reassociation)) {
407     // Prepare strides information in dimension slice.
408     uint64_t linear = 1;
409     for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
410       assert(!ShapedType::isDynamic(shape[j]));
411       linear *= shape[j];
412     }
413     // Start collapse.
414     Value idx = constantIndex(rewriter, loc, i++);
415     Value val;
416     if (!isExpand)
417       val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx);
418     // Iterate over dimension slice.
419     for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
420       linear /= shape[j];
421       Value stride = constantIndex(rewriter, loc, linear);
422       Value jdx = constantIndex(rewriter, loc, j);
423       if (isExpand) {
424         Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx);
425         Value mul = linear == 1
426                         ? old
427                         : rewriter.create<arith::MulIOp>(loc, old, stride);
428         val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul;
429       } else {
430         Value old = val;
431         if (linear != 1)
432           val = rewriter.create<arith::DivUIOp>(loc, val, stride);
433         rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx);
434         if (linear != 1)
435           val = rewriter.create<arith::RemUIOp>(loc, old, stride);
436       }
437     }
438     // Finalize expansion.
439     if (isExpand)
440       rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx);
441     start += map.value().size();
442   }
443   // Sanity.
444   assert((isExpand && i == dstRank) || (!isExpand && i == srcRank));
445 }
446 
447 /// Generate code for a general sparse to sparse reshaping operation.
448 /// Note that unlike dense reshaping (which can be done with a "cheap"
449 /// change of view), sparse reshaping is currently done with actual
450 /// data shuffling.
451 ///
452 /// TODO: proportional to nnz, but still a lot of data movement
453 ///       https://github.com/llvm/llvm-project/issues/56477
454 ///
455 ///   iter = src->toCOO();
456 ///   coo = newSparseCOO()
457 ///   while (elem = iter->getNext()) {
458 ///     coo->add(reshape(elem.indices), elem.value)
459 ///   }
460 ///   s = newSparseTensor(coo)
461 static LogicalResult
462 genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
463                         ArrayRef<ReassociationIndices> reassociation, Value src,
464                         RankedTensorType dstTp, RankedTensorType srcTp) {
465   Location loc = op->getLoc();
466   auto encDst = getSparseTensorEncoding(dstTp);
467   auto encSrc = getSparseTensorEncoding(srcTp);
468   assert(encDst && encSrc);
469   unsigned srcRank = srcTp.getRank();
470   unsigned dstRank = dstTp.getRank();
471   Type elemTp = srcTp.getElementType();
472   assert(elemTp == dstTp.getElementType() &&
473          "reshape should not change element type");
474   // Start an iterator over the source tensor (in original index order).
475   auto noPerm = SparseTensorEncodingAttr::get(
476       op->getContext(), encSrc.getDimLevelType(), AffineMap(),
477       encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
478   SmallVector<Value, 4> sizes;
479   SmallVector<Value, 8> params;
480   sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src);
481   newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes,
482             src);
483   Value iter = genNewCall(rewriter, op, params);
484   // Start a new COO for the destination tensor.
485   sizes.clear();
486   params.clear();
487   sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src);
488   newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes);
489   Value coo = genNewCall(rewriter, op, params);
490   Value dstPerm = params[2];
491   // Construct a while loop over the iterator.
492   Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
493   Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
494   Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
495   SmallVector<Value> noArgs;
496   SmallVector<Type> noTypes;
497   auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
498   Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
499   rewriter.setInsertionPointToEnd(before);
500   Value cond = genGetNextCall(rewriter, op, iter, srcIdx, elemPtr);
501   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
502   // Translate indices from source to target and insert. Note that we do
503   // not need to store the value in elemPtr, as the value is still there.
504   Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
505   rewriter.setInsertionPointToStart(after);
506   translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx);
507   genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm);
508   rewriter.create<scf::YieldOp>(loc);
509   // Final call to construct sparse tensor storage and free temporary resources.
510   rewriter.setInsertionPointAfter(whileOp);
511   params[6] = constantAction(rewriter, loc, Action::kFromCOO);
512   params[7] = coo;
513   Value dst = genNewCall(rewriter, op, params);
514   genDelCOOCall(rewriter, op, elemTp, coo);
515   genDelCOOCall(rewriter, op, elemTp, iter);
516   rewriter.replaceOp(op, dst);
517   return success();
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // Conversion rules.
522 //===----------------------------------------------------------------------===//
523 
524 /// Sparse conversion rule for returns.
525 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
526 public:
527   using OpConversionPattern::OpConversionPattern;
528   LogicalResult
529   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
530                   ConversionPatternRewriter &rewriter) const override {
531     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
532     return success();
533   }
534 };
535 
536 /// Sparse conversion rule for dimension accesses.
537 class SparseTensorToDimSizeConverter
538     : public OpConversionPattern<tensor::DimOp> {
539 public:
540   using OpConversionPattern::OpConversionPattern;
541   LogicalResult
542   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
543                   ConversionPatternRewriter &rewriter) const override {
544     // Only rewrite annotated DimOp with constant index.
545     auto enc = getSparseTensorEncoding(op.getSource().getType());
546     if (!enc)
547       return failure();
548     Optional<int64_t> index = op.getConstantIndex();
549     if (!index)
550       return failure();
551     // Generate the call.
552     Value src = adaptor.getOperands()[0];
553     int64_t idx = *index;
554     rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx));
555     return success();
556   }
557 };
558 
559 /// Sparse conversion rule for trivial tensor casts.
560 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
561 public:
562   using OpConversionPattern::OpConversionPattern;
563   LogicalResult
564   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
565                   ConversionPatternRewriter &rewriter) const override {
566     // Only rewrite identically annotated source/dest.
567     auto encDst = getSparseTensorEncoding(op.getType());
568     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
569     if (!encDst || encDst != encSrc)
570       return failure();
571     rewriter.replaceOp(op, adaptor.getOperands());
572     return success();
573   }
574 };
575 
576 /// Sparse conversion rule for a reshape operator.
577 template <typename ReshapeOp>
578 class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> {
579 public:
580   using OpAdaptor = typename OpConversionPattern<ReshapeOp>::OpAdaptor;
581   using OpConversionPattern<ReshapeOp>::OpConversionPattern;
582   LogicalResult
583   matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
584                   ConversionPatternRewriter &rewriter) const override {
585     Type dstType = op.getResult().getType();
586     Type srcType = op.getSrc().getType();
587     auto encDst = getSparseTensorEncoding(dstType);
588     auto encSrc = getSparseTensorEncoding(srcType);
589     if (encDst && encSrc)
590       return genSparse2SparseReshape(
591           op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0],
592           dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>());
593     return failure(); // handled elsewhere
594   }
595 };
596 
597 /// Sparse conversion rule for the new operator.
598 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
599 public:
600   using OpConversionPattern::OpConversionPattern;
601   LogicalResult
602   matchAndRewrite(NewOp op, OpAdaptor adaptor,
603                   ConversionPatternRewriter &rewriter) const override {
604     Type resType = op.getType();
605     auto enc = getSparseTensorEncoding(resType);
606     if (!enc)
607       return failure();
608     // Generate the call to construct tensor from ptr. The sizes are
609     // inferred from the result type of the new operator.
610     SmallVector<Value, 4> sizes;
611     SmallVector<Value, 8> params;
612     ShapedType stp = resType.cast<ShapedType>();
613     sizesFromType(rewriter, sizes, op.getLoc(), stp);
614     Value ptr = adaptor.getOperands()[0];
615     newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr);
616     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
617     return success();
618   }
619 };
620 
621 /// Sparse conversion rule for the alloc operator.
622 class SparseTensorAllocConverter
623     : public OpConversionPattern<bufferization::AllocTensorOp> {
624 public:
625   using OpConversionPattern::OpConversionPattern;
626   LogicalResult
627   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
628                   ConversionPatternRewriter &rewriter) const override {
629     if (op.getCopy())
630       return rewriter.notifyMatchFailure(op,
631                                          "sparse tensor copy not implemented");
632     RankedTensorType resType = op.getType();
633     auto enc = getSparseTensorEncoding(resType);
634     if (!enc)
635       return failure();
636     // Gather all dimension sizes as SSA values.
637     SmallVector<Value> sizes;
638     unsigned int operandCtr = 0;
639     for (int64_t i = 0; i < resType.getRank(); ++i) {
640       if (resType.isDynamicDim(i)) {
641         sizes.push_back(adaptor.getOperands()[operandCtr++]);
642       } else {
643         sizes.push_back(rewriter.create<arith::ConstantIndexOp>(
644             op.getLoc(), op.getStaticSize(i)));
645       }
646     }
647     // Generate the call to construct empty tensor. The sizes are
648     // explicitly defined by the arguments to the alloc operator.
649     SmallVector<Value, 8> params;
650     ShapedType stp = resType.cast<ShapedType>();
651     newParams(rewriter, params, op, stp, enc, Action::kEmpty, sizes);
652     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
653     return success();
654   }
655 };
656 
657 /// Sparse conversion rule for the convert operator.
658 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
659 public:
660   using OpConversionPattern::OpConversionPattern;
661   SparseTensorConvertConverter(MLIRContext *context,
662                                SparseTensorConversionOptions o)
663       : OpConversionPattern<ConvertOp>(context), options(o) {}
664   SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
665                                SparseTensorConversionOptions o)
666       : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
667 
668   LogicalResult
669   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
670                   ConversionPatternRewriter &rewriter) const override {
671     Location loc = op->getLoc();
672     Type resType = op.getType();
673     Type srcType = op.getSource().getType();
674     auto encDst = getSparseTensorEncoding(resType);
675     auto encSrc = getSparseTensorEncoding(srcType);
676     Value src = adaptor.getOperands()[0];
677     if (encDst && encSrc) {
678       // This is a sparse => sparse conversion, which is handled as follows:
679       //   t = src->toCOO();         ; src to COO in dst order
680       //   dst = newSparseTensor(t)
681       // Using the coordinate scheme as an intermediate does not always
682       // yield the fastest conversion but avoids the need for a full
683       // O(N^2) conversion matrix.
684       if (encDst == encSrc) {
685         rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
686         return success();
687       }
688       SmallVector<Value, 4> sizes;
689       SmallVector<Value, 8> params;
690       ShapedType stp = srcType.cast<ShapedType>();
691       sizesFromPtr(rewriter, sizes, op, encSrc, stp, src);
692       bool useDirectConversion;
693       switch (options.sparseToSparseStrategy) {
694       case SparseToSparseConversionStrategy::kViaCOO:
695         useDirectConversion = false;
696         break;
697       case SparseToSparseConversionStrategy::kDirect:
698         useDirectConversion = true;
699         assert(canUseDirectConversion(encDst.getDimLevelType()) &&
700                "Unsupported target for direct sparse-to-sparse conversion");
701         break;
702       case SparseToSparseConversionStrategy::kAuto:
703         useDirectConversion = canUseDirectConversion(encDst.getDimLevelType());
704         break;
705       }
706       if (useDirectConversion) {
707         newParams(rewriter, params, op, stp, encDst, Action::kSparseToSparse,
708                   sizes, src);
709         rewriter.replaceOp(op, genNewCall(rewriter, op, params));
710       } else { // use via-COO conversion.
711         // Set up encoding with right mix of src and dst so that the two
712         // method calls can share most parameters, while still providing
713         // the correct sparsity information to either of them.
714         auto enc = SparseTensorEncodingAttr::get(
715             op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
716             encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
717         newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src);
718         Value coo = genNewCall(rewriter, op, params);
719         params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
720         params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
721         params[6] = constantAction(rewriter, loc, Action::kFromCOO);
722         params[7] = coo;
723         Value dst = genNewCall(rewriter, op, params);
724         genDelCOOCall(rewriter, op, stp.getElementType(), coo);
725         rewriter.replaceOp(op, dst);
726       }
727       return success();
728     }
729     if (!encDst && encSrc) {
730       // This is sparse => dense conversion, which is handled as follows:
731       //   dst = new Tensor(0);
732       //   iter = src->toCOO();
733       //   iter->startIterator();
734       //   while (elem = iter->getNext()) {
735       //     dst[elem.indices] = elem.value;
736       //   }
737       RankedTensorType dstTensorTp = resType.cast<RankedTensorType>();
738       RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>();
739       unsigned rank = dstTensorTp.getRank();
740       Type elemTp = dstTensorTp.getElementType();
741       // Fabricate a no-permutation encoding for newParams().
742       // The pointer/index types must be those of `src`.
743       // The dimLevelTypes aren't actually used by Action::kToIterator.
744       encDst = SparseTensorEncodingAttr::get(
745           op->getContext(),
746           SmallVector<SparseTensorEncodingAttr::DimLevelType>(
747               rank, SparseTensorEncodingAttr::DimLevelType::Dense),
748           AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
749       SmallVector<Value, 4> sizes;
750       SmallVector<Value, 8> params;
751       sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src);
752       newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator,
753                 sizes, src);
754       Value iter = genNewCall(rewriter, op, params);
755       Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
756       Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
757       Block *insertionBlock = rewriter.getInsertionBlock();
758       // TODO: Dense buffers should be allocated/deallocated via the callback
759       // in BufferizationOptions.
760       Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes);
761       SmallVector<Value> noArgs;
762       SmallVector<Type> noTypes;
763       auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
764       Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
765       rewriter.setInsertionPointToEnd(before);
766       Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr);
767       rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
768       Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
769       rewriter.setInsertionPointToStart(after);
770       insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind);
771       rewriter.create<scf::YieldOp>(loc);
772       rewriter.setInsertionPointAfter(whileOp);
773       genDelCOOCall(rewriter, op, elemTp, iter);
774       rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst);
775       // Deallocate the buffer.
776       if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) {
777         rewriter.setInsertionPoint(insertionBlock->getTerminator());
778         deallocDenseTensor(rewriter, loc, dst);
779       }
780       return success();
781     }
782     if (!encDst && !encSrc) {
783       // dense => dense
784       return failure();
785     }
786     // This is a dense => sparse conversion or a sparse constant in COO =>
787     // sparse conversion, which is handled as follows:
788     //   t = newSparseCOO()
789     //   ...code to fill the COO tensor t...
790     //   s = newSparseTensor(t)
791     //
792     // To fill the COO tensor from a dense tensor:
793     //   for i1 in dim1
794     //    ..
795     //     for ik in dimk
796     //       val = a[i1,..,ik]
797     //       if val != 0
798     //         t->add(val, [i1,..,ik], [p1,..,pk])
799     //
800     // To fill the COO tensor from a sparse constant in COO format:
801     //   for i in range(NNZ)
802     //     val = values[i]
803     //     [i1,..,ik] = indices[i]
804     //     t->add(val, [i1,..,ik], [p1,..,pk])
805     //
806     // Note that the dense tensor traversal code is actually implemented
807     // using MLIR IR to avoid having to expose too much low-level
808     // memref traversal details to the runtime support library.
809     // Also note that the code below only generates the "new" ops and
810     // the loop-nest per se; whereas the entire body of the innermost
811     // loop is generated by genAddElt().
812     ShapedType stp = resType.cast<ShapedType>();
813     unsigned rank = stp.getRank();
814     SmallVector<Value, 4> sizes;
815     SmallVector<Value, 8> params;
816     sizesFromSrc(rewriter, sizes, loc, src);
817     newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes);
818     Value coo = genNewCall(rewriter, op, params);
819     Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
820     Value perm = params[2];
821     SmallVector<Value> lo;
822     SmallVector<Value> hi;
823     SmallVector<Value> st;
824     Value zero = constantIndex(rewriter, loc, 0);
825     Value one = constantIndex(rewriter, loc, 1);
826     auto indicesValues = genSplitSparseConstant(rewriter, loc, src);
827     bool isCOOConstant = indicesValues.has_value();
828     Value indices;
829     Value values;
830     if (isCOOConstant) {
831       indices = indicesValues->first;
832       values = indicesValues->second;
833       lo.push_back(zero);
834       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
835       st.push_back(one);
836     } else {
837       for (unsigned i = 0; i < rank; i++) {
838         lo.push_back(zero);
839         hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
840         st.push_back(one);
841       }
842     }
843     Type eltType = stp.getElementType();
844     Value elemPtr = genAllocaScalar(rewriter, loc, eltType);
845     scf::buildLoopNest(
846         rewriter, op.getLoc(), lo, hi, st, {},
847         [&](OpBuilder &builder, Location loc, ValueRange ivs,
848             ValueRange args) -> scf::ValueVector {
849           Value val;
850           if (isCOOConstant)
851             val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind,
852                                             ivs, rank);
853           else
854             val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
855           builder.create<memref::StoreOp>(loc, val, elemPtr);
856           genAddEltCall(rewriter, op, eltType, coo, elemPtr, ind, perm);
857           return {};
858         });
859     // Final call to construct sparse tensor storage.
860     params[6] = constantAction(rewriter, loc, Action::kFromCOO);
861     params[7] = coo;
862     Value dst = genNewCall(rewriter, op, params);
863     genDelCOOCall(rewriter, op, eltType, coo);
864     rewriter.replaceOp(op, dst);
865     return success();
866   }
867 
868 private:
869   /// Options to control sparse code generation.
870   SparseTensorConversionOptions options;
871 };
872 
873 /// Sparse conversion rule for the release operator.
874 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
875 public:
876   using OpConversionPattern::OpConversionPattern;
877   LogicalResult
878   matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
879                   ConversionPatternRewriter &rewriter) const override {
880     StringRef name = "delSparseTensor";
881     TypeRange noTp;
882     createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
883                    EmitCInterface::Off);
884     rewriter.eraseOp(op);
885     return success();
886   }
887 };
888 
889 /// Sparse conversion rule for pointer accesses.
890 class SparseTensorToPointersConverter
891     : public OpConversionPattern<ToPointersOp> {
892 public:
893   using OpConversionPattern::OpConversionPattern;
894   LogicalResult
895   matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
896                   ConversionPatternRewriter &rewriter) const override {
897     Type resType = op.getType();
898     Type ptrType = resType.cast<ShapedType>().getElementType();
899     SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)};
900     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
901                           EmitCInterface::On);
902     return success();
903   }
904 };
905 
906 /// Sparse conversion rule for index accesses.
907 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
908 public:
909   using OpConversionPattern::OpConversionPattern;
910   LogicalResult
911   matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
912                   ConversionPatternRewriter &rewriter) const override {
913     Type resType = op.getType();
914     Type indType = resType.cast<ShapedType>().getElementType();
915     SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)};
916     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
917                           EmitCInterface::On);
918     return success();
919   }
920 };
921 
922 /// Sparse conversion rule for value accesses.
923 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
924 public:
925   using OpConversionPattern::OpConversionPattern;
926   LogicalResult
927   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
928                   ConversionPatternRewriter &rewriter) const override {
929     Type resType = op.getType();
930     Type eltType = resType.cast<ShapedType>().getElementType();
931     SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)};
932     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
933                           EmitCInterface::On);
934     return success();
935   }
936 };
937 
938 /// Sparse conversion rule for tensor rematerialization.
939 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
940 public:
941   using OpConversionPattern::OpConversionPattern;
942   LogicalResult
943   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
944                   ConversionPatternRewriter &rewriter) const override {
945     if (op.getHasInserts()) {
946       // Finalize any pending insertions.
947       StringRef name = "endInsert";
948       TypeRange noTp;
949       createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
950                      EmitCInterface::Off);
951     }
952     rewriter.replaceOp(op, adaptor.getOperands());
953     return success();
954   }
955 };
956 
957 /// Sparse conversion rule for inserting in lexicographic index order.
958 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
959 public:
960   using OpConversionPattern::OpConversionPattern;
961   LogicalResult
962   matchAndRewrite(LexInsertOp op, OpAdaptor adaptor,
963                   ConversionPatternRewriter &rewriter) const override {
964     Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType();
965     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
966     TypeRange noTp;
967     replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
968                           EmitCInterface::On);
969     return success();
970   }
971 };
972 
973 /// Sparse conversion rule for the expand operator.
974 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
975 public:
976   using OpConversionPattern::OpConversionPattern;
977   LogicalResult
978   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
979                   ConversionPatternRewriter &rewriter) const override {
980     Location loc = op->getLoc();
981     ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
982     Type eltType = srcType.getElementType();
983     Type boolType = rewriter.getIntegerType(1);
984     Type idxType = rewriter.getIndexType();
985     // All initialization should be done on entry of the loop nest.
986     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
987     // Determine the size for access expansion.
988     auto enc = getSparseTensorEncoding(srcType);
989     Value src = adaptor.getOperands()[0];
990     Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1);
991     // Allocate temporary buffers for values, filled-switch, and indices.
992     // We do not use stack buffers for this, since the expanded size may
993     // be rather large (as it envelops a single expanded dense dimension).
994     Value values = genAlloc(rewriter, loc, sz, eltType);
995     Value filled = genAlloc(rewriter, loc, sz, boolType);
996     Value indices = genAlloc(rewriter, loc, sz, idxType);
997     Value zero = constantZero(rewriter, loc, idxType);
998     // Reset the values/filled-switch to all-zero/false. Note that this
999     // introduces an O(N) operation into the computation, but this reset
1000     // operation is amortized over the innermost loops for the access
1001     // pattern expansion. As noted in the operation doc, we would like
1002     // to amortize this setup cost even between kernels.
1003     rewriter.create<linalg::FillOp>(
1004         loc, ValueRange{constantZero(rewriter, loc, eltType)},
1005         ValueRange{values});
1006     rewriter.create<linalg::FillOp>(
1007         loc, ValueRange{constantZero(rewriter, loc, boolType)},
1008         ValueRange{filled});
1009     // Replace expansion op with these buffers and initial index.
1010     assert(op.getNumResults() == 4);
1011     rewriter.replaceOp(op, {values, filled, indices, zero});
1012     return success();
1013   }
1014 };
1015 
1016 /// Sparse conversion rule for the compress operator.
1017 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
1018 public:
1019   using OpConversionPattern::OpConversionPattern;
1020   LogicalResult
1021   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
1022                   ConversionPatternRewriter &rewriter) const override {
1023     Location loc = op->getLoc();
1024     // Note that this method call resets the values/filled-switch back to
1025     // all-zero/false by only iterating over the set elements, so the
1026     // complexity remains proportional to the sparsity of the expanded
1027     // access pattern.
1028     Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType();
1029     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
1030     TypeRange noTp;
1031     replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
1032                           EmitCInterface::On);
1033     // Deallocate the buffers on exit of the loop nest.
1034     Operation *parent = op;
1035     for (; isa<scf::ForOp>(parent->getParentOp()) ||
1036            isa<scf::WhileOp>(parent->getParentOp()) ||
1037            isa<scf::ParallelOp>(parent->getParentOp()) ||
1038            isa<scf::IfOp>(parent->getParentOp());
1039          parent = parent->getParentOp())
1040       ;
1041     rewriter.setInsertionPointAfter(parent);
1042     rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[2]);
1043     rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[3]);
1044     rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[4]);
1045     return success();
1046   }
1047 };
1048 
1049 /// Sparse conversion rule for the output operator.
1050 class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
1051 public:
1052   using OpConversionPattern::OpConversionPattern;
1053   LogicalResult
1054   matchAndRewrite(OutOp op, OpAdaptor adaptor,
1055                   ConversionPatternRewriter &rewriter) const override {
1056     Location loc = op->getLoc();
1057     ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
1058     // Convert to default permuted COO.
1059     Value src = adaptor.getOperands()[0];
1060     auto encSrc = getSparseTensorEncoding(srcType);
1061     SmallVector<Value, 4> sizes;
1062     SmallVector<Value, 8> params;
1063     sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src);
1064     auto enc = SparseTensorEncodingAttr::get(
1065         op->getContext(), encSrc.getDimLevelType(), AffineMap(),
1066         encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
1067     newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src);
1068     Value coo = genNewCall(rewriter, op, params);
1069     // Then output the tensor to external file with indices in the externally
1070     // visible lexicographic index order. A sort is required if the source was
1071     // not in that order yet (note that the sort can be dropped altogether if
1072     // external format does not care about the order at all, but here we assume
1073     // it does).
1074     bool sort =
1075         encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity();
1076     params.clear();
1077     params.push_back(coo);
1078     params.push_back(adaptor.getOperands()[1]);
1079     params.push_back(constantI1(rewriter, loc, sort));
1080     Type eltType = srcType.getElementType();
1081     SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)};
1082     TypeRange noTp;
1083     createFuncCall(rewriter, op, name, noTp, params, EmitCInterface::Off);
1084     genDelCOOCall(rewriter, op, eltType, coo);
1085     rewriter.eraseOp(op);
1086     return success();
1087   }
1088 };
1089 
1090 } // namespace
1091 
1092 //===----------------------------------------------------------------------===//
1093 // Public method for populating conversion rules.
1094 //===----------------------------------------------------------------------===//
1095 
1096 /// Populates the given patterns list with conversion rules required for
1097 /// the sparsification of linear algebra operations.
1098 void mlir::populateSparseTensorConversionPatterns(
1099     TypeConverter &typeConverter, RewritePatternSet &patterns,
1100     const SparseTensorConversionOptions &options) {
1101   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
1102                SparseCastConverter, SparseTensorNewConverter,
1103                SparseReshapeConverter<tensor::ExpandShapeOp>,
1104                SparseReshapeConverter<tensor::CollapseShapeOp>,
1105                SparseTensorAllocConverter, SparseTensorReleaseConverter,
1106                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
1107                SparseTensorToValuesConverter, SparseTensorLoadConverter,
1108                SparseTensorLexInsertConverter, SparseTensorExpandConverter,
1109                SparseTensorCompressConverter, SparseTensorOutConverter>(
1110       typeConverter, patterns.getContext());
1111   patterns.add<SparseTensorConvertConverter>(typeConverter,
1112                                              patterns.getContext(), options);
1113 }
1114