1*bc1df1faSAlex Zinenko //===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===//
2f13893f6SStella Laurenzo //
3f13893f6SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f13893f6SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5f13893f6SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f13893f6SStella Laurenzo //
7f13893f6SStella Laurenzo //===----------------------------------------------------------------------===//
8f13893f6SStella Laurenzo 
9f13893f6SStella Laurenzo #include "Dialects.h"
10f13893f6SStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h"
11f13893f6SStella Laurenzo #include "mlir-c/IR.h"
12f13893f6SStella Laurenzo #include "mlir/Bindings/Python/PybindAdaptors.h"
13f13893f6SStella Laurenzo 
14f13893f6SStella Laurenzo namespace py = pybind11;
15f13893f6SStella Laurenzo using namespace llvm;
16f13893f6SStella Laurenzo using namespace mlir;
17f13893f6SStella Laurenzo using namespace mlir::python::adaptors;
18f13893f6SStella Laurenzo 
19f13893f6SStella Laurenzo void mlir::python::populateDialectSparseTensorSubmodule(
201fc096afSMehdi Amini     const py::module &m, const py::module &irModule) {
21f13893f6SStella Laurenzo   auto attributeClass = irModule.attr("Attribute");
22f13893f6SStella Laurenzo 
238dca953dSSean Silva   py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType", py::module_local())
24f13893f6SStella Laurenzo       .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE)
25f13893f6SStella Laurenzo       .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED)
26f13893f6SStella Laurenzo       .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON);
27f13893f6SStella Laurenzo 
28f13893f6SStella Laurenzo   mlir_attribute_subclass(m, "EncodingAttr",
29f13893f6SStella Laurenzo                           mlirAttributeIsASparseTensorEncodingAttr,
30f13893f6SStella Laurenzo                           attributeClass)
31f13893f6SStella Laurenzo       .def_classmethod(
32f13893f6SStella Laurenzo           "get",
33f13893f6SStella Laurenzo           [](py::object cls,
34f13893f6SStella Laurenzo              std::vector<MlirSparseTensorDimLevelType> dimLevelTypes,
35f13893f6SStella Laurenzo              llvm::Optional<MlirAffineMap> dimOrdering, int pointerBitWidth,
36f13893f6SStella Laurenzo              int indexBitWidth, MlirContext context) {
37f13893f6SStella Laurenzo             return cls(mlirSparseTensorEncodingAttrGet(
38f13893f6SStella Laurenzo                 context, dimLevelTypes.size(), dimLevelTypes.data(),
39f13893f6SStella Laurenzo                 dimOrdering ? *dimOrdering : MlirAffineMap{nullptr},
40f13893f6SStella Laurenzo                 pointerBitWidth, indexBitWidth));
41f13893f6SStella Laurenzo           },
42f13893f6SStella Laurenzo           py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"),
43f13893f6SStella Laurenzo           py::arg("pointer_bit_width"), py::arg("index_bit_width"),
44f13893f6SStella Laurenzo           py::arg("context") = py::none(),
45f13893f6SStella Laurenzo           "Gets a sparse_tensor.encoding from parameters.")
46f13893f6SStella Laurenzo       .def_property_readonly(
47f13893f6SStella Laurenzo           "dim_level_types",
48f13893f6SStella Laurenzo           [](MlirAttribute self) {
49f13893f6SStella Laurenzo             std::vector<MlirSparseTensorDimLevelType> ret;
50f13893f6SStella Laurenzo             for (int i = 0,
51f13893f6SStella Laurenzo                      e = mlirSparseTensorEncodingGetNumDimLevelTypes(self);
52f13893f6SStella Laurenzo                  i < e; ++i)
53f13893f6SStella Laurenzo               ret.push_back(
54f13893f6SStella Laurenzo                   mlirSparseTensorEncodingAttrGetDimLevelType(self, i));
55f13893f6SStella Laurenzo             return ret;
56f13893f6SStella Laurenzo           })
57f13893f6SStella Laurenzo       .def_property_readonly(
58f13893f6SStella Laurenzo           "dim_ordering",
59f13893f6SStella Laurenzo           [](MlirAttribute self) -> llvm::Optional<MlirAffineMap> {
60f13893f6SStella Laurenzo             MlirAffineMap ret =
61f13893f6SStella Laurenzo                 mlirSparseTensorEncodingAttrGetDimOrdering(self);
62f13893f6SStella Laurenzo             if (mlirAffineMapIsNull(ret))
63f13893f6SStella Laurenzo               return {};
64f13893f6SStella Laurenzo             return ret;
65f13893f6SStella Laurenzo           })
66f13893f6SStella Laurenzo       .def_property_readonly(
67f13893f6SStella Laurenzo           "pointer_bit_width",
68f13893f6SStella Laurenzo           [](MlirAttribute self) {
69f13893f6SStella Laurenzo             return mlirSparseTensorEncodingAttrGetPointerBitWidth(self);
70f13893f6SStella Laurenzo           })
71f13893f6SStella Laurenzo       .def_property_readonly("index_bit_width", [](MlirAttribute self) {
72f13893f6SStella Laurenzo         return mlirSparseTensorEncodingAttrGetIndexBitWidth(self);
73f13893f6SStella Laurenzo       });
74f13893f6SStella Laurenzo }
75