1// RUN: mlir-opt %s -pass-pipeline="func.func(convert-vector-to-scf{lower-tensors=true})" -split-input-file -allow-unregistered-dialect | FileCheck %s 2 3// CHECK-LABEL: func @transfer_read_2d( 4// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<4x9xf32>> 5// CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<4x9xf32>> to memref<4xvector<9xf32>> 6// CHECK: scf.for {{.*}} { 7// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %cst {in_bounds = [true]} : tensor<?x?xf32>, vector<9xf32> 8// CHECK: memref.store %[[READ]], %[[CASTED]][%{{.*}}] : memref<4xvector<9xf32>> 9// CHECK: } 10// CHECK: %[[LOADED:.*]] = memref.load %[[ALLOC]][] : memref<vector<4x9xf32>> 11// CHECK: return %[[LOADED]] : vector<4x9xf32> 12func.func @transfer_read_2d(%A : tensor<?x?xf32>, %base1 : index, %base2 : index) 13 -> (vector<4x9xf32>){ 14 %p = arith.constant -42.0: f32 15 %f = vector.transfer_read %A[%base1, %base2], %p {in_bounds = [true, true]} 16 : tensor<?x?xf32>, vector<4x9xf32> 17 return %f : vector<4x9xf32> 18} 19 20// ----- 21 22// CHECK-LABEL: func @transfer_write_2d( 23// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<2x3xf32>> 24// CHECK: memref.store {{.*}}, %[[ALLOC]][] : memref<vector<2x3xf32>> 25// CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<2x3xf32>> to memref<2xvector<3xf32>> 26// CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[STATE:.*]] = %{{.*}}) -> (tensor<?x?xf32>) { 27// CHECK: %[[LOADED:.*]] = memref.load %[[CASTED]][%{{.*}}] : memref<2xvector<3xf32>> 28// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[LOADED]], %[[STATE]][{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor<?x?xf32> 29// CHECK: scf.yield %[[WRITE]] : tensor<?x?xf32> 30// CHECK: } 31// CHECK: return %[[RESULT]] : tensor<?x?xf32> 32func.func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>, 33 %base1 : index, %base2 : index) -> (tensor<?x?xf32>) { 34 %t = vector.transfer_write %vec, %A[%base1, %base2] {in_bounds = [true, true]} 35 : vector<2x3xf32>, tensor<?x?xf32> 36 return %t : tensor<?x?xf32> 37} 38 39