1 //===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/Dialect/SparseTensor.h"
10 #include "mlir-c/IR.h"
11 #include "mlir/Bindings/Python/PybindAdaptors.h"
12
13 namespace py = pybind11;
14 using namespace llvm;
15 using namespace mlir;
16 using namespace mlir::python::adaptors;
17
populateDialectSparseTensorSubmodule(const py::module & m)18 static void populateDialectSparseTensorSubmodule(const py::module &m) {
19 py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType", py::module_local())
20 .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE)
21 .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED)
22 .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON);
23
24 mlir_attribute_subclass(m, "EncodingAttr",
25 mlirAttributeIsASparseTensorEncodingAttr)
26 .def_classmethod(
27 "get",
28 [](py::object cls,
29 std::vector<MlirSparseTensorDimLevelType> dimLevelTypes,
30 llvm::Optional<MlirAffineMap> dimOrdering, int pointerBitWidth,
31 int indexBitWidth, MlirContext context) {
32 return cls(mlirSparseTensorEncodingAttrGet(
33 context, dimLevelTypes.size(), dimLevelTypes.data(),
34 dimOrdering ? *dimOrdering : MlirAffineMap{nullptr},
35 pointerBitWidth, indexBitWidth));
36 },
37 py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"),
38 py::arg("pointer_bit_width"), py::arg("index_bit_width"),
39 py::arg("context") = py::none(),
40 "Gets a sparse_tensor.encoding from parameters.")
41 .def_property_readonly(
42 "dim_level_types",
43 [](MlirAttribute self) {
44 std::vector<MlirSparseTensorDimLevelType> ret;
45 for (int i = 0,
46 e = mlirSparseTensorEncodingGetNumDimLevelTypes(self);
47 i < e; ++i)
48 ret.push_back(
49 mlirSparseTensorEncodingAttrGetDimLevelType(self, i));
50 return ret;
51 })
52 .def_property_readonly(
53 "dim_ordering",
54 [](MlirAttribute self) -> llvm::Optional<MlirAffineMap> {
55 MlirAffineMap ret =
56 mlirSparseTensorEncodingAttrGetDimOrdering(self);
57 if (mlirAffineMapIsNull(ret))
58 return {};
59 return ret;
60 })
61 .def_property_readonly(
62 "pointer_bit_width",
63 [](MlirAttribute self) {
64 return mlirSparseTensorEncodingAttrGetPointerBitWidth(self);
65 })
66 .def_property_readonly("index_bit_width", [](MlirAttribute self) {
67 return mlirSparseTensorEncodingAttrGetIndexBitWidth(self);
68 });
69 }
70
PYBIND11_MODULE(_mlirDialectsSparseTensor,m)71 PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
72 m.doc() = "MLIR SparseTensor dialect.";
73 populateDialectSparseTensorSubmodule(m);
74 }
75