1// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s 2 3// CHECK-LABEL: func @dim( 4// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>, 5// CHECK-SAME: %[[INDEX:.*]]: index) -> index { 6// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<f32> 7// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32> 8// CHECK: return %[[EXTENT]] : index 9func @dim(%arg0: tensor<f32>, %arg1: index) -> index { 10 %0 = tensor.dim %arg0, %arg1 : tensor<f32> 11 return %0 : index 12} 13 14// CHECK-LABEL: func @rank( 15// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> index { 16// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] 17// CHECK: %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32> 18func @rank(%arg0: tensor<*xf32>) -> index { 19 %0 = tensor.rank %arg0 : tensor<*xf32> 20 return %0 : index 21} 22 23// CHECK-LABEL: func @tensor.cast( 24// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> { 25// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] 26// CHECK: %[[CASTED:.*]] = memref.cast %[[MEMREF]] : memref<?xindex> to memref<2xindex> 27// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED]] 28// CHECK: return %[[RET]] : tensor<2xindex> 29func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> { 30 %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex> 31 return %0 : tensor<2xindex> 32} 33 34// CHECK-LABEL: func @tensor.cast_from_unranked( 35// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> { 36// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32> 37// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32> 38// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32> 39// CHECK: return %[[RET]] : tensor<2xf32> 40func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> { 41 %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32> 42 return %0 : tensor<2xf32> 43} 44 45// CHECK-LABEL: func @tensor.cast_to_unranked( 46// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> { 47// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<2xf32> 48// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<2xf32> to memref<*xf32> 49// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<*xf32> 50// CHECK: return %[[RET]] : tensor<*xf32> 51func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> { 52 %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32> 53 return %0 : tensor<*xf32> 54} 55 56// CHECK-LABEL: func @tensor.extract( 57// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>, 58// CHECK-SAME: %[[IDX:.*]]: index) -> f32 { 59// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<?xf32> 60// CHECK: %[[RET:.*]] = memref.load %[[MEMREF]][%[[IDX]]] : memref<?xf32> 61// CHECK: return %[[RET]] : f32 62// CHECK: } 63func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 { 64 %0 = tensor.extract %arg0[%arg1] : tensor<?xf32> 65 return %0 : f32 66} 67 68// CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> { 69// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<0xindex> 70// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 71// CHECK: return %[[RET]] : tensor<0xindex> 72func @tensor.from_elements_no_elements() -> tensor<0xindex> { 73 %0 = tensor.from_elements : tensor<0xindex> 74 return %0 : tensor<0xindex> 75} 76 77// CHECK-LABEL: func @tensor.from_elements_0d( 78// CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor<index> { 79// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<index> 80// CHECK: store %[[ELEM0]], %[[MEMREF]] 81// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 82// CHECK: return %[[RET]] : tensor<index> 83func @tensor.from_elements_0d(%arg0: index) -> tensor<index> { 84 %0 = tensor.from_elements %arg0 : tensor<index> 85 return %0 : tensor<index> 86} 87 88// CHECK-LABEL: func @tensor.from_elements_1d( 89// CHECK-SAME: %[[ELEM0:.*]]: index, 90// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { 91// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<2xindex> 92// CHECK: %[[C0:.*]] = arith.constant 0 : index 93// CHECK: %[[C1:.*]] = arith.constant 1 : index 94// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] 95// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] 96// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 97// CHECK: return %[[RET]] : tensor<2xindex> 98func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> { 99 %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex> 100 return %0 : tensor<2xindex> 101} 102 103// CHECK-LABEL: func @tensor.from_elements_2d( 104// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index) 105// CHECK-SAME: -> tensor<3x2xindex> { 106// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2xindex> 107// CHECK: %[[C0:.*]] = arith.constant 0 : index 108// CHECK: %[[C1:.*]] = arith.constant 1 : index 109// CHECK: %[[C2:.*]] = arith.constant 2 : index 110// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]] 111// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]] 112// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]] 113// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]] 114// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]] 115// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]] 116// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 117// CHECK: return %[[RET]] : tensor<3x2xindex> 118func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> { 119 %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1 120 : tensor<3x2xindex> 121 return %0 : tensor<3x2xindex> 122} 123 124// CHECK-LABEL: func @tensor.from_elements_3d() 125 126// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0 127// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00 128// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0 129// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0 130// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0 131// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0 132// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0 133// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0 134// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0 135// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0 136// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01 137// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01 138 139// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2x2xf32> 140 141// CHECK: %[[C0:.*]] = arith.constant 0 : index 142// CHECK: %[[C1:.*]] = arith.constant 1 : index 143// CHECK: %[[C2:.*]] = arith.constant 2 : index 144 145// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]] 146// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]] 147// CHECK: store %[[F2]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C0]]] 148// CHECK: store %[[F3]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C1]]] 149// CHECK: store %[[F4]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C0]]] 150// CHECK: store %[[F5]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C1]]] 151// CHECK: store %[[F6]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C0]]] 152// CHECK: store %[[F7]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C1]]] 153// CHECK: store %[[F8]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C0]]] 154// CHECK: store %[[F9]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C1]]] 155// CHECK: store %[[F10]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C0]]] 156// CHECK: store %[[F11]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C1]]] 157 158// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 159// CHECK: return %[[RET]] : tensor<3x2x2xf32> 160func @tensor.from_elements_3d() -> tensor<3x2x2xf32> { 161 %f0 = arith.constant 0.0 : f32 162 %f1 = arith.constant 1.0 : f32 163 %f2 = arith.constant 2.0 : f32 164 %f3 = arith.constant 3.0 : f32 165 %f4 = arith.constant 4.0 : f32 166 %f5 = arith.constant 5.0 : f32 167 %f6 = arith.constant 6.0 : f32 168 %f7 = arith.constant 7.0 : f32 169 %f8 = arith.constant 8.0 : f32 170 %f9 = arith.constant 9.0 : f32 171 %f10 = arith.constant 10.0 : f32 172 %f11 = arith.constant 11.0 : f32 173 %0 = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11 174 : tensor<3x2x2xf32> 175 return %0 : tensor<3x2x2xf32> 176} 177 178// CHECK-LABEL: func @tensor.generate( 179// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, 180// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> { 181// CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32> 182// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex> 183// CHECK: %[[C0:.*]] = arith.constant 0 : index 184// CHECK: %[[C1:.*]] = arith.constant 1 : index 185// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { 186// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32> 187// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex> 188// CHECK: scf.yield 189// CHECK: } 190// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<?xindex> 191// CHECK: return %[[RET]] : tensor<?xindex> 192// CHECK: } 193func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> { 194 %result = tensor.generate %dynamic_extent { 195 ^bb0(%i : index): 196 %elem = tensor.dim %arg, %i : tensor<*xf32> 197 tensor.yield %elem : index 198 } : tensor<?xindex> 199 return %result : tensor<?xindex> 200} 201 202// Additional test that checks the logic for intermixed static and dynamic 203// extents. 204// 205// CHECK-LABEL: func @tensor.generate_static_and_dynamic( 206// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { 207// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex> 208// CHECK: %[[C0:.*]] = arith.constant 0 : index 209// CHECK: %[[C1:.*]] = arith.constant 1 : index 210// CHECK: %[[C16:.*]] = arith.constant 16 : index 211// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { 212// CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index 213// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> 214// CHECK: scf.yield 215// CHECK: } 216// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<16x?xindex> 217// CHECK: return %[[RET]] : tensor<16x?xindex> 218// CHECK: } 219func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { 220 %result = tensor.generate %arg0 { 221 ^bb0(%i: index, %j: index): 222 %sum = arith.addi %i, %j : index 223 tensor.yield %sum : index 224 } : tensor<16x?xindex> 225 return %result : tensor<16x?xindex> 226} 227 228// The tensor.generate op needs to put its body into the 229// resulting scf.parallel. To handle unknown ops in the body, it cannot clone 230// the body because that would require the cloned ops to be legalized 231// immediately, which is usually not possible since they might be from various 232// other dialects. 233// 234// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body 235func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> { 236 // CHECK-NOT: tensor.generate 237 %tensor = tensor.generate %arg0 { 238 ^bb0(%iv: index): 239 // CHECK: test.source 240 %0 = "test.source"() : () -> index 241 tensor.yield %0 : index 242 } : tensor<?xindex> 243 return %tensor : tensor<?xindex> 244} 245