1// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s 2 3// CHECK-DAG: #[[$STRIDED_1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> 4// Map corresponding to a 2D memory access where the stride along the last dim is known to be 1. 5// CHECK-DAG: #[[$STRIDED_2D_u_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> 6// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> 7// CHECK-DAG: #[[$nm:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)> 8// CHECK-DAG: #[[$km:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> 9 10func.func @dot(%x: memref<?xf32, offset: ?, strides: [1]>, 11 %y: memref<?xf32, offset: ?, strides: [1]>, 12 %v: memref<f32>) { 13 linalg.dot { __internal_linalg_transform__ = "MEM" } 14 ins(%x, %y: memref<?xf32, offset: ?, strides: [1]>, 15 memref<?xf32, offset: ?, strides: [1]>) 16 outs(%v: memref<f32>) 17 18 return 19} 20// CHECK-LABEL: func @dot 21// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 22// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index 23// CHECK-DAG: %[[c8000:.*]] = arith.constant 8000 : index 24// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] { 25// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] { 26// CHECK: load 27// CHECK: load 28// CHECK: load 29// CHECK: arith.mulf 30// CHECK: arith.addf 31// CHECK: store 32 33func.func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, 34 %x: memref<?xf32, offset: ?, strides: [1]>, 35 %y: memref<?xf32, offset: ?, strides: [1]>) { 36 linalg.matvec 37 ins(%A, %x: memref<?x?xf32, offset: ?, strides: [?, 1]>, 38 memref<?xf32, offset: ?, strides: [1]>) 39 outs(%y: memref<?xf32, offset: ?, strides: [1]>) 40 return 41} 42// CHECK-LABEL: func @matvec 43// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 44// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index 45// CHECK-DAG: %[[c6:.*]] = arith.constant 6 : index 46// CHECK: scf.parallel {{.*}} step (%[[c5]]) 47// CHECK: scf.for {{.*}} step %[[c6]] 48// CHECK: linalg.matvec 49// CHECK: ins({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>, memref<?xf32, #[[$STRIDED_1D]]>) 50// CHECK: outs({{.*}}: memref<?xf32, #[[$STRIDED_1D]]>) 51 52func.func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, 53 %B: memref<?x?xf32, offset: ?, strides: [?, 1]>, 54 %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) { 55 linalg.matmul { __internal_linalg_transform__ = "MEM" } 56 ins(%A, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>, 57 memref<?x?xf32, offset: ?, strides: [?, 1]>) 58 outs(%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) 59 return 60} 61// CHECK-LABEL: func @matmul 62// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 63// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index 64// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index 65// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index 66// CHECK-DAG: %[[c20:.*]] = arith.constant 20 : index 67// CHECK-DAG: %[[c30:.*]] = arith.constant 30 : index 68// CHECK-DAG: %[[c40:.*]] = arith.constant 40 : index 69// CHECK-DAG: %[[c200:.*]] = arith.constant 200 : index 70// CHECK-DAG: %[[c300:.*]] = arith.constant 300 : index 71// CHECK-DAG: %[[c400:.*]] = arith.constant 400 : index 72// CHECK-DAG: %[[c2000:.*]] = arith.constant 2000 : index 73// CHECK-DAG: %[[c3000:.*]] = arith.constant 3000 : index 74// CHECK-DAG: %[[c4000:.*]] = arith.constant 4000 : index 75// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { 76// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { 77// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { 78// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] { 79// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] { 80// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] { 81// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] { 82// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { 83// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { 84// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] { 85// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] { 86// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { 87// CHECK: linalg.matmul 88// CHECK: ins({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>, memref<?x?xf32, #[[$STRIDED_2D_u_1]]>) 89// CHECK: outs({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>) 90 91#matmul_accesses = [ 92 affine_map<(m, n, k) -> (m, k)>, 93 affine_map<(m, n, k) -> (k, n)>, 94 affine_map<(m, n, k) -> (m, n)> 95] 96#generic_matmul_trait = { 97 args_in = 2, 98 args_out = 1, 99 indexing_maps = #matmul_accesses, 100 library_call = "linalg_matmul", 101 iterator_types = ["parallel", "parallel", "reduction"] 102} 103func.func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, 104 %B: memref<?x?xf32, offset: ?, strides: [?, 1]>, 105 %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) { 106 linalg.generic #generic_matmul_trait 107 ins(%A, %B : memref<?x?xf32, offset: ?, strides: [?, 1]>, 108 memref<?x?xf32, offset: ?, strides: [?, 1]>) 109 outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) { 110 ^bb(%a: f32, %b: f32, %c: f32): 111 %d = arith.mulf %a, %b: f32 112 %e = arith.addf %c, %d: f32 113 linalg.yield %e: f32 114 } 115 return 116} 117// CHECK-LABEL: func @permute_generic 118// CHECK: linalg.generic { 119// CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]], 120// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], 121// CHECK-SAME: library_call = "linalg_matmul"} 122// CHECK: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>, 123// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]> 124// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]> 125 126func.func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, 127 %x: memref<?xf32, offset: ?, strides: [1]>, 128 %y: memref<?xf32, offset: ?, strides: [1]>) { 129 linalg.matvec {__internal_linalg_transform__ = "__with_perm__"} 130 ins(%A, %x: memref<?x?xf32, offset: ?, strides: [?, 1]>, 131 memref<?xf32, offset: ?, strides: [1]>) 132 outs(%y: memref<?xf32, offset: ?, strides: [1]>) 133 return 134} 135// CHECK-LABEL: func @matvec_perm 136// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 137// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index 138// CHECK-DAG: %[[c6:.*]] = arith.constant 6 : index 139// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]] 140// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]] 141// CHECK: linalg.matvec 142// CHECK: ins({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>, memref<?xf32, #[[$STRIDED_1D]]>) 143// CHECK: outs({{.*}}: memref<?xf32, #[[$STRIDED_1D]]>) 144 145func.func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, 146 %B: memref<?x?xf32, offset: ?, strides: [?, 1]>, 147 %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) { 148 linalg.matmul {__internal_linalg_transform__ = "__with_perm__"} 149 ins(%A, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>, 150 memref<?x?xf32, offset: ?, strides: [?, 1]>) 151 outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) 152 return 153} 154// CHECK-LABEL: func @matmul_perm 155// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 156// CHECK-DAG: %[[c20:.*]] = arith.constant 20 : index 157// CHECK-DAG: %[[c30:.*]] = arith.constant 30 : index 158// CHECK-DAG: %[[c40:.*]] = arith.constant 40 : index 159// CHECK-DAG: %[[c200:.*]] = arith.constant 200 : index 160// CHECK-DAG: %[[c300:.*]] = arith.constant 300 : index 161// CHECK-DAG: %[[c400:.*]] = arith.constant 400 : index 162// CHECK-DAG: %[[c2000:.*]] = arith.constant 2000 : index 163// CHECK-DAG: %[[c3000:.*]] = arith.constant 3000 : index 164// CHECK-DAG: %[[c4000:.*]] = arith.constant 4000 : index 165// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { 166// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { 167// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { 168// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] { 169// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] { 170// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] { 171// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] { 172// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { 173// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { 174// CHECK: linalg.matmul 175// CHECK: ins({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>, memref<?x?xf32, #[[$STRIDED_2D_u_1]]>) 176// CHECK: outs({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>) 177 178func.func @tile_permute_parallel_loop(%arg0: memref<?x?xf32>, 179 %arg1: memref<?x?xf32>, 180 %arg2: memref<?x?xf32>) { 181 linalg.matmul {__internal_linalg_transform__ = "par__with_perm__"} 182 ins(%arg0, %arg1: memref<?x?xf32>, memref<?x?xf32>) 183 outs(%arg2: memref<?x?xf32>) 184 return 185} 186// CHECK-LABEL: func @tile_permute_parallel_loop 187// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32> 188// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32> 189// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32> 190// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index 191// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index 192// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 193// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 194// CHECK-DAG: %[[D0:.*]] = memref.dim %[[ARG0]], %c0 195// CHECK-DAG: %[[D1:.*]] = memref.dim %[[ARG0]], %c1 196// CHECK-DAG: %[[D2:.*]] = memref.dim %[[ARG1]], %c1 197// CHECK: scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[D2]]) step (%[[C8]]) 198// CHECK: scf.for %{{.*}} = %[[C0]] to %[[D1]] step %[[C4]] 199// CHECK: scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[D0]]) step (%[[C16]]) 200