1// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction  -split-input-file  | FileCheck %s
2
3func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
4  %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
5                    outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
6  return %0: tensor<16x32xf32>
7}
8
9//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
10//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>
11//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
12//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
13//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
14//  CHECK-LABEL: @matmul_split
15//  CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
16//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32>
17//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32>
18//  CHECK-DAG: %[[INI:.*]] = linalg.init_tensor [16, 32, 4] : tensor<16x32x4xf32>
19//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
20//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
21// CHECK-SAME:   , iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
22// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<16x4x64xf32>, tensor<4x64x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) {
23//      CHECK:   arith.mulf
24//      CHECK:   arith.addf
25//      CHECK:   linalg.yield
26//      CHECK: } -> tensor<16x32x4xf32>
27//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]],
28// CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) {
29//      CHECK:   arith.addf
30//      CHECK:   linalg.yield %{{.*}} : f32
31//      CHECK: } -> tensor<16x32xf32>
32//      CHECK: return %[[R]] : tensor<16x32xf32>
33
34// -----
35
36func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: tensor<f32>) -> tensor<f32> {
37  %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
38                                          affine_map<(d0) -> ()>,
39                                          affine_map<(d0) -> ()>],
40   iterator_types = ["reduction"]}
41   ins(%arg0, %arg1 : tensor<32xf32>, tensor<f32>)
42   outs(%out : tensor<f32>) {
43    ^bb0(%arg7: f32, %arg8: f32, %arg9: f32):
44      %40 = arith.subf %arg7, %arg8 : f32
45      %41 = math.exp %40 : f32
46      %42 = arith.mulf %41, %arg9 : f32
47      linalg.yield %42 : f32
48    } -> tensor<f32>
49  return %red : tensor<f32>
50}
51
52//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
53//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
54//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
55//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)>
56//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
57//CHECK-LABEL: @generic_split_1d
58//      CHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
59//      CHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
60//      CHECK: %[[INI:.*]] = linalg.init_tensor [4] : tensor<4xf32>
61//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
62//      CHECK: %[[G:.*]] = linalg.generic
63//      CHECK:   {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
64//      CHECK:   iterator_types = ["parallel", "reduction"]} ins(%[[I1]], %{{.*}} : tensor<4x8xf32>, tensor<f32>) outs(%[[F]] : tensor<4xf32>) {
65//      CHECK:   arith.subf
66//      CHECK:   math.exp
67//      CHECK:   arith.mulf
68//      CHECK:   linalg.yield
69//      CHECK: } -> tensor<4xf32>
70//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
71//      CHECK:   arith.mulf
72//      CHECK:   linalg.yield
73//      CHECK: } -> tensor<f32>
74//      CHECK: return %[[R]] : tensor<f32>
75
76// -----
77
78func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>)
79  -> tensor<5x2xf32>
80{
81  %0 = linalg.generic {
82      indexing_maps = [
83        affine_map<(d0, d1, d2) -> (d1, d0)>,
84        affine_map<(d0, d1, d2) -> (d2, d1)>,
85        affine_map<(d0, d1, d2) -> (d2, d0)>
86      ],
87      iterator_types = ["parallel", "reduction", "parallel"]
88    } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) {
89    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
90      %3 = arith.addf %arg0, %arg1 : f32
91      %4 = arith.maxf %3, %arg2 : f32
92      linalg.yield %4 : f32
93    } -> tensor<5x2xf32>
94  return %0 : tensor<5x2xf32>
95}
96
97//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)>
98//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
99//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
100//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
101//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
102// CHECK-LABEL:  func @generic_split_3d
103//      CHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32
104//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32>
105//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
106//      CHECK: %[[INI:.*]] = linalg.init_tensor [5, 2, 4] : tensor<5x2x4xf32>
107//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
108//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
109// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
110//      CHECK:   arith.addf
111//      CHECK:   arith.maxf
112//      CHECK:   linalg.yield
113//      CHECK: } -> tensor<5x2x4xf32>
114//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
115// CHECK-SAME:   ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
116//      CHECK:   arith.maxf
117//      CHECK:   linalg.yield
118//      CHECK:  } -> tensor<5x2xf32>
119//      CHECK: return %[[R]] : tensor<5x2xf32>
120