1// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s 2// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,4 -split-input-file | FileCheck %s --check-prefix=CHECK2D 3 4// CHECK-LABEL: func @distribute_vector_add 5// CHECK-SAME: (%[[ID:.*]]: index 6// CHECK-NEXT: %[[ADDV:.*]] = arith.addf %{{.*}}, %{{.*}} : vector<32xf32> 7// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32> 8// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32> 9// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<1xf32> 10// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32> 11// CHECK-NEXT: return %[[INS]] : vector<32xf32> 12func.func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> { 13 %0 = arith.addf %A, %B : vector<32xf32> 14 return %0: vector<32xf32> 15} 16 17// ----- 18 19// CHECK-LABEL: func @distribute_vector_add_exp 20// CHECK-SAME: (%[[ID:.*]]: index 21// CHECK-NEXT: %[[EXPV:.*]] = math.exp %{{.*}} : vector<32xf32> 22// CHECK-NEXT: %[[ADDV:.*]] = arith.addf %[[EXPV]], %{{.*}} : vector<32xf32> 23// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32> 24// CHECK-NEXT: %[[EXC:.*]] = math.exp %[[EXA]] : vector<1xf32> 25// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32> 26// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[EXC]], %[[EXB]] : vector<1xf32> 27// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32> 28// CHECK-NEXT: return %[[INS]] : vector<32xf32> 29func.func @distribute_vector_add_exp(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> { 30 %C = math.exp %A : vector<32xf32> 31 %0 = arith.addf %C, %B : vector<32xf32> 32 return %0: vector<32xf32> 33} 34 35// ----- 36 37// CHECK-LABEL: func @vector_add_read_write 38// CHECK-SAME: (%[[ID:.*]]: index 39// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32> 40// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32> 41// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<1xf32> 42// CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32> 43// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[EXC]] : vector<1xf32> 44// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] {{.*}} : vector<1xf32>, memref<32xf32> 45// CHECK-NEXT: return 46func.func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) { 47 %c0 = arith.constant 0 : index 48 %cf0 = arith.constant 0.0 : f32 49 %a = vector.transfer_read %A[%c0], %cf0: memref<32xf32>, vector<32xf32> 50 %b = vector.transfer_read %B[%c0], %cf0: memref<32xf32>, vector<32xf32> 51 %acc = arith.addf %a, %b: vector<32xf32> 52 %c = vector.transfer_read %C[%c0], %cf0: memref<32xf32>, vector<32xf32> 53 %d = arith.addf %acc, %c: vector<32xf32> 54 vector.transfer_write %d, %D[%c0]: vector<32xf32>, memref<32xf32> 55 return 56} 57 58// ----- 59 60// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)> 61 62// CHECK: func @vector_add_cycle 63// CHECK-SAME: (%[[ID:.*]]: index 64// CHECK: %[[ID1:.*]] = affine.apply #[[MAP0]]()[%[[ID]]] 65// CHECK-NEXT: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID1]]], %{{.*}} : memref<64xf32>, vector<2xf32> 66// CHECK-NEXT: %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID]]] 67// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID2]]], %{{.*}} : memref<64xf32>, vector<2xf32> 68// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<2xf32> 69// CHECK-NEXT: %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID]]] 70// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]]] {{.*}} : vector<2xf32>, memref<64xf32> 71// CHECK-NEXT: return 72func.func @vector_add_cycle(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) { 73 %c0 = arith.constant 0 : index 74 %cf0 = arith.constant 0.0 : f32 75 %a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<64xf32> 76 %b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<64xf32> 77 %acc = arith.addf %a, %b: vector<64xf32> 78 vector.transfer_write %acc, %C[%c0]: vector<64xf32>, memref<64xf32> 79 return 80} 81 82// ----- 83 84// Negative test to make sure nothing is done in case the vector size is not a 85// multiple of multiplicity. 86// CHECK-LABEL: func @vector_negative_test 87// CHECK: %[[C0:.*]] = arith.constant 0 : index 88// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32> 89// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32> 90// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<16xf32> 91// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]]] {{.*}} : vector<16xf32>, memref<64xf32> 92// CHECK-NEXT: return 93func.func @vector_negative_test(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) { 94 %c0 = arith.constant 0 : index 95 %cf0 = arith.constant 0.0 : f32 96 %a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<16xf32> 97 %b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<16xf32> 98 %acc = arith.addf %a, %b: vector<16xf32> 99 vector.transfer_write %acc, %C[%c0]: vector<16xf32>, memref<64xf32> 100 return 101} 102 103// ----- 104 105// CHECK-LABEL: func @distribute_vector_add_3d 106// CHECK-SAME: (%[[ID0:.*]]: index, %[[ID1:.*]]: index 107// CHECK-NEXT: %[[ADDV:.*]] = arith.addf %{{.*}}, %{{.*}} : vector<64x4x32xf32> 108// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID0]], %[[ID1]]] : vector<64x4x32xf32> to vector<2x4x1xf32> 109// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID0]], %[[ID1]]] : vector<64x4x32xf32> to vector<2x4x1xf32> 110// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32> 111// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID0]], %[[ID1]]] : vector<2x4x1xf32> into vector<64x4x32xf32> 112// CHECK-NEXT: return %[[INS]] : vector<64x4x32xf32> 113func.func @distribute_vector_add_3d(%id0 : index, %id1 : index, 114 %A: vector<64x4x32xf32>, %B: vector<64x4x32xf32>) -> vector<64x4x32xf32> { 115 %0 = arith.addf %A, %B : vector<64x4x32xf32> 116 return %0: vector<64x4x32xf32> 117} 118 119// ----- 120 121// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)> 122 123// CHECK: func @vector_add_transfer_3d 124// CHECK-SAME: (%[[ID_0:.*]]: index, %[[ID_1:.*]]: index 125// CHECK: %[[C0:.*]] = arith.constant 0 : index 126// CHECK: %[[ID1:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]] 127// CHECK-NEXT: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID1]], %[[C0]], %[[ID_1]]], %{{.*}} : memref<64x64x64xf32>, vector<2x4x1xf32> 128// CHECK-NEXT: %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]] 129// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID2]], %[[C0]], %[[ID_1]]], %{{.*}} : memref<64x64x64xf32>, vector<2x4x1xf32> 130// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32> 131// CHECK-NEXT: %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]] 132// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]], %[[C0]], %[[ID_1]]] {{.*}} : vector<2x4x1xf32>, memref<64x64x64xf32> 133// CHECK-NEXT: return 134func.func @vector_add_transfer_3d(%id0 : index, %id1 : index, %A: memref<64x64x64xf32>, 135 %B: memref<64x64x64xf32>, %C: memref<64x64x64xf32>) { 136 %c0 = arith.constant 0 : index 137 %cf0 = arith.constant 0.0 : f32 138 %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: memref<64x64x64xf32>, vector<64x4x32xf32> 139 %b = vector.transfer_read %B[%c0, %c0, %c0], %cf0: memref<64x64x64xf32>, vector<64x4x32xf32> 140 %acc = arith.addf %a, %b: vector<64x4x32xf32> 141 vector.transfer_write %acc, %C[%c0, %c0, %c0]: vector<64x4x32xf32>, memref<64x64x64xf32> 142 return 143} 144 145// ----- 146 147#map0 = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)> 148#map1 = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)> 149#map2 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> 150 151// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)> 152// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)> 153// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)> 154// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> 155 156// CHECK: func @vector_add_transfer_permutation 157// CHECK-SAME: (%[[ID_0:.*]]: index, %[[ID_1:.*]]: index 158// CHECK: %[[C0:.*]] = arith.constant 0 : index 159// CHECK: %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]] 160// CHECK-NEXT: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[ID2]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<?x?x?x?xf32>, vector<2x4x1xf32> 161// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID_0]], %[[C0]], %[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP2]]} : memref<?x?x?x?xf32>, vector<2x4x1xf32> 162// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32> 163// CHECK-NEXT: %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]] 164// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]], %[[ID_1]], %[[C0]], %[[ID3]]] {permutation_map = #[[MAP3]]} : vector<2x4x1xf32>, memref<?x?x?x?xf32> 165// CHECK-NEXT: return 166func.func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref<?x?x?x?xf32>, 167 %B: memref<?x?x?x?xf32>, %C: memref<?x?x?x?xf32>) { 168 %c0 = arith.constant 0 : index 169 %cf0 = arith.constant 0.0 : f32 170 %a = vector.transfer_read %A[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<64x4x32xf32> 171 %b = vector.transfer_read %B[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map1}: memref<?x?x?x?xf32>, vector<64x4x32xf32> 172 %acc = arith.addf %a, %b: vector<64x4x32xf32> 173 vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32> 174 return 175} 176 177// ----- 178 179// CHECK2D-LABEL: vector_add_contract 180// CHECK2D: %[[A:.+]] = vector.transfer_read %arg2[%0, %c0], %cst : memref<?x?xf32>, vector<2x4xf32> 181// CHECK2D: %[[B:.+]] = vector.transfer_read %arg3[%2, %c0], %cst : memref<?x?xf32>, vector<16x4xf32> 182// CHECK2D: %[[C:.+]] = vector.transfer_read %arg4[%4, %5], %cst : memref<?x?xf32>, vector<2x16xf32> 183// CHECK2D: %[[E:.+]] = vector.transfer_read %arg5[%7, %8], %cst : memref<?x?xf32>, vector<2x16xf32> 184// CHECK2D: %[[D:.+]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]] : vector<2x4xf32>, vector<16x4xf32> into vector<2x16xf32> 185// CHECK2D: %[[R:.+]] = arith.addf %[[D]], %[[E]] : vector<2x16xf32> 186// CHECK2D: vector.transfer_write %[[R]], {{.*}} : vector<2x16xf32>, memref<?x?xf32> 187func.func @vector_add_contract(%id0 : index, %id1 : index, %A: memref<?x?xf32>, 188 %B: memref<?x?xf32>, %C: memref<?x?xf32>, %D: memref<?x?xf32>) { 189 %c0 = arith.constant 0 : index 190 %cf0 = arith.constant 0.0 : f32 191 %a = vector.transfer_read %A[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32> 192 %b = vector.transfer_read %B[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32> 193 %c = vector.transfer_read %C[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32> 194 %d = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 195 affine_map<(d0, d1, d2) -> (d1, d2)>, 196 affine_map<(d0, d1, d2) -> (d0, d1)>], 197 iterator_types = ["parallel", "parallel", "reduction"], 198 kind = #vector.kind<add>} 199 %a, %b, %c : vector<64x4xf32>, vector<64x4xf32> into vector<64x64xf32> 200 %e = vector.transfer_read %D[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32> 201 %r = arith.addf %d, %e : vector<64x64xf32> 202 vector.transfer_write %r, %C[%c0, %c0] : vector<64x64xf32>, memref<?x?xf32> 203 return 204} 205