1 //===- sparse_tensor.c - Test of sparse_tensor APIs -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 // Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 10 // RUN: mlir-capi-sparse-tensor-test 2>&1 | FileCheck %s 11 12 #include "mlir-c/Dialect/SparseTensor.h" 13 #include "mlir-c/IR.h" 14 #include "mlir-c/Registration.h" 15 16 #include <assert.h> 17 #include <math.h> 18 #include <stdio.h> 19 #include <stdlib.h> 20 #include <string.h> 21 22 // CHECK-LABEL: testRoundtripEncoding() 23 static int testRoundtripEncoding(MlirContext ctx) { 24 fprintf(stderr, "testRoundtripEncoding()\n"); 25 // clang-format off 26 const char *originalAsm = 27 "#sparse_tensor.encoding<{ " 28 "dimLevelType = [ \"dense\", \"compressed\", \"singleton\"], " 29 "dimOrdering = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, " 30 "pointerBitWidth = 32, indexBitWidth = 64 }>"; 31 // clang-format on 32 MlirAttribute originalAttr = 33 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(originalAsm)); 34 // CHECK: isa: 1 35 fprintf(stderr, "isa: %d\n", 36 mlirAttributeIsASparseTensorEncodingAttr(originalAttr)); 37 MlirAffineMap dimOrdering = 38 mlirSparseTensorEncodingAttrGetDimOrdering(originalAttr); 39 // CHECK: (d0, d1, d2) -> (d0, d1, d2) 40 mlirAffineMapDump(dimOrdering); 41 // CHECK: level_type: 0 42 // CHECK: level_type: 1 43 // CHECK: level_type: 2 44 int numLevelTypes = mlirSparseTensorEncodingGetNumDimLevelTypes(originalAttr); 45 enum MlirSparseTensorDimLevelType *levelTypes = 46 malloc(sizeof(enum MlirSparseTensorDimLevelType) * numLevelTypes); 47 for (int i = 0; i < numLevelTypes; ++i) { 48 levelTypes[i] = 49 mlirSparseTensorEncodingAttrGetDimLevelType(originalAttr, i); 50 fprintf(stderr, "level_type: %d\n", levelTypes[i]); 51 } 52 // CHECK: pointer: 32 53 int pointerBitWidth = 54 mlirSparseTensorEncodingAttrGetPointerBitWidth(originalAttr); 55 fprintf(stderr, "pointer: %d\n", pointerBitWidth); 56 // CHECK: index: 64 57 int indexBitWidth = 58 mlirSparseTensorEncodingAttrGetIndexBitWidth(originalAttr); 59 fprintf(stderr, "index: %d\n", indexBitWidth); 60 61 MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet( 62 ctx, numLevelTypes, levelTypes, dimOrdering, pointerBitWidth, 63 indexBitWidth); 64 mlirAttributeDump(newAttr); // For debugging filecheck output. 65 // CHECK: equal: 1 66 fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr)); 67 68 free(levelTypes); 69 return 0; 70 } 71 72 int main() { 73 MlirContext ctx = mlirContextCreate(); 74 mlirDialectHandleRegisterDialect(mlirGetDialectHandle__sparse_tensor__(), 75 ctx); 76 if (testRoundtripEncoding(ctx)) 77 return 1; 78 79 mlirContextDestroy(ctx); 80 return 0; 81 } 82