1*bcfa7baeSStella Laurenzo //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===//
2*bcfa7baeSStella Laurenzo //
3*bcfa7baeSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*bcfa7baeSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5*bcfa7baeSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*bcfa7baeSStella Laurenzo //
7*bcfa7baeSStella Laurenzo //===----------------------------------------------------------------------===//
8*bcfa7baeSStella Laurenzo 
9*bcfa7baeSStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h"
10*bcfa7baeSStella Laurenzo #include "mlir-c/IR.h"
11*bcfa7baeSStella Laurenzo #include "mlir/CAPI/AffineMap.h"
12*bcfa7baeSStella Laurenzo #include "mlir/CAPI/Registration.h"
13*bcfa7baeSStella Laurenzo #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14*bcfa7baeSStella Laurenzo #include "mlir/Support/LLVM.h"
15*bcfa7baeSStella Laurenzo 
16*bcfa7baeSStella Laurenzo using namespace llvm;
17*bcfa7baeSStella Laurenzo using namespace mlir::sparse_tensor;
18*bcfa7baeSStella Laurenzo 
19*bcfa7baeSStella Laurenzo MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
20*bcfa7baeSStella Laurenzo                                       mlir::sparse_tensor::SparseTensorDialect)
21*bcfa7baeSStella Laurenzo 
22*bcfa7baeSStella Laurenzo // Ensure the C-API enums are int-castable to C++ equivalents.
23*bcfa7baeSStella Laurenzo static_assert(
24*bcfa7baeSStella Laurenzo     static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) ==
25*bcfa7baeSStella Laurenzo             static_cast<int>(SparseTensorEncodingAttr::DimLevelType::Dense) &&
26*bcfa7baeSStella Laurenzo         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) ==
27*bcfa7baeSStella Laurenzo             static_cast<int>(
28*bcfa7baeSStella Laurenzo                 SparseTensorEncodingAttr::DimLevelType::Compressed) &&
29*bcfa7baeSStella Laurenzo         static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) ==
30*bcfa7baeSStella Laurenzo             static_cast<int>(SparseTensorEncodingAttr::DimLevelType::Singleton),
31*bcfa7baeSStella Laurenzo     "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch");
32*bcfa7baeSStella Laurenzo 
mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr)33*bcfa7baeSStella Laurenzo bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
34*bcfa7baeSStella Laurenzo   return unwrap(attr).isa<SparseTensorEncodingAttr>();
35*bcfa7baeSStella Laurenzo }
36*bcfa7baeSStella Laurenzo 
mlirSparseTensorEncodingAttrGet(MlirContext ctx,intptr_t numDimLevelTypes,MlirSparseTensorDimLevelType const * dimLevelTypes,MlirAffineMap dimOrdering,int pointerBitWidth,int indexBitWidth)37*bcfa7baeSStella Laurenzo MlirAttribute mlirSparseTensorEncodingAttrGet(
38*bcfa7baeSStella Laurenzo     MlirContext ctx, intptr_t numDimLevelTypes,
39*bcfa7baeSStella Laurenzo     MlirSparseTensorDimLevelType const *dimLevelTypes,
40*bcfa7baeSStella Laurenzo     MlirAffineMap dimOrdering, int pointerBitWidth, int indexBitWidth) {
41*bcfa7baeSStella Laurenzo   SmallVector<SparseTensorEncodingAttr::DimLevelType> cppDimLevelTypes;
42*bcfa7baeSStella Laurenzo   cppDimLevelTypes.resize(numDimLevelTypes);
43*bcfa7baeSStella Laurenzo   for (intptr_t i = 0; i < numDimLevelTypes; ++i)
44*bcfa7baeSStella Laurenzo     cppDimLevelTypes[i] =
45*bcfa7baeSStella Laurenzo         static_cast<SparseTensorEncodingAttr::DimLevelType>(dimLevelTypes[i]);
46*bcfa7baeSStella Laurenzo   return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppDimLevelTypes,
47*bcfa7baeSStella Laurenzo                                             unwrap(dimOrdering),
48*bcfa7baeSStella Laurenzo                                             pointerBitWidth, indexBitWidth));
49*bcfa7baeSStella Laurenzo }
50*bcfa7baeSStella Laurenzo 
mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr)51*bcfa7baeSStella Laurenzo MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) {
52*bcfa7baeSStella Laurenzo   return wrap(unwrap(attr).cast<SparseTensorEncodingAttr>().getDimOrdering());
53*bcfa7baeSStella Laurenzo }
54*bcfa7baeSStella Laurenzo 
mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr)55*bcfa7baeSStella Laurenzo intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) {
56*bcfa7baeSStella Laurenzo   return unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType().size();
57*bcfa7baeSStella Laurenzo }
58*bcfa7baeSStella Laurenzo 
59*bcfa7baeSStella Laurenzo MlirSparseTensorDimLevelType
mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr,intptr_t pos)60*bcfa7baeSStella Laurenzo mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) {
61*bcfa7baeSStella Laurenzo   return static_cast<MlirSparseTensorDimLevelType>(
62*bcfa7baeSStella Laurenzo       unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType()[pos]);
63*bcfa7baeSStella Laurenzo }
64*bcfa7baeSStella Laurenzo 
mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr)65*bcfa7baeSStella Laurenzo int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) {
66*bcfa7baeSStella Laurenzo   return unwrap(attr).cast<SparseTensorEncodingAttr>().getPointerBitWidth();
67*bcfa7baeSStella Laurenzo }
68*bcfa7baeSStella Laurenzo 
mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr)69*bcfa7baeSStella Laurenzo int mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr) {
70*bcfa7baeSStella Laurenzo   return unwrap(attr).cast<SparseTensorEncodingAttr>().getIndexBitWidth();
71*bcfa7baeSStella Laurenzo }
72