1bc1df1faSAlex Zinenko //===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===//
2f13893f6SStella Laurenzo //
3f13893f6SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f13893f6SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5f13893f6SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f13893f6SStella Laurenzo //
7f13893f6SStella Laurenzo //===----------------------------------------------------------------------===//
8f13893f6SStella Laurenzo 
9f13893f6SStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h"
10f13893f6SStella Laurenzo #include "mlir-c/IR.h"
11f13893f6SStella Laurenzo #include "mlir/Bindings/Python/PybindAdaptors.h"
12f13893f6SStella Laurenzo 
13f13893f6SStella Laurenzo namespace py = pybind11;
14f13893f6SStella Laurenzo using namespace llvm;
15f13893f6SStella Laurenzo using namespace mlir;
16f13893f6SStella Laurenzo using namespace mlir::python::adaptors;
17f13893f6SStella Laurenzo 
populateDialectSparseTensorSubmodule(const py::module & m)18*95ddbed9SAlex Zinenko static void populateDialectSparseTensorSubmodule(const py::module &m) {
198dca953dSSean Silva   py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType", py::module_local())
20f13893f6SStella Laurenzo       .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE)
21f13893f6SStella Laurenzo       .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED)
22f13893f6SStella Laurenzo       .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON);
23f13893f6SStella Laurenzo 
24f13893f6SStella Laurenzo   mlir_attribute_subclass(m, "EncodingAttr",
25*95ddbed9SAlex Zinenko                           mlirAttributeIsASparseTensorEncodingAttr)
26f13893f6SStella Laurenzo       .def_classmethod(
27f13893f6SStella Laurenzo           "get",
28f13893f6SStella Laurenzo           [](py::object cls,
29f13893f6SStella Laurenzo              std::vector<MlirSparseTensorDimLevelType> dimLevelTypes,
30f13893f6SStella Laurenzo              llvm::Optional<MlirAffineMap> dimOrdering, int pointerBitWidth,
31f13893f6SStella Laurenzo              int indexBitWidth, MlirContext context) {
32f13893f6SStella Laurenzo             return cls(mlirSparseTensorEncodingAttrGet(
33f13893f6SStella Laurenzo                 context, dimLevelTypes.size(), dimLevelTypes.data(),
34f13893f6SStella Laurenzo                 dimOrdering ? *dimOrdering : MlirAffineMap{nullptr},
35f13893f6SStella Laurenzo                 pointerBitWidth, indexBitWidth));
36f13893f6SStella Laurenzo           },
37f13893f6SStella Laurenzo           py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"),
38f13893f6SStella Laurenzo           py::arg("pointer_bit_width"), py::arg("index_bit_width"),
39f13893f6SStella Laurenzo           py::arg("context") = py::none(),
40f13893f6SStella Laurenzo           "Gets a sparse_tensor.encoding from parameters.")
41f13893f6SStella Laurenzo       .def_property_readonly(
42f13893f6SStella Laurenzo           "dim_level_types",
43f13893f6SStella Laurenzo           [](MlirAttribute self) {
44f13893f6SStella Laurenzo             std::vector<MlirSparseTensorDimLevelType> ret;
45f13893f6SStella Laurenzo             for (int i = 0,
46f13893f6SStella Laurenzo                      e = mlirSparseTensorEncodingGetNumDimLevelTypes(self);
47f13893f6SStella Laurenzo                  i < e; ++i)
48f13893f6SStella Laurenzo               ret.push_back(
49f13893f6SStella Laurenzo                   mlirSparseTensorEncodingAttrGetDimLevelType(self, i));
50f13893f6SStella Laurenzo             return ret;
51f13893f6SStella Laurenzo           })
52f13893f6SStella Laurenzo       .def_property_readonly(
53f13893f6SStella Laurenzo           "dim_ordering",
54f13893f6SStella Laurenzo           [](MlirAttribute self) -> llvm::Optional<MlirAffineMap> {
55f13893f6SStella Laurenzo             MlirAffineMap ret =
56f13893f6SStella Laurenzo                 mlirSparseTensorEncodingAttrGetDimOrdering(self);
57f13893f6SStella Laurenzo             if (mlirAffineMapIsNull(ret))
58f13893f6SStella Laurenzo               return {};
59f13893f6SStella Laurenzo             return ret;
60f13893f6SStella Laurenzo           })
61f13893f6SStella Laurenzo       .def_property_readonly(
62f13893f6SStella Laurenzo           "pointer_bit_width",
63f13893f6SStella Laurenzo           [](MlirAttribute self) {
64f13893f6SStella Laurenzo             return mlirSparseTensorEncodingAttrGetPointerBitWidth(self);
65f13893f6SStella Laurenzo           })
66f13893f6SStella Laurenzo       .def_property_readonly("index_bit_width", [](MlirAttribute self) {
67f13893f6SStella Laurenzo         return mlirSparseTensorEncodingAttrGetIndexBitWidth(self);
68f13893f6SStella Laurenzo       });
69f13893f6SStella Laurenzo }
70*95ddbed9SAlex Zinenko 
PYBIND11_MODULE(_mlirDialectsSparseTensor,m)71*95ddbed9SAlex Zinenko PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
72*95ddbed9SAlex Zinenko   m.doc() = "MLIR SparseTensor dialect.";
73*95ddbed9SAlex Zinenko   populateDialectSparseTensorSubmodule(m);
74*95ddbed9SAlex Zinenko }
75