1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "CodegenUtils.h"
18 
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/Dialect/Tensor/IR/Tensor.h"
29 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
30 #include "mlir/Transforms/DialectConversion.h"
31 
32 using namespace mlir;
33 using namespace mlir::sparse_tensor;
34 
35 namespace {
36 
37 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
38 /// `createFuncCall()`, and `replaceOpWithFuncCall()`.
39 enum class EmitCInterface : bool { Off = false, On = true };
40 
41 //===----------------------------------------------------------------------===//
42 // Helper methods.
43 //===----------------------------------------------------------------------===//
44 
45 /// Returns the equivalent of `void*` for opaque arguments to the
46 /// execution engine.
getOpaquePointerType(OpBuilder & builder)47 static Type getOpaquePointerType(OpBuilder &builder) {
48   return LLVM::LLVMPointerType::get(builder.getI8Type());
49 }
50 
51 /// Returns a function reference (first hit also inserts into module). Sets
52 /// the "_emit_c_interface" on the function declaration when requested,
53 /// so that LLVM lowering generates a wrapper function that takes care
54 /// of ABI complications with passing in and returning MemRefs to C functions.
getFunc(Operation * op,StringRef name,TypeRange resultType,ValueRange operands,EmitCInterface emitCInterface)55 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
56                                  TypeRange resultType, ValueRange operands,
57                                  EmitCInterface emitCInterface) {
58   MLIRContext *context = op->getContext();
59   auto module = op->getParentOfType<ModuleOp>();
60   auto result = SymbolRefAttr::get(context, name);
61   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
62   if (!func) {
63     OpBuilder moduleBuilder(module.getBodyRegion());
64     func = moduleBuilder.create<func::FuncOp>(
65         op->getLoc(), name,
66         FunctionType::get(context, operands.getTypes(), resultType));
67     func.setPrivate();
68     if (static_cast<bool>(emitCInterface))
69       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
70                     UnitAttr::get(context));
71   }
72   return result;
73 }
74 
75 /// Creates a `CallOp` to the function reference returned by `getFunc()`.
createFuncCall(OpBuilder & builder,Operation * op,StringRef name,TypeRange resultType,ValueRange operands,EmitCInterface emitCInterface)76 static func::CallOp createFuncCall(OpBuilder &builder, Operation *op,
77                                    StringRef name, TypeRange resultType,
78                                    ValueRange operands,
79                                    EmitCInterface emitCInterface) {
80   auto fn = getFunc(op, name, resultType, operands, emitCInterface);
81   return builder.create<func::CallOp>(op->getLoc(), resultType, fn, operands);
82 }
83 
84 /// Replaces the `op` with  a `CallOp` to the function reference returned
85 /// by `getFunc()`.
replaceOpWithFuncCall(RewriterBase & rewriter,Operation * op,StringRef name,TypeRange resultType,ValueRange operands,EmitCInterface emitCInterface)86 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
87                                           StringRef name, TypeRange resultType,
88                                           ValueRange operands,
89                                           EmitCInterface emitCInterface) {
90   auto fn = getFunc(op, name, resultType, operands, emitCInterface);
91   return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
92                                                    operands);
93 }
94 
95 /// Generates dimension size call.
genDimSizeCall(OpBuilder & builder,Operation * op,SparseTensorEncodingAttr & enc,Value src,int64_t idx)96 static Value genDimSizeCall(OpBuilder &builder, Operation *op,
97                             SparseTensorEncodingAttr &enc, Value src,
98                             int64_t idx) {
99   // Permute the index according to an optional dimension ordering.
100   if (AffineMap p = enc.getDimOrdering())
101     idx = p.getPermutedPosition(idx);
102   // Generate the call.
103   StringRef name = "sparseDimSize";
104   SmallVector<Value, 2> params{src, constantIndex(builder, op->getLoc(), idx)};
105   Type iTp = builder.getIndexType();
106   return createFuncCall(builder, op, name, iTp, params, EmitCInterface::Off)
107       .getResult(0);
108 }
109 
110 /// Generates a call into the "swiss army knife" method of the sparse runtime
111 /// support library for materializing sparse tensors into the computation.
genNewCall(OpBuilder & builder,Operation * op,ArrayRef<Value> params)112 static Value genNewCall(OpBuilder &builder, Operation *op,
113                         ArrayRef<Value> params) {
114   StringRef name = "newSparseTensor";
115   Type pTp = getOpaquePointerType(builder);
116   return createFuncCall(builder, op, name, pTp, params, EmitCInterface::On)
117       .getResult(0);
118 }
119 
120 /// Populates given sizes array from type.
sizesFromType(OpBuilder & builder,SmallVector<Value,4> & sizes,Location loc,ShapedType stp)121 static void sizesFromType(OpBuilder &builder, SmallVector<Value, 4> &sizes,
122                           Location loc, ShapedType stp) {
123   auto shape = stp.getShape();
124   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) {
125     uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
126     sizes.push_back(constantIndex(builder, loc, s));
127   }
128 }
129 
130 /// Populates given sizes array from source.
sizesFromSrc(OpBuilder & builder,SmallVector<Value,4> & sizes,Location loc,Value src)131 static void sizesFromSrc(OpBuilder &builder, SmallVector<Value, 4> &sizes,
132                          Location loc, Value src) {
133   unsigned rank = src.getType().cast<ShapedType>().getRank();
134   for (unsigned i = 0; i < rank; i++)
135     sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, i));
136 }
137 
138 /// Populates given sizes array from type (for static sizes) and from
139 /// an already converted into opague pointer source (for dynamic sizes).
sizesFromPtr(OpBuilder & builder,SmallVector<Value,4> & sizes,Operation * op,SparseTensorEncodingAttr & enc,ShapedType stp,Value src)140 static void sizesFromPtr(OpBuilder &builder, SmallVector<Value, 4> &sizes,
141                          Operation *op, SparseTensorEncodingAttr &enc,
142                          ShapedType stp, Value src) {
143   Location loc = op->getLoc();
144   auto shape = stp.getShape();
145   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
146     if (shape[i] == ShapedType::kDynamicSize)
147       sizes.push_back(genDimSizeCall(builder, op, enc, src, i));
148     else
149       sizes.push_back(constantIndex(builder, loc, shape[i]));
150 }
151 
152 /// Generates an uninitialized temporary buffer of the given size and
153 /// type, but returns it as type `memref<? x $tp>` (rather than as type
154 /// `memref<$sz x $tp>`).
genAlloca(OpBuilder & builder,Location loc,Value sz,Type tp)155 static Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) {
156   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
157   return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz});
158 }
159 
160 /// Generates an uninitialized buffer of the given size and type,
161 /// but returns it as type `memref<? x $tp>` (rather than as type
162 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
163 /// this buffer must be explicitly deallocated by client.
genAlloc(RewriterBase & rewriter,Location loc,Value sz,Type tp)164 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
165   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
166   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
167 }
168 
169 /// Generates an uninitialized temporary buffer of the given size and
170 /// type, but returns it as type `memref<? x $tp>` (rather than as type
171 /// `memref<$sz x $tp>`).
genAlloca(OpBuilder & builder,Location loc,unsigned sz,Type tp)172 static Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp) {
173   return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp);
174 }
175 
176 /// Generates an uninitialized temporary buffer with room for one value
177 /// of the given type, and returns the `memref<$tp>`.
genAllocaScalar(OpBuilder & builder,Location loc,Type tp)178 static Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp) {
179   return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
180 }
181 
182 /// Generates a temporary buffer of the given type and given contents.
genBuffer(OpBuilder & builder,Location loc,ValueRange values)183 static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
184   unsigned sz = values.size();
185   assert(sz >= 1);
186   Value buffer = genAlloca(builder, loc, sz, values[0].getType());
187   for (unsigned i = 0; i < sz; i++) {
188     Value idx = constantIndex(builder, loc, i);
189     builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
190   }
191   return buffer;
192 }
193 
194 /// Populates parameters required to call the "swiss army knife" method of the
195 /// sparse runtime support library for materializing sparse tensors into the
196 /// computation.
newParams(OpBuilder & builder,SmallVector<Value,8> & params,Operation * op,ShapedType stp,SparseTensorEncodingAttr & enc,Action action,ValueRange szs,Value ptr=Value ())197 static void newParams(OpBuilder &builder, SmallVector<Value, 8> &params,
198                       Operation *op, ShapedType stp,
199                       SparseTensorEncodingAttr &enc, Action action,
200                       ValueRange szs, Value ptr = Value()) {
201   Location loc = op->getLoc();
202   ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
203   unsigned sz = dlt.size();
204   // Sparsity annotations.
205   SmallVector<Value, 4> attrs;
206   for (unsigned i = 0; i < sz; i++)
207     attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i]));
208   params.push_back(genBuffer(builder, loc, attrs));
209   // Dimension sizes array of the enveloping tensor. Useful for either
210   // verification of external data, or for construction of internal data.
211   params.push_back(genBuffer(builder, loc, szs));
212   // Dimension order permutation array. This is the "identity" permutation by
213   // default, or otherwise the "reverse" permutation of a given ordering, so
214   // that indices can be mapped quickly to the right position.
215   SmallVector<Value, 4> rev(sz);
216   if (AffineMap p = enc.getDimOrdering()) {
217     for (unsigned i = 0; i < sz; i++)
218       rev[p.getDimPosition(i)] = constantIndex(builder, loc, i);
219   } else {
220     for (unsigned i = 0; i < sz; i++)
221       rev[i] = constantIndex(builder, loc, i);
222   }
223   params.push_back(genBuffer(builder, loc, rev));
224   // Secondary and primary types encoding.
225   Type elemTp = stp.getElementType();
226   params.push_back(constantPointerTypeEncoding(builder, loc, enc));
227   params.push_back(constantIndexTypeEncoding(builder, loc, enc));
228   params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp));
229   // User action.
230   params.push_back(constantAction(builder, loc, action));
231   // Payload pointer.
232   if (!ptr)
233     ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder));
234   params.push_back(ptr);
235 }
236 
237 /// Generates the code to read the value from tensor[ivs], and conditionally
238 /// stores the indices ivs to the memory in ind. The generated code looks like
239 /// the following and the insertion point after this routine is inside the
240 /// if-then branch behind the assignment to ind. This is to ensure that the
241 /// addEltX call generated after is inside the if-then branch.
242 ///    if (tensor[ivs] != 0)
243 ///      ind = ivs
genIndexAndValueForDense(OpBuilder & builder,Location loc,Value tensor,Value ind,ValueRange ivs)244 static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
245                                       Value tensor, Value ind, ValueRange ivs) {
246   Value val = builder.create<tensor::ExtractOp>(loc, tensor, ivs);
247   Value cond = genIsNonzero(builder, loc, val);
248   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else*/ false);
249   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
250   unsigned i = 0;
251   for (auto iv : ivs) {
252     Value idx = constantIndex(builder, loc, i++);
253     builder.create<memref::StoreOp>(loc, iv, ind, idx);
254   }
255   return val;
256 }
257 
258 /// Generates a call to release/delete a `SparseTensorCOO`.
genDelCOOCall(OpBuilder & builder,Operation * op,Type elemTp,Value coo)259 static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp,
260                           Value coo) {
261   SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
262   TypeRange noTp;
263   createFuncCall(builder, op, name, noTp, coo, EmitCInterface::Off);
264 }
265 
266 /// Generates a call that adds one element to a coordinate scheme.
267 /// In particular, this generates code like the following:
268 ///   val = a[i1,..,ik];
269 ///   if val != 0
270 ///     t->add(&val, [i1,..,ik], [p1,..,pk]);
genAddEltCall(OpBuilder & builder,Operation * op,Type eltType,Value ptr,Value valPtr,Value ind,Value perm)271 static void genAddEltCall(OpBuilder &builder, Operation *op, Type eltType,
272                           Value ptr, Value valPtr, Value ind, Value perm) {
273   SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
274   SmallVector<Value, 4> params{ptr, valPtr, ind, perm};
275   Type pTp = getOpaquePointerType(builder);
276   createFuncCall(builder, op, name, pTp, params, EmitCInterface::On);
277 }
278 
279 /// Generates a call to `iter->getNext()`.  If there is a next element,
280 /// then it is copied into the out-parameters `ind` and `elemPtr`,
281 /// and the return value is true.  If there isn't a next element, then
282 /// the memory for `iter` is freed and the return value is false.
genGetNextCall(OpBuilder & builder,Operation * op,Value iter,Value ind,Value elemPtr)283 static Value genGetNextCall(OpBuilder &builder, Operation *op, Value iter,
284                             Value ind, Value elemPtr) {
285   Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
286   SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)};
287   SmallVector<Value, 3> params{iter, ind, elemPtr};
288   Type i1 = builder.getI1Type();
289   return createFuncCall(builder, op, name, i1, params, EmitCInterface::On)
290       .getResult(0);
291 }
292 
293 /// If the tensor is a sparse constant, generates and returns the pair of
294 /// the constants for the indices and the values.
295 static Optional<std::pair<Value, Value>>
genSplitSparseConstant(OpBuilder & builder,Location loc,Value tensor)296 genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) {
297   if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
298     if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
299       DenseElementsAttr indicesAttr = attr.getIndices();
300       Value indices = builder.create<arith::ConstantOp>(loc, indicesAttr);
301       DenseElementsAttr valuesAttr = attr.getValues();
302       Value values = builder.create<arith::ConstantOp>(loc, valuesAttr);
303       return std::make_pair(indices, values);
304     }
305   }
306   return {};
307 }
308 
309 /// Generates the code to copy the index at indices[ivs] to ind, and return
310 /// the value at value[ivs].
genIndexAndValueForSparse(OpBuilder & builder,Location loc,Value indices,Value values,Value ind,ValueRange ivs,unsigned rank)311 static Value genIndexAndValueForSparse(OpBuilder &builder, Location loc,
312                                        Value indices, Value values, Value ind,
313                                        ValueRange ivs, unsigned rank) {
314   for (unsigned i = 0; i < rank; i++) {
315     Value idx = constantIndex(builder, loc, i);
316     Value val = builder.create<tensor::ExtractOp>(loc, indices,
317                                                   ValueRange{ivs[0], idx});
318     val = builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), val);
319     builder.create<memref::StoreOp>(loc, val, ind, idx);
320   }
321   return builder.create<tensor::ExtractOp>(loc, values, ivs[0]);
322 }
323 
324 /// Generates code to allocate a buffer of the given type, and zero
325 /// initialize it.  If the buffer type has any dynamic sizes, then the
326 /// `sizes` parameter should be as filled by sizesFromPtr(); that way
327 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr().
allocDenseTensor(OpBuilder & builder,Location loc,RankedTensorType tensorTp,ValueRange sizes)328 static Value allocDenseTensor(OpBuilder &builder, Location loc,
329                               RankedTensorType tensorTp, ValueRange sizes) {
330   Type elemTp = tensorTp.getElementType();
331   auto shape = tensorTp.getShape();
332   auto memTp = MemRefType::get(shape, elemTp);
333   SmallVector<Value> dynamicSizes;
334   for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) {
335     if (shape[i] == ShapedType::kDynamicSize)
336       dynamicSizes.push_back(sizes[i]);
337   }
338   Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes);
339   Value zero = constantZero(builder, loc, elemTp);
340   builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem});
341   return mem;
342 }
343 
344 /// Generates code to deallocate a dense buffer.
deallocDenseTensor(OpBuilder & builder,Location loc,Value buffer)345 static void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer) {
346   builder.create<memref::DeallocOp>(loc, buffer);
347 }
348 
349 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into
350 /// the tensor created by allocDenseTensor().  The `rank` is the rank
351 /// of the `tensor` and the length of `ind`.
insertScalarIntoDenseTensor(OpBuilder & builder,Location loc,Value elemPtr,Value tensor,unsigned rank,Value ind)352 static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc,
353                                         Value elemPtr, Value tensor,
354                                         unsigned rank, Value ind) {
355   SmallVector<Value, 4> ivs;
356   ivs.reserve(rank);
357   for (unsigned i = 0; i < rank; i++) {
358     Value idx = constantIndex(builder, loc, i);
359     ivs.push_back(builder.create<memref::LoadOp>(loc, ind, idx));
360   }
361   Value elemV = builder.create<memref::LoadOp>(loc, elemPtr);
362   builder.create<memref::StoreOp>(loc, elemV, tensor, ivs);
363 }
364 
365 /// Determine if the runtime library supports direct conversion to the
366 /// given target `dimTypes`.
canUseDirectConversion(ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes)367 static bool canUseDirectConversion(
368     ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes) {
369   bool alreadyCompressed = false;
370   for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) {
371     switch (dimTypes[r]) {
372     case SparseTensorEncodingAttr::DimLevelType::Compressed:
373       if (alreadyCompressed)
374         return false; // Multiple compressed dimensions not yet supported.
375       alreadyCompressed = true;
376       break;
377     case SparseTensorEncodingAttr::DimLevelType::Dense:
378       if (alreadyCompressed)
379         return false; // Dense after Compressed not yet supported.
380       break;
381     case SparseTensorEncodingAttr::DimLevelType::Singleton:
382       // Although Singleton isn't generally supported yet, the direct
383       // conversion method doesn't have any particular problems with
384       // singleton after compressed.
385       break;
386     }
387   }
388   return true;
389 }
390 
391 /// Helper method to translate indices during a reshaping operation.
392 /// TODO: provide as general utility to MLIR at large?
translateIndices(Location loc,ConversionPatternRewriter & rewriter,ArrayRef<ReassociationIndices> reassociation,TensorType dstTp,TensorType srcTp,Value dstIdx,Value srcIdx)393 static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
394                              ArrayRef<ReassociationIndices> reassociation,
395                              TensorType dstTp, TensorType srcTp, Value dstIdx,
396                              Value srcIdx) {
397   unsigned dstRank = dstTp.getRank();
398   unsigned srcRank = srcTp.getRank();
399   unsigned start = 0;
400   unsigned i = 0;
401   bool isExpand = srcRank > dstRank;
402   ArrayRef<int64_t> shape = isExpand ? srcTp.getShape() : dstTp.getShape();
403   // Iterate over reassociation map.
404   for (const auto &map : llvm::enumerate(reassociation)) {
405     // Prepare strides information in dimension slice.
406     uint64_t linear = 1;
407     for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
408       assert(!ShapedType::isDynamic(shape[j]));
409       linear *= shape[j];
410     }
411     // Start collapse.
412     Value idx = constantIndex(rewriter, loc, i++);
413     Value val;
414     if (!isExpand)
415       val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx);
416     // Iterate over dimension slice.
417     for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
418       linear /= shape[j];
419       Value stride = constantIndex(rewriter, loc, linear);
420       Value jdx = constantIndex(rewriter, loc, j);
421       if (isExpand) {
422         Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx);
423         Value mul = linear == 1
424                         ? old
425                         : rewriter.create<arith::MulIOp>(loc, old, stride);
426         val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul;
427       } else {
428         Value old = val;
429         if (linear != 1)
430           val = rewriter.create<arith::DivUIOp>(loc, val, stride);
431         rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx);
432         if (linear != 1)
433           val = rewriter.create<arith::RemUIOp>(loc, old, stride);
434       }
435     }
436     // Finalize expansion.
437     if (isExpand)
438       rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx);
439     start += map.value().size();
440   }
441   // Sanity.
442   assert((isExpand && i == dstRank) || (!isExpand && i == srcRank));
443 }
444 
445 /// Generate code for a general sparse to sparse reshaping operation.
446 /// Note that unlike dense reshaping (which can be done with a "cheap"
447 /// change of view), sparse reshaping is currently done with actual
448 /// data shuffling.
449 ///
450 /// TODO: proportional to nnz, but still a lot of data movement
451 ///       https://github.com/llvm/llvm-project/issues/56477
452 ///
453 ///   iter = src->toCOO();
454 ///   coo = newSparseCOO()
455 ///   while (elem = iter->getNext()) {
456 ///     coo->add(reshape(elem.indices), elem.value)
457 ///   }
458 ///   s = newSparseTensor(coo)
459 static LogicalResult
genSparse2SparseReshape(Operation * op,ConversionPatternRewriter & rewriter,ArrayRef<ReassociationIndices> reassociation,Value src,RankedTensorType dstTp,RankedTensorType srcTp)460 genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
461                         ArrayRef<ReassociationIndices> reassociation, Value src,
462                         RankedTensorType dstTp, RankedTensorType srcTp) {
463   Location loc = op->getLoc();
464   auto encDst = getSparseTensorEncoding(dstTp);
465   auto encSrc = getSparseTensorEncoding(srcTp);
466   assert(encDst && encSrc);
467   unsigned srcRank = srcTp.getRank();
468   unsigned dstRank = dstTp.getRank();
469   Type elemTp = srcTp.getElementType();
470   assert(elemTp == dstTp.getElementType() &&
471          "reshape should not change element type");
472   // Start an iterator over the source tensor (in original index order).
473   auto noPerm = SparseTensorEncodingAttr::get(
474       op->getContext(), encSrc.getDimLevelType(), AffineMap(),
475       encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
476   SmallVector<Value, 4> sizes;
477   SmallVector<Value, 8> params;
478   sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src);
479   newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes,
480             src);
481   Value iter = genNewCall(rewriter, op, params);
482   // Start a new COO for the destination tensor.
483   sizes.clear();
484   params.clear();
485   sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src);
486   newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes);
487   Value coo = genNewCall(rewriter, op, params);
488   Value dstPerm = params[2];
489   // Construct a while loop over the iterator.
490   Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
491   Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
492   Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
493   SmallVector<Value> noArgs;
494   SmallVector<Type> noTypes;
495   auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
496   Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
497   rewriter.setInsertionPointToEnd(before);
498   Value cond = genGetNextCall(rewriter, op, iter, srcIdx, elemPtr);
499   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
500   // Translate indices from source to target and insert. Note that we do
501   // not need to store the value in elemPtr, as the value is still there.
502   Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
503   rewriter.setInsertionPointToStart(after);
504   translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx);
505   genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm);
506   rewriter.create<scf::YieldOp>(loc);
507   // Final call to construct sparse tensor storage and free temporary resources.
508   rewriter.setInsertionPointAfter(whileOp);
509   params[6] = constantAction(rewriter, loc, Action::kFromCOO);
510   params[7] = coo;
511   Value dst = genNewCall(rewriter, op, params);
512   genDelCOOCall(rewriter, op, elemTp, coo);
513   genDelCOOCall(rewriter, op, elemTp, iter);
514   rewriter.replaceOp(op, dst);
515   return success();
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // Conversion rules.
520 //===----------------------------------------------------------------------===//
521 
522 /// Sparse conversion rule for returns.
523 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
524 public:
525   using OpConversionPattern::OpConversionPattern;
526   LogicalResult
matchAndRewrite(func::ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const527   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
528                   ConversionPatternRewriter &rewriter) const override {
529     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
530     return success();
531   }
532 };
533 
534 /// Sparse conversion rule for dimension accesses.
535 class SparseTensorToDimSizeConverter
536     : public OpConversionPattern<tensor::DimOp> {
537 public:
538   using OpConversionPattern::OpConversionPattern;
539   LogicalResult
matchAndRewrite(tensor::DimOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const540   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
541                   ConversionPatternRewriter &rewriter) const override {
542     // Only rewrite annotated DimOp with constant index.
543     auto enc = getSparseTensorEncoding(op.getSource().getType());
544     if (!enc)
545       return failure();
546     Optional<int64_t> index = op.getConstantIndex();
547     if (!index)
548       return failure();
549     // Generate the call.
550     Value src = adaptor.getOperands()[0];
551     int64_t idx = *index;
552     rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx));
553     return success();
554   }
555 };
556 
557 /// Sparse conversion rule for trivial tensor casts.
558 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
559 public:
560   using OpConversionPattern::OpConversionPattern;
561   LogicalResult
matchAndRewrite(tensor::CastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const562   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
563                   ConversionPatternRewriter &rewriter) const override {
564     // Only rewrite identically annotated source/dest.
565     auto encDst = getSparseTensorEncoding(op.getType());
566     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
567     if (!encDst || encDst != encSrc)
568       return failure();
569     rewriter.replaceOp(op, adaptor.getOperands());
570     return success();
571   }
572 };
573 
574 /// Sparse conversion rule for a reshape operator.
575 template <typename ReshapeOp>
576 class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> {
577 public:
578   using OpAdaptor = typename OpConversionPattern<ReshapeOp>::OpAdaptor;
579   using OpConversionPattern<ReshapeOp>::OpConversionPattern;
580   LogicalResult
matchAndRewrite(ReshapeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const581   matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
582                   ConversionPatternRewriter &rewriter) const override {
583     Type dstType = op.getResult().getType();
584     Type srcType = op.getSrc().getType();
585     auto encDst = getSparseTensorEncoding(dstType);
586     auto encSrc = getSparseTensorEncoding(srcType);
587     if (encDst && encSrc)
588       return genSparse2SparseReshape(
589           op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0],
590           dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>());
591     return failure(); // handled elsewhere
592   }
593 };
594 
595 /// Sparse conversion rule for the new operator.
596 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
597 public:
598   using OpConversionPattern::OpConversionPattern;
599   LogicalResult
matchAndRewrite(NewOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const600   matchAndRewrite(NewOp op, OpAdaptor adaptor,
601                   ConversionPatternRewriter &rewriter) const override {
602     Type resType = op.getType();
603     auto enc = getSparseTensorEncoding(resType);
604     if (!enc)
605       return failure();
606     // Generate the call to construct tensor from ptr. The sizes are
607     // inferred from the result type of the new operator.
608     SmallVector<Value, 4> sizes;
609     SmallVector<Value, 8> params;
610     ShapedType stp = resType.cast<ShapedType>();
611     sizesFromType(rewriter, sizes, op.getLoc(), stp);
612     Value ptr = adaptor.getOperands()[0];
613     newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr);
614     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
615     return success();
616   }
617 };
618 
619 /// Sparse conversion rule for the alloc operator.
620 class SparseTensorAllocConverter
621     : public OpConversionPattern<bufferization::AllocTensorOp> {
622 public:
623   using OpConversionPattern::OpConversionPattern;
624   LogicalResult
matchAndRewrite(bufferization::AllocTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const625   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
626                   ConversionPatternRewriter &rewriter) const override {
627     if (op.getCopy())
628       return rewriter.notifyMatchFailure(op,
629                                          "sparse tensor copy not implemented");
630     RankedTensorType resType = op.getType();
631     auto enc = getSparseTensorEncoding(resType);
632     if (!enc)
633       return failure();
634     // Gather all dimension sizes as SSA values.
635     SmallVector<Value> sizes;
636     unsigned int operandCtr = 0;
637     for (int64_t i = 0; i < resType.getRank(); ++i) {
638       if (resType.isDynamicDim(i)) {
639         sizes.push_back(adaptor.getOperands()[operandCtr++]);
640       } else {
641         sizes.push_back(rewriter.create<arith::ConstantIndexOp>(
642             op.getLoc(), op.getStaticSize(i)));
643       }
644     }
645     // Generate the call to construct empty tensor. The sizes are
646     // explicitly defined by the arguments to the alloc operator.
647     SmallVector<Value, 8> params;
648     ShapedType stp = resType.cast<ShapedType>();
649     newParams(rewriter, params, op, stp, enc, Action::kEmpty, sizes);
650     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
651     return success();
652   }
653 };
654 
655 /// Sparse conversion rule for the convert operator.
656 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
657 public:
658   using OpConversionPattern::OpConversionPattern;
SparseTensorConvertConverter(MLIRContext * context,SparseTensorConversionOptions o)659   SparseTensorConvertConverter(MLIRContext *context,
660                                SparseTensorConversionOptions o)
661       : OpConversionPattern<ConvertOp>(context), options(o) {}
SparseTensorConvertConverter(TypeConverter & typeConv,MLIRContext * context,SparseTensorConversionOptions o)662   SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
663                                SparseTensorConversionOptions o)
664       : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
665 
666   LogicalResult
matchAndRewrite(ConvertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const667   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
668                   ConversionPatternRewriter &rewriter) const override {
669     Location loc = op->getLoc();
670     Type resType = op.getType();
671     Type srcType = op.getSource().getType();
672     auto encDst = getSparseTensorEncoding(resType);
673     auto encSrc = getSparseTensorEncoding(srcType);
674     Value src = adaptor.getOperands()[0];
675     if (encDst && encSrc) {
676       // This is a sparse => sparse conversion, which is handled as follows:
677       //   t = src->toCOO();         ; src to COO in dst order
678       //   dst = newSparseTensor(t)
679       // Using the coordinate scheme as an intermediate does not always
680       // yield the fastest conversion but avoids the need for a full
681       // O(N^2) conversion matrix.
682       if (encDst == encSrc) {
683         rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
684         return success();
685       }
686       SmallVector<Value, 4> sizes;
687       SmallVector<Value, 8> params;
688       ShapedType stp = srcType.cast<ShapedType>();
689       sizesFromPtr(rewriter, sizes, op, encSrc, stp, src);
690       bool useDirectConversion;
691       switch (options.sparseToSparseStrategy) {
692       case SparseToSparseConversionStrategy::kViaCOO:
693         useDirectConversion = false;
694         break;
695       case SparseToSparseConversionStrategy::kDirect:
696         useDirectConversion = true;
697         assert(canUseDirectConversion(encDst.getDimLevelType()) &&
698                "Unsupported target for direct sparse-to-sparse conversion");
699         break;
700       case SparseToSparseConversionStrategy::kAuto:
701         useDirectConversion = canUseDirectConversion(encDst.getDimLevelType());
702         break;
703       }
704       if (useDirectConversion) {
705         newParams(rewriter, params, op, stp, encDst, Action::kSparseToSparse,
706                   sizes, src);
707         rewriter.replaceOp(op, genNewCall(rewriter, op, params));
708       } else { // use via-COO conversion.
709         // Set up encoding with right mix of src and dst so that the two
710         // method calls can share most parameters, while still providing
711         // the correct sparsity information to either of them.
712         auto enc = SparseTensorEncodingAttr::get(
713             op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
714             encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
715         newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src);
716         Value coo = genNewCall(rewriter, op, params);
717         params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
718         params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
719         params[6] = constantAction(rewriter, loc, Action::kFromCOO);
720         params[7] = coo;
721         Value dst = genNewCall(rewriter, op, params);
722         genDelCOOCall(rewriter, op, stp.getElementType(), coo);
723         rewriter.replaceOp(op, dst);
724       }
725       return success();
726     }
727     if (!encDst && encSrc) {
728       // This is sparse => dense conversion, which is handled as follows:
729       //   dst = new Tensor(0);
730       //   iter = src->toCOO();
731       //   iter->startIterator();
732       //   while (elem = iter->getNext()) {
733       //     dst[elem.indices] = elem.value;
734       //   }
735       RankedTensorType dstTensorTp = resType.cast<RankedTensorType>();
736       RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>();
737       unsigned rank = dstTensorTp.getRank();
738       Type elemTp = dstTensorTp.getElementType();
739       // Fabricate a no-permutation encoding for newParams().
740       // The pointer/index types must be those of `src`.
741       // The dimLevelTypes aren't actually used by Action::kToIterator.
742       encDst = SparseTensorEncodingAttr::get(
743           op->getContext(),
744           SmallVector<SparseTensorEncodingAttr::DimLevelType>(
745               rank, SparseTensorEncodingAttr::DimLevelType::Dense),
746           AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
747       SmallVector<Value, 4> sizes;
748       SmallVector<Value, 8> params;
749       sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src);
750       newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator,
751                 sizes, src);
752       Value iter = genNewCall(rewriter, op, params);
753       Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
754       Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
755       Block *insertionBlock = rewriter.getInsertionBlock();
756       // TODO: Dense buffers should be allocated/deallocated via the callback
757       // in BufferizationOptions.
758       Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes);
759       SmallVector<Value> noArgs;
760       SmallVector<Type> noTypes;
761       auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
762       Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
763       rewriter.setInsertionPointToEnd(before);
764       Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr);
765       rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
766       Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
767       rewriter.setInsertionPointToStart(after);
768       insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind);
769       rewriter.create<scf::YieldOp>(loc);
770       rewriter.setInsertionPointAfter(whileOp);
771       genDelCOOCall(rewriter, op, elemTp, iter);
772       rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst);
773       // Deallocate the buffer.
774       if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) {
775         rewriter.setInsertionPoint(insertionBlock->getTerminator());
776         deallocDenseTensor(rewriter, loc, dst);
777       }
778       return success();
779     }
780     if (!encDst && !encSrc) {
781       // dense => dense
782       return failure();
783     }
784     // This is a dense => sparse conversion or a sparse constant in COO =>
785     // sparse conversion, which is handled as follows:
786     //   t = newSparseCOO()
787     //   ...code to fill the COO tensor t...
788     //   s = newSparseTensor(t)
789     //
790     // To fill the COO tensor from a dense tensor:
791     //   for i1 in dim1
792     //    ..
793     //     for ik in dimk
794     //       val = a[i1,..,ik]
795     //       if val != 0
796     //         t->add(val, [i1,..,ik], [p1,..,pk])
797     //
798     // To fill the COO tensor from a sparse constant in COO format:
799     //   for i in range(NNZ)
800     //     val = values[i]
801     //     [i1,..,ik] = indices[i]
802     //     t->add(val, [i1,..,ik], [p1,..,pk])
803     //
804     // Note that the dense tensor traversal code is actually implemented
805     // using MLIR IR to avoid having to expose too much low-level
806     // memref traversal details to the runtime support library.
807     // Also note that the code below only generates the "new" ops and
808     // the loop-nest per se; whereas the entire body of the innermost
809     // loop is generated by genAddElt().
810     ShapedType stp = resType.cast<ShapedType>();
811     unsigned rank = stp.getRank();
812     SmallVector<Value, 4> sizes;
813     SmallVector<Value, 8> params;
814     sizesFromSrc(rewriter, sizes, loc, src);
815     newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes);
816     Value coo = genNewCall(rewriter, op, params);
817     Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
818     Value perm = params[2];
819     SmallVector<Value> lo;
820     SmallVector<Value> hi;
821     SmallVector<Value> st;
822     Value zero = constantIndex(rewriter, loc, 0);
823     Value one = constantIndex(rewriter, loc, 1);
824     auto indicesValues = genSplitSparseConstant(rewriter, loc, src);
825     bool isCOOConstant = indicesValues.has_value();
826     Value indices;
827     Value values;
828     if (isCOOConstant) {
829       indices = indicesValues->first;
830       values = indicesValues->second;
831       lo.push_back(zero);
832       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
833       st.push_back(one);
834     } else {
835       for (unsigned i = 0; i < rank; i++) {
836         lo.push_back(zero);
837         hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
838         st.push_back(one);
839       }
840     }
841     Type eltType = stp.getElementType();
842     Value elemPtr = genAllocaScalar(rewriter, loc, eltType);
843     scf::buildLoopNest(
844         rewriter, op.getLoc(), lo, hi, st, {},
845         [&](OpBuilder &builder, Location loc, ValueRange ivs,
846             ValueRange args) -> scf::ValueVector {
847           Value val;
848           if (isCOOConstant)
849             val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind,
850                                             ivs, rank);
851           else
852             val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
853           builder.create<memref::StoreOp>(loc, val, elemPtr);
854           genAddEltCall(rewriter, op, eltType, coo, elemPtr, ind, perm);
855           return {};
856         });
857     // Final call to construct sparse tensor storage.
858     params[6] = constantAction(rewriter, loc, Action::kFromCOO);
859     params[7] = coo;
860     Value dst = genNewCall(rewriter, op, params);
861     genDelCOOCall(rewriter, op, eltType, coo);
862     rewriter.replaceOp(op, dst);
863     return success();
864   }
865 
866 private:
867   /// Options to control sparse code generation.
868   SparseTensorConversionOptions options;
869 };
870 
871 /// Sparse conversion rule for the dealloc operator.
872 class SparseTensorDeallocConverter
873     : public OpConversionPattern<bufferization::DeallocTensorOp> {
874 public:
875   using OpConversionPattern::OpConversionPattern;
876   LogicalResult
matchAndRewrite(bufferization::DeallocTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const877   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
878                   ConversionPatternRewriter &rewriter) const override {
879     auto enc = getSparseTensorEncoding(op.getTensor().getType());
880     if (!enc)
881       return failure();
882     StringRef name = "delSparseTensor";
883     TypeRange noTp;
884     createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
885                    EmitCInterface::Off);
886     rewriter.eraseOp(op);
887     return success();
888   }
889 };
890 
891 /// Sparse conversion rule for pointer accesses.
892 class SparseTensorToPointersConverter
893     : public OpConversionPattern<ToPointersOp> {
894 public:
895   using OpConversionPattern::OpConversionPattern;
896   LogicalResult
matchAndRewrite(ToPointersOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const897   matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
898                   ConversionPatternRewriter &rewriter) const override {
899     Type resType = op.getType();
900     Type ptrType = resType.cast<ShapedType>().getElementType();
901     SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)};
902     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
903                           EmitCInterface::On);
904     return success();
905   }
906 };
907 
908 /// Sparse conversion rule for index accesses.
909 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
910 public:
911   using OpConversionPattern::OpConversionPattern;
912   LogicalResult
matchAndRewrite(ToIndicesOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const913   matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
914                   ConversionPatternRewriter &rewriter) const override {
915     Type resType = op.getType();
916     Type indType = resType.cast<ShapedType>().getElementType();
917     SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)};
918     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
919                           EmitCInterface::On);
920     return success();
921   }
922 };
923 
924 /// Sparse conversion rule for value accesses.
925 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
926 public:
927   using OpConversionPattern::OpConversionPattern;
928   LogicalResult
matchAndRewrite(ToValuesOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const929   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
930                   ConversionPatternRewriter &rewriter) const override {
931     Type resType = op.getType();
932     Type eltType = resType.cast<ShapedType>().getElementType();
933     SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)};
934     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
935                           EmitCInterface::On);
936     return success();
937   }
938 };
939 
940 /// Sparse conversion rule for tensor rematerialization.
941 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
942 public:
943   using OpConversionPattern::OpConversionPattern;
944   LogicalResult
matchAndRewrite(LoadOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const945   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
946                   ConversionPatternRewriter &rewriter) const override {
947     if (op.getHasInserts()) {
948       // Finalize any pending insertions.
949       StringRef name = "endInsert";
950       TypeRange noTp;
951       createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
952                      EmitCInterface::Off);
953     }
954     rewriter.replaceOp(op, adaptor.getOperands());
955     return success();
956   }
957 };
958 
959 /// Sparse conversion rule for inserting in lexicographic index order.
960 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
961 public:
962   using OpConversionPattern::OpConversionPattern;
963   LogicalResult
matchAndRewrite(LexInsertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const964   matchAndRewrite(LexInsertOp op, OpAdaptor adaptor,
965                   ConversionPatternRewriter &rewriter) const override {
966     Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType();
967     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
968     TypeRange noTp;
969     replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
970                           EmitCInterface::On);
971     return success();
972   }
973 };
974 
975 /// Sparse conversion rule for the expand operator.
976 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
977 public:
978   using OpConversionPattern::OpConversionPattern;
979   LogicalResult
matchAndRewrite(ExpandOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const980   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
981                   ConversionPatternRewriter &rewriter) const override {
982     Location loc = op->getLoc();
983     ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
984     Type eltType = srcType.getElementType();
985     Type boolType = rewriter.getIntegerType(1);
986     Type idxType = rewriter.getIndexType();
987     // All initialization should be done on entry of the loop nest.
988     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
989     // Determine the size for access expansion.
990     auto enc = getSparseTensorEncoding(srcType);
991     Value src = adaptor.getOperands()[0];
992     Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1);
993     // Allocate temporary buffers for values, filled-switch, and indices.
994     // We do not use stack buffers for this, since the expanded size may
995     // be rather large (as it envelops a single expanded dense dimension).
996     Value values = genAlloc(rewriter, loc, sz, eltType);
997     Value filled = genAlloc(rewriter, loc, sz, boolType);
998     Value indices = genAlloc(rewriter, loc, sz, idxType);
999     Value zero = constantZero(rewriter, loc, idxType);
1000     // Reset the values/filled-switch to all-zero/false. Note that this
1001     // introduces an O(N) operation into the computation, but this reset
1002     // operation is amortized over the innermost loops for the access
1003     // pattern expansion. As noted in the operation doc, we would like
1004     // to amortize this setup cost even between kernels.
1005     rewriter.create<linalg::FillOp>(
1006         loc, ValueRange{constantZero(rewriter, loc, eltType)},
1007         ValueRange{values});
1008     rewriter.create<linalg::FillOp>(
1009         loc, ValueRange{constantZero(rewriter, loc, boolType)},
1010         ValueRange{filled});
1011     // Replace expansion op with these buffers and initial index.
1012     assert(op.getNumResults() == 4);
1013     rewriter.replaceOp(op, {values, filled, indices, zero});
1014     return success();
1015   }
1016 };
1017 
1018 /// Sparse conversion rule for the compress operator.
1019 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
1020 public:
1021   using OpConversionPattern::OpConversionPattern;
1022   LogicalResult
matchAndRewrite(CompressOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1023   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
1024                   ConversionPatternRewriter &rewriter) const override {
1025     Location loc = op->getLoc();
1026     // Note that this method call resets the values/filled-switch back to
1027     // all-zero/false by only iterating over the set elements, so the
1028     // complexity remains proportional to the sparsity of the expanded
1029     // access pattern.
1030     Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType();
1031     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
1032     TypeRange noTp;
1033     replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
1034                           EmitCInterface::On);
1035     // Deallocate the buffers on exit of the loop nest.
1036     Operation *parent = op;
1037     for (; isa<scf::ForOp>(parent->getParentOp()) ||
1038            isa<scf::WhileOp>(parent->getParentOp()) ||
1039            isa<scf::ParallelOp>(parent->getParentOp()) ||
1040            isa<scf::IfOp>(parent->getParentOp());
1041          parent = parent->getParentOp())
1042       ;
1043     rewriter.setInsertionPointAfter(parent);
1044     rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[2]);
1045     rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[3]);
1046     rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[4]);
1047     return success();
1048   }
1049 };
1050 
1051 /// Sparse conversion rule for the output operator.
1052 class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
1053 public:
1054   using OpConversionPattern::OpConversionPattern;
1055   LogicalResult
matchAndRewrite(OutOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1056   matchAndRewrite(OutOp op, OpAdaptor adaptor,
1057                   ConversionPatternRewriter &rewriter) const override {
1058     Location loc = op->getLoc();
1059     ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
1060     // Convert to default permuted COO.
1061     Value src = adaptor.getOperands()[0];
1062     auto encSrc = getSparseTensorEncoding(srcType);
1063     SmallVector<Value, 4> sizes;
1064     SmallVector<Value, 8> params;
1065     sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src);
1066     auto enc = SparseTensorEncodingAttr::get(
1067         op->getContext(), encSrc.getDimLevelType(), AffineMap(),
1068         encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
1069     newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src);
1070     Value coo = genNewCall(rewriter, op, params);
1071     // Then output the tensor to external file with indices in the externally
1072     // visible lexicographic index order. A sort is required if the source was
1073     // not in that order yet (note that the sort can be dropped altogether if
1074     // external format does not care about the order at all, but here we assume
1075     // it does).
1076     bool sort =
1077         encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity();
1078     params.clear();
1079     params.push_back(coo);
1080     params.push_back(adaptor.getOperands()[1]);
1081     params.push_back(constantI1(rewriter, loc, sort));
1082     Type eltType = srcType.getElementType();
1083     SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)};
1084     TypeRange noTp;
1085     createFuncCall(rewriter, op, name, noTp, params, EmitCInterface::Off);
1086     genDelCOOCall(rewriter, op, eltType, coo);
1087     rewriter.eraseOp(op);
1088     return success();
1089   }
1090 };
1091 
1092 } // namespace
1093 
1094 //===----------------------------------------------------------------------===//
1095 // Public method for populating conversion rules.
1096 //===----------------------------------------------------------------------===//
1097 
1098 /// Populates the given patterns list with conversion rules required for
1099 /// the sparsification of linear algebra operations.
populateSparseTensorConversionPatterns(TypeConverter & typeConverter,RewritePatternSet & patterns,const SparseTensorConversionOptions & options)1100 void mlir::populateSparseTensorConversionPatterns(
1101     TypeConverter &typeConverter, RewritePatternSet &patterns,
1102     const SparseTensorConversionOptions &options) {
1103   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
1104                SparseCastConverter, SparseTensorNewConverter,
1105                SparseReshapeConverter<tensor::ExpandShapeOp>,
1106                SparseReshapeConverter<tensor::CollapseShapeOp>,
1107                SparseTensorAllocConverter, SparseTensorDeallocConverter,
1108                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
1109                SparseTensorToValuesConverter, SparseTensorLoadConverter,
1110                SparseTensorLexInsertConverter, SparseTensorExpandConverter,
1111                SparseTensorCompressConverter, SparseTensorOutConverter>(
1112       typeConverter, patterns.getContext());
1113   patterns.add<SparseTensorConvertConverter>(typeConverter,
1114                                              patterns.getContext(), options);
1115 }
1116