1// RUN: mlir-opt %s \
2// RUN:   --sparsification --sparse-tensor-conversion \
3// RUN:   --linalg-bufferize --convert-linalg-to-loops \
4// RUN:   --convert-vector-to-scf --convert-scf-to-std \
5// RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
6// RUN:   --std-bufferize --finalizing-bufferize --lower-affine \
7// RUN:   --convert-vector-to-llvm --convert-memref-to-llvm --convert-math-to-llvm \
8// RUN:   --convert-std-to-llvm --reconcile-unrealized-casts | \
9// RUN: mlir-cpu-runner \
10// RUN:  -e entry -entry-point-result=void  \
11// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
12// RUN: FileCheck %s
13
14#DCSR  = #sparse_tensor.encoding<{
15  dimLevelType = [ "compressed", "compressed" ],
16  pointerBitWidth = 8,
17  indexBitWidth = 8
18}>
19
20#DCSC  = #sparse_tensor.encoding<{
21  dimLevelType = [ "compressed", "compressed" ],
22  dimOrdering = affine_map<(i,j) -> (j,i)>,
23  pointerBitWidth = 64,
24  indexBitWidth = 64
25}>
26
27#CSC  = #sparse_tensor.encoding<{
28  dimLevelType = [ "dense", "compressed" ],
29  dimOrdering = affine_map<(i,j) -> (j,i)>,
30  pointerBitWidth = 16,
31  indexBitWidth = 32
32}>
33
34//
35// Integration test that tests conversions between sparse tensors,
36// where the pointer and index sizes in the overhead storage change
37// in addition to layout.
38//
39module {
40
41  //
42  // Helper method to print values and indices arrays. The transfer actually
43  // reads more than required to verify size of buffer as well.
44  //
45  func @dumpf64(%arg0: memref<?xf64>) {
46    %c = arith.constant 0 : index
47    %d = arith.constant -1.0 : f64
48    %0 = vector.transfer_read %arg0[%c], %d: memref<?xf64>, vector<8xf64>
49    vector.print %0 : vector<8xf64>
50    return
51  }
52  func @dumpi08(%arg0: memref<?xi8>) {
53    %c = arith.constant 0 : index
54    %d = arith.constant -1 : i8
55    %0 = vector.transfer_read %arg0[%c], %d: memref<?xi8>, vector<8xi8>
56    vector.print %0 : vector<8xi8>
57    return
58  }
59  func @dumpi32(%arg0: memref<?xi32>) {
60    %c = arith.constant 0 : index
61    %d = arith.constant -1 : i32
62    %0 = vector.transfer_read %arg0[%c], %d: memref<?xi32>, vector<8xi32>
63    vector.print %0 : vector<8xi32>
64    return
65  }
66  func @dumpi64(%arg0: memref<?xi64>) {
67    %c = arith.constant 0 : index
68    %d = arith.constant -1 : i64
69    %0 = vector.transfer_read %arg0[%c], %d: memref<?xi64>, vector<8xi64>
70    vector.print %0 : vector<8xi64>
71    return
72  }
73
74  func @entry() {
75    %c1 = arith.constant 1 : index
76    %t1 = arith.constant sparse<
77      [ [0,0], [0,1], [0,63], [1,0], [1,1], [31,0], [31,63] ],
78        [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 ]> : tensor<32x64xf64>
79    %t2 = tensor.cast %t1 : tensor<32x64xf64> to tensor<?x?xf64>
80
81    // Dense to sparse.
82    %1 = sparse_tensor.convert %t1 : tensor<32x64xf64> to tensor<32x64xf64, #DCSR>
83    %2 = sparse_tensor.convert %t1 : tensor<32x64xf64> to tensor<32x64xf64, #DCSC>
84    %3 = sparse_tensor.convert %t1 : tensor<32x64xf64> to tensor<32x64xf64, #CSC>
85
86    // Sparse to sparse.
87    %4 = sparse_tensor.convert %1 : tensor<32x64xf64, #DCSR> to tensor<32x64xf64, #DCSC>
88    %5 = sparse_tensor.convert %2 : tensor<32x64xf64, #DCSC> to tensor<32x64xf64, #DCSR>
89    %6 = sparse_tensor.convert %3 : tensor<32x64xf64, #CSC>  to tensor<32x64xf64, #DCSR>
90
91    //
92    // All proper row-/column-wise?
93    //
94    // CHECK:      ( 1, 2, 3, 4, 5, 6, 7, -1 )
95    // CHECK-NEXT: ( 1, 4, 6, 2, 5, 3, 7, -1 )
96    // CHECK-NEXT: ( 1, 4, 6, 2, 5, 3, 7, -1 )
97    // CHECK-NEXT: ( 1, 4, 6, 2, 5, 3, 7, -1 )
98    // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, -1 )
99    // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, -1 )
100    //
101    %m1 = sparse_tensor.values %1 : tensor<32x64xf64, #DCSR> to memref<?xf64>
102    %m2 = sparse_tensor.values %2 : tensor<32x64xf64, #DCSC> to memref<?xf64>
103    %m3 = sparse_tensor.values %3 : tensor<32x64xf64, #CSC>  to memref<?xf64>
104    %m4 = sparse_tensor.values %4 : tensor<32x64xf64, #DCSC> to memref<?xf64>
105    %m5 = sparse_tensor.values %5 : tensor<32x64xf64, #DCSR> to memref<?xf64>
106    %m6 = sparse_tensor.values %6 : tensor<32x64xf64, #DCSR> to memref<?xf64>
107    call @dumpf64(%m1) : (memref<?xf64>) -> ()
108    call @dumpf64(%m2) : (memref<?xf64>) -> ()
109    call @dumpf64(%m3) : (memref<?xf64>) -> ()
110    call @dumpf64(%m4) : (memref<?xf64>) -> ()
111    call @dumpf64(%m5) : (memref<?xf64>) -> ()
112    call @dumpf64(%m6) : (memref<?xf64>) -> ()
113
114    //
115    // Sanity check on indices.
116    //
117    // CHECK-NEXT: ( 0, 1, 63, 0, 1, 0, 63, -1 )
118    // CHECK-NEXT: ( 0, 1, 31, 0, 1, 0, 31, -1 )
119    // CHECK-NEXT: ( 0, 1, 31, 0, 1, 0, 31, -1 )
120    // CHECK-NEXT: ( 0, 1, 31, 0, 1, 0, 31, -1 )
121    // CHECK-NEXT: ( 0, 1, 63, 0, 1, 0, 63, -1 )
122    // CHECK-NEXT: ( 0, 1, 63, 0, 1, 0, 63, -1 )
123    //
124    %i1 = sparse_tensor.indices %1, %c1 : tensor<32x64xf64, #DCSR> to memref<?xi8>
125    %i2 = sparse_tensor.indices %2, %c1 : tensor<32x64xf64, #DCSC> to memref<?xi64>
126    %i3 = sparse_tensor.indices %3, %c1 : tensor<32x64xf64, #CSC>  to memref<?xi32>
127    %i4 = sparse_tensor.indices %4, %c1 : tensor<32x64xf64, #DCSC> to memref<?xi64>
128    %i5 = sparse_tensor.indices %5, %c1 : tensor<32x64xf64, #DCSR> to memref<?xi8>
129    %i6 = sparse_tensor.indices %6, %c1 : tensor<32x64xf64, #DCSR> to memref<?xi8>
130    call @dumpi08(%i1) : (memref<?xi8>)  -> ()
131    call @dumpi64(%i2) : (memref<?xi64>) -> ()
132    call @dumpi32(%i3) : (memref<?xi32>) -> ()
133    call @dumpi64(%i4) : (memref<?xi64>) -> ()
134    call @dumpi08(%i5) : (memref<?xi08>) -> ()
135    call @dumpi08(%i6) : (memref<?xi08>) -> ()
136
137    // Release the resources.
138    sparse_tensor.release %1 : tensor<32x64xf64, #DCSR>
139    sparse_tensor.release %2 : tensor<32x64xf64, #DCSC>
140    sparse_tensor.release %3 : tensor<32x64xf64, #CSC>
141    sparse_tensor.release %4 : tensor<32x64xf64, #DCSC>
142    sparse_tensor.release %5 : tensor<32x64xf64, #DCSR>
143    sparse_tensor.release %6 : tensor<32x64xf64, #DCSR>
144
145    return
146  }
147}
148