1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
18 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 namespace {
33 struct TestLinalgTransforms
34     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
35   TestLinalgTransforms() = default;
36   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
37 
38   void getDependentDialects(DialectRegistry &registry) const override {
39     // clang-format off
40     registry.insert<AffineDialect,
41                     memref::MemRefDialect,
42                     scf::SCFDialect,
43                     StandardOpsDialect,
44                     vector::VectorDialect,
45                     gpu::GPUDialect>();
46     // clang-format on
47   }
48   StringRef getArgument() const final {
49     return "test-linalg-transform-patterns";
50   }
51   StringRef getDescription() const final {
52     return "Test Linalg transformation patterns by applying them greedily.";
53   }
54 
55   void runOnFunction() override;
56 
57   Option<bool> testPatterns{*this, "test-patterns",
58                             llvm::cl::desc("Test a mixed set of patterns"),
59                             llvm::cl::init(false)};
60   Option<bool> testMatmulToVectorPatterns1dTiling{
61       *this, "test-matmul-to-vector-patterns-tile-1d",
62       llvm::cl::desc(
63           "Test a fused pass that applies patterns from matmul to vectors via "
64           "1-d tiling"),
65       llvm::cl::init(false)};
66   Option<bool> testMatmulToVectorPatterns2dTiling{
67       *this, "test-matmul-to-vector-patterns-tile-2d",
68       llvm::cl::desc(
69           "Test a fused pass that applies patterns from matmul to vectors via "
70           "2-d tiling"),
71       llvm::cl::init(false)};
72   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
73                                     llvm::cl::desc("Test promotion options"),
74                                     llvm::cl::init(false)};
75   Option<bool> testTileAndDistributionOptions{
76       *this, "test-tile-and-distribute-options",
77       llvm::cl::desc("Test tile and distribute options"),
78       llvm::cl::init(false)};
79   Option<bool> testVectorTransferForwardingPatterns{
80       *this, "test-vector-transfer-forwarding-patterns",
81       llvm::cl::desc(
82           "Test a fused pass that forwards linalg.copy to vector.transfer"),
83       llvm::cl::init(false)};
84   Option<bool> testGenericToVectorPattern{
85       *this, "test-linalg-to-vector-patterns",
86       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
87                      "in vector.contract form"),
88       llvm::cl::init(false)};
89   Option<bool> testTilePattern{*this, "test-tile-pattern",
90                                llvm::cl::desc("Test tile pattern"),
91                                llvm::cl::init(false)};
92   Option<bool> testTileScalarizeDynamicDims{
93       *this, "test-tile-scalarize-dynamic-dims",
94       llvm::cl::desc("Test tiling of dynamic dims by 1"),
95       llvm::cl::init(false)};
96   Option<bool> testTransformPadTensor{
97       *this, "test-transform-pad-tensor",
98       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
99       llvm::cl::init(false)};
100   Option<bool> testGeneralizePadTensor{
101       *this, "test-generalize-pad-tensor",
102       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
103       llvm::cl::init(false)};
104   Option<bool> testSwapSubTensorPadTensor{
105       *this, "test-swap-subtensor-padtensor",
106       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
107                      "pad_tensor(subtensor)"),
108       llvm::cl::init(false)};
109   ListOption<int64_t> peeledLoops{
110       *this, "peeled-loops",
111       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
112       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
113   ListOption<int64_t> tileSizes{
114       *this, "tile-sizes",
115       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
116       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
117   ListOption<unsigned> testTiledLoopPeeling{
118       *this, "test-tiled-loop-peeling",
119       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
120       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
121   Option<bool> skipPartial{
122       *this, "skip-partial",
123       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
124       llvm::cl::init(false)};
125   Option<std::string> loopType{
126       *this, "loop-type",
127       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
128                      "tiled_loop"),
129       llvm::cl::init("for")};
130   Option<bool> testDecomposeConvolutionPattern{
131       *this, "test-decompose-convolution-patterns",
132       llvm::cl::desc("Test a set of patterns to rewrite high-D convolution ops "
133                      "into low-D ones"),
134       llvm::cl::init(false)};
135 };
136 } // end anonymous namespace
137 
138 static void applyPatterns(FuncOp funcOp) {
139   MLIRContext *ctx = funcOp.getContext();
140   RewritePatternSet patterns(ctx);
141 
142   //===--------------------------------------------------------------------===//
143   // Linalg tiling patterns.
144   //===--------------------------------------------------------------------===//
145   patterns.add<LinalgTilingPattern<MatmulOp>>(
146       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
147       LinalgTransformationFilter(Identifier::get("MEM", ctx),
148                                  Identifier::get("L3", ctx)));
149   patterns.add<LinalgTilingPattern<MatmulOp>>(
150       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
151       LinalgTransformationFilter(Identifier::get("L3", ctx),
152                                  Identifier::get("L2", ctx)));
153   patterns.add<LinalgTilingPattern<MatmulOp>>(
154       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
155       LinalgTransformationFilter(Identifier::get("L2", ctx),
156                                  Identifier::get("L1", ctx)));
157   patterns.add<LinalgTilingPattern<MatmulOp>>(
158       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
159       LinalgTransformationFilter(Identifier::get("L1", ctx),
160                                  Identifier::get("REG", ctx)));
161 
162   patterns.add<LinalgTilingPattern<MatvecOp>>(
163       ctx,
164       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
165           LinalgTilingLoopType::ParallelLoops),
166       LinalgTransformationFilter(ArrayRef<Identifier>{},
167                                  Identifier::get("L1", ctx)));
168 
169   patterns.add<LinalgTilingPattern<DotOp>>(
170       ctx, LinalgTilingOptions().setTileSizes(8000),
171       LinalgTransformationFilter(
172           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
173                                Identifier::get("L3", ctx),
174                                Identifier::get("L2", ctx)},
175           Identifier::get("REG", ctx)));
176 
177   //===--------------------------------------------------------------------===//
178   // Linalg tiling and permutation patterns.
179   //===--------------------------------------------------------------------===//
180   patterns.add<LinalgTilingPattern<MatmulOp>>(
181       ctx,
182       LinalgTilingOptions()
183           .setTileSizes({2000, 3000, 4000})
184           .setInterchange({1, 2, 0}),
185       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
186                                  Identifier::get("L2__with_perm__", ctx)));
187   patterns.add<LinalgTilingPattern<MatmulOp>>(
188       ctx,
189       LinalgTilingOptions()
190           .setTileSizes({200, 300, 400})
191           .setInterchange({1, 0, 2}),
192       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
193                                  Identifier::get("L1__with_perm__", ctx)));
194   patterns.add<LinalgTilingPattern<MatmulOp>>(
195       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
196       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
197                                  Identifier::get("REG__with_perm__", ctx)));
198 
199   patterns.add<LinalgTilingPattern<MatvecOp>>(
200       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
201       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
202                                  Identifier::get("L1__with_perm__", ctx)));
203 
204   patterns.add<LinalgTilingPattern<MatmulOp>>(
205       ctx,
206       LinalgTilingOptions()
207           .setTileSizes({16, 8, 4})
208           .setInterchange({1, 2, 0})
209           .setLoopType(LinalgTilingLoopType::ParallelLoops),
210       LinalgTransformationFilter(
211           Identifier::get("par__with_perm__", ctx),
212           Identifier::get("after_par__with_perm__", ctx)));
213 
214   //===--------------------------------------------------------------------===//
215   // Linalg to loops patterns.
216   //===--------------------------------------------------------------------===//
217   patterns.add<LinalgLoweringPattern<DotOp>>(
218       ctx,
219       /*loweringType=*/LinalgLoweringType::Loops,
220       LinalgTransformationFilter(Identifier::get("REG", ctx)));
221 
222   //===--------------------------------------------------------------------===//
223   // Linalg distribution patterns.
224   //===--------------------------------------------------------------------===//
225   LinalgLoopDistributionOptions distributionOptions;
226 
227   //===--------------------------------------------------------------------===//
228   // Linalg to vector contraction patterns.
229   //===--------------------------------------------------------------------===//
230   patterns.add<LinalgVectorizationPattern>(
231       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
232                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
233 
234   //===--------------------------------------------------------------------===//
235   // Linalg generic interchange pattern.
236   //===--------------------------------------------------------------------===//
237   patterns.add<GenericOpInterchangePattern>(
238       ctx,
239       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
240       LinalgTransformationFilter(ArrayRef<Identifier>{},
241                                  Identifier::get("PERMUTED", ctx)));
242 
243   //===--------------------------------------------------------------------===//
244   // Linalg subview operands promotion.
245   //===--------------------------------------------------------------------===//
246   patterns.add<LinalgPromotionPattern<MatmulOp>>(
247       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
248       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
249                                  Identifier::get("_views_promoted_", ctx)));
250   patterns.add<LinalgPromotionPattern<MatmulOp>>(
251       ctx,
252       LinalgPromotionOptions()
253           .setOperandsToPromote({0})
254           .setUseFullTileBuffersByDefault(true),
255       LinalgTransformationFilter(
256           Identifier::get("_promote_first_view_", ctx),
257           Identifier::get("_first_view_promoted_", ctx)));
258   patterns.add<LinalgPromotionPattern<FillOp>>(
259       ctx,
260       LinalgPromotionOptions()
261           .setOperandsToPromote({1})
262           .setUseFullTileBuffers({false, true})
263           .setAlignment(32),
264       LinalgTransformationFilter(
265           Identifier::get("_promote_views_aligned_", ctx),
266           Identifier::get("_views_aligned_promoted_", ctx)));
267 
268   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
269 
270   // Drop the marker.
271   funcOp.walk([](LinalgOp op) {
272     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
273   });
274 }
275 
276 static void fillL1TilingAndMatmulToVectorPatterns(
277     FuncOp funcOp, StringRef startMarker,
278     SmallVectorImpl<RewritePatternSet> &patternsVector) {
279   MLIRContext *ctx = funcOp.getContext();
280   patternsVector.emplace_back(
281       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
282                ctx,
283                LinalgTilingOptions()
284                    .setTileSizes({8, 12, 16})
285                    .setInterchange({1, 0, 2}),
286                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
287                                           Identifier::get("L1", ctx))));
288 
289   patternsVector.emplace_back(
290       ctx,
291       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
292           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
293           LinalgTransformationFilter(Identifier::get("L1", ctx),
294                                      Identifier::get("VEC", ctx))));
295 
296   patternsVector.emplace_back(
297       ctx, std::make_unique<LinalgVectorizationPattern>(
298                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
299                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
300   patternsVector.back().add<LinalgVectorizationPattern>(
301       ctx, LinalgTransformationFilter().addFilter(
302                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // Test promotion callbacks
307 //===----------------------------------------------------------------------===//
308 
309 // Allocation call back
310 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
311                                        ArrayRef<Value> boundingSubViewSize,
312                                        DataLayout &layout) {
313   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
314   return b
315       .create<memref::AllocOp>(
316           subView.getLoc(),
317           MemRefType::get(shape, subView.getType().getElementType(),
318                           /*affineMapComposition =*/{}, 3),
319           boundingSubViewSize)
320       .getResult();
321 }
322 
323 // Deallocation callback
324 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
325   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
326   return success();
327 }
328 
329 // Copy in call back
330 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
331                                     bool isOutput) {
332   auto floatType = src.getType().cast<MemRefType>().getElementType();
333   if (!floatType.isa<FloatType>())
334     return failure();
335   if (!isOutput) {
336     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
337                                             FloatAttr::get(floatType, 42.0));
338     b.create<FillOp>(src.getLoc(), cst, dst);
339   }
340   b.create<CopyOp>(src.getLoc(), src, dst);
341   return success();
342 }
343 
344 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
345                                           RewritePatternSet &patterns) {
346   patterns.add<LinalgTilingPattern<MatmulOp>>(
347       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
348       LinalgTransformationFilter(Identifier::get("START", ctx),
349                                  Identifier::get("PROMOTE", ctx)));
350   patterns.add<LinalgPromotionPattern<MatmulOp>>(
351       ctx,
352       LinalgPromotionOptions()
353           .setOperandsToPromote({0, 2})
354           .setUseFullTileBuffers({false, false})
355           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
356           .setCopyInOutFns(
357               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
358                 return copyCallBackFn(b, src, dst, false);
359               },
360               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
361                 return copyCallBackFn(b, src, dst, true);
362               }),
363       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
364 }
365 
366 template <typename IdOp, typename NProcsOp>
367 static SmallVector<ProcInfo, 2>
368 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
369   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
370   SmallVector<ProcInfo, 2> procInfo(count);
371   const char *xyz[] = {"x", "y", "z"};
372   Type indexType = b.getIndexType();
373   for (unsigned i = 0; i < count; ++i) {
374     procInfo[count - 1 - i] = {
375         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
376         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
377   }
378   return procInfo;
379 }
380 
381 static void fillTileAndDistributePatterns(MLIRContext *context,
382                                           RewritePatternSet &patterns) {
383   {
384     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
385     cyclicNprocsEqNiters.distributionMethod.resize(
386         2, DistributionMethod::CyclicNumProcsEqNumIters);
387     cyclicNprocsEqNiters.procInfo =
388         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
389     patterns.add<LinalgTilingPattern<MatmulOp>>(
390         context,
391         LinalgTilingOptions()
392             .setTileSizes({8, 8, 4})
393             .setLoopType(LinalgTilingLoopType::ParallelLoops)
394             .setDistributionOptions(cyclicNprocsEqNiters),
395         LinalgTransformationFilter(
396             Identifier::get("distribute1", context),
397             Identifier::get("after_distribute1", context)));
398   }
399 
400   {
401     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
402     cyclicNprocsGeNiters.distributionMethod.resize(
403         2, DistributionMethod::CyclicNumProcsGeNumIters);
404     cyclicNprocsGeNiters.procInfo =
405         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
406     patterns.add<LinalgTilingPattern<MatmulOp>>(
407         context,
408         LinalgTilingOptions()
409             .setTileSizes({8, 8, 4})
410             .setLoopType(LinalgTilingLoopType::ParallelLoops)
411             .setDistributionOptions(cyclicNprocsGeNiters),
412         LinalgTransformationFilter(
413             Identifier::get("distribute2", context),
414             Identifier::get("after_distribute2", context)));
415   }
416 
417   {
418     LinalgLoopDistributionOptions cyclicNprocsDefault;
419     cyclicNprocsDefault.distributionMethod.resize(2,
420                                                   DistributionMethod::Cyclic);
421     cyclicNprocsDefault.procInfo =
422         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
423     patterns.add<LinalgTilingPattern<MatmulOp>>(
424         context,
425         LinalgTilingOptions()
426             .setTileSizes({8, 8, 4})
427             .setLoopType(LinalgTilingLoopType::ParallelLoops)
428             .setDistributionOptions(cyclicNprocsDefault),
429         LinalgTransformationFilter(
430             Identifier::get("distribute3", context),
431             Identifier::get("after_distribute3", context)));
432   }
433 
434   {
435     LinalgLoopDistributionOptions cyclicNprocsMixed1;
436     cyclicNprocsMixed1.distributionMethod = {
437         DistributionMethod::CyclicNumProcsEqNumIters,
438         DistributionMethod::CyclicNumProcsGeNumIters};
439     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
440     patterns.add<LinalgTilingPattern<MatmulOp>>(
441         context,
442         LinalgTilingOptions()
443             .setTileSizes({8, 8, 4})
444             .setLoopType(LinalgTilingLoopType::ParallelLoops)
445             .setDistributionOptions(cyclicNprocsMixed1),
446         LinalgTransformationFilter(
447             Identifier::get("distribute4", context),
448             Identifier::get("after_distribute4", context)));
449   }
450 
451   {
452     LinalgLoopDistributionOptions cyclicNprocsMixed2;
453     cyclicNprocsMixed2.distributionMethod = {
454         DistributionMethod::CyclicNumProcsGeNumIters,
455         DistributionMethod::Cyclic};
456     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
457     patterns.add<LinalgTilingPattern<MatmulOp>>(
458         context,
459         LinalgTilingOptions()
460             .setTileSizes({8, 8, 4})
461             .setLoopType(LinalgTilingLoopType::ParallelLoops)
462             .setDistributionOptions(cyclicNprocsMixed2),
463         LinalgTransformationFilter(
464             Identifier::get("distribute5", context),
465             Identifier::get("after_distribute5", context)));
466   }
467 
468   {
469     LinalgLoopDistributionOptions cyclicNprocsMixed3;
470     cyclicNprocsMixed3.distributionMethod = {
471         DistributionMethod::Cyclic,
472         DistributionMethod::CyclicNumProcsEqNumIters};
473     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
474 
475     patterns.add<LinalgTilingPattern<MatmulOp>>(
476         context,
477         LinalgTilingOptions()
478             .setTileSizes({8, 8, 4})
479             .setLoopType(LinalgTilingLoopType::ParallelLoops)
480             .setDistributionOptions(cyclicNprocsMixed3),
481         LinalgTransformationFilter(
482             Identifier::get("distribute6", context),
483             Identifier::get("after_distribute6", context)));
484   }
485 
486   {
487     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
488     cyclicNprocsEqNiters.distributionMethod.resize(2,
489                                                    DistributionMethod::Cyclic);
490     cyclicNprocsEqNiters.procInfo =
491         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
492     patterns.add<LinalgTilingPattern<MatmulOp>>(
493         context,
494         LinalgTilingOptions()
495             .setTileSizes({8, 8, 4})
496             .setLoopType(LinalgTilingLoopType::Loops)
497             .setDistributionOptions(cyclicNprocsEqNiters),
498         LinalgTransformationFilter(
499             Identifier::get("tensors_distribute1", context),
500             Identifier::get("tensors_after_distribute1", context)));
501   }
502 }
503 
504 static void
505 applyMatmulToVectorPatterns(FuncOp funcOp,
506                             bool testMatmulToVectorPatterns1dTiling,
507                             bool testMatmulToVectorPatterns2dTiling) {
508   MLIRContext *ctx = funcOp.getContext();
509   SmallVector<RewritePatternSet, 4> stage1Patterns;
510   if (testMatmulToVectorPatterns1dTiling) {
511     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
512                                           stage1Patterns);
513   } else if (testMatmulToVectorPatterns2dTiling) {
514     stage1Patterns.emplace_back(
515         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
516                  ctx,
517                  LinalgTilingOptions()
518                      .setTileSizes({768, 264, 768})
519                      .setInterchange({1, 2, 0}),
520                  LinalgTransformationFilter(Identifier::get("START", ctx),
521                                             Identifier::get("L2", ctx))));
522     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
523                                           stage1Patterns);
524   }
525   {
526     // Canonicalization patterns
527     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
528     vector::populateVectorTransferPermutationMapLoweringPatterns(
529         canonicalizationPatterns);
530     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
531     stage1Patterns.push_back(std::move(canonicalizationPatterns));
532   }
533   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
534   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
535   FrozenRewritePatternSet stage2Patterns =
536       getLinalgTilingCanonicalizationPatterns(ctx);
537   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
538                             std::move(stage2Patterns));
539 }
540 
541 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
542   RewritePatternSet forwardPattern(funcOp.getContext());
543   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
544   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
545   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
546 }
547 
548 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
549   RewritePatternSet patterns(funcOp.getContext());
550   patterns.add<LinalgVectorizationPattern>(
551       funcOp.getContext(),
552       LinalgTransformationFilter()
553           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
554   populatePadTensorOpVectorizationPatterns(patterns);
555   populateConvolutionVectorizationPatterns(patterns);
556   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
557 }
558 
559 static void applyDecomposeConvolutionPatterns(FuncOp funcOp) {
560   RewritePatternSet patterns(funcOp.getContext());
561   populateDecomposeConvolutionPatterns(patterns);
562   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
563 }
564 
565 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
566   RewritePatternSet patterns(funcOp.getContext());
567   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
568   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
569 }
570 
571 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
572   RewritePatternSet patterns(funcOp.getContext());
573   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
574   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
575 }
576 
577 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
578   RewritePatternSet patterns(funcOp.getContext());
579   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
580   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
581 }
582 
583 static void applyTilePattern(FuncOp funcOp, std::string loopType,
584                              ArrayRef<int64_t> tileSizes,
585                              ArrayRef<int64_t> peeledLoops,
586                              bool scalarizeDynamicDims) {
587   MLIRContext *context = funcOp.getContext();
588   RewritePatternSet tilingPattern(context);
589   LinalgTilingLoopType type =
590       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
591           .Case("for", LinalgTilingLoopType::Loops)
592           .Case("affine", LinalgTilingLoopType::AffineLoops)
593           .Case("parallel", LinalgTilingLoopType::ParallelLoops)
594           .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
595   auto linalgTilingOptions = linalg::LinalgTilingOptions()
596                                  .setPeeledLoops(peeledLoops)
597                                  .setLoopType(type);
598   if (scalarizeDynamicDims) {
599     linalgTilingOptions.scalarizeDynamicDims();
600     assert(tileSizes.empty() &&
601            "tileSizes and scalarizeDynamicDims is mutually exclusive");
602   } else {
603     linalgTilingOptions.setTileSizes(tileSizes);
604   }
605   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
606                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
607       context, linalgTilingOptions,
608       linalg::LinalgTransformationFilter(Identifier::get("tile", context)));
609   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
610 }
611 
612 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
613 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
614 
615 namespace {
616 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
617 /// `idx`-th loop contains only "full" iterations and a second loop for the
618 /// remaining partial iteration (if any).
619 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
620   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
621       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
622   }
623 
624   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
625                                 PatternRewriter &rewriter) const override {
626     SmallVector<int64_t> peeledLoops;
627     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
628       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
629       peeledLoops =
630           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
631             return attr.cast<IntegerAttr>().getInt();
632           }));
633       // Check if the loop was already peeled.
634       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
635         return failure();
636     }
637     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
638       // No peeling of loop nests with a partial iteration.
639       return failure();
640 
641     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
642       return failure();
643 
644     // Peel loop and canonicalize.
645     TiledLoopOp result;
646     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
647                                                     result)))
648       return failure();
649 
650     // Apply label, so that the same loop is not rewritten a second time.
651     peeledLoops.push_back(idx);
652     rewriter.updateRootInPlace(loopOp, [&]() {
653       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
654     });
655     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
656     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
657 
658     return success();
659   }
660 
661   /// Index of loop to peel.
662   int64_t idx;
663 
664   /// If set to true, do not peel TiledLoopOps with a partial iteration.
665   bool skipPartial;
666 };
667 } // namespace
668 
669 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
670                                          ArrayRef<unsigned> loops,
671                                          bool skipPartial) {
672   MLIRContext *ctx = funcOp.getContext();
673   RewritePatternSet patterns(ctx);
674   for (unsigned idx : loops)
675     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
676   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
677 
678   // Drop the markers.
679   funcOp.walk([](TiledLoopOp op) {
680     op->removeAttr(kPeeledLoopsLabel);
681     op->removeAttr(kPartialIterationLabel);
682   });
683 }
684 
685 /// Apply transformations specified as patterns.
686 void TestLinalgTransforms::runOnFunction() {
687   auto lambda = [&](void *) {
688     getFunction().walk([](LinalgOp op) {
689       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
690     });
691   };
692   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
693 
694   if (testPromotionOptions) {
695     RewritePatternSet patterns(&getContext());
696     fillPromotionCallBackPatterns(&getContext(), patterns);
697     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
698     return;
699   }
700   if (testTileAndDistributionOptions) {
701     RewritePatternSet patterns(&getContext());
702     fillTileAndDistributePatterns(&getContext(), patterns);
703     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
704     return;
705   }
706   if (testPatterns)
707     return applyPatterns(getFunction());
708   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
709     return applyMatmulToVectorPatterns(getFunction(),
710                                        testMatmulToVectorPatterns1dTiling,
711                                        testMatmulToVectorPatterns2dTiling);
712   if (testVectorTransferForwardingPatterns)
713     return applyVectorTransferForwardingPatterns(getFunction());
714   if (testGenericToVectorPattern)
715     return applyLinalgToVectorPatterns(getFunction());
716   if (testTransformPadTensor)
717     return applyPadTensorToGenericPatterns(getFunction());
718   if (testGeneralizePadTensor)
719     return applyGeneralizePadTensorPatterns(getFunction());
720   if (testSwapSubTensorPadTensor)
721     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
722   if (testTiledLoopPeeling.hasValue())
723     return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
724                                         skipPartial);
725   if (testTilePattern)
726     return applyTilePattern(getFunction(), loopType, tileSizes, peeledLoops,
727                             /*scalarizeDynamicDims=*/false);
728   if (testTileScalarizeDynamicDims)
729     return applyTilePattern(getFunction(), loopType, tileSizes,
730                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
731   if (testDecomposeConvolutionPattern)
732     return applyDecomposeConvolutionPatterns(getFunction());
733 }
734 
735 namespace mlir {
736 namespace test {
737 void registerTestLinalgTransforms() {
738   PassRegistration<TestLinalgTransforms>();
739 }
740 } // namespace test
741 } // namespace mlir
742