1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
21 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
22 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/Vector/IR/VectorOps.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallVector.h"
30 
31 using namespace mlir;
32 using namespace mlir::linalg;
33 
34 namespace {
35 struct TestLinalgTransforms
36     : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> {
37   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms)
38 
39   TestLinalgTransforms() = default;
TestLinalgTransforms__anon52d1bf190111::TestLinalgTransforms40   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
41 
getDependentDialects__anon52d1bf190111::TestLinalgTransforms42   void getDependentDialects(DialectRegistry &registry) const override {
43     // clang-format off
44     registry.insert<AffineDialect,
45                     bufferization::BufferizationDialect,
46                     memref::MemRefDialect,
47                     scf::SCFDialect,
48                     linalg::LinalgDialect,
49                     vector::VectorDialect,
50                     gpu::GPUDialect>();
51     // clang-format on
52   }
getArgument__anon52d1bf190111::TestLinalgTransforms53   StringRef getArgument() const final {
54     return "test-linalg-transform-patterns";
55   }
getDescription__anon52d1bf190111::TestLinalgTransforms56   StringRef getDescription() const final {
57     return "Test Linalg transformation patterns by applying them greedily.";
58   }
59 
60   void runOnOperation() override;
61 
62   Option<bool> testPatterns{*this, "test-patterns",
63                             llvm::cl::desc("Test a mixed set of patterns"),
64                             llvm::cl::init(false)};
65   Option<bool> testTileAndDistributionOptions{
66       *this, "test-tile-and-distribute-options",
67       llvm::cl::desc("Test tile and distribute options"),
68       llvm::cl::init(false)};
69   Option<bool> testTileFuseAndDistributionOptions{
70       *this, "test-tile-fuse-and-distribute-options",
71       llvm::cl::desc("Test tile, fuse and distribute options"),
72       llvm::cl::init(false)};
73   Option<bool> testVectorTransferForwardingPatterns{
74       *this, "test-vector-transfer-forwarding-patterns",
75       llvm::cl::desc(
76           "Test a fused pass that forwards memref.copy to vector.transfer"),
77       llvm::cl::init(false)};
78   Option<bool> testGenericToVectorPattern{
79       *this, "test-linalg-to-vector-patterns",
80       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
81                      "in vector.contract form"),
82       llvm::cl::init(false)};
83   Option<bool> testTilePattern{*this, "test-tile-pattern",
84                                llvm::cl::desc("Test tile pattern"),
85                                llvm::cl::init(false)};
86   Option<bool> testTileScalarizeDynamicDims{
87       *this, "test-tile-scalarize-dynamic-dims",
88       llvm::cl::desc("Test tiling of dynamic dims by 1"),
89       llvm::cl::init(false)};
90   Option<bool> testTransformPadTensor{
91       *this, "test-transform-pad-tensor",
92       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
93       llvm::cl::init(false)};
94   Option<bool> testGeneralizePadTensor{
95       *this, "test-generalize-pad-tensor",
96       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
97       llvm::cl::init(false)};
98   Option<bool> testSwapSubTensorPadTensor{
99       *this, "test-swap-subtensor-padtensor",
100       llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
101                      "tensor.pad(subtensor)"),
102       llvm::cl::init(false)};
103   Option<bool> testSplitReduction{
104       *this, "test-split-reduction",
105       llvm::cl::desc("Test split reduction transformation"),
106       llvm::cl::init(false)};
107   ListOption<int64_t> peeledLoops{
108       *this, "peeled-loops",
109       llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
110   ListOption<int64_t> tileSizes{
111       *this, "tile-sizes",
112       llvm::cl::desc("Linalg tile sizes for test-tile-pattern")};
113   Option<bool> skipPartial{
114       *this, "skip-partial",
115       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
116       llvm::cl::init(false)};
117   Option<std::string> loopType{
118       *this, "loop-type",
119       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
120                      "tiled_loop"),
121       llvm::cl::init("for")};
122   Option<bool> testBubbleUpExtractSliceOpPattern{
123       *this, "test-bubble-up-extract-slice-op-pattern",
124       llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
125                      "extract_slice + linalgOp"),
126       llvm::cl::init(false)};
127 };
128 } // namespace
129 
applyPatterns(func::FuncOp funcOp)130 static void applyPatterns(func::FuncOp funcOp) {
131   MLIRContext *ctx = funcOp.getContext();
132   RewritePatternSet patterns(ctx);
133 
134   //===--------------------------------------------------------------------===//
135   // Linalg tiling patterns.
136   //===--------------------------------------------------------------------===//
137   patterns.add<LinalgTilingPattern>(
138       MatmulOp::getOperationName(), ctx,
139       LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
140       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
141                                  StringAttr::get(ctx, "L3")));
142   patterns.add<LinalgTilingPattern>(
143       MatmulOp::getOperationName(), ctx,
144       LinalgTilingOptions().setTileSizes({200, 300, 400}),
145       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
146                                  StringAttr::get(ctx, "L2")));
147   patterns.add<LinalgTilingPattern>(
148       MatmulOp::getOperationName(), ctx,
149       LinalgTilingOptions().setTileSizes({20, 30, 40}),
150       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
151                                  StringAttr::get(ctx, "L1")));
152   patterns.add<LinalgTilingPattern>(
153       MatmulOp::getOperationName(), ctx,
154       LinalgTilingOptions().setTileSizes({2, 3, 4}),
155       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
156                                  StringAttr::get(ctx, "REG")));
157 
158   patterns.add<LinalgTilingPattern>(
159       MatvecOp::getOperationName(), ctx,
160       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
161           LinalgTilingLoopType::ParallelLoops),
162       LinalgTransformationFilter(ArrayRef<StringAttr>{},
163                                  StringAttr::get(ctx, "L1")));
164 
165   patterns.add<LinalgTilingPattern>(
166       DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
167       LinalgTransformationFilter(
168           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
169                                StringAttr::get(ctx, "L3"),
170                                StringAttr::get(ctx, "L2")},
171           StringAttr::get(ctx, "REG")));
172 
173   //===--------------------------------------------------------------------===//
174   // Linalg tiling and permutation patterns.
175   //===--------------------------------------------------------------------===//
176   patterns.add<LinalgTilingPattern>(
177       MatmulOp::getOperationName(), ctx,
178       LinalgTilingOptions()
179           .setTileSizes({2000, 3000, 4000})
180           .setInterchange({1, 2, 0}),
181       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
182                                  StringAttr::get(ctx, "L2__with_perm__")));
183   patterns.add<LinalgTilingPattern>(
184       MatmulOp::getOperationName(), ctx,
185       LinalgTilingOptions()
186           .setTileSizes({200, 300, 400})
187           .setInterchange({1, 0, 2}),
188       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
189                                  StringAttr::get(ctx, "L1__with_perm__")));
190   patterns.add<LinalgTilingPattern>(
191       MatmulOp::getOperationName(), ctx,
192       LinalgTilingOptions().setTileSizes({20, 30, 40}),
193       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
194                                  StringAttr::get(ctx, "REG__with_perm__")));
195 
196   patterns.add<LinalgTilingPattern>(
197       MatvecOp::getOperationName(), ctx,
198       LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
199       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
200                                  StringAttr::get(ctx, "L1__with_perm__")));
201 
202   patterns.add<LinalgTilingPattern>(
203       MatmulOp::getOperationName(), ctx,
204       LinalgTilingOptions()
205           .setTileSizes({16, 8, 4})
206           .setInterchange({1, 2, 0})
207           .setLoopType(LinalgTilingLoopType::ParallelLoops),
208       LinalgTransformationFilter(
209           StringAttr::get(ctx, "par__with_perm__"),
210           StringAttr::get(ctx, "after_par__with_perm__")));
211 
212   //===--------------------------------------------------------------------===//
213   // Linalg to loops patterns.
214   //===--------------------------------------------------------------------===//
215   patterns.add<LinalgLoweringPattern<DotOp>>(
216       ctx,
217       /*loweringType=*/LinalgLoweringType::Loops,
218       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
219 
220   //===--------------------------------------------------------------------===//
221   // Linalg distribution patterns.
222   //===--------------------------------------------------------------------===//
223   LinalgLoopDistributionOptions distributionOptions;
224 
225   //===--------------------------------------------------------------------===//
226   // Linalg to vector contraction patterns.
227   //===--------------------------------------------------------------------===//
228   patterns.add<LinalgVectorizationPattern>(
229       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
230                .addOpFilter<MatmulOp, FillOp, GenericOp>());
231   patterns.add<CopyVectorizationPattern>(ctx);
232 
233   //===--------------------------------------------------------------------===//
234   // Linalg generic interchange pattern.
235   //===--------------------------------------------------------------------===//
236   patterns.add<GenericOpInterchangePattern>(
237       ctx,
238       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
239       LinalgTransformationFilter(ArrayRef<StringAttr>{},
240                                  StringAttr::get(ctx, "PERMUTED")));
241 
242   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
243 
244   // Drop the marker.
245   funcOp.walk([](LinalgOp op) {
246     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
247   });
248 }
249 
250 template <typename IdOp, typename NProcsOp>
251 static SmallVector<ProcInfo, 2>
getGpuProcIds(OpBuilder & b,Location loc,ArrayRef<Range> parallelLoopRanges)252 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
253   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
254   SmallVector<ProcInfo, 2> procInfo(count);
255   Type indexType = b.getIndexType();
256   for (unsigned i = 0; i < count; ++i) {
257     gpu::Dimension dim = *gpu::symbolizeDimension(i);
258     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
259                                b.create<NProcsOp>(loc, indexType, dim)};
260   }
261   return procInfo;
262 }
263 
fillTileAndDistributePatterns(MLIRContext * context,RewritePatternSet & patterns)264 static void fillTileAndDistributePatterns(MLIRContext *context,
265                                           RewritePatternSet &patterns) {
266   {
267     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
268     cyclicNprocsEqNiters.distributionMethod.resize(
269         2, DistributionMethod::CyclicNumProcsEqNumIters);
270     cyclicNprocsEqNiters.procInfo =
271         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
272     patterns.add<LinalgTilingPattern>(
273         MatmulOp::getOperationName(), context,
274         LinalgTilingOptions()
275             .setTileSizes({8, 8, 4})
276             .setLoopType(LinalgTilingLoopType::ParallelLoops)
277             .setDistributionOptions(cyclicNprocsEqNiters),
278         LinalgTransformationFilter(
279             StringAttr::get(context, "distribute1"),
280             StringAttr::get(context, "after_distribute1")));
281   }
282 
283   {
284     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
285     cyclicNprocsGeNiters.distributionMethod.resize(
286         2, DistributionMethod::CyclicNumProcsGeNumIters);
287     cyclicNprocsGeNiters.procInfo =
288         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
289     patterns.add<LinalgTilingPattern>(
290         MatmulOp::getOperationName(), context,
291         LinalgTilingOptions()
292             .setTileSizes({8, 8, 4})
293             .setLoopType(LinalgTilingLoopType::ParallelLoops)
294             .setDistributionOptions(cyclicNprocsGeNiters),
295         LinalgTransformationFilter(
296             StringAttr::get(context, "distribute2"),
297             StringAttr::get(context, "after_distribute2")));
298   }
299 
300   {
301     LinalgLoopDistributionOptions cyclicNprocsDefault;
302     cyclicNprocsDefault.distributionMethod.resize(2,
303                                                   DistributionMethod::Cyclic);
304     cyclicNprocsDefault.procInfo =
305         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
306     patterns.add<LinalgTilingPattern>(
307         MatmulOp::getOperationName(), context,
308         LinalgTilingOptions()
309             .setTileSizes({8, 8, 4})
310             .setLoopType(LinalgTilingLoopType::ParallelLoops)
311             .setDistributionOptions(cyclicNprocsDefault),
312         LinalgTransformationFilter(
313             StringAttr::get(context, "distribute3"),
314             StringAttr::get(context, "after_distribute3")));
315   }
316 
317   {
318     LinalgLoopDistributionOptions cyclicNprocsMixed1;
319     cyclicNprocsMixed1.distributionMethod = {
320         DistributionMethod::CyclicNumProcsEqNumIters,
321         DistributionMethod::CyclicNumProcsGeNumIters};
322     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
323     patterns.add<LinalgTilingPattern>(
324         MatmulOp::getOperationName(), context,
325         LinalgTilingOptions()
326             .setTileSizes({8, 8, 4})
327             .setLoopType(LinalgTilingLoopType::ParallelLoops)
328             .setDistributionOptions(cyclicNprocsMixed1),
329         LinalgTransformationFilter(
330             StringAttr::get(context, "distribute4"),
331             StringAttr::get(context, "after_distribute4")));
332   }
333 
334   {
335     LinalgLoopDistributionOptions cyclicNprocsMixed2;
336     cyclicNprocsMixed2.distributionMethod = {
337         DistributionMethod::CyclicNumProcsGeNumIters,
338         DistributionMethod::Cyclic};
339     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
340     patterns.add<LinalgTilingPattern>(
341         MatmulOp::getOperationName(), context,
342         LinalgTilingOptions()
343             .setTileSizes({8, 8, 4})
344             .setLoopType(LinalgTilingLoopType::ParallelLoops)
345             .setDistributionOptions(cyclicNprocsMixed2),
346         LinalgTransformationFilter(
347             StringAttr::get(context, "distribute5"),
348             StringAttr::get(context, "after_distribute5")));
349   }
350 
351   {
352     LinalgLoopDistributionOptions cyclicNprocsMixed3;
353     cyclicNprocsMixed3.distributionMethod = {
354         DistributionMethod::Cyclic,
355         DistributionMethod::CyclicNumProcsEqNumIters};
356     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
357 
358     patterns.add<LinalgTilingPattern>(
359         MatmulOp::getOperationName(), context,
360         LinalgTilingOptions()
361             .setTileSizes({8, 8, 4})
362             .setLoopType(LinalgTilingLoopType::ParallelLoops)
363             .setDistributionOptions(cyclicNprocsMixed3),
364         LinalgTransformationFilter(
365             StringAttr::get(context, "distribute6"),
366             StringAttr::get(context, "after_distribute6")));
367   }
368 
369   {
370     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
371     cyclicNprocsEqNiters.distributionMethod.resize(2,
372                                                    DistributionMethod::Cyclic);
373     cyclicNprocsEqNiters.procInfo =
374         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
375     patterns.add<LinalgTilingPattern>(
376         MatmulOp::getOperationName(), context,
377         LinalgTilingOptions()
378             .setTileSizes({8, 8, 4})
379             .setLoopType(LinalgTilingLoopType::Loops)
380             .setDistributionOptions(cyclicNprocsEqNiters),
381         LinalgTransformationFilter(
382             StringAttr::get(context, "tensors_distribute1"),
383             StringAttr::get(context, "tensors_after_distribute1")));
384   }
385 }
386 
fillTileFuseAndDistributePatterns(MLIRContext * context,RewritePatternSet & patterns)387 static void fillTileFuseAndDistributePatterns(MLIRContext *context,
388                                               RewritePatternSet &patterns) {
389   LinalgLoopDistributionOptions cyclicNprocsEqNiters;
390   cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
391   cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
392   patterns.add<LinalgTileAndFuseTensorOpsPattern>(
393       MatmulOp::getOperationName(), context,
394       LinalgTilingAndFusionOptions()
395           .setTileSizes({8, 8, 4})
396           .setDistributionOptions(cyclicNprocsEqNiters),
397       LinalgTransformationFilter(
398           StringAttr::get(context, "tensors_fuse_distribute1"),
399           StringAttr::get(context, "tensors_after_fuse_distribute1")));
400 }
401 
applyVectorTransferForwardingPatterns(func::FuncOp funcOp)402 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
403   RewritePatternSet forwardPattern(funcOp.getContext());
404   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
405   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
406   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
407 }
408 
applyLinalgToVectorPatterns(func::FuncOp funcOp)409 static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
410   RewritePatternSet patterns(funcOp.getContext());
411   auto *ctx = funcOp.getContext();
412   patterns.add<LinalgVectorizationPattern>(
413       ctx, LinalgTransformationFilter()
414                .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
415   patterns.add<CopyVectorizationPattern>(ctx);
416   populatePadOpVectorizationPatterns(patterns);
417   populateConvolutionVectorizationPatterns(patterns);
418   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
419 }
420 
applyPadTensorToGenericPatterns(func::FuncOp funcOp)421 static void applyPadTensorToGenericPatterns(func::FuncOp funcOp) {
422   RewritePatternSet patterns(funcOp.getContext());
423   patterns.add<PadOpTransformationPattern>(funcOp.getContext());
424   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
425 }
426 
applyGeneralizePadTensorPatterns(func::FuncOp funcOp)427 static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
428   RewritePatternSet patterns(funcOp.getContext());
429   patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
430   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
431 }
432 
applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp)433 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
434   RewritePatternSet patterns(funcOp.getContext());
435   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
436   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
437 }
438 
applyTilePattern(func::FuncOp funcOp,const std::string & loopType,ArrayRef<int64_t> tileSizes,ArrayRef<int64_t> peeledLoops,bool scalarizeDynamicDims)439 static void applyTilePattern(func::FuncOp funcOp, const std::string &loopType,
440                              ArrayRef<int64_t> tileSizes,
441                              ArrayRef<int64_t> peeledLoops,
442                              bool scalarizeDynamicDims) {
443   MLIRContext *context = funcOp.getContext();
444   RewritePatternSet tilingPattern(context);
445   LinalgTilingLoopType type =
446       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
447           .Case("for", LinalgTilingLoopType::Loops)
448           .Case("affine", LinalgTilingLoopType::AffineLoops)
449           .Case("parallel", LinalgTilingLoopType::ParallelLoops);
450   auto linalgTilingOptions = linalg::LinalgTilingOptions()
451                                  .setPeeledLoops(peeledLoops)
452                                  .setLoopType(type);
453   if (scalarizeDynamicDims) {
454     linalgTilingOptions.scalarizeDynamicDims();
455     assert(tileSizes.empty() &&
456            "tileSizes and scalarizeDynamicDims is mutually exclusive");
457   } else {
458     linalgTilingOptions.setTileSizes(tileSizes);
459   }
460   linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
461   TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
462       tilingPattern, linalgTilingOptions, f);
463   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
464 }
465 
applySplitReduction(func::FuncOp funcOp)466 static void applySplitReduction(func::FuncOp funcOp) {
467   RewritePatternSet patterns(funcOp.getContext());
468   linalg::populateSplitReductionPattern(
469       patterns,
470       [](LinalgOp op) {
471         unsigned insertDimIndex = op.getNumLoops() - 1;
472         return std::make_pair(4, insertDimIndex);
473       },
474       LinalgTransformationFilter(
475           ArrayRef<StringAttr>{},
476           StringAttr::get(funcOp.getContext(), "SPLIT")));
477   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
478 }
479 
applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp)480 static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
481   RewritePatternSet patterns(funcOp.getContext());
482   populateBubbleUpExtractSliceOpPatterns(patterns);
483   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
484 }
485 
486 /// Apply transformations specified as patterns.
runOnOperation()487 void TestLinalgTransforms::runOnOperation() {
488   auto lambda = [&](void *) {
489     getOperation().walk([](LinalgOp op) {
490       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
491     });
492   };
493   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
494 
495   if (testTileAndDistributionOptions) {
496     RewritePatternSet patterns(&getContext());
497     fillTileAndDistributePatterns(&getContext(), patterns);
498     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
499     return;
500   }
501   if (testTileFuseAndDistributionOptions) {
502     RewritePatternSet patterns(&getContext());
503     fillTileFuseAndDistributePatterns(&getContext(), patterns);
504     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
505     return;
506   }
507   if (testPatterns)
508     return applyPatterns(getOperation());
509   if (testVectorTransferForwardingPatterns)
510     return applyVectorTransferForwardingPatterns(getOperation());
511   if (testGenericToVectorPattern)
512     return applyLinalgToVectorPatterns(getOperation());
513   if (testTransformPadTensor)
514     return applyPadTensorToGenericPatterns(getOperation());
515   if (testGeneralizePadTensor)
516     return applyGeneralizePadTensorPatterns(getOperation());
517   if (testSwapSubTensorPadTensor)
518     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
519   if (testTilePattern)
520     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
521                             /*scalarizeDynamicDims=*/false);
522   if (testTileScalarizeDynamicDims)
523     return applyTilePattern(getOperation(), loopType, tileSizes,
524                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
525   if (testSplitReduction)
526     return applySplitReduction(getOperation());
527   if (testBubbleUpExtractSliceOpPattern)
528     return applyBubbleUpExtractSliceOpPattern(getOperation());
529 }
530 
531 namespace mlir {
532 namespace test {
registerTestLinalgTransforms()533 void registerTestLinalgTransforms() {
534   PassRegistration<TestLinalgTransforms>();
535 }
536 } // namespace test
537 } // namespace mlir
538