1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "CodegenUtils.h"
18
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/Dialect/Tensor/IR/Tensor.h"
29 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
30 #include "mlir/Transforms/DialectConversion.h"
31
32 using namespace mlir;
33 using namespace mlir::sparse_tensor;
34
35 namespace {
36
37 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
38 /// `createFuncCall()`, and `replaceOpWithFuncCall()`.
39 enum class EmitCInterface : bool { Off = false, On = true };
40
41 //===----------------------------------------------------------------------===//
42 // Helper methods.
43 //===----------------------------------------------------------------------===//
44
45 /// Returns the equivalent of `void*` for opaque arguments to the
46 /// execution engine.
getOpaquePointerType(OpBuilder & builder)47 static Type getOpaquePointerType(OpBuilder &builder) {
48 return LLVM::LLVMPointerType::get(builder.getI8Type());
49 }
50
51 /// Returns a function reference (first hit also inserts into module). Sets
52 /// the "_emit_c_interface" on the function declaration when requested,
53 /// so that LLVM lowering generates a wrapper function that takes care
54 /// of ABI complications with passing in and returning MemRefs to C functions.
getFunc(Operation * op,StringRef name,TypeRange resultType,ValueRange operands,EmitCInterface emitCInterface)55 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
56 TypeRange resultType, ValueRange operands,
57 EmitCInterface emitCInterface) {
58 MLIRContext *context = op->getContext();
59 auto module = op->getParentOfType<ModuleOp>();
60 auto result = SymbolRefAttr::get(context, name);
61 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
62 if (!func) {
63 OpBuilder moduleBuilder(module.getBodyRegion());
64 func = moduleBuilder.create<func::FuncOp>(
65 op->getLoc(), name,
66 FunctionType::get(context, operands.getTypes(), resultType));
67 func.setPrivate();
68 if (static_cast<bool>(emitCInterface))
69 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
70 UnitAttr::get(context));
71 }
72 return result;
73 }
74
75 /// Creates a `CallOp` to the function reference returned by `getFunc()`.
createFuncCall(OpBuilder & builder,Operation * op,StringRef name,TypeRange resultType,ValueRange operands,EmitCInterface emitCInterface)76 static func::CallOp createFuncCall(OpBuilder &builder, Operation *op,
77 StringRef name, TypeRange resultType,
78 ValueRange operands,
79 EmitCInterface emitCInterface) {
80 auto fn = getFunc(op, name, resultType, operands, emitCInterface);
81 return builder.create<func::CallOp>(op->getLoc(), resultType, fn, operands);
82 }
83
84 /// Replaces the `op` with a `CallOp` to the function reference returned
85 /// by `getFunc()`.
replaceOpWithFuncCall(RewriterBase & rewriter,Operation * op,StringRef name,TypeRange resultType,ValueRange operands,EmitCInterface emitCInterface)86 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
87 StringRef name, TypeRange resultType,
88 ValueRange operands,
89 EmitCInterface emitCInterface) {
90 auto fn = getFunc(op, name, resultType, operands, emitCInterface);
91 return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
92 operands);
93 }
94
95 /// Generates dimension size call.
genDimSizeCall(OpBuilder & builder,Operation * op,SparseTensorEncodingAttr & enc,Value src,int64_t idx)96 static Value genDimSizeCall(OpBuilder &builder, Operation *op,
97 SparseTensorEncodingAttr &enc, Value src,
98 int64_t idx) {
99 // Permute the index according to an optional dimension ordering.
100 if (AffineMap p = enc.getDimOrdering())
101 idx = p.getPermutedPosition(idx);
102 // Generate the call.
103 StringRef name = "sparseDimSize";
104 SmallVector<Value, 2> params{src, constantIndex(builder, op->getLoc(), idx)};
105 Type iTp = builder.getIndexType();
106 return createFuncCall(builder, op, name, iTp, params, EmitCInterface::Off)
107 .getResult(0);
108 }
109
110 /// Generates a call into the "swiss army knife" method of the sparse runtime
111 /// support library for materializing sparse tensors into the computation.
genNewCall(OpBuilder & builder,Operation * op,ArrayRef<Value> params)112 static Value genNewCall(OpBuilder &builder, Operation *op,
113 ArrayRef<Value> params) {
114 StringRef name = "newSparseTensor";
115 Type pTp = getOpaquePointerType(builder);
116 return createFuncCall(builder, op, name, pTp, params, EmitCInterface::On)
117 .getResult(0);
118 }
119
120 /// Populates given sizes array from type.
sizesFromType(OpBuilder & builder,SmallVector<Value,4> & sizes,Location loc,ShapedType stp)121 static void sizesFromType(OpBuilder &builder, SmallVector<Value, 4> &sizes,
122 Location loc, ShapedType stp) {
123 auto shape = stp.getShape();
124 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) {
125 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
126 sizes.push_back(constantIndex(builder, loc, s));
127 }
128 }
129
130 /// Populates given sizes array from source.
sizesFromSrc(OpBuilder & builder,SmallVector<Value,4> & sizes,Location loc,Value src)131 static void sizesFromSrc(OpBuilder &builder, SmallVector<Value, 4> &sizes,
132 Location loc, Value src) {
133 unsigned rank = src.getType().cast<ShapedType>().getRank();
134 for (unsigned i = 0; i < rank; i++)
135 sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, i));
136 }
137
138 /// Populates given sizes array from type (for static sizes) and from
139 /// an already converted into opague pointer source (for dynamic sizes).
sizesFromPtr(OpBuilder & builder,SmallVector<Value,4> & sizes,Operation * op,SparseTensorEncodingAttr & enc,ShapedType stp,Value src)140 static void sizesFromPtr(OpBuilder &builder, SmallVector<Value, 4> &sizes,
141 Operation *op, SparseTensorEncodingAttr &enc,
142 ShapedType stp, Value src) {
143 Location loc = op->getLoc();
144 auto shape = stp.getShape();
145 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
146 if (shape[i] == ShapedType::kDynamicSize)
147 sizes.push_back(genDimSizeCall(builder, op, enc, src, i));
148 else
149 sizes.push_back(constantIndex(builder, loc, shape[i]));
150 }
151
152 /// Generates an uninitialized temporary buffer of the given size and
153 /// type, but returns it as type `memref<? x $tp>` (rather than as type
154 /// `memref<$sz x $tp>`).
genAlloca(OpBuilder & builder,Location loc,Value sz,Type tp)155 static Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) {
156 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
157 return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz});
158 }
159
160 /// Generates an uninitialized buffer of the given size and type,
161 /// but returns it as type `memref<? x $tp>` (rather than as type
162 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
163 /// this buffer must be explicitly deallocated by client.
genAlloc(RewriterBase & rewriter,Location loc,Value sz,Type tp)164 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
165 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
166 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
167 }
168
169 /// Generates an uninitialized temporary buffer of the given size and
170 /// type, but returns it as type `memref<? x $tp>` (rather than as type
171 /// `memref<$sz x $tp>`).
genAlloca(OpBuilder & builder,Location loc,unsigned sz,Type tp)172 static Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp) {
173 return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp);
174 }
175
176 /// Generates an uninitialized temporary buffer with room for one value
177 /// of the given type, and returns the `memref<$tp>`.
genAllocaScalar(OpBuilder & builder,Location loc,Type tp)178 static Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp) {
179 return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
180 }
181
182 /// Generates a temporary buffer of the given type and given contents.
genBuffer(OpBuilder & builder,Location loc,ValueRange values)183 static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
184 unsigned sz = values.size();
185 assert(sz >= 1);
186 Value buffer = genAlloca(builder, loc, sz, values[0].getType());
187 for (unsigned i = 0; i < sz; i++) {
188 Value idx = constantIndex(builder, loc, i);
189 builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
190 }
191 return buffer;
192 }
193
194 /// Populates parameters required to call the "swiss army knife" method of the
195 /// sparse runtime support library for materializing sparse tensors into the
196 /// computation.
newParams(OpBuilder & builder,SmallVector<Value,8> & params,Operation * op,ShapedType stp,SparseTensorEncodingAttr & enc,Action action,ValueRange szs,Value ptr=Value ())197 static void newParams(OpBuilder &builder, SmallVector<Value, 8> ¶ms,
198 Operation *op, ShapedType stp,
199 SparseTensorEncodingAttr &enc, Action action,
200 ValueRange szs, Value ptr = Value()) {
201 Location loc = op->getLoc();
202 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
203 unsigned sz = dlt.size();
204 // Sparsity annotations.
205 SmallVector<Value, 4> attrs;
206 for (unsigned i = 0; i < sz; i++)
207 attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i]));
208 params.push_back(genBuffer(builder, loc, attrs));
209 // Dimension sizes array of the enveloping tensor. Useful for either
210 // verification of external data, or for construction of internal data.
211 params.push_back(genBuffer(builder, loc, szs));
212 // Dimension order permutation array. This is the "identity" permutation by
213 // default, or otherwise the "reverse" permutation of a given ordering, so
214 // that indices can be mapped quickly to the right position.
215 SmallVector<Value, 4> rev(sz);
216 if (AffineMap p = enc.getDimOrdering()) {
217 for (unsigned i = 0; i < sz; i++)
218 rev[p.getDimPosition(i)] = constantIndex(builder, loc, i);
219 } else {
220 for (unsigned i = 0; i < sz; i++)
221 rev[i] = constantIndex(builder, loc, i);
222 }
223 params.push_back(genBuffer(builder, loc, rev));
224 // Secondary and primary types encoding.
225 Type elemTp = stp.getElementType();
226 params.push_back(constantPointerTypeEncoding(builder, loc, enc));
227 params.push_back(constantIndexTypeEncoding(builder, loc, enc));
228 params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp));
229 // User action.
230 params.push_back(constantAction(builder, loc, action));
231 // Payload pointer.
232 if (!ptr)
233 ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder));
234 params.push_back(ptr);
235 }
236
237 /// Generates the code to read the value from tensor[ivs], and conditionally
238 /// stores the indices ivs to the memory in ind. The generated code looks like
239 /// the following and the insertion point after this routine is inside the
240 /// if-then branch behind the assignment to ind. This is to ensure that the
241 /// addEltX call generated after is inside the if-then branch.
242 /// if (tensor[ivs] != 0)
243 /// ind = ivs
genIndexAndValueForDense(OpBuilder & builder,Location loc,Value tensor,Value ind,ValueRange ivs)244 static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
245 Value tensor, Value ind, ValueRange ivs) {
246 Value val = builder.create<tensor::ExtractOp>(loc, tensor, ivs);
247 Value cond = genIsNonzero(builder, loc, val);
248 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else*/ false);
249 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
250 unsigned i = 0;
251 for (auto iv : ivs) {
252 Value idx = constantIndex(builder, loc, i++);
253 builder.create<memref::StoreOp>(loc, iv, ind, idx);
254 }
255 return val;
256 }
257
258 /// Generates a call to release/delete a `SparseTensorCOO`.
genDelCOOCall(OpBuilder & builder,Operation * op,Type elemTp,Value coo)259 static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp,
260 Value coo) {
261 SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
262 TypeRange noTp;
263 createFuncCall(builder, op, name, noTp, coo, EmitCInterface::Off);
264 }
265
266 /// Generates a call that adds one element to a coordinate scheme.
267 /// In particular, this generates code like the following:
268 /// val = a[i1,..,ik];
269 /// if val != 0
270 /// t->add(&val, [i1,..,ik], [p1,..,pk]);
genAddEltCall(OpBuilder & builder,Operation * op,Type eltType,Value ptr,Value valPtr,Value ind,Value perm)271 static void genAddEltCall(OpBuilder &builder, Operation *op, Type eltType,
272 Value ptr, Value valPtr, Value ind, Value perm) {
273 SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
274 SmallVector<Value, 4> params{ptr, valPtr, ind, perm};
275 Type pTp = getOpaquePointerType(builder);
276 createFuncCall(builder, op, name, pTp, params, EmitCInterface::On);
277 }
278
279 /// Generates a call to `iter->getNext()`. If there is a next element,
280 /// then it is copied into the out-parameters `ind` and `elemPtr`,
281 /// and the return value is true. If there isn't a next element, then
282 /// the memory for `iter` is freed and the return value is false.
genGetNextCall(OpBuilder & builder,Operation * op,Value iter,Value ind,Value elemPtr)283 static Value genGetNextCall(OpBuilder &builder, Operation *op, Value iter,
284 Value ind, Value elemPtr) {
285 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
286 SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)};
287 SmallVector<Value, 3> params{iter, ind, elemPtr};
288 Type i1 = builder.getI1Type();
289 return createFuncCall(builder, op, name, i1, params, EmitCInterface::On)
290 .getResult(0);
291 }
292
293 /// If the tensor is a sparse constant, generates and returns the pair of
294 /// the constants for the indices and the values.
295 static Optional<std::pair<Value, Value>>
genSplitSparseConstant(OpBuilder & builder,Location loc,Value tensor)296 genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) {
297 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
298 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
299 DenseElementsAttr indicesAttr = attr.getIndices();
300 Value indices = builder.create<arith::ConstantOp>(loc, indicesAttr);
301 DenseElementsAttr valuesAttr = attr.getValues();
302 Value values = builder.create<arith::ConstantOp>(loc, valuesAttr);
303 return std::make_pair(indices, values);
304 }
305 }
306 return {};
307 }
308
309 /// Generates the code to copy the index at indices[ivs] to ind, and return
310 /// the value at value[ivs].
genIndexAndValueForSparse(OpBuilder & builder,Location loc,Value indices,Value values,Value ind,ValueRange ivs,unsigned rank)311 static Value genIndexAndValueForSparse(OpBuilder &builder, Location loc,
312 Value indices, Value values, Value ind,
313 ValueRange ivs, unsigned rank) {
314 for (unsigned i = 0; i < rank; i++) {
315 Value idx = constantIndex(builder, loc, i);
316 Value val = builder.create<tensor::ExtractOp>(loc, indices,
317 ValueRange{ivs[0], idx});
318 val = builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), val);
319 builder.create<memref::StoreOp>(loc, val, ind, idx);
320 }
321 return builder.create<tensor::ExtractOp>(loc, values, ivs[0]);
322 }
323
324 /// Generates code to allocate a buffer of the given type, and zero
325 /// initialize it. If the buffer type has any dynamic sizes, then the
326 /// `sizes` parameter should be as filled by sizesFromPtr(); that way
327 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr().
allocDenseTensor(OpBuilder & builder,Location loc,RankedTensorType tensorTp,ValueRange sizes)328 static Value allocDenseTensor(OpBuilder &builder, Location loc,
329 RankedTensorType tensorTp, ValueRange sizes) {
330 Type elemTp = tensorTp.getElementType();
331 auto shape = tensorTp.getShape();
332 auto memTp = MemRefType::get(shape, elemTp);
333 SmallVector<Value> dynamicSizes;
334 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) {
335 if (shape[i] == ShapedType::kDynamicSize)
336 dynamicSizes.push_back(sizes[i]);
337 }
338 Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes);
339 Value zero = constantZero(builder, loc, elemTp);
340 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem});
341 return mem;
342 }
343
344 /// Generates code to deallocate a dense buffer.
deallocDenseTensor(OpBuilder & builder,Location loc,Value buffer)345 static void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer) {
346 builder.create<memref::DeallocOp>(loc, buffer);
347 }
348
349 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into
350 /// the tensor created by allocDenseTensor(). The `rank` is the rank
351 /// of the `tensor` and the length of `ind`.
insertScalarIntoDenseTensor(OpBuilder & builder,Location loc,Value elemPtr,Value tensor,unsigned rank,Value ind)352 static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc,
353 Value elemPtr, Value tensor,
354 unsigned rank, Value ind) {
355 SmallVector<Value, 4> ivs;
356 ivs.reserve(rank);
357 for (unsigned i = 0; i < rank; i++) {
358 Value idx = constantIndex(builder, loc, i);
359 ivs.push_back(builder.create<memref::LoadOp>(loc, ind, idx));
360 }
361 Value elemV = builder.create<memref::LoadOp>(loc, elemPtr);
362 builder.create<memref::StoreOp>(loc, elemV, tensor, ivs);
363 }
364
365 /// Determine if the runtime library supports direct conversion to the
366 /// given target `dimTypes`.
canUseDirectConversion(ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes)367 static bool canUseDirectConversion(
368 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes) {
369 bool alreadyCompressed = false;
370 for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) {
371 switch (dimTypes[r]) {
372 case SparseTensorEncodingAttr::DimLevelType::Compressed:
373 if (alreadyCompressed)
374 return false; // Multiple compressed dimensions not yet supported.
375 alreadyCompressed = true;
376 break;
377 case SparseTensorEncodingAttr::DimLevelType::Dense:
378 if (alreadyCompressed)
379 return false; // Dense after Compressed not yet supported.
380 break;
381 case SparseTensorEncodingAttr::DimLevelType::Singleton:
382 // Although Singleton isn't generally supported yet, the direct
383 // conversion method doesn't have any particular problems with
384 // singleton after compressed.
385 break;
386 }
387 }
388 return true;
389 }
390
391 /// Helper method to translate indices during a reshaping operation.
392 /// TODO: provide as general utility to MLIR at large?
translateIndices(Location loc,ConversionPatternRewriter & rewriter,ArrayRef<ReassociationIndices> reassociation,TensorType dstTp,TensorType srcTp,Value dstIdx,Value srcIdx)393 static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
394 ArrayRef<ReassociationIndices> reassociation,
395 TensorType dstTp, TensorType srcTp, Value dstIdx,
396 Value srcIdx) {
397 unsigned dstRank = dstTp.getRank();
398 unsigned srcRank = srcTp.getRank();
399 unsigned start = 0;
400 unsigned i = 0;
401 bool isExpand = srcRank > dstRank;
402 ArrayRef<int64_t> shape = isExpand ? srcTp.getShape() : dstTp.getShape();
403 // Iterate over reassociation map.
404 for (const auto &map : llvm::enumerate(reassociation)) {
405 // Prepare strides information in dimension slice.
406 uint64_t linear = 1;
407 for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
408 assert(!ShapedType::isDynamic(shape[j]));
409 linear *= shape[j];
410 }
411 // Start collapse.
412 Value idx = constantIndex(rewriter, loc, i++);
413 Value val;
414 if (!isExpand)
415 val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx);
416 // Iterate over dimension slice.
417 for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
418 linear /= shape[j];
419 Value stride = constantIndex(rewriter, loc, linear);
420 Value jdx = constantIndex(rewriter, loc, j);
421 if (isExpand) {
422 Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx);
423 Value mul = linear == 1
424 ? old
425 : rewriter.create<arith::MulIOp>(loc, old, stride);
426 val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul;
427 } else {
428 Value old = val;
429 if (linear != 1)
430 val = rewriter.create<arith::DivUIOp>(loc, val, stride);
431 rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx);
432 if (linear != 1)
433 val = rewriter.create<arith::RemUIOp>(loc, old, stride);
434 }
435 }
436 // Finalize expansion.
437 if (isExpand)
438 rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx);
439 start += map.value().size();
440 }
441 // Sanity.
442 assert((isExpand && i == dstRank) || (!isExpand && i == srcRank));
443 }
444
445 /// Generate code for a general sparse to sparse reshaping operation.
446 /// Note that unlike dense reshaping (which can be done with a "cheap"
447 /// change of view), sparse reshaping is currently done with actual
448 /// data shuffling.
449 ///
450 /// TODO: proportional to nnz, but still a lot of data movement
451 /// https://github.com/llvm/llvm-project/issues/56477
452 ///
453 /// iter = src->toCOO();
454 /// coo = newSparseCOO()
455 /// while (elem = iter->getNext()) {
456 /// coo->add(reshape(elem.indices), elem.value)
457 /// }
458 /// s = newSparseTensor(coo)
459 static LogicalResult
genSparse2SparseReshape(Operation * op,ConversionPatternRewriter & rewriter,ArrayRef<ReassociationIndices> reassociation,Value src,RankedTensorType dstTp,RankedTensorType srcTp)460 genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
461 ArrayRef<ReassociationIndices> reassociation, Value src,
462 RankedTensorType dstTp, RankedTensorType srcTp) {
463 Location loc = op->getLoc();
464 auto encDst = getSparseTensorEncoding(dstTp);
465 auto encSrc = getSparseTensorEncoding(srcTp);
466 assert(encDst && encSrc);
467 unsigned srcRank = srcTp.getRank();
468 unsigned dstRank = dstTp.getRank();
469 Type elemTp = srcTp.getElementType();
470 assert(elemTp == dstTp.getElementType() &&
471 "reshape should not change element type");
472 // Start an iterator over the source tensor (in original index order).
473 auto noPerm = SparseTensorEncodingAttr::get(
474 op->getContext(), encSrc.getDimLevelType(), AffineMap(),
475 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
476 SmallVector<Value, 4> sizes;
477 SmallVector<Value, 8> params;
478 sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src);
479 newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes,
480 src);
481 Value iter = genNewCall(rewriter, op, params);
482 // Start a new COO for the destination tensor.
483 sizes.clear();
484 params.clear();
485 sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src);
486 newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes);
487 Value coo = genNewCall(rewriter, op, params);
488 Value dstPerm = params[2];
489 // Construct a while loop over the iterator.
490 Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
491 Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
492 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
493 SmallVector<Value> noArgs;
494 SmallVector<Type> noTypes;
495 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
496 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
497 rewriter.setInsertionPointToEnd(before);
498 Value cond = genGetNextCall(rewriter, op, iter, srcIdx, elemPtr);
499 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
500 // Translate indices from source to target and insert. Note that we do
501 // not need to store the value in elemPtr, as the value is still there.
502 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
503 rewriter.setInsertionPointToStart(after);
504 translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx);
505 genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm);
506 rewriter.create<scf::YieldOp>(loc);
507 // Final call to construct sparse tensor storage and free temporary resources.
508 rewriter.setInsertionPointAfter(whileOp);
509 params[6] = constantAction(rewriter, loc, Action::kFromCOO);
510 params[7] = coo;
511 Value dst = genNewCall(rewriter, op, params);
512 genDelCOOCall(rewriter, op, elemTp, coo);
513 genDelCOOCall(rewriter, op, elemTp, iter);
514 rewriter.replaceOp(op, dst);
515 return success();
516 }
517
518 //===----------------------------------------------------------------------===//
519 // Conversion rules.
520 //===----------------------------------------------------------------------===//
521
522 /// Sparse conversion rule for returns.
523 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
524 public:
525 using OpConversionPattern::OpConversionPattern;
526 LogicalResult
matchAndRewrite(func::ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const527 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
528 ConversionPatternRewriter &rewriter) const override {
529 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
530 return success();
531 }
532 };
533
534 /// Sparse conversion rule for dimension accesses.
535 class SparseTensorToDimSizeConverter
536 : public OpConversionPattern<tensor::DimOp> {
537 public:
538 using OpConversionPattern::OpConversionPattern;
539 LogicalResult
matchAndRewrite(tensor::DimOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const540 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
541 ConversionPatternRewriter &rewriter) const override {
542 // Only rewrite annotated DimOp with constant index.
543 auto enc = getSparseTensorEncoding(op.getSource().getType());
544 if (!enc)
545 return failure();
546 Optional<int64_t> index = op.getConstantIndex();
547 if (!index)
548 return failure();
549 // Generate the call.
550 Value src = adaptor.getOperands()[0];
551 int64_t idx = *index;
552 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx));
553 return success();
554 }
555 };
556
557 /// Sparse conversion rule for trivial tensor casts.
558 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
559 public:
560 using OpConversionPattern::OpConversionPattern;
561 LogicalResult
matchAndRewrite(tensor::CastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const562 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
563 ConversionPatternRewriter &rewriter) const override {
564 // Only rewrite identically annotated source/dest.
565 auto encDst = getSparseTensorEncoding(op.getType());
566 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
567 if (!encDst || encDst != encSrc)
568 return failure();
569 rewriter.replaceOp(op, adaptor.getOperands());
570 return success();
571 }
572 };
573
574 /// Sparse conversion rule for a reshape operator.
575 template <typename ReshapeOp>
576 class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> {
577 public:
578 using OpAdaptor = typename OpConversionPattern<ReshapeOp>::OpAdaptor;
579 using OpConversionPattern<ReshapeOp>::OpConversionPattern;
580 LogicalResult
matchAndRewrite(ReshapeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const581 matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
582 ConversionPatternRewriter &rewriter) const override {
583 Type dstType = op.getResult().getType();
584 Type srcType = op.getSrc().getType();
585 auto encDst = getSparseTensorEncoding(dstType);
586 auto encSrc = getSparseTensorEncoding(srcType);
587 if (encDst && encSrc)
588 return genSparse2SparseReshape(
589 op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0],
590 dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>());
591 return failure(); // handled elsewhere
592 }
593 };
594
595 /// Sparse conversion rule for the new operator.
596 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
597 public:
598 using OpConversionPattern::OpConversionPattern;
599 LogicalResult
matchAndRewrite(NewOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const600 matchAndRewrite(NewOp op, OpAdaptor adaptor,
601 ConversionPatternRewriter &rewriter) const override {
602 Type resType = op.getType();
603 auto enc = getSparseTensorEncoding(resType);
604 if (!enc)
605 return failure();
606 // Generate the call to construct tensor from ptr. The sizes are
607 // inferred from the result type of the new operator.
608 SmallVector<Value, 4> sizes;
609 SmallVector<Value, 8> params;
610 ShapedType stp = resType.cast<ShapedType>();
611 sizesFromType(rewriter, sizes, op.getLoc(), stp);
612 Value ptr = adaptor.getOperands()[0];
613 newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr);
614 rewriter.replaceOp(op, genNewCall(rewriter, op, params));
615 return success();
616 }
617 };
618
619 /// Sparse conversion rule for the alloc operator.
620 class SparseTensorAllocConverter
621 : public OpConversionPattern<bufferization::AllocTensorOp> {
622 public:
623 using OpConversionPattern::OpConversionPattern;
624 LogicalResult
matchAndRewrite(bufferization::AllocTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const625 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
626 ConversionPatternRewriter &rewriter) const override {
627 if (op.getCopy())
628 return rewriter.notifyMatchFailure(op,
629 "sparse tensor copy not implemented");
630 RankedTensorType resType = op.getType();
631 auto enc = getSparseTensorEncoding(resType);
632 if (!enc)
633 return failure();
634 // Gather all dimension sizes as SSA values.
635 SmallVector<Value> sizes;
636 unsigned int operandCtr = 0;
637 for (int64_t i = 0; i < resType.getRank(); ++i) {
638 if (resType.isDynamicDim(i)) {
639 sizes.push_back(adaptor.getOperands()[operandCtr++]);
640 } else {
641 sizes.push_back(rewriter.create<arith::ConstantIndexOp>(
642 op.getLoc(), op.getStaticSize(i)));
643 }
644 }
645 // Generate the call to construct empty tensor. The sizes are
646 // explicitly defined by the arguments to the alloc operator.
647 SmallVector<Value, 8> params;
648 ShapedType stp = resType.cast<ShapedType>();
649 newParams(rewriter, params, op, stp, enc, Action::kEmpty, sizes);
650 rewriter.replaceOp(op, genNewCall(rewriter, op, params));
651 return success();
652 }
653 };
654
655 /// Sparse conversion rule for the convert operator.
656 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
657 public:
658 using OpConversionPattern::OpConversionPattern;
SparseTensorConvertConverter(MLIRContext * context,SparseTensorConversionOptions o)659 SparseTensorConvertConverter(MLIRContext *context,
660 SparseTensorConversionOptions o)
661 : OpConversionPattern<ConvertOp>(context), options(o) {}
SparseTensorConvertConverter(TypeConverter & typeConv,MLIRContext * context,SparseTensorConversionOptions o)662 SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
663 SparseTensorConversionOptions o)
664 : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
665
666 LogicalResult
matchAndRewrite(ConvertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const667 matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
668 ConversionPatternRewriter &rewriter) const override {
669 Location loc = op->getLoc();
670 Type resType = op.getType();
671 Type srcType = op.getSource().getType();
672 auto encDst = getSparseTensorEncoding(resType);
673 auto encSrc = getSparseTensorEncoding(srcType);
674 Value src = adaptor.getOperands()[0];
675 if (encDst && encSrc) {
676 // This is a sparse => sparse conversion, which is handled as follows:
677 // t = src->toCOO(); ; src to COO in dst order
678 // dst = newSparseTensor(t)
679 // Using the coordinate scheme as an intermediate does not always
680 // yield the fastest conversion but avoids the need for a full
681 // O(N^2) conversion matrix.
682 if (encDst == encSrc) {
683 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
684 return success();
685 }
686 SmallVector<Value, 4> sizes;
687 SmallVector<Value, 8> params;
688 ShapedType stp = srcType.cast<ShapedType>();
689 sizesFromPtr(rewriter, sizes, op, encSrc, stp, src);
690 bool useDirectConversion;
691 switch (options.sparseToSparseStrategy) {
692 case SparseToSparseConversionStrategy::kViaCOO:
693 useDirectConversion = false;
694 break;
695 case SparseToSparseConversionStrategy::kDirect:
696 useDirectConversion = true;
697 assert(canUseDirectConversion(encDst.getDimLevelType()) &&
698 "Unsupported target for direct sparse-to-sparse conversion");
699 break;
700 case SparseToSparseConversionStrategy::kAuto:
701 useDirectConversion = canUseDirectConversion(encDst.getDimLevelType());
702 break;
703 }
704 if (useDirectConversion) {
705 newParams(rewriter, params, op, stp, encDst, Action::kSparseToSparse,
706 sizes, src);
707 rewriter.replaceOp(op, genNewCall(rewriter, op, params));
708 } else { // use via-COO conversion.
709 // Set up encoding with right mix of src and dst so that the two
710 // method calls can share most parameters, while still providing
711 // the correct sparsity information to either of them.
712 auto enc = SparseTensorEncodingAttr::get(
713 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
714 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
715 newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src);
716 Value coo = genNewCall(rewriter, op, params);
717 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
718 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
719 params[6] = constantAction(rewriter, loc, Action::kFromCOO);
720 params[7] = coo;
721 Value dst = genNewCall(rewriter, op, params);
722 genDelCOOCall(rewriter, op, stp.getElementType(), coo);
723 rewriter.replaceOp(op, dst);
724 }
725 return success();
726 }
727 if (!encDst && encSrc) {
728 // This is sparse => dense conversion, which is handled as follows:
729 // dst = new Tensor(0);
730 // iter = src->toCOO();
731 // iter->startIterator();
732 // while (elem = iter->getNext()) {
733 // dst[elem.indices] = elem.value;
734 // }
735 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>();
736 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>();
737 unsigned rank = dstTensorTp.getRank();
738 Type elemTp = dstTensorTp.getElementType();
739 // Fabricate a no-permutation encoding for newParams().
740 // The pointer/index types must be those of `src`.
741 // The dimLevelTypes aren't actually used by Action::kToIterator.
742 encDst = SparseTensorEncodingAttr::get(
743 op->getContext(),
744 SmallVector<SparseTensorEncodingAttr::DimLevelType>(
745 rank, SparseTensorEncodingAttr::DimLevelType::Dense),
746 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
747 SmallVector<Value, 4> sizes;
748 SmallVector<Value, 8> params;
749 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src);
750 newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator,
751 sizes, src);
752 Value iter = genNewCall(rewriter, op, params);
753 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
754 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
755 Block *insertionBlock = rewriter.getInsertionBlock();
756 // TODO: Dense buffers should be allocated/deallocated via the callback
757 // in BufferizationOptions.
758 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes);
759 SmallVector<Value> noArgs;
760 SmallVector<Type> noTypes;
761 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
762 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
763 rewriter.setInsertionPointToEnd(before);
764 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr);
765 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
766 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
767 rewriter.setInsertionPointToStart(after);
768 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind);
769 rewriter.create<scf::YieldOp>(loc);
770 rewriter.setInsertionPointAfter(whileOp);
771 genDelCOOCall(rewriter, op, elemTp, iter);
772 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst);
773 // Deallocate the buffer.
774 if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) {
775 rewriter.setInsertionPoint(insertionBlock->getTerminator());
776 deallocDenseTensor(rewriter, loc, dst);
777 }
778 return success();
779 }
780 if (!encDst && !encSrc) {
781 // dense => dense
782 return failure();
783 }
784 // This is a dense => sparse conversion or a sparse constant in COO =>
785 // sparse conversion, which is handled as follows:
786 // t = newSparseCOO()
787 // ...code to fill the COO tensor t...
788 // s = newSparseTensor(t)
789 //
790 // To fill the COO tensor from a dense tensor:
791 // for i1 in dim1
792 // ..
793 // for ik in dimk
794 // val = a[i1,..,ik]
795 // if val != 0
796 // t->add(val, [i1,..,ik], [p1,..,pk])
797 //
798 // To fill the COO tensor from a sparse constant in COO format:
799 // for i in range(NNZ)
800 // val = values[i]
801 // [i1,..,ik] = indices[i]
802 // t->add(val, [i1,..,ik], [p1,..,pk])
803 //
804 // Note that the dense tensor traversal code is actually implemented
805 // using MLIR IR to avoid having to expose too much low-level
806 // memref traversal details to the runtime support library.
807 // Also note that the code below only generates the "new" ops and
808 // the loop-nest per se; whereas the entire body of the innermost
809 // loop is generated by genAddElt().
810 ShapedType stp = resType.cast<ShapedType>();
811 unsigned rank = stp.getRank();
812 SmallVector<Value, 4> sizes;
813 SmallVector<Value, 8> params;
814 sizesFromSrc(rewriter, sizes, loc, src);
815 newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes);
816 Value coo = genNewCall(rewriter, op, params);
817 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
818 Value perm = params[2];
819 SmallVector<Value> lo;
820 SmallVector<Value> hi;
821 SmallVector<Value> st;
822 Value zero = constantIndex(rewriter, loc, 0);
823 Value one = constantIndex(rewriter, loc, 1);
824 auto indicesValues = genSplitSparseConstant(rewriter, loc, src);
825 bool isCOOConstant = indicesValues.has_value();
826 Value indices;
827 Value values;
828 if (isCOOConstant) {
829 indices = indicesValues->first;
830 values = indicesValues->second;
831 lo.push_back(zero);
832 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
833 st.push_back(one);
834 } else {
835 for (unsigned i = 0; i < rank; i++) {
836 lo.push_back(zero);
837 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
838 st.push_back(one);
839 }
840 }
841 Type eltType = stp.getElementType();
842 Value elemPtr = genAllocaScalar(rewriter, loc, eltType);
843 scf::buildLoopNest(
844 rewriter, op.getLoc(), lo, hi, st, {},
845 [&](OpBuilder &builder, Location loc, ValueRange ivs,
846 ValueRange args) -> scf::ValueVector {
847 Value val;
848 if (isCOOConstant)
849 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind,
850 ivs, rank);
851 else
852 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
853 builder.create<memref::StoreOp>(loc, val, elemPtr);
854 genAddEltCall(rewriter, op, eltType, coo, elemPtr, ind, perm);
855 return {};
856 });
857 // Final call to construct sparse tensor storage.
858 params[6] = constantAction(rewriter, loc, Action::kFromCOO);
859 params[7] = coo;
860 Value dst = genNewCall(rewriter, op, params);
861 genDelCOOCall(rewriter, op, eltType, coo);
862 rewriter.replaceOp(op, dst);
863 return success();
864 }
865
866 private:
867 /// Options to control sparse code generation.
868 SparseTensorConversionOptions options;
869 };
870
871 /// Sparse conversion rule for the dealloc operator.
872 class SparseTensorDeallocConverter
873 : public OpConversionPattern<bufferization::DeallocTensorOp> {
874 public:
875 using OpConversionPattern::OpConversionPattern;
876 LogicalResult
matchAndRewrite(bufferization::DeallocTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const877 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
878 ConversionPatternRewriter &rewriter) const override {
879 auto enc = getSparseTensorEncoding(op.getTensor().getType());
880 if (!enc)
881 return failure();
882 StringRef name = "delSparseTensor";
883 TypeRange noTp;
884 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
885 EmitCInterface::Off);
886 rewriter.eraseOp(op);
887 return success();
888 }
889 };
890
891 /// Sparse conversion rule for pointer accesses.
892 class SparseTensorToPointersConverter
893 : public OpConversionPattern<ToPointersOp> {
894 public:
895 using OpConversionPattern::OpConversionPattern;
896 LogicalResult
matchAndRewrite(ToPointersOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const897 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
898 ConversionPatternRewriter &rewriter) const override {
899 Type resType = op.getType();
900 Type ptrType = resType.cast<ShapedType>().getElementType();
901 SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)};
902 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
903 EmitCInterface::On);
904 return success();
905 }
906 };
907
908 /// Sparse conversion rule for index accesses.
909 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
910 public:
911 using OpConversionPattern::OpConversionPattern;
912 LogicalResult
matchAndRewrite(ToIndicesOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const913 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
914 ConversionPatternRewriter &rewriter) const override {
915 Type resType = op.getType();
916 Type indType = resType.cast<ShapedType>().getElementType();
917 SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)};
918 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
919 EmitCInterface::On);
920 return success();
921 }
922 };
923
924 /// Sparse conversion rule for value accesses.
925 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
926 public:
927 using OpConversionPattern::OpConversionPattern;
928 LogicalResult
matchAndRewrite(ToValuesOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const929 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
930 ConversionPatternRewriter &rewriter) const override {
931 Type resType = op.getType();
932 Type eltType = resType.cast<ShapedType>().getElementType();
933 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)};
934 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
935 EmitCInterface::On);
936 return success();
937 }
938 };
939
940 /// Sparse conversion rule for tensor rematerialization.
941 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
942 public:
943 using OpConversionPattern::OpConversionPattern;
944 LogicalResult
matchAndRewrite(LoadOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const945 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
946 ConversionPatternRewriter &rewriter) const override {
947 if (op.getHasInserts()) {
948 // Finalize any pending insertions.
949 StringRef name = "endInsert";
950 TypeRange noTp;
951 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
952 EmitCInterface::Off);
953 }
954 rewriter.replaceOp(op, adaptor.getOperands());
955 return success();
956 }
957 };
958
959 /// Sparse conversion rule for inserting in lexicographic index order.
960 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
961 public:
962 using OpConversionPattern::OpConversionPattern;
963 LogicalResult
matchAndRewrite(LexInsertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const964 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor,
965 ConversionPatternRewriter &rewriter) const override {
966 Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType();
967 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
968 TypeRange noTp;
969 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
970 EmitCInterface::On);
971 return success();
972 }
973 };
974
975 /// Sparse conversion rule for the expand operator.
976 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
977 public:
978 using OpConversionPattern::OpConversionPattern;
979 LogicalResult
matchAndRewrite(ExpandOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const980 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
981 ConversionPatternRewriter &rewriter) const override {
982 Location loc = op->getLoc();
983 ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
984 Type eltType = srcType.getElementType();
985 Type boolType = rewriter.getIntegerType(1);
986 Type idxType = rewriter.getIndexType();
987 // All initialization should be done on entry of the loop nest.
988 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
989 // Determine the size for access expansion.
990 auto enc = getSparseTensorEncoding(srcType);
991 Value src = adaptor.getOperands()[0];
992 Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1);
993 // Allocate temporary buffers for values, filled-switch, and indices.
994 // We do not use stack buffers for this, since the expanded size may
995 // be rather large (as it envelops a single expanded dense dimension).
996 Value values = genAlloc(rewriter, loc, sz, eltType);
997 Value filled = genAlloc(rewriter, loc, sz, boolType);
998 Value indices = genAlloc(rewriter, loc, sz, idxType);
999 Value zero = constantZero(rewriter, loc, idxType);
1000 // Reset the values/filled-switch to all-zero/false. Note that this
1001 // introduces an O(N) operation into the computation, but this reset
1002 // operation is amortized over the innermost loops for the access
1003 // pattern expansion. As noted in the operation doc, we would like
1004 // to amortize this setup cost even between kernels.
1005 rewriter.create<linalg::FillOp>(
1006 loc, ValueRange{constantZero(rewriter, loc, eltType)},
1007 ValueRange{values});
1008 rewriter.create<linalg::FillOp>(
1009 loc, ValueRange{constantZero(rewriter, loc, boolType)},
1010 ValueRange{filled});
1011 // Replace expansion op with these buffers and initial index.
1012 assert(op.getNumResults() == 4);
1013 rewriter.replaceOp(op, {values, filled, indices, zero});
1014 return success();
1015 }
1016 };
1017
1018 /// Sparse conversion rule for the compress operator.
1019 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
1020 public:
1021 using OpConversionPattern::OpConversionPattern;
1022 LogicalResult
matchAndRewrite(CompressOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1023 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
1024 ConversionPatternRewriter &rewriter) const override {
1025 Location loc = op->getLoc();
1026 // Note that this method call resets the values/filled-switch back to
1027 // all-zero/false by only iterating over the set elements, so the
1028 // complexity remains proportional to the sparsity of the expanded
1029 // access pattern.
1030 Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType();
1031 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
1032 TypeRange noTp;
1033 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
1034 EmitCInterface::On);
1035 // Deallocate the buffers on exit of the loop nest.
1036 Operation *parent = op;
1037 for (; isa<scf::ForOp>(parent->getParentOp()) ||
1038 isa<scf::WhileOp>(parent->getParentOp()) ||
1039 isa<scf::ParallelOp>(parent->getParentOp()) ||
1040 isa<scf::IfOp>(parent->getParentOp());
1041 parent = parent->getParentOp())
1042 ;
1043 rewriter.setInsertionPointAfter(parent);
1044 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[2]);
1045 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[3]);
1046 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[4]);
1047 return success();
1048 }
1049 };
1050
1051 /// Sparse conversion rule for the output operator.
1052 class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
1053 public:
1054 using OpConversionPattern::OpConversionPattern;
1055 LogicalResult
matchAndRewrite(OutOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1056 matchAndRewrite(OutOp op, OpAdaptor adaptor,
1057 ConversionPatternRewriter &rewriter) const override {
1058 Location loc = op->getLoc();
1059 ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
1060 // Convert to default permuted COO.
1061 Value src = adaptor.getOperands()[0];
1062 auto encSrc = getSparseTensorEncoding(srcType);
1063 SmallVector<Value, 4> sizes;
1064 SmallVector<Value, 8> params;
1065 sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src);
1066 auto enc = SparseTensorEncodingAttr::get(
1067 op->getContext(), encSrc.getDimLevelType(), AffineMap(),
1068 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
1069 newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src);
1070 Value coo = genNewCall(rewriter, op, params);
1071 // Then output the tensor to external file with indices in the externally
1072 // visible lexicographic index order. A sort is required if the source was
1073 // not in that order yet (note that the sort can be dropped altogether if
1074 // external format does not care about the order at all, but here we assume
1075 // it does).
1076 bool sort =
1077 encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity();
1078 params.clear();
1079 params.push_back(coo);
1080 params.push_back(adaptor.getOperands()[1]);
1081 params.push_back(constantI1(rewriter, loc, sort));
1082 Type eltType = srcType.getElementType();
1083 SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)};
1084 TypeRange noTp;
1085 createFuncCall(rewriter, op, name, noTp, params, EmitCInterface::Off);
1086 genDelCOOCall(rewriter, op, eltType, coo);
1087 rewriter.eraseOp(op);
1088 return success();
1089 }
1090 };
1091
1092 } // namespace
1093
1094 //===----------------------------------------------------------------------===//
1095 // Public method for populating conversion rules.
1096 //===----------------------------------------------------------------------===//
1097
1098 /// Populates the given patterns list with conversion rules required for
1099 /// the sparsification of linear algebra operations.
populateSparseTensorConversionPatterns(TypeConverter & typeConverter,RewritePatternSet & patterns,const SparseTensorConversionOptions & options)1100 void mlir::populateSparseTensorConversionPatterns(
1101 TypeConverter &typeConverter, RewritePatternSet &patterns,
1102 const SparseTensorConversionOptions &options) {
1103 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
1104 SparseCastConverter, SparseTensorNewConverter,
1105 SparseReshapeConverter<tensor::ExpandShapeOp>,
1106 SparseReshapeConverter<tensor::CollapseShapeOp>,
1107 SparseTensorAllocConverter, SparseTensorDeallocConverter,
1108 SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
1109 SparseTensorToValuesConverter, SparseTensorLoadConverter,
1110 SparseTensorLexInsertConverter, SparseTensorExpandConverter,
1111 SparseTensorCompressConverter, SparseTensorOutConverter>(
1112 typeConverter, patterns.getContext());
1113 patterns.add<SparseTensorConvertConverter>(typeConverter,
1114 patterns.getContext(), options);
1115 }
1116