1// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-linalg-to-vector-patterns %s | FileCheck %s 2 3func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) { 4 linalg.conv_1d_nwc_wcf 5 {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} 6 ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>) 7 outs(%output : memref<4x2x8xf32>) 8 return 9} 10 11// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 12// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> 13// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 14 15// CHECK: func @conv1d_nwc_4x2x8_memref 16// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) 17 18// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 19// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 20 21/// Read the whole data in one shot. 22// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] 23// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] 24// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] 25 26// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] 27// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> 28// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] 29// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> 30 31// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32> 32 33// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] 34// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> 35// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] 36// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> 37 38/// w == 0, kw == 0 39// CHECK: %[[CONTRACT_0:.+]] = vector.contract { 40// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], 41// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 42// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]] 43// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> 44 45/// w == 1, kw == 0 46// CHECK: %[[CONTRACT_1:.+]] = vector.contract { 47// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], 48// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 49// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]] 50// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> 51 52/// w == 0, kw == 0 53// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]] 54// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> 55/// w == 1, kw == 0 56// CHECK: %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_1]], %[[RES_0]] 57// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> 58 59// Write the result back in one shot. 60// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] 61 62// ----- 63 64func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) { 65 linalg.conv_1d_nwc_wcf 66 {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} 67 ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>) 68 outs(%output : memref<4x2x8xf32>) 69 return 70} 71 72// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 73// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> 74// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 75 76// CHECK: func @conv1d_nwc_4x2x8_memref 77// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) 78 79// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 80// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 81 82/// Read the whole data in one shot. 83// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] 84// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] 85// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] 86 87// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] 88// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> 89// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] 90// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> 91// CHECK: %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] 92// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> 93// CHECK: %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] 94// CHECK-SAME: {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> 95 96// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32> 97// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32> 98 99// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] 100// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> 101// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] 102// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> 103 104/// w == 0, kw == 0 105// CHECK: %[[CONTRACT_0:.+]] = vector.contract { 106// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], 107// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 108// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_0]] 109// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> 110/// w == 1, kw == 0 111// CHECK: %[[CONTRACT_1:.+]] = vector.contract { 112// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], 113// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 114// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_0]], %[[V_OUTPUT_1]] 115// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> 116/// w == 1, kw == 1 117// CHECK: %[[CONTRACT_2:.+]] = vector.contract { 118// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], 119// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 120// CHECK-SAME: %[[V_INPUT_2]], %[[V_FILTER_1]], %[[CONTRACT_0]] 121// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> 122/// w == 1, kw == 1 123// CHECK: %[[CONTRACT_3:.+]] = vector.contract { 124// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], 125// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 126// CHECK-SAME: %[[V_INPUT_3]], %[[V_FILTER_1]], %[[CONTRACT_1]] 127// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> 128 129/// w == 0, kw == 0 130// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_2]], %[[V_OUTPUT_R]] 131// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> 132/// w == 1, kw == 0 133// CHECK: %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_3]], %[[RES_0]] 134// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> 135 136// Write the result back in one shot. 137// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] 138 139// ----- 140 141func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) { 142 linalg.conv_1d_nwc_wcf 143 {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 144 ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>) 145 outs(%output : memref<4x2x8xf32>) 146 return 147} 148 149// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 150// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> 151// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 152 153// CHECK: func @conv1d_nwc_4x2x8_memref 154// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) 155 156// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 157// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 158 159/// Read the whole data in one shot. 160// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] 161// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] 162// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] 163 164// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] 165// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32> 166// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] 167// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32> 168 169// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32> 170// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32> 171 172/// w == 0, kw == 0 173// CHECK: %[[CONTRACT_0:.+]] = vector.contract { 174// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], 175// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 176// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] 177// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> 178/// w == 0, kw == 1 179// CHECK: %[[CONTRACT_1:.+]] = vector.contract { 180// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], 181// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 182// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]] 183// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> 184 185// Write the result back in one shot. 186// CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] 187 188// ----- 189 190func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) { 191 linalg.depthwise_conv_1d_nwc_wc 192 {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 193 ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>) 194 outs(%output : memref<3x2x4xf32>) 195 return 196} 197 198// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4_memref 199// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>) 200 201// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 202// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 203 204/// Read the whole data in one shot. 205// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]] 206// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]] 207// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] 208 209// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] 210// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32> 211// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] 212// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32> 213 214// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32> 215// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32> 216 217/// w == 0, kw = 0 218// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32> 219// CHECK: %[[FMA_0:.*]] = vector.fma %[[V_INPUT_0]], %[[B_FILTER_0]], %[[V_OUTPUT_R]] : vector<3x2x4xf32> 220 221/// w == 0, kw = 1 222// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32> 223// CHECK: %[[FMA_1:.*]] = vector.fma %[[V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x2x4xf32> 224 225// Write the result back in one shot. 226// CHECK: vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] 227 228 229// ----- 230 231func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) { 232 linalg.conv_1d_nwc_wcf 233 {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} 234 ins(%input, %filter : memref<1x2x3xf16>, memref<1x3x2xf16>) 235 outs(%output : memref<1x2x2xf32>) 236 return 237} 238 239// CHECK: func @conv_1d_nwc_wcf_mixed_type_memref 240// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xf16>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xf16>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>) 241 242// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 243// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 244 245/// Read the whole data in one shot. 246// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]] 247// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]] 248// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] 249// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x2xf16> 250// CHECK: %[[CONT:.*]] = vector.contract 251// {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32> 252// CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] 253