1// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s 2 3// CHECK-LABEL: func @tensor.cast( 4// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> { 5// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] 6// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex> 7// CHECK: %[[RET:.*]] = tensor_load %[[CASTED]] 8// CHECK: return %[[RET]] : tensor<2xindex> 9func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> { 10 %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<2xindex> 11 return %0 : tensor<2xindex> 12} 13 14// CHECK-LABEL: func @tensor.cast_from_unranked( 15// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> { 16// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32> 17// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32> 18// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32> 19// CHECK: return %[[RET]] : tensor<2xf32> 20func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> { 21 %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32> 22 return %0 : tensor<2xf32> 23} 24 25// CHECK-LABEL: func @tensor.cast_to_unranked( 26// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> { 27// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32> 28// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32> 29// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32> 30// CHECK: return %[[RET]] : tensor<*xf32> 31func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> { 32 %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32> 33 return %0 : tensor<*xf32> 34} 35 36// CHECK-LABEL: func @tensor.extract( 37// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>, 38// CHECK-SAME: %[[IDX:.*]]: index) -> f32 { 39// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32> 40// CHECK: %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref<?xf32> 41// CHECK: return %[[RET]] : f32 42// CHECK: } 43func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 { 44 %0 = tensor.extract %arg0[%arg1] : tensor<?xf32> 45 return %0 : f32 46} 47 48// CHECK-LABEL: func @tensor.from_elements( 49// CHECK-SAME: %[[ELEM0:.*]]: index, 50// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { 51// CHECK: %[[MEMREF:.*]] = alloc() 52// CHECK: %[[C0:.*]] = constant 0 : index 53// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] 54// CHECK: %[[C1:.*]] = constant 1 : index 55// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] 56// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] 57// CHECK: return %[[RET]] : tensor<2xindex> 58func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> { 59 %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex> 60 return %0 : tensor<2xindex> 61} 62 63// CHECK-LABEL: func @tensor.generate( 64// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, 65// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> { 66// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex> 67// CHECK: %[[C0:.*]] = constant 0 : index 68// CHECK: %[[C1:.*]] = constant 1 : index 69// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { 70// CHECK: %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32> 71// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex> 72// CHECK: scf.yield 73// CHECK: } 74// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<?xindex> 75// CHECK: return %[[RET]] : tensor<?xindex> 76// CHECK: } 77func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> { 78 %result = tensor.generate %dynamic_extent { 79 ^bb0(%i : index): 80 %elem = dim %arg, %i : tensor<*xf32> 81 tensor.yield %elem : index 82 } : tensor<?xindex> 83 return %result : tensor<?xindex> 84} 85 86// Additional test that checks the logic for intermixed static and dynamic 87// extents. 88// 89// CHECK-LABEL: func @tensor.generate_static_and_dynamic( 90// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { 91// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex> 92// CHECK: %[[C0:.*]] = constant 0 : index 93// CHECK: %[[C1:.*]] = constant 1 : index 94// CHECK: %[[C16:.*]] = constant 16 : index 95// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { 96// CHECK: %[[VAL_7:.*]] = addi %[[I]], %[[J]] : index 97// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> 98// CHECK: scf.yield 99// CHECK: } 100// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<16x?xindex> 101// CHECK: return %[[RET]] : tensor<16x?xindex> 102// CHECK: } 103func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { 104 %result = tensor.generate %arg0 { 105 ^bb0(%i: index, %j: index): 106 %sum = addi %i, %j : index 107 tensor.yield %sum : index 108 } : tensor<16x?xindex> 109 return %result : tensor<16x?xindex> 110} 111 112// The tensor.generate op needs to put its body into the 113// resulting scf.parallel. To handle unknown ops in the body, it cannot clone 114// the body because that would require the cloned ops to be legalized 115// immediately, which is usually not possible since they might be from various 116// other dialects. 117// 118// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body 119func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> { 120 // CHECK-NOT: tensor.generate 121 %tensor = tensor.generate %arg0 { 122 ^bb0(%iv: index): 123 // CHECK: test.source 124 %0 = "test.source"() : () -> index 125 tensor.yield %0 : index 126 } : tensor<?xindex> 127 return %tensor : tensor<?xindex> 128} 129