1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
21 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
22 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/Vector/IR/VectorOps.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallVector.h"
30
31 using namespace mlir;
32 using namespace mlir::linalg;
33
34 namespace {
35 struct TestLinalgTransforms
36 : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> {
37 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms)
38
39 TestLinalgTransforms() = default;
TestLinalgTransforms__anon52d1bf190111::TestLinalgTransforms40 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
41
getDependentDialects__anon52d1bf190111::TestLinalgTransforms42 void getDependentDialects(DialectRegistry ®istry) const override {
43 // clang-format off
44 registry.insert<AffineDialect,
45 bufferization::BufferizationDialect,
46 memref::MemRefDialect,
47 scf::SCFDialect,
48 linalg::LinalgDialect,
49 vector::VectorDialect,
50 gpu::GPUDialect>();
51 // clang-format on
52 }
getArgument__anon52d1bf190111::TestLinalgTransforms53 StringRef getArgument() const final {
54 return "test-linalg-transform-patterns";
55 }
getDescription__anon52d1bf190111::TestLinalgTransforms56 StringRef getDescription() const final {
57 return "Test Linalg transformation patterns by applying them greedily.";
58 }
59
60 void runOnOperation() override;
61
62 Option<bool> testPatterns{*this, "test-patterns",
63 llvm::cl::desc("Test a mixed set of patterns"),
64 llvm::cl::init(false)};
65 Option<bool> testTileAndDistributionOptions{
66 *this, "test-tile-and-distribute-options",
67 llvm::cl::desc("Test tile and distribute options"),
68 llvm::cl::init(false)};
69 Option<bool> testTileFuseAndDistributionOptions{
70 *this, "test-tile-fuse-and-distribute-options",
71 llvm::cl::desc("Test tile, fuse and distribute options"),
72 llvm::cl::init(false)};
73 Option<bool> testVectorTransferForwardingPatterns{
74 *this, "test-vector-transfer-forwarding-patterns",
75 llvm::cl::desc(
76 "Test a fused pass that forwards memref.copy to vector.transfer"),
77 llvm::cl::init(false)};
78 Option<bool> testGenericToVectorPattern{
79 *this, "test-linalg-to-vector-patterns",
80 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
81 "in vector.contract form"),
82 llvm::cl::init(false)};
83 Option<bool> testTilePattern{*this, "test-tile-pattern",
84 llvm::cl::desc("Test tile pattern"),
85 llvm::cl::init(false)};
86 Option<bool> testTileScalarizeDynamicDims{
87 *this, "test-tile-scalarize-dynamic-dims",
88 llvm::cl::desc("Test tiling of dynamic dims by 1"),
89 llvm::cl::init(false)};
90 Option<bool> testTransformPadTensor{
91 *this, "test-transform-pad-tensor",
92 llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
93 llvm::cl::init(false)};
94 Option<bool> testGeneralizePadTensor{
95 *this, "test-generalize-pad-tensor",
96 llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
97 llvm::cl::init(false)};
98 Option<bool> testSwapSubTensorPadTensor{
99 *this, "test-swap-subtensor-padtensor",
100 llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
101 "tensor.pad(subtensor)"),
102 llvm::cl::init(false)};
103 Option<bool> testSplitReduction{
104 *this, "test-split-reduction",
105 llvm::cl::desc("Test split reduction transformation"),
106 llvm::cl::init(false)};
107 ListOption<int64_t> peeledLoops{
108 *this, "peeled-loops",
109 llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
110 ListOption<int64_t> tileSizes{
111 *this, "tile-sizes",
112 llvm::cl::desc("Linalg tile sizes for test-tile-pattern")};
113 Option<bool> skipPartial{
114 *this, "skip-partial",
115 llvm::cl::desc("Skip loops inside partial iterations during peeling"),
116 llvm::cl::init(false)};
117 Option<std::string> loopType{
118 *this, "loop-type",
119 llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
120 "tiled_loop"),
121 llvm::cl::init("for")};
122 Option<bool> testBubbleUpExtractSliceOpPattern{
123 *this, "test-bubble-up-extract-slice-op-pattern",
124 llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
125 "extract_slice + linalgOp"),
126 llvm::cl::init(false)};
127 };
128 } // namespace
129
applyPatterns(func::FuncOp funcOp)130 static void applyPatterns(func::FuncOp funcOp) {
131 MLIRContext *ctx = funcOp.getContext();
132 RewritePatternSet patterns(ctx);
133
134 //===--------------------------------------------------------------------===//
135 // Linalg tiling patterns.
136 //===--------------------------------------------------------------------===//
137 patterns.add<LinalgTilingPattern>(
138 MatmulOp::getOperationName(), ctx,
139 LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
140 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
141 StringAttr::get(ctx, "L3")));
142 patterns.add<LinalgTilingPattern>(
143 MatmulOp::getOperationName(), ctx,
144 LinalgTilingOptions().setTileSizes({200, 300, 400}),
145 LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
146 StringAttr::get(ctx, "L2")));
147 patterns.add<LinalgTilingPattern>(
148 MatmulOp::getOperationName(), ctx,
149 LinalgTilingOptions().setTileSizes({20, 30, 40}),
150 LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
151 StringAttr::get(ctx, "L1")));
152 patterns.add<LinalgTilingPattern>(
153 MatmulOp::getOperationName(), ctx,
154 LinalgTilingOptions().setTileSizes({2, 3, 4}),
155 LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
156 StringAttr::get(ctx, "REG")));
157
158 patterns.add<LinalgTilingPattern>(
159 MatvecOp::getOperationName(), ctx,
160 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
161 LinalgTilingLoopType::ParallelLoops),
162 LinalgTransformationFilter(ArrayRef<StringAttr>{},
163 StringAttr::get(ctx, "L1")));
164
165 patterns.add<LinalgTilingPattern>(
166 DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
167 LinalgTransformationFilter(
168 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
169 StringAttr::get(ctx, "L3"),
170 StringAttr::get(ctx, "L2")},
171 StringAttr::get(ctx, "REG")));
172
173 //===--------------------------------------------------------------------===//
174 // Linalg tiling and permutation patterns.
175 //===--------------------------------------------------------------------===//
176 patterns.add<LinalgTilingPattern>(
177 MatmulOp::getOperationName(), ctx,
178 LinalgTilingOptions()
179 .setTileSizes({2000, 3000, 4000})
180 .setInterchange({1, 2, 0}),
181 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
182 StringAttr::get(ctx, "L2__with_perm__")));
183 patterns.add<LinalgTilingPattern>(
184 MatmulOp::getOperationName(), ctx,
185 LinalgTilingOptions()
186 .setTileSizes({200, 300, 400})
187 .setInterchange({1, 0, 2}),
188 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
189 StringAttr::get(ctx, "L1__with_perm__")));
190 patterns.add<LinalgTilingPattern>(
191 MatmulOp::getOperationName(), ctx,
192 LinalgTilingOptions().setTileSizes({20, 30, 40}),
193 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
194 StringAttr::get(ctx, "REG__with_perm__")));
195
196 patterns.add<LinalgTilingPattern>(
197 MatvecOp::getOperationName(), ctx,
198 LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
199 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
200 StringAttr::get(ctx, "L1__with_perm__")));
201
202 patterns.add<LinalgTilingPattern>(
203 MatmulOp::getOperationName(), ctx,
204 LinalgTilingOptions()
205 .setTileSizes({16, 8, 4})
206 .setInterchange({1, 2, 0})
207 .setLoopType(LinalgTilingLoopType::ParallelLoops),
208 LinalgTransformationFilter(
209 StringAttr::get(ctx, "par__with_perm__"),
210 StringAttr::get(ctx, "after_par__with_perm__")));
211
212 //===--------------------------------------------------------------------===//
213 // Linalg to loops patterns.
214 //===--------------------------------------------------------------------===//
215 patterns.add<LinalgLoweringPattern<DotOp>>(
216 ctx,
217 /*loweringType=*/LinalgLoweringType::Loops,
218 LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
219
220 //===--------------------------------------------------------------------===//
221 // Linalg distribution patterns.
222 //===--------------------------------------------------------------------===//
223 LinalgLoopDistributionOptions distributionOptions;
224
225 //===--------------------------------------------------------------------===//
226 // Linalg to vector contraction patterns.
227 //===--------------------------------------------------------------------===//
228 patterns.add<LinalgVectorizationPattern>(
229 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
230 .addOpFilter<MatmulOp, FillOp, GenericOp>());
231 patterns.add<CopyVectorizationPattern>(ctx);
232
233 //===--------------------------------------------------------------------===//
234 // Linalg generic interchange pattern.
235 //===--------------------------------------------------------------------===//
236 patterns.add<GenericOpInterchangePattern>(
237 ctx,
238 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
239 LinalgTransformationFilter(ArrayRef<StringAttr>{},
240 StringAttr::get(ctx, "PERMUTED")));
241
242 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
243
244 // Drop the marker.
245 funcOp.walk([](LinalgOp op) {
246 op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
247 });
248 }
249
250 template <typename IdOp, typename NProcsOp>
251 static SmallVector<ProcInfo, 2>
getGpuProcIds(OpBuilder & b,Location loc,ArrayRef<Range> parallelLoopRanges)252 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
253 size_t count = std::min<size_t>(3, parallelLoopRanges.size());
254 SmallVector<ProcInfo, 2> procInfo(count);
255 Type indexType = b.getIndexType();
256 for (unsigned i = 0; i < count; ++i) {
257 gpu::Dimension dim = *gpu::symbolizeDimension(i);
258 procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
259 b.create<NProcsOp>(loc, indexType, dim)};
260 }
261 return procInfo;
262 }
263
fillTileAndDistributePatterns(MLIRContext * context,RewritePatternSet & patterns)264 static void fillTileAndDistributePatterns(MLIRContext *context,
265 RewritePatternSet &patterns) {
266 {
267 LinalgLoopDistributionOptions cyclicNprocsEqNiters;
268 cyclicNprocsEqNiters.distributionMethod.resize(
269 2, DistributionMethod::CyclicNumProcsEqNumIters);
270 cyclicNprocsEqNiters.procInfo =
271 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
272 patterns.add<LinalgTilingPattern>(
273 MatmulOp::getOperationName(), context,
274 LinalgTilingOptions()
275 .setTileSizes({8, 8, 4})
276 .setLoopType(LinalgTilingLoopType::ParallelLoops)
277 .setDistributionOptions(cyclicNprocsEqNiters),
278 LinalgTransformationFilter(
279 StringAttr::get(context, "distribute1"),
280 StringAttr::get(context, "after_distribute1")));
281 }
282
283 {
284 LinalgLoopDistributionOptions cyclicNprocsGeNiters;
285 cyclicNprocsGeNiters.distributionMethod.resize(
286 2, DistributionMethod::CyclicNumProcsGeNumIters);
287 cyclicNprocsGeNiters.procInfo =
288 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
289 patterns.add<LinalgTilingPattern>(
290 MatmulOp::getOperationName(), context,
291 LinalgTilingOptions()
292 .setTileSizes({8, 8, 4})
293 .setLoopType(LinalgTilingLoopType::ParallelLoops)
294 .setDistributionOptions(cyclicNprocsGeNiters),
295 LinalgTransformationFilter(
296 StringAttr::get(context, "distribute2"),
297 StringAttr::get(context, "after_distribute2")));
298 }
299
300 {
301 LinalgLoopDistributionOptions cyclicNprocsDefault;
302 cyclicNprocsDefault.distributionMethod.resize(2,
303 DistributionMethod::Cyclic);
304 cyclicNprocsDefault.procInfo =
305 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
306 patterns.add<LinalgTilingPattern>(
307 MatmulOp::getOperationName(), context,
308 LinalgTilingOptions()
309 .setTileSizes({8, 8, 4})
310 .setLoopType(LinalgTilingLoopType::ParallelLoops)
311 .setDistributionOptions(cyclicNprocsDefault),
312 LinalgTransformationFilter(
313 StringAttr::get(context, "distribute3"),
314 StringAttr::get(context, "after_distribute3")));
315 }
316
317 {
318 LinalgLoopDistributionOptions cyclicNprocsMixed1;
319 cyclicNprocsMixed1.distributionMethod = {
320 DistributionMethod::CyclicNumProcsEqNumIters,
321 DistributionMethod::CyclicNumProcsGeNumIters};
322 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
323 patterns.add<LinalgTilingPattern>(
324 MatmulOp::getOperationName(), context,
325 LinalgTilingOptions()
326 .setTileSizes({8, 8, 4})
327 .setLoopType(LinalgTilingLoopType::ParallelLoops)
328 .setDistributionOptions(cyclicNprocsMixed1),
329 LinalgTransformationFilter(
330 StringAttr::get(context, "distribute4"),
331 StringAttr::get(context, "after_distribute4")));
332 }
333
334 {
335 LinalgLoopDistributionOptions cyclicNprocsMixed2;
336 cyclicNprocsMixed2.distributionMethod = {
337 DistributionMethod::CyclicNumProcsGeNumIters,
338 DistributionMethod::Cyclic};
339 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
340 patterns.add<LinalgTilingPattern>(
341 MatmulOp::getOperationName(), context,
342 LinalgTilingOptions()
343 .setTileSizes({8, 8, 4})
344 .setLoopType(LinalgTilingLoopType::ParallelLoops)
345 .setDistributionOptions(cyclicNprocsMixed2),
346 LinalgTransformationFilter(
347 StringAttr::get(context, "distribute5"),
348 StringAttr::get(context, "after_distribute5")));
349 }
350
351 {
352 LinalgLoopDistributionOptions cyclicNprocsMixed3;
353 cyclicNprocsMixed3.distributionMethod = {
354 DistributionMethod::Cyclic,
355 DistributionMethod::CyclicNumProcsEqNumIters};
356 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
357
358 patterns.add<LinalgTilingPattern>(
359 MatmulOp::getOperationName(), context,
360 LinalgTilingOptions()
361 .setTileSizes({8, 8, 4})
362 .setLoopType(LinalgTilingLoopType::ParallelLoops)
363 .setDistributionOptions(cyclicNprocsMixed3),
364 LinalgTransformationFilter(
365 StringAttr::get(context, "distribute6"),
366 StringAttr::get(context, "after_distribute6")));
367 }
368
369 {
370 LinalgLoopDistributionOptions cyclicNprocsEqNiters;
371 cyclicNprocsEqNiters.distributionMethod.resize(2,
372 DistributionMethod::Cyclic);
373 cyclicNprocsEqNiters.procInfo =
374 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
375 patterns.add<LinalgTilingPattern>(
376 MatmulOp::getOperationName(), context,
377 LinalgTilingOptions()
378 .setTileSizes({8, 8, 4})
379 .setLoopType(LinalgTilingLoopType::Loops)
380 .setDistributionOptions(cyclicNprocsEqNiters),
381 LinalgTransformationFilter(
382 StringAttr::get(context, "tensors_distribute1"),
383 StringAttr::get(context, "tensors_after_distribute1")));
384 }
385 }
386
fillTileFuseAndDistributePatterns(MLIRContext * context,RewritePatternSet & patterns)387 static void fillTileFuseAndDistributePatterns(MLIRContext *context,
388 RewritePatternSet &patterns) {
389 LinalgLoopDistributionOptions cyclicNprocsEqNiters;
390 cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
391 cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
392 patterns.add<LinalgTileAndFuseTensorOpsPattern>(
393 MatmulOp::getOperationName(), context,
394 LinalgTilingAndFusionOptions()
395 .setTileSizes({8, 8, 4})
396 .setDistributionOptions(cyclicNprocsEqNiters),
397 LinalgTransformationFilter(
398 StringAttr::get(context, "tensors_fuse_distribute1"),
399 StringAttr::get(context, "tensors_after_fuse_distribute1")));
400 }
401
applyVectorTransferForwardingPatterns(func::FuncOp funcOp)402 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
403 RewritePatternSet forwardPattern(funcOp.getContext());
404 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
405 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
406 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
407 }
408
applyLinalgToVectorPatterns(func::FuncOp funcOp)409 static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
410 RewritePatternSet patterns(funcOp.getContext());
411 auto *ctx = funcOp.getContext();
412 patterns.add<LinalgVectorizationPattern>(
413 ctx, LinalgTransformationFilter()
414 .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
415 patterns.add<CopyVectorizationPattern>(ctx);
416 populatePadOpVectorizationPatterns(patterns);
417 populateConvolutionVectorizationPatterns(patterns);
418 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
419 }
420
applyPadTensorToGenericPatterns(func::FuncOp funcOp)421 static void applyPadTensorToGenericPatterns(func::FuncOp funcOp) {
422 RewritePatternSet patterns(funcOp.getContext());
423 patterns.add<PadOpTransformationPattern>(funcOp.getContext());
424 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
425 }
426
applyGeneralizePadTensorPatterns(func::FuncOp funcOp)427 static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
428 RewritePatternSet patterns(funcOp.getContext());
429 patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
430 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
431 }
432
applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp)433 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
434 RewritePatternSet patterns(funcOp.getContext());
435 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
436 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
437 }
438
applyTilePattern(func::FuncOp funcOp,const std::string & loopType,ArrayRef<int64_t> tileSizes,ArrayRef<int64_t> peeledLoops,bool scalarizeDynamicDims)439 static void applyTilePattern(func::FuncOp funcOp, const std::string &loopType,
440 ArrayRef<int64_t> tileSizes,
441 ArrayRef<int64_t> peeledLoops,
442 bool scalarizeDynamicDims) {
443 MLIRContext *context = funcOp.getContext();
444 RewritePatternSet tilingPattern(context);
445 LinalgTilingLoopType type =
446 llvm::StringSwitch<LinalgTilingLoopType>(loopType)
447 .Case("for", LinalgTilingLoopType::Loops)
448 .Case("affine", LinalgTilingLoopType::AffineLoops)
449 .Case("parallel", LinalgTilingLoopType::ParallelLoops);
450 auto linalgTilingOptions = linalg::LinalgTilingOptions()
451 .setPeeledLoops(peeledLoops)
452 .setLoopType(type);
453 if (scalarizeDynamicDims) {
454 linalgTilingOptions.scalarizeDynamicDims();
455 assert(tileSizes.empty() &&
456 "tileSizes and scalarizeDynamicDims is mutually exclusive");
457 } else {
458 linalgTilingOptions.setTileSizes(tileSizes);
459 }
460 linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
461 TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
462 tilingPattern, linalgTilingOptions, f);
463 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
464 }
465
applySplitReduction(func::FuncOp funcOp)466 static void applySplitReduction(func::FuncOp funcOp) {
467 RewritePatternSet patterns(funcOp.getContext());
468 linalg::populateSplitReductionPattern(
469 patterns,
470 [](LinalgOp op) {
471 unsigned insertDimIndex = op.getNumLoops() - 1;
472 return std::make_pair(4, insertDimIndex);
473 },
474 LinalgTransformationFilter(
475 ArrayRef<StringAttr>{},
476 StringAttr::get(funcOp.getContext(), "SPLIT")));
477 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
478 }
479
applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp)480 static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
481 RewritePatternSet patterns(funcOp.getContext());
482 populateBubbleUpExtractSliceOpPatterns(patterns);
483 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
484 }
485
486 /// Apply transformations specified as patterns.
runOnOperation()487 void TestLinalgTransforms::runOnOperation() {
488 auto lambda = [&](void *) {
489 getOperation().walk([](LinalgOp op) {
490 op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
491 });
492 };
493 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
494
495 if (testTileAndDistributionOptions) {
496 RewritePatternSet patterns(&getContext());
497 fillTileAndDistributePatterns(&getContext(), patterns);
498 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
499 return;
500 }
501 if (testTileFuseAndDistributionOptions) {
502 RewritePatternSet patterns(&getContext());
503 fillTileFuseAndDistributePatterns(&getContext(), patterns);
504 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
505 return;
506 }
507 if (testPatterns)
508 return applyPatterns(getOperation());
509 if (testVectorTransferForwardingPatterns)
510 return applyVectorTransferForwardingPatterns(getOperation());
511 if (testGenericToVectorPattern)
512 return applyLinalgToVectorPatterns(getOperation());
513 if (testTransformPadTensor)
514 return applyPadTensorToGenericPatterns(getOperation());
515 if (testGeneralizePadTensor)
516 return applyGeneralizePadTensorPatterns(getOperation());
517 if (testSwapSubTensorPadTensor)
518 return applyExtractSliceOfPadTensorSwapPattern(getOperation());
519 if (testTilePattern)
520 return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
521 /*scalarizeDynamicDims=*/false);
522 if (testTileScalarizeDynamicDims)
523 return applyTilePattern(getOperation(), loopType, tileSizes,
524 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
525 if (testSplitReduction)
526 return applySplitReduction(getOperation());
527 if (testBubbleUpExtractSliceOpPattern)
528 return applyBubbleUpExtractSliceOpPattern(getOperation());
529 }
530
531 namespace mlir {
532 namespace test {
registerTestLinalgTransforms()533 void registerTestLinalgTransforms() {
534 PassRegistration<TestLinalgTransforms>();
535 }
536 } // namespace test
537 } // namespace mlir
538