1// RUN: mlir-opt %s --sparse-compiler | \
2// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
3// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
4// RUN: FileCheck %s
5//
6// Do the same run, but now with SIMDization as well. This should not change the outcome.
7//
8// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=2 vl=4" | \
9// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
10// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
11// RUN: FileCheck %s
12
13#SparseVector = #sparse_tensor.encoding<{
14  dimLevelType = ["compressed"]
15}>
16
17#SparseMatrix = #sparse_tensor.encoding<{
18  dimLevelType = ["compressed", "compressed"]
19}>
20
21#trait_1d = {
22  indexing_maps = [
23    affine_map<(i) -> (i)>,  // a
24    affine_map<(i) -> (i)>   // x (out)
25  ],
26  iterator_types = ["parallel"],
27  doc = "X(i) = a(i) op i"
28}
29
30#trait_2d = {
31  indexing_maps = [
32    affine_map<(i,j) -> (i,j)>,  // A
33    affine_map<(i,j) -> (i,j)>   // X (out)
34  ],
35  iterator_types = ["parallel", "parallel"],
36  doc = "X(i,j) = A(i,j) op i op j"
37}
38
39//
40// Test with indices and sparse inputs. All outputs are dense.
41//
42module {
43
44  //
45  // Kernel that uses index in the index notation (conjunction).
46  //
47  func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> {
48    %init = linalg.init_tensor [8] : tensor<8xi64>
49    %r = linalg.generic #trait_1d
50        ins(%arga: tensor<8xi64, #SparseVector>)
51       outs(%init: tensor<8xi64>) {
52        ^bb(%a: i64, %x: i64):
53          %i = linalg.index 0 : index
54          %ii = arith.index_cast %i : index to i64
55          %m1 = arith.muli %a, %ii : i64
56          linalg.yield %m1 : i64
57    } -> tensor<8xi64>
58    return %r : tensor<8xi64>
59  }
60
61  //
62  // Kernel that uses index in the index notation (disjunction).
63  //
64  func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> {
65    %init = linalg.init_tensor [8] : tensor<8xi64>
66    %r = linalg.generic #trait_1d
67        ins(%arga: tensor<8xi64, #SparseVector>)
68       outs(%init: tensor<8xi64>) {
69        ^bb(%a: i64, %x: i64):
70          %i = linalg.index 0 : index
71          %ii = arith.index_cast %i : index to i64
72          %m1 = arith.addi %a, %ii : i64
73          linalg.yield %m1 : i64
74    } -> tensor<8xi64>
75    return %r : tensor<8xi64>
76  }
77
78  //
79  // Kernel that uses indices in the index notation (conjunction).
80  //
81  func.func @sparse_index_2d_conj(%arga: tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64> {
82    %init = linalg.init_tensor [3,4] : tensor<3x4xi64>
83    %r = linalg.generic #trait_2d
84        ins(%arga: tensor<3x4xi64, #SparseMatrix>)
85       outs(%init: tensor<3x4xi64>) {
86        ^bb(%a: i64, %x: i64):
87          %i = linalg.index 0 : index
88          %j = linalg.index 1 : index
89          %ii = arith.index_cast %i : index to i64
90          %jj = arith.index_cast %j : index to i64
91          %m1 = arith.muli %ii, %a : i64
92          %m2 = arith.muli %jj, %m1 : i64
93          linalg.yield %m2 : i64
94    } -> tensor<3x4xi64>
95    return %r : tensor<3x4xi64>
96  }
97
98  //
99  // Kernel that uses indices in the index notation (disjunction).
100  //
101  func.func @sparse_index_2d_disj(%arga: tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64> {
102    %init = linalg.init_tensor [3,4] : tensor<3x4xi64>
103    %r = linalg.generic #trait_2d
104        ins(%arga: tensor<3x4xi64, #SparseMatrix>)
105       outs(%init: tensor<3x4xi64>) {
106        ^bb(%a: i64, %x: i64):
107          %i = linalg.index 0 : index
108          %j = linalg.index 1 : index
109          %ii = arith.index_cast %i : index to i64
110          %jj = arith.index_cast %j : index to i64
111          %m1 = arith.addi %ii, %a : i64
112          %m2 = arith.addi %jj, %m1 : i64
113          linalg.yield %m2 : i64
114    } -> tensor<3x4xi64>
115    return %r : tensor<3x4xi64>
116  }
117
118  //
119  // Main driver.
120  //
121  func.func @entry() {
122    %c0 = arith.constant 0 : index
123    %du = arith.constant -1 : i64
124
125    // Setup input sparse vector.
126    %v1 = arith.constant sparse<[[2], [4]], [ 10, 20]> : tensor<8xi64>
127    %sv = sparse_tensor.convert %v1 : tensor<8xi64> to tensor<8xi64, #SparseVector>
128
129    // Setup input "sparse" vector.
130    %v2 = arith.constant dense<[ 1,  2,  4,  8,  16,  32,  64,  128 ]> : tensor<8xi64>
131    %dv = sparse_tensor.convert %v2 : tensor<8xi64> to tensor<8xi64, #SparseVector>
132
133    // Setup input sparse matrix.
134    %m1 = arith.constant sparse<[[1,1], [2,3]], [10, 20]> : tensor<3x4xi64>
135    %sm = sparse_tensor.convert %m1 : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix>
136
137    // Setup input "sparse" matrix.
138    %m2 = arith.constant dense <[ [ 1,  1,  1,  1 ],
139                                  [ 1,  2,  1,  1 ],
140                                  [ 1,  1,  3,  4 ] ]> : tensor<3x4xi64>
141    %dm = sparse_tensor.convert %m2 : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix>
142
143    // Call the kernels.
144    %0 = call @sparse_index_1d_conj(%sv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64>
145    %1 = call @sparse_index_1d_disj(%sv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64>
146    %2 = call @sparse_index_1d_conj(%dv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64>
147    %3 = call @sparse_index_1d_disj(%dv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64>
148    %4 = call @sparse_index_2d_conj(%sm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64>
149    %5 = call @sparse_index_2d_disj(%sm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64>
150    %6 = call @sparse_index_2d_conj(%dm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64>
151    %7 = call @sparse_index_2d_disj(%dm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64>
152
153    // Get the backing buffers.
154    %mem0 = bufferization.to_memref %0 : memref<8xi64>
155    %mem1 = bufferization.to_memref %1 : memref<8xi64>
156    %mem2 = bufferization.to_memref %2 : memref<8xi64>
157    %mem3 = bufferization.to_memref %3 : memref<8xi64>
158    %mem4 = bufferization.to_memref %4 : memref<3x4xi64>
159    %mem5 = bufferization.to_memref %5 : memref<3x4xi64>
160    %mem6 = bufferization.to_memref %6 : memref<3x4xi64>
161    %mem7 = bufferization.to_memref %7 : memref<3x4xi64>
162
163    //
164    // Verify result.
165    //
166    // CHECK:      ( 0, 0, 20, 0, 80, 0, 0, 0 )
167    // CHECK-NEXT: ( 0, 1, 12, 3, 24, 5, 6, 7 )
168    // CHECK-NEXT: ( 0, 2, 8, 24, 64, 160, 384, 896 )
169    // CHECK-NEXT: ( 1, 3, 6, 11, 20, 37, 70, 135 )
170    // CHECK-NEXT: ( ( 0, 0, 0, 0 ), ( 0, 10, 0, 0 ), ( 0, 0, 0, 120 ) )
171    // CHECK-NEXT: ( ( 0, 1, 2, 3 ), ( 1, 12, 3, 4 ), ( 2, 3, 4, 25 ) )
172    // CHECK-NEXT: ( ( 0, 0, 0, 0 ), ( 0, 2, 2, 3 ), ( 0, 2, 12, 24 ) )
173    // CHECK-NEXT: ( ( 1, 2, 3, 4 ), ( 2, 4, 4, 5 ), ( 3, 4, 7, 9 ) )
174    //
175    %vv0 = vector.transfer_read %mem0[%c0], %du: memref<8xi64>, vector<8xi64>
176    %vv1 = vector.transfer_read %mem1[%c0], %du: memref<8xi64>, vector<8xi64>
177    %vv2 = vector.transfer_read %mem2[%c0], %du: memref<8xi64>, vector<8xi64>
178    %vv3 = vector.transfer_read %mem3[%c0], %du: memref<8xi64>, vector<8xi64>
179    %vv4 = vector.transfer_read %mem4[%c0,%c0], %du: memref<3x4xi64>, vector<3x4xi64>
180    %vv5 = vector.transfer_read %mem5[%c0,%c0], %du: memref<3x4xi64>, vector<3x4xi64>
181    %vv6 = vector.transfer_read %mem6[%c0,%c0], %du: memref<3x4xi64>, vector<3x4xi64>
182    %vv7 = vector.transfer_read %mem7[%c0,%c0], %du: memref<3x4xi64>, vector<3x4xi64>
183    vector.print %vv0 : vector<8xi64>
184    vector.print %vv1 : vector<8xi64>
185    vector.print %vv2 : vector<8xi64>
186    vector.print %vv3 : vector<8xi64>
187    vector.print %vv4 : vector<3x4xi64>
188    vector.print %vv5 : vector<3x4xi64>
189    vector.print %vv6 : vector<3x4xi64>
190    vector.print %vv7 : vector<3x4xi64>
191
192    // Release resources.
193    sparse_tensor.release %sv : tensor<8xi64, #SparseVector>
194    sparse_tensor.release %dv : tensor<8xi64, #SparseVector>
195    sparse_tensor.release %sm : tensor<3x4xi64, #SparseMatrix>
196    sparse_tensor.release %dm : tensor<3x4xi64, #SparseMatrix>
197    memref.dealloc %mem0 : memref<8xi64>
198    memref.dealloc %mem1 : memref<8xi64>
199    memref.dealloc %mem2 : memref<8xi64>
200    memref.dealloc %mem3 : memref<8xi64>
201    memref.dealloc %mem4 : memref<3x4xi64>
202    memref.dealloc %mem5 : memref<3x4xi64>
203    memref.dealloc %mem6 : memref<3x4xi64>
204    memref.dealloc %mem7 : memref<3x4xi64>
205
206    return
207  }
208}
209