1// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s 2 3// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> 4 5// CHECK-LABEL: func @dim( 6// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>, 7// CHECK-SAME: %[[INDEX:.*]]: index) -> index { 8// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<f32> 9// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32> 10// CHECK: return %[[EXTENT]] : index 11func @dim(%arg0: tensor<f32>, %arg1: index) -> index { 12 %0 = tensor.dim %arg0, %arg1 : tensor<f32> 13 return %0 : index 14} 15 16// CHECK-LABEL: func @rank( 17// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> index { 18// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] 19// CHECK: %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32> 20func @rank(%arg0: tensor<*xf32>) -> index { 21 %0 = tensor.rank %arg0 : tensor<*xf32> 22 return %0 : index 23} 24 25// CHECK-LABEL: func @tensor.cast( 26// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> { 27// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] 28// CHECK: %[[CASTED:.*]] = memref.cast %[[MEMREF]] : memref<?xindex> to memref<2xindex> 29// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED]] 30// CHECK: return %[[RET]] : tensor<2xindex> 31func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> { 32 %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex> 33 return %0 : tensor<2xindex> 34} 35 36// CHECK-LABEL: func @tensor.cast_from_unranked( 37// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> { 38// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32> 39// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32> 40// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32> 41// CHECK: return %[[RET]] : tensor<2xf32> 42func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> { 43 %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32> 44 return %0 : tensor<2xf32> 45} 46 47// CHECK-LABEL: func @tensor.cast_to_unranked( 48// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> { 49// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<2xf32> 50// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<2xf32> to memref<*xf32> 51// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<*xf32> 52// CHECK: return %[[RET]] : tensor<*xf32> 53func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> { 54 %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32> 55 return %0 : tensor<*xf32> 56} 57 58// CHECK-LABEL: func @tensor.extract( 59// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>, 60// CHECK-SAME: %[[IDX:.*]]: index) -> f32 { 61// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<?xf32> 62// CHECK: %[[RET:.*]] = memref.load %[[MEMREF]][%[[IDX]]] : memref<?xf32> 63// CHECK: return %[[RET]] : f32 64// CHECK: } 65func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 { 66 %0 = tensor.extract %arg0[%arg1] : tensor<?xf32> 67 return %0 : f32 68} 69 70// CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> { 71// CHECK: %[[RET:.*]] = arith.constant dense<> : tensor<0xindex> 72// CHECK: return %[[RET]] : tensor<0xindex> 73func @tensor.from_elements_no_elements() -> tensor<0xindex> { 74 %0 = tensor.from_elements : tensor<0xindex> 75 return %0 : tensor<0xindex> 76} 77 78// CHECK-LABEL: func @tensor.from_elements_0d( 79// CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor<index> { 80// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index> 81// CHECK: store %[[ELEM0]], %[[MEMREF]] 82// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 83// CHECK: return %[[RET]] : tensor<index> 84func @tensor.from_elements_0d(%arg0: index) -> tensor<index> { 85 %0 = tensor.from_elements %arg0 : tensor<index> 86 return %0 : tensor<index> 87} 88 89// CHECK-LABEL: func @tensor.from_elements_1d( 90// CHECK-SAME: %[[ELEM0:.*]]: index, 91// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { 92// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 93// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 94// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex> 95// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] 96// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] 97// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 98// CHECK: return %[[RET]] : tensor<2xindex> 99func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> { 100 %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex> 101 return %0 : tensor<2xindex> 102} 103 104// CHECK-LABEL: func @tensor.from_elements_2d( 105// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index) 106// CHECK-SAME: -> tensor<3x2xindex> { 107// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 108// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 109// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 110// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex> 111// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]] 112// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]] 113// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]] 114// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]] 115// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]] 116// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]] 117// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 118// CHECK: return %[[RET]] : tensor<3x2xindex> 119func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> { 120 %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1 121 : tensor<3x2xindex> 122 return %0 : tensor<3x2xindex> 123} 124 125// CHECK-LABEL: func @tensor.from_elements_3d( 126// CHECK-SAME: %[[F0:.*]]: f32 127 128// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00 129// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0 130// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0 131// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0 132// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0 133// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0 134// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0 135// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0 136// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0 137// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01 138// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01 139 140// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 141// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 142// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 143 144// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32> 145 146// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]] 147// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]] 148// CHECK: store %[[F2]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C0]]] 149// CHECK: store %[[F3]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C1]]] 150// CHECK: store %[[F4]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C0]]] 151// CHECK: store %[[F5]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C1]]] 152// CHECK: store %[[F6]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C0]]] 153// CHECK: store %[[F7]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C1]]] 154// CHECK: store %[[F8]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C0]]] 155// CHECK: store %[[F9]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C1]]] 156// CHECK: store %[[F10]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C0]]] 157// CHECK: store %[[F11]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C1]]] 158 159// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 160// CHECK: return %[[RET]] : tensor<3x2x2xf32> 161func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> { 162 %f1 = arith.constant 1.0 : f32 163 %f2 = arith.constant 2.0 : f32 164 %f3 = arith.constant 3.0 : f32 165 %f4 = arith.constant 4.0 : f32 166 %f5 = arith.constant 5.0 : f32 167 %f6 = arith.constant 6.0 : f32 168 %f7 = arith.constant 7.0 : f32 169 %f8 = arith.constant 8.0 : f32 170 %f9 = arith.constant 9.0 : f32 171 %f10 = arith.constant 10.0 : f32 172 %f11 = arith.constant 11.0 : f32 173 %0 = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11 174 : tensor<3x2x2xf32> 175 return %0 : tensor<3x2x2xf32> 176} 177 178// CHECK-LABEL: func @tensor.generate( 179// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, 180// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> { 181// CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32> 182// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 183// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 184// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex> 185// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { 186// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32> 187// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex> 188// CHECK: scf.yield 189// CHECK: } 190// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<?xindex> 191// CHECK: return %[[RET]] : tensor<?xindex> 192// CHECK: } 193func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> { 194 %result = tensor.generate %dynamic_extent { 195 ^bb0(%i : index): 196 %elem = tensor.dim %arg, %i : tensor<*xf32> 197 tensor.yield %elem : index 198 } : tensor<?xindex> 199 return %result : tensor<?xindex> 200} 201 202// Additional test that checks the logic for intermixed static and dynamic 203// extents. 204// 205// CHECK-LABEL: func @tensor.generate_static_and_dynamic( 206// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { 207// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 208// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 209// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index 210// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex> 211// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { 212// CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index 213// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> 214// CHECK: scf.yield 215// CHECK: } 216// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<16x?xindex> 217// CHECK: return %[[RET]] : tensor<16x?xindex> 218// CHECK: } 219func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { 220 %result = tensor.generate %arg0 { 221 ^bb0(%i: index, %j: index): 222 %sum = arith.addi %i, %j : index 223 tensor.yield %sum : index 224 } : tensor<16x?xindex> 225 return %result : tensor<16x?xindex> 226} 227 228// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body 229func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> { 230 // CHECK-NOT: tensor.generate 231 %tensor = tensor.generate %arg0 { 232 ^bb0(%iv: index): 233 // CHECK: test.source 234 %0 = "test.source"() : () -> index 235 tensor.yield %0 : index 236 } : tensor<?xindex> 237 return %tensor : tensor<?xindex> 238} 239 240// CHECK-LABEL: func @tensor.extract_slice( 241// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[idx1:.*]]: index, %[[idx2:.*]]: index 242func @tensor.extract_slice( 243 %t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> { 244 // CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32> 245 // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP]]> 246 %0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1] 247 : tensor<?x?xf32> to tensor<?x10xf32> 248 // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] 249 // CHECK: return %[[r_tensor]] 250 return %0 : tensor<?x10xf32> 251} 252 253// CHECK-LABEL: func @tensor.extract_slice_rank_reducing( 254// CHECK-SAME: %[[t1:.*]]: tensor<?x10x?xf32>, %[[idx1:.*]]: index, 255// CHECK-SAME: %[[idx2:.*]]: index 256func @tensor.extract_slice_rank_reducing( 257 %t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> { 258 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10x?xf32> 259 // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP]]> 260 %0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1] 261 : tensor<?x10x?xf32> to tensor<?x15xf32> 262 // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] 263 // CHECK: return %[[r_tensor]] 264 return %0 : tensor<?x15xf32> 265} 266 267// CHECK-LABEL: func @tensor.insert_slice( 268// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>, 269// CHECK-SAME: %[[idx1:.*]]: index, %[[idx2:.*]]: index 270func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>, 271 %idx1: index, %idx2: index) -> tensor<?x?xf32> { 272 // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 273 // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index 274 // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32> 275 // CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref<?x10xf32> 276 // CHECK: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]] 277 // CHECK: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]] 278 // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]]) 279 // CHECK: memref.copy %[[m1]], %[[alloc]] 280 // CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1] 281 // CHECK: memref.copy %[[m2]], %[[subview]] 282 %0 = tensor.insert_slice %t2 into %t1[%idx1, 5][%idx2, 10][1, 1] 283 : tensor<?x10xf32> into tensor<?x?xf32> 284 285 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] 286 // CHECK: return %[[r]] 287 return %0 : tensor<?x?xf32> 288} 289 290// CHECK-LABEL: func @tensor.insert( 291// CHECK-SAME: %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index, 292// CHECK-SAME: %[[f:.*]]: f32 293func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> { 294 // CHECK-DAG: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32> 295 // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32> 296 // CHECK: memref.copy %[[m1]], %[[alloc]] 297 // CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]] 298 %0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32> 299 300 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] 301 // CHECK: return %[[r]] 302 return %0 : tensor<5xf32> 303} 304 305// CHECK-LABEL: func @tensor.expand_shape( 306// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32> 307func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> { 308 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32> 309 // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [ 310 // CHECK-SAME: [0, 1], [2]] : memref<?x10xf32> into memref<2x?x10xf32> 311 %0 = tensor.expand_shape %t1 [[0, 1], [2]] 312 : tensor<?x10xf32> into tensor<2x?x10xf32> 313 314 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] 315 // CHECK: return %[[r]] 316 return %0 : tensor<2x?x10xf32> 317} 318 319// CHECK-LABEL: func @tensor.collapse_shape( 320// CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32> 321func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> { 322 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<2x?x?xf32> 323 // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [ 324 // CHECK-SAME: [0, 1], [2]] : memref<2x?x?xf32> into memref<?x?xf32> 325 %0 = tensor.collapse_shape %t1 [[0, 1], [2]] 326 : tensor<2x?x?xf32> into tensor<?x?xf32> 327 328 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]] 329 // CHECK: return %[[r]] 330 return %0 : tensor<?x?xf32> 331} 332