1*f13893f6SStella Laurenzo //===- DialectLinalg.cpp - 'sparse_tensor' dialect submodule --------------===//
2*f13893f6SStella Laurenzo //
3*f13893f6SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*f13893f6SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5*f13893f6SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*f13893f6SStella Laurenzo //
7*f13893f6SStella Laurenzo //===----------------------------------------------------------------------===//
8*f13893f6SStella Laurenzo 
9*f13893f6SStella Laurenzo #include "Dialects.h"
10*f13893f6SStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h"
11*f13893f6SStella Laurenzo #include "mlir-c/IR.h"
12*f13893f6SStella Laurenzo #include "mlir/Bindings/Python/PybindAdaptors.h"
13*f13893f6SStella Laurenzo 
14*f13893f6SStella Laurenzo namespace py = pybind11;
15*f13893f6SStella Laurenzo using namespace llvm;
16*f13893f6SStella Laurenzo using namespace mlir;
17*f13893f6SStella Laurenzo using namespace mlir::python::adaptors;
18*f13893f6SStella Laurenzo 
19*f13893f6SStella Laurenzo void mlir::python::populateDialectSparseTensorSubmodule(
20*f13893f6SStella Laurenzo     py::module m, const py::module &irModule) {
21*f13893f6SStella Laurenzo   auto attributeClass = irModule.attr("Attribute");
22*f13893f6SStella Laurenzo 
23*f13893f6SStella Laurenzo   py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType")
24*f13893f6SStella Laurenzo       .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE)
25*f13893f6SStella Laurenzo       .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED)
26*f13893f6SStella Laurenzo       .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON);
27*f13893f6SStella Laurenzo 
28*f13893f6SStella Laurenzo   mlir_attribute_subclass(m, "EncodingAttr",
29*f13893f6SStella Laurenzo                           mlirAttributeIsASparseTensorEncodingAttr,
30*f13893f6SStella Laurenzo                           attributeClass)
31*f13893f6SStella Laurenzo       .def_classmethod(
32*f13893f6SStella Laurenzo           "get",
33*f13893f6SStella Laurenzo           [](py::object cls,
34*f13893f6SStella Laurenzo              std::vector<MlirSparseTensorDimLevelType> dimLevelTypes,
35*f13893f6SStella Laurenzo              llvm::Optional<MlirAffineMap> dimOrdering, int pointerBitWidth,
36*f13893f6SStella Laurenzo              int indexBitWidth, MlirContext context) {
37*f13893f6SStella Laurenzo             return cls(mlirSparseTensorEncodingAttrGet(
38*f13893f6SStella Laurenzo                 context, dimLevelTypes.size(), dimLevelTypes.data(),
39*f13893f6SStella Laurenzo                 dimOrdering ? *dimOrdering : MlirAffineMap{nullptr},
40*f13893f6SStella Laurenzo                 pointerBitWidth, indexBitWidth));
41*f13893f6SStella Laurenzo           },
42*f13893f6SStella Laurenzo           py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"),
43*f13893f6SStella Laurenzo           py::arg("pointer_bit_width"), py::arg("index_bit_width"),
44*f13893f6SStella Laurenzo           py::arg("context") = py::none(),
45*f13893f6SStella Laurenzo           "Gets a sparse_tensor.encoding from parameters.")
46*f13893f6SStella Laurenzo       .def_property_readonly(
47*f13893f6SStella Laurenzo           "dim_level_types",
48*f13893f6SStella Laurenzo           [](MlirAttribute self) {
49*f13893f6SStella Laurenzo             std::vector<MlirSparseTensorDimLevelType> ret;
50*f13893f6SStella Laurenzo             for (int i = 0,
51*f13893f6SStella Laurenzo                      e = mlirSparseTensorEncodingGetNumDimLevelTypes(self);
52*f13893f6SStella Laurenzo                  i < e; ++i)
53*f13893f6SStella Laurenzo               ret.push_back(
54*f13893f6SStella Laurenzo                   mlirSparseTensorEncodingAttrGetDimLevelType(self, i));
55*f13893f6SStella Laurenzo             return ret;
56*f13893f6SStella Laurenzo           })
57*f13893f6SStella Laurenzo       .def_property_readonly(
58*f13893f6SStella Laurenzo           "dim_ordering",
59*f13893f6SStella Laurenzo           [](MlirAttribute self) -> llvm::Optional<MlirAffineMap> {
60*f13893f6SStella Laurenzo             MlirAffineMap ret =
61*f13893f6SStella Laurenzo                 mlirSparseTensorEncodingAttrGetDimOrdering(self);
62*f13893f6SStella Laurenzo             if (mlirAffineMapIsNull(ret))
63*f13893f6SStella Laurenzo               return {};
64*f13893f6SStella Laurenzo             return ret;
65*f13893f6SStella Laurenzo           })
66*f13893f6SStella Laurenzo       .def_property_readonly(
67*f13893f6SStella Laurenzo           "pointer_bit_width",
68*f13893f6SStella Laurenzo           [](MlirAttribute self) {
69*f13893f6SStella Laurenzo             return mlirSparseTensorEncodingAttrGetPointerBitWidth(self);
70*f13893f6SStella Laurenzo           })
71*f13893f6SStella Laurenzo       .def_property_readonly("index_bit_width", [](MlirAttribute self) {
72*f13893f6SStella Laurenzo         return mlirSparseTensorEncodingAttrGetIndexBitWidth(self);
73*f13893f6SStella Laurenzo       });
74*f13893f6SStella Laurenzo }
75