121895486Swren romano //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2a2c9d4bbSAart Bik //
3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information.
5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a2c9d4bbSAart Bik //
7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
8a2c9d4bbSAart Bik //
996a23911SAart Bik // Convert sparse tensor primitives to calls into a runtime support library.
1096a23911SAart Bik // Note that this is a current implementation choice to keep the conversion
1196a23911SAart Bik // simple. In principle, these primitives could also be converted to actual
12a2c9d4bbSAart Bik // elaborate IR code that implements the primitives on the selected sparse
13a2c9d4bbSAart Bik // tensor storage schemes.
14a2c9d4bbSAart Bik //
15a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
16a2c9d4bbSAart Bik
1785b8d03eSwren romano #include "CodegenUtils.h"
18efa15f41SAart Bik
19c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
2057470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2123aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
22236a9080SAart Bik #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23236a9080SAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h"
24a2c9d4bbSAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h"
258b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
26a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
27a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28ca5d0a73SAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h"
29845561ecSwren romano #include "mlir/ExecutionEngine/SparseTensorUtils.h"
30a2c9d4bbSAart Bik #include "mlir/Transforms/DialectConversion.h"
31a2c9d4bbSAart Bik
32a2c9d4bbSAart Bik using namespace mlir;
3396a23911SAart Bik using namespace mlir::sparse_tensor;
34a2c9d4bbSAart Bik
35a2c9d4bbSAart Bik namespace {
36a2c9d4bbSAart Bik
37d8731bfcSwren romano /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
38d8731bfcSwren romano /// `createFuncCall()`, and `replaceOpWithFuncCall()`.
39d8731bfcSwren romano enum class EmitCInterface : bool { Off = false, On = true };
40d8731bfcSwren romano
4105c7f450SAart Bik //===----------------------------------------------------------------------===//
4205c7f450SAart Bik // Helper methods.
4305c7f450SAart Bik //===----------------------------------------------------------------------===//
4405c7f450SAart Bik
45f527fdf5Swren romano /// Returns the equivalent of `void*` for opaque arguments to the
46f527fdf5Swren romano /// execution engine.
getOpaquePointerType(OpBuilder & builder)47e9fa5590SMatthias Springer static Type getOpaquePointerType(OpBuilder &builder) {
48e9fa5590SMatthias Springer return LLVM::LLVMPointerType::get(builder.getI8Type());
49f527fdf5Swren romano }
50f527fdf5Swren romano
51128a9e1cSAart Bik /// Returns a function reference (first hit also inserts into module). Sets
52128a9e1cSAart Bik /// the "_emit_c_interface" on the function declaration when requested,
53128a9e1cSAart Bik /// so that LLVM lowering generates a wrapper function that takes care
54128a9e1cSAart Bik /// of ABI complications with passing in and returning MemRefs to C functions.
getFunc(Operation * op,StringRef name,TypeRange resultType,ValueRange operands,EmitCInterface emitCInterface)5516b8f4ddSAart Bik static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
5616b8f4ddSAart Bik TypeRange resultType, ValueRange operands,
57d8731bfcSwren romano EmitCInterface emitCInterface) {
58a2c9d4bbSAart Bik MLIRContext *context = op->getContext();
59a2c9d4bbSAart Bik auto module = op->getParentOfType<ModuleOp>();
6041d4aa7dSChris Lattner auto result = SymbolRefAttr::get(context, name);
6158ceae95SRiver Riddle auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
62a2c9d4bbSAart Bik if (!func) {
63a2c9d4bbSAart Bik OpBuilder moduleBuilder(module.getBodyRegion());
6458ceae95SRiver Riddle func = moduleBuilder.create<func::FuncOp>(
6541d4aa7dSChris Lattner op->getLoc(), name,
66128a9e1cSAart Bik FunctionType::get(context, operands.getTypes(), resultType));
67128a9e1cSAart Bik func.setPrivate();
68d8731bfcSwren romano if (static_cast<bool>(emitCInterface))
69610139d2SAlex Zinenko func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
70610139d2SAlex Zinenko UnitAttr::get(context));
71a2c9d4bbSAart Bik }
7241d4aa7dSChris Lattner return result;
73a2c9d4bbSAart Bik }
74a2c9d4bbSAart Bik
75f527fdf5Swren romano /// Creates a `CallOp` to the function reference returned by `getFunc()`.
createFuncCall(OpBuilder & builder,Operation * op,StringRef name,TypeRange resultType,ValueRange operands,EmitCInterface emitCInterface)7623aa5a74SRiver Riddle static func::CallOp createFuncCall(OpBuilder &builder, Operation *op,
77f527fdf5Swren romano StringRef name, TypeRange resultType,
78f527fdf5Swren romano ValueRange operands,
79d8731bfcSwren romano EmitCInterface emitCInterface) {
80f527fdf5Swren romano auto fn = getFunc(op, name, resultType, operands, emitCInterface);
8123aa5a74SRiver Riddle return builder.create<func::CallOp>(op->getLoc(), resultType, fn, operands);
8223aa5a74SRiver Riddle }
8323aa5a74SRiver Riddle
8423aa5a74SRiver Riddle /// Replaces the `op` with a `CallOp` to the function reference returned
8523aa5a74SRiver Riddle /// by `getFunc()`.
replaceOpWithFuncCall(RewriterBase & rewriter,Operation * op,StringRef name,TypeRange resultType,ValueRange operands,EmitCInterface emitCInterface)86e9fa5590SMatthias Springer static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
87e9fa5590SMatthias Springer StringRef name, TypeRange resultType,
8823aa5a74SRiver Riddle ValueRange operands,
8923aa5a74SRiver Riddle EmitCInterface emitCInterface) {
9023aa5a74SRiver Riddle auto fn = getFunc(op, name, resultType, operands, emitCInterface);
9123aa5a74SRiver Riddle return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
9223aa5a74SRiver Riddle operands);
93f527fdf5Swren romano }
94f527fdf5Swren romano
959d1db3d4SAart Bik /// Generates dimension size call.
genDimSizeCall(OpBuilder & builder,Operation * op,SparseTensorEncodingAttr & enc,Value src,int64_t idx)96e9fa5590SMatthias Springer static Value genDimSizeCall(OpBuilder &builder, Operation *op,
979d1db3d4SAart Bik SparseTensorEncodingAttr &enc, Value src,
989d1db3d4SAart Bik int64_t idx) {
999d1db3d4SAart Bik // Permute the index according to an optional dimension ordering.
1009d1db3d4SAart Bik if (AffineMap p = enc.getDimOrdering())
1019d1db3d4SAart Bik idx = p.getPermutedPosition(idx);
1029d1db3d4SAart Bik // Generate the call.
1039d1db3d4SAart Bik StringRef name = "sparseDimSize";
104e9fa5590SMatthias Springer SmallVector<Value, 2> params{src, constantIndex(builder, op->getLoc(), idx)};
105e9fa5590SMatthias Springer Type iTp = builder.getIndexType();
106e9fa5590SMatthias Springer return createFuncCall(builder, op, name, iTp, params, EmitCInterface::Off)
107d8731bfcSwren romano .getResult(0);
1089d1db3d4SAart Bik }
1099d1db3d4SAart Bik
1109d1db3d4SAart Bik /// Generates a call into the "swiss army knife" method of the sparse runtime
1119d1db3d4SAart Bik /// support library for materializing sparse tensors into the computation.
genNewCall(OpBuilder & builder,Operation * op,ArrayRef<Value> params)112e9fa5590SMatthias Springer static Value genNewCall(OpBuilder &builder, Operation *op,
1139d1db3d4SAart Bik ArrayRef<Value> params) {
1149d1db3d4SAart Bik StringRef name = "newSparseTensor";
115e9fa5590SMatthias Springer Type pTp = getOpaquePointerType(builder);
116e9fa5590SMatthias Springer return createFuncCall(builder, op, name, pTp, params, EmitCInterface::On)
117d8731bfcSwren romano .getResult(0);
1189d1db3d4SAart Bik }
1199d1db3d4SAart Bik
1209d1db3d4SAart Bik /// Populates given sizes array from type.
sizesFromType(OpBuilder & builder,SmallVector<Value,4> & sizes,Location loc,ShapedType stp)121e9fa5590SMatthias Springer static void sizesFromType(OpBuilder &builder, SmallVector<Value, 4> &sizes,
122e9fa5590SMatthias Springer Location loc, ShapedType stp) {
1239d1db3d4SAart Bik auto shape = stp.getShape();
1249d1db3d4SAart Bik for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) {
1259d1db3d4SAart Bik uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
126e9fa5590SMatthias Springer sizes.push_back(constantIndex(builder, loc, s));
1279d1db3d4SAart Bik }
1289d1db3d4SAart Bik }
1299d1db3d4SAart Bik
1309d1db3d4SAart Bik /// Populates given sizes array from source.
sizesFromSrc(OpBuilder & builder,SmallVector<Value,4> & sizes,Location loc,Value src)131e9fa5590SMatthias Springer static void sizesFromSrc(OpBuilder &builder, SmallVector<Value, 4> &sizes,
132e9fa5590SMatthias Springer Location loc, Value src) {
133f527fdf5Swren romano unsigned rank = src.getType().cast<ShapedType>().getRank();
134f527fdf5Swren romano for (unsigned i = 0; i < rank; i++)
135e9fa5590SMatthias Springer sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, i));
1369d1db3d4SAart Bik }
1379d1db3d4SAart Bik
1389d1db3d4SAart Bik /// Populates given sizes array from type (for static sizes) and from
1399d1db3d4SAart Bik /// an already converted into opague pointer source (for dynamic sizes).
sizesFromPtr(OpBuilder & builder,SmallVector<Value,4> & sizes,Operation * op,SparseTensorEncodingAttr & enc,ShapedType stp,Value src)140e9fa5590SMatthias Springer static void sizesFromPtr(OpBuilder &builder, SmallVector<Value, 4> &sizes,
141e9fa5590SMatthias Springer Operation *op, SparseTensorEncodingAttr &enc,
142e9fa5590SMatthias Springer ShapedType stp, Value src) {
143f527fdf5Swren romano Location loc = op->getLoc();
1449d1db3d4SAart Bik auto shape = stp.getShape();
1459d1db3d4SAart Bik for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
1469d1db3d4SAart Bik if (shape[i] == ShapedType::kDynamicSize)
147e9fa5590SMatthias Springer sizes.push_back(genDimSizeCall(builder, op, enc, src, i));
1489d1db3d4SAart Bik else
149e9fa5590SMatthias Springer sizes.push_back(constantIndex(builder, loc, shape[i]));
1509d1db3d4SAart Bik }
1519d1db3d4SAart Bik
15228882b65Swren romano /// Generates an uninitialized temporary buffer of the given size and
15328882b65Swren romano /// type, but returns it as type `memref<? x $tp>` (rather than as type
15428882b65Swren romano /// `memref<$sz x $tp>`).
genAlloca(OpBuilder & builder,Location loc,Value sz,Type tp)155e9fa5590SMatthias Springer static Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) {
156b24788abSAart Bik auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
157e9fa5590SMatthias Springer return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz});
1584f2ec7f9SAart Bik }
1594f2ec7f9SAart Bik
1600b55f94dSAart Bik /// Generates an uninitialized buffer of the given size and type,
1610b55f94dSAart Bik /// but returns it as type `memref<? x $tp>` (rather than as type
1620b55f94dSAart Bik /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
1630b55f94dSAart Bik /// this buffer must be explicitly deallocated by client.
genAlloc(RewriterBase & rewriter,Location loc,Value sz,Type tp)164e9fa5590SMatthias Springer static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
1650b55f94dSAart Bik auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
1660b55f94dSAart Bik return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
1670b55f94dSAart Bik }
1680b55f94dSAart Bik
1694f2ec7f9SAart Bik /// Generates an uninitialized temporary buffer of the given size and
1704f2ec7f9SAart Bik /// type, but returns it as type `memref<? x $tp>` (rather than as type
1714f2ec7f9SAart Bik /// `memref<$sz x $tp>`).
genAlloca(OpBuilder & builder,Location loc,unsigned sz,Type tp)172e9fa5590SMatthias Springer static Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp) {
173e9fa5590SMatthias Springer return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp);
174b24788abSAart Bik }
175b24788abSAart Bik
17628882b65Swren romano /// Generates an uninitialized temporary buffer with room for one value
17728882b65Swren romano /// of the given type, and returns the `memref<$tp>`.
genAllocaScalar(OpBuilder & builder,Location loc,Type tp)178e9fa5590SMatthias Springer static Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp) {
179e9fa5590SMatthias Springer return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
18028882b65Swren romano }
18128882b65Swren romano
1829d1db3d4SAart Bik /// Generates a temporary buffer of the given type and given contents.
genBuffer(OpBuilder & builder,Location loc,ValueRange values)183e9fa5590SMatthias Springer static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
184b24788abSAart Bik unsigned sz = values.size();
185b24788abSAart Bik assert(sz >= 1);
186e9fa5590SMatthias Springer Value buffer = genAlloca(builder, loc, sz, values[0].getType());
187b24788abSAart Bik for (unsigned i = 0; i < sz; i++) {
188e9fa5590SMatthias Springer Value idx = constantIndex(builder, loc, i);
189e9fa5590SMatthias Springer builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
190b24788abSAart Bik }
191b24788abSAart Bik return buffer;
192b24788abSAart Bik }
193b24788abSAart Bik
1949d1db3d4SAart Bik /// Populates parameters required to call the "swiss army knife" method of the
1959d1db3d4SAart Bik /// sparse runtime support library for materializing sparse tensors into the
1969d1db3d4SAart Bik /// computation.
newParams(OpBuilder & builder,SmallVector<Value,8> & params,Operation * op,ShapedType stp,SparseTensorEncodingAttr & enc,Action action,ValueRange szs,Value ptr=Value ())197e9fa5590SMatthias Springer static void newParams(OpBuilder &builder, SmallVector<Value, 8> ¶ms,
198e9fa5590SMatthias Springer Operation *op, ShapedType stp,
199e9fa5590SMatthias Springer SparseTensorEncodingAttr &enc, Action action,
200e9fa5590SMatthias Springer ValueRange szs, Value ptr = Value()) {
20105c7f450SAart Bik Location loc = op->getLoc();
202b24788abSAart Bik ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
203b24788abSAart Bik unsigned sz = dlt.size();
2049d1db3d4SAart Bik // Sparsity annotations.
2059d1db3d4SAart Bik SmallVector<Value, 4> attrs;
20605c7f450SAart Bik for (unsigned i = 0; i < sz; i++)
207e9fa5590SMatthias Springer attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i]));
208e9fa5590SMatthias Springer params.push_back(genBuffer(builder, loc, attrs));
2099d1db3d4SAart Bik // Dimension sizes array of the enveloping tensor. Useful for either
21005c7f450SAart Bik // verification of external data, or for construction of internal data.
211e9fa5590SMatthias Springer params.push_back(genBuffer(builder, loc, szs));
21205c7f450SAart Bik // Dimension order permutation array. This is the "identity" permutation by
21305c7f450SAart Bik // default, or otherwise the "reverse" permutation of a given ordering, so
21405c7f450SAart Bik // that indices can be mapped quickly to the right position.
215b24788abSAart Bik SmallVector<Value, 4> rev(sz);
216236a9080SAart Bik if (AffineMap p = enc.getDimOrdering()) {
21705c7f450SAart Bik for (unsigned i = 0; i < sz; i++)
218e9fa5590SMatthias Springer rev[p.getDimPosition(i)] = constantIndex(builder, loc, i);
21905c7f450SAart Bik } else {
22005c7f450SAart Bik for (unsigned i = 0; i < sz; i++)
221e9fa5590SMatthias Springer rev[i] = constantIndex(builder, loc, i);
22205c7f450SAart Bik }
223e9fa5590SMatthias Springer params.push_back(genBuffer(builder, loc, rev));
22405c7f450SAart Bik // Secondary and primary types encoding.
225efa15f41SAart Bik Type elemTp = stp.getElementType();
226e9fa5590SMatthias Springer params.push_back(constantPointerTypeEncoding(builder, loc, enc));
227e9fa5590SMatthias Springer params.push_back(constantIndexTypeEncoding(builder, loc, enc));
228e9fa5590SMatthias Springer params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp));
229f527fdf5Swren romano // User action.
230e9fa5590SMatthias Springer params.push_back(constantAction(builder, loc, action));
231f527fdf5Swren romano // Payload pointer.
232f527fdf5Swren romano if (!ptr)
233e9fa5590SMatthias Springer ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder));
23405c7f450SAart Bik params.push_back(ptr);
23505c7f450SAart Bik }
23605c7f450SAart Bik
237fbd5821cSBixia Zheng /// Generates the code to read the value from tensor[ivs], and conditionally
238fbd5821cSBixia Zheng /// stores the indices ivs to the memory in ind. The generated code looks like
239fbd5821cSBixia Zheng /// the following and the insertion point after this routine is inside the
240fbd5821cSBixia Zheng /// if-then branch behind the assignment to ind. This is to ensure that the
241fbd5821cSBixia Zheng /// addEltX call generated after is inside the if-then branch.
242faa00c13SAart Bik /// if (tensor[ivs] != 0)
243fbd5821cSBixia Zheng /// ind = ivs
genIndexAndValueForDense(OpBuilder & builder,Location loc,Value tensor,Value ind,ValueRange ivs)244e9fa5590SMatthias Springer static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
245e9fa5590SMatthias Springer Value tensor, Value ind, ValueRange ivs) {
246e9fa5590SMatthias Springer Value val = builder.create<tensor::ExtractOp>(loc, tensor, ivs);
247e9fa5590SMatthias Springer Value cond = genIsNonzero(builder, loc, val);
248e9fa5590SMatthias Springer scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else*/ false);
249e9fa5590SMatthias Springer builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
250fbd5821cSBixia Zheng unsigned i = 0;
251fbd5821cSBixia Zheng for (auto iv : ivs) {
252e9fa5590SMatthias Springer Value idx = constantIndex(builder, loc, i++);
253e9fa5590SMatthias Springer builder.create<memref::StoreOp>(loc, iv, ind, idx);
254fbd5821cSBixia Zheng }
255fbd5821cSBixia Zheng return val;
256fbd5821cSBixia Zheng }
257fbd5821cSBixia Zheng
25863bdcaf9Swren romano /// Generates a call to release/delete a `SparseTensorCOO`.
genDelCOOCall(OpBuilder & builder,Operation * op,Type elemTp,Value coo)25963bdcaf9Swren romano static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp,
26063bdcaf9Swren romano Value coo) {
26163bdcaf9Swren romano SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
26263bdcaf9Swren romano TypeRange noTp;
26363bdcaf9Swren romano createFuncCall(builder, op, name, noTp, coo, EmitCInterface::Off);
26463bdcaf9Swren romano }
26563bdcaf9Swren romano
266236a9080SAart Bik /// Generates a call that adds one element to a coordinate scheme.
267221856f5Swren romano /// In particular, this generates code like the following:
268221856f5Swren romano /// val = a[i1,..,ik];
269221856f5Swren romano /// if val != 0
270aef20f59SAart Bik /// t->add(&val, [i1,..,ik], [p1,..,pk]);
genAddEltCall(OpBuilder & builder,Operation * op,Type eltType,Value ptr,Value valPtr,Value ind,Value perm)271e9fa5590SMatthias Springer static void genAddEltCall(OpBuilder &builder, Operation *op, Type eltType,
272aef20f59SAart Bik Value ptr, Value valPtr, Value ind, Value perm) {
273c9489225Swren romano SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
274aef20f59SAart Bik SmallVector<Value, 4> params{ptr, valPtr, ind, perm};
275e9fa5590SMatthias Springer Type pTp = getOpaquePointerType(builder);
276e9fa5590SMatthias Springer createFuncCall(builder, op, name, pTp, params, EmitCInterface::On);
27705c7f450SAart Bik }
27805c7f450SAart Bik
27928882b65Swren romano /// Generates a call to `iter->getNext()`. If there is a next element,
28028882b65Swren romano /// then it is copied into the out-parameters `ind` and `elemPtr`,
28128882b65Swren romano /// and the return value is true. If there isn't a next element, then
2826be36fd7Swren romano /// the memory for `iter` is freed and the return value is false.
genGetNextCall(OpBuilder & builder,Operation * op,Value iter,Value ind,Value elemPtr)283e9fa5590SMatthias Springer static Value genGetNextCall(OpBuilder &builder, Operation *op, Value iter,
284e9fa5590SMatthias Springer Value ind, Value elemPtr) {
28528882b65Swren romano Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
286c9489225Swren romano SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)};
287f527fdf5Swren romano SmallVector<Value, 3> params{iter, ind, elemPtr};
288e9fa5590SMatthias Springer Type i1 = builder.getI1Type();
289e9fa5590SMatthias Springer return createFuncCall(builder, op, name, i1, params, EmitCInterface::On)
290d8731bfcSwren romano .getResult(0);
29128882b65Swren romano }
29228882b65Swren romano
293fbd5821cSBixia Zheng /// If the tensor is a sparse constant, generates and returns the pair of
294fbd5821cSBixia Zheng /// the constants for the indices and the values.
295fbd5821cSBixia Zheng static Optional<std::pair<Value, Value>>
genSplitSparseConstant(OpBuilder & builder,Location loc,Value tensor)296e9fa5590SMatthias Springer genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) {
297a54f4eaeSMogball if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
298cfb72fd3SJacques Pienaar if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
299fbd5821cSBixia Zheng DenseElementsAttr indicesAttr = attr.getIndices();
300e9fa5590SMatthias Springer Value indices = builder.create<arith::ConstantOp>(loc, indicesAttr);
301fbd5821cSBixia Zheng DenseElementsAttr valuesAttr = attr.getValues();
302e9fa5590SMatthias Springer Value values = builder.create<arith::ConstantOp>(loc, valuesAttr);
303fbd5821cSBixia Zheng return std::make_pair(indices, values);
304fbd5821cSBixia Zheng }
305fbd5821cSBixia Zheng }
306fbd5821cSBixia Zheng return {};
307fbd5821cSBixia Zheng }
308fbd5821cSBixia Zheng
309fbd5821cSBixia Zheng /// Generates the code to copy the index at indices[ivs] to ind, and return
310fbd5821cSBixia Zheng /// the value at value[ivs].
genIndexAndValueForSparse(OpBuilder & builder,Location loc,Value indices,Value values,Value ind,ValueRange ivs,unsigned rank)311e9fa5590SMatthias Springer static Value genIndexAndValueForSparse(OpBuilder &builder, Location loc,
312e9fa5590SMatthias Springer Value indices, Value values, Value ind,
313e9fa5590SMatthias Springer ValueRange ivs, unsigned rank) {
314fbd5821cSBixia Zheng for (unsigned i = 0; i < rank; i++) {
315e9fa5590SMatthias Springer Value idx = constantIndex(builder, loc, i);
316e9fa5590SMatthias Springer Value val = builder.create<tensor::ExtractOp>(loc, indices,
317fbd5821cSBixia Zheng ValueRange{ivs[0], idx});
318e9fa5590SMatthias Springer val = builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), val);
319e9fa5590SMatthias Springer builder.create<memref::StoreOp>(loc, val, ind, idx);
320fbd5821cSBixia Zheng }
321e9fa5590SMatthias Springer return builder.create<tensor::ExtractOp>(loc, values, ivs[0]);
322fbd5821cSBixia Zheng }
323fbd5821cSBixia Zheng
324c66303c2SMatthias Springer /// Generates code to allocate a buffer of the given type, and zero
325c66303c2SMatthias Springer /// initialize it. If the buffer type has any dynamic sizes, then the
3265389cdc8Swren romano /// `sizes` parameter should be as filled by sizesFromPtr(); that way
3275389cdc8Swren romano /// we can reuse the genDimSizeCall() results generated by sizesFromPtr().
allocDenseTensor(OpBuilder & builder,Location loc,RankedTensorType tensorTp,ValueRange sizes)328e9fa5590SMatthias Springer static Value allocDenseTensor(OpBuilder &builder, Location loc,
3295389cdc8Swren romano RankedTensorType tensorTp, ValueRange sizes) {
33028882b65Swren romano Type elemTp = tensorTp.getElementType();
3315389cdc8Swren romano auto shape = tensorTp.getShape();
3325389cdc8Swren romano auto memTp = MemRefType::get(shape, elemTp);
3335389cdc8Swren romano SmallVector<Value> dynamicSizes;
3345389cdc8Swren romano for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) {
3355389cdc8Swren romano if (shape[i] == ShapedType::kDynamicSize)
3365389cdc8Swren romano dynamicSizes.push_back(sizes[i]);
3375389cdc8Swren romano }
338e9fa5590SMatthias Springer Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes);
339e9fa5590SMatthias Springer Value zero = constantZero(builder, loc, elemTp);
340e9fa5590SMatthias Springer builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem});
34128882b65Swren romano return mem;
34228882b65Swren romano }
34328882b65Swren romano
344c66303c2SMatthias Springer /// Generates code to deallocate a dense buffer.
deallocDenseTensor(OpBuilder & builder,Location loc,Value buffer)345c66303c2SMatthias Springer static void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer) {
346c66303c2SMatthias Springer builder.create<memref::DeallocOp>(loc, buffer);
347c66303c2SMatthias Springer }
348c66303c2SMatthias Springer
34928882b65Swren romano /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into
35028882b65Swren romano /// the tensor created by allocDenseTensor(). The `rank` is the rank
35128882b65Swren romano /// of the `tensor` and the length of `ind`.
insertScalarIntoDenseTensor(OpBuilder & builder,Location loc,Value elemPtr,Value tensor,unsigned rank,Value ind)352e9fa5590SMatthias Springer static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc,
353e9fa5590SMatthias Springer Value elemPtr, Value tensor,
354e9fa5590SMatthias Springer unsigned rank, Value ind) {
35528882b65Swren romano SmallVector<Value, 4> ivs;
35628882b65Swren romano ivs.reserve(rank);
35728882b65Swren romano for (unsigned i = 0; i < rank; i++) {
358e9fa5590SMatthias Springer Value idx = constantIndex(builder, loc, i);
359e9fa5590SMatthias Springer ivs.push_back(builder.create<memref::LoadOp>(loc, ind, idx));
36028882b65Swren romano }
361e9fa5590SMatthias Springer Value elemV = builder.create<memref::LoadOp>(loc, elemPtr);
362e9fa5590SMatthias Springer builder.create<memref::StoreOp>(loc, elemV, tensor, ivs);
36328882b65Swren romano }
36428882b65Swren romano
3658cb33240Swren romano /// Determine if the runtime library supports direct conversion to the
3668cb33240Swren romano /// given target `dimTypes`.
canUseDirectConversion(ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes)3678cb33240Swren romano static bool canUseDirectConversion(
3688cb33240Swren romano ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes) {
3698cb33240Swren romano bool alreadyCompressed = false;
3708cb33240Swren romano for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) {
3718cb33240Swren romano switch (dimTypes[r]) {
3728cb33240Swren romano case SparseTensorEncodingAttr::DimLevelType::Compressed:
3738cb33240Swren romano if (alreadyCompressed)
3748cb33240Swren romano return false; // Multiple compressed dimensions not yet supported.
3758cb33240Swren romano alreadyCompressed = true;
3768cb33240Swren romano break;
3778cb33240Swren romano case SparseTensorEncodingAttr::DimLevelType::Dense:
3788cb33240Swren romano if (alreadyCompressed)
3798cb33240Swren romano return false; // Dense after Compressed not yet supported.
3808cb33240Swren romano break;
3818cb33240Swren romano case SparseTensorEncodingAttr::DimLevelType::Singleton:
3828cb33240Swren romano // Although Singleton isn't generally supported yet, the direct
3838cb33240Swren romano // conversion method doesn't have any particular problems with
3848cb33240Swren romano // singleton after compressed.
3858cb33240Swren romano break;
3868cb33240Swren romano }
3878cb33240Swren romano }
3888cb33240Swren romano return true;
3898cb33240Swren romano }
3908cb33240Swren romano
391faa00c13SAart Bik /// Helper method to translate indices during a reshaping operation.
392faa00c13SAart Bik /// TODO: provide as general utility to MLIR at large?
translateIndices(Location loc,ConversionPatternRewriter & rewriter,ArrayRef<ReassociationIndices> reassociation,TensorType dstTp,TensorType srcTp,Value dstIdx,Value srcIdx)393faa00c13SAart Bik static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
394faa00c13SAart Bik ArrayRef<ReassociationIndices> reassociation,
395faa00c13SAart Bik TensorType dstTp, TensorType srcTp, Value dstIdx,
396faa00c13SAart Bik Value srcIdx) {
397faa00c13SAart Bik unsigned dstRank = dstTp.getRank();
398faa00c13SAart Bik unsigned srcRank = srcTp.getRank();
399faa00c13SAart Bik unsigned start = 0;
400faa00c13SAart Bik unsigned i = 0;
401faa00c13SAart Bik bool isExpand = srcRank > dstRank;
402faa00c13SAart Bik ArrayRef<int64_t> shape = isExpand ? srcTp.getShape() : dstTp.getShape();
403faa00c13SAart Bik // Iterate over reassociation map.
404faa00c13SAart Bik for (const auto &map : llvm::enumerate(reassociation)) {
405faa00c13SAart Bik // Prepare strides information in dimension slice.
406faa00c13SAart Bik uint64_t linear = 1;
407faa00c13SAart Bik for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
408faa00c13SAart Bik assert(!ShapedType::isDynamic(shape[j]));
409faa00c13SAart Bik linear *= shape[j];
410faa00c13SAart Bik }
411faa00c13SAart Bik // Start collapse.
412faa00c13SAart Bik Value idx = constantIndex(rewriter, loc, i++);
413faa00c13SAart Bik Value val;
414faa00c13SAart Bik if (!isExpand)
415faa00c13SAart Bik val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx);
416faa00c13SAart Bik // Iterate over dimension slice.
417faa00c13SAart Bik for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
418faa00c13SAart Bik linear /= shape[j];
419faa00c13SAart Bik Value stride = constantIndex(rewriter, loc, linear);
420faa00c13SAart Bik Value jdx = constantIndex(rewriter, loc, j);
421faa00c13SAart Bik if (isExpand) {
422faa00c13SAart Bik Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx);
423faa00c13SAart Bik Value mul = linear == 1
424faa00c13SAart Bik ? old
425faa00c13SAart Bik : rewriter.create<arith::MulIOp>(loc, old, stride);
426faa00c13SAart Bik val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul;
427faa00c13SAart Bik } else {
428faa00c13SAart Bik Value old = val;
429faa00c13SAart Bik if (linear != 1)
430faa00c13SAart Bik val = rewriter.create<arith::DivUIOp>(loc, val, stride);
431faa00c13SAart Bik rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx);
432faa00c13SAart Bik if (linear != 1)
433faa00c13SAart Bik val = rewriter.create<arith::RemUIOp>(loc, old, stride);
434faa00c13SAart Bik }
435faa00c13SAart Bik }
436faa00c13SAart Bik // Finalize expansion.
437faa00c13SAart Bik if (isExpand)
438faa00c13SAart Bik rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx);
439faa00c13SAart Bik start += map.value().size();
440faa00c13SAart Bik }
441faa00c13SAart Bik // Sanity.
442faa00c13SAart Bik assert((isExpand && i == dstRank) || (!isExpand && i == srcRank));
443faa00c13SAart Bik }
444faa00c13SAart Bik
445faa00c13SAart Bik /// Generate code for a general sparse to sparse reshaping operation.
446faa00c13SAart Bik /// Note that unlike dense reshaping (which can be done with a "cheap"
447faa00c13SAart Bik /// change of view), sparse reshaping is currently done with actual
448faa00c13SAart Bik /// data shuffling.
449faa00c13SAart Bik ///
450faa00c13SAart Bik /// TODO: proportional to nnz, but still a lot of data movement
451faa00c13SAart Bik /// https://github.com/llvm/llvm-project/issues/56477
452faa00c13SAart Bik ///
453faa00c13SAart Bik /// iter = src->toCOO();
454faa00c13SAart Bik /// coo = newSparseCOO()
455faa00c13SAart Bik /// while (elem = iter->getNext()) {
456faa00c13SAart Bik /// coo->add(reshape(elem.indices), elem.value)
457faa00c13SAart Bik /// }
458faa00c13SAart Bik /// s = newSparseTensor(coo)
459faa00c13SAart Bik static LogicalResult
genSparse2SparseReshape(Operation * op,ConversionPatternRewriter & rewriter,ArrayRef<ReassociationIndices> reassociation,Value src,RankedTensorType dstTp,RankedTensorType srcTp)460faa00c13SAart Bik genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
461faa00c13SAart Bik ArrayRef<ReassociationIndices> reassociation, Value src,
462faa00c13SAart Bik RankedTensorType dstTp, RankedTensorType srcTp) {
463faa00c13SAart Bik Location loc = op->getLoc();
464faa00c13SAart Bik auto encDst = getSparseTensorEncoding(dstTp);
465faa00c13SAart Bik auto encSrc = getSparseTensorEncoding(srcTp);
466faa00c13SAart Bik assert(encDst && encSrc);
467faa00c13SAart Bik unsigned srcRank = srcTp.getRank();
468faa00c13SAart Bik unsigned dstRank = dstTp.getRank();
469faa00c13SAart Bik Type elemTp = srcTp.getElementType();
470faa00c13SAart Bik assert(elemTp == dstTp.getElementType() &&
471faa00c13SAart Bik "reshape should not change element type");
472faa00c13SAart Bik // Start an iterator over the source tensor (in original index order).
473faa00c13SAart Bik auto noPerm = SparseTensorEncodingAttr::get(
474faa00c13SAart Bik op->getContext(), encSrc.getDimLevelType(), AffineMap(),
475faa00c13SAart Bik encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
476faa00c13SAart Bik SmallVector<Value, 4> sizes;
477faa00c13SAart Bik SmallVector<Value, 8> params;
478faa00c13SAart Bik sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src);
479faa00c13SAart Bik newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes,
480faa00c13SAart Bik src);
481faa00c13SAart Bik Value iter = genNewCall(rewriter, op, params);
482faa00c13SAart Bik // Start a new COO for the destination tensor.
483faa00c13SAart Bik sizes.clear();
484faa00c13SAart Bik params.clear();
485faa00c13SAart Bik sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src);
486faa00c13SAart Bik newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes);
487faa00c13SAart Bik Value coo = genNewCall(rewriter, op, params);
488faa00c13SAart Bik Value dstPerm = params[2];
489faa00c13SAart Bik // Construct a while loop over the iterator.
490faa00c13SAart Bik Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
491faa00c13SAart Bik Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
492faa00c13SAart Bik Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
493faa00c13SAart Bik SmallVector<Value> noArgs;
494faa00c13SAart Bik SmallVector<Type> noTypes;
495faa00c13SAart Bik auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
496faa00c13SAart Bik Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
497faa00c13SAart Bik rewriter.setInsertionPointToEnd(before);
498faa00c13SAart Bik Value cond = genGetNextCall(rewriter, op, iter, srcIdx, elemPtr);
499faa00c13SAart Bik rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
500faa00c13SAart Bik // Translate indices from source to target and insert. Note that we do
501faa00c13SAart Bik // not need to store the value in elemPtr, as the value is still there.
502faa00c13SAart Bik Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
503faa00c13SAart Bik rewriter.setInsertionPointToStart(after);
504faa00c13SAart Bik translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx);
505faa00c13SAart Bik genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm);
506faa00c13SAart Bik rewriter.create<scf::YieldOp>(loc);
507faa00c13SAart Bik // Final call to construct sparse tensor storage and free temporary resources.
508faa00c13SAart Bik rewriter.setInsertionPointAfter(whileOp);
509faa00c13SAart Bik params[6] = constantAction(rewriter, loc, Action::kFromCOO);
510faa00c13SAart Bik params[7] = coo;
511faa00c13SAart Bik Value dst = genNewCall(rewriter, op, params);
512faa00c13SAart Bik genDelCOOCall(rewriter, op, elemTp, coo);
513faa00c13SAart Bik genDelCOOCall(rewriter, op, elemTp, iter);
514faa00c13SAart Bik rewriter.replaceOp(op, dst);
515faa00c13SAart Bik return success();
516faa00c13SAart Bik }
517faa00c13SAart Bik
51805c7f450SAart Bik //===----------------------------------------------------------------------===//
51905c7f450SAart Bik // Conversion rules.
52005c7f450SAart Bik //===----------------------------------------------------------------------===//
52105c7f450SAart Bik
52296a23911SAart Bik /// Sparse conversion rule for returns.
52323aa5a74SRiver Riddle class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
52496a23911SAart Bik public:
525a2c9d4bbSAart Bik using OpConversionPattern::OpConversionPattern;
526a2c9d4bbSAart Bik LogicalResult
matchAndRewrite(func::ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const52723aa5a74SRiver Riddle matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
528a2c9d4bbSAart Bik ConversionPatternRewriter &rewriter) const override {
52923aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
530a2c9d4bbSAart Bik return success();
531a2c9d4bbSAart Bik }
532a2c9d4bbSAart Bik };
533a2c9d4bbSAart Bik
534a2c9d4bbSAart Bik /// Sparse conversion rule for dimension accesses.
535a2c9d4bbSAart Bik class SparseTensorToDimSizeConverter
536c0a6318dSMatthias Springer : public OpConversionPattern<tensor::DimOp> {
537a2c9d4bbSAart Bik public:
538a2c9d4bbSAart Bik using OpConversionPattern::OpConversionPattern;
539a2c9d4bbSAart Bik LogicalResult
matchAndRewrite(tensor::DimOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const540b54c724bSRiver Riddle matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
541a2c9d4bbSAart Bik ConversionPatternRewriter &rewriter) const override {
5429d1db3d4SAart Bik // Only rewrite annotated DimOp with constant index.
5438df54a6aSJacques Pienaar auto enc = getSparseTensorEncoding(op.getSource().getType());
544d37d72eaSAart Bik if (!enc)
545d37d72eaSAart Bik return failure();
546d37d72eaSAart Bik Optional<int64_t> index = op.getConstantIndex();
547037f0995SKazu Hirata if (!index)
548d37d72eaSAart Bik return failure();
549d37d72eaSAart Bik // Generate the call.
5509d1db3d4SAart Bik Value src = adaptor.getOperands()[0];
5516d5fc1e3SKazu Hirata int64_t idx = *index;
5529d1db3d4SAart Bik rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx));
553a2c9d4bbSAart Bik return success();
554a2c9d4bbSAart Bik }
555a2c9d4bbSAart Bik };
556a2c9d4bbSAart Bik
5571b15160eSAart Bik /// Sparse conversion rule for trivial tensor casts.
5581b15160eSAart Bik class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
559faa00c13SAart Bik public:
5601b15160eSAart Bik using OpConversionPattern::OpConversionPattern;
5611b15160eSAart Bik LogicalResult
matchAndRewrite(tensor::CastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const5621b15160eSAart Bik matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
5631b15160eSAart Bik ConversionPatternRewriter &rewriter) const override {
5641b15160eSAart Bik // Only rewrite identically annotated source/dest.
5651b15160eSAart Bik auto encDst = getSparseTensorEncoding(op.getType());
5668df54a6aSJacques Pienaar auto encSrc = getSparseTensorEncoding(op.getSource().getType());
5671b15160eSAart Bik if (!encDst || encDst != encSrc)
5681b15160eSAart Bik return failure();
5691b15160eSAart Bik rewriter.replaceOp(op, adaptor.getOperands());
5701b15160eSAart Bik return success();
5711b15160eSAart Bik }
5721b15160eSAart Bik };
5731b15160eSAart Bik
574faa00c13SAart Bik /// Sparse conversion rule for a reshape operator.
575faa00c13SAart Bik template <typename ReshapeOp>
576faa00c13SAart Bik class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> {
577faa00c13SAart Bik public:
578faa00c13SAart Bik using OpAdaptor = typename OpConversionPattern<ReshapeOp>::OpAdaptor;
579faa00c13SAart Bik using OpConversionPattern<ReshapeOp>::OpConversionPattern;
580faa00c13SAart Bik LogicalResult
matchAndRewrite(ReshapeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const581faa00c13SAart Bik matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
582faa00c13SAart Bik ConversionPatternRewriter &rewriter) const override {
583faa00c13SAart Bik Type dstType = op.getResult().getType();
584faa00c13SAart Bik Type srcType = op.getSrc().getType();
585faa00c13SAart Bik auto encDst = getSparseTensorEncoding(dstType);
586faa00c13SAart Bik auto encSrc = getSparseTensorEncoding(srcType);
587faa00c13SAart Bik if (encDst && encSrc)
588faa00c13SAart Bik return genSparse2SparseReshape(
589faa00c13SAart Bik op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0],
590faa00c13SAart Bik dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>());
591faa00c13SAart Bik return failure(); // handled elsewhere
592faa00c13SAart Bik }
593faa00c13SAart Bik };
594faa00c13SAart Bik
59596a23911SAart Bik /// Sparse conversion rule for the new operator.
59696a23911SAart Bik class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
597faa00c13SAart Bik public:
59896a23911SAart Bik using OpConversionPattern::OpConversionPattern;
59996a23911SAart Bik LogicalResult
matchAndRewrite(NewOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const600b54c724bSRiver Riddle matchAndRewrite(NewOp op, OpAdaptor adaptor,
60196a23911SAart Bik ConversionPatternRewriter &rewriter) const override {
60296a23911SAart Bik Type resType = op.getType();
60396a23911SAart Bik auto enc = getSparseTensorEncoding(resType);
60496a23911SAart Bik if (!enc)
60596a23911SAart Bik return failure();
6069d1db3d4SAart Bik // Generate the call to construct tensor from ptr. The sizes are
6079d1db3d4SAart Bik // inferred from the result type of the new operator.
6089d1db3d4SAart Bik SmallVector<Value, 4> sizes;
6099d1db3d4SAart Bik SmallVector<Value, 8> params;
610efa15f41SAart Bik ShapedType stp = resType.cast<ShapedType>();
611efa15f41SAart Bik sizesFromType(rewriter, sizes, op.getLoc(), stp);
6129d1db3d4SAart Bik Value ptr = adaptor.getOperands()[0];
613efa15f41SAart Bik newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr);
6149d1db3d4SAart Bik rewriter.replaceOp(op, genNewCall(rewriter, op, params));
615b24788abSAart Bik return success();
616b24788abSAart Bik }
617b24788abSAart Bik };
618b24788abSAart Bik
6196232a8f3SMatthias Springer /// Sparse conversion rule for the alloc operator.
6206232a8f3SMatthias Springer class SparseTensorAllocConverter
6216232a8f3SMatthias Springer : public OpConversionPattern<bufferization::AllocTensorOp> {
622faa00c13SAart Bik public:
623b24788abSAart Bik using OpConversionPattern::OpConversionPattern;
624b24788abSAart Bik LogicalResult
matchAndRewrite(bufferization::AllocTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const6256232a8f3SMatthias Springer matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
626b24788abSAart Bik ConversionPatternRewriter &rewriter) const override {
627c66303c2SMatthias Springer if (op.getCopy())
628c66303c2SMatthias Springer return rewriter.notifyMatchFailure(op,
629c66303c2SMatthias Springer "sparse tensor copy not implemented");
6306232a8f3SMatthias Springer RankedTensorType resType = op.getType();
631b24788abSAart Bik auto enc = getSparseTensorEncoding(resType);
632b24788abSAart Bik if (!enc)
633b24788abSAart Bik return failure();
6346232a8f3SMatthias Springer // Gather all dimension sizes as SSA values.
6356232a8f3SMatthias Springer SmallVector<Value> sizes;
6366232a8f3SMatthias Springer unsigned int operandCtr = 0;
6376232a8f3SMatthias Springer for (int64_t i = 0; i < resType.getRank(); ++i) {
6386232a8f3SMatthias Springer if (resType.isDynamicDim(i)) {
6396232a8f3SMatthias Springer sizes.push_back(adaptor.getOperands()[operandCtr++]);
6406232a8f3SMatthias Springer } else {
6416232a8f3SMatthias Springer sizes.push_back(rewriter.create<arith::ConstantIndexOp>(
6426232a8f3SMatthias Springer op.getLoc(), op.getStaticSize(i)));
6436232a8f3SMatthias Springer }
6446232a8f3SMatthias Springer }
6459d1db3d4SAart Bik // Generate the call to construct empty tensor. The sizes are
6466232a8f3SMatthias Springer // explicitly defined by the arguments to the alloc operator.
6479d1db3d4SAart Bik SmallVector<Value, 8> params;
648efa15f41SAart Bik ShapedType stp = resType.cast<ShapedType>();
6496232a8f3SMatthias Springer newParams(rewriter, params, op, stp, enc, Action::kEmpty, sizes);
6509d1db3d4SAart Bik rewriter.replaceOp(op, genNewCall(rewriter, op, params));
65196a23911SAart Bik return success();
65296a23911SAart Bik }
65396a23911SAart Bik };
65496a23911SAart Bik
655697ea09dSAart Bik /// Sparse conversion rule for the convert operator.
656697ea09dSAart Bik class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
657c7e24db4Swren romano public:
658697ea09dSAart Bik using OpConversionPattern::OpConversionPattern;
SparseTensorConvertConverter(MLIRContext * context,SparseTensorConversionOptions o)659c7e24db4Swren romano SparseTensorConvertConverter(MLIRContext *context,
660c7e24db4Swren romano SparseTensorConversionOptions o)
661c7e24db4Swren romano : OpConversionPattern<ConvertOp>(context), options(o) {}
SparseTensorConvertConverter(TypeConverter & typeConv,MLIRContext * context,SparseTensorConversionOptions o)662c7e24db4Swren romano SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
663c7e24db4Swren romano SparseTensorConversionOptions o)
664c7e24db4Swren romano : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
665c7e24db4Swren romano
666697ea09dSAart Bik LogicalResult
matchAndRewrite(ConvertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const667b54c724bSRiver Riddle matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
668697ea09dSAart Bik ConversionPatternRewriter &rewriter) const override {
6699d1db3d4SAart Bik Location loc = op->getLoc();
67005c7f450SAart Bik Type resType = op.getType();
6718df54a6aSJacques Pienaar Type srcType = op.getSource().getType();
67205c7f450SAart Bik auto encDst = getSparseTensorEncoding(resType);
6739d1db3d4SAart Bik auto encSrc = getSparseTensorEncoding(srcType);
6749d1db3d4SAart Bik Value src = adaptor.getOperands()[0];
6750a7b8cc5SAart Bik if (encDst && encSrc) {
6760a7b8cc5SAart Bik // This is a sparse => sparse conversion, which is handled as follows:
67721895486Swren romano // t = src->toCOO(); ; src to COO in dst order
6780a7b8cc5SAart Bik // dst = newSparseTensor(t)
6790a7b8cc5SAart Bik // Using the coordinate scheme as an intermediate does not always
6800a7b8cc5SAart Bik // yield the fastest conversion but avoids the need for a full
6810a7b8cc5SAart Bik // O(N^2) conversion matrix.
6821b15160eSAart Bik if (encDst == encSrc) {
6831b15160eSAart Bik rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
6841b15160eSAart Bik return success();
6851b15160eSAart Bik }
6869d1db3d4SAart Bik SmallVector<Value, 4> sizes;
6879d1db3d4SAart Bik SmallVector<Value, 8> params;
688efa15f41SAart Bik ShapedType stp = srcType.cast<ShapedType>();
689efa15f41SAart Bik sizesFromPtr(rewriter, sizes, op, encSrc, stp, src);
6908cb33240Swren romano bool useDirectConversion;
6918cb33240Swren romano switch (options.sparseToSparseStrategy) {
6928cb33240Swren romano case SparseToSparseConversionStrategy::kViaCOO:
6938cb33240Swren romano useDirectConversion = false;
6948cb33240Swren romano break;
6958cb33240Swren romano case SparseToSparseConversionStrategy::kDirect:
6968cb33240Swren romano useDirectConversion = true;
6978cb33240Swren romano assert(canUseDirectConversion(encDst.getDimLevelType()) &&
6988cb33240Swren romano "Unsupported target for direct sparse-to-sparse conversion");
6998cb33240Swren romano break;
7008cb33240Swren romano case SparseToSparseConversionStrategy::kAuto:
7018cb33240Swren romano useDirectConversion = canUseDirectConversion(encDst.getDimLevelType());
7028cb33240Swren romano break;
7038cb33240Swren romano }
7048cb33240Swren romano if (useDirectConversion) {
7058cb33240Swren romano newParams(rewriter, params, op, stp, encDst, Action::kSparseToSparse,
7068cb33240Swren romano sizes, src);
7078cb33240Swren romano rewriter.replaceOp(op, genNewCall(rewriter, op, params));
7088cb33240Swren romano } else { // use via-COO conversion.
709185960dcSAart Bik // Set up encoding with right mix of src and dst so that the two
710185960dcSAart Bik // method calls can share most parameters, while still providing
711185960dcSAart Bik // the correct sparsity information to either of them.
712185960dcSAart Bik auto enc = SparseTensorEncodingAttr::get(
713185960dcSAart Bik op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
714185960dcSAart Bik encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
715efa15f41SAart Bik newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src);
7169d1db3d4SAart Bik Value coo = genNewCall(rewriter, op, params);
717d7d7ffe2Swren romano params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
718d7d7ffe2Swren romano params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
719845561ecSwren romano params[6] = constantAction(rewriter, loc, Action::kFromCOO);
7209d1db3d4SAart Bik params[7] = coo;
72163bdcaf9Swren romano Value dst = genNewCall(rewriter, op, params);
72263bdcaf9Swren romano genDelCOOCall(rewriter, op, stp.getElementType(), coo);
72363bdcaf9Swren romano rewriter.replaceOp(op, dst);
7248cb33240Swren romano }
7250a7b8cc5SAart Bik return success();
7260a7b8cc5SAart Bik }
72728882b65Swren romano if (!encDst && encSrc) {
72828882b65Swren romano // This is sparse => dense conversion, which is handled as follows:
72928882b65Swren romano // dst = new Tensor(0);
7306be36fd7Swren romano // iter = src->toCOO();
7316be36fd7Swren romano // iter->startIterator();
73228882b65Swren romano // while (elem = iter->getNext()) {
73328882b65Swren romano // dst[elem.indices] = elem.value;
73428882b65Swren romano // }
7356be36fd7Swren romano RankedTensorType dstTensorTp = resType.cast<RankedTensorType>();
7366be36fd7Swren romano RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>();
7376be36fd7Swren romano unsigned rank = dstTensorTp.getRank();
7386be36fd7Swren romano Type elemTp = dstTensorTp.getElementType();
7396be36fd7Swren romano // Fabricate a no-permutation encoding for newParams().
7406be36fd7Swren romano // The pointer/index types must be those of `src`.
741845561ecSwren romano // The dimLevelTypes aren't actually used by Action::kToIterator.
74228882b65Swren romano encDst = SparseTensorEncodingAttr::get(
74328882b65Swren romano op->getContext(),
74428882b65Swren romano SmallVector<SparseTensorEncodingAttr::DimLevelType>(
74528882b65Swren romano rank, SparseTensorEncodingAttr::DimLevelType::Dense),
74628882b65Swren romano AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
74728882b65Swren romano SmallVector<Value, 4> sizes;
74828882b65Swren romano SmallVector<Value, 8> params;
7496be36fd7Swren romano sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src);
750efa15f41SAart Bik newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator,
751efa15f41SAart Bik sizes, src);
75228882b65Swren romano Value iter = genNewCall(rewriter, op, params);
7535389cdc8Swren romano Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
7546be36fd7Swren romano Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
755c66303c2SMatthias Springer Block *insertionBlock = rewriter.getInsertionBlock();
756c66303c2SMatthias Springer // TODO: Dense buffers should be allocated/deallocated via the callback
757c66303c2SMatthias Springer // in BufferizationOptions.
7586be36fd7Swren romano Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes);
75928882b65Swren romano SmallVector<Value> noArgs;
76028882b65Swren romano SmallVector<Type> noTypes;
76128882b65Swren romano auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
762c0342a2dSJacques Pienaar Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
76328882b65Swren romano rewriter.setInsertionPointToEnd(before);
76428882b65Swren romano Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr);
76528882b65Swren romano rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
766c0342a2dSJacques Pienaar Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
76728882b65Swren romano rewriter.setInsertionPointToStart(after);
76828882b65Swren romano insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind);
76928882b65Swren romano rewriter.create<scf::YieldOp>(loc);
77028882b65Swren romano rewriter.setInsertionPointAfter(whileOp);
77163bdcaf9Swren romano genDelCOOCall(rewriter, op, elemTp, iter);
77257470abcSAlexander Belyaev rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst);
773c66303c2SMatthias Springer // Deallocate the buffer.
774c66303c2SMatthias Springer if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) {
775c66303c2SMatthias Springer rewriter.setInsertionPoint(insertionBlock->getTerminator());
776c66303c2SMatthias Springer deallocDenseTensor(rewriter, loc, dst);
777c66303c2SMatthias Springer }
77828882b65Swren romano return success();
77928882b65Swren romano }
78028882b65Swren romano if (!encDst && !encSrc) {
78128882b65Swren romano // dense => dense
782697ea09dSAart Bik return failure();
7830a7b8cc5SAart Bik }
784fbd5821cSBixia Zheng // This is a dense => sparse conversion or a sparse constant in COO =>
785fbd5821cSBixia Zheng // sparse conversion, which is handled as follows:
786236a9080SAart Bik // t = newSparseCOO()
787fbd5821cSBixia Zheng // ...code to fill the COO tensor t...
788fbd5821cSBixia Zheng // s = newSparseTensor(t)
789fbd5821cSBixia Zheng //
790fbd5821cSBixia Zheng // To fill the COO tensor from a dense tensor:
791236a9080SAart Bik // for i1 in dim1
792236a9080SAart Bik // ..
793236a9080SAart Bik // for ik in dimk
794236a9080SAart Bik // val = a[i1,..,ik]
795236a9080SAart Bik // if val != 0
796236a9080SAart Bik // t->add(val, [i1,..,ik], [p1,..,pk])
797fbd5821cSBixia Zheng //
798fbd5821cSBixia Zheng // To fill the COO tensor from a sparse constant in COO format:
799fbd5821cSBixia Zheng // for i in range(NNZ)
800fbd5821cSBixia Zheng // val = values[i]
801fbd5821cSBixia Zheng // [i1,..,ik] = indices[i]
802fbd5821cSBixia Zheng // t->add(val, [i1,..,ik], [p1,..,pk])
803fbd5821cSBixia Zheng //
804236a9080SAart Bik // Note that the dense tensor traversal code is actually implemented
805236a9080SAart Bik // using MLIR IR to avoid having to expose too much low-level
806236a9080SAart Bik // memref traversal details to the runtime support library.
807221856f5Swren romano // Also note that the code below only generates the "new" ops and
808221856f5Swren romano // the loop-nest per se; whereas the entire body of the innermost
809221856f5Swren romano // loop is generated by genAddElt().
8109d1db3d4SAart Bik ShapedType stp = resType.cast<ShapedType>();
8119d1db3d4SAart Bik unsigned rank = stp.getRank();
8129d1db3d4SAart Bik SmallVector<Value, 4> sizes;
8139d1db3d4SAart Bik SmallVector<Value, 8> params;
8149d1db3d4SAart Bik sizesFromSrc(rewriter, sizes, loc, src);
815efa15f41SAart Bik newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes);
81663bdcaf9Swren romano Value coo = genNewCall(rewriter, op, params);
8179d1db3d4SAart Bik Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
8189d1db3d4SAart Bik Value perm = params[2];
819236a9080SAart Bik SmallVector<Value> lo;
820236a9080SAart Bik SmallVector<Value> hi;
821236a9080SAart Bik SmallVector<Value> st;
82263d4fc94Swren romano Value zero = constantIndex(rewriter, loc, 0);
82363d4fc94Swren romano Value one = constantIndex(rewriter, loc, 1);
8245167c36aSwren romano auto indicesValues = genSplitSparseConstant(rewriter, loc, src);
8250916d96dSKazu Hirata bool isCOOConstant = indicesValues.has_value();
826fbd5821cSBixia Zheng Value indices;
827fbd5821cSBixia Zheng Value values;
828fbd5821cSBixia Zheng if (isCOOConstant) {
829fbd5821cSBixia Zheng indices = indicesValues->first;
830fbd5821cSBixia Zheng values = indicesValues->second;
831fbd5821cSBixia Zheng lo.push_back(zero);
832fbd5821cSBixia Zheng hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
833fbd5821cSBixia Zheng st.push_back(one);
834fbd5821cSBixia Zheng } else {
8359d1db3d4SAart Bik for (unsigned i = 0; i < rank; i++) {
836236a9080SAart Bik lo.push_back(zero);
837af7ac1d9Swren romano hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
838236a9080SAart Bik st.push_back(one);
839236a9080SAart Bik }
840fbd5821cSBixia Zheng }
8419d1db3d4SAart Bik Type eltType = stp.getElementType();
842aef20f59SAart Bik Value elemPtr = genAllocaScalar(rewriter, loc, eltType);
843a54f4eaeSMogball scf::buildLoopNest(
844a54f4eaeSMogball rewriter, op.getLoc(), lo, hi, st, {},
845236a9080SAart Bik [&](OpBuilder &builder, Location loc, ValueRange ivs,
846236a9080SAart Bik ValueRange args) -> scf::ValueVector {
847fbd5821cSBixia Zheng Value val;
848fbd5821cSBixia Zheng if (isCOOConstant)
8495167c36aSwren romano val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind,
850a54f4eaeSMogball ivs, rank);
851fbd5821cSBixia Zheng else
8525167c36aSwren romano val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
853aef20f59SAart Bik builder.create<memref::StoreOp>(loc, val, elemPtr);
854aef20f59SAart Bik genAddEltCall(rewriter, op, eltType, coo, elemPtr, ind, perm);
855236a9080SAart Bik return {};
856236a9080SAart Bik });
8579d1db3d4SAart Bik // Final call to construct sparse tensor storage.
858845561ecSwren romano params[6] = constantAction(rewriter, loc, Action::kFromCOO);
85963bdcaf9Swren romano params[7] = coo;
86063bdcaf9Swren romano Value dst = genNewCall(rewriter, op, params);
86163bdcaf9Swren romano genDelCOOCall(rewriter, op, eltType, coo);
86263bdcaf9Swren romano rewriter.replaceOp(op, dst);
86305c7f450SAart Bik return success();
864697ea09dSAart Bik }
865faa00c13SAart Bik
866faa00c13SAart Bik private:
867faa00c13SAart Bik /// Options to control sparse code generation.
868faa00c13SAart Bik SparseTensorConversionOptions options;
869697ea09dSAart Bik };
870697ea09dSAart Bik
871*27a431f5SMatthias Springer /// Sparse conversion rule for the dealloc operator.
872*27a431f5SMatthias Springer class SparseTensorDeallocConverter
873*27a431f5SMatthias Springer : public OpConversionPattern<bufferization::DeallocTensorOp> {
87416b8f4ddSAart Bik public:
87516b8f4ddSAart Bik using OpConversionPattern::OpConversionPattern;
87616b8f4ddSAart Bik LogicalResult
matchAndRewrite(bufferization::DeallocTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const877*27a431f5SMatthias Springer matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
87816b8f4ddSAart Bik ConversionPatternRewriter &rewriter) const override {
879*27a431f5SMatthias Springer auto enc = getSparseTensorEncoding(op.getTensor().getType());
880*27a431f5SMatthias Springer if (!enc)
881*27a431f5SMatthias Springer return failure();
88216b8f4ddSAart Bik StringRef name = "delSparseTensor";
883f527fdf5Swren romano TypeRange noTp;
884d8731bfcSwren romano createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
885d8731bfcSwren romano EmitCInterface::Off);
88616b8f4ddSAart Bik rewriter.eraseOp(op);
88716b8f4ddSAart Bik return success();
88816b8f4ddSAart Bik }
88916b8f4ddSAart Bik };
89016b8f4ddSAart Bik
891a2c9d4bbSAart Bik /// Sparse conversion rule for pointer accesses.
892a2c9d4bbSAart Bik class SparseTensorToPointersConverter
89396a23911SAart Bik : public OpConversionPattern<ToPointersOp> {
894a2c9d4bbSAart Bik public:
895a2c9d4bbSAart Bik using OpConversionPattern::OpConversionPattern;
896a2c9d4bbSAart Bik LogicalResult
matchAndRewrite(ToPointersOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const897b54c724bSRiver Riddle matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
898a2c9d4bbSAart Bik ConversionPatternRewriter &rewriter) const override {
899a2c9d4bbSAart Bik Type resType = op.getType();
900c9489225Swren romano Type ptrType = resType.cast<ShapedType>().getElementType();
901c9489225Swren romano SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)};
902f527fdf5Swren romano replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
903d8731bfcSwren romano EmitCInterface::On);
904a2c9d4bbSAart Bik return success();
905a2c9d4bbSAart Bik }
906a2c9d4bbSAart Bik };
907a2c9d4bbSAart Bik
908a2c9d4bbSAart Bik /// Sparse conversion rule for index accesses.
90996a23911SAart Bik class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
910a2c9d4bbSAart Bik public:
911a2c9d4bbSAart Bik using OpConversionPattern::OpConversionPattern;
912a2c9d4bbSAart Bik LogicalResult
matchAndRewrite(ToIndicesOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const913b54c724bSRiver Riddle matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
914a2c9d4bbSAart Bik ConversionPatternRewriter &rewriter) const override {
915a2c9d4bbSAart Bik Type resType = op.getType();
916c9489225Swren romano Type indType = resType.cast<ShapedType>().getElementType();
917c9489225Swren romano SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)};
918f527fdf5Swren romano replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
919d8731bfcSwren romano EmitCInterface::On);
920a2c9d4bbSAart Bik return success();
921a2c9d4bbSAart Bik }
922a2c9d4bbSAart Bik };
923a2c9d4bbSAart Bik
924a2c9d4bbSAart Bik /// Sparse conversion rule for value accesses.
92596a23911SAart Bik class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
926a2c9d4bbSAart Bik public:
927a2c9d4bbSAart Bik using OpConversionPattern::OpConversionPattern;
928a2c9d4bbSAart Bik LogicalResult
matchAndRewrite(ToValuesOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const929b54c724bSRiver Riddle matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
930a2c9d4bbSAart Bik ConversionPatternRewriter &rewriter) const override {
931a2c9d4bbSAart Bik Type resType = op.getType();
932a2c9d4bbSAart Bik Type eltType = resType.cast<ShapedType>().getElementType();
933c9489225Swren romano SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)};
934f527fdf5Swren romano replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
935d8731bfcSwren romano EmitCInterface::On);
936a2c9d4bbSAart Bik return success();
937a2c9d4bbSAart Bik }
938a2c9d4bbSAart Bik };
939a2c9d4bbSAart Bik
940f66e5769SAart Bik /// Sparse conversion rule for tensor rematerialization.
941f66e5769SAart Bik class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
942727a63e0SAart Bik public:
943727a63e0SAart Bik using OpConversionPattern::OpConversionPattern;
944727a63e0SAart Bik LogicalResult
matchAndRewrite(LoadOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const945f66e5769SAart Bik matchAndRewrite(LoadOp op, OpAdaptor adaptor,
946727a63e0SAart Bik ConversionPatternRewriter &rewriter) const override {
9478df54a6aSJacques Pienaar if (op.getHasInserts()) {
948f66e5769SAart Bik // Finalize any pending insertions.
949f66e5769SAart Bik StringRef name = "endInsert";
950f66e5769SAart Bik TypeRange noTp;
951d8731bfcSwren romano createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
952d8731bfcSwren romano EmitCInterface::Off);
95336b66ab9SAart Bik }
954f66e5769SAart Bik rewriter.replaceOp(op, adaptor.getOperands());
955f66e5769SAart Bik return success();
95636b66ab9SAart Bik }
957f66e5769SAart Bik };
958f66e5769SAart Bik
959f66e5769SAart Bik /// Sparse conversion rule for inserting in lexicographic index order.
960f66e5769SAart Bik class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
961f66e5769SAart Bik public:
962f66e5769SAart Bik using OpConversionPattern::OpConversionPattern;
963f66e5769SAart Bik LogicalResult
matchAndRewrite(LexInsertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const964f66e5769SAart Bik matchAndRewrite(LexInsertOp op, OpAdaptor adaptor,
965f66e5769SAart Bik ConversionPatternRewriter &rewriter) const override {
9668df54a6aSJacques Pienaar Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType();
967c9489225Swren romano SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
968f66e5769SAart Bik TypeRange noTp;
969f527fdf5Swren romano replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
970d8731bfcSwren romano EmitCInterface::On);
97136b66ab9SAart Bik return success();
972727a63e0SAart Bik }
973727a63e0SAart Bik };
974727a63e0SAart Bik
975faa00c13SAart Bik /// Sparse conversion rule for the expand operator.
9764f2ec7f9SAart Bik class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
9774f2ec7f9SAart Bik public:
9784f2ec7f9SAart Bik using OpConversionPattern::OpConversionPattern;
9794f2ec7f9SAart Bik LogicalResult
matchAndRewrite(ExpandOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const9804f2ec7f9SAart Bik matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
9814f2ec7f9SAart Bik ConversionPatternRewriter &rewriter) const override {
9824f2ec7f9SAart Bik Location loc = op->getLoc();
9838df54a6aSJacques Pienaar ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
9844f2ec7f9SAart Bik Type eltType = srcType.getElementType();
9854f2ec7f9SAart Bik Type boolType = rewriter.getIntegerType(1);
9864f2ec7f9SAart Bik Type idxType = rewriter.getIndexType();
9874f2ec7f9SAart Bik // All initialization should be done on entry of the loop nest.
9888df54a6aSJacques Pienaar rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
9894f2ec7f9SAart Bik // Determine the size for access expansion.
9904f2ec7f9SAart Bik auto enc = getSparseTensorEncoding(srcType);
9914f2ec7f9SAart Bik Value src = adaptor.getOperands()[0];
9924f2ec7f9SAart Bik Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1);
9930b55f94dSAart Bik // Allocate temporary buffers for values, filled-switch, and indices.
9940b55f94dSAart Bik // We do not use stack buffers for this, since the expanded size may
9950b55f94dSAart Bik // be rather large (as it envelops a single expanded dense dimension).
9960b55f94dSAart Bik Value values = genAlloc(rewriter, loc, sz, eltType);
9970b55f94dSAart Bik Value filled = genAlloc(rewriter, loc, sz, boolType);
9980b55f94dSAart Bik Value indices = genAlloc(rewriter, loc, sz, idxType);
9994f2ec7f9SAart Bik Value zero = constantZero(rewriter, loc, idxType);
10004f2ec7f9SAart Bik // Reset the values/filled-switch to all-zero/false. Note that this
10014f2ec7f9SAart Bik // introduces an O(N) operation into the computation, but this reset
10024f2ec7f9SAart Bik // operation is amortized over the innermost loops for the access
10030b55f94dSAart Bik // pattern expansion. As noted in the operation doc, we would like
10040b55f94dSAart Bik // to amortize this setup cost even between kernels.
10057294be2bSgysit rewriter.create<linalg::FillOp>(
10067294be2bSgysit loc, ValueRange{constantZero(rewriter, loc, eltType)},
10077294be2bSgysit ValueRange{values});
10087294be2bSgysit rewriter.create<linalg::FillOp>(
10097294be2bSgysit loc, ValueRange{constantZero(rewriter, loc, boolType)},
10107294be2bSgysit ValueRange{filled});
10114f2ec7f9SAart Bik // Replace expansion op with these buffers and initial index.
10124f2ec7f9SAart Bik assert(op.getNumResults() == 4);
10134f2ec7f9SAart Bik rewriter.replaceOp(op, {values, filled, indices, zero});
10144f2ec7f9SAart Bik return success();
10154f2ec7f9SAart Bik }
10164f2ec7f9SAart Bik };
10174f2ec7f9SAart Bik
1018faa00c13SAart Bik /// Sparse conversion rule for the compress operator.
10194f2ec7f9SAart Bik class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
10204f2ec7f9SAart Bik public:
10214f2ec7f9SAart Bik using OpConversionPattern::OpConversionPattern;
10224f2ec7f9SAart Bik LogicalResult
matchAndRewrite(CompressOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const10234f2ec7f9SAart Bik matchAndRewrite(CompressOp op, OpAdaptor adaptor,
10244f2ec7f9SAart Bik ConversionPatternRewriter &rewriter) const override {
10250b55f94dSAart Bik Location loc = op->getLoc();
10264f2ec7f9SAart Bik // Note that this method call resets the values/filled-switch back to
10274f2ec7f9SAart Bik // all-zero/false by only iterating over the set elements, so the
10284f2ec7f9SAart Bik // complexity remains proportional to the sparsity of the expanded
10294f2ec7f9SAart Bik // access pattern.
10308df54a6aSJacques Pienaar Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType();
1031c9489225Swren romano SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
10324f2ec7f9SAart Bik TypeRange noTp;
1033bb8632c1SAart Bik replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
1034bb8632c1SAart Bik EmitCInterface::On);
10350b55f94dSAart Bik // Deallocate the buffers on exit of the loop nest.
10360b55f94dSAart Bik Operation *parent = op;
10370b55f94dSAart Bik for (; isa<scf::ForOp>(parent->getParentOp()) ||
10380b55f94dSAart Bik isa<scf::WhileOp>(parent->getParentOp()) ||
10390b55f94dSAart Bik isa<scf::ParallelOp>(parent->getParentOp()) ||
10400b55f94dSAart Bik isa<scf::IfOp>(parent->getParentOp());
10410b55f94dSAart Bik parent = parent->getParentOp())
10420b55f94dSAart Bik ;
10430b55f94dSAart Bik rewriter.setInsertionPointAfter(parent);
10440b55f94dSAart Bik rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[2]);
10450b55f94dSAart Bik rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[3]);
10460b55f94dSAart Bik rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[4]);
10474f2ec7f9SAart Bik return success();
10484f2ec7f9SAart Bik }
10494f2ec7f9SAart Bik };
10504f2ec7f9SAart Bik
1051faa00c13SAart Bik /// Sparse conversion rule for the output operator.
1052efa15f41SAart Bik class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
1053efa15f41SAart Bik public:
1054efa15f41SAart Bik using OpConversionPattern::OpConversionPattern;
1055efa15f41SAart Bik LogicalResult
matchAndRewrite(OutOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1056efa15f41SAart Bik matchAndRewrite(OutOp op, OpAdaptor adaptor,
1057efa15f41SAart Bik ConversionPatternRewriter &rewriter) const override {
1058efa15f41SAart Bik Location loc = op->getLoc();
10598df54a6aSJacques Pienaar ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
1060efa15f41SAart Bik // Convert to default permuted COO.
1061efa15f41SAart Bik Value src = adaptor.getOperands()[0];
1062efa15f41SAart Bik auto encSrc = getSparseTensorEncoding(srcType);
1063efa15f41SAart Bik SmallVector<Value, 4> sizes;
1064efa15f41SAart Bik SmallVector<Value, 8> params;
1065efa15f41SAart Bik sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src);
1066efa15f41SAart Bik auto enc = SparseTensorEncodingAttr::get(
1067efa15f41SAart Bik op->getContext(), encSrc.getDimLevelType(), AffineMap(),
1068efa15f41SAart Bik encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
1069efa15f41SAart Bik newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src);
1070efa15f41SAart Bik Value coo = genNewCall(rewriter, op, params);
1071efa15f41SAart Bik // Then output the tensor to external file with indices in the externally
1072efa15f41SAart Bik // visible lexicographic index order. A sort is required if the source was
1073efa15f41SAart Bik // not in that order yet (note that the sort can be dropped altogether if
1074efa15f41SAart Bik // external format does not care about the order at all, but here we assume
1075efa15f41SAart Bik // it does).
1076efa15f41SAart Bik bool sort =
1077efa15f41SAart Bik encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity();
1078efa15f41SAart Bik params.clear();
1079efa15f41SAart Bik params.push_back(coo);
1080efa15f41SAart Bik params.push_back(adaptor.getOperands()[1]);
1081efa15f41SAart Bik params.push_back(constantI1(rewriter, loc, sort));
1082efa15f41SAart Bik Type eltType = srcType.getElementType();
1083efa15f41SAart Bik SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)};
1084efa15f41SAart Bik TypeRange noTp;
108563bdcaf9Swren romano createFuncCall(rewriter, op, name, noTp, params, EmitCInterface::Off);
108663bdcaf9Swren romano genDelCOOCall(rewriter, op, eltType, coo);
108763bdcaf9Swren romano rewriter.eraseOp(op);
1088efa15f41SAart Bik return success();
1089efa15f41SAart Bik }
1090efa15f41SAart Bik };
1091efa15f41SAart Bik
1092a2c9d4bbSAart Bik } // namespace
1093a2c9d4bbSAart Bik
109405c7f450SAart Bik //===----------------------------------------------------------------------===//
109505c7f450SAart Bik // Public method for populating conversion rules.
109605c7f450SAart Bik //===----------------------------------------------------------------------===//
109705c7f450SAart Bik
1098a2c9d4bbSAart Bik /// Populates the given patterns list with conversion rules required for
1099a2c9d4bbSAart Bik /// the sparsification of linear algebra operations.
populateSparseTensorConversionPatterns(TypeConverter & typeConverter,RewritePatternSet & patterns,const SparseTensorConversionOptions & options)1100c7e24db4Swren romano void mlir::populateSparseTensorConversionPatterns(
1101c7e24db4Swren romano TypeConverter &typeConverter, RewritePatternSet &patterns,
1102c7e24db4Swren romano const SparseTensorConversionOptions &options) {
110396a23911SAart Bik patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
11041b15160eSAart Bik SparseCastConverter, SparseTensorNewConverter,
1105faa00c13SAart Bik SparseReshapeConverter<tensor::ExpandShapeOp>,
1106faa00c13SAart Bik SparseReshapeConverter<tensor::CollapseShapeOp>,
1107*27a431f5SMatthias Springer SparseTensorAllocConverter, SparseTensorDeallocConverter,
1108c7e24db4Swren romano SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
1109c7e24db4Swren romano SparseTensorToValuesConverter, SparseTensorLoadConverter,
1110c7e24db4Swren romano SparseTensorLexInsertConverter, SparseTensorExpandConverter,
1111c7e24db4Swren romano SparseTensorCompressConverter, SparseTensorOutConverter>(
1112c7e24db4Swren romano typeConverter, patterns.getContext());
1113c7e24db4Swren romano patterns.add<SparseTensorConvertConverter>(typeConverter,
1114c7e24db4Swren romano patterns.getContext(), options);
1115a2c9d4bbSAart Bik }
1116