1 //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/Dialect/SparseTensor.h"
10 #include "mlir-c/IR.h"
11 #include "mlir/CAPI/AffineMap.h"
12 #include "mlir/CAPI/Registration.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 #include "mlir/Support/LLVM.h"
15
16 using namespace llvm;
17 using namespace mlir::sparse_tensor;
18
19 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
20 mlir::sparse_tensor::SparseTensorDialect)
21
22 // Ensure the C-API enums are int-castable to C++ equivalents.
23 static_assert(
24 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) ==
25 static_cast<int>(SparseTensorEncodingAttr::DimLevelType::Dense) &&
26 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) ==
27 static_cast<int>(
28 SparseTensorEncodingAttr::DimLevelType::Compressed) &&
29 static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) ==
30 static_cast<int>(SparseTensorEncodingAttr::DimLevelType::Singleton),
31 "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch");
32
mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr)33 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
34 return unwrap(attr).isa<SparseTensorEncodingAttr>();
35 }
36
mlirSparseTensorEncodingAttrGet(MlirContext ctx,intptr_t numDimLevelTypes,MlirSparseTensorDimLevelType const * dimLevelTypes,MlirAffineMap dimOrdering,int pointerBitWidth,int indexBitWidth)37 MlirAttribute mlirSparseTensorEncodingAttrGet(
38 MlirContext ctx, intptr_t numDimLevelTypes,
39 MlirSparseTensorDimLevelType const *dimLevelTypes,
40 MlirAffineMap dimOrdering, int pointerBitWidth, int indexBitWidth) {
41 SmallVector<SparseTensorEncodingAttr::DimLevelType> cppDimLevelTypes;
42 cppDimLevelTypes.resize(numDimLevelTypes);
43 for (intptr_t i = 0; i < numDimLevelTypes; ++i)
44 cppDimLevelTypes[i] =
45 static_cast<SparseTensorEncodingAttr::DimLevelType>(dimLevelTypes[i]);
46 return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppDimLevelTypes,
47 unwrap(dimOrdering),
48 pointerBitWidth, indexBitWidth));
49 }
50
mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr)51 MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) {
52 return wrap(unwrap(attr).cast<SparseTensorEncodingAttr>().getDimOrdering());
53 }
54
mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr)55 intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) {
56 return unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType().size();
57 }
58
59 MlirSparseTensorDimLevelType
mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr,intptr_t pos)60 mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) {
61 return static_cast<MlirSparseTensorDimLevelType>(
62 unwrap(attr).cast<SparseTensorEncodingAttr>().getDimLevelType()[pos]);
63 }
64
mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr)65 int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) {
66 return unwrap(attr).cast<SparseTensorEncodingAttr>().getPointerBitWidth();
67 }
68
mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr)69 int mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr) {
70 return unwrap(attr).cast<SparseTensorEncodingAttr>().getIndexBitWidth();
71 }
72