1// RUN: mlir-opt %s -linalg-tile="tile-sizes=2 loop-type=parallel" | FileCheck %s -check-prefix=TILE-2 2// RUN: mlir-opt %s -linalg-tile="tile-sizes=0,2 loop-type=parallel" | FileCheck %s -check-prefix=TILE-02 3// RUN: mlir-opt %s -linalg-tile="tile-sizes=0,0,2 loop-type=parallel" | FileCheck %s -check-prefix=TILE-002 4// RUN: mlir-opt %s -linalg-tile="tile-sizes=2,3,4 loop-type=parallel" | FileCheck %s -check-prefix=TILE-234 5 6#id_2d = affine_map<(i, j) -> (i, j)> 7#pointwise_2d_trait = { 8 args_in = 2, 9 args_out = 1, 10 indexing_maps = [#id_2d, #id_2d, #id_2d], 11 iterator_types = ["parallel", "parallel"] 12} 13 14func.func @sum(%lhs: memref<?x?xf32, offset: ?, strides: [?, 1]>, 15 %rhs: memref<?x?xf32, offset: ?, strides: [?, 1]>, 16 %sum: memref<?x?xf32, offset: ?, strides: [?, 1]>) { 17 linalg.generic #pointwise_2d_trait 18 ins(%lhs, %rhs: memref<?x?xf32, offset: ?, strides: [?, 1]>, 19 memref<?x?xf32, offset: ?, strides: [?, 1]>) 20 outs(%sum : memref<?x?xf32, offset: ?, strides: [?, 1]>) { 21 ^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32): 22 %result = arith.addf %lhs_in, %rhs_in : f32 23 linalg.yield %result : f32 24 } 25 return 26} 27// TILE-2-LABEL: func @sum( 28// TILE-2-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { 29// TILE-2-DAG: [[C0:%.*]] = arith.constant 0 : index 30// TILE-2-DAG: [[C2:%.*]] = arith.constant 2 : index 31// TILE-2: [[LHS_ROWS:%.*]] = memref.dim [[LHS]], %c0 32// TILE-2: scf.parallel ([[I:%.*]]) = ([[C0]]) to ([[LHS_ROWS]]) step ([[C2]]) { 33// TILE-2-NO: scf.parallel 34// TILE-2: [[LHS_SUBVIEW:%.*]] = memref.subview [[LHS]] 35// TILE-2: [[RHS_SUBVIEW:%.*]] = memref.subview [[RHS]] 36// TILE-2: [[SUM_SUBVIEW:%.*]] = memref.subview [[SUM]] 37// TILE-2: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]] 38 39// TILE-02-LABEL: func @sum( 40// TILE-02-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { 41// TILE-02-DAG: [[C0:%.*]] = arith.constant 0 : index 42// TILE-02-DAG: [[C2:%.*]] = arith.constant 2 : index 43// TILE-02: [[LHS_COLS:%.*]] = memref.dim [[LHS]], %c1 44// TILE-02: scf.parallel ([[I:%.*]]) = ([[C0]]) to ([[LHS_COLS]]) step ([[C2]]) { 45// TILE-02-NO: scf.parallel 46// TILE-02: [[LHS_SUBVIEW:%.*]] = memref.subview [[LHS]] 47// TILE-02: [[RHS_SUBVIEW:%.*]] = memref.subview [[RHS]] 48// TILE-02: [[SUM_SUBVIEW:%.*]] = memref.subview [[SUM]] 49// TILE-02: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]] 50 51// TILE-002-LABEL: func @sum( 52// TILE-002-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { 53// TILE-002-NO: scf.parallel 54// TILE-002: linalg.generic {{.*}} ins([[LHS]], [[RHS]]{{.*}} outs([[SUM]] 55 56// TILE-234-LABEL: func @sum( 57// TILE-234-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { 58// TILE-234-DAG: [[C0:%.*]] = arith.constant 0 : index 59// TILE-234-DAG: [[C2:%.*]] = arith.constant 2 : index 60// TILE-234-DAG: [[C3:%.*]] = arith.constant 3 : index 61// TILE-234: [[LHS_ROWS:%.*]] = memref.dim [[LHS]], %c0 62// TILE-234: [[LHS_COLS:%.*]] = memref.dim [[LHS]], %c1 63// TILE-234: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) to ([[LHS_ROWS]], [[LHS_COLS]]) step ([[C2]], [[C3]]) { 64// TILE-234-NO: scf.parallel 65// TILE-234: [[LHS_SUBVIEW:%.*]] = memref.subview [[LHS]] 66// TILE-234: [[RHS_SUBVIEW:%.*]] = memref.subview [[RHS]] 67// TILE-234: [[SUM_SUBVIEW:%.*]] = memref.subview [[SUM]] 68// TILE-234: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]] 69