1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct TestLinalgTransforms
35     : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> {
36   TestLinalgTransforms() = default;
37   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
38 
39   void getDependentDialects(DialectRegistry &registry) const override {
40     // clang-format off
41     registry.insert<AffineDialect,
42                     memref::MemRefDialect,
43                     scf::SCFDialect,
44                     StandardOpsDialect,
45                     linalg::LinalgDialect,
46                     vector::VectorDialect,
47                     gpu::GPUDialect>();
48     // clang-format on
49   }
50   StringRef getArgument() const final {
51     return "test-linalg-transform-patterns";
52   }
53   StringRef getDescription() const final {
54     return "Test Linalg transformation patterns by applying them greedily.";
55   }
56 
57   void runOnOperation() override;
58 
59   Option<bool> testPatterns{*this, "test-patterns",
60                             llvm::cl::desc("Test a mixed set of patterns"),
61                             llvm::cl::init(false)};
62   Option<bool> testMatmulToVectorPatterns1dTiling{
63       *this, "test-matmul-to-vector-patterns-tile-1d",
64       llvm::cl::desc(
65           "Test a fused pass that applies patterns from matmul to vectors via "
66           "1-d tiling"),
67       llvm::cl::init(false)};
68   Option<bool> testMatmulToVectorPatterns2dTiling{
69       *this, "test-matmul-to-vector-patterns-tile-2d",
70       llvm::cl::desc(
71           "Test a fused pass that applies patterns from matmul to vectors via "
72           "2-d tiling"),
73       llvm::cl::init(false)};
74   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
75                                     llvm::cl::desc("Test promotion options"),
76                                     llvm::cl::init(false)};
77   Option<bool> testTileAndDistributionOptions{
78       *this, "test-tile-and-distribute-options",
79       llvm::cl::desc("Test tile and distribute options"),
80       llvm::cl::init(false)};
81   Option<bool> testTileFuseAndDistributionOptions{
82       *this, "test-tile-fuse-and-distribute-options",
83       llvm::cl::desc("Test tile, fuse and distribute options"),
84       llvm::cl::init(false)};
85   Option<bool> testVectorTransferForwardingPatterns{
86       *this, "test-vector-transfer-forwarding-patterns",
87       llvm::cl::desc(
88           "Test a fused pass that forwards memref.copy to vector.transfer"),
89       llvm::cl::init(false)};
90   Option<bool> testGenericToVectorPattern{
91       *this, "test-linalg-to-vector-patterns",
92       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
93                      "in vector.contract form"),
94       llvm::cl::init(false)};
95   Option<bool> testTilePattern{*this, "test-tile-pattern",
96                                llvm::cl::desc("Test tile pattern"),
97                                llvm::cl::init(false)};
98   Option<bool> testTileScalarizeDynamicDims{
99       *this, "test-tile-scalarize-dynamic-dims",
100       llvm::cl::desc("Test tiling of dynamic dims by 1"),
101       llvm::cl::init(false)};
102   Option<bool> testTransformPadTensor{
103       *this, "test-transform-pad-tensor",
104       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
105       llvm::cl::init(false)};
106   Option<bool> testGeneralizePadTensor{
107       *this, "test-generalize-pad-tensor",
108       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
109       llvm::cl::init(false)};
110   Option<bool> testSwapSubTensorPadTensor{
111       *this, "test-swap-subtensor-padtensor",
112       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
113                      "pad_tensor(subtensor)"),
114       llvm::cl::init(false)};
115   ListOption<int64_t> peeledLoops{
116       *this, "peeled-loops",
117       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
118       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
119   ListOption<int64_t> tileSizes{
120       *this, "tile-sizes",
121       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
122       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
123   Option<bool> skipPartial{
124       *this, "skip-partial",
125       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
126       llvm::cl::init(false)};
127   Option<std::string> loopType{
128       *this, "loop-type",
129       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
130                      "tiled_loop"),
131       llvm::cl::init("for")};
132 };
133 } // namespace
134 
135 static void applyPatterns(FuncOp funcOp) {
136   MLIRContext *ctx = funcOp.getContext();
137   RewritePatternSet patterns(ctx);
138 
139   //===--------------------------------------------------------------------===//
140   // Linalg tiling patterns.
141   //===--------------------------------------------------------------------===//
142   patterns.add<LinalgTilingPattern>(
143       MatmulOp::getOperationName(), ctx,
144       LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
145       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
146                                  StringAttr::get(ctx, "L3")));
147   patterns.add<LinalgTilingPattern>(
148       MatmulOp::getOperationName(), ctx,
149       LinalgTilingOptions().setTileSizes({200, 300, 400}),
150       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
151                                  StringAttr::get(ctx, "L2")));
152   patterns.add<LinalgTilingPattern>(
153       MatmulOp::getOperationName(), ctx,
154       LinalgTilingOptions().setTileSizes({20, 30, 40}),
155       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
156                                  StringAttr::get(ctx, "L1")));
157   patterns.add<LinalgTilingPattern>(
158       MatmulOp::getOperationName(), ctx,
159       LinalgTilingOptions().setTileSizes({2, 3, 4}),
160       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
161                                  StringAttr::get(ctx, "REG")));
162 
163   patterns.add<LinalgTilingPattern>(
164       MatvecOp::getOperationName(), ctx,
165       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
166           LinalgTilingLoopType::ParallelLoops),
167       LinalgTransformationFilter(ArrayRef<StringAttr>{},
168                                  StringAttr::get(ctx, "L1")));
169 
170   patterns.add<LinalgTilingPattern>(
171       DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
172       LinalgTransformationFilter(
173           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
174                                StringAttr::get(ctx, "L3"),
175                                StringAttr::get(ctx, "L2")},
176           StringAttr::get(ctx, "REG")));
177 
178   //===--------------------------------------------------------------------===//
179   // Linalg tiling and permutation patterns.
180   //===--------------------------------------------------------------------===//
181   patterns.add<LinalgTilingPattern>(
182       MatmulOp::getOperationName(), ctx,
183       LinalgTilingOptions()
184           .setTileSizes({2000, 3000, 4000})
185           .setInterchange({1, 2, 0}),
186       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
187                                  StringAttr::get(ctx, "L2__with_perm__")));
188   patterns.add<LinalgTilingPattern>(
189       MatmulOp::getOperationName(), ctx,
190       LinalgTilingOptions()
191           .setTileSizes({200, 300, 400})
192           .setInterchange({1, 0, 2}),
193       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
194                                  StringAttr::get(ctx, "L1__with_perm__")));
195   patterns.add<LinalgTilingPattern>(
196       MatmulOp::getOperationName(), ctx,
197       LinalgTilingOptions().setTileSizes({20, 30, 40}),
198       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
199                                  StringAttr::get(ctx, "REG__with_perm__")));
200 
201   patterns.add<LinalgTilingPattern>(
202       MatvecOp::getOperationName(), ctx,
203       LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
204       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
205                                  StringAttr::get(ctx, "L1__with_perm__")));
206 
207   patterns.add<LinalgTilingPattern>(
208       MatmulOp::getOperationName(), ctx,
209       LinalgTilingOptions()
210           .setTileSizes({16, 8, 4})
211           .setInterchange({1, 2, 0})
212           .setLoopType(LinalgTilingLoopType::ParallelLoops),
213       LinalgTransformationFilter(
214           StringAttr::get(ctx, "par__with_perm__"),
215           StringAttr::get(ctx, "after_par__with_perm__")));
216 
217   //===--------------------------------------------------------------------===//
218   // Linalg to loops patterns.
219   //===--------------------------------------------------------------------===//
220   patterns.add<LinalgLoweringPattern<DotOp>>(
221       ctx,
222       /*loweringType=*/LinalgLoweringType::Loops,
223       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
224 
225   //===--------------------------------------------------------------------===//
226   // Linalg distribution patterns.
227   //===--------------------------------------------------------------------===//
228   LinalgLoopDistributionOptions distributionOptions;
229 
230   //===--------------------------------------------------------------------===//
231   // Linalg to vector contraction patterns.
232   //===--------------------------------------------------------------------===//
233   patterns.add<LinalgVectorizationPattern>(
234       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
235                .addOpFilter<MatmulOp, FillOp, GenericOp>());
236   patterns.add<CopyVectorizationPattern>(ctx);
237 
238   //===--------------------------------------------------------------------===//
239   // Linalg generic interchange pattern.
240   //===--------------------------------------------------------------------===//
241   patterns.add<GenericOpInterchangePattern>(
242       ctx,
243       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
244       LinalgTransformationFilter(ArrayRef<StringAttr>{},
245                                  StringAttr::get(ctx, "PERMUTED")));
246 
247   //===--------------------------------------------------------------------===//
248   // Linalg subview operands promotion.
249   //===--------------------------------------------------------------------===//
250   patterns.add<LinalgPromotionPattern<MatmulOp>>(
251       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
252       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
253                                  StringAttr::get(ctx, "_views_promoted_")));
254   patterns.add<LinalgPromotionPattern<MatmulOp>>(
255       ctx,
256       LinalgPromotionOptions()
257           .setOperandsToPromote({0})
258           .setUseFullTileBuffersByDefault(true),
259       LinalgTransformationFilter(
260           StringAttr::get(ctx, "_promote_first_view_"),
261           StringAttr::get(ctx, "_first_view_promoted_")));
262   patterns.add<LinalgPromotionPattern<FillOp>>(
263       ctx,
264       LinalgPromotionOptions()
265           .setOperandsToPromote({1})
266           .setUseFullTileBuffers({false, true})
267           .setAlignment(32),
268       LinalgTransformationFilter(
269           StringAttr::get(ctx, "_promote_views_aligned_"),
270           StringAttr::get(ctx, "_views_aligned_promoted_")));
271 
272   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
273 
274   // Drop the marker.
275   funcOp.walk([](LinalgOp op) {
276     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
277   });
278 }
279 
280 static void fillL1TilingAndMatmulToVectorPatterns(
281     FuncOp funcOp, StringRef startMarker,
282     SmallVectorImpl<RewritePatternSet> &patternsVector) {
283   MLIRContext *ctx = funcOp.getContext();
284   patternsVector.emplace_back(
285       ctx, std::make_unique<LinalgTilingPattern>(
286                MatmulOp::getOperationName(), ctx,
287                LinalgTilingOptions()
288                    .setTileSizes({8, 12, 16})
289                    .setInterchange({1, 0, 2}),
290                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
291                                           StringAttr::get(ctx, "L1"))));
292 
293   patternsVector.emplace_back(
294       ctx,
295       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
296           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
297           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
298                                      StringAttr::get(ctx, "VEC"))));
299 
300   patternsVector.emplace_back(
301       ctx, std::make_unique<LinalgVectorizationPattern>(
302                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
303                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
304   patternsVector.back().add<LinalgVectorizationPattern>(
305       ctx, LinalgTransformationFilter().addOpFilter<FillOp>());
306   patternsVector.back().add<CopyVectorizationPattern>(ctx);
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // Test promotion callbacks
311 //===----------------------------------------------------------------------===//
312 
313 // Allocation call back
314 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
315                                        ArrayRef<Value> boundingSubViewSize,
316                                        DataLayout &layout) {
317   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
318   return b
319       .create<memref::AllocOp>(
320           subView.getLoc(),
321           MemRefType::get(shape, subView.getType().getElementType(),
322                           /*affineMapComposition =*/{}, 3),
323           boundingSubViewSize)
324       .getResult();
325 }
326 
327 // Deallocation callback
328 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
329   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
330   return success();
331 }
332 
333 // Copy in call back
334 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
335                                     bool isOutput) {
336   auto floatType = src.getType().cast<MemRefType>().getElementType();
337   if (!floatType.isa<FloatType>())
338     return failure();
339   if (!isOutput) {
340     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
341                                             FloatAttr::get(floatType, 42.0));
342     b.create<FillOp>(src.getLoc(), cst, dst);
343   }
344   b.create<memref::CopyOp>(src.getLoc(), src, dst);
345   return success();
346 }
347 
348 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
349                                           RewritePatternSet &patterns) {
350   patterns.add<LinalgTilingPattern>(
351       MatmulOp::getOperationName(), ctx,
352       LinalgTilingOptions().setTileSizes({16, 16, 16}),
353       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
354                                  StringAttr::get(ctx, "PROMOTE")));
355   patterns.add<LinalgPromotionPattern<MatmulOp>>(
356       ctx,
357       LinalgPromotionOptions()
358           .setOperandsToPromote({0, 2})
359           .setUseFullTileBuffers({false, false})
360           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
361           .setCopyInOutFns(
362               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
363                 return copyCallBackFn(b, src, dst, false);
364               },
365               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
366                 return copyCallBackFn(b, src, dst, true);
367               }),
368       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
369 }
370 
371 template <typename IdOp, typename NProcsOp>
372 static SmallVector<ProcInfo, 2>
373 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
374   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
375   SmallVector<ProcInfo, 2> procInfo(count);
376   Type indexType = b.getIndexType();
377   for (unsigned i = 0; i < count; ++i) {
378     gpu::Dimension dim = *gpu::symbolizeDimension(i);
379     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
380                                b.create<NProcsOp>(loc, indexType, dim)};
381   }
382   return procInfo;
383 }
384 
385 static void fillTileAndDistributePatterns(MLIRContext *context,
386                                           RewritePatternSet &patterns) {
387   {
388     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
389     cyclicNprocsEqNiters.distributionMethod.resize(
390         2, DistributionMethod::CyclicNumProcsEqNumIters);
391     cyclicNprocsEqNiters.procInfo =
392         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
393     patterns.add<LinalgTilingPattern>(
394         MatmulOp::getOperationName(), context,
395         LinalgTilingOptions()
396             .setTileSizes({8, 8, 4})
397             .setLoopType(LinalgTilingLoopType::ParallelLoops)
398             .setDistributionOptions(cyclicNprocsEqNiters),
399         LinalgTransformationFilter(
400             StringAttr::get(context, "distribute1"),
401             StringAttr::get(context, "after_distribute1")));
402   }
403 
404   {
405     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
406     cyclicNprocsGeNiters.distributionMethod.resize(
407         2, DistributionMethod::CyclicNumProcsGeNumIters);
408     cyclicNprocsGeNiters.procInfo =
409         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
410     patterns.add<LinalgTilingPattern>(
411         MatmulOp::getOperationName(), context,
412         LinalgTilingOptions()
413             .setTileSizes({8, 8, 4})
414             .setLoopType(LinalgTilingLoopType::ParallelLoops)
415             .setDistributionOptions(cyclicNprocsGeNiters),
416         LinalgTransformationFilter(
417             StringAttr::get(context, "distribute2"),
418             StringAttr::get(context, "after_distribute2")));
419   }
420 
421   {
422     LinalgLoopDistributionOptions cyclicNprocsDefault;
423     cyclicNprocsDefault.distributionMethod.resize(2,
424                                                   DistributionMethod::Cyclic);
425     cyclicNprocsDefault.procInfo =
426         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
427     patterns.add<LinalgTilingPattern>(
428         MatmulOp::getOperationName(), context,
429         LinalgTilingOptions()
430             .setTileSizes({8, 8, 4})
431             .setLoopType(LinalgTilingLoopType::ParallelLoops)
432             .setDistributionOptions(cyclicNprocsDefault),
433         LinalgTransformationFilter(
434             StringAttr::get(context, "distribute3"),
435             StringAttr::get(context, "after_distribute3")));
436   }
437 
438   {
439     LinalgLoopDistributionOptions cyclicNprocsMixed1;
440     cyclicNprocsMixed1.distributionMethod = {
441         DistributionMethod::CyclicNumProcsEqNumIters,
442         DistributionMethod::CyclicNumProcsGeNumIters};
443     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
444     patterns.add<LinalgTilingPattern>(
445         MatmulOp::getOperationName(), context,
446         LinalgTilingOptions()
447             .setTileSizes({8, 8, 4})
448             .setLoopType(LinalgTilingLoopType::ParallelLoops)
449             .setDistributionOptions(cyclicNprocsMixed1),
450         LinalgTransformationFilter(
451             StringAttr::get(context, "distribute4"),
452             StringAttr::get(context, "after_distribute4")));
453   }
454 
455   {
456     LinalgLoopDistributionOptions cyclicNprocsMixed2;
457     cyclicNprocsMixed2.distributionMethod = {
458         DistributionMethod::CyclicNumProcsGeNumIters,
459         DistributionMethod::Cyclic};
460     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
461     patterns.add<LinalgTilingPattern>(
462         MatmulOp::getOperationName(), context,
463         LinalgTilingOptions()
464             .setTileSizes({8, 8, 4})
465             .setLoopType(LinalgTilingLoopType::ParallelLoops)
466             .setDistributionOptions(cyclicNprocsMixed2),
467         LinalgTransformationFilter(
468             StringAttr::get(context, "distribute5"),
469             StringAttr::get(context, "after_distribute5")));
470   }
471 
472   {
473     LinalgLoopDistributionOptions cyclicNprocsMixed3;
474     cyclicNprocsMixed3.distributionMethod = {
475         DistributionMethod::Cyclic,
476         DistributionMethod::CyclicNumProcsEqNumIters};
477     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
478 
479     patterns.add<LinalgTilingPattern>(
480         MatmulOp::getOperationName(), context,
481         LinalgTilingOptions()
482             .setTileSizes({8, 8, 4})
483             .setLoopType(LinalgTilingLoopType::ParallelLoops)
484             .setDistributionOptions(cyclicNprocsMixed3),
485         LinalgTransformationFilter(
486             StringAttr::get(context, "distribute6"),
487             StringAttr::get(context, "after_distribute6")));
488   }
489 
490   {
491     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
492     cyclicNprocsEqNiters.distributionMethod.resize(2,
493                                                    DistributionMethod::Cyclic);
494     cyclicNprocsEqNiters.procInfo =
495         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
496     patterns.add<LinalgTilingPattern>(
497         MatmulOp::getOperationName(), context,
498         LinalgTilingOptions()
499             .setTileSizes({8, 8, 4})
500             .setLoopType(LinalgTilingLoopType::Loops)
501             .setDistributionOptions(cyclicNprocsEqNiters),
502         LinalgTransformationFilter(
503             StringAttr::get(context, "tensors_distribute1"),
504             StringAttr::get(context, "tensors_after_distribute1")));
505   }
506 }
507 
508 static void fillTileFuseAndDistributePatterns(MLIRContext *context,
509                                               RewritePatternSet &patterns) {
510   LinalgLoopDistributionOptions cyclicNprocsEqNiters;
511   cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
512   cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
513   patterns.add<LinalgTileAndFuseTensorOpsPattern>(
514       MatmulOp::getOperationName(), context,
515       LinalgTilingAndFusionOptions()
516           .setTileSizes({8, 8, 4})
517           .setDistributionOptions(cyclicNprocsEqNiters),
518       LinalgTransformationFilter(
519           StringAttr::get(context, "tensors_fuse_distribute1"),
520           StringAttr::get(context, "tensors_after_fuse_distribute1")));
521 }
522 
523 static void
524 applyMatmulToVectorPatterns(FuncOp funcOp,
525                             bool testMatmulToVectorPatterns1dTiling,
526                             bool testMatmulToVectorPatterns2dTiling) {
527   MLIRContext *ctx = funcOp.getContext();
528   SmallVector<RewritePatternSet, 4> stage1Patterns;
529   if (testMatmulToVectorPatterns1dTiling) {
530     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
531   } else if (testMatmulToVectorPatterns2dTiling) {
532     stage1Patterns.emplace_back(
533         ctx, std::make_unique<LinalgTilingPattern>(
534                  MatmulOp::getOperationName(), ctx,
535                  LinalgTilingOptions()
536                      .setTileSizes({768, 264, 768})
537                      .setInterchange({1, 2, 0}),
538                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
539                                             StringAttr::get(ctx, "L2"))));
540     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
541   }
542   {
543     // Canonicalization patterns
544     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
545     vector::populateVectorTransferPermutationMapLoweringPatterns(
546         canonicalizationPatterns);
547     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
548     stage1Patterns.push_back(std::move(canonicalizationPatterns));
549   }
550   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
551   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
552   FrozenRewritePatternSet stage2Patterns =
553       getLinalgTilingCanonicalizationPatterns(ctx);
554   (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns);
555 }
556 
557 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
558   RewritePatternSet forwardPattern(funcOp.getContext());
559   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
560   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
561   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
562 }
563 
564 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
565   RewritePatternSet patterns(funcOp.getContext());
566   auto *ctx = funcOp.getContext();
567   patterns.add<LinalgVectorizationPattern>(
568       ctx, LinalgTransformationFilter()
569                .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
570   patterns.add<CopyVectorizationPattern>(ctx);
571   populatePadOpVectorizationPatterns(patterns);
572   populateConvolutionVectorizationPatterns(patterns);
573   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
574 }
575 
576 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
577   RewritePatternSet patterns(funcOp.getContext());
578   patterns.add<PadOpTransformationPattern>(funcOp.getContext());
579   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
580 }
581 
582 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
583   RewritePatternSet patterns(funcOp.getContext());
584   patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
585   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
586 }
587 
588 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
589   RewritePatternSet patterns(funcOp.getContext());
590   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
591   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
592 }
593 
594 static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
595                              ArrayRef<int64_t> tileSizes,
596                              ArrayRef<int64_t> peeledLoops,
597                              bool scalarizeDynamicDims) {
598   MLIRContext *context = funcOp.getContext();
599   RewritePatternSet tilingPattern(context);
600   LinalgTilingLoopType type =
601       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
602           .Case("for", LinalgTilingLoopType::Loops)
603           .Case("affine", LinalgTilingLoopType::AffineLoops)
604           .Case("parallel", LinalgTilingLoopType::ParallelLoops);
605   auto linalgTilingOptions = linalg::LinalgTilingOptions()
606                                  .setPeeledLoops(peeledLoops)
607                                  .setLoopType(type);
608   if (scalarizeDynamicDims) {
609     linalgTilingOptions.scalarizeDynamicDims();
610     assert(tileSizes.empty() &&
611            "tileSizes and scalarizeDynamicDims is mutually exclusive");
612   } else {
613     linalgTilingOptions.setTileSizes(tileSizes);
614   }
615   linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
616   TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
617       tilingPattern, linalgTilingOptions, f);
618   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
619 }
620 
621 /// Apply transformations specified as patterns.
622 void TestLinalgTransforms::runOnOperation() {
623   auto lambda = [&](void *) {
624     getOperation().walk([](LinalgOp op) {
625       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
626     });
627   };
628   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
629 
630   if (testPromotionOptions) {
631     RewritePatternSet patterns(&getContext());
632     fillPromotionCallBackPatterns(&getContext(), patterns);
633     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
634     return;
635   }
636   if (testTileAndDistributionOptions) {
637     RewritePatternSet patterns(&getContext());
638     fillTileAndDistributePatterns(&getContext(), patterns);
639     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
640     return;
641   }
642   if (testTileFuseAndDistributionOptions) {
643     RewritePatternSet patterns(&getContext());
644     fillTileFuseAndDistributePatterns(&getContext(), patterns);
645     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
646     return;
647   }
648   if (testPatterns)
649     return applyPatterns(getOperation());
650   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
651     return applyMatmulToVectorPatterns(getOperation(),
652                                        testMatmulToVectorPatterns1dTiling,
653                                        testMatmulToVectorPatterns2dTiling);
654   if (testVectorTransferForwardingPatterns)
655     return applyVectorTransferForwardingPatterns(getOperation());
656   if (testGenericToVectorPattern)
657     return applyLinalgToVectorPatterns(getOperation());
658   if (testTransformPadTensor)
659     return applyPadTensorToGenericPatterns(getOperation());
660   if (testGeneralizePadTensor)
661     return applyGeneralizePadTensorPatterns(getOperation());
662   if (testSwapSubTensorPadTensor)
663     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
664   if (testTilePattern)
665     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
666                             /*scalarizeDynamicDims=*/false);
667   if (testTileScalarizeDynamicDims)
668     return applyTilePattern(getOperation(), loopType, tileSizes,
669                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
670 }
671 
672 namespace mlir {
673 namespace test {
674 void registerTestLinalgTransforms() {
675   PassRegistration<TestLinalgTransforms>();
676 }
677 } // namespace test
678 } // namespace mlir
679