1// RUN: mlir-opt %s -tensor-bufferize -cse -split-input-file | FileCheck %s
2
3// CHECK-LABEL:   func @dim(
4// CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
5// CHECK-SAME:              %[[INDEX:.*]]: index) -> index {
6// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<f32>
7// CHECK:           %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
8// CHECK:           return %[[EXTENT]] : index
9func.func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
10  %0 = tensor.dim %arg0, %arg1 : tensor<f32>
11  return %0 : index
12}
13
14// -----
15
16// CHECK-LABEL: func @rank(
17// CHECK-SAME:    %[[TENSOR:.*]]: tensor<*xf32>) -> index {
18// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
19// CHECK:           %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32>
20func.func @rank(%arg0: tensor<*xf32>) -> index {
21  %0 = tensor.rank %arg0 : tensor<*xf32>
22  return %0 : index
23}
24
25// -----
26
27// CHECK-LABEL:   func @tensor.cast(
28// CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
29// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
30// CHECK:           %[[CASTED:.*]] = memref.cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
31// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED]]
32// CHECK:           return %[[RET]] : tensor<2xindex>
33func.func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
34  %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex>
35  return %0 : tensor<2xindex>
36}
37
38// -----
39
40// CHECK-LABEL:   func @tensor.cast_from_unranked(
41// CHECK-SAME:                                    %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
42// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32>
43// CHECK:           %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
44// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32>
45// CHECK:           return %[[RET]] : tensor<2xf32>
46func.func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
47  %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>
48  return %0 : tensor<2xf32>
49}
50
51// -----
52
53// CHECK-LABEL:   func @tensor.cast_to_unranked(
54// CHECK-SAME:                                  %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
55// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<2xf32>
56// CHECK:           %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
57// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<*xf32>
58// CHECK:           return %[[RET]] : tensor<*xf32>
59func.func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
60  %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32>
61  return %0 : tensor<*xf32>
62}
63
64// -----
65
66// CHECK-LABEL:   func @tensor.extract(
67// CHECK-SAME:                  %[[TENSOR:.*]]: tensor<?xf32>,
68// CHECK-SAME:                  %[[IDX:.*]]: index) -> f32 {
69// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<?xf32>
70// CHECK:           %[[RET:.*]] = memref.load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
71// CHECK:           return %[[RET]] : f32
72// CHECK:         }
73func.func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
74  %0 = tensor.extract %arg0[%arg1] : tensor<?xf32>
75  return %0 : f32
76}
77
78// -----
79
80// CHECK-LABEL:   func @tensor.from_elements_0d(
81// CHECK-SAME:        %[[ELEM0:.*]]: index) -> tensor<index> {
82// CHECK:           %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index>
83// CHECK:           store %[[ELEM0]], %[[MEMREF]]
84// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
85// CHECK:           return %[[RET]] : tensor<index>
86func.func @tensor.from_elements_0d(%arg0: index) -> tensor<index> {
87  %0 = tensor.from_elements %arg0 : tensor<index>
88  return %0 : tensor<index>
89}
90
91// -----
92
93// CHECK-LABEL:   func @tensor.from_elements_1d(
94// CHECK-SAME:                               %[[ELEM0:.*]]: index,
95// CHECK-SAME:                               %[[ELEM1:.*]]: index) -> tensor<2xindex> {
96// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
97// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
98// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex>
99// CHECK:           store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
100// CHECK:           store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
101// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
102// CHECK:           return %[[RET]] : tensor<2xindex>
103func.func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
104  %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex>
105  return %0 : tensor<2xindex>
106}
107
108// -----
109
110// CHECK-LABEL: func @tensor.from_elements_2d(
111// CHECK-SAME:      %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index)
112// CHECK-SAME:      -> tensor<3x2xindex> {
113// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
114// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
115// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
116// CHECK-DAG:     %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
117// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
118// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
119// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
120// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
121// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
122// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
123// CHECK:         %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
124// CHECK:         return %[[RET]] : tensor<3x2xindex>
125func.func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
126  %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
127         : tensor<3x2xindex>
128  return %0 : tensor<3x2xindex>
129}
130
131// -----
132
133// CHECK-LABEL: func @tensor.from_elements_3d(
134//  CHECK-SAME:     %[[F0:.*]]: f32
135
136// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
137// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
138// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
139// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
140// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
141// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
142// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
143// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
144// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
145// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
146// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
147
148// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
149// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
150// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
151
152// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32>
153
154// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
155// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
156// CHECK: store %[[F2]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C0]]]
157// CHECK: store %[[F3]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C1]]]
158// CHECK: store %[[F4]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C0]]]
159// CHECK: store %[[F5]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C1]]]
160// CHECK: store %[[F6]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C0]]]
161// CHECK: store %[[F7]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C1]]]
162// CHECK: store %[[F8]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C0]]]
163// CHECK: store %[[F9]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C1]]]
164// CHECK: store %[[F10]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C0]]]
165// CHECK: store %[[F11]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C1]]]
166
167// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
168// CHECK: return %[[RET]] : tensor<3x2x2xf32>
169func.func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
170  %f1 = arith.constant 1.0 : f32
171  %f2 = arith.constant 2.0 : f32
172  %f3 = arith.constant 3.0 : f32
173  %f4 = arith.constant 4.0 : f32
174  %f5 = arith.constant 5.0 : f32
175  %f6 = arith.constant 6.0 : f32
176  %f7 = arith.constant 7.0 : f32
177  %f8 = arith.constant 8.0 : f32
178  %f9 = arith.constant 9.0 : f32
179  %f10 = arith.constant 10.0 : f32
180  %f11 = arith.constant 11.0 : f32
181  %0 = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
182         : tensor<3x2x2xf32>
183  return %0 : tensor<3x2x2xf32>
184}
185
186// -----
187
188// CHECK-LABEL:   func @tensor.generate(
189// CHECK-SAME:                                       %[[ARG:.*]]: tensor<*xf32>,
190// CHECK-SAME:                                       %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
191// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
192// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
193// CHECK-DAG:       %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
194// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
195// CHECK:           scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
196// CHECK:             %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
197// CHECK:             store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
198// CHECK:             scf.yield
199// CHECK:           }
200// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<?xindex>
201// CHECK:           return %[[RET]] : tensor<?xindex>
202// CHECK:         }
203func.func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> {
204  %result = tensor.generate %dynamic_extent {
205  ^bb0(%i : index):
206    %elem = tensor.dim %arg, %i : tensor<*xf32>
207    tensor.yield %elem : index
208  } : tensor<?xindex>
209  return %result : tensor<?xindex>
210}
211
212// -----
213
214// Additional test that checks the logic for intermixed static and dynamic
215// extents.
216//
217// CHECK-LABEL:   func @tensor.generate_static_and_dynamic(
218// CHECK-SAME:        %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
219// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
220// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
221// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
222// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
223// CHECK:           scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
224// CHECK:             %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index
225// CHECK:             store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
226// CHECK:             scf.yield
227// CHECK:           }
228// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<16x?xindex>
229// CHECK:           return %[[RET]] : tensor<16x?xindex>
230// CHECK:         }
231func.func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
232  %result = tensor.generate %arg0 {
233  ^bb0(%i: index, %j: index):
234    %sum = arith.addi %i, %j : index
235    tensor.yield %sum : index
236  } : tensor<16x?xindex>
237  return %result : tensor<16x?xindex>
238}
239
240// -----
241
242// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body
243func.func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
244  // CHECK-NOT: tensor.generate
245  %tensor = tensor.generate %arg0 {
246  ^bb0(%iv: index):
247    // CHECK: test.source
248    %0 = "test.source"() : () -> index
249    tensor.yield %0 : index
250  } : tensor<?xindex>
251  return %tensor : tensor<?xindex>
252}
253
254// -----
255
256 // CHECK-DAG: #[[$MAP0a:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
257
258// CHECK-LABEL: func @tensor.extract_slice(
259//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?xf32>, %[[idx1:.*]]: index, %[[idx2:.*]]: index
260func.func @tensor.extract_slice(
261    %t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> {
262  // CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
263  // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP0a]]>
264  %0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1]
265      : tensor<?x?xf32> to tensor<?x10xf32>
266  // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
267  // CHECK: return %[[r_tensor]]
268  return %0 : tensor<?x10xf32>
269}
270
271// -----
272
273// CHECK-DAG: #[[$MAP0b:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
274
275// CHECK-LABEL: func @tensor.extract_slice_rank_reducing(
276//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10x?xf32>, %[[idx1:.*]]: index,
277//  CHECK-SAME:     %[[idx2:.*]]: index
278func.func @tensor.extract_slice_rank_reducing(
279    %t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> {
280  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10x?xf32>
281  // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP0b]]>
282  %0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1]
283      : tensor<?x10x?xf32> to tensor<?x15xf32>
284  // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
285  // CHECK: return %[[r_tensor]]
286  return %0 : tensor<?x15xf32>
287}
288
289// -----
290
291// CHECK-LABEL: func @tensor.insert_slice(
292//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>,
293//  CHECK-SAME:     %[[idx1:.*]]: index, %[[idx2:.*]]: index
294func.func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
295                               %idx1: index, %idx2: index) -> tensor<?x?xf32> {
296  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
297  // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
298  // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
299  // CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref<?x10xf32>
300  // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
301  // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
302  //     CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]])
303  //     CHECK: memref.copy %[[m1]], %[[alloc]]
304  //     CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1]
305  //     CHECK: memref.copy %[[m2]], %[[subview]]
306  %0 = tensor.insert_slice %t2 into %t1[%idx1, 5][%idx2, 10][1, 1]
307      : tensor<?x10xf32> into tensor<?x?xf32>
308
309  //     CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
310  //     CHECK: return %[[r]]
311  return %0 : tensor<?x?xf32>
312}
313
314// -----
315
316// CHECK: #[[$MAP11:.*]] = affine_map<()[s0] -> (s0)>
317
318// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_1(
319func.func @tensor.insert_slice_rank_reducing_1(
320    %t1: tensor<?x?xf32>, %f: tensor<f32>, %idx1: index, %idx2: index)
321  -> tensor<?x?xf32>
322{
323  // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?xf32>
324  // CHECK: memref.subview %[[alloc]][%{{.*}}, %{{.*}}] [1, 1] [1, 1] : memref<?x?xf32> to memref<f32, #[[$MAP11]]>
325  // CHECK: memref.copy {{.*}} : memref<f32> to memref<f32, #[[$MAP11]]>
326  %0 = tensor.insert_slice %f into %t1[%idx1, %idx2][1, 1][1, 1]
327      : tensor<f32> into tensor<?x?xf32>
328  return %0 : tensor<?x?xf32>
329}
330
331// -----
332
333// CHECK: #[[$MAP12:.*]] = affine_map<(d0, d1, d2, d3, d4)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5)>
334
335// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_2(
336func.func @tensor.insert_slice_rank_reducing_2(
337    %t1: tensor<?x?x?x?x?x?x?xf32>, %t2: tensor<2x1x4x1x1xf32>, %i: index)
338  -> tensor<?x?x?x?x?x?x?xf32>
339{
340  // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?x?x?x?x?x?xf32>
341  // CHECK: memref.subview %[[alloc]][{{.*}}] [1, 2, 1, 4, 1, 1, 1] [1, 1, 1, 1, 1, 1, 1] : memref<?x?x?x?x?x?x?xf32> to memref<2x1x4x1x1xf32, #[[$MAP12]]>
342  // CHECK: memref.copy {{.*}} : memref<2x1x4x1x1xf32> to memref<2x1x4x1x1xf32, #[[$MAP12]]>
343  %0 = tensor.insert_slice %t2 into %t1[%i, %i, %i, %i, %i, %i, %i][1, 2, 1, 4, 1, 1, 1][1, 1, 1, 1, 1, 1, 1]
344      : tensor<2x1x4x1x1xf32> into tensor<?x?x?x?x?x?x?xf32>
345  return %0 : tensor<?x?x?x?x?x?x?xf32>
346}
347
348// -----
349
350// CHECK-LABEL: func @tensor.insert(
351//  CHECK-SAME:     %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index,
352//  CHECK-SAME:     %[[f:.*]]: f32
353func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> {
354  // CHECK-DAG: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
355  // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32>
356  // CHECK: memref.copy %[[m1]], %[[alloc]]
357  // CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]]
358  %0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32>
359
360  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
361  // CHECK: return %[[r]]
362  return %0 : tensor<5xf32>
363}
364
365// -----
366
367// CHECK-LABEL: func @tensor.expand_shape(
368//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
369func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
370  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
371  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [
372  // CHECK-SAME: [0, 1], [2]] : memref<?x10xf32> into memref<2x?x10xf32>
373  %0 = tensor.expand_shape %t1 [[0, 1], [2]]
374      : tensor<?x10xf32> into tensor<2x?x10xf32>
375
376  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
377  // CHECK: return %[[r]]
378  return %0 : tensor<2x?x10xf32>
379}
380
381// -----
382
383// CHECK-DAG: #[[$MAP1b:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)>
384// CHECK-DAG: #[[$MAP2b:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)>
385
386// CHECK-LABEL: func @tensor.expand_shape_of_slice(
387//  CHECK-SAME:     %[[t1:.*]]: tensor<?x20xf32>
388func.func @tensor.expand_shape_of_slice(
389    %t1: tensor<?x20xf32>, %o1: index, %s1: index) -> tensor<?x7x2x5xf32> {
390  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x20xf32>
391  // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, #[[$MAP1b]]>
392  %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
393      tensor<?x20xf32> to tensor<?x10xf32>
394  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [
395  // CHECK-SAME: [0, 1], [2, 3]] : memref<?x10xf32, #[[$MAP1b]]> into memref<?x7x2x5xf32, #[[$MAP2b]]>
396  %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] :
397      tensor<?x10xf32> into tensor<?x7x2x5xf32>
398  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
399  // CHECK: return %[[r]]
400  return %1 : tensor<?x7x2x5xf32>
401}
402
403// -----
404
405// CHECK-DAG: #[[$MAP9:.*]] = affine_map<()[s0] -> (s0)>
406// CHECK-DAG: #[[$MAP10:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
407
408// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
409//  CHECK-SAME:     %[[t1:.*]]: tensor<?xf32>
410func.func @tensor.expand_shape_of_scalar_slice(
411    %t1: tensor<?xf32>, %o1: index, %s1: index) -> tensor<1xf32> {
412  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?xf32>
413  // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] :  memref<?xf32> to memref<f32, #[[$MAP9]]>
414  %0 = tensor.extract_slice %t1[%o1][1][1] : tensor<?xf32> to tensor<f32>
415  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref<f32, #[[$MAP9]]> into memref<1xf32, #[[$MAP10]]>
416  %1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1xf32>
417  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
418  // CHECK: return %[[r]]
419  return %1 : tensor<1xf32>
420}
421
422// -----
423
424// CHECK-LABEL: func @tensor.collapse_shape(
425//  CHECK-SAME:     %[[t1:.*]]: tensor<2x?x?xf32>
426func.func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
427  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<2x?x?xf32>
428  // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [
429  // CHECK-SAME: [0, 1], [2]] : memref<2x?x?xf32> into memref<?x?xf32>
430  %0 = tensor.collapse_shape %t1 [[0, 1], [2]]
431      : tensor<2x?x?xf32> into tensor<?x?xf32>
432
433  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]]
434  // CHECK: return %[[r]]
435  return %0 : tensor<?x?xf32>
436}
437
438// -----
439
440// CHECK-LABEL: func @tensor.collapse_shape_to_scalar(
441//  CHECK-SAME:     %[[t1:.*]]: tensor<1x1x1xf32>
442func.func @tensor.collapse_shape_to_scalar(%t1: tensor<1x1x1xf32>) -> tensor<f32> {
443  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<1x1x1xf32>
444  // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [] : memref<1x1x1xf32> into memref<f32>
445  %0 = tensor.collapse_shape %t1 []
446      : tensor<1x1x1xf32> into tensor<f32>
447
448  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]]
449  // CHECK: return %[[r]]
450  return %0 : tensor<f32>
451}
452
453// -----
454
455// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)>
456// CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)>
457
458// CHECK-LABEL: func @tensor.collapse_shape_of_slice(
459func.func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor<i32> {
460  // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, #[[$MAP3]]>
461  %0 = tensor.extract_slice %arg0[1] [1] [1] : tensor<2xi32> to tensor<1xi32>
462  // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, #[[$MAP3]]> into memref<i32, #[[$MAP4]]>
463  %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
464  return %1 : tensor<i32>
465}
466
467// -----
468
469// CHECK-LABEL: func @tensor.collapse_shape_of_slice2(
470func.func @tensor.collapse_shape_of_slice2(
471    %arg0: tensor<?x?x?x?xi64>, %o1: index, %o2: index, %o3: index, %o4: index)
472    -> tensor<87x63648xi64> {
473  // CHECK: %[[subview:.*]] = memref.subview %{{.*}} : memref<?x?x?x?xi64> to memref<87x78x68x12xi64, #{{.*}}>
474  %0 = tensor.extract_slice %arg0[%o1, %o2, %o3, %o4] [87, 78, 68, 12] [1, 1, 1, 1] : tensor<?x?x?x?xi64> to tensor<87x78x68x12xi64>
475
476  // This memref may not be collapsible, so the buffer must be copied to get rid
477  // of the layout map.
478  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<87x78x68x12xi64>
479  // CHECK: memref.copy %[[subview]], %[[alloc]]
480  // CHECK: memref.collapse_shape %[[alloc]] [
481  // CHECK-SAME: [0], [1, 2, 3]] : memref<87x78x68x12xi64> into memref<87x63648xi64>
482  %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor<87x78x68x12xi64> into tensor<87x63648xi64>
483  return %1 : tensor<87x63648xi64>
484}
485
486// -----
487
488// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
489// CHECK-DAG: #[[$MAP6:.*]] = affine_map<(d0) -> (d0 * 2)>
490
491// CHECK-LABEL: func @tensor.collapse_shape_of_slice3(
492//  CHECK-SAME:     %[[t1:.*]]: tensor<1x2xf32>
493func.func @tensor.collapse_shape_of_slice3(%t1: tensor<1x2xf32>) -> tensor<1xf32> {
494  // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP5]]>
495  %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32>
496  // CHECK: memref.collapse_shape %{{.*}} [
497  // CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP5]]> into memref<1xf32, #[[$MAP6]]>
498  %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
499  return %1 : tensor<1xf32>
500}
501
502// -----
503
504// CHECK-DAG: #[[$MAP7:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 * 4 + d2)>
505// CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
506
507// CHECK-LABEL:   func @tensor.collapse_shape_of_slice4(
508//  CHECK-SAME:     %[[t1:.*]]: tensor<?x2x4xf32>,
509// CHECK-SAME:      %[[OFFSET:.*]]: index) -> tensor<8xf32> {
510func.func @tensor.collapse_shape_of_slice4(%arg0: tensor<?x2x4xf32>, %offset: index, %size: index) -> tensor<8xf32> {
511  // CHECK: memref.subview %{{.*}} : memref<?x2x4xf32> to memref<4x2x1xf32, #[[$MAP7]]>
512  %0 = tensor.extract_slice %arg0[0, 0, %offset] [4, 2, 1] [1, 1, 1] : tensor<?x2x4xf32> to tensor<4x2x1xf32>
513  // CHECK: memref.collapse_shape %{{.*}} [
514  // CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, #[[$MAP7]]> into memref<8xf32, #[[$MAP8]]>
515  %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32>
516  return %ret: tensor<8xf32>
517}
518
519// -----
520
521// CHECK-LABEL: func @tensor.reshape(
522//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
523func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> {
524  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
525
526  // CHECK: %[[two:.*]] = arith.constant 2 : i64
527  %two = arith.constant 2 : i64
528  // CHECK: %[[five:.*]] = arith.constant 5 : i64
529  %five = arith.constant 5 : i64
530
531  // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 128 : i64} : memref<3xi64>
532  // CHECK: %[[zero_idx:.*]] = arith.constant 0 : index
533  // CHECK: %[[one_idx:.*]] = arith.constant 1 : index
534  // CHECK: %[[two_idx:.*]] = arith.constant 2 : index
535  // CHECK: memref.store %[[two]], %[[alloc]][%[[zero_idx]]] : memref<3xi64>
536  // CHECK: memref.store %[[two]], %[[alloc]][%[[one_idx]]] : memref<3xi64>
537  // CHECK: memref.store %[[five]], %[[alloc]][%[[two_idx]]] : memref<3xi64>
538  %shape = tensor.from_elements %two, %two, %five : tensor<3xi64>
539
540  // CHECK: %[[reshaped:.*]] = memref.reshape %[[m1]](%[[alloc]]) : (memref<?x10xf32>, memref<3xi64>) -> memref<2x2x5xf32>
541  %reshaped = tensor.reshape %t1(%shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<2x2x5xf32>
542
543  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[reshaped]]
544  // CHECK: return %[[r]]
545  return %reshaped : tensor<2x2x5xf32>
546}
547