1// RUN: mlir-opt %s --sparse-compiler | \
2// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
3// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
4// RUN: FileCheck %s
5//
6// Do the same run, but now with SIMDization as well. This should not change the outcome.
7//
8// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=2 vl=4" | \
9// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
10// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
11// RUN: FileCheck %s
12
13#SparseVector = #sparse_tensor.encoding<{
14  dimLevelType = ["compressed"]
15}>
16
17#SparseMatrix = #sparse_tensor.encoding<{
18  dimLevelType = ["compressed", "compressed"]
19}>
20
21#trait_1d = {
22  indexing_maps = [
23    affine_map<(i) -> (i)>,  // a
24    affine_map<(i) -> (i)>   // x (out)
25  ],
26  iterator_types = ["parallel"],
27  doc = "X(i) = a(i) op i"
28}
29
30#trait_2d = {
31  indexing_maps = [
32    affine_map<(i,j) -> (i,j)>,  // A
33    affine_map<(i,j) -> (i,j)>   // X (out)
34  ],
35  iterator_types = ["parallel", "parallel"],
36  doc = "X(i,j) = A(i,j) op i op j"
37}
38
39//
40// Test with indices and sparse inputs. All outputs are dense.
41//
42module {
43
44  //
45  // Kernel that uses index in the index notation (conjunction).
46  //
47  func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>,
48                                  %out: tensor<8xi64>) -> tensor<8xi64> {
49    %r = linalg.generic #trait_1d
50        ins(%arga: tensor<8xi64, #SparseVector>)
51       outs(%out: tensor<8xi64>) {
52        ^bb(%a: i64, %x: i64):
53          %i = linalg.index 0 : index
54          %ii = arith.index_cast %i : index to i64
55          %m1 = arith.muli %a, %ii : i64
56          linalg.yield %m1 : i64
57    } -> tensor<8xi64>
58    return %r : tensor<8xi64>
59  }
60
61  //
62  // Kernel that uses index in the index notation (disjunction).
63  //
64  func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>,
65                                  %out: tensor<8xi64>) -> tensor<8xi64> {
66    %r = linalg.generic #trait_1d
67        ins(%arga: tensor<8xi64, #SparseVector>)
68       outs(%out: tensor<8xi64>) {
69        ^bb(%a: i64, %x: i64):
70          %i = linalg.index 0 : index
71          %ii = arith.index_cast %i : index to i64
72          %m1 = arith.addi %a, %ii : i64
73          linalg.yield %m1 : i64
74    } -> tensor<8xi64>
75    return %r : tensor<8xi64>
76  }
77
78  //
79  // Kernel that uses indices in the index notation (conjunction).
80  //
81  func.func @sparse_index_2d_conj(%arga: tensor<3x4xi64, #SparseMatrix>,
82                                  %out: tensor<3x4xi64>) -> tensor<3x4xi64> {
83    %r = linalg.generic #trait_2d
84        ins(%arga: tensor<3x4xi64, #SparseMatrix>)
85       outs(%out: tensor<3x4xi64>) {
86        ^bb(%a: i64, %x: i64):
87          %i = linalg.index 0 : index
88          %j = linalg.index 1 : index
89          %ii = arith.index_cast %i : index to i64
90          %jj = arith.index_cast %j : index to i64
91          %m1 = arith.muli %ii, %a : i64
92          %m2 = arith.muli %jj, %m1 : i64
93          linalg.yield %m2 : i64
94    } -> tensor<3x4xi64>
95    return %r : tensor<3x4xi64>
96  }
97
98  //
99  // Kernel that uses indices in the index notation (disjunction).
100  //
101  func.func @sparse_index_2d_disj(%arga: tensor<3x4xi64, #SparseMatrix>,
102                                  %out: tensor<3x4xi64>) -> tensor<3x4xi64> {
103    %r = linalg.generic #trait_2d
104        ins(%arga: tensor<3x4xi64, #SparseMatrix>)
105       outs(%out: tensor<3x4xi64>) {
106        ^bb(%a: i64, %x: i64):
107          %i = linalg.index 0 : index
108          %j = linalg.index 1 : index
109          %ii = arith.index_cast %i : index to i64
110          %jj = arith.index_cast %j : index to i64
111          %m1 = arith.addi %ii, %a : i64
112          %m2 = arith.addi %jj, %m1 : i64
113          linalg.yield %m2 : i64
114    } -> tensor<3x4xi64>
115    return %r : tensor<3x4xi64>
116  }
117
118  //
119  // Main driver.
120  //
121  func.func @entry() {
122    %c0 = arith.constant 0 : index
123    %du = arith.constant -1 : i64
124
125    // Setup input sparse vector.
126    %v1 = arith.constant sparse<[[2], [4]], [ 10, 20]> : tensor<8xi64>
127    %sv = sparse_tensor.convert %v1 : tensor<8xi64> to tensor<8xi64, #SparseVector>
128
129    // Setup input "sparse" vector.
130    %v2 = arith.constant dense<[ 1,  2,  4,  8,  16,  32,  64,  128 ]> : tensor<8xi64>
131    %dv = sparse_tensor.convert %v2 : tensor<8xi64> to tensor<8xi64, #SparseVector>
132
133    // Setup input sparse matrix.
134    %m1 = arith.constant sparse<[[1,1], [2,3]], [10, 20]> : tensor<3x4xi64>
135    %sm = sparse_tensor.convert %m1 : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix>
136
137    // Setup input "sparse" matrix.
138    %m2 = arith.constant dense <[ [ 1,  1,  1,  1 ],
139                                  [ 1,  2,  1,  1 ],
140                                  [ 1,  1,  3,  4 ] ]> : tensor<3x4xi64>
141    %dm = sparse_tensor.convert %m2 : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix>
142
143    // Setup out tensors.
144    %init_8 = bufferization.alloc_tensor() : tensor<8xi64>
145    %init_3_4 = bufferization.alloc_tensor() : tensor<3x4xi64>
146
147    // Call the kernels.
148    %0 = call @sparse_index_1d_conj(%sv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64>
149    %1 = call @sparse_index_1d_disj(%sv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64>
150    %2 = call @sparse_index_1d_conj(%dv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64>
151    %3 = call @sparse_index_1d_disj(%dv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64>
152    %4 = call @sparse_index_2d_conj(%sm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64>
153    %5 = call @sparse_index_2d_disj(%sm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64>
154    %6 = call @sparse_index_2d_conj(%dm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64>
155    %7 = call @sparse_index_2d_disj(%dm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64>
156
157    //
158    // Verify result.
159    //
160    // CHECK:      ( 0, 0, 20, 0, 80, 0, 0, 0 )
161    // CHECK-NEXT: ( 0, 1, 12, 3, 24, 5, 6, 7 )
162    // CHECK-NEXT: ( 0, 2, 8, 24, 64, 160, 384, 896 )
163    // CHECK-NEXT: ( 1, 3, 6, 11, 20, 37, 70, 135 )
164    // CHECK-NEXT: ( ( 0, 0, 0, 0 ), ( 0, 10, 0, 0 ), ( 0, 0, 0, 120 ) )
165    // CHECK-NEXT: ( ( 0, 1, 2, 3 ), ( 1, 12, 3, 4 ), ( 2, 3, 4, 25 ) )
166    // CHECK-NEXT: ( ( 0, 0, 0, 0 ), ( 0, 2, 2, 3 ), ( 0, 2, 12, 24 ) )
167    // CHECK-NEXT: ( ( 1, 2, 3, 4 ), ( 2, 4, 4, 5 ), ( 3, 4, 7, 9 ) )
168    //
169    %vv0 = vector.transfer_read %0[%c0], %du: tensor<8xi64>, vector<8xi64>
170    %vv1 = vector.transfer_read %1[%c0], %du: tensor<8xi64>, vector<8xi64>
171    %vv2 = vector.transfer_read %2[%c0], %du: tensor<8xi64>, vector<8xi64>
172    %vv3 = vector.transfer_read %3[%c0], %du: tensor<8xi64>, vector<8xi64>
173    %vv4 = vector.transfer_read %4[%c0,%c0], %du: tensor<3x4xi64>, vector<3x4xi64>
174    %vv5 = vector.transfer_read %5[%c0,%c0], %du: tensor<3x4xi64>, vector<3x4xi64>
175    %vv6 = vector.transfer_read %6[%c0,%c0], %du: tensor<3x4xi64>, vector<3x4xi64>
176    %vv7 = vector.transfer_read %7[%c0,%c0], %du: tensor<3x4xi64>, vector<3x4xi64>
177    vector.print %vv0 : vector<8xi64>
178    vector.print %vv1 : vector<8xi64>
179    vector.print %vv2 : vector<8xi64>
180    vector.print %vv3 : vector<8xi64>
181    vector.print %vv4 : vector<3x4xi64>
182    vector.print %vv5 : vector<3x4xi64>
183    vector.print %vv6 : vector<3x4xi64>
184    vector.print %vv7 : vector<3x4xi64>
185
186    // Release resources.
187    bufferization.dealloc_tensor %sv : tensor<8xi64, #SparseVector>
188    bufferization.dealloc_tensor %dv : tensor<8xi64, #SparseVector>
189    bufferization.dealloc_tensor %sm : tensor<3x4xi64, #SparseMatrix>
190    bufferization.dealloc_tensor %dm : tensor<3x4xi64, #SparseMatrix>
191
192    return
193  }
194}
195