1// RUN: mlir-opt %s -tensor-bufferize -cse -split-input-file | FileCheck %s 2 3// CHECK-LABEL: func @dim( 4// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>, 5// CHECK-SAME: %[[INDEX:.*]]: index) -> index { 6// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<f32> 7// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32> 8// CHECK: return %[[EXTENT]] : index 9func.func @dim(%arg0: tensor<f32>, %arg1: index) -> index { 10 %0 = tensor.dim %arg0, %arg1 : tensor<f32> 11 return %0 : index 12} 13 14// ----- 15 16// CHECK-LABEL: func @rank( 17// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> index { 18// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] 19// CHECK: %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32> 20func.func @rank(%arg0: tensor<*xf32>) -> index { 21 %0 = tensor.rank %arg0 : tensor<*xf32> 22 return %0 : index 23} 24 25// ----- 26 27// CHECK-LABEL: func @tensor.cast( 28// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> { 29// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] 30// CHECK: %[[CASTED:.*]] = memref.cast %[[MEMREF]] : memref<?xindex> to memref<2xindex> 31// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED]] 32// CHECK: return %[[RET]] : tensor<2xindex> 33func.func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> { 34 %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex> 35 return %0 : tensor<2xindex> 36} 37 38// ----- 39 40// CHECK-LABEL: func @tensor.cast_from_unranked( 41// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> { 42// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32> 43// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32> 44// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32> 45// CHECK: return %[[RET]] : tensor<2xf32> 46func.func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> { 47 %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32> 48 return %0 : tensor<2xf32> 49} 50 51// ----- 52 53// CHECK-LABEL: func @tensor.cast_to_unranked( 54// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> { 55// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<2xf32> 56// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<2xf32> to memref<*xf32> 57// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<*xf32> 58// CHECK: return %[[RET]] : tensor<*xf32> 59func.func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> { 60 %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32> 61 return %0 : tensor<*xf32> 62} 63 64// ----- 65 66// CHECK-LABEL: func @tensor.extract( 67// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>, 68// CHECK-SAME: %[[IDX:.*]]: index) -> f32 { 69// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<?xf32> 70// CHECK: %[[RET:.*]] = memref.load %[[MEMREF]][%[[IDX]]] : memref<?xf32> 71// CHECK: return %[[RET]] : f32 72// CHECK: } 73func.func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 { 74 %0 = tensor.extract %arg0[%arg1] : tensor<?xf32> 75 return %0 : f32 76} 77 78// ----- 79 80// CHECK-LABEL: func @tensor.from_elements_0d( 81// CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor<index> { 82// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index> 83// CHECK: store %[[ELEM0]], %[[MEMREF]] 84// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 85// CHECK: return %[[RET]] : tensor<index> 86func.func @tensor.from_elements_0d(%arg0: index) -> tensor<index> { 87 %0 = tensor.from_elements %arg0 : tensor<index> 88 return %0 : tensor<index> 89} 90 91// ----- 92 93// CHECK-LABEL: func @tensor.from_elements_1d( 94// CHECK-SAME: %[[ELEM0:.*]]: index, 95// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { 96// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 97// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 98// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex> 99// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] 100// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] 101// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 102// CHECK: return %[[RET]] : tensor<2xindex> 103func.func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> { 104 %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex> 105 return %0 : tensor<2xindex> 106} 107 108// ----- 109 110// CHECK-LABEL: func @tensor.from_elements_2d( 111// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index) 112// CHECK-SAME: -> tensor<3x2xindex> { 113// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 114// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 115// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 116// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex> 117// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]] 118// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]] 119// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]] 120// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]] 121// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]] 122// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]] 123// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 124// CHECK: return %[[RET]] : tensor<3x2xindex> 125func.func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> { 126 %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1 127 : tensor<3x2xindex> 128 return %0 : tensor<3x2xindex> 129} 130 131// ----- 132 133// CHECK-LABEL: func @tensor.from_elements_3d( 134// CHECK-SAME: %[[F0:.*]]: f32 135 136// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00 137// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0 138// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0 139// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0 140// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0 141// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0 142// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0 143// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0 144// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0 145// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01 146// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01 147 148// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 149// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 150// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 151 152// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32> 153 154// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]] 155// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]] 156// CHECK: store %[[F2]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C0]]] 157// CHECK: store %[[F3]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C1]]] 158// CHECK: store %[[F4]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C0]]] 159// CHECK: store %[[F5]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C1]]] 160// CHECK: store %[[F6]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C0]]] 161// CHECK: store %[[F7]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C1]]] 162// CHECK: store %[[F8]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C0]]] 163// CHECK: store %[[F9]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C1]]] 164// CHECK: store %[[F10]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C0]]] 165// CHECK: store %[[F11]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C1]]] 166 167// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] 168// CHECK: return %[[RET]] : tensor<3x2x2xf32> 169func.func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> { 170 %f1 = arith.constant 1.0 : f32 171 %f2 = arith.constant 2.0 : f32 172 %f3 = arith.constant 3.0 : f32 173 %f4 = arith.constant 4.0 : f32 174 %f5 = arith.constant 5.0 : f32 175 %f6 = arith.constant 6.0 : f32 176 %f7 = arith.constant 7.0 : f32 177 %f8 = arith.constant 8.0 : f32 178 %f9 = arith.constant 9.0 : f32 179 %f10 = arith.constant 10.0 : f32 180 %f11 = arith.constant 11.0 : f32 181 %0 = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11 182 : tensor<3x2x2xf32> 183 return %0 : tensor<3x2x2xf32> 184} 185 186// ----- 187 188// CHECK-LABEL: func @tensor.generate( 189// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, 190// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> { 191// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 192// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 193// CHECK-DAG: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32> 194// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex> 195// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { 196// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32> 197// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex> 198// CHECK: scf.yield 199// CHECK: } 200// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<?xindex> 201// CHECK: return %[[RET]] : tensor<?xindex> 202// CHECK: } 203func.func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> { 204 %result = tensor.generate %dynamic_extent { 205 ^bb0(%i : index): 206 %elem = tensor.dim %arg, %i : tensor<*xf32> 207 tensor.yield %elem : index 208 } : tensor<?xindex> 209 return %result : tensor<?xindex> 210} 211 212// ----- 213 214// Additional test that checks the logic for intermixed static and dynamic 215// extents. 216// 217// CHECK-LABEL: func @tensor.generate_static_and_dynamic( 218// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { 219// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 220// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 221// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index 222// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex> 223// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { 224// CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index 225// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> 226// CHECK: scf.yield 227// CHECK: } 228// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<16x?xindex> 229// CHECK: return %[[RET]] : tensor<16x?xindex> 230// CHECK: } 231func.func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { 232 %result = tensor.generate %arg0 { 233 ^bb0(%i: index, %j: index): 234 %sum = arith.addi %i, %j : index 235 tensor.yield %sum : index 236 } : tensor<16x?xindex> 237 return %result : tensor<16x?xindex> 238} 239 240// ----- 241 242// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body 243func.func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> { 244 // CHECK-NOT: tensor.generate 245 %tensor = tensor.generate %arg0 { 246 ^bb0(%iv: index): 247 // CHECK: test.source 248 %0 = "test.source"() : () -> index 249 tensor.yield %0 : index 250 } : tensor<?xindex> 251 return %tensor : tensor<?xindex> 252} 253 254// ----- 255 256 // CHECK-DAG: #[[$MAP0a:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> 257 258// CHECK-LABEL: func @tensor.extract_slice( 259// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[idx1:.*]]: index, %[[idx2:.*]]: index 260func.func @tensor.extract_slice( 261 %t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> { 262 // CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32> 263 // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP0a]]> 264 %0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1] 265 : tensor<?x?xf32> to tensor<?x10xf32> 266 // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] 267 // CHECK: return %[[r_tensor]] 268 return %0 : tensor<?x10xf32> 269} 270 271// ----- 272 273// CHECK-DAG: #[[$MAP0b:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> 274 275// CHECK-LABEL: func @tensor.extract_slice_rank_reducing( 276// CHECK-SAME: %[[t1:.*]]: tensor<?x10x?xf32>, %[[idx1:.*]]: index, 277// CHECK-SAME: %[[idx2:.*]]: index 278func.func @tensor.extract_slice_rank_reducing( 279 %t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> { 280 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10x?xf32> 281 // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP0b]]> 282 %0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1] 283 : tensor<?x10x?xf32> to tensor<?x15xf32> 284 // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] 285 // CHECK: return %[[r_tensor]] 286 return %0 : tensor<?x15xf32> 287} 288 289// ----- 290 291// CHECK-LABEL: func @tensor.insert_slice( 292// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>, 293// CHECK-SAME: %[[idx1:.*]]: index, %[[idx2:.*]]: index 294func.func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>, 295 %idx1: index, %idx2: index) -> tensor<?x?xf32> { 296 // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 297 // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index 298 // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32> 299 // CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref<?x10xf32> 300 // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]] 301 // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]] 302 // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]]) 303 // CHECK: memref.copy %[[m1]], %[[alloc]] 304 // CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1] 305 // CHECK: memref.copy %[[m2]], %[[subview]] 306 %0 = tensor.insert_slice %t2 into %t1[%idx1, 5][%idx2, 10][1, 1] 307 : tensor<?x10xf32> into tensor<?x?xf32> 308 309 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] 310 // CHECK: return %[[r]] 311 return %0 : tensor<?x?xf32> 312} 313 314// ----- 315 316// CHECK: #[[$MAP11:.*]] = affine_map<()[s0] -> (s0)> 317 318// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_1( 319func.func @tensor.insert_slice_rank_reducing_1( 320 %t1: tensor<?x?xf32>, %f: tensor<f32>, %idx1: index, %idx2: index) 321 -> tensor<?x?xf32> 322{ 323 // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?xf32> 324 // CHECK: memref.subview %[[alloc]][%{{.*}}, %{{.*}}] [1, 1] [1, 1] : memref<?x?xf32> to memref<f32, #[[$MAP11]]> 325 // CHECK: memref.copy {{.*}} : memref<f32> to memref<f32, #[[$MAP11]]> 326 %0 = tensor.insert_slice %f into %t1[%idx1, %idx2][1, 1][1, 1] 327 : tensor<f32> into tensor<?x?xf32> 328 return %0 : tensor<?x?xf32> 329} 330 331// ----- 332 333// CHECK: #[[$MAP12:.*]] = affine_map<(d0, d1, d2, d3, d4)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5)> 334 335// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_2( 336func.func @tensor.insert_slice_rank_reducing_2( 337 %t1: tensor<?x?x?x?x?x?x?xf32>, %t2: tensor<2x1x4x1x1xf32>, %i: index) 338 -> tensor<?x?x?x?x?x?x?xf32> 339{ 340 // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?x?x?x?x?x?xf32> 341 // CHECK: memref.subview %[[alloc]][{{.*}}] [1, 2, 1, 4, 1, 1, 1] [1, 1, 1, 1, 1, 1, 1] : memref<?x?x?x?x?x?x?xf32> to memref<2x1x4x1x1xf32, #[[$MAP12]]> 342 // CHECK: memref.copy {{.*}} : memref<2x1x4x1x1xf32> to memref<2x1x4x1x1xf32, #[[$MAP12]]> 343 %0 = tensor.insert_slice %t2 into %t1[%i, %i, %i, %i, %i, %i, %i][1, 2, 1, 4, 1, 1, 1][1, 1, 1, 1, 1, 1, 1] 344 : tensor<2x1x4x1x1xf32> into tensor<?x?x?x?x?x?x?xf32> 345 return %0 : tensor<?x?x?x?x?x?x?xf32> 346} 347 348// ----- 349 350// CHECK-LABEL: func @tensor.insert( 351// CHECK-SAME: %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index, 352// CHECK-SAME: %[[f:.*]]: f32 353func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> { 354 // CHECK-DAG: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32> 355 // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32> 356 // CHECK: memref.copy %[[m1]], %[[alloc]] 357 // CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]] 358 %0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32> 359 360 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] 361 // CHECK: return %[[r]] 362 return %0 : tensor<5xf32> 363} 364 365// ----- 366 367// CHECK-LABEL: func @tensor.expand_shape( 368// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32> 369func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> { 370 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32> 371 // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [ 372 // CHECK-SAME: [0, 1], [2]] : memref<?x10xf32> into memref<2x?x10xf32> 373 %0 = tensor.expand_shape %t1 [[0, 1], [2]] 374 : tensor<?x10xf32> into tensor<2x?x10xf32> 375 376 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] 377 // CHECK: return %[[r]] 378 return %0 : tensor<2x?x10xf32> 379} 380 381// ----- 382 383// CHECK-DAG: #[[$MAP1b:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)> 384// CHECK-DAG: #[[$MAP2b:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)> 385 386// CHECK-LABEL: func @tensor.expand_shape_of_slice( 387// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32> 388func.func @tensor.expand_shape_of_slice( 389 %t1: tensor<?x20xf32>, %o1: index, %s1: index) -> tensor<?x7x2x5xf32> { 390 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x20xf32> 391 // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, #[[$MAP1b]]> 392 %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] : 393 tensor<?x20xf32> to tensor<?x10xf32> 394 // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [ 395 // CHECK-SAME: [0, 1], [2, 3]] : memref<?x10xf32, #[[$MAP1b]]> into memref<?x7x2x5xf32, #[[$MAP2b]]> 396 %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] : 397 tensor<?x10xf32> into tensor<?x7x2x5xf32> 398 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] 399 // CHECK: return %[[r]] 400 return %1 : tensor<?x7x2x5xf32> 401} 402 403// ----- 404 405// CHECK-DAG: #[[$MAP9:.*]] = affine_map<()[s0] -> (s0)> 406// CHECK-DAG: #[[$MAP10:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> 407 408// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice( 409// CHECK-SAME: %[[t1:.*]]: tensor<?xf32> 410func.func @tensor.expand_shape_of_scalar_slice( 411 %t1: tensor<?xf32>, %o1: index, %s1: index) -> tensor<1xf32> { 412 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?xf32> 413 // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] : memref<?xf32> to memref<f32, #[[$MAP9]]> 414 %0 = tensor.extract_slice %t1[%o1][1][1] : tensor<?xf32> to tensor<f32> 415 // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref<f32, #[[$MAP9]]> into memref<1xf32, #[[$MAP10]]> 416 %1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1xf32> 417 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] 418 // CHECK: return %[[r]] 419 return %1 : tensor<1xf32> 420} 421 422// ----- 423 424// CHECK-LABEL: func @tensor.collapse_shape( 425// CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32> 426func.func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> { 427 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<2x?x?xf32> 428 // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [ 429 // CHECK-SAME: [0, 1], [2]] : memref<2x?x?xf32> into memref<?x?xf32> 430 %0 = tensor.collapse_shape %t1 [[0, 1], [2]] 431 : tensor<2x?x?xf32> into tensor<?x?xf32> 432 433 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]] 434 // CHECK: return %[[r]] 435 return %0 : tensor<?x?xf32> 436} 437 438// ----- 439 440// CHECK-LABEL: func @tensor.collapse_shape_to_scalar( 441// CHECK-SAME: %[[t1:.*]]: tensor<1x1x1xf32> 442func.func @tensor.collapse_shape_to_scalar(%t1: tensor<1x1x1xf32>) -> tensor<f32> { 443 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<1x1x1xf32> 444 // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [] : memref<1x1x1xf32> into memref<f32> 445 %0 = tensor.collapse_shape %t1 [] 446 : tensor<1x1x1xf32> into tensor<f32> 447 448 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]] 449 // CHECK: return %[[r]] 450 return %0 : tensor<f32> 451} 452 453// ----- 454 455// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)> 456// CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)> 457 458// CHECK-LABEL: func @tensor.collapse_shape_of_slice( 459func.func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor<i32> { 460 // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, #[[$MAP3]]> 461 %0 = tensor.extract_slice %arg0[1] [1] [1] : tensor<2xi32> to tensor<1xi32> 462 // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, #[[$MAP3]]> into memref<i32, #[[$MAP4]]> 463 %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32> 464 return %1 : tensor<i32> 465} 466 467// ----- 468 469// CHECK-LABEL: func @tensor.collapse_shape_of_slice2( 470func.func @tensor.collapse_shape_of_slice2( 471 %arg0: tensor<?x?x?x?xi64>, %o1: index, %o2: index, %o3: index, %o4: index) 472 -> tensor<87x63648xi64> { 473 // CHECK: %[[subview:.*]] = memref.subview %{{.*}} : memref<?x?x?x?xi64> to memref<87x78x68x12xi64, #{{.*}}> 474 %0 = tensor.extract_slice %arg0[%o1, %o2, %o3, %o4] [87, 78, 68, 12] [1, 1, 1, 1] : tensor<?x?x?x?xi64> to tensor<87x78x68x12xi64> 475 476 // This memref may not be collapsible, so the buffer must be copied to get rid 477 // of the layout map. 478 // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<87x78x68x12xi64> 479 // CHECK: memref.copy %[[subview]], %[[alloc]] 480 // CHECK: memref.collapse_shape %[[alloc]] [ 481 // CHECK-SAME: [0], [1, 2, 3]] : memref<87x78x68x12xi64> into memref<87x63648xi64> 482 %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor<87x78x68x12xi64> into tensor<87x63648xi64> 483 return %1 : tensor<87x63648xi64> 484} 485 486// ----- 487 488// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> 489// CHECK-DAG: #[[$MAP6:.*]] = affine_map<(d0) -> (d0 * 2)> 490 491// CHECK-LABEL: func @tensor.collapse_shape_of_slice3( 492// CHECK-SAME: %[[t1:.*]]: tensor<1x2xf32> 493func.func @tensor.collapse_shape_of_slice3(%t1: tensor<1x2xf32>) -> tensor<1xf32> { 494 // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP5]]> 495 %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32> 496 // CHECK: memref.collapse_shape %{{.*}} [ 497 // CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP5]]> into memref<1xf32, #[[$MAP6]]> 498 %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32> 499 return %1 : tensor<1xf32> 500} 501 502// ----- 503 504// CHECK-DAG: #[[$MAP7:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 * 4 + d2)> 505// CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)> 506 507// CHECK-LABEL: func @tensor.collapse_shape_of_slice4( 508// CHECK-SAME: %[[t1:.*]]: tensor<?x2x4xf32>, 509// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<8xf32> { 510func.func @tensor.collapse_shape_of_slice4(%arg0: tensor<?x2x4xf32>, %offset: index, %size: index) -> tensor<8xf32> { 511 // CHECK: memref.subview %{{.*}} : memref<?x2x4xf32> to memref<4x2x1xf32, #[[$MAP7]]> 512 %0 = tensor.extract_slice %arg0[0, 0, %offset] [4, 2, 1] [1, 1, 1] : tensor<?x2x4xf32> to tensor<4x2x1xf32> 513 // CHECK: memref.collapse_shape %{{.*}} [ 514 // CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, #[[$MAP7]]> into memref<8xf32, #[[$MAP8]]> 515 %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32> 516 return %ret: tensor<8xf32> 517} 518 519// ----- 520 521// CHECK-LABEL: func @tensor.reshape( 522// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32> 523func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> { 524 // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32> 525 526 // CHECK: %[[two:.*]] = arith.constant 2 : i64 527 %two = arith.constant 2 : i64 528 // CHECK: %[[five:.*]] = arith.constant 5 : i64 529 %five = arith.constant 5 : i64 530 531 // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 128 : i64} : memref<3xi64> 532 // CHECK: %[[zero_idx:.*]] = arith.constant 0 : index 533 // CHECK: %[[one_idx:.*]] = arith.constant 1 : index 534 // CHECK: %[[two_idx:.*]] = arith.constant 2 : index 535 // CHECK: memref.store %[[two]], %[[alloc]][%[[zero_idx]]] : memref<3xi64> 536 // CHECK: memref.store %[[two]], %[[alloc]][%[[one_idx]]] : memref<3xi64> 537 // CHECK: memref.store %[[five]], %[[alloc]][%[[two_idx]]] : memref<3xi64> 538 %shape = tensor.from_elements %two, %two, %five : tensor<3xi64> 539 540 // CHECK: %[[reshaped:.*]] = memref.reshape %[[m1]](%[[alloc]]) : (memref<?x10xf32>, memref<3xi64>) -> memref<2x2x5xf32> 541 %reshaped = tensor.reshape %t1(%shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<2x2x5xf32> 542 543 // CHECK: %[[r:.*]] = bufferization.to_tensor %[[reshaped]] 544 // CHECK: return %[[r]] 545 return %reshaped : tensor<2x2x5xf32> 546} 547