1*bc1df1faSAlex Zinenko //===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===// 2f13893f6SStella Laurenzo // 3f13893f6SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4f13893f6SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5f13893f6SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6f13893f6SStella Laurenzo // 7f13893f6SStella Laurenzo //===----------------------------------------------------------------------===// 8f13893f6SStella Laurenzo 9f13893f6SStella Laurenzo #include "Dialects.h" 10f13893f6SStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h" 11f13893f6SStella Laurenzo #include "mlir-c/IR.h" 12f13893f6SStella Laurenzo #include "mlir/Bindings/Python/PybindAdaptors.h" 13f13893f6SStella Laurenzo 14f13893f6SStella Laurenzo namespace py = pybind11; 15f13893f6SStella Laurenzo using namespace llvm; 16f13893f6SStella Laurenzo using namespace mlir; 17f13893f6SStella Laurenzo using namespace mlir::python::adaptors; 18f13893f6SStella Laurenzo 19f13893f6SStella Laurenzo void mlir::python::populateDialectSparseTensorSubmodule( 201fc096afSMehdi Amini const py::module &m, const py::module &irModule) { 21f13893f6SStella Laurenzo auto attributeClass = irModule.attr("Attribute"); 22f13893f6SStella Laurenzo 238dca953dSSean Silva py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType", py::module_local()) 24f13893f6SStella Laurenzo .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) 25f13893f6SStella Laurenzo .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) 26f13893f6SStella Laurenzo .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON); 27f13893f6SStella Laurenzo 28f13893f6SStella Laurenzo mlir_attribute_subclass(m, "EncodingAttr", 29f13893f6SStella Laurenzo mlirAttributeIsASparseTensorEncodingAttr, 30f13893f6SStella Laurenzo attributeClass) 31f13893f6SStella Laurenzo .def_classmethod( 32f13893f6SStella Laurenzo "get", 33f13893f6SStella Laurenzo [](py::object cls, 34f13893f6SStella Laurenzo std::vector<MlirSparseTensorDimLevelType> dimLevelTypes, 35f13893f6SStella Laurenzo llvm::Optional<MlirAffineMap> dimOrdering, int pointerBitWidth, 36f13893f6SStella Laurenzo int indexBitWidth, MlirContext context) { 37f13893f6SStella Laurenzo return cls(mlirSparseTensorEncodingAttrGet( 38f13893f6SStella Laurenzo context, dimLevelTypes.size(), dimLevelTypes.data(), 39f13893f6SStella Laurenzo dimOrdering ? *dimOrdering : MlirAffineMap{nullptr}, 40f13893f6SStella Laurenzo pointerBitWidth, indexBitWidth)); 41f13893f6SStella Laurenzo }, 42f13893f6SStella Laurenzo py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"), 43f13893f6SStella Laurenzo py::arg("pointer_bit_width"), py::arg("index_bit_width"), 44f13893f6SStella Laurenzo py::arg("context") = py::none(), 45f13893f6SStella Laurenzo "Gets a sparse_tensor.encoding from parameters.") 46f13893f6SStella Laurenzo .def_property_readonly( 47f13893f6SStella Laurenzo "dim_level_types", 48f13893f6SStella Laurenzo [](MlirAttribute self) { 49f13893f6SStella Laurenzo std::vector<MlirSparseTensorDimLevelType> ret; 50f13893f6SStella Laurenzo for (int i = 0, 51f13893f6SStella Laurenzo e = mlirSparseTensorEncodingGetNumDimLevelTypes(self); 52f13893f6SStella Laurenzo i < e; ++i) 53f13893f6SStella Laurenzo ret.push_back( 54f13893f6SStella Laurenzo mlirSparseTensorEncodingAttrGetDimLevelType(self, i)); 55f13893f6SStella Laurenzo return ret; 56f13893f6SStella Laurenzo }) 57f13893f6SStella Laurenzo .def_property_readonly( 58f13893f6SStella Laurenzo "dim_ordering", 59f13893f6SStella Laurenzo [](MlirAttribute self) -> llvm::Optional<MlirAffineMap> { 60f13893f6SStella Laurenzo MlirAffineMap ret = 61f13893f6SStella Laurenzo mlirSparseTensorEncodingAttrGetDimOrdering(self); 62f13893f6SStella Laurenzo if (mlirAffineMapIsNull(ret)) 63f13893f6SStella Laurenzo return {}; 64f13893f6SStella Laurenzo return ret; 65f13893f6SStella Laurenzo }) 66f13893f6SStella Laurenzo .def_property_readonly( 67f13893f6SStella Laurenzo "pointer_bit_width", 68f13893f6SStella Laurenzo [](MlirAttribute self) { 69f13893f6SStella Laurenzo return mlirSparseTensorEncodingAttrGetPointerBitWidth(self); 70f13893f6SStella Laurenzo }) 71f13893f6SStella Laurenzo .def_property_readonly("index_bit_width", [](MlirAttribute self) { 72f13893f6SStella Laurenzo return mlirSparseTensorEncodingAttrGetIndexBitWidth(self); 73f13893f6SStella Laurenzo }); 74f13893f6SStella Laurenzo } 75