1// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s
2
3// CHECK-LABEL: func.func @matmul_split
4func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
5
6  //      CHECK: linalg.generic
7  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
8  // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<16x4x64xf32>, tensor<4x64x32xf32>)
9  // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>) {
10
11  //      CHECK: linalg.generic
12  // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
13  // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>)
14  // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32xf32>) {
15  %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
16                    outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
17  return %0: tensor<16x32xf32>
18}
19
20transform.with_pdl_patterns {
21^bb0(%arg0: !pdl.operation):
22  transform.sequence %arg0 {
23  ^bb1(%arg1: !pdl.operation):
24    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
25    %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
26  }
27}
28