1// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file -verify-diagnostics | FileCheck %s 2// RUN: mlir-opt %s --test-transform-dialect-interpreter --canonicalize --split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CANON 3 4transform.with_pdl_patterns { 5^bb0(%arg0: !pdl.operation): 6 transform.sequence %arg0 { 7 ^bb1(%arg1: !pdl.operation): 8 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 9 %1:2 = transform.structured.split %0 after 42 { dimension = 0 } 10 } 11} 12 13func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 14 15// CHECK: #[[$ADD_42_MAP:.+]] = affine_map<(d0) -> (d0 + 42)> 16// CHECK: #[[$ADD_10_MAP:.+]] = affine_map<(d0) -> (d0 + 10)> 17 18// CHECK-LABEL: @one_d_static 19// CHECK-SAME: %[[IN:.+]]: tensor<100xf32>, %[[OUT:.+]]: tensor<100xf32> 20func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 21 // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [42] [1] : tensor<100xf32> to tensor<42xf32> 22 // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [42] [1] : tensor<100xf32> to tensor<42xf32> 23 // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic 24 // CHECK: ins(%[[IN_SLICE_LOW]] 25 // CHECK: outs(%[[OUT_SLICE_LOW]] 26 // CHECK: linalg.index 0 27 // CHECK: func.call @elem 28 // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [42] [1] 29 // 30 // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][42] [58] [1] : tensor<100xf32> to tensor<58xf32> 31 // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][42] [58] [1] : tensor<100xf32> to tensor<58xf32> 32 // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic 33 // CHECK: ins(%[[IN_SLICE_HIGH]] 34 // CHECK: outs(%[[OUT_SLICE_HIGH]] 35 // CHECK: %[[IDX:.+]] = linalg.index 0 36 // CHECK: affine.apply #[[$ADD_42_MAP]](%[[IDX]]) 37 // CHECK: func.call @elem 38 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][42] [58] [1] 39 %0 = linalg.generic { 40 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 41 iterator_types = ["parallel"] 42 } 43 ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { 44 ^bb0(%0: f32, %1: f32): 45 %i = linalg.index 0 : index 46 %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32 47 linalg.yield %call_res : f32 48 } -> tensor<100xf32> 49 50 // CHECK: return %[[RES]] 51 return %0 : tensor<100xf32> 52} 53 54// CHECK-LABEL: @one_d_static_overflow 55// CHECK-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32> 56// CANON-LABEL: @one_d_static_overflow 57// CANON-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32> 58func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { 59 // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [10] [1] : tensor<10xf32> to tensor<10xf32> 60 // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [10] [1] : tensor<10xf32> to tensor<10xf32> 61 // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic 62 // CHECK: ins(%[[IN_SLICE_LOW]] 63 // CHECK: outs(%[[OUT_SLICE_LOW]] 64 // CHECK: linalg.index 0 65 // CHECK: func.call @elem 66 // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [10] [1] 67 // 68 // Due to overflow, the first part of the split computes everything and the 69 // insert/extract slices are folded away by the canonicalizer. 70 // CANON: %[[RES_PARTIAL:.+]] = linalg.generic 71 // CANON: ins(%[[IN]] 72 // CANON: outs(%[[OUT]] 73 // CANON: linalg.index 0 74 // CANON: func.call @elem 75 // The second part operates on zero-sized slices that are not currently 76 // folded away. 77 // 78 // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][10] [0] [1] : tensor<10xf32> to tensor<0xf32> 79 // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][10] [0] [1] : tensor<10xf32> to tensor<0xf32> 80 // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic 81 // CHECK: ins(%[[IN_SLICE_HIGH]] 82 // CHECK: outs(%[[OUT_SLICE_HIGH]] 83 // CHECK: %[[IDX:.+]] = linalg.index 0 84 // CHECK: affine.apply #[[$ADD_10_MAP]](%[[IDX]]) 85 // CHECK: func.call @elem 86 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][10] [0] [1] 87 %0 = linalg.generic { 88 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 89 iterator_types = ["parallel"] 90 } 91 ins(%arg0: tensor<10xf32>) outs(%arg1: tensor<10xf32>) { 92 ^bb0(%0: f32, %1: f32): 93 %i = linalg.index 0 : index 94 %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32 95 linalg.yield %call_res : f32 96 } -> tensor<10xf32> 97 return %0 : tensor<10xf32> 98} 99 100// ----- 101 102transform.with_pdl_patterns { 103^bb0(%arg0: !pdl.operation): 104 transform.sequence %arg0 { 105 ^bb1(%arg1: !pdl.operation): 106 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 107 %1 = transform.structured.match ops{["func.call"]} in %arg1 108 transform.structured.split %0 after %1 { dimension = 0 } 109 } 110} 111 112func.func private @get_size() -> index 113 114// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<()[s0] -> (s0, 100)> 115// CHECK: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)> 116 117// CHECK-LABEL: @dynamic 118func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 119 // CHECK: %[[SPLIT:.+]] = call @get_size 120 // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]]()[%[[SPLIT]] 121 // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32> 122 // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32> 123 // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic 124 // CHECK: ins(%[[IN_SLICE_LOW]] 125 // CHECK: outs(%[[OUT_SLICE_LOW]] 126 // CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1] 127 // 128 // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] 129 // CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] 130 // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor<?xf32> 131 // CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] 132 // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[PARTIAL:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] : tensor<100xf32> to tensor<?xf32> 133 // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic 134 // CHECK: ins(%[[IN_SLICE_HIGH]] 135 // CHECK: outs(%[[OUT_SLICE_HIGH]] 136 // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] 137 %0 = func.call @get_size() : () -> index 138 %1 = linalg.generic { 139 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 140 iterator_types = ["parallel"] 141 } 142 ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { 143 ^bb0(%3: f32, %4: f32): 144 %5 = arith.addf %3, %4 : f32 145 linalg.yield %5 : f32 146 } -> tensor<100xf32> 147 return %1 : tensor<100xf32> 148} 149 150// ----- 151 152transform.with_pdl_patterns { 153^bb0(%arg0: !pdl.operation): 154 transform.sequence %arg0 { 155 ^bb1(%arg1: !pdl.operation): 156 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 157 %1:2 = transform.structured.split %0 after 4 { dimension = 0} 158 %2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } 159 } 160} 161 162func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 163 164// CHECK-LABEL: @two_d 165func.func @two_d(%arg0: tensor<10x34xf32>, 166 %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> { 167 // Check the overall structure: split along the dimension 0, and then split 168 // the second half only along the dimension 1. 169 // CHECK: %[[IN_1:.+]] = tensor.extract_slice %[[IN:.+]][0, 0] 170 // CHECK: %[[OUT_1:.+]] = tensor.extract_slice %[[OUT:.+]][0, 0] 171 // CHECK: %[[RES_1:.+]] = linalg.generic 172 // CHECK-SAME: ins(%[[IN_1]] : tensor<4x34xf32>) 173 // CHECK-SAME: outs(%[[OUT_1]] : tensor<4x34xf32>) 174 // CHECK: %[[PARTIAL_1:.+]] = tensor.insert_slice %[[RES_1]] into %[[OUT]] 175 // 176 // CHECK: %[[IN_2:.+]] = tensor.extract_slice %[[IN]] 177 // CHECK: %[[OUT_2:.+]] = tensor.extract_slice %[[PARTIAL_1]] 178 // CHECK: %[[IN_21:.+]] = tensor.extract_slice %[[IN_2]] 179 // CHECK: %[[OUT_21:.+]] = tensor.extract_slice %[[OUT_2]] 180 // CHECK: %[[RES_21:.+]] = linalg.generic 181 // CHECK-SAME: ins(%[[IN_21]] : tensor<6x16xf32>) 182 // CHECK-SAME: outs(%[[OUT_21]] : tensor<6x16xf32>) 183 // CHECK: %[[PARTIAL_21:.+]] = tensor.insert_slice %[[RES_21]] into %[[OUT_2]] 184 // 185 // CHECK: %[[IN_22:.+]] = tensor.extract_slice %[[IN_2]] 186 // CHECK: %[[OUT_22:.+]] = tensor.extract_slice %[[PARTIAL_21]] 187 // CHECK: %[[RES_22:.+]] = linalg.generic 188 // CHECK-SAME: ins(%[[IN_22]] : tensor<6x18xf32>) 189 // CHECK-SAME: outs(%[[OUT_22]] : tensor<6x18xf32>) 190 // CHECK: %[[PARTIAL_22:.+]] = tensor.insert_slice %[[RES_22]] into %[[PARTIAL_21]] 191 // CHECK: %[[PARTIAL_2:.+]] = tensor.insert_slice %[[PARTIAL_22]] into %[[PARTIAL_1]] 192 %0 = linalg.generic { 193 indexing_maps = [affine_map<(i, j) -> (i, j)>, 194 affine_map<(i, j) -> (i, j)>], 195 iterator_types = ["parallel", "parallel"] 196 } 197 ins(%arg0: tensor<10x34xf32>) 198 outs(%arg1: tensor<10x34xf32>) { 199 ^bb0(%0: f32, %1: f32): 200 %i = linalg.index 0 : index 201 %j = linalg.index 1 : index 202 %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32 203 linalg.yield %call_res : f32 204 } -> tensor<10x34xf32> 205 return %0 : tensor<10x34xf32> 206} 207 208// ----- 209 210transform.sequence { 211^bb1(%arg1: !pdl.operation): 212 // expected-error @below {{expects either a dynamic or a static split point to be provided}} 213 %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_split_point = -1 } : (!pdl.operation) -> (!pdl.operation, !pdl.operation) 214} 215 216// ----- 217 218transform.with_pdl_patterns { 219^bb0(%arg0: !pdl.operation): 220 transform.sequence %arg0 { 221 ^bb1(%arg1: !pdl.operation): 222 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 223 %1 = transform.structured.match ops{["func.call"]} in %arg1 224 // expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}} 225 transform.structured.split %0 after %1 { dimension = 0 } 226 } 227} 228 229func.func private @get_size() -> i64 230 231func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 232 // expected-note @below {{dynamic split point}} 233 %0 = func.call @get_size() : () -> i64 234 %1 = linalg.generic { 235 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 236 iterator_types = ["parallel"] 237 } 238 ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { 239 ^bb0(%3: f32, %4: f32): 240 linalg.yield %3 : f32 241 } -> tensor<100xf32> 242 return %1 : tensor<100xf32> 243} 244 245// ----- 246 247transform.with_pdl_patterns { 248^bb0(%arg0: !pdl.operation): 249 transform.sequence %arg0 { 250 ^bb1(%arg1: !pdl.operation): 251 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 252 %1 = transform.structured.match ops{["func.call"]} in %arg1 253 // expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}} 254 transform.structured.split %0 after %1 { dimension = 0 } 255 } 256} 257 258func.func private @get_size() -> i64 259 260func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 261 %1 = linalg.generic { 262 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 263 iterator_types = ["parallel"] 264 } 265 ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { 266 ^bb0(%3: f32, %4: f32): 267 linalg.yield %3 : f32 268 } -> tensor<100xf32> 269 return %1 : tensor<100xf32> 270} 271 272// ----- 273 274transform.with_pdl_patterns { 275^bb0(%arg0: !pdl.operation): 276 pdl.pattern @func_return : benefit(1) { 277 %0 = pdl.operands 278 %1 = pdl.types 279 %2 = pdl.operation "func.return"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>) 280 pdl.rewrite %2 with "transform.dialect" 281 } 282 283 transform.sequence %arg0 { 284 ^bb1(%arg1: !pdl.operation): 285 %0 = transform.structured.match ops{["func.return"]} in %arg1 286 // expected-error @below {{only applies to structured ops}} 287 transform.structured.split %0 after 16 { dimension = 1 } 288 } 289} 290 291func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 292 // expected-note @below {{target op}} 293 return %arg0 : tensor<100xf32> 294} 295 296// ----- 297 298transform.with_pdl_patterns { 299^bb0(%arg0: !pdl.operation): 300 pdl.pattern @linalg_generic : benefit(1) { 301 %0 = pdl.operands 302 %1 = pdl.types 303 %2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>) 304 pdl.rewrite %2 with "transform.dialect" 305 } 306 307 transform.sequence %arg0 { 308 ^bb1(%arg1: !pdl.operation): 309 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 310 // expected-error @below {{dimension 1 does not exist in target op}} 311 transform.structured.split %0 after 16 { dimension = 1 } 312 } 313} 314 315func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 316 // expected-note @below {{target op}} 317 %0 = linalg.generic { 318 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 319 iterator_types = ["parallel"] 320 } 321 ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { 322 ^bb0(%0: f32, %1: f32): 323 linalg.yield %0 : f32 324 } -> tensor<100xf32> 325 return %0 : tensor<100xf32> 326} 327 328