1// RUN: mlir-opt %s -sparsification="vectorization-strategy=0 vl=16" -cse -split-input-file | \
2// RUN:   FileCheck %s --check-prefix=CHECK-VEC0
3// RUN: mlir-opt %s -sparsification="vectorization-strategy=1 vl=16" -cse -split-input-file | \
4// RUN:   FileCheck %s --check-prefix=CHECK-VEC1
5// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" -cse -split-input-file | \
6// RUN:   FileCheck %s --check-prefix=CHECK-VEC2
7// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" -cse -split-input-file | \
8// RUN:   FileCheck %s --check-prefix=CHECK-VEC3
9// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=4 enable-vla-vectorization=true" -cse -split-input-file | \
10// RUN:   FileCheck %s --check-prefix=CHECK-VEC4
11
12#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
13
14#trait_scale_d = {
15  indexing_maps = [
16    affine_map<(i) -> (i)>,  // a
17    affine_map<(i) -> (i)>   // x (out)
18  ],
19  iterator_types = ["parallel"],
20  doc = "x(i) = a(i) * b"
21}
22
23//
24// CHECK-VEC0-LABEL: func @scale_d
25// CHECK-VEC0-DAG:   %[[c0:.*]] = arith.constant 0 : index
26// CHECK-VEC0-DAG:   %[[c1:.*]] = arith.constant 1 : index
27// CHECK-VEC0-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
28// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] {
29// CHECK-VEC0:         %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
30// CHECK-VEC0:         %[[m:.*]] = arith.mulf %[[l]], %{{.*}} : f32
31// CHECK-VEC0:         store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>
32// CHECK-VEC0:       }
33// CHECK-VEC0:       return
34//
35// CHECK-VEC1-LABEL: func @scale_d
36// CHECK-VEC1-DAG:   %[[c0:.*]] = arith.constant 0 : index
37// CHECK-VEC1-DAG:   %[[c16:.*]] = arith.constant 16 : index
38// CHECK-VEC1-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
39// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
40// CHECK-VEC1:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
41// CHECK-VEC1:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
42// CHECK-VEC1:         %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32>
43// CHECK-VEC1:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
44// CHECK-VEC1:       }
45// CHECK-VEC1:       return
46//
47// CHECK-VEC2-LABEL: func @scale_d
48// CHECK-VEC2-DAG:   %[[c0:.*]] = arith.constant 0 : index
49// CHECK-VEC2-DAG:   %[[c16:.*]] = arith.constant 16 : index
50// CHECK-VEC2-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
51// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
52// CHECK-VEC2:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
53// CHECK-VEC2:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
54// CHECK-VEC2:         %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32>
55// CHECK-VEC2:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
56// CHECK-VEC2:       }
57// CHECK-VEC2:       return
58//
59// CHECK-VEC4:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
60// CHECK-VEC4-LABEL: func @scale_d
61// CHECK-VEC4-DAG:   %[[c0:.*]] = arith.constant 0 : index
62// CHECK-VEC4-DAG:   %[[c4:.*]] = arith.constant 4 : index
63// CHECK-VEC4-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
64// CHECK-VEC4-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
65// CHECK-VEC4-DAG:   %[[vscale:.*]] = vector.vscale
66// CHECK-VEC4:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
67// CHECK-VEC4:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] {
68// CHECK-VEC4:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
69// CHECK-VEC4:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
70// CHECK-VEC4:         %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
71// CHECK-VEC4:         %[[scalev:.*]] = vector.broadcast %{{.*}} : f32 to vector<[4]xf32>
72// CHECK-VEC4:         %[[scaled:.*]] = arith.mulf %[[val]], %[[scalev]] : vector<[4]xf32>
73// CHECK-VEC4:         vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32>
74// CHECK-VEC4:       }
75// CHECK-VEC4:       return
76//
77func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
78  %0 = linalg.generic #trait_scale_d
79    ins(%arga: tensor<1024xf32, #DenseVector>)
80    outs(%argx: tensor<1024xf32>) {
81      ^bb(%a: f32, %x: f32):
82        %0 = arith.mulf %a, %b : f32
83        linalg.yield %0 : f32
84  } -> tensor<1024xf32>
85  return %0 : tensor<1024xf32>
86}
87
88// -----
89
90#SparseVector = #sparse_tensor.encoding<{
91  dimLevelType = [ "compressed" ],
92  pointerBitWidth = 32,
93  indexBitWidth = 32
94}>
95
96#trait_mul_s = {
97  indexing_maps = [
98    affine_map<(i) -> (i)>,  // a
99    affine_map<(i) -> (i)>,  // b
100    affine_map<(i) -> (i)>   // x (out)
101  ],
102  iterator_types = ["parallel"],
103  doc = "x(i) = a(i) * b(i)"
104}
105
106//
107// CHECK-VEC0-LABEL: func @mul_s
108// CHECK-VEC0-DAG:   %[[c0:.*]] = arith.constant 0 : index
109// CHECK-VEC0-DAG:   %[[c1:.*]] = arith.constant 1 : index
110// CHECK-VEC0:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
111// CHECK-VEC0:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
112// CHECK-VEC0:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
113// CHECK-VEC0:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
114// CHECK-VEC0:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
115// CHECK-VEC0:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
116// CHECK-VEC0:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
117// CHECK-VEC0:         %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
118// CHECK-VEC0:         %[[zi:.*]] = arith.extui %[[li]] : i32 to i64
119// CHECK-VEC0:         %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index
120// CHECK-VEC0:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
121// CHECK-VEC0:         %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
122// CHECK-VEC0:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
123// CHECK-VEC0:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
124// CHECK-VEC0:       }
125// CHECK-VEC0:       return
126//
127// CHECK-VEC1-LABEL: func @mul_s
128// CHECK-VEC1-DAG:   %[[c0:.*]] = arith.constant 0 : index
129// CHECK-VEC1-DAG:   %[[c1:.*]] = arith.constant 1 : index
130// CHECK-VEC1:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
131// CHECK-VEC1:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
132// CHECK-VEC1:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
133// CHECK-VEC1:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
134// CHECK-VEC1:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
135// CHECK-VEC1:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
136// CHECK-VEC1:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
137// CHECK-VEC1:         %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
138// CHECK-VEC1:         %[[zi:.*]] = arith.extui %[[li]] : i32 to i64
139// CHECK-VEC1:         %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index
140// CHECK-VEC1:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
141// CHECK-VEC1:         %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
142// CHECK-VEC1:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
143// CHECK-VEC1:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
144// CHECK-VEC1:       }
145// CHECK-VEC1:       return
146//
147// CHECK-VEC2:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
148// CHECK-VEC2-LABEL: func @mul_s
149// CHECK-VEC2-DAG:   %[[c0:.*]] = arith.constant 0 : index
150// CHECK-VEC2-DAG:   %[[c1:.*]] = arith.constant 1 : index
151// CHECK-VEC2-DAG:   %[[c16:.*]] = arith.constant 16 : index
152// CHECK-VEC2:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
153// CHECK-VEC2:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
154// CHECK-VEC2:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
155// CHECK-VEC2:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
156// CHECK-VEC2:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
157// CHECK-VEC2:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
158// CHECK-VEC2:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
159// CHECK-VEC2:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
160// CHECK-VEC2:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
161// CHECK-VEC2:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
162// CHECK-VEC2:         %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
163// CHECK-VEC2:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
164// CHECK-VEC2:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
165// CHECK-VEC2:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
166// CHECK-VEC2:         vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
167// CHECK-VEC2:       }
168// CHECK-VEC2:       return
169//
170// CHECK-VEC3:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
171// CHECK-VEC3-LABEL: func @mul_s
172// CHECK-VEC3-DAG:   %[[c0:.*]] = arith.constant 0 : index
173// CHECK-VEC3-DAG:   %[[c1:.*]] = arith.constant 1 : index
174// CHECK-VEC3-DAG:   %[[c16:.*]] = arith.constant 16 : index
175// CHECK-VEC3:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
176// CHECK-VEC3:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
177// CHECK-VEC3:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
178// CHECK-VEC3:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
179// CHECK-VEC3:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
180// CHECK-VEC3:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
181// CHECK-VEC3:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
182// CHECK-VEC3:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
183// CHECK-VEC3:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
184// CHECK-VEC3:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
185// CHECK-VEC3:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
186// CHECK-VEC3:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
187// CHECK-VEC3:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
188// CHECK-VEC3:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
189// CHECK-VEC3:       }
190// CHECK-VEC3:       return
191//
192// CHECK-VEC4:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
193// CHECK-VEC4-LABEL: func @mul_s
194// CHECK-VEC4-DAG:   %[[c0:.*]] = arith.constant 0 : index
195// CHECK-VEC4-DAG:   %[[c1:.*]] = arith.constant 1 : index
196// CHECK-VEC4-DAG:   %[[c4:.*]] = arith.constant 4 : index
197// CHECK-VEC4-DAG:   %[[v0i:.*]] = arith.constant dense<0> : vector<[4]xi32>
198// CHECK-VEC4-DAG:   %[[v0f:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
199// CHECK-VEC4:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
200// CHECK-VEC4:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
201// CHECK-VEC4:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
202// CHECK-VEC4:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
203// CHECK-VEC4:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
204// CHECK-VEC4:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
205// CHECK-VEC4:       %[[vscale:.*]] = vector.vscale
206// CHECK-VEC4:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
207// CHECK-VEC4:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[step]] {
208// CHECK-VEC4:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[step]]]
209// CHECK-VEC4:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
210// CHECK-VEC4:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
211// CHECK-VEC4:         %[[lii64:.*]] = arith.extui %[[li]] : vector<[4]xi32> to vector<[4]xi64>
212// CHECK-VEC4:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0f]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
213// CHECK-VEC4:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[v0f]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
214// CHECK-VEC4:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
215// CHECK-VEC4:         vector.scatter %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32>
216// CHECK-VEC4:       }
217// CHECK-VEC4:       return
218//
219func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
220  %0 = linalg.generic #trait_mul_s
221    ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
222    outs(%argx: tensor<1024xf32>) {
223      ^bb(%a: f32, %b: f32, %x: f32):
224        %0 = arith.mulf %a, %b : f32
225        linalg.yield %0 : f32
226  } -> tensor<1024xf32>
227  return %0 : tensor<1024xf32>
228}
229
230// -----
231
232#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
233
234#trait_reduction_d = {
235  indexing_maps = [
236    affine_map<(i) -> (i)>,  // a
237    affine_map<(i) -> (i)>,  // b
238    affine_map<(i) -> ()>    // x (out)
239  ],
240  iterator_types = ["reduction"],
241  doc = "x += a(i) * b(i)"
242}
243
244//
245// CHECK-VEC0-LABEL: func @reduction_d
246// CHECK-VEC0-DAG:   %[[c0:.*]] = arith.constant 0 : index
247// CHECK-VEC0-DAG:   %[[c1:.*]] = arith.constant 1 : index
248// CHECK-VEC0-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
249// CHECK-VEC0:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) {
250// CHECK-VEC0:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
251// CHECK-VEC0:         %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32>
252// CHECK-VEC0:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
253// CHECK-VEC0:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : f32
254// CHECK-VEC0:         scf.yield %[[a]] : f32
255// CHECK-VEC0:       }
256// CHECK-VEC0:       return
257//
258// CHECK-VEC1-LABEL: func @reduction_d
259// CHECK-VEC1-DAG:   %[[c0:.*]] = arith.constant 0 : index
260// CHECK-VEC1-DAG:   %[[c16:.*]] = arith.constant 16 : index
261// CHECK-VEC1-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
262// CHECK-VEC1-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
263// CHECK-VEC1:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
264// CHECK-VEC1:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
265// CHECK-VEC1:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
266// CHECK-VEC1:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
267// CHECK-VEC1:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
268// CHECK-VEC1:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
269// CHECK-VEC1:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32>
270// CHECK-VEC1:         scf.yield %[[a]] : vector<16xf32>
271// CHECK-VEC1:       }
272// CHECK-VEC1:       %{{.*}} = vector.reduction <add>, %[[red]] : vector<16xf32> into f32
273// CHECK-VEC1:       return
274//
275// CHECK-VEC2-LABEL: func @reduction_d
276// CHECK-VEC2-DAG:   %[[c0:.*]] = arith.constant 0 : index
277// CHECK-VEC2-DAG:   %[[c16:.*]] = arith.constant 16 : index
278// CHECK-VEC2-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
279// CHECK-VEC2-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
280// CHECK-VEC2:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
281// CHECK-VEC2:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
282// CHECK-VEC2:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
283// CHECK-VEC2:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
284// CHECK-VEC2:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
285// CHECK-VEC2:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
286// CHECK-VEC2:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32>
287// CHECK-VEC2:         scf.yield %[[a]] : vector<16xf32>
288// CHECK-VEC2:       }
289// CHECK-VEC2:       %{{.*}} = vector.reduction <add>, %[[red]] : vector<16xf32> into f32
290// CHECK-VEC2:       return
291//
292// CHECK-VEC4:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
293// CHECK-VEC4-LABEL: func @reduction_d
294// CHECK-VEC4-DAG:   %[[c0:.*]] = arith.constant 0 : index
295// CHECK-VEC4-DAG:   %[[c4:.*]] = arith.constant 4 : index
296// CHECK-VEC4-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
297// CHECK-VEC4-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
298// CHECK-VEC4:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
299// CHECK-VEC4:       %[[vscale:.*]] = vector.vscale
300// CHECK-VEC4:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
301// CHECK-VEC4:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
302// CHECK-VEC4:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
303// CHECK-VEC4:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
304// CHECK-VEC4:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
305// CHECK-VEC4:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
306// CHECK-VEC4:         %[[lb:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
307// CHECK-VEC4:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
308// CHECK-VEC4:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<[4]xf32>
309// CHECK-VEC4:         %[[sa:.*]] = arith.select %[[mask]], %[[a]], %[[red_in]] : vector<[4]xi1>, vector<[4]xf32>
310// CHECK-VEC4:         scf.yield %[[sa]] : vector<[4]xf32>
311// CHECK-VEC4:       }
312// CHECK-VEC4:       %{{.*}} = vector.reduction <add>, %[[red]] : vector<[4]xf32> into f32
313// CHECK-VEC4:       return
314//
315func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
316  %0 = linalg.generic #trait_reduction_d
317    ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32>)
318    outs(%argx: tensor<f32>) {
319      ^bb(%a: f32, %b: f32, %x: f32):
320        %0 = arith.mulf %a, %b : f32
321        %1 = arith.addf %x, %0 : f32
322        linalg.yield %1 : f32
323  } -> tensor<f32>
324  return %0 : tensor<f32>
325}
326
327// -----
328
329#SparseMatrix = #sparse_tensor.encoding<{
330  dimLevelType = [ "dense", "compressed" ],
331  pointerBitWidth = 32,
332  indexBitWidth = 32
333}>
334
335#trait_mul_ds = {
336  indexing_maps = [
337    affine_map<(i,j) -> (i,j)>,  // A
338    affine_map<(i,j) -> (i,j)>,  // B
339    affine_map<(i,j) -> (i,j)>   // X (out)
340  ],
341  iterator_types = ["parallel", "parallel"],
342  doc = "X(i,j) = A(i,j) * B(i,j)"
343}
344
345//
346// CHECK-VEC0-LABEL: func @mul_ds
347// CHECK-VEC0-DAG:   %[[c0:.*]] = arith.constant 0 : index
348// CHECK-VEC0-DAG:   %[[c1:.*]] = arith.constant 1 : index
349// CHECK-VEC0-DAG:   %[[c512:.*]] = arith.constant 512 : index
350// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
351// CHECK-VEC0:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
352// CHECK-VEC0:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
353// CHECK-VEC0:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
354// CHECK-VEC0:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
355// CHECK-VEC0:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
356// CHECK-VEC0:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
357// CHECK-VEC0:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
358// CHECK-VEC0:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
359// CHECK-VEC0:           %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
360// CHECK-VEC0:           %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64
361// CHECK-VEC0:           %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index
362// CHECK-VEC0:           %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
363// CHECK-VEC0:           %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
364// CHECK-VEC0:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
365// CHECK-VEC0:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
366// CHECK-VEC0:         }
367// CHECK-VEC0:       }
368// CHECK-VEC0:       return
369//
370// CHECK-VEC1-LABEL: func @mul_ds
371// CHECK-VEC1-DAG:   %[[c0:.*]] = arith.constant 0 : index
372// CHECK-VEC1-DAG:   %[[c1:.*]] = arith.constant 1 : index
373// CHECK-VEC1-DAG:   %[[c512:.*]] = arith.constant 512 : index
374// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
375// CHECK-VEC1:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
376// CHECK-VEC1:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
377// CHECK-VEC1:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
378// CHECK-VEC1:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
379// CHECK-VEC1:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
380// CHECK-VEC1:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
381// CHECK-VEC1:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
382// CHECK-VEC1:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
383// CHECK-VEC1:           %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
384// CHECK-VEC1:           %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64
385// CHECK-VEC1:           %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index
386// CHECK-VEC1:           %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
387// CHECK-VEC1:           %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
388// CHECK-VEC1:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
389// CHECK-VEC1:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
390// CHECK-VEC1:         }
391// CHECK-VEC1:       }
392// CHECK-VEC1:       return
393//
394// CHECK-VEC2:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
395// CHECK-VEC2-LABEL: func @mul_ds
396// CHECK-VEC2-DAG:   %[[c0:.*]] = arith.constant 0 : index
397// CHECK-VEC2-DAG:   %[[c1:.*]] = arith.constant 1 : index
398// CHECK-VEC2-DAG:   %[[c16:.*]] = arith.constant 16 : index
399// CHECK-VEC2-DAG:   %[[c512:.*]] = arith.constant 512 : index
400// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
401// CHECK-VEC2:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
402// CHECK-VEC2:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
403// CHECK-VEC2:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
404// CHECK-VEC2:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
405// CHECK-VEC2:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
406// CHECK-VEC2:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
407// CHECK-VEC2:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
408// CHECK-VEC2:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
409// CHECK-VEC2:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
410// CHECK-VEC2:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
411// CHECK-VEC2:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
412// CHECK-VEC2:           %[[zj:.*]] = arith.extui %[[lj]] : vector<16xi32> to vector<16xi64>
413// CHECK-VEC2:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
414// CHECK-VEC2:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
415// CHECK-VEC2:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
416// CHECK-VEC2:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
417// CHECK-VEC2:         }
418// CHECK-VEC2:       }
419// CHECK-VEC2:       return
420//
421// CHECK-VEC3:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
422// CHECK-VEC3-LABEL: func @mul_ds
423// CHECK-VEC3-DAG:   %[[c0:.*]] = arith.constant 0 : index
424// CHECK-VEC3-DAG:   %[[c1:.*]] = arith.constant 1 : index
425// CHECK-VEC3-DAG:   %[[c16:.*]] = arith.constant 16 : index
426// CHECK-VEC3-DAG:   %[[c512:.*]] = arith.constant 512 : index
427// CHECK-VEC3:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
428// CHECK-VEC3:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
429// CHECK-VEC3:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
430// CHECK-VEC3:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
431// CHECK-VEC3:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
432// CHECK-VEC3:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
433// CHECK-VEC3:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
434// CHECK-VEC3:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
435// CHECK-VEC3:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
436// CHECK-VEC3:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
437// CHECK-VEC3:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
438// CHECK-VEC3:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
439// CHECK-VEC3:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
440// CHECK-VEC3:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
441// CHECK-VEC3:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
442// CHECK-VEC3:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
443// CHECK-VEC3:         }
444// CHECK-VEC3:       }
445// CHECK-VEC3:       return
446//
447// CHECK-VEC4:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
448// CHECK-VEC4-LABEL: func @mul_ds
449// CHECK-VEC4-DAG:   %[[c0:.*]] = arith.constant 0 : index
450// CHECK-VEC4-DAG:   %[[c1:.*]] = arith.constant 1 : index
451// CHECK-VEC4-DAG:   %[[c4:.*]] = arith.constant 4 : index
452// CHECK-VEC4-DAG:   %[[c512:.*]] = arith.constant 512 : index
453// CHECK-VEC4-DAG:   %[[v0i:.*]] = arith.constant dense<0> : vector<[4]xi32>
454// CHECK-VEC4-DAG:   %[[v0f:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
455// CHECK-VEC4:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
456// CHECK-VEC4:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
457// CHECK-VEC4:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
458// CHECK-VEC4:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
459// CHECK-VEC4:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
460// CHECK-VEC4:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
461// CHECK-VEC4:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
462// CHECK-VEC4:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
463// CHECK-VEC4:         %[[vscale:.*]] = vector.vscale
464// CHECK-VEC4:         %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
465// CHECK-VEC4:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[step]] {
466// CHECK-VEC4:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[step]]]
467// CHECK-VEC4:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
468// CHECK-VEC4:           %[[lji32:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
469// CHECK-VEC4:           %[[lj:.*]] = arith.extui %[[lji32]] : vector<[4]xi32> to vector<[4]xi64>
470// CHECK-VEC4:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0f]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
471// CHECK-VEC4:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[v0f]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
472// CHECK-VEC4:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
473// CHECK-VEC4:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32>
474// CHECK-VEC4:         }
475// CHECK-VEC4:       }
476// CHECK-VEC4:       return
477//
478func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
479  %0 = linalg.generic #trait_mul_ds
480    ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>)
481    outs(%argx: tensor<512x1024xf32>) {
482      ^bb(%a: f32, %b: f32, %x: f32):
483        %0 = arith.mulf %a, %b : f32
484        linalg.yield %0 : f32
485  } -> tensor<512x1024xf32>
486  return %0 : tensor<512x1024xf32>
487}
488
489// -----
490
491#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","compressed"]}>
492
493#trait_affine = {
494  indexing_maps = [
495    affine_map<(i,j) -> (i,j)>,
496    affine_map<(i,j) -> (i+1,j)>
497  ],
498  iterator_types = ["parallel","parallel"],
499  doc = "X(i+1,j) += A(i,j)"
500}
501
502//
503// CHECK-VEC0-LABEL: func @add_dense
504// CHECK-VEC0-DAG:   %[[c0:.*]] = arith.constant 0 : index
505// CHECK-VEC0-DAG:   %[[c1:.*]] = arith.constant 1 : index
506// CHECK-VEC0-DAG:   %[[c32:.*]] = arith.constant 32 : index
507// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
508// CHECK-VEC0:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
509// CHECK-VEC0:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
510// CHECK-VEC0:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
511// CHECK-VEC0:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] {
512// CHECK-VEC0:           %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xindex>
513// CHECK-VEC0:           %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
514// CHECK-VEC0:           %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xf64>
515// CHECK-VEC0:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64
516// CHECK-VEC0:           memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
517// CHECK-VEC0:         }
518// CHECK-VEC0:       }
519// CHECK-VEC0:       return
520//
521// CHECK-VEC1-LABEL: func @add_dense
522// CHECK-VEC1-DAG:   %[[c0:.*]] = arith.constant 0 : index
523// CHECK-VEC1-DAG:   %[[c1:.*]] = arith.constant 1 : index
524// CHECK-VEC1-DAG:   %[[c32:.*]] = arith.constant 32 : index
525// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
526// CHECK-VEC1:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
527// CHECK-VEC1:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
528// CHECK-VEC1:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
529// CHECK-VEC1:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] {
530// CHECK-VEC1:           %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xindex>
531// CHECK-VEC1:           %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
532// CHECK-VEC1:           %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xf64>
533// CHECK-VEC1:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64
534// CHECK-VEC1:           memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
535// CHECK-VEC1:         }
536// CHECK-VEC1:       }
537// CHECK-VEC1:       return
538//
539// CHECK-VEC2:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
540// CHECK-VEC2-LABEL: func @add_dense
541// CHECK-VEC2-DAG:   %[[c0:.*]] = arith.constant 0 : index
542// CHECK-VEC2-DAG:   %[[c1:.*]] = arith.constant 1 : index
543// CHECK-VEC2-DAG:   %[[c16:.*]] = arith.constant 16 : index
544// CHECK-VEC2-DAG:   %[[c32:.*]] = arith.constant 32 : index
545// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
546// CHECK-VEC2:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
547// CHECK-VEC2:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
548// CHECK-VEC2:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
549// CHECK-VEC2:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c16]] {
550// CHECK-VEC2:           %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[c16]]]
551// CHECK-VEC2:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
552// CHECK-VEC2:           %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xindex>
553// CHECK-VEC2:           %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %{{.*}} : memref<33x64xf64>
554// CHECK-VEC2:           %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xf64>
555// CHECK-VEC2:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<16xf64>
556// CHECK-VEC2:           vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64>
557// CHECK-VEC2:         }
558// CHECK-VEC2:       }
559// CHECK-VEC2:       return
560//
561// CHECK-VEC4:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
562// CHECK-VEC4-LABEL: func @add_dense
563// CHECK-VEC4-DAG:   %[[c0:.*]] = arith.constant 0 : index
564// CHECK-VEC4-DAG:   %[[c1:.*]] = arith.constant 1 : index
565// CHECK-VEC4-DAG:   %[[c4:.*]] = arith.constant 4 : index
566// CHECK-VEC4-DAG:   %[[c32:.*]] = arith.constant 32 : index
567// CHECK-VEC4-DAG:   %[[v0idx:.*]] = arith.constant dense<0> : vector<[4]xindex>
568// CHECK-VEC4-DAG:   %[[v0f64:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf64>
569// CHECK-VEC4:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
570// CHECK-VEC4:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
571// CHECK-VEC4:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
572// CHECK-VEC4:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
573// CHECK-VEC4:         %[[vscale:.*]] = vector.vscale
574// CHECK-VEC4:         %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
575// CHECK-VEC4:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[step]] {
576// CHECK-VEC4:           %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[step]]]
577// CHECK-VEC4:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
578// CHECK-VEC4:           %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %[[v0idx]] : memref<?xindex>
579// CHECK-VEC4:           %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[v0f64]] : memref<33x64xf64>
580// CHECK-VEC4:           %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %[[v0f64]] : memref<?xf64>
581// CHECK-VEC4:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<[4]xf64>
582// CHECK-VEC4:           vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64>
583// CHECK-VEC4:         }
584// CHECK-VEC4:       }
585// CHECK-VEC4:       return
586//
587func.func @add_dense(%arga: tensor<32x64xf64, #SparseMatrix>,
588                %argx: tensor<33x64xf64>) -> tensor<33x64xf64> {
589  %0 = linalg.generic #trait_affine
590     ins(%arga: tensor<32x64xf64, #SparseMatrix>)
591    outs(%argx: tensor<33x64xf64>) {
592      ^bb(%a: f64, %x: f64):
593        %0 = arith.addf %x, %a : f64
594        linalg.yield %0 : f64
595  } -> tensor<33x64xf64>
596  return %0 : tensor<33x64xf64>
597}
598