1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
21 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
22 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/Vector/IR/VectorOps.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallVector.h"
30 
31 using namespace mlir;
32 using namespace mlir::linalg;
33 
34 namespace {
35 struct TestLinalgTransforms
36     : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> {
37   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms)
38 
39   TestLinalgTransforms() = default;
40   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
41 
42   void getDependentDialects(DialectRegistry &registry) const override {
43     // clang-format off
44     registry.insert<AffineDialect,
45                     bufferization::BufferizationDialect,
46                     memref::MemRefDialect,
47                     scf::SCFDialect,
48                     linalg::LinalgDialect,
49                     vector::VectorDialect,
50                     gpu::GPUDialect>();
51     // clang-format on
52   }
53   StringRef getArgument() const final {
54     return "test-linalg-transform-patterns";
55   }
56   StringRef getDescription() const final {
57     return "Test Linalg transformation patterns by applying them greedily.";
58   }
59 
60   void runOnOperation() override;
61 
62   Option<bool> testPatterns{*this, "test-patterns",
63                             llvm::cl::desc("Test a mixed set of patterns"),
64                             llvm::cl::init(false)};
65   Option<bool> testMatmulToVectorPatterns1dTiling{
66       *this, "test-matmul-to-vector-patterns-tile-1d",
67       llvm::cl::desc(
68           "Test a fused pass that applies patterns from matmul to vectors via "
69           "1-d tiling"),
70       llvm::cl::init(false)};
71   Option<bool> testMatmulToVectorPatterns2dTiling{
72       *this, "test-matmul-to-vector-patterns-tile-2d",
73       llvm::cl::desc(
74           "Test a fused pass that applies patterns from matmul to vectors via "
75           "2-d tiling"),
76       llvm::cl::init(false)};
77   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
78                                     llvm::cl::desc("Test promotion options"),
79                                     llvm::cl::init(false)};
80   Option<bool> testTileAndDistributionOptions{
81       *this, "test-tile-and-distribute-options",
82       llvm::cl::desc("Test tile and distribute options"),
83       llvm::cl::init(false)};
84   Option<bool> testTileFuseAndDistributionOptions{
85       *this, "test-tile-fuse-and-distribute-options",
86       llvm::cl::desc("Test tile, fuse and distribute options"),
87       llvm::cl::init(false)};
88   Option<bool> testVectorTransferForwardingPatterns{
89       *this, "test-vector-transfer-forwarding-patterns",
90       llvm::cl::desc(
91           "Test a fused pass that forwards memref.copy to vector.transfer"),
92       llvm::cl::init(false)};
93   Option<bool> testGenericToVectorPattern{
94       *this, "test-linalg-to-vector-patterns",
95       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
96                      "in vector.contract form"),
97       llvm::cl::init(false)};
98   Option<bool> testTilePattern{*this, "test-tile-pattern",
99                                llvm::cl::desc("Test tile pattern"),
100                                llvm::cl::init(false)};
101   Option<bool> testTileScalarizeDynamicDims{
102       *this, "test-tile-scalarize-dynamic-dims",
103       llvm::cl::desc("Test tiling of dynamic dims by 1"),
104       llvm::cl::init(false)};
105   Option<bool> testTransformPadTensor{
106       *this, "test-transform-pad-tensor",
107       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
108       llvm::cl::init(false)};
109   Option<bool> testGeneralizePadTensor{
110       *this, "test-generalize-pad-tensor",
111       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
112       llvm::cl::init(false)};
113   Option<bool> testSwapSubTensorPadTensor{
114       *this, "test-swap-subtensor-padtensor",
115       llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
116                      "tensor.pad(subtensor)"),
117       llvm::cl::init(false)};
118   Option<bool> testSplitReduction{
119       *this, "test-split-reduction",
120       llvm::cl::desc("Test split reduction transformation"),
121       llvm::cl::init(false)};
122   ListOption<int64_t> peeledLoops{
123       *this, "peeled-loops",
124       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
125       llvm::cl::ZeroOrMore};
126   ListOption<int64_t> tileSizes{
127       *this, "tile-sizes",
128       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
129       llvm::cl::ZeroOrMore};
130   Option<bool> skipPartial{
131       *this, "skip-partial",
132       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
133       llvm::cl::init(false)};
134   Option<std::string> loopType{
135       *this, "loop-type",
136       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
137                      "tiled_loop"),
138       llvm::cl::init("for")};
139   Option<bool> testBubbleUpExtractSliceOpPattern{
140       *this, "test-bubble-up-extract-slice-op-pattern",
141       llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
142                      "extract_slice + linalgOp"),
143       llvm::cl::init(false)};
144 };
145 } // namespace
146 
147 static void applyPatterns(func::FuncOp funcOp) {
148   MLIRContext *ctx = funcOp.getContext();
149   RewritePatternSet patterns(ctx);
150 
151   //===--------------------------------------------------------------------===//
152   // Linalg tiling patterns.
153   //===--------------------------------------------------------------------===//
154   patterns.add<LinalgTilingPattern>(
155       MatmulOp::getOperationName(), ctx,
156       LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
157       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
158                                  StringAttr::get(ctx, "L3")));
159   patterns.add<LinalgTilingPattern>(
160       MatmulOp::getOperationName(), ctx,
161       LinalgTilingOptions().setTileSizes({200, 300, 400}),
162       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
163                                  StringAttr::get(ctx, "L2")));
164   patterns.add<LinalgTilingPattern>(
165       MatmulOp::getOperationName(), ctx,
166       LinalgTilingOptions().setTileSizes({20, 30, 40}),
167       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
168                                  StringAttr::get(ctx, "L1")));
169   patterns.add<LinalgTilingPattern>(
170       MatmulOp::getOperationName(), ctx,
171       LinalgTilingOptions().setTileSizes({2, 3, 4}),
172       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
173                                  StringAttr::get(ctx, "REG")));
174 
175   patterns.add<LinalgTilingPattern>(
176       MatvecOp::getOperationName(), ctx,
177       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
178           LinalgTilingLoopType::ParallelLoops),
179       LinalgTransformationFilter(ArrayRef<StringAttr>{},
180                                  StringAttr::get(ctx, "L1")));
181 
182   patterns.add<LinalgTilingPattern>(
183       DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
184       LinalgTransformationFilter(
185           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
186                                StringAttr::get(ctx, "L3"),
187                                StringAttr::get(ctx, "L2")},
188           StringAttr::get(ctx, "REG")));
189 
190   //===--------------------------------------------------------------------===//
191   // Linalg tiling and permutation patterns.
192   //===--------------------------------------------------------------------===//
193   patterns.add<LinalgTilingPattern>(
194       MatmulOp::getOperationName(), ctx,
195       LinalgTilingOptions()
196           .setTileSizes({2000, 3000, 4000})
197           .setInterchange({1, 2, 0}),
198       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
199                                  StringAttr::get(ctx, "L2__with_perm__")));
200   patterns.add<LinalgTilingPattern>(
201       MatmulOp::getOperationName(), ctx,
202       LinalgTilingOptions()
203           .setTileSizes({200, 300, 400})
204           .setInterchange({1, 0, 2}),
205       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
206                                  StringAttr::get(ctx, "L1__with_perm__")));
207   patterns.add<LinalgTilingPattern>(
208       MatmulOp::getOperationName(), ctx,
209       LinalgTilingOptions().setTileSizes({20, 30, 40}),
210       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
211                                  StringAttr::get(ctx, "REG__with_perm__")));
212 
213   patterns.add<LinalgTilingPattern>(
214       MatvecOp::getOperationName(), ctx,
215       LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
216       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
217                                  StringAttr::get(ctx, "L1__with_perm__")));
218 
219   patterns.add<LinalgTilingPattern>(
220       MatmulOp::getOperationName(), ctx,
221       LinalgTilingOptions()
222           .setTileSizes({16, 8, 4})
223           .setInterchange({1, 2, 0})
224           .setLoopType(LinalgTilingLoopType::ParallelLoops),
225       LinalgTransformationFilter(
226           StringAttr::get(ctx, "par__with_perm__"),
227           StringAttr::get(ctx, "after_par__with_perm__")));
228 
229   //===--------------------------------------------------------------------===//
230   // Linalg to loops patterns.
231   //===--------------------------------------------------------------------===//
232   patterns.add<LinalgLoweringPattern<DotOp>>(
233       ctx,
234       /*loweringType=*/LinalgLoweringType::Loops,
235       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
236 
237   //===--------------------------------------------------------------------===//
238   // Linalg distribution patterns.
239   //===--------------------------------------------------------------------===//
240   LinalgLoopDistributionOptions distributionOptions;
241 
242   //===--------------------------------------------------------------------===//
243   // Linalg to vector contraction patterns.
244   //===--------------------------------------------------------------------===//
245   patterns.add<LinalgVectorizationPattern>(
246       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
247                .addOpFilter<MatmulOp, FillOp, GenericOp>());
248   patterns.add<CopyVectorizationPattern>(ctx);
249 
250   //===--------------------------------------------------------------------===//
251   // Linalg generic interchange pattern.
252   //===--------------------------------------------------------------------===//
253   patterns.add<GenericOpInterchangePattern>(
254       ctx,
255       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
256       LinalgTransformationFilter(ArrayRef<StringAttr>{},
257                                  StringAttr::get(ctx, "PERMUTED")));
258 
259   //===--------------------------------------------------------------------===//
260   // Linalg subview operands promotion.
261   //===--------------------------------------------------------------------===//
262   patterns.add<LinalgPromotionPattern<MatmulOp>>(
263       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
264       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
265                                  StringAttr::get(ctx, "_views_promoted_")));
266   patterns.add<LinalgPromotionPattern<MatmulOp>>(
267       ctx,
268       LinalgPromotionOptions()
269           .setOperandsToPromote({0})
270           .setUseFullTileBuffersByDefault(true),
271       LinalgTransformationFilter(
272           StringAttr::get(ctx, "_promote_first_view_"),
273           StringAttr::get(ctx, "_first_view_promoted_")));
274   patterns.add<LinalgPromotionPattern<FillOp>>(
275       ctx,
276       LinalgPromotionOptions()
277           .setOperandsToPromote({1})
278           .setUseFullTileBuffers({false, true})
279           .setAlignment(32),
280       LinalgTransformationFilter(
281           StringAttr::get(ctx, "_promote_views_aligned_"),
282           StringAttr::get(ctx, "_views_aligned_promoted_")));
283 
284   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
285 
286   // Drop the marker.
287   funcOp.walk([](LinalgOp op) {
288     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
289   });
290 }
291 
292 static void fillL1TilingAndMatmulToVectorPatterns(
293     func::FuncOp funcOp, StringRef startMarker,
294     SmallVectorImpl<RewritePatternSet> &patternsVector) {
295   MLIRContext *ctx = funcOp.getContext();
296   patternsVector.emplace_back(
297       ctx, std::make_unique<LinalgTilingPattern>(
298                MatmulOp::getOperationName(), ctx,
299                LinalgTilingOptions()
300                    .setTileSizes({8, 12, 16})
301                    .setInterchange({1, 0, 2}),
302                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
303                                           StringAttr::get(ctx, "L1"))));
304 
305   patternsVector.emplace_back(
306       ctx,
307       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
308           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
309           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
310                                      StringAttr::get(ctx, "VEC"))));
311 
312   patternsVector.emplace_back(
313       ctx, std::make_unique<LinalgVectorizationPattern>(
314                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
315                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
316   patternsVector.back().add<LinalgVectorizationPattern>(
317       ctx, LinalgTransformationFilter().addOpFilter<FillOp>());
318   patternsVector.back().add<CopyVectorizationPattern>(ctx);
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // Test promotion callbacks
323 //===----------------------------------------------------------------------===//
324 
325 // Allocation call back
326 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
327                                        ArrayRef<Value> boundingSubViewSize,
328                                        DataLayout &layout) {
329   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
330   return b
331       .create<memref::AllocOp>(
332           subView.getLoc(),
333           MemRefType::get(shape, subView.getType().getElementType(),
334                           /*affineMapComposition =*/{}, 3),
335           boundingSubViewSize)
336       .getResult();
337 }
338 
339 // Deallocation callback
340 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
341   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
342   return success();
343 }
344 
345 // Copy in call back
346 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
347                                     bool isOutput) {
348   auto floatType = src.getType().cast<MemRefType>().getElementType();
349   if (!floatType.isa<FloatType>())
350     return failure();
351   if (!isOutput) {
352     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
353                                             FloatAttr::get(floatType, 42.0));
354     b.create<FillOp>(src.getLoc(), cst, dst);
355   }
356   b.create<memref::CopyOp>(src.getLoc(), src, dst);
357   return success();
358 }
359 
360 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
361                                           RewritePatternSet &patterns) {
362   patterns.add<LinalgTilingPattern>(
363       MatmulOp::getOperationName(), ctx,
364       LinalgTilingOptions().setTileSizes({16, 16, 16}),
365       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
366                                  StringAttr::get(ctx, "PROMOTE")));
367   patterns.add<LinalgPromotionPattern<MatmulOp>>(
368       ctx,
369       LinalgPromotionOptions()
370           .setOperandsToPromote({0, 2})
371           .setUseFullTileBuffers({false, false})
372           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
373           .setCopyInOutFns(
374               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
375                 return copyCallBackFn(b, src, dst, false);
376               },
377               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
378                 return copyCallBackFn(b, src, dst, true);
379               }),
380       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
381 }
382 
383 template <typename IdOp, typename NProcsOp>
384 static SmallVector<ProcInfo, 2>
385 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
386   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
387   SmallVector<ProcInfo, 2> procInfo(count);
388   Type indexType = b.getIndexType();
389   for (unsigned i = 0; i < count; ++i) {
390     gpu::Dimension dim = *gpu::symbolizeDimension(i);
391     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
392                                b.create<NProcsOp>(loc, indexType, dim)};
393   }
394   return procInfo;
395 }
396 
397 static void fillTileAndDistributePatterns(MLIRContext *context,
398                                           RewritePatternSet &patterns) {
399   {
400     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
401     cyclicNprocsEqNiters.distributionMethod.resize(
402         2, DistributionMethod::CyclicNumProcsEqNumIters);
403     cyclicNprocsEqNiters.procInfo =
404         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
405     patterns.add<LinalgTilingPattern>(
406         MatmulOp::getOperationName(), context,
407         LinalgTilingOptions()
408             .setTileSizes({8, 8, 4})
409             .setLoopType(LinalgTilingLoopType::ParallelLoops)
410             .setDistributionOptions(cyclicNprocsEqNiters),
411         LinalgTransformationFilter(
412             StringAttr::get(context, "distribute1"),
413             StringAttr::get(context, "after_distribute1")));
414   }
415 
416   {
417     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
418     cyclicNprocsGeNiters.distributionMethod.resize(
419         2, DistributionMethod::CyclicNumProcsGeNumIters);
420     cyclicNprocsGeNiters.procInfo =
421         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
422     patterns.add<LinalgTilingPattern>(
423         MatmulOp::getOperationName(), context,
424         LinalgTilingOptions()
425             .setTileSizes({8, 8, 4})
426             .setLoopType(LinalgTilingLoopType::ParallelLoops)
427             .setDistributionOptions(cyclicNprocsGeNiters),
428         LinalgTransformationFilter(
429             StringAttr::get(context, "distribute2"),
430             StringAttr::get(context, "after_distribute2")));
431   }
432 
433   {
434     LinalgLoopDistributionOptions cyclicNprocsDefault;
435     cyclicNprocsDefault.distributionMethod.resize(2,
436                                                   DistributionMethod::Cyclic);
437     cyclicNprocsDefault.procInfo =
438         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
439     patterns.add<LinalgTilingPattern>(
440         MatmulOp::getOperationName(), context,
441         LinalgTilingOptions()
442             .setTileSizes({8, 8, 4})
443             .setLoopType(LinalgTilingLoopType::ParallelLoops)
444             .setDistributionOptions(cyclicNprocsDefault),
445         LinalgTransformationFilter(
446             StringAttr::get(context, "distribute3"),
447             StringAttr::get(context, "after_distribute3")));
448   }
449 
450   {
451     LinalgLoopDistributionOptions cyclicNprocsMixed1;
452     cyclicNprocsMixed1.distributionMethod = {
453         DistributionMethod::CyclicNumProcsEqNumIters,
454         DistributionMethod::CyclicNumProcsGeNumIters};
455     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
456     patterns.add<LinalgTilingPattern>(
457         MatmulOp::getOperationName(), context,
458         LinalgTilingOptions()
459             .setTileSizes({8, 8, 4})
460             .setLoopType(LinalgTilingLoopType::ParallelLoops)
461             .setDistributionOptions(cyclicNprocsMixed1),
462         LinalgTransformationFilter(
463             StringAttr::get(context, "distribute4"),
464             StringAttr::get(context, "after_distribute4")));
465   }
466 
467   {
468     LinalgLoopDistributionOptions cyclicNprocsMixed2;
469     cyclicNprocsMixed2.distributionMethod = {
470         DistributionMethod::CyclicNumProcsGeNumIters,
471         DistributionMethod::Cyclic};
472     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
473     patterns.add<LinalgTilingPattern>(
474         MatmulOp::getOperationName(), context,
475         LinalgTilingOptions()
476             .setTileSizes({8, 8, 4})
477             .setLoopType(LinalgTilingLoopType::ParallelLoops)
478             .setDistributionOptions(cyclicNprocsMixed2),
479         LinalgTransformationFilter(
480             StringAttr::get(context, "distribute5"),
481             StringAttr::get(context, "after_distribute5")));
482   }
483 
484   {
485     LinalgLoopDistributionOptions cyclicNprocsMixed3;
486     cyclicNprocsMixed3.distributionMethod = {
487         DistributionMethod::Cyclic,
488         DistributionMethod::CyclicNumProcsEqNumIters};
489     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
490 
491     patterns.add<LinalgTilingPattern>(
492         MatmulOp::getOperationName(), context,
493         LinalgTilingOptions()
494             .setTileSizes({8, 8, 4})
495             .setLoopType(LinalgTilingLoopType::ParallelLoops)
496             .setDistributionOptions(cyclicNprocsMixed3),
497         LinalgTransformationFilter(
498             StringAttr::get(context, "distribute6"),
499             StringAttr::get(context, "after_distribute6")));
500   }
501 
502   {
503     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
504     cyclicNprocsEqNiters.distributionMethod.resize(2,
505                                                    DistributionMethod::Cyclic);
506     cyclicNprocsEqNiters.procInfo =
507         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
508     patterns.add<LinalgTilingPattern>(
509         MatmulOp::getOperationName(), context,
510         LinalgTilingOptions()
511             .setTileSizes({8, 8, 4})
512             .setLoopType(LinalgTilingLoopType::Loops)
513             .setDistributionOptions(cyclicNprocsEqNiters),
514         LinalgTransformationFilter(
515             StringAttr::get(context, "tensors_distribute1"),
516             StringAttr::get(context, "tensors_after_distribute1")));
517   }
518 }
519 
520 static void fillTileFuseAndDistributePatterns(MLIRContext *context,
521                                               RewritePatternSet &patterns) {
522   LinalgLoopDistributionOptions cyclicNprocsEqNiters;
523   cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
524   cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
525   patterns.add<LinalgTileAndFuseTensorOpsPattern>(
526       MatmulOp::getOperationName(), context,
527       LinalgTilingAndFusionOptions()
528           .setTileSizes({8, 8, 4})
529           .setDistributionOptions(cyclicNprocsEqNiters),
530       LinalgTransformationFilter(
531           StringAttr::get(context, "tensors_fuse_distribute1"),
532           StringAttr::get(context, "tensors_after_fuse_distribute1")));
533 }
534 
535 static void
536 applyMatmulToVectorPatterns(func::FuncOp funcOp,
537                             bool testMatmulToVectorPatterns1dTiling,
538                             bool testMatmulToVectorPatterns2dTiling) {
539   MLIRContext *ctx = funcOp.getContext();
540   SmallVector<RewritePatternSet, 4> stage1Patterns;
541   if (testMatmulToVectorPatterns1dTiling) {
542     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
543   } else if (testMatmulToVectorPatterns2dTiling) {
544     stage1Patterns.emplace_back(
545         ctx, std::make_unique<LinalgTilingPattern>(
546                  MatmulOp::getOperationName(), ctx,
547                  LinalgTilingOptions()
548                      .setTileSizes({768, 264, 768})
549                      .setInterchange({1, 2, 0}),
550                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
551                                             StringAttr::get(ctx, "L2"))));
552     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
553   }
554   {
555     // Canonicalization patterns
556     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
557     vector::populateVectorTransferPermutationMapLoweringPatterns(
558         canonicalizationPatterns);
559     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
560     stage1Patterns.push_back(std::move(canonicalizationPatterns));
561   }
562   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
563   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
564   FrozenRewritePatternSet stage2Patterns =
565       getLinalgTilingCanonicalizationPatterns(ctx);
566   (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns);
567 }
568 
569 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
570   RewritePatternSet forwardPattern(funcOp.getContext());
571   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
572   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
573   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
574 }
575 
576 static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
577   RewritePatternSet patterns(funcOp.getContext());
578   auto *ctx = funcOp.getContext();
579   patterns.add<LinalgVectorizationPattern>(
580       ctx, LinalgTransformationFilter()
581                .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
582   patterns.add<CopyVectorizationPattern>(ctx);
583   populatePadOpVectorizationPatterns(patterns);
584   populateConvolutionVectorizationPatterns(patterns);
585   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
586 }
587 
588 static void applyPadTensorToGenericPatterns(func::FuncOp funcOp) {
589   RewritePatternSet patterns(funcOp.getContext());
590   patterns.add<PadOpTransformationPattern>(funcOp.getContext());
591   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
592 }
593 
594 static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
595   RewritePatternSet patterns(funcOp.getContext());
596   patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
597   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
598 }
599 
600 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
601   RewritePatternSet patterns(funcOp.getContext());
602   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
603   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
604 }
605 
606 static void applyTilePattern(func::FuncOp funcOp, const std::string &loopType,
607                              ArrayRef<int64_t> tileSizes,
608                              ArrayRef<int64_t> peeledLoops,
609                              bool scalarizeDynamicDims) {
610   MLIRContext *context = funcOp.getContext();
611   RewritePatternSet tilingPattern(context);
612   LinalgTilingLoopType type =
613       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
614           .Case("for", LinalgTilingLoopType::Loops)
615           .Case("affine", LinalgTilingLoopType::AffineLoops)
616           .Case("parallel", LinalgTilingLoopType::ParallelLoops);
617   auto linalgTilingOptions = linalg::LinalgTilingOptions()
618                                  .setPeeledLoops(peeledLoops)
619                                  .setLoopType(type);
620   if (scalarizeDynamicDims) {
621     linalgTilingOptions.scalarizeDynamicDims();
622     assert(tileSizes.empty() &&
623            "tileSizes and scalarizeDynamicDims is mutually exclusive");
624   } else {
625     linalgTilingOptions.setTileSizes(tileSizes);
626   }
627   linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
628   TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
629       tilingPattern, linalgTilingOptions, f);
630   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
631 }
632 
633 static void applySplitReduction(func::FuncOp funcOp) {
634   RewritePatternSet patterns(funcOp.getContext());
635   linalg::populateSplitReductionPattern(
636       patterns,
637       [](LinalgOp op) {
638         unsigned insertDimIndex = op.getNumLoops() - 1;
639         return std::make_pair(4, insertDimIndex);
640       },
641       LinalgTransformationFilter(
642           ArrayRef<StringAttr>{},
643           StringAttr::get(funcOp.getContext(), "SPLIT")));
644   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
645 }
646 
647 static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
648   RewritePatternSet patterns(funcOp.getContext());
649   populateBubbleUpExtractSliceOpPatterns(patterns);
650   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
651 }
652 
653 /// Apply transformations specified as patterns.
654 void TestLinalgTransforms::runOnOperation() {
655   auto lambda = [&](void *) {
656     getOperation().walk([](LinalgOp op) {
657       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
658     });
659   };
660   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
661 
662   if (testPromotionOptions) {
663     RewritePatternSet patterns(&getContext());
664     fillPromotionCallBackPatterns(&getContext(), patterns);
665     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
666     return;
667   }
668   if (testTileAndDistributionOptions) {
669     RewritePatternSet patterns(&getContext());
670     fillTileAndDistributePatterns(&getContext(), patterns);
671     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
672     return;
673   }
674   if (testTileFuseAndDistributionOptions) {
675     RewritePatternSet patterns(&getContext());
676     fillTileFuseAndDistributePatterns(&getContext(), patterns);
677     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
678     return;
679   }
680   if (testPatterns)
681     return applyPatterns(getOperation());
682   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
683     return applyMatmulToVectorPatterns(getOperation(),
684                                        testMatmulToVectorPatterns1dTiling,
685                                        testMatmulToVectorPatterns2dTiling);
686   if (testVectorTransferForwardingPatterns)
687     return applyVectorTransferForwardingPatterns(getOperation());
688   if (testGenericToVectorPattern)
689     return applyLinalgToVectorPatterns(getOperation());
690   if (testTransformPadTensor)
691     return applyPadTensorToGenericPatterns(getOperation());
692   if (testGeneralizePadTensor)
693     return applyGeneralizePadTensorPatterns(getOperation());
694   if (testSwapSubTensorPadTensor)
695     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
696   if (testTilePattern)
697     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
698                             /*scalarizeDynamicDims=*/false);
699   if (testTileScalarizeDynamicDims)
700     return applyTilePattern(getOperation(), loopType, tileSizes,
701                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
702   if (testSplitReduction)
703     return applySplitReduction(getOperation());
704   if (testBubbleUpExtractSliceOpPattern)
705     return applyBubbleUpExtractSliceOpPattern(getOperation());
706 }
707 
708 namespace mlir {
709 namespace test {
710 void registerTestLinalgTransforms() {
711   PassRegistration<TestLinalgTransforms>();
712 }
713 } // namespace test
714 } // namespace mlir
715