1*f13893f6SStella Laurenzo //===- DialectLinalg.cpp - 'sparse_tensor' dialect submodule --------------===// 2*f13893f6SStella Laurenzo // 3*f13893f6SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*f13893f6SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5*f13893f6SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*f13893f6SStella Laurenzo // 7*f13893f6SStella Laurenzo //===----------------------------------------------------------------------===// 8*f13893f6SStella Laurenzo 9*f13893f6SStella Laurenzo #include "Dialects.h" 10*f13893f6SStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h" 11*f13893f6SStella Laurenzo #include "mlir-c/IR.h" 12*f13893f6SStella Laurenzo #include "mlir/Bindings/Python/PybindAdaptors.h" 13*f13893f6SStella Laurenzo 14*f13893f6SStella Laurenzo namespace py = pybind11; 15*f13893f6SStella Laurenzo using namespace llvm; 16*f13893f6SStella Laurenzo using namespace mlir; 17*f13893f6SStella Laurenzo using namespace mlir::python::adaptors; 18*f13893f6SStella Laurenzo 19*f13893f6SStella Laurenzo void mlir::python::populateDialectSparseTensorSubmodule( 20*f13893f6SStella Laurenzo py::module m, const py::module &irModule) { 21*f13893f6SStella Laurenzo auto attributeClass = irModule.attr("Attribute"); 22*f13893f6SStella Laurenzo 23*f13893f6SStella Laurenzo py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType") 24*f13893f6SStella Laurenzo .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) 25*f13893f6SStella Laurenzo .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) 26*f13893f6SStella Laurenzo .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON); 27*f13893f6SStella Laurenzo 28*f13893f6SStella Laurenzo mlir_attribute_subclass(m, "EncodingAttr", 29*f13893f6SStella Laurenzo mlirAttributeIsASparseTensorEncodingAttr, 30*f13893f6SStella Laurenzo attributeClass) 31*f13893f6SStella Laurenzo .def_classmethod( 32*f13893f6SStella Laurenzo "get", 33*f13893f6SStella Laurenzo [](py::object cls, 34*f13893f6SStella Laurenzo std::vector<MlirSparseTensorDimLevelType> dimLevelTypes, 35*f13893f6SStella Laurenzo llvm::Optional<MlirAffineMap> dimOrdering, int pointerBitWidth, 36*f13893f6SStella Laurenzo int indexBitWidth, MlirContext context) { 37*f13893f6SStella Laurenzo return cls(mlirSparseTensorEncodingAttrGet( 38*f13893f6SStella Laurenzo context, dimLevelTypes.size(), dimLevelTypes.data(), 39*f13893f6SStella Laurenzo dimOrdering ? *dimOrdering : MlirAffineMap{nullptr}, 40*f13893f6SStella Laurenzo pointerBitWidth, indexBitWidth)); 41*f13893f6SStella Laurenzo }, 42*f13893f6SStella Laurenzo py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"), 43*f13893f6SStella Laurenzo py::arg("pointer_bit_width"), py::arg("index_bit_width"), 44*f13893f6SStella Laurenzo py::arg("context") = py::none(), 45*f13893f6SStella Laurenzo "Gets a sparse_tensor.encoding from parameters.") 46*f13893f6SStella Laurenzo .def_property_readonly( 47*f13893f6SStella Laurenzo "dim_level_types", 48*f13893f6SStella Laurenzo [](MlirAttribute self) { 49*f13893f6SStella Laurenzo std::vector<MlirSparseTensorDimLevelType> ret; 50*f13893f6SStella Laurenzo for (int i = 0, 51*f13893f6SStella Laurenzo e = mlirSparseTensorEncodingGetNumDimLevelTypes(self); 52*f13893f6SStella Laurenzo i < e; ++i) 53*f13893f6SStella Laurenzo ret.push_back( 54*f13893f6SStella Laurenzo mlirSparseTensorEncodingAttrGetDimLevelType(self, i)); 55*f13893f6SStella Laurenzo return ret; 56*f13893f6SStella Laurenzo }) 57*f13893f6SStella Laurenzo .def_property_readonly( 58*f13893f6SStella Laurenzo "dim_ordering", 59*f13893f6SStella Laurenzo [](MlirAttribute self) -> llvm::Optional<MlirAffineMap> { 60*f13893f6SStella Laurenzo MlirAffineMap ret = 61*f13893f6SStella Laurenzo mlirSparseTensorEncodingAttrGetDimOrdering(self); 62*f13893f6SStella Laurenzo if (mlirAffineMapIsNull(ret)) 63*f13893f6SStella Laurenzo return {}; 64*f13893f6SStella Laurenzo return ret; 65*f13893f6SStella Laurenzo }) 66*f13893f6SStella Laurenzo .def_property_readonly( 67*f13893f6SStella Laurenzo "pointer_bit_width", 68*f13893f6SStella Laurenzo [](MlirAttribute self) { 69*f13893f6SStella Laurenzo return mlirSparseTensorEncodingAttrGetPointerBitWidth(self); 70*f13893f6SStella Laurenzo }) 71*f13893f6SStella Laurenzo .def_property_readonly("index_bit_width", [](MlirAttribute self) { 72*f13893f6SStella Laurenzo return mlirSparseTensorEncodingAttrGetIndexBitWidth(self); 73*f13893f6SStella Laurenzo }); 74*f13893f6SStella Laurenzo } 75