1// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
2// RUN: mlir-opt %s -sparsification | FileCheck %s
3
4#Td = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
5
6#Tddd = #sparse_tensor.encoding<{ dimLevelType = [ "dense",      "dense",      "dense"      ] }>
7#Tdds = #sparse_tensor.encoding<{ dimLevelType = [ "dense",      "dense",      "compressed" ] }>
8#Tdsd = #sparse_tensor.encoding<{ dimLevelType = [ "dense",      "compressed", "dense"      ] }>
9#Tdss = #sparse_tensor.encoding<{ dimLevelType = [ "dense",      "compressed", "compressed" ] }>
10#Tsdd = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense",      "dense"      ] }>
11#Tsds = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense",      "compressed" ] }>
12#Tssd = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense"      ] }>
13#Tsss = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>
14
15#trait3 = {
16  indexing_maps = [
17    affine_map<(i,j,k) -> (i,j,k)>,  // A
18    affine_map<(i,j,k) -> (i,j,k)>,  // B
19    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
20  ],
21  iterator_types = ["parallel", "parallel", "parallel"],
22  doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)"
23}
24
25// CHECK-LABEL:   func @add_ddd(
26// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
27// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
28// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
29// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
30// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
31// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
32// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
33// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
34// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
35// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
36// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
37// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
38// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
39// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
40// CHECK:             scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
41// CHECK:               %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index
42// CHECK:               %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_13]] : index
43// CHECK:               scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
44// CHECK:                 %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index
45// CHECK:                 %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_16]] : index
46// CHECK:                 %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf32>
47// CHECK:                 %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
48// CHECK:                 %[[VAL_21:.*]] = arith.addf %[[VAL_19]], %[[VAL_20]] : f32
49// CHECK:                 memref.store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
50// CHECK:               }
51// CHECK:             }
52// CHECK:           }
53// CHECK:           %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<32x16x8xf32>
54// CHECK:           return %[[VAL_22]] : tensor<32x16x8xf32>
55// CHECK:         }
56func.func @add_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
57  %0 = linalg.generic #trait3
58     ins(%arga, %argb: tensor<32x16x8xf32, #Tddd>, tensor<32x16x8xf32>)
59    outs(%argx: tensor<32x16x8xf32>) {
60      ^bb(%a: f32, %b: f32, %x: f32):
61        %0 = arith.addf %a, %b : f32
62        linalg.yield %0 : f32
63  } -> tensor<32x16x8xf32>
64  return %0 : tensor<32x16x8xf32>
65}
66
67// CHECK-LABEL:   func @mul_ddd(
68// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
69// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
70// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
71// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
72// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
73// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
74// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
75// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
76// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
77// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
78// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
79// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
80// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
81// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
82// CHECK:             scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
83// CHECK:               %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index
84// CHECK:               %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_13]] : index
85// CHECK:               scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
86// CHECK:                 %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index
87// CHECK:                 %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_16]] : index
88// CHECK:                 %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf32>
89// CHECK:                 %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
90// CHECK:                 %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f32
91// CHECK:                 memref.store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
92// CHECK:               }
93// CHECK:             }
94// CHECK:           }
95// CHECK:           %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<32x16x8xf32>
96// CHECK:           return %[[VAL_22]] : tensor<32x16x8xf32>
97// CHECK:         }
98func.func @mul_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
99  %0 = linalg.generic #trait3
100     ins(%arga, %argb: tensor<32x16x8xf32, #Tddd>, tensor<32x16x8xf32>)
101    outs(%argx: tensor<32x16x8xf32>) {
102      ^bb(%a: f32, %b: f32, %x: f32):
103        %0 = arith.mulf %a, %b : f32
104        linalg.yield %0 : f32
105  } -> tensor<32x16x8xf32>
106  return %0 : tensor<32x16x8xf32>
107}
108
109// CHECK-LABEL:   func @add_dds(
110// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
111// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
112// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
113// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
114// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
115// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
116// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 16 : index
117// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 8 : index
118// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
119// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
120// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 1 : index
121// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
122// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
123// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
124// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
125// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
126// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>)
127// CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_9]] {
128// CHECK:             scf.for %[[VAL_17:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_9]] {
129// CHECK:               %[[VAL_18:.*]] = arith.muli %[[VAL_16]], %[[VAL_5]] : index
130// CHECK:               %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_17]] : index
131// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
132// CHECK:               %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_9]] : index
133// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xindex>
134// CHECK:               %[[VAL_23:.*]]:2 = scf.while (%[[VAL_24:.*]] = %[[VAL_20]], %[[VAL_25:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
135// CHECK:                 %[[VAL_26:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_22]] : index
136// CHECK:                 scf.condition(%[[VAL_26]]) %[[VAL_24]], %[[VAL_25]] : index, index
137// CHECK:               } do {
138// CHECK:               ^bb0(%[[VAL_27:.*]]: index, %[[VAL_28:.*]]: index):
139// CHECK:                 %[[VAL_29:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref<?xindex>
140// CHECK:                 %[[VAL_30:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_28]] : index
141// CHECK:                 scf.if %[[VAL_30]] {
142// CHECK:                   %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xf32>
143// CHECK:                   %[[VAL_32:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
144// CHECK:                   %[[VAL_33:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32
145// CHECK:                   memref.store %[[VAL_33]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
146// CHECK:                 } else {
147// CHECK:                   scf.if %[[VAL_8]] {
148// CHECK:                     %[[VAL_34:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
149// CHECK:                     memref.store %[[VAL_34]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
150// CHECK:                   } else {
151// CHECK:                   }
152// CHECK:                 }
153// CHECK:                 %[[VAL_35:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_28]] : index
154// CHECK:                 %[[VAL_36:.*]] = arith.addi %[[VAL_27]], %[[VAL_9]] : index
155// CHECK:                 %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_36]], %[[VAL_27]] : index
156// CHECK:                 %[[VAL_38:.*]] = arith.addi %[[VAL_28]], %[[VAL_9]] : index
157// CHECK:                 scf.yield %[[VAL_37]], %[[VAL_38]] : index, index
158// CHECK:               }
159// CHECK:               scf.for %[[VAL_39:.*]] = %[[VAL_40:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
160// CHECK:                 %[[VAL_41:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_39]]] : memref<32x16x8xf32>
161// CHECK:                 memref.store %[[VAL_41]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_39]]] : memref<32x16x8xf32>
162// CHECK:               }
163// CHECK:             }
164// CHECK:           }
165// CHECK:           %[[VAL_42:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<32x16x8xf32>
166// CHECK:           return %[[VAL_42]] : tensor<32x16x8xf32>
167// CHECK:         }
168func.func @add_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
169  %0 = linalg.generic #trait3
170     ins(%arga, %argb: tensor<32x16x8xf32, #Tdds>, tensor<32x16x8xf32>)
171    outs(%argx: tensor<32x16x8xf32>) {
172      ^bb(%a: f32, %b: f32, %x: f32):
173        %0 = arith.addf %a, %b : f32
174        linalg.yield %0 : f32
175  } -> tensor<32x16x8xf32>
176  return %0 : tensor<32x16x8xf32>
177}
178
179// CHECK-LABEL:   func @mul_dds(
180// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
181// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
182// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
183// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
184// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
185// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
186// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 16 : index
187// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
188// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
189// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
190// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
191// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
192// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
193// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
194// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>)
195// CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
196// CHECK:             scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
197// CHECK:               %[[VAL_16:.*]] = arith.muli %[[VAL_14]], %[[VAL_5]] : index
198// CHECK:               %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_15]] : index
199// CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
200// CHECK:               %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_7]] : index
201// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
202// CHECK:               scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_7]] {
203// CHECK:                 %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
204// CHECK:                 %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xf32>
205// CHECK:                 %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_22]]] : memref<32x16x8xf32>
206// CHECK:                 %[[VAL_25:.*]] = arith.mulf %[[VAL_23]], %[[VAL_24]] : f32
207// CHECK:                 memref.store %[[VAL_25]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_22]]] : memref<32x16x8xf32>
208// CHECK:               }
209// CHECK:             }
210// CHECK:           }
211// CHECK:           %[[VAL_26:.*]] = bufferization.to_tensor %[[VAL_13]] : memref<32x16x8xf32>
212// CHECK:           return %[[VAL_26]] : tensor<32x16x8xf32>
213// CHECK:         }
214func.func @mul_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
215  %0 = linalg.generic #trait3
216     ins(%arga, %argb: tensor<32x16x8xf32, #Tdds>, tensor<32x16x8xf32>)
217    outs(%argx: tensor<32x16x8xf32>) {
218      ^bb(%a: f32, %b: f32, %x: f32):
219        %0 = arith.mulf %a, %b : f32
220        linalg.yield %0 : f32
221  } -> tensor<32x16x8xf32>
222  return %0 : tensor<32x16x8xf32>
223}
224
225// CHECK-LABEL:   func @add_dsd(
226// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
227// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
228// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
229// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
230// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
231// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
232// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
233// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant true
234// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
235// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 1 : index
236// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
237// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
238// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
239// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
240// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
241// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
242// CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_8]] {
243// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xindex>
244// CHECK:             %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_8]] : index
245// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref<?xindex>
246// CHECK:             %[[VAL_19:.*]]:2 = scf.while (%[[VAL_20:.*]] = %[[VAL_16]], %[[VAL_21:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
247// CHECK:               %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_18]] : index
248// CHECK:               scf.condition(%[[VAL_22]]) %[[VAL_20]], %[[VAL_21]] : index, index
249// CHECK:             } do {
250// CHECK:             ^bb0(%[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index):
251// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref<?xindex>
252// CHECK:               %[[VAL_26:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index
253// CHECK:               scf.if %[[VAL_26]] {
254// CHECK:                 scf.for %[[VAL_27:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
255// CHECK:                   %[[VAL_28:.*]] = arith.muli %[[VAL_23]], %[[VAL_5]] : index
256// CHECK:                   %[[VAL_29:.*]] = arith.addi %[[VAL_28]], %[[VAL_27]] : index
257// CHECK:                   %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref<?xf32>
258// CHECK:                   %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_27]]] : memref<32x16x8xf32>
259// CHECK:                   %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[VAL_31]] : f32
260// CHECK:                   memref.store %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_27]]] : memref<32x16x8xf32>
261// CHECK:                 }
262// CHECK:               } else {
263// CHECK:                 scf.if %[[VAL_6]] {
264// CHECK:                   scf.for %[[VAL_33:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
265// CHECK:                     %[[VAL_34:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_33]]] : memref<32x16x8xf32>
266// CHECK:                     memref.store %[[VAL_34]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_33]]] : memref<32x16x8xf32>
267// CHECK:                   }
268// CHECK:                 } else {
269// CHECK:                 }
270// CHECK:               }
271// CHECK:               %[[VAL_35:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index
272// CHECK:               %[[VAL_36:.*]] = arith.addi %[[VAL_23]], %[[VAL_8]] : index
273// CHECK:               %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_36]], %[[VAL_23]] : index
274// CHECK:               %[[VAL_38:.*]] = arith.addi %[[VAL_24]], %[[VAL_8]] : index
275// CHECK:               scf.yield %[[VAL_37]], %[[VAL_38]] : index, index
276// CHECK:             }
277// CHECK:             scf.for %[[VAL_39:.*]] = %[[VAL_40:.*]]#1 to %[[VAL_4]] step %[[VAL_8]] {
278// CHECK:               scf.for %[[VAL_41:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
279// CHECK:                 %[[VAL_42:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_39]], %[[VAL_41]]] : memref<32x16x8xf32>
280// CHECK:                 memref.store %[[VAL_42]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_39]], %[[VAL_41]]] : memref<32x16x8xf32>
281// CHECK:               }
282// CHECK:             }
283// CHECK:           }
284// CHECK:           %[[VAL_43:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<32x16x8xf32>
285// CHECK:           return %[[VAL_43]] : tensor<32x16x8xf32>
286// CHECK:         }
287func.func @add_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
288  %0 = linalg.generic #trait3
289     ins(%arga, %argb: tensor<32x16x8xf32, #Tdsd>, tensor<32x16x8xf32>)
290    outs(%argx: tensor<32x16x8xf32>) {
291      ^bb(%a: f32, %b: f32, %x: f32):
292        %0 = arith.addf %a, %b : f32
293        linalg.yield %0 : f32
294  } -> tensor<32x16x8xf32>
295  return %0 : tensor<32x16x8xf32>
296}
297
298// CHECK-LABEL:   func @mul_dsd(
299// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
300// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
301// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
302// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
303// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
304// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 8 : index
305// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
306// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
307// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
308// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
309// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
310// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
311// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
312// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_12]] : memref<32x16x8xf32>)
313// CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
314// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
315// CHECK:             %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : index
316// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
317// CHECK:             scf.for %[[VAL_17:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_6]] {
318// CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
319// CHECK:               scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
320// CHECK:                 %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index
321// CHECK:                 %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : index
322// CHECK:                 %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xf32>
323// CHECK:                 %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_13]], %[[VAL_18]], %[[VAL_19]]] : memref<32x16x8xf32>
324// CHECK:                 %[[VAL_24:.*]] = arith.mulf %[[VAL_22]], %[[VAL_23]] : f32
325// CHECK:                 memref.store %[[VAL_24]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_18]], %[[VAL_19]]] : memref<32x16x8xf32>
326// CHECK:               }
327// CHECK:             }
328// CHECK:           }
329// CHECK:           %[[VAL_25:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<32x16x8xf32>
330// CHECK:           return %[[VAL_25]] : tensor<32x16x8xf32>
331// CHECK:         }
332func.func @mul_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
333  %0 = linalg.generic #trait3
334     ins(%arga, %argb: tensor<32x16x8xf32, #Tdsd>, tensor<32x16x8xf32>)
335    outs(%argx: tensor<32x16x8xf32>) {
336      ^bb(%a: f32, %b: f32, %x: f32):
337        %0 = arith.mulf %a, %b : f32
338        linalg.yield %0 : f32
339  } -> tensor<32x16x8xf32>
340  return %0 : tensor<32x16x8xf32>
341}
342
343// CHECK-LABEL:   func @add_dss(
344// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
345// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
346// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
347// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
348// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
349// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
350// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 16 : index
351// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 8 : index
352// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant true
353// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
354// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 1 : index
355// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
356// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
357// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
358// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
359// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
360// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
361// CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
362// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_17]] : memref<32x16x8xf32>)
363// CHECK:           scf.for %[[VAL_18:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_9]] {
364// CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xindex>
365// CHECK:             %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_9]] : index
366// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xindex>
367// CHECK:             %[[VAL_22:.*]]:2 = scf.while (%[[VAL_23:.*]] = %[[VAL_19]], %[[VAL_24:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
368// CHECK:               %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_23]], %[[VAL_21]] : index
369// CHECK:               scf.condition(%[[VAL_25]]) %[[VAL_23]], %[[VAL_24]] : index, index
370// CHECK:             } do {
371// CHECK:             ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index):
372// CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_26]]] : memref<?xindex>
373// CHECK:               %[[VAL_29:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_27]] : index
374// CHECK:               scf.if %[[VAL_29]] {
375// CHECK:                 %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<?xindex>
376// CHECK:                 %[[VAL_31:.*]] = arith.addi %[[VAL_26]], %[[VAL_9]] : index
377// CHECK:                 %[[VAL_32:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<?xindex>
378// CHECK:                 %[[VAL_33:.*]]:2 = scf.while (%[[VAL_34:.*]] = %[[VAL_30]], %[[VAL_35:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
379// CHECK:                   %[[VAL_36:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_32]] : index
380// CHECK:                   scf.condition(%[[VAL_36]]) %[[VAL_34]], %[[VAL_35]] : index, index
381// CHECK:                 } do {
382// CHECK:                 ^bb0(%[[VAL_37:.*]]: index, %[[VAL_38:.*]]: index):
383// CHECK:                   %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_37]]] : memref<?xindex>
384// CHECK:                   %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_38]] : index
385// CHECK:                   scf.if %[[VAL_40]] {
386// CHECK:                     %[[VAL_41:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_37]]] : memref<?xf32>
387// CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
388// CHECK:                     %[[VAL_43:.*]] = arith.addf %[[VAL_41]], %[[VAL_42]] : f32
389// CHECK:                     memref.store %[[VAL_43]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
390// CHECK:                   } else {
391// CHECK:                     scf.if %[[VAL_7]] {
392// CHECK:                       %[[VAL_44:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
393// CHECK:                       memref.store %[[VAL_44]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
394// CHECK:                     } else {
395// CHECK:                     }
396// CHECK:                   }
397// CHECK:                   %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_38]] : index
398// CHECK:                   %[[VAL_46:.*]] = arith.addi %[[VAL_37]], %[[VAL_9]] : index
399// CHECK:                   %[[VAL_47:.*]] = arith.select %[[VAL_45]], %[[VAL_46]], %[[VAL_37]] : index
400// CHECK:                   %[[VAL_48:.*]] = arith.addi %[[VAL_38]], %[[VAL_9]] : index
401// CHECK:                   scf.yield %[[VAL_47]], %[[VAL_48]] : index, index
402// CHECK:                 }
403// CHECK:                 scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
404// CHECK:                   %[[VAL_51:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_49]]] : memref<32x16x8xf32>
405// CHECK:                   memref.store %[[VAL_51]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_49]]] : memref<32x16x8xf32>
406// CHECK:                 }
407// CHECK:               } else {
408// CHECK:                 scf.if %[[VAL_7]] {
409// CHECK:                   scf.for %[[VAL_52:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
410// CHECK:                     %[[VAL_53:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_52]]] : memref<32x16x8xf32>
411// CHECK:                     memref.store %[[VAL_53]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_52]]] : memref<32x16x8xf32>
412// CHECK:                   }
413// CHECK:                 } else {
414// CHECK:                 }
415// CHECK:               }
416// CHECK:               %[[VAL_54:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_27]] : index
417// CHECK:               %[[VAL_55:.*]] = arith.addi %[[VAL_26]], %[[VAL_9]] : index
418// CHECK:               %[[VAL_56:.*]] = arith.select %[[VAL_54]], %[[VAL_55]], %[[VAL_26]] : index
419// CHECK:               %[[VAL_57:.*]] = arith.addi %[[VAL_27]], %[[VAL_9]] : index
420// CHECK:               scf.yield %[[VAL_56]], %[[VAL_57]] : index, index
421// CHECK:             }
422// CHECK:             scf.for %[[VAL_58:.*]] = %[[VAL_59:.*]]#1 to %[[VAL_5]] step %[[VAL_9]] {
423// CHECK:               scf.for %[[VAL_60:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
424// CHECK:                 %[[VAL_61:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_58]], %[[VAL_60]]] : memref<32x16x8xf32>
425// CHECK:                 memref.store %[[VAL_61]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_58]], %[[VAL_60]]] : memref<32x16x8xf32>
426// CHECK:               }
427// CHECK:             }
428// CHECK:           }
429// CHECK:           %[[VAL_62:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<32x16x8xf32>
430// CHECK:           return %[[VAL_62]] : tensor<32x16x8xf32>
431// CHECK:         }
432func.func @add_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
433  %0 = linalg.generic #trait3
434     ins(%arga, %argb: tensor<32x16x8xf32, #Tdss>, tensor<32x16x8xf32>)
435    outs(%argx: tensor<32x16x8xf32>) {
436      ^bb(%a: f32, %b: f32, %x: f32):
437        %0 = arith.addf %a, %b : f32
438        linalg.yield %0 : f32
439  } -> tensor<32x16x8xf32>
440  return %0 : tensor<32x16x8xf32>
441}
442
443// CHECK-LABEL:   func @mul_dss(
444// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
445// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
446// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
447// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
448// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
449// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
450// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
451// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
452// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
453// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
454// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
455// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
456// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
457// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
458// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
459// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
460// CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
461// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
462// CHECK:             %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_6]] : index
463// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
464// CHECK:             scf.for %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_18]] step %[[VAL_6]] {
465// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
466// CHECK:               %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
467// CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_6]] : index
468// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref<?xindex>
469// CHECK:               scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_6]] {
470// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref<?xindex>
471// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref<?xf32>
472// CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_20]], %[[VAL_25]]] : memref<32x16x8xf32>
473// CHECK:                 %[[VAL_28:.*]] = arith.mulf %[[VAL_26]], %[[VAL_27]] : f32
474// CHECK:                 memref.store %[[VAL_28]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_20]], %[[VAL_25]]] : memref<32x16x8xf32>
475// CHECK:               }
476// CHECK:             }
477// CHECK:           }
478// CHECK:           %[[VAL_29:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<32x16x8xf32>
479// CHECK:           return %[[VAL_29]] : tensor<32x16x8xf32>
480// CHECK:         }
481func.func @mul_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
482  %0 = linalg.generic #trait3
483     ins(%arga, %argb: tensor<32x16x8xf32, #Tdss>, tensor<32x16x8xf32>)
484    outs(%argx: tensor<32x16x8xf32>) {
485      ^bb(%a: f32, %b: f32, %x: f32):
486        %0 = arith.mulf %a, %b : f32
487        linalg.yield %0 : f32
488  } -> tensor<32x16x8xf32>
489  return %0 : tensor<32x16x8xf32>
490}
491
492// CHECK-LABEL:   func @add_sdd(
493// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
494// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
495// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
496// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
497// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
498// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
499// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
500// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant true
501// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
502// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 1 : index
503// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
504// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
505// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
506// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
507// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
508// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
509// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
510// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref<?xindex>
511// CHECK:           %[[VAL_17:.*]]:2 = scf.while (%[[VAL_18:.*]] = %[[VAL_15]], %[[VAL_19:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
512// CHECK:             %[[VAL_20:.*]] = arith.cmpi ult, %[[VAL_18]], %[[VAL_16]] : index
513// CHECK:             scf.condition(%[[VAL_20]]) %[[VAL_18]], %[[VAL_19]] : index, index
514// CHECK:           } do {
515// CHECK:           ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index):
516// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xindex>
517// CHECK:             %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_22]] : index
518// CHECK:             scf.if %[[VAL_24]] {
519// CHECK:               scf.for %[[VAL_25:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
520// CHECK:                 %[[VAL_26:.*]] = arith.muli %[[VAL_21]], %[[VAL_4]] : index
521// CHECK:                 %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_25]] : index
522// CHECK:                 scf.for %[[VAL_28:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
523// CHECK:                   %[[VAL_29:.*]] = arith.muli %[[VAL_27]], %[[VAL_5]] : index
524// CHECK:                   %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_28]] : index
525// CHECK:                   %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xf32>
526// CHECK:                   %[[VAL_32:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]] : memref<32x16x8xf32>
527// CHECK:                   %[[VAL_33:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32
528// CHECK:                   memref.store %[[VAL_33]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]] : memref<32x16x8xf32>
529// CHECK:                 }
530// CHECK:               }
531// CHECK:             } else {
532// CHECK:               scf.if %[[VAL_6]] {
533// CHECK:                 scf.for %[[VAL_34:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
534// CHECK:                   scf.for %[[VAL_35:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
535// CHECK:                     %[[VAL_36:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]], %[[VAL_34]], %[[VAL_35]]] : memref<32x16x8xf32>
536// CHECK:                     memref.store %[[VAL_36]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_34]], %[[VAL_35]]] : memref<32x16x8xf32>
537// CHECK:                   }
538// CHECK:                 }
539// CHECK:               } else {
540// CHECK:               }
541// CHECK:             }
542// CHECK:             %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_22]] : index
543// CHECK:             %[[VAL_38:.*]] = arith.addi %[[VAL_21]], %[[VAL_8]] : index
544// CHECK:             %[[VAL_39:.*]] = arith.select %[[VAL_37]], %[[VAL_38]], %[[VAL_21]] : index
545// CHECK:             %[[VAL_40:.*]] = arith.addi %[[VAL_22]], %[[VAL_8]] : index
546// CHECK:             scf.yield %[[VAL_39]], %[[VAL_40]] : index, index
547// CHECK:           }
548// CHECK:           scf.for %[[VAL_41:.*]] = %[[VAL_42:.*]]#1 to %[[VAL_3]] step %[[VAL_8]] {
549// CHECK:             scf.for %[[VAL_43:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
550// CHECK:               scf.for %[[VAL_44:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
551// CHECK:                 %[[VAL_45:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_41]], %[[VAL_43]], %[[VAL_44]]] : memref<32x16x8xf32>
552// CHECK:                 memref.store %[[VAL_45]], %[[VAL_14]]{{\[}}%[[VAL_41]], %[[VAL_43]], %[[VAL_44]]] : memref<32x16x8xf32>
553// CHECK:               }
554// CHECK:             }
555// CHECK:           }
556// CHECK:           %[[VAL_46:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<32x16x8xf32>
557// CHECK:           return %[[VAL_46]] : tensor<32x16x8xf32>
558// CHECK:         }
559func.func @add_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
560  %0 = linalg.generic #trait3
561     ins(%arga, %argb: tensor<32x16x8xf32, #Tsdd>, tensor<32x16x8xf32>)
562    outs(%argx: tensor<32x16x8xf32>) {
563      ^bb(%a: f32, %b: f32, %x: f32):
564        %0 = arith.addf %a, %b : f32
565        linalg.yield %0 : f32
566  } -> tensor<32x16x8xf32>
567  return %0 : tensor<32x16x8xf32>
568}
569
570// CHECK-LABEL:   func @mul_sdd(
571// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
572// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
573// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
574// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
575// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 16 : index
576// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 8 : index
577// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
578// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
579// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
580// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
581// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
582// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
583// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
584// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_12]] : memref<32x16x8xf32>)
585// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
586// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
587// CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_6]] {
588// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xindex>
589// CHECK:             scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
590// CHECK:               %[[VAL_18:.*]] = arith.muli %[[VAL_15]], %[[VAL_3]] : index
591// CHECK:               %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_17]] : index
592// CHECK:               scf.for %[[VAL_20:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
593// CHECK:                 %[[VAL_21:.*]] = arith.muli %[[VAL_19]], %[[VAL_4]] : index
594// CHECK:                 %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_20]] : index
595// CHECK:                 %[[VAL_23:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref<?xf32>
596// CHECK:                 %[[VAL_24:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_20]]] : memref<32x16x8xf32>
597// CHECK:                 %[[VAL_25:.*]] = arith.mulf %[[VAL_23]], %[[VAL_24]] : f32
598// CHECK:                 memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_20]]] : memref<32x16x8xf32>
599// CHECK:               }
600// CHECK:             }
601// CHECK:           }
602// CHECK:           %[[VAL_26:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<32x16x8xf32>
603// CHECK:           return %[[VAL_26]] : tensor<32x16x8xf32>
604// CHECK:         }
605func.func @mul_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
606  %0 = linalg.generic #trait3
607     ins(%arga, %argb: tensor<32x16x8xf32, #Tsdd>, tensor<32x16x8xf32>)
608    outs(%argx: tensor<32x16x8xf32>) {
609      ^bb(%a: f32, %b: f32, %x: f32):
610        %0 = arith.mulf %a, %b : f32
611        linalg.yield %0 : f32
612  } -> tensor<32x16x8xf32>
613  return %0 : tensor<32x16x8xf32>
614}
615
616// CHECK-LABEL:   func @add_sds(
617// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
618// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
619// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
620// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
621// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
622// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
623// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 16 : index
624// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 8 : index
625// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant true
626// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
627// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 1 : index
628// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
629// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
630// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
631// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
632// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
633// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
634// CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
635// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_17]] : memref<32x16x8xf32>)
636// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_8]]] : memref<?xindex>
637// CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_9]]] : memref<?xindex>
638// CHECK:           %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_18]], %[[VAL_22:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
639// CHECK:             %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_21]], %[[VAL_19]] : index
640// CHECK:             scf.condition(%[[VAL_23]]) %[[VAL_21]], %[[VAL_22]] : index, index
641// CHECK:           } do {
642// CHECK:           ^bb0(%[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index):
643// CHECK:             %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref<?xindex>
644// CHECK:             %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_25]] : index
645// CHECK:             scf.if %[[VAL_27]] {
646// CHECK:               scf.for %[[VAL_28:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
647// CHECK:                 %[[VAL_29:.*]] = arith.muli %[[VAL_24]], %[[VAL_5]] : index
648// CHECK:                 %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_28]] : index
649// CHECK:                 %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_30]]] : memref<?xindex>
650// CHECK:                 %[[VAL_32:.*]] = arith.addi %[[VAL_30]], %[[VAL_9]] : index
651// CHECK:                 %[[VAL_33:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xindex>
652// CHECK:                 %[[VAL_34:.*]]:2 = scf.while (%[[VAL_35:.*]] = %[[VAL_31]], %[[VAL_36:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
653// CHECK:                   %[[VAL_37:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_33]] : index
654// CHECK:                   scf.condition(%[[VAL_37]]) %[[VAL_35]], %[[VAL_36]] : index, index
655// CHECK:                 } do {
656// CHECK:                 ^bb0(%[[VAL_38:.*]]: index, %[[VAL_39:.*]]: index):
657// CHECK:                   %[[VAL_40:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_38]]] : memref<?xindex>
658// CHECK:                   %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_40]], %[[VAL_39]] : index
659// CHECK:                   scf.if %[[VAL_41]] {
660// CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_38]]] : memref<?xf32>
661// CHECK:                     %[[VAL_43:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
662// CHECK:                     %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32
663// CHECK:                     memref.store %[[VAL_44]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
664// CHECK:                   } else {
665// CHECK:                     scf.if %[[VAL_7]] {
666// CHECK:                       %[[VAL_45:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
667// CHECK:                       memref.store %[[VAL_45]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
668// CHECK:                     } else {
669// CHECK:                     }
670// CHECK:                   }
671// CHECK:                   %[[VAL_46:.*]] = arith.cmpi eq, %[[VAL_40]], %[[VAL_39]] : index
672// CHECK:                   %[[VAL_47:.*]] = arith.addi %[[VAL_38]], %[[VAL_9]] : index
673// CHECK:                   %[[VAL_48:.*]] = arith.select %[[VAL_46]], %[[VAL_47]], %[[VAL_38]] : index
674// CHECK:                   %[[VAL_49:.*]] = arith.addi %[[VAL_39]], %[[VAL_9]] : index
675// CHECK:                   scf.yield %[[VAL_48]], %[[VAL_49]] : index, index
676// CHECK:                 }
677// CHECK:                 scf.for %[[VAL_50:.*]] = %[[VAL_51:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
678// CHECK:                   %[[VAL_52:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_50]]] : memref<32x16x8xf32>
679// CHECK:                   memref.store %[[VAL_52]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_50]]] : memref<32x16x8xf32>
680// CHECK:                 }
681// CHECK:               }
682// CHECK:             } else {
683// CHECK:               scf.if %[[VAL_7]] {
684// CHECK:                 scf.for %[[VAL_53:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
685// CHECK:                   scf.for %[[VAL_54:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
686// CHECK:                     %[[VAL_55:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_53]], %[[VAL_54]]] : memref<32x16x8xf32>
687// CHECK:                     memref.store %[[VAL_55]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_53]], %[[VAL_54]]] : memref<32x16x8xf32>
688// CHECK:                   }
689// CHECK:                 }
690// CHECK:               } else {
691// CHECK:               }
692// CHECK:             }
693// CHECK:             %[[VAL_56:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_25]] : index
694// CHECK:             %[[VAL_57:.*]] = arith.addi %[[VAL_24]], %[[VAL_9]] : index
695// CHECK:             %[[VAL_58:.*]] = arith.select %[[VAL_56]], %[[VAL_57]], %[[VAL_24]] : index
696// CHECK:             %[[VAL_59:.*]] = arith.addi %[[VAL_25]], %[[VAL_9]] : index
697// CHECK:             scf.yield %[[VAL_58]], %[[VAL_59]] : index, index
698// CHECK:           }
699// CHECK:           scf.for %[[VAL_60:.*]] = %[[VAL_61:.*]]#1 to %[[VAL_4]] step %[[VAL_9]] {
700// CHECK:             scf.for %[[VAL_62:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
701// CHECK:               scf.for %[[VAL_63:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
702// CHECK:                 %[[VAL_64:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_60]], %[[VAL_62]], %[[VAL_63]]] : memref<32x16x8xf32>
703// CHECK:                 memref.store %[[VAL_64]], %[[VAL_17]]{{\[}}%[[VAL_60]], %[[VAL_62]], %[[VAL_63]]] : memref<32x16x8xf32>
704// CHECK:               }
705// CHECK:             }
706// CHECK:           }
707// CHECK:           %[[VAL_65:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<32x16x8xf32>
708// CHECK:           return %[[VAL_65]] : tensor<32x16x8xf32>
709// CHECK:         }
710func.func @add_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
711  %0 = linalg.generic #trait3
712     ins(%arga, %argb: tensor<32x16x8xf32, #Tsds>, tensor<32x16x8xf32>)
713    outs(%argx: tensor<32x16x8xf32>) {
714      ^bb(%a: f32, %b: f32, %x: f32):
715        %0 = arith.addf %a, %b : f32
716        linalg.yield %0 : f32
717  } -> tensor<32x16x8xf32>
718  return %0 : tensor<32x16x8xf32>
719}
720
721// CHECK-LABEL:   func @mul_sds(
722// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
723// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
724// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
725// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
726// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
727// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
728// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
729// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
730// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
731// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
732// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
733// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
734// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
735// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
736// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
737// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
738// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
739// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
740// CHECK:           scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_6]] {
741// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
742// CHECK:             scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
743// CHECK:               %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index
744// CHECK:               %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : index
745// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
746// CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_6]] : index
747// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
748// CHECK:               scf.for %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_6]] {
749// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xindex>
750// CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref<?xf32>
751// CHECK:                 %[[VAL_28:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_18]], %[[VAL_19]], %[[VAL_26]]] : memref<32x16x8xf32>
752// CHECK:                 %[[VAL_29:.*]] = arith.mulf %[[VAL_27]], %[[VAL_28]] : f32
753// CHECK:                 memref.store %[[VAL_29]], %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_19]], %[[VAL_26]]] : memref<32x16x8xf32>
754// CHECK:               }
755// CHECK:             }
756// CHECK:           }
757// CHECK:           %[[VAL_30:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<32x16x8xf32>
758// CHECK:           return %[[VAL_30]] : tensor<32x16x8xf32>
759// CHECK:         }
760func.func @mul_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
761  %0 = linalg.generic #trait3
762     ins(%arga, %argb: tensor<32x16x8xf32, #Tsds>, tensor<32x16x8xf32>)
763    outs(%argx: tensor<32x16x8xf32>) {
764      ^bb(%a: f32, %b: f32, %x: f32):
765        %0 = arith.mulf %a, %b : f32
766        linalg.yield %0 : f32
767  } -> tensor<32x16x8xf32>
768  return %0 : tensor<32x16x8xf32>
769}
770
771// CHECK-LABEL:   func @add_ssd(
772// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
773// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
774// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
775// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
776// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
777// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
778// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
779// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant true
780// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
781// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 1 : index
782// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
783// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
784// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
785// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
786// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
787// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
788// CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
789// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_16]] : memref<32x16x8xf32>)
790// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
791// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref<?xindex>
792// CHECK:           %[[VAL_19:.*]]:2 = scf.while (%[[VAL_20:.*]] = %[[VAL_17]], %[[VAL_21:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
793// CHECK:             %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_18]] : index
794// CHECK:             scf.condition(%[[VAL_22]]) %[[VAL_20]], %[[VAL_21]] : index, index
795// CHECK:           } do {
796// CHECK:           ^bb0(%[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index):
797// CHECK:             %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref<?xindex>
798// CHECK:             %[[VAL_26:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index
799// CHECK:             scf.if %[[VAL_26]] {
800// CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref<?xindex>
801// CHECK:               %[[VAL_28:.*]] = arith.addi %[[VAL_23]], %[[VAL_8]] : index
802// CHECK:               %[[VAL_29:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref<?xindex>
803// CHECK:               %[[VAL_30:.*]]:2 = scf.while (%[[VAL_31:.*]] = %[[VAL_27]], %[[VAL_32:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
804// CHECK:                 %[[VAL_33:.*]] = arith.cmpi ult, %[[VAL_31]], %[[VAL_29]] : index
805// CHECK:                 scf.condition(%[[VAL_33]]) %[[VAL_31]], %[[VAL_32]] : index, index
806// CHECK:               } do {
807// CHECK:               ^bb0(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index):
808// CHECK:                 %[[VAL_36:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_34]]] : memref<?xindex>
809// CHECK:                 %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_35]] : index
810// CHECK:                 scf.if %[[VAL_37]] {
811// CHECK:                   scf.for %[[VAL_38:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
812// CHECK:                     %[[VAL_39:.*]] = arith.muli %[[VAL_34]], %[[VAL_5]] : index
813// CHECK:                     %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_38]] : index
814// CHECK:                     %[[VAL_41:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_40]]] : memref<?xf32>
815// CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_38]]] : memref<32x16x8xf32>
816// CHECK:                     %[[VAL_43:.*]] = arith.addf %[[VAL_41]], %[[VAL_42]] : f32
817// CHECK:                     memref.store %[[VAL_43]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_38]]] : memref<32x16x8xf32>
818// CHECK:                   }
819// CHECK:                 } else {
820// CHECK:                   scf.if %[[VAL_6]] {
821// CHECK:                     scf.for %[[VAL_44:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
822// CHECK:                       %[[VAL_45:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_44]]] : memref<32x16x8xf32>
823// CHECK:                       memref.store %[[VAL_45]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_44]]] : memref<32x16x8xf32>
824// CHECK:                     }
825// CHECK:                   } else {
826// CHECK:                   }
827// CHECK:                 }
828// CHECK:                 %[[VAL_46:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_35]] : index
829// CHECK:                 %[[VAL_47:.*]] = arith.addi %[[VAL_34]], %[[VAL_8]] : index
830// CHECK:                 %[[VAL_48:.*]] = arith.select %[[VAL_46]], %[[VAL_47]], %[[VAL_34]] : index
831// CHECK:                 %[[VAL_49:.*]] = arith.addi %[[VAL_35]], %[[VAL_8]] : index
832// CHECK:                 scf.yield %[[VAL_48]], %[[VAL_49]] : index, index
833// CHECK:               }
834// CHECK:               scf.for %[[VAL_50:.*]] = %[[VAL_51:.*]]#1 to %[[VAL_4]] step %[[VAL_8]] {
835// CHECK:                 scf.for %[[VAL_52:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
836// CHECK:                   %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_50]], %[[VAL_52]]] : memref<32x16x8xf32>
837// CHECK:                   memref.store %[[VAL_53]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_50]], %[[VAL_52]]] : memref<32x16x8xf32>
838// CHECK:                 }
839// CHECK:               }
840// CHECK:             } else {
841// CHECK:               scf.if %[[VAL_6]] {
842// CHECK:                 scf.for %[[VAL_54:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
843// CHECK:                   scf.for %[[VAL_55:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
844// CHECK:                     %[[VAL_56:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_54]], %[[VAL_55]]] : memref<32x16x8xf32>
845// CHECK:                     memref.store %[[VAL_56]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_54]], %[[VAL_55]]] : memref<32x16x8xf32>
846// CHECK:                   }
847// CHECK:                 }
848// CHECK:               } else {
849// CHECK:               }
850// CHECK:             }
851// CHECK:             %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index
852// CHECK:             %[[VAL_58:.*]] = arith.addi %[[VAL_23]], %[[VAL_8]] : index
853// CHECK:             %[[VAL_59:.*]] = arith.select %[[VAL_57]], %[[VAL_58]], %[[VAL_23]] : index
854// CHECK:             %[[VAL_60:.*]] = arith.addi %[[VAL_24]], %[[VAL_8]] : index
855// CHECK:             scf.yield %[[VAL_59]], %[[VAL_60]] : index, index
856// CHECK:           }
857// CHECK:           scf.for %[[VAL_61:.*]] = %[[VAL_62:.*]]#1 to %[[VAL_3]] step %[[VAL_8]] {
858// CHECK:             scf.for %[[VAL_63:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
859// CHECK:               scf.for %[[VAL_64:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
860// CHECK:                 %[[VAL_65:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_61]], %[[VAL_63]], %[[VAL_64]]] : memref<32x16x8xf32>
861// CHECK:                 memref.store %[[VAL_65]], %[[VAL_16]]{{\[}}%[[VAL_61]], %[[VAL_63]], %[[VAL_64]]] : memref<32x16x8xf32>
862// CHECK:               }
863// CHECK:             }
864// CHECK:           }
865// CHECK:           %[[VAL_66:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<32x16x8xf32>
866// CHECK:           return %[[VAL_66]] : tensor<32x16x8xf32>
867// CHECK:         }
868func.func @add_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
869  %0 = linalg.generic #trait3
870     ins(%arga, %argb: tensor<32x16x8xf32, #Tssd>, tensor<32x16x8xf32>)
871    outs(%argx: tensor<32x16x8xf32>) {
872      ^bb(%a: f32, %b: f32, %x: f32):
873        %0 = arith.addf %a, %b : f32
874        linalg.yield %0 : f32
875  } -> tensor<32x16x8xf32>
876  return %0 : tensor<32x16x8xf32>
877}
878
879// CHECK-LABEL:   func @mul_ssd(
880// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
881// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
882// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
883// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
884// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 8 : index
885// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
886// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
887// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
888// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
889// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
890// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
891// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
892// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
893// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
894// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>)
895// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
896// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
897// CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] {
898// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
899// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
900// CHECK:             %[[VAL_19:.*]] = arith.addi %[[VAL_16]], %[[VAL_5]] : index
901// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
902// CHECK:             scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_5]] {
903// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
904// CHECK:               scf.for %[[VAL_23:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
905// CHECK:                 %[[VAL_24:.*]] = arith.muli %[[VAL_21]], %[[VAL_3]] : index
906// CHECK:                 %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_23]] : index
907// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
908// CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_23]]] : memref<32x16x8xf32>
909// CHECK:                 %[[VAL_28:.*]] = arith.mulf %[[VAL_26]], %[[VAL_27]] : f32
910// CHECK:                 memref.store %[[VAL_28]], %[[VAL_13]]{{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_23]]] : memref<32x16x8xf32>
911// CHECK:               }
912// CHECK:             }
913// CHECK:           }
914// CHECK:           %[[VAL_29:.*]] = bufferization.to_tensor %[[VAL_13]] : memref<32x16x8xf32>
915// CHECK:           return %[[VAL_29]] : tensor<32x16x8xf32>
916// CHECK:         }
917func.func @mul_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
918  %0 = linalg.generic #trait3
919     ins(%arga, %argb: tensor<32x16x8xf32, #Tssd>, tensor<32x16x8xf32>)
920    outs(%argx: tensor<32x16x8xf32>) {
921      ^bb(%a: f32, %b: f32, %x: f32):
922        %0 = arith.mulf %a, %b : f32
923        linalg.yield %0 : f32
924  } -> tensor<32x16x8xf32>
925  return %0 : tensor<32x16x8xf32>
926}
927
928// CHECK-LABEL:   func @add_sss(
929// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
930// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
931// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
932// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
933// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
934// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
935// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 16 : index
936// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 8 : index
937// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant true
938// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
939// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 1 : index
940// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
941// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
942// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
943// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
944// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
945// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
946// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
947// CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
948// CHECK-DAG:       %[[VAL_19:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
949// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_19]] : memref<32x16x8xf32>)
950// CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_8]]] : memref<?xindex>
951// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_9]]] : memref<?xindex>
952// CHECK:           %[[VAL_22:.*]]:2 = scf.while (%[[VAL_23:.*]] = %[[VAL_20]], %[[VAL_24:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
953// CHECK:             %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_23]], %[[VAL_21]] : index
954// CHECK:             scf.condition(%[[VAL_25]]) %[[VAL_23]], %[[VAL_24]] : index, index
955// CHECK:           } do {
956// CHECK:           ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index):
957// CHECK:             %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_26]]] : memref<?xindex>
958// CHECK:             %[[VAL_29:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_27]] : index
959// CHECK:             scf.if %[[VAL_29]] {
960// CHECK:               %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<?xindex>
961// CHECK:               %[[VAL_31:.*]] = arith.addi %[[VAL_26]], %[[VAL_9]] : index
962// CHECK:               %[[VAL_32:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<?xindex>
963// CHECK:               %[[VAL_33:.*]]:2 = scf.while (%[[VAL_34:.*]] = %[[VAL_30]], %[[VAL_35:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
964// CHECK:                 %[[VAL_36:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_32]] : index
965// CHECK:                 scf.condition(%[[VAL_36]]) %[[VAL_34]], %[[VAL_35]] : index, index
966// CHECK:               } do {
967// CHECK:               ^bb0(%[[VAL_37:.*]]: index, %[[VAL_38:.*]]: index):
968// CHECK:                 %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_37]]] : memref<?xindex>
969// CHECK:                 %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_38]] : index
970// CHECK:                 scf.if %[[VAL_40]] {
971// CHECK:                   %[[VAL_41:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_37]]] : memref<?xindex>
972// CHECK:                   %[[VAL_42:.*]] = arith.addi %[[VAL_37]], %[[VAL_9]] : index
973// CHECK:                   %[[VAL_43:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_42]]] : memref<?xindex>
974// CHECK:                   %[[VAL_44:.*]]:2 = scf.while (%[[VAL_45:.*]] = %[[VAL_41]], %[[VAL_46:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
975// CHECK:                     %[[VAL_47:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_43]] : index
976// CHECK:                     scf.condition(%[[VAL_47]]) %[[VAL_45]], %[[VAL_46]] : index, index
977// CHECK:                   } do {
978// CHECK:                   ^bb0(%[[VAL_48:.*]]: index, %[[VAL_49:.*]]: index):
979// CHECK:                     %[[VAL_50:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_48]]] : memref<?xindex>
980// CHECK:                     %[[VAL_51:.*]] = arith.cmpi eq, %[[VAL_50]], %[[VAL_49]] : index
981// CHECK:                     scf.if %[[VAL_51]] {
982// CHECK:                       %[[VAL_52:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_48]]] : memref<?xf32>
983// CHECK:                       %[[VAL_53:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
984// CHECK:                       %[[VAL_54:.*]] = arith.addf %[[VAL_52]], %[[VAL_53]] : f32
985// CHECK:                       memref.store %[[VAL_54]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
986// CHECK:                     } else {
987// CHECK:                       scf.if %[[VAL_7]] {
988// CHECK:                         %[[VAL_55:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
989// CHECK:                         memref.store %[[VAL_55]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
990// CHECK:                       } else {
991// CHECK:                       }
992// CHECK:                     }
993// CHECK:                     %[[VAL_56:.*]] = arith.cmpi eq, %[[VAL_50]], %[[VAL_49]] : index
994// CHECK:                     %[[VAL_57:.*]] = arith.addi %[[VAL_48]], %[[VAL_9]] : index
995// CHECK:                     %[[VAL_58:.*]] = arith.select %[[VAL_56]], %[[VAL_57]], %[[VAL_48]] : index
996// CHECK:                     %[[VAL_59:.*]] = arith.addi %[[VAL_49]], %[[VAL_9]] : index
997// CHECK:                     scf.yield %[[VAL_58]], %[[VAL_59]] : index, index
998// CHECK:                   }
999// CHECK:                   scf.for %[[VAL_60:.*]] = %[[VAL_61:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
1000// CHECK:                     %[[VAL_62:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_60]]] : memref<32x16x8xf32>
1001// CHECK:                     memref.store %[[VAL_62]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_60]]] : memref<32x16x8xf32>
1002// CHECK:                   }
1003// CHECK:                 } else {
1004// CHECK:                   scf.if %[[VAL_7]] {
1005// CHECK:                     scf.for %[[VAL_63:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
1006// CHECK:                       %[[VAL_64:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_63]]] : memref<32x16x8xf32>
1007// CHECK:                       memref.store %[[VAL_64]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_63]]] : memref<32x16x8xf32>
1008// CHECK:                     }
1009// CHECK:                   } else {
1010// CHECK:                   }
1011// CHECK:                 }
1012// CHECK:                 %[[VAL_65:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_38]] : index
1013// CHECK:                 %[[VAL_66:.*]] = arith.addi %[[VAL_37]], %[[VAL_9]] : index
1014// CHECK:                 %[[VAL_67:.*]] = arith.select %[[VAL_65]], %[[VAL_66]], %[[VAL_37]] : index
1015// CHECK:                 %[[VAL_68:.*]] = arith.addi %[[VAL_38]], %[[VAL_9]] : index
1016// CHECK:                 scf.yield %[[VAL_67]], %[[VAL_68]] : index, index
1017// CHECK:               }
1018// CHECK:               scf.for %[[VAL_69:.*]] = %[[VAL_70:.*]]#1 to %[[VAL_5]] step %[[VAL_9]] {
1019// CHECK:                 scf.for %[[VAL_71:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
1020// CHECK:                   %[[VAL_72:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_69]], %[[VAL_71]]] : memref<32x16x8xf32>
1021// CHECK:                   memref.store %[[VAL_72]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_69]], %[[VAL_71]]] : memref<32x16x8xf32>
1022// CHECK:                 }
1023// CHECK:               }
1024// CHECK:             } else {
1025// CHECK:               scf.if %[[VAL_7]] {
1026// CHECK:                 scf.for %[[VAL_73:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
1027// CHECK:                   scf.for %[[VAL_74:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
1028// CHECK:                     %[[VAL_75:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_73]], %[[VAL_74]]] : memref<32x16x8xf32>
1029// CHECK:                     memref.store %[[VAL_75]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_73]], %[[VAL_74]]] : memref<32x16x8xf32>
1030// CHECK:                   }
1031// CHECK:                 }
1032// CHECK:               } else {
1033// CHECK:               }
1034// CHECK:             }
1035// CHECK:             %[[VAL_76:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_27]] : index
1036// CHECK:             %[[VAL_77:.*]] = arith.addi %[[VAL_26]], %[[VAL_9]] : index
1037// CHECK:             %[[VAL_78:.*]] = arith.select %[[VAL_76]], %[[VAL_77]], %[[VAL_26]] : index
1038// CHECK:             %[[VAL_79:.*]] = arith.addi %[[VAL_27]], %[[VAL_9]] : index
1039// CHECK:             scf.yield %[[VAL_78]], %[[VAL_79]] : index, index
1040// CHECK:           }
1041// CHECK:           scf.for %[[VAL_80:.*]] = %[[VAL_81:.*]]#1 to %[[VAL_4]] step %[[VAL_9]] {
1042// CHECK:             scf.for %[[VAL_82:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
1043// CHECK:               scf.for %[[VAL_83:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
1044// CHECK:                 %[[VAL_84:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_80]], %[[VAL_82]], %[[VAL_83]]] : memref<32x16x8xf32>
1045// CHECK:                 memref.store %[[VAL_84]], %[[VAL_19]]{{\[}}%[[VAL_80]], %[[VAL_82]], %[[VAL_83]]] : memref<32x16x8xf32>
1046// CHECK:               }
1047// CHECK:             }
1048// CHECK:           }
1049// CHECK:           %[[VAL_85:.*]] = bufferization.to_tensor %[[VAL_19]] : memref<32x16x8xf32>
1050// CHECK:           return %[[VAL_85]] : tensor<32x16x8xf32>
1051// CHECK:         }
1052func.func @add_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
1053  %0 = linalg.generic #trait3
1054     ins(%arga, %argb: tensor<32x16x8xf32, #Tsss>, tensor<32x16x8xf32>)
1055    outs(%argx: tensor<32x16x8xf32>) {
1056      ^bb(%a: f32, %b: f32, %x: f32):
1057        %0 = arith.addf %a, %b : f32
1058        linalg.yield %0 : f32
1059  } -> tensor<32x16x8xf32>
1060  return %0 : tensor<32x16x8xf32>
1061}
1062
1063// CHECK-LABEL:   func @mul_sss(
1064// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
1065// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
1066// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
1067// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
1068// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
1069// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
1070// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
1071// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
1072// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
1073// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
1074// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
1075// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
1076// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
1077// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
1078// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16x8xf32>
1079// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
1080// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>)
1081// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
1082// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
1083// CHECK:           scf.for %[[VAL_18:.*]] = %[[VAL_16]] to %[[VAL_17]] step %[[VAL_5]] {
1084// CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
1085// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
1086// CHECK:             %[[VAL_21:.*]] = arith.addi %[[VAL_18]], %[[VAL_5]] : index
1087// CHECK:             %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref<?xindex>
1088// CHECK:             scf.for %[[VAL_23:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] {
1089// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
1090// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref<?xindex>
1091// CHECK:               %[[VAL_26:.*]] = arith.addi %[[VAL_23]], %[[VAL_5]] : index
1092// CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref<?xindex>
1093// CHECK:               scf.for %[[VAL_28:.*]] = %[[VAL_25]] to %[[VAL_27]] step %[[VAL_5]] {
1094// CHECK:                 %[[VAL_29:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref<?xindex>
1095// CHECK:                 %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref<?xf32>
1096// CHECK:                 %[[VAL_31:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]], %[[VAL_24]], %[[VAL_29]]] : memref<32x16x8xf32>
1097// CHECK:                 %[[VAL_32:.*]] = arith.mulf %[[VAL_30]], %[[VAL_31]] : f32
1098// CHECK:                 memref.store %[[VAL_32]], %[[VAL_15]]{{\[}}%[[VAL_19]], %[[VAL_24]], %[[VAL_29]]] : memref<32x16x8xf32>
1099// CHECK:               }
1100// CHECK:             }
1101// CHECK:           }
1102// CHECK:           %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<32x16x8xf32>
1103// CHECK:           return %[[VAL_33]] : tensor<32x16x8xf32>
1104// CHECK:         }
1105func.func @mul_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
1106  %0 = linalg.generic #trait3
1107     ins(%arga, %argb: tensor<32x16x8xf32, #Tsss>, tensor<32x16x8xf32>)
1108    outs(%argx: tensor<32x16x8xf32>) {
1109      ^bb(%a: f32, %b: f32, %x: f32):
1110        %0 = arith.mulf %a, %b : f32
1111        linalg.yield %0 : f32
1112  } -> tensor<32x16x8xf32>
1113  return %0 : tensor<32x16x8xf32>
1114}
1115
1116#trait_kernel_3d = {
1117  indexing_maps = [
1118    affine_map<(i,j,k,l) -> (i,k,l)>,  // B
1119    affine_map<(i,j,k,l) -> (k,j)>,    // C
1120    affine_map<(i,j,k,l) -> (l,j)>,    // D
1121    affine_map<(i,j,k,l) -> (i,j)>     // A (out)
1122  ],
1123  iterator_types = ["parallel", "parallel", "reduction", "reduction"],
1124  doc = "A(i,j) += SUM_k,l B(i,k,l) * C(k,j) * D(l,j)"
1125}
1126
1127// CHECK-LABEL:   func @kernel_3d(
1128// CHECK-SAME:      %[[VAL_0:.*0]]: tensor<?x?xf32>,
1129// CHECK-SAME:      %[[VAL_1:.*1]]: tensor<?x?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
1130// CHECK-SAME:      %[[VAL_2:.*2]]: tensor<?x?xf32>,
1131// CHECK-SAME:      %[[VAL_3:.*3]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
1132// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
1133// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
1134// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
1135// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
1136// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
1137// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
1138// CHECK-DAG:       %[[VAL_10:.*]] = tensor.dim %[[VAL_2]], %[[VAL_5]] : tensor<?x?xf32>
1139// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<?x?xf32>
1140// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?x?xf32>
1141// CHECK-DAG:       %[[VAL_13:.*]] = tensor.dim %[[VAL_0]], %[[VAL_5]] : tensor<?x?xf32>
1142// CHECK-DAG:       %[[VAL_14:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?xf32>
1143// CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_0]] : memref<?x?xf32>
1144// CHECK:           scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_13]] step %[[VAL_6]] {
1145// CHECK:             scf.for %[[VAL_18:.*]] = %[[VAL_5]] to %[[VAL_10]] step %[[VAL_6]] {
1146// CHECK:               %[[VAL_19:.*]] = arith.muli %[[VAL_10]], %[[VAL_17]] : index
1147// CHECK:               %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index
1148// CHECK:               %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex>
1149// CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_6]] : index
1150// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_22]]] : memref<?xindex>
1151// CHECK:               scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_6]] {
1152// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xindex>
1153// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_24]]] : memref<?xf32>
1154// CHECK:                 scf.for %[[VAL_27:.*]] = %[[VAL_5]] to %[[VAL_14]] step %[[VAL_6]] {
1155// CHECK:                   %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_18]], %[[VAL_27]]] : memref<?x?xf32>
1156// CHECK:                   %[[VAL_29:.*]] = arith.mulf %[[VAL_26]], %[[VAL_28]] : f32
1157// CHECK:                   %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_25]], %[[VAL_27]]] : memref<?x?xf32>
1158// CHECK:                   %[[VAL_31:.*]] = arith.mulf %[[VAL_29]], %[[VAL_30]] : f32
1159// CHECK:                   %[[VAL_32:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_17]], %[[VAL_27]]] : memref<?x?xf32>
1160// CHECK:                   %[[VAL_33:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32
1161// CHECK:                   memref.store %[[VAL_33]], %[[VAL_16]]{{\[}}%[[VAL_17]], %[[VAL_27]]] : memref<?x?xf32>
1162// CHECK:                 }
1163// CHECK:               }
1164// CHECK:             }
1165// CHECK:           }
1166// CHECK:           %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<?x?xf32>
1167// CHECK:           return %[[VAL_34]] : tensor<?x?xf32>
1168// CHECK:         }
1169func.func @kernel_3d(%arga: tensor<?x?xf32>,
1170                %argb: tensor<?x?x?xf32, #Tdds>,
1171                %argc: tensor<?x?xf32>,
1172	        %argd: tensor<?x?xf32>) -> tensor<?x?xf32> {
1173  %0 = linalg.generic #trait_kernel_3d
1174       ins(%argb, %argc, %argd: tensor<?x?x?xf32, #Tdds>, tensor<?x?xf32>, tensor<?x?xf32>)
1175      outs(%arga: tensor<?x?xf32>) {
1176    ^bb(%b: f32, %c: f32, %d: f32, %a: f32):
1177      %0 = arith.mulf %b, %c : f32
1178      %1 = arith.mulf %0, %d : f32
1179      %2 = arith.addf %1, %a : f32
1180      linalg.yield %2 : f32
1181  } -> tensor<?x?xf32>
1182  return %0 : tensor<?x?xf32>
1183}
1184
1185#trait_sum_reduction = {
1186  indexing_maps = [
1187    affine_map<(i,j,k) -> (i,j,k)>,  // A
1188    affine_map<(i,j,k) -> ()>        // x (scalar out)
1189  ],
1190  iterator_types = ["reduction", "reduction", "reduction"],
1191  doc = "x += SUM_ijk A(i,j,k)"
1192}
1193
1194// CHECK-LABEL:   func @sum_reduction(
1195// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>,
1196// CHECK-SAME:      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
1197// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
1198// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
1199// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
1200// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}>>
1201// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}>>
1202// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>
1203// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>
1204// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
1205// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
1206// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
1207// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
1208// CHECK:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_3]] iter_args(%[[VAL_16:.*]] = %[[VAL_11]]) -> (f32) {
1209// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
1210// CHECK:             %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_3]] : index
1211// CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_18]]] : memref<?xindex>
1212// CHECK:             %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_3]] iter_args(%[[VAL_22:.*]] = %[[VAL_16]]) -> (f32) {
1213// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_21]]] : memref<?xindex>
1214// CHECK:               %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
1215// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xindex>
1216// CHECK:               %[[VAL_26:.*]] = scf.for %[[VAL_27:.*]] = %[[VAL_23]] to %[[VAL_25]] step %[[VAL_3]] iter_args(%[[VAL_28:.*]] = %[[VAL_22]]) -> (f32) {
1217// CHECK:                 %[[VAL_29:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref<?xf32>
1218// CHECK:                 %[[VAL_30:.*]] = arith.addf %[[VAL_28]], %[[VAL_29]] : f32
1219// CHECK:                 scf.yield %[[VAL_30]] : f32
1220// CHECK:               }
1221// CHECK:               scf.yield %[[VAL_26]] : f32
1222// CHECK:             }
1223// CHECK:             scf.yield %[[VAL_20]] : f32
1224// CHECK:           }
1225// CHECK:           memref.store %[[VAL_14]], %[[VAL_10]][] : memref<f32>
1226// CHECK:           %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_10]] : memref<f32>
1227// CHECK:           return %[[VAL_34]] : tensor<f32>
1228// CHECK:         }
1229func.func @sum_reduction(%arga: tensor<10x20x30xf32, #Tsss>, %argx: tensor<f32>) -> tensor<f32> {
1230  %0 = linalg.generic #trait_sum_reduction
1231     ins(%arga: tensor<10x20x30xf32, #Tsss>)
1232    outs(%argx: tensor<f32>) {
1233      ^bb(%a: f32, %x: f32):
1234        %0 = arith.addf %x, %a : f32
1235        linalg.yield %0 : f32
1236  } -> tensor<f32>
1237  return %0 : tensor<f32>
1238}
1239
1240#trait_sum_reduction_inv = {
1241  indexing_maps = [
1242    affine_map<(i,j,k) -> (i,j,k)>,  // A
1243    affine_map<(i,j,k) -> (i)>,      // b
1244    affine_map<(i,j,k) -> ()>        // x (scalar out)
1245  ],
1246  iterator_types = ["reduction", "reduction", "reduction"],
1247  doc = "x += SUM_i A(i,j,k) * b(i)"
1248}
1249
1250// CHECK-LABEL:   func @sum_reduction_inv(
1251// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?x?xf32>,
1252// CHECK-SAME:      %[[VAL_1:.*]]: tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
1253// CHECK-SAME:      %[[VAL_2:.*]]: tensor<f32>) -> tensor<f32> {
1254// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
1255// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
1256// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
1257// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32>
1258// CHECK-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
1259// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<?x?x?xf32>
1260// CHECK-DAG:       %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[VAL_5]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
1261// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
1262// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<f32>
1263// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_12]][] : memref<f32>
1264// CHECK:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_9]] step %[[VAL_3]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
1265// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_15]]] : memref<?xf32>
1266// CHECK:             %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_3]] iter_args(%[[VAL_20:.*]] = %[[VAL_16]]) -> (f32) {
1267// CHECK:               %[[VAL_21:.*]] = scf.for %[[VAL_22:.*]] = %[[VAL_5]] to %[[VAL_7]] step %[[VAL_3]] iter_args(%[[VAL_23:.*]] = %[[VAL_20]]) -> (f32) {
1268// CHECK:                 %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]], %[[VAL_19]], %[[VAL_22]]] : memref<?x?x?xf32>
1269// CHECK:                 %[[VAL_25:.*]] = arith.mulf %[[VAL_24]], %[[VAL_17]] : f32
1270// CHECK:                 %[[VAL_26:.*]] = arith.addf %[[VAL_23]], %[[VAL_25]] : f32
1271// CHECK:                 scf.yield %[[VAL_26]] : f32
1272// CHECK:               }
1273// CHECK:               scf.yield %[[VAL_21]] : f32
1274// CHECK:             }
1275// CHECK:             scf.yield %[[VAL_18]] : f32
1276// CHECK:           }
1277// CHECK:           memref.store %[[VAL_14]], %[[VAL_12]][] : memref<f32>
1278// CHECK:           %[[VAL_30:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<f32>
1279// CHECK:           return %[[VAL_30]] : tensor<f32>
1280// CHECK:         }
1281func.func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
1282                        %argb: tensor<?xf32, #Td>,
1283		        %argx: tensor<f32>) -> tensor<f32> {
1284  %0 = linalg.generic #trait_sum_reduction_inv
1285    ins(%arga, %argb: tensor<?x?x?xf32>, tensor<?xf32, #Td>)
1286    outs(%argx: tensor<f32>) {
1287      ^bb(%a: f32, %b: f32, %x: f32):
1288        %0 = arith.mulf %a, %b : f32
1289        %1 = arith.addf %x, %0 : f32
1290        linalg.yield %1 : f32
1291  } -> tensor<f32>
1292  return %0 : tensor<f32>
1293}
1294
1295#trait_invariants = {
1296  indexing_maps = [
1297    affine_map<(i,j,k) -> (i)>,      // a
1298    affine_map<(i,j,k) -> (j)>,      // b
1299    affine_map<(i,j,k) -> (k)>,      // c
1300    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
1301  ],
1302  iterator_types = ["parallel", "parallel", "parallel"],
1303  doc = "X(i,j,k) = a(i) * b(j) * c(k)"
1304}
1305
1306// CHECK-LABEL:   func @invariants(
1307// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
1308// CHECK-SAME:      %[[VAL_1:.*]]: tensor<20xf32>,
1309// CHECK-SAME:      %[[VAL_2:.*]]: tensor<30xf32>,
1310// CHECK-SAME:      %[[VAL_3:.*]]: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> {
1311// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
1312// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 10 : index
1313// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 20 : index
1314// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 30 : index
1315// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
1316// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 1 : index
1317// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
1318// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20xf32>
1319// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<30xf32>
1320// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_3]] : memref<10x20x30xf32>
1321// CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<10x20x30xf32>)
1322// CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
1323// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<?xf32>
1324// CHECK:             scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
1325// CHECK:               %[[VAL_17:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<20xf32>
1326// CHECK:               scf.for %[[VAL_18:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] {
1327// CHECK:                 %[[VAL_19:.*]] = arith.mulf %[[VAL_15]], %[[VAL_17]] : f32
1328// CHECK:                 %[[VAL_20:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_18]]] : memref<30xf32>
1329// CHECK:                 %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f32
1330// CHECK:                 memref.store %[[VAL_21]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_16]], %[[VAL_18]]] : memref<10x20x30xf32>
1331// CHECK:               }
1332// CHECK:             }
1333// CHECK:           }
1334// CHECK:           %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_13]] : memref<10x20x30xf32>
1335// CHECK:           return %[[VAL_22]] : tensor<10x20x30xf32>
1336// CHECK:         }
1337func.func @invariants(%arga: tensor<10xf32, #Td>,
1338                 %argb: tensor<20xf32>,
1339                 %argc: tensor<30xf32>,
1340                 %argx: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> {
1341  %0 = linalg.generic #trait_invariants
1342     ins(%arga, %argb, %argc : tensor<10xf32, #Td>, tensor<20xf32>, tensor<30xf32>)
1343    outs(%argx: tensor<10x20x30xf32>) {
1344      ^bb(%a: f32, %b: f32, %c: f32, %x: f32):
1345        %0 = arith.mulf %a, %b : f32
1346        %1 = arith.mulf %0, %c : f32
1347        linalg.yield %1 : f32
1348  } -> tensor<10x20x30xf32>
1349  return %0 : tensor<10x20x30xf32>
1350}
1351