1// RUN: mlir-opt %s -sparsification="vectorization-strategy=0 vl=16" -cse -split-input-file | \
2// RUN:   FileCheck %s --check-prefix=CHECK-VEC0
3// RUN: mlir-opt %s -sparsification="vectorization-strategy=1 vl=16" -cse -split-input-file | \
4// RUN:   FileCheck %s --check-prefix=CHECK-VEC1
5// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" -cse -split-input-file | \
6// RUN:   FileCheck %s --check-prefix=CHECK-VEC2
7// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" -cse -split-input-file | \
8// RUN:   FileCheck %s --check-prefix=CHECK-VEC3
9
10#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
11
12#trait_scale_d = {
13  indexing_maps = [
14    affine_map<(i) -> (i)>,  // a
15    affine_map<(i) -> (i)>   // x (out)
16  ],
17  iterator_types = ["parallel"],
18  doc = "x(i) = a(i) * b"
19}
20
21//
22// CHECK-VEC0-LABEL: func @scale_d
23// CHECK-VEC0-DAG:   %[[c0:.*]] = arith.constant 0 : index
24// CHECK-VEC0-DAG:   %[[c1:.*]] = arith.constant 1 : index
25// CHECK-VEC0-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
26// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] {
27// CHECK-VEC0:         %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
28// CHECK-VEC0:         %[[m:.*]] = arith.mulf %[[l]], %{{.*}} : f32
29// CHECK-VEC0:         store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>
30// CHECK-VEC0:       }
31// CHECK-VEC0:       return
32//
33// CHECK-VEC1-LABEL: func @scale_d
34// CHECK-VEC1-DAG:   %[[c0:.*]] = arith.constant 0 : index
35// CHECK-VEC1-DAG:   %[[c16:.*]] = arith.constant 16 : index
36// CHECK-VEC1-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
37// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
38// CHECK-VEC1:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
39// CHECK-VEC1:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
40// CHECK-VEC1:         %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32>
41// CHECK-VEC1:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
42// CHECK-VEC1:       }
43// CHECK-VEC1:       return
44//
45// CHECK-VEC2-LABEL: func @scale_d
46// CHECK-VEC2-DAG:   %[[c0:.*]] = arith.constant 0 : index
47// CHECK-VEC2-DAG:   %[[c16:.*]] = arith.constant 16 : index
48// CHECK-VEC2-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
49// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
50// CHECK-VEC2:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
51// CHECK-VEC2:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
52// CHECK-VEC2:         %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32>
53// CHECK-VEC2:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
54// CHECK-VEC2:       }
55// CHECK-VEC2:       return
56//
57func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
58  %0 = linalg.generic #trait_scale_d
59    ins(%arga: tensor<1024xf32, #DenseVector>)
60    outs(%argx: tensor<1024xf32>) {
61      ^bb(%a: f32, %x: f32):
62        %0 = arith.mulf %a, %b : f32
63        linalg.yield %0 : f32
64  } -> tensor<1024xf32>
65  return %0 : tensor<1024xf32>
66}
67
68// -----
69
70#SparseVector = #sparse_tensor.encoding<{
71  dimLevelType = [ "compressed" ],
72  pointerBitWidth = 32,
73  indexBitWidth = 32
74}>
75
76#trait_mul_s = {
77  indexing_maps = [
78    affine_map<(i) -> (i)>,  // a
79    affine_map<(i) -> (i)>,  // b
80    affine_map<(i) -> (i)>   // x (out)
81  ],
82  iterator_types = ["parallel"],
83  doc = "x(i) = a(i) * b(i)"
84}
85
86//
87// CHECK-VEC0-LABEL: func @mul_s
88// CHECK-VEC0-DAG:   %[[c0:.*]] = arith.constant 0 : index
89// CHECK-VEC0-DAG:   %[[c1:.*]] = arith.constant 1 : index
90// CHECK-VEC0:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
91// CHECK-VEC0:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
92// CHECK-VEC0:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
93// CHECK-VEC0:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
94// CHECK-VEC0:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
95// CHECK-VEC0:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
96// CHECK-VEC0:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
97// CHECK-VEC0:         %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
98// CHECK-VEC0:         %[[zi:.*]] = arith.extui %[[li]] : i32 to i64
99// CHECK-VEC0:         %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index
100// CHECK-VEC0:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
101// CHECK-VEC0:         %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
102// CHECK-VEC0:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
103// CHECK-VEC0:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
104// CHECK-VEC0:       }
105// CHECK-VEC0:       return
106//
107// CHECK-VEC1-LABEL: func @mul_s
108// CHECK-VEC1-DAG:   %[[c0:.*]] = arith.constant 0 : index
109// CHECK-VEC1-DAG:   %[[c1:.*]] = arith.constant 1 : index
110// CHECK-VEC1:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
111// CHECK-VEC1:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
112// CHECK-VEC1:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
113// CHECK-VEC1:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
114// CHECK-VEC1:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
115// CHECK-VEC1:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
116// CHECK-VEC1:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
117// CHECK-VEC1:         %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
118// CHECK-VEC1:         %[[zi:.*]] = arith.extui %[[li]] : i32 to i64
119// CHECK-VEC1:         %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index
120// CHECK-VEC1:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
121// CHECK-VEC1:         %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
122// CHECK-VEC1:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
123// CHECK-VEC1:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
124// CHECK-VEC1:       }
125// CHECK-VEC1:       return
126//
127// CHECK-VEC2:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
128// CHECK-VEC2-LABEL: func @mul_s
129// CHECK-VEC2-DAG:   %[[c0:.*]] = arith.constant 0 : index
130// CHECK-VEC2-DAG:   %[[c1:.*]] = arith.constant 1 : index
131// CHECK-VEC2-DAG:   %[[c16:.*]] = arith.constant 16 : index
132// CHECK-VEC2:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
133// CHECK-VEC2:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
134// CHECK-VEC2:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
135// CHECK-VEC2:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
136// CHECK-VEC2:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
137// CHECK-VEC2:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
138// CHECK-VEC2:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
139// CHECK-VEC2:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
140// CHECK-VEC2:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
141// CHECK-VEC2:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
142// CHECK-VEC2:         %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
143// CHECK-VEC2:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
144// CHECK-VEC2:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
145// CHECK-VEC2:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
146// CHECK-VEC2:         vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
147// CHECK-VEC2:       }
148// CHECK-VEC2:       return
149//
150// CHECK-VEC3:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
151// CHECK-VEC3-LABEL: func @mul_s
152// CHECK-VEC3-DAG:   %[[c0:.*]] = arith.constant 0 : index
153// CHECK-VEC3-DAG:   %[[c1:.*]] = arith.constant 1 : index
154// CHECK-VEC3-DAG:   %[[c16:.*]] = arith.constant 16 : index
155// CHECK-VEC3:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
156// CHECK-VEC3:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
157// CHECK-VEC3:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
158// CHECK-VEC3:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
159// CHECK-VEC3:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
160// CHECK-VEC3:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
161// CHECK-VEC3:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
162// CHECK-VEC3:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
163// CHECK-VEC3:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
164// CHECK-VEC3:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
165// CHECK-VEC3:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
166// CHECK-VEC3:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
167// CHECK-VEC3:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
168// CHECK-VEC3:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
169// CHECK-VEC3:       }
170// CHECK-VEC3:       return
171//
172func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
173  %0 = linalg.generic #trait_mul_s
174    ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
175    outs(%argx: tensor<1024xf32>) {
176      ^bb(%a: f32, %b: f32, %x: f32):
177        %0 = arith.mulf %a, %b : f32
178        linalg.yield %0 : f32
179  } -> tensor<1024xf32>
180  return %0 : tensor<1024xf32>
181}
182
183// -----
184
185#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
186
187#trait_reduction_d = {
188  indexing_maps = [
189    affine_map<(i) -> (i)>,  // a
190    affine_map<(i) -> (i)>,  // b
191    affine_map<(i) -> ()>    // x (out)
192  ],
193  iterator_types = ["reduction"],
194  doc = "x += a(i) * b(i)"
195}
196
197//
198// CHECK-VEC0-LABEL: func @reduction_d
199// CHECK-VEC0-DAG:   %[[c0:.*]] = arith.constant 0 : index
200// CHECK-VEC0-DAG:   %[[c1:.*]] = arith.constant 1 : index
201// CHECK-VEC0-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
202// CHECK-VEC0:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) {
203// CHECK-VEC0:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
204// CHECK-VEC0:         %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32>
205// CHECK-VEC0:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
206// CHECK-VEC0:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : f32
207// CHECK-VEC0:         scf.yield %[[a]] : f32
208// CHECK-VEC0:       }
209// CHECK-VEC0:       return
210//
211// CHECK-VEC1-LABEL: func @reduction_d
212// CHECK-VEC1-DAG:   %[[c0:.*]] = arith.constant 0 : index
213// CHECK-VEC1-DAG:   %[[c16:.*]] = arith.constant 16 : index
214// CHECK-VEC1-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
215// CHECK-VEC1-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
216// CHECK-VEC1:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
217// CHECK-VEC1:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
218// CHECK-VEC1:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
219// CHECK-VEC1:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
220// CHECK-VEC1:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
221// CHECK-VEC1:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
222// CHECK-VEC1:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32>
223// CHECK-VEC1:         scf.yield %[[a]] : vector<16xf32>
224// CHECK-VEC1:       }
225// CHECK-VEC1:       %{{.*}} = vector.reduction "add", %[[red]] : vector<16xf32> into f32
226// CHECK-VEC1:       return
227//
228// CHECK-VEC2-LABEL: func @reduction_d
229// CHECK-VEC2-DAG:   %[[c0:.*]] = arith.constant 0 : index
230// CHECK-VEC2-DAG:   %[[c16:.*]] = arith.constant 16 : index
231// CHECK-VEC2-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
232// CHECK-VEC2-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
233// CHECK-VEC2:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
234// CHECK-VEC2:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
235// CHECK-VEC2:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
236// CHECK-VEC2:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
237// CHECK-VEC2:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
238// CHECK-VEC2:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
239// CHECK-VEC2:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32>
240// CHECK-VEC2:         scf.yield %[[a]] : vector<16xf32>
241// CHECK-VEC2:       }
242// CHECK-VEC2:       %{{.*}} = vector.reduction "add", %[[red]] : vector<16xf32> into f32
243// CHECK-VEC2:       return
244//
245func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
246  %0 = linalg.generic #trait_reduction_d
247    ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32>)
248    outs(%argx: tensor<f32>) {
249      ^bb(%a: f32, %b: f32, %x: f32):
250        %0 = arith.mulf %a, %b : f32
251        %1 = arith.addf %x, %0 : f32
252        linalg.yield %1 : f32
253  } -> tensor<f32>
254  return %0 : tensor<f32>
255}
256
257// -----
258
259#SparseMatrix = #sparse_tensor.encoding<{
260  dimLevelType = [ "dense", "compressed" ],
261  pointerBitWidth = 32,
262  indexBitWidth = 32
263}>
264
265#trait_mul_ds = {
266  indexing_maps = [
267    affine_map<(i,j) -> (i,j)>,  // A
268    affine_map<(i,j) -> (i,j)>,  // B
269    affine_map<(i,j) -> (i,j)>   // X (out)
270  ],
271  iterator_types = ["parallel", "parallel"],
272  doc = "X(i,j) = A(i,j) * B(i,j)"
273}
274
275//
276// CHECK-VEC0-LABEL: func @mul_ds
277// CHECK-VEC0-DAG:   %[[c0:.*]] = arith.constant 0 : index
278// CHECK-VEC0-DAG:   %[[c1:.*]] = arith.constant 1 : index
279// CHECK-VEC0-DAG:   %[[c512:.*]] = arith.constant 512 : index
280// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
281// CHECK-VEC0:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
282// CHECK-VEC0:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
283// CHECK-VEC0:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
284// CHECK-VEC0:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
285// CHECK-VEC0:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
286// CHECK-VEC0:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
287// CHECK-VEC0:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
288// CHECK-VEC0:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
289// CHECK-VEC0:           %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
290// CHECK-VEC0:           %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64
291// CHECK-VEC0:           %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index
292// CHECK-VEC0:           %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
293// CHECK-VEC0:           %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
294// CHECK-VEC0:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
295// CHECK-VEC0:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
296// CHECK-VEC0:         }
297// CHECK-VEC0:       }
298// CHECK-VEC0:       return
299//
300// CHECK-VEC1-LABEL: func @mul_ds
301// CHECK-VEC1-DAG:   %[[c0:.*]] = arith.constant 0 : index
302// CHECK-VEC1-DAG:   %[[c1:.*]] = arith.constant 1 : index
303// CHECK-VEC1-DAG:   %[[c512:.*]] = arith.constant 512 : index
304// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
305// CHECK-VEC1:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
306// CHECK-VEC1:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
307// CHECK-VEC1:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
308// CHECK-VEC1:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
309// CHECK-VEC1:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
310// CHECK-VEC1:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
311// CHECK-VEC1:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
312// CHECK-VEC1:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
313// CHECK-VEC1:           %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
314// CHECK-VEC1:           %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64
315// CHECK-VEC1:           %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index
316// CHECK-VEC1:           %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
317// CHECK-VEC1:           %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
318// CHECK-VEC1:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
319// CHECK-VEC1:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
320// CHECK-VEC1:         }
321// CHECK-VEC1:       }
322// CHECK-VEC1:       return
323//
324// CHECK-VEC2:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
325// CHECK-VEC2-LABEL: func @mul_ds
326// CHECK-VEC2-DAG:   %[[c0:.*]] = arith.constant 0 : index
327// CHECK-VEC2-DAG:   %[[c1:.*]] = arith.constant 1 : index
328// CHECK-VEC2-DAG:   %[[c16:.*]] = arith.constant 16 : index
329// CHECK-VEC2-DAG:   %[[c512:.*]] = arith.constant 512 : index
330// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
331// CHECK-VEC2:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
332// CHECK-VEC2:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
333// CHECK-VEC2:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
334// CHECK-VEC2:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
335// CHECK-VEC2:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
336// CHECK-VEC2:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
337// CHECK-VEC2:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
338// CHECK-VEC2:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
339// CHECK-VEC2:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
340// CHECK-VEC2:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
341// CHECK-VEC2:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
342// CHECK-VEC2:           %[[zj:.*]] = arith.extui %[[lj]] : vector<16xi32> to vector<16xi64>
343// CHECK-VEC2:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
344// CHECK-VEC2:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
345// CHECK-VEC2:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
346// CHECK-VEC2:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
347// CHECK-VEC2:         }
348// CHECK-VEC2:       }
349// CHECK-VEC2:       return
350//
351// CHECK-VEC3:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
352// CHECK-VEC3-LABEL: func @mul_ds
353// CHECK-VEC3-DAG:   %[[c0:.*]] = arith.constant 0 : index
354// CHECK-VEC3-DAG:   %[[c1:.*]] = arith.constant 1 : index
355// CHECK-VEC3-DAG:   %[[c16:.*]] = arith.constant 16 : index
356// CHECK-VEC3-DAG:   %[[c512:.*]] = arith.constant 512 : index
357// CHECK-VEC3:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
358// CHECK-VEC3:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
359// CHECK-VEC3:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
360// CHECK-VEC3:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
361// CHECK-VEC3:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
362// CHECK-VEC3:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
363// CHECK-VEC3:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
364// CHECK-VEC3:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
365// CHECK-VEC3:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
366// CHECK-VEC3:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
367// CHECK-VEC3:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
368// CHECK-VEC3:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
369// CHECK-VEC3:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
370// CHECK-VEC3:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
371// CHECK-VEC3:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
372// CHECK-VEC3:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
373// CHECK-VEC3:         }
374// CHECK-VEC3:       }
375// CHECK-VEC3:       return
376//
377func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
378  %0 = linalg.generic #trait_mul_ds
379    ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>)
380    outs(%argx: tensor<512x1024xf32>) {
381      ^bb(%a: f32, %b: f32, %x: f32):
382        %0 = arith.mulf %a, %b : f32
383        linalg.yield %0 : f32
384  } -> tensor<512x1024xf32>
385  return %0 : tensor<512x1024xf32>
386}
387
388// -----
389
390#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","compressed"]}>
391
392#trait_affine = {
393  indexing_maps = [
394    affine_map<(i,j) -> (i,j)>,
395    affine_map<(i,j) -> (i+1,j)>
396  ],
397  iterator_types = ["parallel","parallel"],
398  doc = "X(i+1,j) += A(i,j)"
399}
400
401//
402// CHECK-VEC0-LABEL: func @add_dense
403// CHECK-VEC0-DAG:   %[[c0:.*]] = arith.constant 0 : index
404// CHECK-VEC0-DAG:   %[[c1:.*]] = arith.constant 1 : index
405// CHECK-VEC0-DAG:   %[[c32:.*]] = arith.constant 32 : index
406// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
407// CHECK-VEC0:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
408// CHECK-VEC0:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
409// CHECK-VEC0:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
410// CHECK-VEC0:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] {
411// CHECK-VEC0:           %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xindex>
412// CHECK-VEC0:           %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
413// CHECK-VEC0:           %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xf64>
414// CHECK-VEC0:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64
415// CHECK-VEC0:           memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
416// CHECK-VEC0:         }
417// CHECK-VEC0:       }
418// CHECK-VEC0:       return
419//
420// CHECK-VEC1-LABEL: func @add_dense
421// CHECK-VEC1-DAG:   %[[c0:.*]] = arith.constant 0 : index
422// CHECK-VEC1-DAG:   %[[c1:.*]] = arith.constant 1 : index
423// CHECK-VEC1-DAG:   %[[c32:.*]] = arith.constant 32 : index
424// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
425// CHECK-VEC1:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
426// CHECK-VEC1:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
427// CHECK-VEC1:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
428// CHECK-VEC1:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] {
429// CHECK-VEC1:           %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xindex>
430// CHECK-VEC1:           %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
431// CHECK-VEC1:           %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xf64>
432// CHECK-VEC1:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64
433// CHECK-VEC1:           memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
434// CHECK-VEC1:         }
435// CHECK-VEC1:       }
436// CHECK-VEC1:       return
437//
438// CHECK-VEC2:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
439// CHECK-VEC2-LABEL: func @add_dense
440// CHECK-VEC2-DAG:   %[[c0:.*]] = arith.constant 0 : index
441// CHECK-VEC2-DAG:   %[[c1:.*]] = arith.constant 1 : index
442// CHECK-VEC2-DAG:   %[[c16:.*]] = arith.constant 16 : index
443// CHECK-VEC2-DAG:   %[[c32:.*]] = arith.constant 32 : index
444// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
445// CHECK-VEC2:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
446// CHECK-VEC2:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
447// CHECK-VEC2:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
448// CHECK-VEC2:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c16]] {
449// CHECK-VEC2:           %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[c16]]]
450// CHECK-VEC2:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
451// CHECK-VEC2:           %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xindex>
452// CHECK-VEC2:           %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %{{.*}} : memref<33x64xf64>
453// CHECK-VEC2:           %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xf64>
454// CHECK-VEC2:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<16xf64>
455// CHECK-VEC2:           vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64>
456// CHECK-VEC2:         }
457// CHECK-VEC2:       }
458// CHECK-VEC2:       return
459//
460func @add_dense(%arga: tensor<32x64xf64, #SparseMatrix>,
461                %argx: tensor<33x64xf64> {linalg.inplaceable = true}) -> tensor<33x64xf64> {
462  %0 = linalg.generic #trait_affine
463     ins(%arga: tensor<32x64xf64, #SparseMatrix>)
464    outs(%argx: tensor<33x64xf64>) {
465      ^bb(%a: f64, %x: f64):
466        %0 = arith.addf %x, %a : f64
467        linalg.yield %0 : f64
468  } -> tensor<33x64xf64>
469  return %0 : tensor<33x64xf64>
470}
471