1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct TestLinalgTransforms
35     : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> {
36   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms)
37 
38   TestLinalgTransforms() = default;
39   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
40 
41   void getDependentDialects(DialectRegistry &registry) const override {
42     // clang-format off
43     registry.insert<AffineDialect,
44                     memref::MemRefDialect,
45                     scf::SCFDialect,
46                     linalg::LinalgDialect,
47                     vector::VectorDialect,
48                     gpu::GPUDialect>();
49     // clang-format on
50   }
51   StringRef getArgument() const final {
52     return "test-linalg-transform-patterns";
53   }
54   StringRef getDescription() const final {
55     return "Test Linalg transformation patterns by applying them greedily.";
56   }
57 
58   void runOnOperation() override;
59 
60   Option<bool> testPatterns{*this, "test-patterns",
61                             llvm::cl::desc("Test a mixed set of patterns"),
62                             llvm::cl::init(false)};
63   Option<bool> testMatmulToVectorPatterns1dTiling{
64       *this, "test-matmul-to-vector-patterns-tile-1d",
65       llvm::cl::desc(
66           "Test a fused pass that applies patterns from matmul to vectors via "
67           "1-d tiling"),
68       llvm::cl::init(false)};
69   Option<bool> testMatmulToVectorPatterns2dTiling{
70       *this, "test-matmul-to-vector-patterns-tile-2d",
71       llvm::cl::desc(
72           "Test a fused pass that applies patterns from matmul to vectors via "
73           "2-d tiling"),
74       llvm::cl::init(false)};
75   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
76                                     llvm::cl::desc("Test promotion options"),
77                                     llvm::cl::init(false)};
78   Option<bool> testTileAndDistributionOptions{
79       *this, "test-tile-and-distribute-options",
80       llvm::cl::desc("Test tile and distribute options"),
81       llvm::cl::init(false)};
82   Option<bool> testTileFuseAndDistributionOptions{
83       *this, "test-tile-fuse-and-distribute-options",
84       llvm::cl::desc("Test tile, fuse and distribute options"),
85       llvm::cl::init(false)};
86   Option<bool> testVectorTransferForwardingPatterns{
87       *this, "test-vector-transfer-forwarding-patterns",
88       llvm::cl::desc(
89           "Test a fused pass that forwards memref.copy to vector.transfer"),
90       llvm::cl::init(false)};
91   Option<bool> testGenericToVectorPattern{
92       *this, "test-linalg-to-vector-patterns",
93       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
94                      "in vector.contract form"),
95       llvm::cl::init(false)};
96   Option<bool> testTilePattern{*this, "test-tile-pattern",
97                                llvm::cl::desc("Test tile pattern"),
98                                llvm::cl::init(false)};
99   Option<bool> testTileScalarizeDynamicDims{
100       *this, "test-tile-scalarize-dynamic-dims",
101       llvm::cl::desc("Test tiling of dynamic dims by 1"),
102       llvm::cl::init(false)};
103   Option<bool> testTransformPadTensor{
104       *this, "test-transform-pad-tensor",
105       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
106       llvm::cl::init(false)};
107   Option<bool> testGeneralizePadTensor{
108       *this, "test-generalize-pad-tensor",
109       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
110       llvm::cl::init(false)};
111   Option<bool> testSwapSubTensorPadTensor{
112       *this, "test-swap-subtensor-padtensor",
113       llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
114                      "tensor.pad(subtensor)"),
115       llvm::cl::init(false)};
116   Option<bool> testSplitReduction{
117       *this, "test-split-reduction",
118       llvm::cl::desc("Test split reduction transformation"),
119       llvm::cl::init(false)};
120   ListOption<int64_t> peeledLoops{
121       *this, "peeled-loops",
122       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
123       llvm::cl::ZeroOrMore};
124   ListOption<int64_t> tileSizes{
125       *this, "tile-sizes",
126       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
127       llvm::cl::ZeroOrMore};
128   Option<bool> skipPartial{
129       *this, "skip-partial",
130       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
131       llvm::cl::init(false)};
132   Option<std::string> loopType{
133       *this, "loop-type",
134       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
135                      "tiled_loop"),
136       llvm::cl::init("for")};
137   Option<bool> testBubbleUpExtractSliceOpPattern{
138       *this, "test-bubble-up-extract-slice-op-pattern",
139       llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
140                      "extract_slice + linalgOp"),
141       llvm::cl::init(false)};
142 };
143 } // namespace
144 
145 static void applyPatterns(func::FuncOp funcOp) {
146   MLIRContext *ctx = funcOp.getContext();
147   RewritePatternSet patterns(ctx);
148 
149   //===--------------------------------------------------------------------===//
150   // Linalg tiling patterns.
151   //===--------------------------------------------------------------------===//
152   patterns.add<LinalgTilingPattern>(
153       MatmulOp::getOperationName(), ctx,
154       LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
155       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
156                                  StringAttr::get(ctx, "L3")));
157   patterns.add<LinalgTilingPattern>(
158       MatmulOp::getOperationName(), ctx,
159       LinalgTilingOptions().setTileSizes({200, 300, 400}),
160       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
161                                  StringAttr::get(ctx, "L2")));
162   patterns.add<LinalgTilingPattern>(
163       MatmulOp::getOperationName(), ctx,
164       LinalgTilingOptions().setTileSizes({20, 30, 40}),
165       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
166                                  StringAttr::get(ctx, "L1")));
167   patterns.add<LinalgTilingPattern>(
168       MatmulOp::getOperationName(), ctx,
169       LinalgTilingOptions().setTileSizes({2, 3, 4}),
170       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
171                                  StringAttr::get(ctx, "REG")));
172 
173   patterns.add<LinalgTilingPattern>(
174       MatvecOp::getOperationName(), ctx,
175       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
176           LinalgTilingLoopType::ParallelLoops),
177       LinalgTransformationFilter(ArrayRef<StringAttr>{},
178                                  StringAttr::get(ctx, "L1")));
179 
180   patterns.add<LinalgTilingPattern>(
181       DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
182       LinalgTransformationFilter(
183           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
184                                StringAttr::get(ctx, "L3"),
185                                StringAttr::get(ctx, "L2")},
186           StringAttr::get(ctx, "REG")));
187 
188   //===--------------------------------------------------------------------===//
189   // Linalg tiling and permutation patterns.
190   //===--------------------------------------------------------------------===//
191   patterns.add<LinalgTilingPattern>(
192       MatmulOp::getOperationName(), ctx,
193       LinalgTilingOptions()
194           .setTileSizes({2000, 3000, 4000})
195           .setInterchange({1, 2, 0}),
196       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
197                                  StringAttr::get(ctx, "L2__with_perm__")));
198   patterns.add<LinalgTilingPattern>(
199       MatmulOp::getOperationName(), ctx,
200       LinalgTilingOptions()
201           .setTileSizes({200, 300, 400})
202           .setInterchange({1, 0, 2}),
203       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
204                                  StringAttr::get(ctx, "L1__with_perm__")));
205   patterns.add<LinalgTilingPattern>(
206       MatmulOp::getOperationName(), ctx,
207       LinalgTilingOptions().setTileSizes({20, 30, 40}),
208       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
209                                  StringAttr::get(ctx, "REG__with_perm__")));
210 
211   patterns.add<LinalgTilingPattern>(
212       MatvecOp::getOperationName(), ctx,
213       LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
214       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
215                                  StringAttr::get(ctx, "L1__with_perm__")));
216 
217   patterns.add<LinalgTilingPattern>(
218       MatmulOp::getOperationName(), ctx,
219       LinalgTilingOptions()
220           .setTileSizes({16, 8, 4})
221           .setInterchange({1, 2, 0})
222           .setLoopType(LinalgTilingLoopType::ParallelLoops),
223       LinalgTransformationFilter(
224           StringAttr::get(ctx, "par__with_perm__"),
225           StringAttr::get(ctx, "after_par__with_perm__")));
226 
227   //===--------------------------------------------------------------------===//
228   // Linalg to loops patterns.
229   //===--------------------------------------------------------------------===//
230   patterns.add<LinalgLoweringPattern<DotOp>>(
231       ctx,
232       /*loweringType=*/LinalgLoweringType::Loops,
233       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
234 
235   //===--------------------------------------------------------------------===//
236   // Linalg distribution patterns.
237   //===--------------------------------------------------------------------===//
238   LinalgLoopDistributionOptions distributionOptions;
239 
240   //===--------------------------------------------------------------------===//
241   // Linalg to vector contraction patterns.
242   //===--------------------------------------------------------------------===//
243   patterns.add<LinalgVectorizationPattern>(
244       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
245                .addOpFilter<MatmulOp, FillOp, GenericOp>());
246   patterns.add<CopyVectorizationPattern>(ctx);
247 
248   //===--------------------------------------------------------------------===//
249   // Linalg generic interchange pattern.
250   //===--------------------------------------------------------------------===//
251   patterns.add<GenericOpInterchangePattern>(
252       ctx,
253       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
254       LinalgTransformationFilter(ArrayRef<StringAttr>{},
255                                  StringAttr::get(ctx, "PERMUTED")));
256 
257   //===--------------------------------------------------------------------===//
258   // Linalg subview operands promotion.
259   //===--------------------------------------------------------------------===//
260   patterns.add<LinalgPromotionPattern<MatmulOp>>(
261       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
262       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
263                                  StringAttr::get(ctx, "_views_promoted_")));
264   patterns.add<LinalgPromotionPattern<MatmulOp>>(
265       ctx,
266       LinalgPromotionOptions()
267           .setOperandsToPromote({0})
268           .setUseFullTileBuffersByDefault(true),
269       LinalgTransformationFilter(
270           StringAttr::get(ctx, "_promote_first_view_"),
271           StringAttr::get(ctx, "_first_view_promoted_")));
272   patterns.add<LinalgPromotionPattern<FillOp>>(
273       ctx,
274       LinalgPromotionOptions()
275           .setOperandsToPromote({1})
276           .setUseFullTileBuffers({false, true})
277           .setAlignment(32),
278       LinalgTransformationFilter(
279           StringAttr::get(ctx, "_promote_views_aligned_"),
280           StringAttr::get(ctx, "_views_aligned_promoted_")));
281 
282   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
283 
284   // Drop the marker.
285   funcOp.walk([](LinalgOp op) {
286     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
287   });
288 }
289 
290 static void fillL1TilingAndMatmulToVectorPatterns(
291     func::FuncOp funcOp, StringRef startMarker,
292     SmallVectorImpl<RewritePatternSet> &patternsVector) {
293   MLIRContext *ctx = funcOp.getContext();
294   patternsVector.emplace_back(
295       ctx, std::make_unique<LinalgTilingPattern>(
296                MatmulOp::getOperationName(), ctx,
297                LinalgTilingOptions()
298                    .setTileSizes({8, 12, 16})
299                    .setInterchange({1, 0, 2}),
300                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
301                                           StringAttr::get(ctx, "L1"))));
302 
303   patternsVector.emplace_back(
304       ctx,
305       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
306           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
307           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
308                                      StringAttr::get(ctx, "VEC"))));
309 
310   patternsVector.emplace_back(
311       ctx, std::make_unique<LinalgVectorizationPattern>(
312                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
313                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
314   patternsVector.back().add<LinalgVectorizationPattern>(
315       ctx, LinalgTransformationFilter().addOpFilter<FillOp>());
316   patternsVector.back().add<CopyVectorizationPattern>(ctx);
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // Test promotion callbacks
321 //===----------------------------------------------------------------------===//
322 
323 // Allocation call back
324 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
325                                        ArrayRef<Value> boundingSubViewSize,
326                                        DataLayout &layout) {
327   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
328   return b
329       .create<memref::AllocOp>(
330           subView.getLoc(),
331           MemRefType::get(shape, subView.getType().getElementType(),
332                           /*affineMapComposition =*/{}, 3),
333           boundingSubViewSize)
334       .getResult();
335 }
336 
337 // Deallocation callback
338 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
339   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
340   return success();
341 }
342 
343 // Copy in call back
344 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
345                                     bool isOutput) {
346   auto floatType = src.getType().cast<MemRefType>().getElementType();
347   if (!floatType.isa<FloatType>())
348     return failure();
349   if (!isOutput) {
350     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
351                                             FloatAttr::get(floatType, 42.0));
352     b.create<FillOp>(src.getLoc(), cst, dst);
353   }
354   b.create<memref::CopyOp>(src.getLoc(), src, dst);
355   return success();
356 }
357 
358 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
359                                           RewritePatternSet &patterns) {
360   patterns.add<LinalgTilingPattern>(
361       MatmulOp::getOperationName(), ctx,
362       LinalgTilingOptions().setTileSizes({16, 16, 16}),
363       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
364                                  StringAttr::get(ctx, "PROMOTE")));
365   patterns.add<LinalgPromotionPattern<MatmulOp>>(
366       ctx,
367       LinalgPromotionOptions()
368           .setOperandsToPromote({0, 2})
369           .setUseFullTileBuffers({false, false})
370           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
371           .setCopyInOutFns(
372               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
373                 return copyCallBackFn(b, src, dst, false);
374               },
375               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
376                 return copyCallBackFn(b, src, dst, true);
377               }),
378       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
379 }
380 
381 template <typename IdOp, typename NProcsOp>
382 static SmallVector<ProcInfo, 2>
383 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
384   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
385   SmallVector<ProcInfo, 2> procInfo(count);
386   Type indexType = b.getIndexType();
387   for (unsigned i = 0; i < count; ++i) {
388     gpu::Dimension dim = *gpu::symbolizeDimension(i);
389     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
390                                b.create<NProcsOp>(loc, indexType, dim)};
391   }
392   return procInfo;
393 }
394 
395 static void fillTileAndDistributePatterns(MLIRContext *context,
396                                           RewritePatternSet &patterns) {
397   {
398     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
399     cyclicNprocsEqNiters.distributionMethod.resize(
400         2, DistributionMethod::CyclicNumProcsEqNumIters);
401     cyclicNprocsEqNiters.procInfo =
402         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
403     patterns.add<LinalgTilingPattern>(
404         MatmulOp::getOperationName(), context,
405         LinalgTilingOptions()
406             .setTileSizes({8, 8, 4})
407             .setLoopType(LinalgTilingLoopType::ParallelLoops)
408             .setDistributionOptions(cyclicNprocsEqNiters),
409         LinalgTransformationFilter(
410             StringAttr::get(context, "distribute1"),
411             StringAttr::get(context, "after_distribute1")));
412   }
413 
414   {
415     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
416     cyclicNprocsGeNiters.distributionMethod.resize(
417         2, DistributionMethod::CyclicNumProcsGeNumIters);
418     cyclicNprocsGeNiters.procInfo =
419         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
420     patterns.add<LinalgTilingPattern>(
421         MatmulOp::getOperationName(), context,
422         LinalgTilingOptions()
423             .setTileSizes({8, 8, 4})
424             .setLoopType(LinalgTilingLoopType::ParallelLoops)
425             .setDistributionOptions(cyclicNprocsGeNiters),
426         LinalgTransformationFilter(
427             StringAttr::get(context, "distribute2"),
428             StringAttr::get(context, "after_distribute2")));
429   }
430 
431   {
432     LinalgLoopDistributionOptions cyclicNprocsDefault;
433     cyclicNprocsDefault.distributionMethod.resize(2,
434                                                   DistributionMethod::Cyclic);
435     cyclicNprocsDefault.procInfo =
436         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
437     patterns.add<LinalgTilingPattern>(
438         MatmulOp::getOperationName(), context,
439         LinalgTilingOptions()
440             .setTileSizes({8, 8, 4})
441             .setLoopType(LinalgTilingLoopType::ParallelLoops)
442             .setDistributionOptions(cyclicNprocsDefault),
443         LinalgTransformationFilter(
444             StringAttr::get(context, "distribute3"),
445             StringAttr::get(context, "after_distribute3")));
446   }
447 
448   {
449     LinalgLoopDistributionOptions cyclicNprocsMixed1;
450     cyclicNprocsMixed1.distributionMethod = {
451         DistributionMethod::CyclicNumProcsEqNumIters,
452         DistributionMethod::CyclicNumProcsGeNumIters};
453     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
454     patterns.add<LinalgTilingPattern>(
455         MatmulOp::getOperationName(), context,
456         LinalgTilingOptions()
457             .setTileSizes({8, 8, 4})
458             .setLoopType(LinalgTilingLoopType::ParallelLoops)
459             .setDistributionOptions(cyclicNprocsMixed1),
460         LinalgTransformationFilter(
461             StringAttr::get(context, "distribute4"),
462             StringAttr::get(context, "after_distribute4")));
463   }
464 
465   {
466     LinalgLoopDistributionOptions cyclicNprocsMixed2;
467     cyclicNprocsMixed2.distributionMethod = {
468         DistributionMethod::CyclicNumProcsGeNumIters,
469         DistributionMethod::Cyclic};
470     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
471     patterns.add<LinalgTilingPattern>(
472         MatmulOp::getOperationName(), context,
473         LinalgTilingOptions()
474             .setTileSizes({8, 8, 4})
475             .setLoopType(LinalgTilingLoopType::ParallelLoops)
476             .setDistributionOptions(cyclicNprocsMixed2),
477         LinalgTransformationFilter(
478             StringAttr::get(context, "distribute5"),
479             StringAttr::get(context, "after_distribute5")));
480   }
481 
482   {
483     LinalgLoopDistributionOptions cyclicNprocsMixed3;
484     cyclicNprocsMixed3.distributionMethod = {
485         DistributionMethod::Cyclic,
486         DistributionMethod::CyclicNumProcsEqNumIters};
487     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
488 
489     patterns.add<LinalgTilingPattern>(
490         MatmulOp::getOperationName(), context,
491         LinalgTilingOptions()
492             .setTileSizes({8, 8, 4})
493             .setLoopType(LinalgTilingLoopType::ParallelLoops)
494             .setDistributionOptions(cyclicNprocsMixed3),
495         LinalgTransformationFilter(
496             StringAttr::get(context, "distribute6"),
497             StringAttr::get(context, "after_distribute6")));
498   }
499 
500   {
501     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
502     cyclicNprocsEqNiters.distributionMethod.resize(2,
503                                                    DistributionMethod::Cyclic);
504     cyclicNprocsEqNiters.procInfo =
505         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
506     patterns.add<LinalgTilingPattern>(
507         MatmulOp::getOperationName(), context,
508         LinalgTilingOptions()
509             .setTileSizes({8, 8, 4})
510             .setLoopType(LinalgTilingLoopType::Loops)
511             .setDistributionOptions(cyclicNprocsEqNiters),
512         LinalgTransformationFilter(
513             StringAttr::get(context, "tensors_distribute1"),
514             StringAttr::get(context, "tensors_after_distribute1")));
515   }
516 }
517 
518 static void fillTileFuseAndDistributePatterns(MLIRContext *context,
519                                               RewritePatternSet &patterns) {
520   LinalgLoopDistributionOptions cyclicNprocsEqNiters;
521   cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
522   cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
523   patterns.add<LinalgTileAndFuseTensorOpsPattern>(
524       MatmulOp::getOperationName(), context,
525       LinalgTilingAndFusionOptions()
526           .setTileSizes({8, 8, 4})
527           .setDistributionOptions(cyclicNprocsEqNiters),
528       LinalgTransformationFilter(
529           StringAttr::get(context, "tensors_fuse_distribute1"),
530           StringAttr::get(context, "tensors_after_fuse_distribute1")));
531 }
532 
533 static void
534 applyMatmulToVectorPatterns(func::FuncOp funcOp,
535                             bool testMatmulToVectorPatterns1dTiling,
536                             bool testMatmulToVectorPatterns2dTiling) {
537   MLIRContext *ctx = funcOp.getContext();
538   SmallVector<RewritePatternSet, 4> stage1Patterns;
539   if (testMatmulToVectorPatterns1dTiling) {
540     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
541   } else if (testMatmulToVectorPatterns2dTiling) {
542     stage1Patterns.emplace_back(
543         ctx, std::make_unique<LinalgTilingPattern>(
544                  MatmulOp::getOperationName(), ctx,
545                  LinalgTilingOptions()
546                      .setTileSizes({768, 264, 768})
547                      .setInterchange({1, 2, 0}),
548                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
549                                             StringAttr::get(ctx, "L2"))));
550     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
551   }
552   {
553     // Canonicalization patterns
554     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
555     vector::populateVectorTransferPermutationMapLoweringPatterns(
556         canonicalizationPatterns);
557     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
558     stage1Patterns.push_back(std::move(canonicalizationPatterns));
559   }
560   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
561   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
562   FrozenRewritePatternSet stage2Patterns =
563       getLinalgTilingCanonicalizationPatterns(ctx);
564   (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns);
565 }
566 
567 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
568   RewritePatternSet forwardPattern(funcOp.getContext());
569   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
570   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
571   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
572 }
573 
574 static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
575   RewritePatternSet patterns(funcOp.getContext());
576   auto *ctx = funcOp.getContext();
577   patterns.add<LinalgVectorizationPattern>(
578       ctx, LinalgTransformationFilter()
579                .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
580   patterns.add<CopyVectorizationPattern>(ctx);
581   populatePadOpVectorizationPatterns(patterns);
582   populateConvolutionVectorizationPatterns(patterns);
583   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
584 }
585 
586 static void applyPadTensorToGenericPatterns(func::FuncOp funcOp) {
587   RewritePatternSet patterns(funcOp.getContext());
588   patterns.add<PadOpTransformationPattern>(funcOp.getContext());
589   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
590 }
591 
592 static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
593   RewritePatternSet patterns(funcOp.getContext());
594   patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
595   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
596 }
597 
598 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
599   RewritePatternSet patterns(funcOp.getContext());
600   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
601   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
602 }
603 
604 static void applyTilePattern(func::FuncOp funcOp, const std::string &loopType,
605                              ArrayRef<int64_t> tileSizes,
606                              ArrayRef<int64_t> peeledLoops,
607                              bool scalarizeDynamicDims) {
608   MLIRContext *context = funcOp.getContext();
609   RewritePatternSet tilingPattern(context);
610   LinalgTilingLoopType type =
611       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
612           .Case("for", LinalgTilingLoopType::Loops)
613           .Case("affine", LinalgTilingLoopType::AffineLoops)
614           .Case("parallel", LinalgTilingLoopType::ParallelLoops);
615   auto linalgTilingOptions = linalg::LinalgTilingOptions()
616                                  .setPeeledLoops(peeledLoops)
617                                  .setLoopType(type);
618   if (scalarizeDynamicDims) {
619     linalgTilingOptions.scalarizeDynamicDims();
620     assert(tileSizes.empty() &&
621            "tileSizes and scalarizeDynamicDims is mutually exclusive");
622   } else {
623     linalgTilingOptions.setTileSizes(tileSizes);
624   }
625   linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
626   TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
627       tilingPattern, linalgTilingOptions, f);
628   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
629 }
630 
631 static void applySplitReduction(func::FuncOp funcOp) {
632   RewritePatternSet patterns(funcOp.getContext());
633   linalg::populateSplitReductionPattern(
634       patterns,
635       [](LinalgOp op) {
636         unsigned insertDimIndex = op.getNumLoops() - 1;
637         return std::make_pair(4, insertDimIndex);
638       },
639       LinalgTransformationFilter(
640           ArrayRef<StringAttr>{},
641           StringAttr::get(funcOp.getContext(), "SPLIT")));
642   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
643 }
644 
645 static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
646   RewritePatternSet patterns(funcOp.getContext());
647   populateBubbleUpExtractSliceOpPatterns(patterns);
648   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
649 }
650 
651 /// Apply transformations specified as patterns.
652 void TestLinalgTransforms::runOnOperation() {
653   auto lambda = [&](void *) {
654     getOperation().walk([](LinalgOp op) {
655       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
656     });
657   };
658   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
659 
660   if (testPromotionOptions) {
661     RewritePatternSet patterns(&getContext());
662     fillPromotionCallBackPatterns(&getContext(), patterns);
663     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
664     return;
665   }
666   if (testTileAndDistributionOptions) {
667     RewritePatternSet patterns(&getContext());
668     fillTileAndDistributePatterns(&getContext(), patterns);
669     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
670     return;
671   }
672   if (testTileFuseAndDistributionOptions) {
673     RewritePatternSet patterns(&getContext());
674     fillTileFuseAndDistributePatterns(&getContext(), patterns);
675     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
676     return;
677   }
678   if (testPatterns)
679     return applyPatterns(getOperation());
680   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
681     return applyMatmulToVectorPatterns(getOperation(),
682                                        testMatmulToVectorPatterns1dTiling,
683                                        testMatmulToVectorPatterns2dTiling);
684   if (testVectorTransferForwardingPatterns)
685     return applyVectorTransferForwardingPatterns(getOperation());
686   if (testGenericToVectorPattern)
687     return applyLinalgToVectorPatterns(getOperation());
688   if (testTransformPadTensor)
689     return applyPadTensorToGenericPatterns(getOperation());
690   if (testGeneralizePadTensor)
691     return applyGeneralizePadTensorPatterns(getOperation());
692   if (testSwapSubTensorPadTensor)
693     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
694   if (testTilePattern)
695     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
696                             /*scalarizeDynamicDims=*/false);
697   if (testTileScalarizeDynamicDims)
698     return applyTilePattern(getOperation(), loopType, tileSizes,
699                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
700   if (testSplitReduction)
701     return applySplitReduction(getOperation());
702   if (testBubbleUpExtractSliceOpPattern)
703     return applyBubbleUpExtractSliceOpPattern(getOperation());
704 }
705 
706 namespace mlir {
707 namespace test {
708 void registerTestLinalgTransforms() {
709   PassRegistration<TestLinalgTransforms>();
710 }
711 } // namespace test
712 } // namespace mlir
713