1// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction -split-input-file | FileCheck %s 2 3func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { 4 %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) 5 outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 6 return %0: tensor<16x32xf32> 7} 8 9// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> 10// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)> 11// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 12// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 13// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> 14// CHECK-LABEL: @matmul_split 15// CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 16// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32> 17// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32> 18// CHECK-DAG: %[[INI:.*]] = linalg.init_tensor [16, 32, 4] : tensor<16x32x4xf32> 19// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> 20// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] 21// CHECK-SAME: , iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 22// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x4x64xf32>, tensor<4x64x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) { 23// CHECK: arith.mulf 24// CHECK: arith.addf 25// CHECK: linalg.yield 26// CHECK: } -> tensor<16x32x4xf32> 27// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], 28// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) { 29// CHECK: arith.addf 30// CHECK: linalg.yield %{{.*}} : f32 31// CHECK: } -> tensor<16x32xf32> 32// CHECK: return %[[R]] : tensor<16x32xf32> 33 34// ----- 35 36func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: tensor<f32>) -> tensor<f32> { 37 %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, 38 affine_map<(d0) -> ()>, 39 affine_map<(d0) -> ()>], 40 iterator_types = ["reduction"]} 41 ins(%arg0, %arg1 : tensor<32xf32>, tensor<f32>) 42 outs(%out : tensor<f32>) { 43 ^bb0(%arg7: f32, %arg8: f32, %arg9: f32): 44 %40 = arith.subf %arg7, %arg8 : f32 45 %41 = math.exp %40 : f32 46 %42 = arith.mulf %41, %arg9 : f32 47 linalg.yield %42 : f32 48 } -> tensor<f32> 49 return %red : tensor<f32> 50} 51 52// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> 53// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> 54// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> 55// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)> 56// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()> 57//CHECK-LABEL: @generic_split_1d 58// CHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 59// CHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32> 60// CHECK: %[[INI:.*]] = linalg.init_tensor [4] : tensor<4xf32> 61// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> 62// CHECK: %[[G:.*]] = linalg.generic 63// CHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], 64// CHECK: iterator_types = ["parallel", "reduction"]} ins(%[[I1]], %{{.*}} : tensor<4x8xf32>, tensor<f32>) outs(%[[F]] : tensor<4xf32>) { 65// CHECK: arith.subf 66// CHECK: math.exp 67// CHECK: arith.mulf 68// CHECK: linalg.yield 69// CHECK: } -> tensor<4xf32> 70// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) { 71// CHECK: arith.mulf 72// CHECK: linalg.yield 73// CHECK: } -> tensor<f32> 74// CHECK: return %[[R]] : tensor<f32> 75 76// ----- 77 78func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>) 79 -> tensor<5x2xf32> 80{ 81 %0 = linalg.generic { 82 indexing_maps = [ 83 affine_map<(d0, d1, d2) -> (d1, d0)>, 84 affine_map<(d0, d1, d2) -> (d2, d1)>, 85 affine_map<(d0, d1, d2) -> (d2, d0)> 86 ], 87 iterator_types = ["parallel", "reduction", "parallel"] 88 } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) { 89 ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): 90 %3 = arith.addf %arg0, %arg1 : f32 91 %4 = arith.maxf %3, %arg2 : f32 92 linalg.yield %4 : f32 93 } -> tensor<5x2xf32> 94 return %0 : tensor<5x2xf32> 95} 96 97// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)> 98// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> 99// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> 100// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 101// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> 102// CHECK-LABEL: func @generic_split_3d 103// CHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32 104// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32> 105// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> 106// CHECK: %[[INI:.*]] = linalg.init_tensor [5, 2, 4] : tensor<5x2x4xf32> 107// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> 108// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} 109// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { 110// CHECK: arith.addf 111// CHECK: arith.maxf 112// CHECK: linalg.yield 113// CHECK: } -> tensor<5x2x4xf32> 114// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} 115// CHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) { 116// CHECK: arith.maxf 117// CHECK: linalg.yield 118// CHECK: } -> tensor<5x2xf32> 119// CHECK: return %[[R]] : tensor<5x2xf32> 120