1bc1df1faSAlex Zinenko //===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===//
2f13893f6SStella Laurenzo //
3f13893f6SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f13893f6SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5f13893f6SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f13893f6SStella Laurenzo //
7f13893f6SStella Laurenzo //===----------------------------------------------------------------------===//
8f13893f6SStella Laurenzo
9f13893f6SStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h"
10f13893f6SStella Laurenzo #include "mlir-c/IR.h"
11f13893f6SStella Laurenzo #include "mlir/Bindings/Python/PybindAdaptors.h"
12f13893f6SStella Laurenzo
13f13893f6SStella Laurenzo namespace py = pybind11;
14f13893f6SStella Laurenzo using namespace llvm;
15f13893f6SStella Laurenzo using namespace mlir;
16f13893f6SStella Laurenzo using namespace mlir::python::adaptors;
17f13893f6SStella Laurenzo
populateDialectSparseTensorSubmodule(const py::module & m)18*95ddbed9SAlex Zinenko static void populateDialectSparseTensorSubmodule(const py::module &m) {
198dca953dSSean Silva py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType", py::module_local())
20f13893f6SStella Laurenzo .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE)
21f13893f6SStella Laurenzo .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED)
22f13893f6SStella Laurenzo .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON);
23f13893f6SStella Laurenzo
24f13893f6SStella Laurenzo mlir_attribute_subclass(m, "EncodingAttr",
25*95ddbed9SAlex Zinenko mlirAttributeIsASparseTensorEncodingAttr)
26f13893f6SStella Laurenzo .def_classmethod(
27f13893f6SStella Laurenzo "get",
28f13893f6SStella Laurenzo [](py::object cls,
29f13893f6SStella Laurenzo std::vector<MlirSparseTensorDimLevelType> dimLevelTypes,
30f13893f6SStella Laurenzo llvm::Optional<MlirAffineMap> dimOrdering, int pointerBitWidth,
31f13893f6SStella Laurenzo int indexBitWidth, MlirContext context) {
32f13893f6SStella Laurenzo return cls(mlirSparseTensorEncodingAttrGet(
33f13893f6SStella Laurenzo context, dimLevelTypes.size(), dimLevelTypes.data(),
34f13893f6SStella Laurenzo dimOrdering ? *dimOrdering : MlirAffineMap{nullptr},
35f13893f6SStella Laurenzo pointerBitWidth, indexBitWidth));
36f13893f6SStella Laurenzo },
37f13893f6SStella Laurenzo py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"),
38f13893f6SStella Laurenzo py::arg("pointer_bit_width"), py::arg("index_bit_width"),
39f13893f6SStella Laurenzo py::arg("context") = py::none(),
40f13893f6SStella Laurenzo "Gets a sparse_tensor.encoding from parameters.")
41f13893f6SStella Laurenzo .def_property_readonly(
42f13893f6SStella Laurenzo "dim_level_types",
43f13893f6SStella Laurenzo [](MlirAttribute self) {
44f13893f6SStella Laurenzo std::vector<MlirSparseTensorDimLevelType> ret;
45f13893f6SStella Laurenzo for (int i = 0,
46f13893f6SStella Laurenzo e = mlirSparseTensorEncodingGetNumDimLevelTypes(self);
47f13893f6SStella Laurenzo i < e; ++i)
48f13893f6SStella Laurenzo ret.push_back(
49f13893f6SStella Laurenzo mlirSparseTensorEncodingAttrGetDimLevelType(self, i));
50f13893f6SStella Laurenzo return ret;
51f13893f6SStella Laurenzo })
52f13893f6SStella Laurenzo .def_property_readonly(
53f13893f6SStella Laurenzo "dim_ordering",
54f13893f6SStella Laurenzo [](MlirAttribute self) -> llvm::Optional<MlirAffineMap> {
55f13893f6SStella Laurenzo MlirAffineMap ret =
56f13893f6SStella Laurenzo mlirSparseTensorEncodingAttrGetDimOrdering(self);
57f13893f6SStella Laurenzo if (mlirAffineMapIsNull(ret))
58f13893f6SStella Laurenzo return {};
59f13893f6SStella Laurenzo return ret;
60f13893f6SStella Laurenzo })
61f13893f6SStella Laurenzo .def_property_readonly(
62f13893f6SStella Laurenzo "pointer_bit_width",
63f13893f6SStella Laurenzo [](MlirAttribute self) {
64f13893f6SStella Laurenzo return mlirSparseTensorEncodingAttrGetPointerBitWidth(self);
65f13893f6SStella Laurenzo })
66f13893f6SStella Laurenzo .def_property_readonly("index_bit_width", [](MlirAttribute self) {
67f13893f6SStella Laurenzo return mlirSparseTensorEncodingAttrGetIndexBitWidth(self);
68f13893f6SStella Laurenzo });
69f13893f6SStella Laurenzo }
70*95ddbed9SAlex Zinenko
PYBIND11_MODULE(_mlirDialectsSparseTensor,m)71*95ddbed9SAlex Zinenko PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
72*95ddbed9SAlex Zinenko m.doc() = "MLIR SparseTensor dialect.";
73*95ddbed9SAlex Zinenko populateDialectSparseTensorSubmodule(m);
74*95ddbed9SAlex Zinenko }
75