1// RUN: mlir-opt %s -linalg-tile="tile-sizes=2 loop-type=parallel" | FileCheck %s -check-prefix=TILE-2
2// RUN: mlir-opt %s -linalg-tile="tile-sizes=0,2 loop-type=parallel" | FileCheck %s -check-prefix=TILE-02
3// RUN: mlir-opt %s -linalg-tile="tile-sizes=0,0,2 loop-type=parallel" | FileCheck %s -check-prefix=TILE-002
4// RUN: mlir-opt %s -linalg-tile="tile-sizes=2,3,4 loop-type=parallel" | FileCheck %s -check-prefix=TILE-234
5
6#id_2d = affine_map<(i, j) -> (i, j)>
7#pointwise_2d_trait = {
8  args_in = 2,
9  args_out = 1,
10  indexing_maps = [#id_2d, #id_2d, #id_2d],
11  iterator_types = ["parallel", "parallel"]
12}
13
14func.func @sum(%lhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
15          %rhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
16          %sum: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
17  linalg.generic #pointwise_2d_trait
18     ins(%lhs, %rhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
19                     memref<?x?xf32, offset: ?, strides: [?, 1]>)
20    outs(%sum : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
21  ^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32):
22    %result = arith.addf %lhs_in, %rhs_in : f32
23    linalg.yield %result : f32
24  }
25  return
26}
27// TILE-2-LABEL: func @sum(
28// TILE-2-SAME:    [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
29// TILE-2-DAG: [[C0:%.*]] = arith.constant 0 : index
30// TILE-2-DAG: [[C2:%.*]] = arith.constant 2 : index
31// TILE-2: [[LHS_ROWS:%.*]] = memref.dim [[LHS]], %c0
32// TILE-2: scf.parallel ([[I:%.*]]) = ([[C0]]) to ([[LHS_ROWS]]) step ([[C2]]) {
33// TILE-2-NO: scf.parallel
34// TILE-2:   [[LHS_SUBVIEW:%.*]] = memref.subview [[LHS]]
35// TILE-2:   [[RHS_SUBVIEW:%.*]] = memref.subview [[RHS]]
36// TILE-2:   [[SUM_SUBVIEW:%.*]] = memref.subview [[SUM]]
37// TILE-2:   linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]]
38
39// TILE-02-LABEL: func @sum(
40// TILE-02-SAME:    [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
41// TILE-02-DAG: [[C0:%.*]] = arith.constant 0 : index
42// TILE-02-DAG: [[C2:%.*]] = arith.constant 2 : index
43// TILE-02: [[LHS_COLS:%.*]] = memref.dim [[LHS]], %c1
44// TILE-02: scf.parallel ([[I:%.*]]) = ([[C0]]) to ([[LHS_COLS]]) step ([[C2]]) {
45// TILE-02-NO: scf.parallel
46// TILE-02:   [[LHS_SUBVIEW:%.*]] = memref.subview [[LHS]]
47// TILE-02:   [[RHS_SUBVIEW:%.*]] = memref.subview [[RHS]]
48// TILE-02:   [[SUM_SUBVIEW:%.*]] = memref.subview [[SUM]]
49// TILE-02:    linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]]
50
51// TILE-002-LABEL: func @sum(
52// TILE-002-SAME:    [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
53// TILE-002-NO: scf.parallel
54// TILE-002:   linalg.generic {{.*}} ins([[LHS]], [[RHS]]{{.*}} outs([[SUM]]
55
56// TILE-234-LABEL: func @sum(
57// TILE-234-SAME:    [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
58// TILE-234-DAG: [[C0:%.*]] = arith.constant 0 : index
59// TILE-234-DAG: [[C2:%.*]] = arith.constant 2 : index
60// TILE-234-DAG: [[C3:%.*]] = arith.constant 3 : index
61// TILE-234: [[LHS_ROWS:%.*]] = memref.dim [[LHS]], %c0
62// TILE-234: [[LHS_COLS:%.*]] = memref.dim [[LHS]], %c1
63// TILE-234: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) to ([[LHS_ROWS]], [[LHS_COLS]]) step ([[C2]], [[C3]]) {
64// TILE-234-NO: scf.parallel
65// TILE-234:   [[LHS_SUBVIEW:%.*]] = memref.subview [[LHS]]
66// TILE-234:   [[RHS_SUBVIEW:%.*]] = memref.subview [[RHS]]
67// TILE-234:   [[SUM_SUBVIEW:%.*]] = memref.subview [[SUM]]
68// TILE-234:   linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]]
69