1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/GPU/GPUDialect.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct TestLinalgTransforms
35     : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> {
36   TestLinalgTransforms() = default;
37   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
38 
39   void getDependentDialects(DialectRegistry &registry) const override {
40     // clang-format off
41     registry.insert<AffineDialect,
42                     memref::MemRefDialect,
43                     scf::SCFDialect,
44                     linalg::LinalgDialect,
45                     vector::VectorDialect,
46                     gpu::GPUDialect>();
47     // clang-format on
48   }
49   StringRef getArgument() const final {
50     return "test-linalg-transform-patterns";
51   }
52   StringRef getDescription() const final {
53     return "Test Linalg transformation patterns by applying them greedily.";
54   }
55 
56   void runOnOperation() override;
57 
58   Option<bool> testPatterns{*this, "test-patterns",
59                             llvm::cl::desc("Test a mixed set of patterns"),
60                             llvm::cl::init(false)};
61   Option<bool> testMatmulToVectorPatterns1dTiling{
62       *this, "test-matmul-to-vector-patterns-tile-1d",
63       llvm::cl::desc(
64           "Test a fused pass that applies patterns from matmul to vectors via "
65           "1-d tiling"),
66       llvm::cl::init(false)};
67   Option<bool> testMatmulToVectorPatterns2dTiling{
68       *this, "test-matmul-to-vector-patterns-tile-2d",
69       llvm::cl::desc(
70           "Test a fused pass that applies patterns from matmul to vectors via "
71           "2-d tiling"),
72       llvm::cl::init(false)};
73   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
74                                     llvm::cl::desc("Test promotion options"),
75                                     llvm::cl::init(false)};
76   Option<bool> testTileAndDistributionOptions{
77       *this, "test-tile-and-distribute-options",
78       llvm::cl::desc("Test tile and distribute options"),
79       llvm::cl::init(false)};
80   Option<bool> testTileFuseAndDistributionOptions{
81       *this, "test-tile-fuse-and-distribute-options",
82       llvm::cl::desc("Test tile, fuse and distribute options"),
83       llvm::cl::init(false)};
84   Option<bool> testVectorTransferForwardingPatterns{
85       *this, "test-vector-transfer-forwarding-patterns",
86       llvm::cl::desc(
87           "Test a fused pass that forwards memref.copy to vector.transfer"),
88       llvm::cl::init(false)};
89   Option<bool> testGenericToVectorPattern{
90       *this, "test-linalg-to-vector-patterns",
91       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
92                      "in vector.contract form"),
93       llvm::cl::init(false)};
94   Option<bool> testTilePattern{*this, "test-tile-pattern",
95                                llvm::cl::desc("Test tile pattern"),
96                                llvm::cl::init(false)};
97   Option<bool> testTileScalarizeDynamicDims{
98       *this, "test-tile-scalarize-dynamic-dims",
99       llvm::cl::desc("Test tiling of dynamic dims by 1"),
100       llvm::cl::init(false)};
101   Option<bool> testTransformPadTensor{
102       *this, "test-transform-pad-tensor",
103       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
104       llvm::cl::init(false)};
105   Option<bool> testGeneralizePadTensor{
106       *this, "test-generalize-pad-tensor",
107       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
108       llvm::cl::init(false)};
109   Option<bool> testSwapSubTensorPadTensor{
110       *this, "test-swap-subtensor-padtensor",
111       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
112                      "pad_tensor(subtensor)"),
113       llvm::cl::init(false)};
114   ListOption<int64_t> peeledLoops{
115       *this, "peeled-loops",
116       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
117       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
118   ListOption<int64_t> tileSizes{
119       *this, "tile-sizes",
120       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
121       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
122   Option<bool> skipPartial{
123       *this, "skip-partial",
124       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
125       llvm::cl::init(false)};
126   Option<std::string> loopType{
127       *this, "loop-type",
128       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
129                      "tiled_loop"),
130       llvm::cl::init("for")};
131 };
132 } // namespace
133 
134 static void applyPatterns(FuncOp funcOp) {
135   MLIRContext *ctx = funcOp.getContext();
136   RewritePatternSet patterns(ctx);
137 
138   //===--------------------------------------------------------------------===//
139   // Linalg tiling patterns.
140   //===--------------------------------------------------------------------===//
141   patterns.add<LinalgTilingPattern>(
142       MatmulOp::getOperationName(), ctx,
143       LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
144       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
145                                  StringAttr::get(ctx, "L3")));
146   patterns.add<LinalgTilingPattern>(
147       MatmulOp::getOperationName(), ctx,
148       LinalgTilingOptions().setTileSizes({200, 300, 400}),
149       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
150                                  StringAttr::get(ctx, "L2")));
151   patterns.add<LinalgTilingPattern>(
152       MatmulOp::getOperationName(), ctx,
153       LinalgTilingOptions().setTileSizes({20, 30, 40}),
154       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
155                                  StringAttr::get(ctx, "L1")));
156   patterns.add<LinalgTilingPattern>(
157       MatmulOp::getOperationName(), ctx,
158       LinalgTilingOptions().setTileSizes({2, 3, 4}),
159       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
160                                  StringAttr::get(ctx, "REG")));
161 
162   patterns.add<LinalgTilingPattern>(
163       MatvecOp::getOperationName(), ctx,
164       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
165           LinalgTilingLoopType::ParallelLoops),
166       LinalgTransformationFilter(ArrayRef<StringAttr>{},
167                                  StringAttr::get(ctx, "L1")));
168 
169   patterns.add<LinalgTilingPattern>(
170       DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
171       LinalgTransformationFilter(
172           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
173                                StringAttr::get(ctx, "L3"),
174                                StringAttr::get(ctx, "L2")},
175           StringAttr::get(ctx, "REG")));
176 
177   //===--------------------------------------------------------------------===//
178   // Linalg tiling and permutation patterns.
179   //===--------------------------------------------------------------------===//
180   patterns.add<LinalgTilingPattern>(
181       MatmulOp::getOperationName(), ctx,
182       LinalgTilingOptions()
183           .setTileSizes({2000, 3000, 4000})
184           .setInterchange({1, 2, 0}),
185       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
186                                  StringAttr::get(ctx, "L2__with_perm__")));
187   patterns.add<LinalgTilingPattern>(
188       MatmulOp::getOperationName(), ctx,
189       LinalgTilingOptions()
190           .setTileSizes({200, 300, 400})
191           .setInterchange({1, 0, 2}),
192       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
193                                  StringAttr::get(ctx, "L1__with_perm__")));
194   patterns.add<LinalgTilingPattern>(
195       MatmulOp::getOperationName(), ctx,
196       LinalgTilingOptions().setTileSizes({20, 30, 40}),
197       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
198                                  StringAttr::get(ctx, "REG__with_perm__")));
199 
200   patterns.add<LinalgTilingPattern>(
201       MatvecOp::getOperationName(), ctx,
202       LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
203       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
204                                  StringAttr::get(ctx, "L1__with_perm__")));
205 
206   patterns.add<LinalgTilingPattern>(
207       MatmulOp::getOperationName(), ctx,
208       LinalgTilingOptions()
209           .setTileSizes({16, 8, 4})
210           .setInterchange({1, 2, 0})
211           .setLoopType(LinalgTilingLoopType::ParallelLoops),
212       LinalgTransformationFilter(
213           StringAttr::get(ctx, "par__with_perm__"),
214           StringAttr::get(ctx, "after_par__with_perm__")));
215 
216   //===--------------------------------------------------------------------===//
217   // Linalg to loops patterns.
218   //===--------------------------------------------------------------------===//
219   patterns.add<LinalgLoweringPattern<DotOp>>(
220       ctx,
221       /*loweringType=*/LinalgLoweringType::Loops,
222       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
223 
224   //===--------------------------------------------------------------------===//
225   // Linalg distribution patterns.
226   //===--------------------------------------------------------------------===//
227   LinalgLoopDistributionOptions distributionOptions;
228 
229   //===--------------------------------------------------------------------===//
230   // Linalg to vector contraction patterns.
231   //===--------------------------------------------------------------------===//
232   patterns.add<LinalgVectorizationPattern>(
233       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
234                .addOpFilter<MatmulOp, FillOp, GenericOp>());
235   patterns.add<CopyVectorizationPattern>(ctx);
236 
237   //===--------------------------------------------------------------------===//
238   // Linalg generic interchange pattern.
239   //===--------------------------------------------------------------------===//
240   patterns.add<GenericOpInterchangePattern>(
241       ctx,
242       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
243       LinalgTransformationFilter(ArrayRef<StringAttr>{},
244                                  StringAttr::get(ctx, "PERMUTED")));
245 
246   //===--------------------------------------------------------------------===//
247   // Linalg subview operands promotion.
248   //===--------------------------------------------------------------------===//
249   patterns.add<LinalgPromotionPattern<MatmulOp>>(
250       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
251       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
252                                  StringAttr::get(ctx, "_views_promoted_")));
253   patterns.add<LinalgPromotionPattern<MatmulOp>>(
254       ctx,
255       LinalgPromotionOptions()
256           .setOperandsToPromote({0})
257           .setUseFullTileBuffersByDefault(true),
258       LinalgTransformationFilter(
259           StringAttr::get(ctx, "_promote_first_view_"),
260           StringAttr::get(ctx, "_first_view_promoted_")));
261   patterns.add<LinalgPromotionPattern<FillOp>>(
262       ctx,
263       LinalgPromotionOptions()
264           .setOperandsToPromote({1})
265           .setUseFullTileBuffers({false, true})
266           .setAlignment(32),
267       LinalgTransformationFilter(
268           StringAttr::get(ctx, "_promote_views_aligned_"),
269           StringAttr::get(ctx, "_views_aligned_promoted_")));
270 
271   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
272 
273   // Drop the marker.
274   funcOp.walk([](LinalgOp op) {
275     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
276   });
277 }
278 
279 static void fillL1TilingAndMatmulToVectorPatterns(
280     FuncOp funcOp, StringRef startMarker,
281     SmallVectorImpl<RewritePatternSet> &patternsVector) {
282   MLIRContext *ctx = funcOp.getContext();
283   patternsVector.emplace_back(
284       ctx, std::make_unique<LinalgTilingPattern>(
285                MatmulOp::getOperationName(), ctx,
286                LinalgTilingOptions()
287                    .setTileSizes({8, 12, 16})
288                    .setInterchange({1, 0, 2}),
289                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
290                                           StringAttr::get(ctx, "L1"))));
291 
292   patternsVector.emplace_back(
293       ctx,
294       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
295           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
296           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
297                                      StringAttr::get(ctx, "VEC"))));
298 
299   patternsVector.emplace_back(
300       ctx, std::make_unique<LinalgVectorizationPattern>(
301                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
302                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
303   patternsVector.back().add<LinalgVectorizationPattern>(
304       ctx, LinalgTransformationFilter().addOpFilter<FillOp>());
305   patternsVector.back().add<CopyVectorizationPattern>(ctx);
306 }
307 
308 //===----------------------------------------------------------------------===//
309 // Test promotion callbacks
310 //===----------------------------------------------------------------------===//
311 
312 // Allocation call back
313 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
314                                        ArrayRef<Value> boundingSubViewSize,
315                                        DataLayout &layout) {
316   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
317   return b
318       .create<memref::AllocOp>(
319           subView.getLoc(),
320           MemRefType::get(shape, subView.getType().getElementType(),
321                           /*affineMapComposition =*/{}, 3),
322           boundingSubViewSize)
323       .getResult();
324 }
325 
326 // Deallocation callback
327 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
328   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
329   return success();
330 }
331 
332 // Copy in call back
333 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
334                                     bool isOutput) {
335   auto floatType = src.getType().cast<MemRefType>().getElementType();
336   if (!floatType.isa<FloatType>())
337     return failure();
338   if (!isOutput) {
339     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
340                                             FloatAttr::get(floatType, 42.0));
341     b.create<FillOp>(src.getLoc(), cst, dst);
342   }
343   b.create<memref::CopyOp>(src.getLoc(), src, dst);
344   return success();
345 }
346 
347 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
348                                           RewritePatternSet &patterns) {
349   patterns.add<LinalgTilingPattern>(
350       MatmulOp::getOperationName(), ctx,
351       LinalgTilingOptions().setTileSizes({16, 16, 16}),
352       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
353                                  StringAttr::get(ctx, "PROMOTE")));
354   patterns.add<LinalgPromotionPattern<MatmulOp>>(
355       ctx,
356       LinalgPromotionOptions()
357           .setOperandsToPromote({0, 2})
358           .setUseFullTileBuffers({false, false})
359           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
360           .setCopyInOutFns(
361               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
362                 return copyCallBackFn(b, src, dst, false);
363               },
364               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
365                 return copyCallBackFn(b, src, dst, true);
366               }),
367       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
368 }
369 
370 template <typename IdOp, typename NProcsOp>
371 static SmallVector<ProcInfo, 2>
372 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
373   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
374   SmallVector<ProcInfo, 2> procInfo(count);
375   Type indexType = b.getIndexType();
376   for (unsigned i = 0; i < count; ++i) {
377     gpu::Dimension dim = *gpu::symbolizeDimension(i);
378     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
379                                b.create<NProcsOp>(loc, indexType, dim)};
380   }
381   return procInfo;
382 }
383 
384 static void fillTileAndDistributePatterns(MLIRContext *context,
385                                           RewritePatternSet &patterns) {
386   {
387     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
388     cyclicNprocsEqNiters.distributionMethod.resize(
389         2, DistributionMethod::CyclicNumProcsEqNumIters);
390     cyclicNprocsEqNiters.procInfo =
391         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
392     patterns.add<LinalgTilingPattern>(
393         MatmulOp::getOperationName(), context,
394         LinalgTilingOptions()
395             .setTileSizes({8, 8, 4})
396             .setLoopType(LinalgTilingLoopType::ParallelLoops)
397             .setDistributionOptions(cyclicNprocsEqNiters),
398         LinalgTransformationFilter(
399             StringAttr::get(context, "distribute1"),
400             StringAttr::get(context, "after_distribute1")));
401   }
402 
403   {
404     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
405     cyclicNprocsGeNiters.distributionMethod.resize(
406         2, DistributionMethod::CyclicNumProcsGeNumIters);
407     cyclicNprocsGeNiters.procInfo =
408         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
409     patterns.add<LinalgTilingPattern>(
410         MatmulOp::getOperationName(), context,
411         LinalgTilingOptions()
412             .setTileSizes({8, 8, 4})
413             .setLoopType(LinalgTilingLoopType::ParallelLoops)
414             .setDistributionOptions(cyclicNprocsGeNiters),
415         LinalgTransformationFilter(
416             StringAttr::get(context, "distribute2"),
417             StringAttr::get(context, "after_distribute2")));
418   }
419 
420   {
421     LinalgLoopDistributionOptions cyclicNprocsDefault;
422     cyclicNprocsDefault.distributionMethod.resize(2,
423                                                   DistributionMethod::Cyclic);
424     cyclicNprocsDefault.procInfo =
425         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
426     patterns.add<LinalgTilingPattern>(
427         MatmulOp::getOperationName(), context,
428         LinalgTilingOptions()
429             .setTileSizes({8, 8, 4})
430             .setLoopType(LinalgTilingLoopType::ParallelLoops)
431             .setDistributionOptions(cyclicNprocsDefault),
432         LinalgTransformationFilter(
433             StringAttr::get(context, "distribute3"),
434             StringAttr::get(context, "after_distribute3")));
435   }
436 
437   {
438     LinalgLoopDistributionOptions cyclicNprocsMixed1;
439     cyclicNprocsMixed1.distributionMethod = {
440         DistributionMethod::CyclicNumProcsEqNumIters,
441         DistributionMethod::CyclicNumProcsGeNumIters};
442     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
443     patterns.add<LinalgTilingPattern>(
444         MatmulOp::getOperationName(), context,
445         LinalgTilingOptions()
446             .setTileSizes({8, 8, 4})
447             .setLoopType(LinalgTilingLoopType::ParallelLoops)
448             .setDistributionOptions(cyclicNprocsMixed1),
449         LinalgTransformationFilter(
450             StringAttr::get(context, "distribute4"),
451             StringAttr::get(context, "after_distribute4")));
452   }
453 
454   {
455     LinalgLoopDistributionOptions cyclicNprocsMixed2;
456     cyclicNprocsMixed2.distributionMethod = {
457         DistributionMethod::CyclicNumProcsGeNumIters,
458         DistributionMethod::Cyclic};
459     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
460     patterns.add<LinalgTilingPattern>(
461         MatmulOp::getOperationName(), context,
462         LinalgTilingOptions()
463             .setTileSizes({8, 8, 4})
464             .setLoopType(LinalgTilingLoopType::ParallelLoops)
465             .setDistributionOptions(cyclicNprocsMixed2),
466         LinalgTransformationFilter(
467             StringAttr::get(context, "distribute5"),
468             StringAttr::get(context, "after_distribute5")));
469   }
470 
471   {
472     LinalgLoopDistributionOptions cyclicNprocsMixed3;
473     cyclicNprocsMixed3.distributionMethod = {
474         DistributionMethod::Cyclic,
475         DistributionMethod::CyclicNumProcsEqNumIters};
476     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
477 
478     patterns.add<LinalgTilingPattern>(
479         MatmulOp::getOperationName(), context,
480         LinalgTilingOptions()
481             .setTileSizes({8, 8, 4})
482             .setLoopType(LinalgTilingLoopType::ParallelLoops)
483             .setDistributionOptions(cyclicNprocsMixed3),
484         LinalgTransformationFilter(
485             StringAttr::get(context, "distribute6"),
486             StringAttr::get(context, "after_distribute6")));
487   }
488 
489   {
490     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
491     cyclicNprocsEqNiters.distributionMethod.resize(2,
492                                                    DistributionMethod::Cyclic);
493     cyclicNprocsEqNiters.procInfo =
494         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
495     patterns.add<LinalgTilingPattern>(
496         MatmulOp::getOperationName(), context,
497         LinalgTilingOptions()
498             .setTileSizes({8, 8, 4})
499             .setLoopType(LinalgTilingLoopType::Loops)
500             .setDistributionOptions(cyclicNprocsEqNiters),
501         LinalgTransformationFilter(
502             StringAttr::get(context, "tensors_distribute1"),
503             StringAttr::get(context, "tensors_after_distribute1")));
504   }
505 }
506 
507 static void fillTileFuseAndDistributePatterns(MLIRContext *context,
508                                               RewritePatternSet &patterns) {
509   LinalgLoopDistributionOptions cyclicNprocsEqNiters;
510   cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
511   cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
512   patterns.add<LinalgTileAndFuseTensorOpsPattern>(
513       MatmulOp::getOperationName(), context,
514       LinalgTilingAndFusionOptions()
515           .setTileSizes({8, 8, 4})
516           .setDistributionOptions(cyclicNprocsEqNiters),
517       LinalgTransformationFilter(
518           StringAttr::get(context, "tensors_fuse_distribute1"),
519           StringAttr::get(context, "tensors_after_fuse_distribute1")));
520 }
521 
522 static void
523 applyMatmulToVectorPatterns(FuncOp funcOp,
524                             bool testMatmulToVectorPatterns1dTiling,
525                             bool testMatmulToVectorPatterns2dTiling) {
526   MLIRContext *ctx = funcOp.getContext();
527   SmallVector<RewritePatternSet, 4> stage1Patterns;
528   if (testMatmulToVectorPatterns1dTiling) {
529     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
530   } else if (testMatmulToVectorPatterns2dTiling) {
531     stage1Patterns.emplace_back(
532         ctx, std::make_unique<LinalgTilingPattern>(
533                  MatmulOp::getOperationName(), ctx,
534                  LinalgTilingOptions()
535                      .setTileSizes({768, 264, 768})
536                      .setInterchange({1, 2, 0}),
537                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
538                                             StringAttr::get(ctx, "L2"))));
539     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
540   }
541   {
542     // Canonicalization patterns
543     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
544     vector::populateVectorTransferPermutationMapLoweringPatterns(
545         canonicalizationPatterns);
546     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
547     stage1Patterns.push_back(std::move(canonicalizationPatterns));
548   }
549   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
550   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
551   FrozenRewritePatternSet stage2Patterns =
552       getLinalgTilingCanonicalizationPatterns(ctx);
553   (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns);
554 }
555 
556 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
557   RewritePatternSet forwardPattern(funcOp.getContext());
558   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
559   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
560   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
561 }
562 
563 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
564   RewritePatternSet patterns(funcOp.getContext());
565   auto *ctx = funcOp.getContext();
566   patterns.add<LinalgVectorizationPattern>(
567       ctx, LinalgTransformationFilter()
568                .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
569   patterns.add<CopyVectorizationPattern>(ctx);
570   populatePadOpVectorizationPatterns(patterns);
571   populateConvolutionVectorizationPatterns(patterns);
572   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
573 }
574 
575 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
576   RewritePatternSet patterns(funcOp.getContext());
577   patterns.add<PadOpTransformationPattern>(funcOp.getContext());
578   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
579 }
580 
581 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
582   RewritePatternSet patterns(funcOp.getContext());
583   patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
584   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
585 }
586 
587 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
588   RewritePatternSet patterns(funcOp.getContext());
589   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
590   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
591 }
592 
593 static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
594                              ArrayRef<int64_t> tileSizes,
595                              ArrayRef<int64_t> peeledLoops,
596                              bool scalarizeDynamicDims) {
597   MLIRContext *context = funcOp.getContext();
598   RewritePatternSet tilingPattern(context);
599   LinalgTilingLoopType type =
600       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
601           .Case("for", LinalgTilingLoopType::Loops)
602           .Case("affine", LinalgTilingLoopType::AffineLoops)
603           .Case("parallel", LinalgTilingLoopType::ParallelLoops);
604   auto linalgTilingOptions = linalg::LinalgTilingOptions()
605                                  .setPeeledLoops(peeledLoops)
606                                  .setLoopType(type);
607   if (scalarizeDynamicDims) {
608     linalgTilingOptions.scalarizeDynamicDims();
609     assert(tileSizes.empty() &&
610            "tileSizes and scalarizeDynamicDims is mutually exclusive");
611   } else {
612     linalgTilingOptions.setTileSizes(tileSizes);
613   }
614   linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
615   TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
616       tilingPattern, linalgTilingOptions, f);
617   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
618 }
619 
620 /// Apply transformations specified as patterns.
621 void TestLinalgTransforms::runOnOperation() {
622   auto lambda = [&](void *) {
623     getOperation().walk([](LinalgOp op) {
624       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
625     });
626   };
627   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
628 
629   if (testPromotionOptions) {
630     RewritePatternSet patterns(&getContext());
631     fillPromotionCallBackPatterns(&getContext(), patterns);
632     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
633     return;
634   }
635   if (testTileAndDistributionOptions) {
636     RewritePatternSet patterns(&getContext());
637     fillTileAndDistributePatterns(&getContext(), patterns);
638     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
639     return;
640   }
641   if (testTileFuseAndDistributionOptions) {
642     RewritePatternSet patterns(&getContext());
643     fillTileFuseAndDistributePatterns(&getContext(), patterns);
644     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
645     return;
646   }
647   if (testPatterns)
648     return applyPatterns(getOperation());
649   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
650     return applyMatmulToVectorPatterns(getOperation(),
651                                        testMatmulToVectorPatterns1dTiling,
652                                        testMatmulToVectorPatterns2dTiling);
653   if (testVectorTransferForwardingPatterns)
654     return applyVectorTransferForwardingPatterns(getOperation());
655   if (testGenericToVectorPattern)
656     return applyLinalgToVectorPatterns(getOperation());
657   if (testTransformPadTensor)
658     return applyPadTensorToGenericPatterns(getOperation());
659   if (testGeneralizePadTensor)
660     return applyGeneralizePadTensorPatterns(getOperation());
661   if (testSwapSubTensorPadTensor)
662     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
663   if (testTilePattern)
664     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
665                             /*scalarizeDynamicDims=*/false);
666   if (testTileScalarizeDynamicDims)
667     return applyTilePattern(getOperation(), loopType, tileSizes,
668                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
669 }
670 
671 namespace mlir {
672 namespace test {
673 void registerTestLinalgTransforms() {
674   PassRegistration<TestLinalgTransforms>();
675 }
676 } // namespace test
677 } // namespace mlir
678