1// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
2
3// CHECK-LABEL:   func @tensor.cast(
4// CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
5// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
6// CHECK:           %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
7// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED]]
8// CHECK:           return %[[RET]] : tensor<2xindex>
9func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
10  %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex>
11  return %0 : tensor<2xindex>
12}
13
14// CHECK-LABEL:   func @tensor.cast_from_unranked(
15// CHECK-SAME:                                    %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
16// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32>
17// CHECK:           %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
18// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32>
19// CHECK:           return %[[RET]] : tensor<2xf32>
20func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
21  %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>
22  return %0 : tensor<2xf32>
23}
24
25// CHECK-LABEL:   func @tensor.cast_to_unranked(
26// CHECK-SAME:                                  %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
27// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32>
28// CHECK:           %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
29// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32>
30// CHECK:           return %[[RET]] : tensor<*xf32>
31func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
32  %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32>
33  return %0 : tensor<*xf32>
34}
35
36// CHECK-LABEL:   func @tensor.extract(
37// CHECK-SAME:                  %[[TENSOR:.*]]: tensor<?xf32>,
38// CHECK-SAME:                  %[[IDX:.*]]: index) -> f32 {
39// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
40// CHECK:           %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
41// CHECK:           return %[[RET]] : f32
42// CHECK:         }
43func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
44  %0 = tensor.extract %arg0[%arg1] : tensor<?xf32>
45  return %0 : f32
46}
47
48// CHECK-LABEL:   func @tensor.from_elements(
49// CHECK-SAME:                               %[[ELEM0:.*]]: index,
50// CHECK-SAME:                               %[[ELEM1:.*]]: index) -> tensor<2xindex> {
51// CHECK:           %[[MEMREF:.*]] = alloc()
52// CHECK:           %[[C0:.*]] = constant 0 : index
53// CHECK:           store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
54// CHECK:           %[[C1:.*]] = constant 1 : index
55// CHECK:           store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
56// CHECK:           %[[RET:.*]] = tensor_load %[[MEMREF]]
57// CHECK:           return %[[RET]] : tensor<2xindex>
58func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
59  %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex>
60  return %0 : tensor<2xindex>
61}
62
63// CHECK-LABEL:   func @tensor.generate(
64// CHECK-SAME:                                       %[[ARG:.*]]: tensor<*xf32>,
65// CHECK-SAME:                                       %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
66// CHECK:           %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex>
67// CHECK:           %[[C0:.*]] = constant 0 : index
68// CHECK:           %[[C1:.*]] = constant 1 : index
69// CHECK:           scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
70// CHECK:             %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
71// CHECK:             store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
72// CHECK:             scf.yield
73// CHECK:           }
74// CHECK:           %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<?xindex>
75// CHECK:           return %[[RET]] : tensor<?xindex>
76// CHECK:         }
77func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> {
78  %result = tensor.generate %dynamic_extent {
79  ^bb0(%i : index):
80    %elem = dim %arg, %i : tensor<*xf32>
81    tensor.yield %elem : index
82  } : tensor<?xindex>
83  return %result : tensor<?xindex>
84}
85
86// Additional test that checks the logic for intermixed static and dynamic
87// extents.
88//
89// CHECK-LABEL:   func @tensor.generate_static_and_dynamic(
90// CHECK-SAME:                                                          %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
91// CHECK:           %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex>
92// CHECK:           %[[C0:.*]] = constant 0 : index
93// CHECK:           %[[C1:.*]] = constant 1 : index
94// CHECK:           %[[C16:.*]] = constant 16 : index
95// CHECK:           scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
96// CHECK:             %[[VAL_7:.*]] = addi %[[I]], %[[J]] : index
97// CHECK:             store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
98// CHECK:             scf.yield
99// CHECK:           }
100// CHECK:           %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<16x?xindex>
101// CHECK:           return %[[RET]] : tensor<16x?xindex>
102// CHECK:         }
103func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
104  %result = tensor.generate %arg0 {
105  ^bb0(%i: index, %j: index):
106    %sum = addi %i, %j : index
107    tensor.yield %sum : index
108  } : tensor<16x?xindex>
109  return %result : tensor<16x?xindex>
110}
111
112// The tensor.generate op needs to put its body into the
113// resulting scf.parallel. To handle unknown ops in the body, it cannot clone
114// the body because that would require the cloned ops to be legalized
115// immediately, which is usually not possible since they might be from various
116// other dialects.
117//
118// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body
119func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
120  // CHECK-NOT: tensor.generate
121  %tensor = tensor.generate %arg0 {
122  ^bb0(%iv: index):
123    // CHECK: test.source
124    %0 = "test.source"() : () -> index
125    tensor.yield %0 : index
126  } : tensor<?xindex>
127  return %tensor : tensor<?xindex>
128}
129