1// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
2
3// CHECK-LABEL:   func @dim(
4// CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
5// CHECK-SAME:              %[[INDEX:.*]]: index) -> index {
6// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<f32>
7// CHECK:           %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
8// CHECK:           return %[[EXTENT]] : index
9func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
10  %0 = tensor.dim %arg0, %arg1 : tensor<f32>
11  return %0 : index
12}
13
14// CHECK-LABEL: func @rank(
15// CHECK-SAME:    %[[TENSOR:.*]]: tensor<*xf32>) -> index {
16// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
17// CHECK:           %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32>
18func @rank(%arg0: tensor<*xf32>) -> index {
19  %0 = tensor.rank %arg0 : tensor<*xf32>
20  return %0 : index
21}
22
23// CHECK-LABEL:   func @tensor.cast(
24// CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
25// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
26// CHECK:           %[[CASTED:.*]] = memref.cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
27// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED]]
28// CHECK:           return %[[RET]] : tensor<2xindex>
29func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
30  %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex>
31  return %0 : tensor<2xindex>
32}
33
34// CHECK-LABEL:   func @tensor.cast_from_unranked(
35// CHECK-SAME:                                    %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
36// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32>
37// CHECK:           %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
38// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32>
39// CHECK:           return %[[RET]] : tensor<2xf32>
40func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
41  %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>
42  return %0 : tensor<2xf32>
43}
44
45// CHECK-LABEL:   func @tensor.cast_to_unranked(
46// CHECK-SAME:                                  %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
47// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<2xf32>
48// CHECK:           %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
49// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<*xf32>
50// CHECK:           return %[[RET]] : tensor<*xf32>
51func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
52  %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32>
53  return %0 : tensor<*xf32>
54}
55
56// CHECK-LABEL:   func @tensor.extract(
57// CHECK-SAME:                  %[[TENSOR:.*]]: tensor<?xf32>,
58// CHECK-SAME:                  %[[IDX:.*]]: index) -> f32 {
59// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<?xf32>
60// CHECK:           %[[RET:.*]] = memref.load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
61// CHECK:           return %[[RET]] : f32
62// CHECK:         }
63func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
64  %0 = tensor.extract %arg0[%arg1] : tensor<?xf32>
65  return %0 : f32
66}
67
68// CHECK-LABEL:   func @tensor.from_elements_no_elements() -> tensor<0xindex> {
69// CHECK:           %[[MEMREF:.*]] = memref.alloc() : memref<0xindex>
70// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
71// CHECK:           return %[[RET]] : tensor<0xindex>
72func @tensor.from_elements_no_elements() -> tensor<0xindex> {
73  %0 = tensor.from_elements : tensor<0xindex>
74  return %0 : tensor<0xindex>
75}
76
77// CHECK-LABEL:   func @tensor.from_elements_0d(
78// CHECK-SAME:        %[[ELEM0:.*]]: index) -> tensor<index> {
79// CHECK:           %[[MEMREF:.*]] = memref.alloc() : memref<index>
80// CHECK:           store %[[ELEM0]], %[[MEMREF]]
81// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
82// CHECK:           return %[[RET]] : tensor<index>
83func @tensor.from_elements_0d(%arg0: index) -> tensor<index> {
84  %0 = tensor.from_elements %arg0 : tensor<index>
85  return %0 : tensor<index>
86}
87
88// CHECK-LABEL:   func @tensor.from_elements_1d(
89// CHECK-SAME:                               %[[ELEM0:.*]]: index,
90// CHECK-SAME:                               %[[ELEM1:.*]]: index) -> tensor<2xindex> {
91// CHECK:           %[[MEMREF:.*]] = memref.alloc() : memref<2xindex>
92// CHECK:           %[[C0:.*]] = arith.constant 0 : index
93// CHECK:           %[[C1:.*]] = arith.constant 1 : index
94// CHECK:           store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
95// CHECK:           store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
96// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
97// CHECK:           return %[[RET]] : tensor<2xindex>
98func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
99  %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex>
100  return %0 : tensor<2xindex>
101}
102
103// CHECK-LABEL: func @tensor.from_elements_2d(
104// CHECK-SAME:      %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index)
105// CHECK-SAME:      -> tensor<3x2xindex> {
106// CHECK:         %[[MEMREF:.*]] = memref.alloc() : memref<3x2xindex>
107// CHECK:         %[[C0:.*]] = arith.constant 0 : index
108// CHECK:         %[[C1:.*]] = arith.constant 1 : index
109// CHECK:         %[[C2:.*]] = arith.constant 2 : index
110// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
111// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
112// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
113// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
114// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
115// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
116// CHECK:         %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
117// CHECK:         return %[[RET]] : tensor<3x2xindex>
118func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
119  %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
120         : tensor<3x2xindex>
121  return %0 : tensor<3x2xindex>
122}
123
124// CHECK-LABEL: func @tensor.from_elements_3d()
125
126// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
127// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
128// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
129// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
130// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
131// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
132// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
133// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
134// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
135// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
136// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
137// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
138
139// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2x2xf32>
140
141// CHECK: %[[C0:.*]] = arith.constant 0 : index
142// CHECK: %[[C1:.*]] = arith.constant 1 : index
143// CHECK: %[[C2:.*]] = arith.constant 2 : index
144
145// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
146// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
147// CHECK: store %[[F2]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C0]]]
148// CHECK: store %[[F3]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C1]]]
149// CHECK: store %[[F4]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C0]]]
150// CHECK: store %[[F5]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C1]]]
151// CHECK: store %[[F6]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C0]]]
152// CHECK: store %[[F7]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C1]]]
153// CHECK: store %[[F8]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C0]]]
154// CHECK: store %[[F9]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C1]]]
155// CHECK: store %[[F10]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C0]]]
156// CHECK: store %[[F11]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C1]]]
157
158// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
159// CHECK: return %[[RET]] : tensor<3x2x2xf32>
160func @tensor.from_elements_3d() -> tensor<3x2x2xf32> {
161  %f0 = arith.constant 0.0 : f32
162  %f1 = arith.constant 1.0 : f32
163  %f2 = arith.constant 2.0 : f32
164  %f3 = arith.constant 3.0 : f32
165  %f4 = arith.constant 4.0 : f32
166  %f5 = arith.constant 5.0 : f32
167  %f6 = arith.constant 6.0 : f32
168  %f7 = arith.constant 7.0 : f32
169  %f8 = arith.constant 8.0 : f32
170  %f9 = arith.constant 9.0 : f32
171  %f10 = arith.constant 10.0 : f32
172  %f11 = arith.constant 11.0 : f32
173  %0 = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
174         : tensor<3x2x2xf32>
175  return %0 : tensor<3x2x2xf32>
176}
177
178// CHECK-LABEL:   func @tensor.generate(
179// CHECK-SAME:                                       %[[ARG:.*]]: tensor<*xf32>,
180// CHECK-SAME:                                       %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
181// CHECK:           %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
182// CHECK:           %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex>
183// CHECK:           %[[C0:.*]] = arith.constant 0 : index
184// CHECK:           %[[C1:.*]] = arith.constant 1 : index
185// CHECK:           scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
186// CHECK:             %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
187// CHECK:             store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
188// CHECK:             scf.yield
189// CHECK:           }
190// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<?xindex>
191// CHECK:           return %[[RET]] : tensor<?xindex>
192// CHECK:         }
193func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> {
194  %result = tensor.generate %dynamic_extent {
195  ^bb0(%i : index):
196    %elem = tensor.dim %arg, %i : tensor<*xf32>
197    tensor.yield %elem : index
198  } : tensor<?xindex>
199  return %result : tensor<?xindex>
200}
201
202// Additional test that checks the logic for intermixed static and dynamic
203// extents.
204//
205// CHECK-LABEL:   func @tensor.generate_static_and_dynamic(
206// CHECK-SAME:                                                          %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
207// CHECK:           %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex>
208// CHECK:           %[[C0:.*]] = arith.constant 0 : index
209// CHECK:           %[[C1:.*]] = arith.constant 1 : index
210// CHECK:           %[[C16:.*]] = arith.constant 16 : index
211// CHECK:           scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
212// CHECK:             %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index
213// CHECK:             store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
214// CHECK:             scf.yield
215// CHECK:           }
216// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<16x?xindex>
217// CHECK:           return %[[RET]] : tensor<16x?xindex>
218// CHECK:         }
219func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
220  %result = tensor.generate %arg0 {
221  ^bb0(%i: index, %j: index):
222    %sum = arith.addi %i, %j : index
223    tensor.yield %sum : index
224  } : tensor<16x?xindex>
225  return %result : tensor<16x?xindex>
226}
227
228// The tensor.generate op needs to put its body into the
229// resulting scf.parallel. To handle unknown ops in the body, it cannot clone
230// the body because that would require the cloned ops to be legalized
231// immediately, which is usually not possible since they might be from various
232// other dialects.
233//
234// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body
235func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
236  // CHECK-NOT: tensor.generate
237  %tensor = tensor.generate %arg0 {
238  ^bb0(%iv: index):
239    // CHECK: test.source
240    %0 = "test.source"() : () -> index
241    tensor.yield %0 : index
242  } : tensor<?xindex>
243  return %tensor : tensor<?xindex>
244}
245