1// RUN: mlir-opt %s -tensor-bufferize -cse | FileCheck %s
2
3 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
4 // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)>
5 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)>
6 // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)>
7 // CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)>
8 // CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
9 // CHECK-DAG: #[[$MAP6:.*]] = affine_map<(d0) -> (d0 * 2)>
10 // CHECK-DAG: #[[$MAP7:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 * 4 + d2)>
11 // CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
12
13// CHECK-LABEL:   func @dim(
14// CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
15// CHECK-SAME:              %[[INDEX:.*]]: index) -> index {
16// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<f32>
17// CHECK:           %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
18// CHECK:           return %[[EXTENT]] : index
19func.func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
20  %0 = tensor.dim %arg0, %arg1 : tensor<f32>
21  return %0 : index
22}
23
24// CHECK-LABEL: func @rank(
25// CHECK-SAME:    %[[TENSOR:.*]]: tensor<*xf32>) -> index {
26// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
27// CHECK:           %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32>
28func.func @rank(%arg0: tensor<*xf32>) -> index {
29  %0 = tensor.rank %arg0 : tensor<*xf32>
30  return %0 : index
31}
32
33// CHECK-LABEL:   func @tensor.cast(
34// CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
35// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
36// CHECK:           %[[CASTED:.*]] = memref.cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
37// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED]]
38// CHECK:           return %[[RET]] : tensor<2xindex>
39func.func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
40  %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex>
41  return %0 : tensor<2xindex>
42}
43
44// CHECK-LABEL:   func @tensor.cast_from_unranked(
45// CHECK-SAME:                                    %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
46// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32>
47// CHECK:           %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
48// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32>
49// CHECK:           return %[[RET]] : tensor<2xf32>
50func.func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
51  %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>
52  return %0 : tensor<2xf32>
53}
54
55// CHECK-LABEL:   func @tensor.cast_to_unranked(
56// CHECK-SAME:                                  %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
57// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<2xf32>
58// CHECK:           %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
59// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<*xf32>
60// CHECK:           return %[[RET]] : tensor<*xf32>
61func.func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
62  %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32>
63  return %0 : tensor<*xf32>
64}
65
66// CHECK-LABEL:   func @tensor.extract(
67// CHECK-SAME:                  %[[TENSOR:.*]]: tensor<?xf32>,
68// CHECK-SAME:                  %[[IDX:.*]]: index) -> f32 {
69// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<?xf32>
70// CHECK:           %[[RET:.*]] = memref.load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
71// CHECK:           return %[[RET]] : f32
72// CHECK:         }
73func.func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
74  %0 = tensor.extract %arg0[%arg1] : tensor<?xf32>
75  return %0 : f32
76}
77
78// CHECK-LABEL:   func @tensor.from_elements_0d(
79// CHECK-SAME:        %[[ELEM0:.*]]: index) -> tensor<index> {
80// CHECK:           %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index>
81// CHECK:           store %[[ELEM0]], %[[MEMREF]]
82// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
83// CHECK:           return %[[RET]] : tensor<index>
84func.func @tensor.from_elements_0d(%arg0: index) -> tensor<index> {
85  %0 = tensor.from_elements %arg0 : tensor<index>
86  return %0 : tensor<index>
87}
88
89// CHECK-LABEL:   func @tensor.from_elements_1d(
90// CHECK-SAME:                               %[[ELEM0:.*]]: index,
91// CHECK-SAME:                               %[[ELEM1:.*]]: index) -> tensor<2xindex> {
92// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
93// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
94// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex>
95// CHECK:           store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
96// CHECK:           store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
97// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
98// CHECK:           return %[[RET]] : tensor<2xindex>
99func.func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
100  %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex>
101  return %0 : tensor<2xindex>
102}
103
104// CHECK-LABEL: func @tensor.from_elements_2d(
105// CHECK-SAME:      %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index)
106// CHECK-SAME:      -> tensor<3x2xindex> {
107// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
108// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
109// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
110// CHECK-DAG:     %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
111// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
112// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
113// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
114// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
115// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
116// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
117// CHECK:         %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
118// CHECK:         return %[[RET]] : tensor<3x2xindex>
119func.func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
120  %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
121         : tensor<3x2xindex>
122  return %0 : tensor<3x2xindex>
123}
124
125// CHECK-LABEL: func @tensor.from_elements_3d(
126//  CHECK-SAME:     %[[F0:.*]]: f32
127
128// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
129// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
130// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
131// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
132// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
133// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
134// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
135// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
136// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
137// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
138// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
139
140// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
141// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
142// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
143
144// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32>
145
146// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
147// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
148// CHECK: store %[[F2]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C0]]]
149// CHECK: store %[[F3]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C1]]]
150// CHECK: store %[[F4]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C0]]]
151// CHECK: store %[[F5]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C1]]]
152// CHECK: store %[[F6]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C0]]]
153// CHECK: store %[[F7]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C1]]]
154// CHECK: store %[[F8]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C0]]]
155// CHECK: store %[[F9]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C1]]]
156// CHECK: store %[[F10]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C0]]]
157// CHECK: store %[[F11]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C1]]]
158
159// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
160// CHECK: return %[[RET]] : tensor<3x2x2xf32>
161func.func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
162  %f1 = arith.constant 1.0 : f32
163  %f2 = arith.constant 2.0 : f32
164  %f3 = arith.constant 3.0 : f32
165  %f4 = arith.constant 4.0 : f32
166  %f5 = arith.constant 5.0 : f32
167  %f6 = arith.constant 6.0 : f32
168  %f7 = arith.constant 7.0 : f32
169  %f8 = arith.constant 8.0 : f32
170  %f9 = arith.constant 9.0 : f32
171  %f10 = arith.constant 10.0 : f32
172  %f11 = arith.constant 11.0 : f32
173  %0 = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
174         : tensor<3x2x2xf32>
175  return %0 : tensor<3x2x2xf32>
176}
177
178// CHECK-LABEL:   func @tensor.generate(
179// CHECK-SAME:                                       %[[ARG:.*]]: tensor<*xf32>,
180// CHECK-SAME:                                       %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
181// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
182// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
183// CHECK-DAG:       %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
184// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
185// CHECK:           scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
186// CHECK:             %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
187// CHECK:             store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
188// CHECK:             scf.yield
189// CHECK:           }
190// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<?xindex>
191// CHECK:           return %[[RET]] : tensor<?xindex>
192// CHECK:         }
193func.func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> {
194  %result = tensor.generate %dynamic_extent {
195  ^bb0(%i : index):
196    %elem = tensor.dim %arg, %i : tensor<*xf32>
197    tensor.yield %elem : index
198  } : tensor<?xindex>
199  return %result : tensor<?xindex>
200}
201
202// Additional test that checks the logic for intermixed static and dynamic
203// extents.
204//
205// CHECK-LABEL:   func @tensor.generate_static_and_dynamic(
206// CHECK-SAME:        %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
207// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
208// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
209// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
210// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
211// CHECK:           scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
212// CHECK:             %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index
213// CHECK:             store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
214// CHECK:             scf.yield
215// CHECK:           }
216// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<16x?xindex>
217// CHECK:           return %[[RET]] : tensor<16x?xindex>
218// CHECK:         }
219func.func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
220  %result = tensor.generate %arg0 {
221  ^bb0(%i: index, %j: index):
222    %sum = arith.addi %i, %j : index
223    tensor.yield %sum : index
224  } : tensor<16x?xindex>
225  return %result : tensor<16x?xindex>
226}
227
228// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body
229func.func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
230  // CHECK-NOT: tensor.generate
231  %tensor = tensor.generate %arg0 {
232  ^bb0(%iv: index):
233    // CHECK: test.source
234    %0 = "test.source"() : () -> index
235    tensor.yield %0 : index
236  } : tensor<?xindex>
237  return %tensor : tensor<?xindex>
238}
239
240// CHECK-LABEL: func @tensor.extract_slice(
241//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?xf32>, %[[idx1:.*]]: index, %[[idx2:.*]]: index
242func.func @tensor.extract_slice(
243    %t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> {
244  // CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
245  // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP0]]>
246  %0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1]
247      : tensor<?x?xf32> to tensor<?x10xf32>
248  // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
249  // CHECK: return %[[r_tensor]]
250  return %0 : tensor<?x10xf32>
251}
252
253// CHECK-LABEL: func @tensor.extract_slice_rank_reducing(
254//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10x?xf32>, %[[idx1:.*]]: index,
255//  CHECK-SAME:     %[[idx2:.*]]: index
256func.func @tensor.extract_slice_rank_reducing(
257    %t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> {
258  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10x?xf32>
259  // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP0]]>
260  %0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1]
261      : tensor<?x10x?xf32> to tensor<?x15xf32>
262  // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
263  // CHECK: return %[[r_tensor]]
264  return %0 : tensor<?x15xf32>
265}
266
267// CHECK-LABEL: func @tensor.insert_slice(
268//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>,
269//  CHECK-SAME:     %[[idx1:.*]]: index, %[[idx2:.*]]: index
270func.func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
271                          %idx1: index, %idx2: index) -> tensor<?x?xf32> {
272  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
273  // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
274  // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
275  // CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref<?x10xf32>
276  // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
277  // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
278  //     CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]])
279  //     CHECK: memref.copy %[[m1]], %[[alloc]]
280  //     CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1]
281  //     CHECK: memref.copy %[[m2]], %[[subview]]
282  %0 = tensor.insert_slice %t2 into %t1[%idx1, 5][%idx2, 10][1, 1]
283      : tensor<?x10xf32> into tensor<?x?xf32>
284
285  //     CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
286  //     CHECK: return %[[r]]
287  return %0 : tensor<?x?xf32>
288}
289
290// CHECK-LABEL: func @tensor.insert(
291//  CHECK-SAME:     %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index,
292//  CHECK-SAME:     %[[f:.*]]: f32
293func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> {
294  // CHECK-DAG: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
295  // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32>
296  // CHECK: memref.copy %[[m1]], %[[alloc]]
297  // CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]]
298  %0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32>
299
300  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
301  // CHECK: return %[[r]]
302  return %0 : tensor<5xf32>
303}
304
305// CHECK-LABEL: func @tensor.expand_shape(
306//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
307func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
308  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
309  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [
310  // CHECK-SAME: [0, 1], [2]] : memref<?x10xf32> into memref<2x?x10xf32>
311  %0 = tensor.expand_shape %t1 [[0, 1], [2]]
312      : tensor<?x10xf32> into tensor<2x?x10xf32>
313
314  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
315  // CHECK: return %[[r]]
316  return %0 : tensor<2x?x10xf32>
317}
318
319// CHECK-LABEL: func @tensor.expand_shape_of_slice(
320//  CHECK-SAME:     %[[t1:.*]]: tensor<?x20xf32>
321func.func @tensor.expand_shape_of_slice(
322    %t1: tensor<?x20xf32>, %o1: index, %s1: index) -> tensor<?x7x2x5xf32> {
323  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x20xf32>
324  // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, #[[$MAP1]]>
325  %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
326      tensor<?x20xf32> to tensor<?x10xf32>
327  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [
328  // CHECK-SAME: [0, 1], [2, 3]] : memref<?x10xf32, #[[$MAP1]]> into memref<?x7x2x5xf32, #[[$MAP2]]>
329  %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] :
330      tensor<?x10xf32> into tensor<?x7x2x5xf32>
331  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
332  // CHECK: return %[[r]]
333  return %1 : tensor<?x7x2x5xf32>
334}
335
336// CHECK-LABEL: func @tensor.collapse_shape(
337//  CHECK-SAME:     %[[t1:.*]]: tensor<2x?x?xf32>
338func.func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
339  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<2x?x?xf32>
340  // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [
341  // CHECK-SAME: [0, 1], [2]] : memref<2x?x?xf32> into memref<?x?xf32>
342  %0 = tensor.collapse_shape %t1 [[0, 1], [2]]
343      : tensor<2x?x?xf32> into tensor<?x?xf32>
344
345  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]]
346  // CHECK: return %[[r]]
347  return %0 : tensor<?x?xf32>
348}
349
350// CHECK-LABEL: func @tensor.collapse_shape_to_scalar(
351//  CHECK-SAME:     %[[t1:.*]]: tensor<1x1x1xf32>
352func.func @tensor.collapse_shape_to_scalar(%t1: tensor<1x1x1xf32>) -> tensor<f32> {
353  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<1x1x1xf32>
354  // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [] : memref<1x1x1xf32> into memref<f32>
355  %0 = tensor.collapse_shape %t1 []
356      : tensor<1x1x1xf32> into tensor<f32>
357
358  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]]
359  // CHECK: return %[[r]]
360  return %0 : tensor<f32>
361}
362
363// CHECK-LABEL: func @tensor.collapse_shape_of_slice(
364func.func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor<i32> {
365  // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, #[[$MAP3]]>
366  %0 = tensor.extract_slice %arg0[1] [1] [1] : tensor<2xi32> to tensor<1xi32>
367  // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, #[[$MAP3]]> into memref<i32, #[[$MAP4]]>
368  %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
369  return %1 : tensor<i32>
370}
371
372// CHECK-LABEL: func @tensor.collapse_shape_of_slice2(
373func.func @tensor.collapse_shape_of_slice2(
374    %arg0: tensor<?x?x?x?xi64>, %o1: index, %o2: index, %o3: index, %o4: index)
375    -> tensor<87x63648xi64> {
376  // CHECK: %[[subview:.*]] = memref.subview %{{.*}} : memref<?x?x?x?xi64> to memref<87x78x68x12xi64, #{{.*}}>
377  %0 = tensor.extract_slice %arg0[%o1, %o2, %o3, %o4] [87, 78, 68, 12] [1, 1, 1, 1] : tensor<?x?x?x?xi64> to tensor<87x78x68x12xi64>
378
379  // This memref may not be collapsible, so the buffer must be copied to get rid
380  // of the layout map.
381  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<87x78x68x12xi64>
382  // CHECK: memref.copy %[[subview]], %[[alloc]]
383  // CHECK: memref.collapse_shape %[[alloc]] [
384  // CHECK-SAME: [0], [1, 2, 3]] : memref<87x78x68x12xi64> into memref<87x63648xi64>
385  %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor<87x78x68x12xi64> into tensor<87x63648xi64>
386  return %1 : tensor<87x63648xi64>
387}
388
389// CHECK-LABEL: func @tensor.collapse_shape_of_slice3(
390//  CHECK-SAME:     %[[t1:.*]]: tensor<1x2xf32>
391func.func @tensor.collapse_shape_of_slice3(%t1: tensor<1x2xf32>) -> tensor<1xf32> {
392  // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP5]]>
393  %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32>
394  // CHECK: memref.collapse_shape %{{.*}} [
395  // CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP5]]> into memref<1xf32, #[[$MAP6]]>
396  %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
397  return %1 : tensor<1xf32>
398}
399
400// CHECK-LABEL:   func @tensor.collapse_shape_of_slice4(
401//  CHECK-SAME:     %[[t1:.*]]: tensor<?x2x4xf32>,
402// CHECK-SAME:      %[[OFFSET:.*]]: index) -> tensor<8xf32> {
403func.func @tensor.collapse_shape_of_slice4(%arg0: tensor<?x2x4xf32>, %offset: index, %size: index) -> tensor<8xf32> {
404  // CHECK: memref.subview %{{.*}} : memref<?x2x4xf32> to memref<4x2x1xf32, #[[$MAP7]]>
405  %0 = tensor.extract_slice %arg0[0, 0, %offset] [4, 2, 1] [1, 1, 1] : tensor<?x2x4xf32> to tensor<4x2x1xf32>
406  // CHECK: memref.collapse_shape %{{.*}} [
407  // CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, #[[$MAP7]]> into memref<8xf32, #[[$MAP8]]>
408  %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32>
409  return %ret: tensor<8xf32>
410}
411