1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
30 namespace {
31 struct TestLinalgTransforms
32     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
33   TestLinalgTransforms() = default;
34   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
35 
36   void getDependentDialects(DialectRegistry &registry) const override {
37     // clang-format off
38     registry.insert<AffineDialect,
39                     memref::MemRefDialect,
40                     scf::SCFDialect,
41                     StandardOpsDialect,
42                     vector::VectorDialect,
43                     gpu::GPUDialect>();
44     // clang-format on
45   }
46   StringRef getArgument() const final {
47     return "test-linalg-transform-patterns";
48   }
49   StringRef getDescription() const final {
50     return "Test Linalg transformation patterns by applying them greedily.";
51   }
52 
53   void runOnFunction() override;
54 
55   Option<bool> testPatterns{*this, "test-patterns",
56                             llvm::cl::desc("Test a mixed set of patterns"),
57                             llvm::cl::init(false)};
58   Option<bool> testMatmulToVectorPatterns1dTiling{
59       *this, "test-matmul-to-vector-patterns-tile-1d",
60       llvm::cl::desc(
61           "Test a fused pass that applies patterns from matmul to vectors via "
62           "1-d tiling"),
63       llvm::cl::init(false)};
64   Option<bool> testMatmulToVectorPatterns2dTiling{
65       *this, "test-matmul-to-vector-patterns-tile-2d",
66       llvm::cl::desc(
67           "Test a fused pass that applies patterns from matmul to vectors via "
68           "2-d tiling"),
69       llvm::cl::init(false)};
70   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
71                                     llvm::cl::desc("Test promotion options"),
72                                     llvm::cl::init(false)};
73   Option<bool> testTileAndDistributionOptions{
74       *this, "test-tile-and-distribute-options",
75       llvm::cl::desc("Test tile and distribute options"),
76       llvm::cl::init(false)};
77   Option<bool> testVectorTransferForwardingPatterns{
78       *this, "test-vector-transfer-forwarding-patterns",
79       llvm::cl::desc(
80           "Test a fused pass that forwards linalg.copy to vector.transfer"),
81       llvm::cl::init(false)};
82   Option<bool> testGenericToVectorPattern{
83       *this, "test-linalg-to-vector-patterns",
84       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
85                      "in vector.contract form"),
86       llvm::cl::init(false)};
87   Option<bool> testTilePattern{*this, "test-tile-pattern",
88                                llvm::cl::desc("Test tile pattern"),
89                                llvm::cl::init(false)};
90   Option<bool> testTileScalarizeDynamicDims{
91       *this, "test-tile-scalarize-dynamic-dims",
92       llvm::cl::desc("Test tiling of dynamic dims by 1"),
93       llvm::cl::init(false)};
94   Option<int> testHoistPadding{*this, "test-hoist-padding",
95                                llvm::cl::desc("Test hoist padding"),
96                                llvm::cl::init(0)};
97   Option<bool> testTransformPadTensor{
98       *this, "test-transform-pad-tensor",
99       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
100       llvm::cl::init(false)};
101   Option<bool> testGeneralizePadTensor{
102       *this, "test-generalize-pad-tensor",
103       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
104       llvm::cl::init(false)};
105   Option<bool> testSwapSubTensorPadTensor{
106       *this, "test-swap-subtensor-padtensor",
107       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
108                      "pad_tensor(subtensor)"),
109       llvm::cl::init(false)};
110   Option<bool> padTiles{*this, "pad-tiles",
111                         llvm::cl::desc("Pad tiles when test-tile-pattern"),
112                         llvm::cl::init(false)};
113   ListOption<int64_t> peeledLoops{
114       *this, "peeled-loops",
115       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
116       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
117   ListOption<int64_t> tileSizes{
118       *this, "tile-sizes",
119       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
120       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
121   ListOption<unsigned> testInterchangePattern{
122       *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
123       llvm::cl::desc("Test the interchange pattern.")};
124   ListOption<unsigned> testTiledLoopPeeling{
125       *this, "test-tiled-loop-peeling",
126       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
127       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
128   Option<bool> skipPartial{
129       *this, "skip-partial",
130       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
131       llvm::cl::init(false)};
132 };
133 } // end anonymous namespace
134 
135 static void applyPatterns(FuncOp funcOp) {
136   MLIRContext *ctx = funcOp.getContext();
137   RewritePatternSet patterns(ctx);
138 
139   //===--------------------------------------------------------------------===//
140   // Linalg tiling patterns.
141   //===--------------------------------------------------------------------===//
142   patterns.add<LinalgTilingPattern<MatmulOp>>(
143       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
144       LinalgTransformationFilter(Identifier::get("MEM", ctx),
145                                  Identifier::get("L3", ctx)));
146   patterns.add<LinalgTilingPattern<MatmulOp>>(
147       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
148       LinalgTransformationFilter(Identifier::get("L3", ctx),
149                                  Identifier::get("L2", ctx)));
150   patterns.add<LinalgTilingPattern<MatmulOp>>(
151       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
152       LinalgTransformationFilter(Identifier::get("L2", ctx),
153                                  Identifier::get("L1", ctx)));
154   patterns.add<LinalgTilingPattern<MatmulOp>>(
155       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
156       LinalgTransformationFilter(Identifier::get("L1", ctx),
157                                  Identifier::get("REG", ctx)));
158 
159   patterns.add<LinalgTilingPattern<MatvecOp>>(
160       ctx,
161       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
162           LinalgTilingLoopType::ParallelLoops),
163       LinalgTransformationFilter(ArrayRef<Identifier>{},
164                                  Identifier::get("L1", ctx)));
165 
166   patterns.add<LinalgTilingPattern<DotOp>>(
167       ctx, LinalgTilingOptions().setTileSizes(8000),
168       LinalgTransformationFilter(
169           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
170                                Identifier::get("L3", ctx),
171                                Identifier::get("L2", ctx)},
172           Identifier::get("REG", ctx)));
173 
174   //===--------------------------------------------------------------------===//
175   // Linalg tiling and permutation patterns.
176   //===--------------------------------------------------------------------===//
177   patterns.add<LinalgTilingPattern<MatmulOp>>(
178       ctx,
179       LinalgTilingOptions()
180           .setTileSizes({2000, 3000, 4000})
181           .setInterchange({1, 2, 0}),
182       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
183                                  Identifier::get("L2__with_perm__", ctx)));
184   patterns.add<LinalgTilingPattern<MatmulOp>>(
185       ctx,
186       LinalgTilingOptions()
187           .setTileSizes({200, 300, 400})
188           .setInterchange({1, 0, 2}),
189       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
190                                  Identifier::get("L1__with_perm__", ctx)));
191   patterns.add<LinalgTilingPattern<MatmulOp>>(
192       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
193       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
194                                  Identifier::get("REG__with_perm__", ctx)));
195 
196   patterns.add<LinalgTilingPattern<MatvecOp>>(
197       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
198       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
199                                  Identifier::get("L1__with_perm__", ctx)));
200 
201   patterns.add<LinalgTilingPattern<MatmulOp>>(
202       ctx,
203       LinalgTilingOptions()
204           .setTileSizes({16, 8, 4})
205           .setInterchange({1, 2, 0})
206           .setLoopType(LinalgTilingLoopType::ParallelLoops),
207       LinalgTransformationFilter(
208           Identifier::get("par__with_perm__", ctx),
209           Identifier::get("after_par__with_perm__", ctx)));
210 
211   //===--------------------------------------------------------------------===//
212   // Linalg to loops patterns.
213   //===--------------------------------------------------------------------===//
214   patterns.add<LinalgLoweringPattern<DotOp>>(
215       ctx,
216       /*loweringType=*/LinalgLoweringType::Loops,
217       LinalgTransformationFilter(Identifier::get("REG", ctx)));
218 
219   //===--------------------------------------------------------------------===//
220   // Linalg distribution patterns.
221   //===--------------------------------------------------------------------===//
222   LinalgLoopDistributionOptions distributionOptions;
223 
224   //===--------------------------------------------------------------------===//
225   // Linalg to vector contraction patterns.
226   //===--------------------------------------------------------------------===//
227   patterns.add<LinalgVectorizationPattern>(
228       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
229                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
230 
231   //===--------------------------------------------------------------------===//
232   // Linalg generic interchange pattern.
233   //===--------------------------------------------------------------------===//
234   patterns.add<GenericOpInterchangePattern>(
235       ctx,
236       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
237       LinalgTransformationFilter(ArrayRef<Identifier>{},
238                                  Identifier::get("PERMUTED", ctx)));
239 
240   //===--------------------------------------------------------------------===//
241   // Linalg subview operands promotion.
242   //===--------------------------------------------------------------------===//
243   patterns.add<LinalgPromotionPattern<MatmulOp>>(
244       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
245       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
246                                  Identifier::get("_views_promoted_", ctx)));
247   patterns.add<LinalgPromotionPattern<MatmulOp>>(
248       ctx,
249       LinalgPromotionOptions()
250           .setOperandsToPromote({0})
251           .setUseFullTileBuffersByDefault(true),
252       LinalgTransformationFilter(
253           Identifier::get("_promote_first_view_", ctx),
254           Identifier::get("_first_view_promoted_", ctx)));
255   patterns.add<LinalgPromotionPattern<FillOp>>(
256       ctx,
257       LinalgPromotionOptions()
258           .setOperandsToPromote({1})
259           .setUseFullTileBuffers({false, true})
260           .setAlignment(32),
261       LinalgTransformationFilter(
262           Identifier::get("_promote_views_aligned_", ctx),
263           Identifier::get("_views_aligned_promoted_", ctx)));
264 
265   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
266 
267   // Drop the marker.
268   funcOp.walk([](LinalgOp op) {
269     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
270   });
271 }
272 
273 static void fillL1TilingAndMatmulToVectorPatterns(
274     FuncOp funcOp, StringRef startMarker,
275     SmallVectorImpl<RewritePatternSet> &patternsVector) {
276   MLIRContext *ctx = funcOp.getContext();
277   patternsVector.emplace_back(
278       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
279                ctx,
280                LinalgTilingOptions()
281                    .setTileSizes({8, 12, 16})
282                    .setInterchange({1, 0, 2}),
283                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
284                                           Identifier::get("L1", ctx))));
285 
286   patternsVector.emplace_back(
287       ctx,
288       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
289           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
290           LinalgTransformationFilter(Identifier::get("L1", ctx),
291                                      Identifier::get("VEC", ctx))));
292 
293   patternsVector.emplace_back(
294       ctx, std::make_unique<LinalgVectorizationPattern>(
295                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
296                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
297   patternsVector.back().add<LinalgVectorizationPattern>(
298       ctx, LinalgTransformationFilter().addFilter(
299                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // Test promotion callbacks
304 //===----------------------------------------------------------------------===//
305 
306 // Allocation call back
307 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
308                                        ArrayRef<Value> boundingSubViewSize,
309                                        DataLayout &layout) {
310   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
311   return b
312       .create<memref::AllocOp>(
313           subView.getLoc(),
314           MemRefType::get(shape, subView.getType().getElementType(),
315                           /*affineMapComposition =*/{}, 3),
316           boundingSubViewSize)
317       .getResult();
318 }
319 
320 // Deallocation callback
321 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
322   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
323   return success();
324 }
325 
326 // Copy in call back
327 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
328                                     bool isOutput) {
329   auto floatType = src.getType().cast<MemRefType>().getElementType();
330   if (!floatType.isa<FloatType>())
331     return failure();
332   if (!isOutput) {
333     Value cst =
334         b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0));
335     b.create<FillOp>(src.getLoc(), cst, dst);
336   }
337   b.create<CopyOp>(src.getLoc(), src, dst);
338   return success();
339 }
340 
341 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
342                                           RewritePatternSet &patterns) {
343   patterns.add<LinalgTilingPattern<MatmulOp>>(
344       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
345       LinalgTransformationFilter(Identifier::get("START", ctx),
346                                  Identifier::get("PROMOTE", ctx)));
347   patterns.add<LinalgPromotionPattern<MatmulOp>>(
348       ctx,
349       LinalgPromotionOptions()
350           .setOperandsToPromote({0, 2})
351           .setUseFullTileBuffers({false, false})
352           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
353           .setCopyInOutFns(
354               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
355                 return copyCallBackFn(b, src, dst, false);
356               },
357               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
358                 return copyCallBackFn(b, src, dst, true);
359               }),
360       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
361 }
362 
363 template <typename IdOp, typename NProcsOp>
364 static SmallVector<ProcInfo, 2>
365 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
366   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
367   SmallVector<ProcInfo, 2> procInfo(count);
368   const char *xyz[] = {"x", "y", "z"};
369   Type indexType = b.getIndexType();
370   for (unsigned i = 0; i < count; ++i) {
371     procInfo[count - 1 - i] = {
372         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
373         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
374   }
375   return procInfo;
376 }
377 
378 static void fillTileAndDistributePatterns(MLIRContext *context,
379                                           RewritePatternSet &patterns) {
380   {
381     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
382     cyclicNprocsEqNiters.distributionMethod.resize(
383         2, DistributionMethod::CyclicNumProcsEqNumIters);
384     cyclicNprocsEqNiters.procInfo =
385         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
386     patterns.add<LinalgTilingPattern<MatmulOp>>(
387         context,
388         LinalgTilingOptions()
389             .setTileSizes({8, 8, 4})
390             .setLoopType(LinalgTilingLoopType::ParallelLoops)
391             .setDistributionOptions(cyclicNprocsEqNiters),
392         LinalgTransformationFilter(
393             Identifier::get("distribute1", context),
394             Identifier::get("after_distribute1", context)));
395   }
396 
397   {
398     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
399     cyclicNprocsGeNiters.distributionMethod.resize(
400         2, DistributionMethod::CyclicNumProcsGeNumIters);
401     cyclicNprocsGeNiters.procInfo =
402         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
403     patterns.add<LinalgTilingPattern<MatmulOp>>(
404         context,
405         LinalgTilingOptions()
406             .setTileSizes({8, 8, 4})
407             .setLoopType(LinalgTilingLoopType::ParallelLoops)
408             .setDistributionOptions(cyclicNprocsGeNiters),
409         LinalgTransformationFilter(
410             Identifier::get("distribute2", context),
411             Identifier::get("after_distribute2", context)));
412   }
413 
414   {
415     LinalgLoopDistributionOptions cyclicNprocsDefault;
416     cyclicNprocsDefault.distributionMethod.resize(2,
417                                                   DistributionMethod::Cyclic);
418     cyclicNprocsDefault.procInfo =
419         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
420     patterns.add<LinalgTilingPattern<MatmulOp>>(
421         context,
422         LinalgTilingOptions()
423             .setTileSizes({8, 8, 4})
424             .setLoopType(LinalgTilingLoopType::ParallelLoops)
425             .setDistributionOptions(cyclicNprocsDefault),
426         LinalgTransformationFilter(
427             Identifier::get("distribute3", context),
428             Identifier::get("after_distribute3", context)));
429   }
430 
431   {
432     LinalgLoopDistributionOptions cyclicNprocsMixed1;
433     cyclicNprocsMixed1.distributionMethod = {
434         DistributionMethod::CyclicNumProcsEqNumIters,
435         DistributionMethod::CyclicNumProcsGeNumIters};
436     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
437     patterns.add<LinalgTilingPattern<MatmulOp>>(
438         context,
439         LinalgTilingOptions()
440             .setTileSizes({8, 8, 4})
441             .setLoopType(LinalgTilingLoopType::ParallelLoops)
442             .setDistributionOptions(cyclicNprocsMixed1),
443         LinalgTransformationFilter(
444             Identifier::get("distribute4", context),
445             Identifier::get("after_distribute4", context)));
446   }
447 
448   {
449     LinalgLoopDistributionOptions cyclicNprocsMixed2;
450     cyclicNprocsMixed2.distributionMethod = {
451         DistributionMethod::CyclicNumProcsGeNumIters,
452         DistributionMethod::Cyclic};
453     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
454     patterns.add<LinalgTilingPattern<MatmulOp>>(
455         context,
456         LinalgTilingOptions()
457             .setTileSizes({8, 8, 4})
458             .setLoopType(LinalgTilingLoopType::ParallelLoops)
459             .setDistributionOptions(cyclicNprocsMixed2),
460         LinalgTransformationFilter(
461             Identifier::get("distribute5", context),
462             Identifier::get("after_distribute5", context)));
463   }
464 
465   {
466     LinalgLoopDistributionOptions cyclicNprocsMixed3;
467     cyclicNprocsMixed3.distributionMethod = {
468         DistributionMethod::Cyclic,
469         DistributionMethod::CyclicNumProcsEqNumIters};
470     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
471 
472     patterns.add<LinalgTilingPattern<MatmulOp>>(
473         context,
474         LinalgTilingOptions()
475             .setTileSizes({8, 8, 4})
476             .setLoopType(LinalgTilingLoopType::ParallelLoops)
477             .setDistributionOptions(cyclicNprocsMixed3),
478         LinalgTransformationFilter(
479             Identifier::get("distribute6", context),
480             Identifier::get("after_distribute6", context)));
481   }
482 
483   {
484     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
485     cyclicNprocsEqNiters.distributionMethod.resize(2,
486                                                    DistributionMethod::Cyclic);
487     cyclicNprocsEqNiters.procInfo =
488         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
489     patterns.add<LinalgTilingPattern<MatmulOp>>(
490         context,
491         LinalgTilingOptions()
492             .setTileSizes({8, 8, 4})
493             .setLoopType(LinalgTilingLoopType::Loops)
494             .setDistributionOptions(cyclicNprocsEqNiters),
495         LinalgTransformationFilter(
496             Identifier::get("tensors_distribute1", context),
497             Identifier::get("tensors_after_distribute1", context)));
498   }
499 }
500 
501 static void
502 applyMatmulToVectorPatterns(FuncOp funcOp,
503                             bool testMatmulToVectorPatterns1dTiling,
504                             bool testMatmulToVectorPatterns2dTiling) {
505   MLIRContext *ctx = funcOp.getContext();
506   SmallVector<RewritePatternSet, 4> stage1Patterns;
507   if (testMatmulToVectorPatterns1dTiling) {
508     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
509                                           stage1Patterns);
510   } else if (testMatmulToVectorPatterns2dTiling) {
511     stage1Patterns.emplace_back(
512         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
513                  ctx,
514                  LinalgTilingOptions()
515                      .setTileSizes({768, 264, 768})
516                      .setInterchange({1, 2, 0}),
517                  LinalgTransformationFilter(Identifier::get("START", ctx),
518                                             Identifier::get("L2", ctx))));
519     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
520                                           stage1Patterns);
521   }
522   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
523   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
524   FrozenRewritePatternSet stage2Patterns =
525       getLinalgTilingCanonicalizationPatterns(ctx);
526   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
527                             std::move(stage2Patterns));
528 }
529 
530 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
531   RewritePatternSet forwardPattern(funcOp.getContext());
532   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
533   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
534   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
535 }
536 
537 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
538   RewritePatternSet patterns(funcOp.getContext());
539   patterns.add<LinalgVectorizationPattern>(
540       funcOp.getContext(),
541       LinalgTransformationFilter()
542           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
543   populatePadTensorOpVectorizationPatterns(patterns);
544   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
545 }
546 
547 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
548   RewritePatternSet patterns(funcOp.getContext());
549   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
550   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
551 }
552 
553 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
554   RewritePatternSet patterns(funcOp.getContext());
555   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
556   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
557 }
558 
559 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
560   RewritePatternSet patterns(funcOp.getContext());
561   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
562   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
563 }
564 
565 // For now, just assume it is the zero of type.
566 // In the future, it should be the zero of type + op.
567 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
568   auto t = getElementTypeOrSelf(op.get());
569   return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
570 }
571 
572 static void applyTilePattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes,
573                              bool padTiles, ArrayRef<int64_t> peeledLoops,
574                              bool scalarizeDynamicDims) {
575   MLIRContext *context = funcOp.getContext();
576   RewritePatternSet tilingPattern(context);
577   auto linalgTilingOptions =
578       linalg::LinalgTilingOptions().setPeeledLoops(peeledLoops);
579   if (scalarizeDynamicDims) {
580     linalgTilingOptions.scalarizeDynamicDims();
581     assert(tileSizes.empty() &&
582            "tileSizes and scalarizeDynamicDims is mutually exclusive");
583   } else {
584     linalgTilingOptions.setTileSizes(tileSizes);
585   }
586   if (padTiles)
587     linalgTilingOptions.setPaddingValueComputationFunction(
588         getNeutralOfLinalgOp);
589 
590   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
591                     linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>,
592                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
593       context, linalgTilingOptions,
594       linalg::LinalgTransformationFilter(Identifier::get("tile", context)));
595   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
596 }
597 
598 static void applyInterchangePattern(FuncOp funcOp,
599                                     ArrayRef<unsigned> interchangeVector) {
600   MLIRContext *context = funcOp.getContext();
601   RewritePatternSet interchangePattern(context);
602   interchangePattern.add<GenericOpInterchangePattern>(
603       context, interchangeVector,
604       LinalgTransformationFilter(ArrayRef<Identifier>{},
605                                  Identifier::get("interchange", context)));
606   (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
607 }
608 
609 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
610 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
611 
612 namespace {
613 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
614 /// `idx`-th loop contains only "full" iterations and a second loop for the
615 /// remaining partial iteration (if any).
616 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
617   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
618       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
619   }
620 
621   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
622                                 PatternRewriter &rewriter) const override {
623     SmallVector<int64_t> peeledLoops;
624     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
625       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
626       peeledLoops =
627           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
628             return attr.cast<IntegerAttr>().getInt();
629           }));
630       // Check if the loop was already peeled.
631       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
632         return failure();
633     }
634     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
635       // No peeling of loop nests with a partial iteration.
636       return failure();
637 
638     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
639       return failure();
640 
641     // Peel loop and canonicalize.
642     TiledLoopOp result;
643     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
644                                                     result)))
645       return failure();
646 
647     // Apply label, so that the same loop is not rewritten a second time.
648     peeledLoops.push_back(idx);
649     rewriter.updateRootInPlace(loopOp, [&]() {
650       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
651     });
652     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
653     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
654 
655     return success();
656   }
657 
658   /// Index of loop to peel.
659   int64_t idx;
660 
661   /// If set to true, do not peel TiledLoopOps with a partial iteration.
662   bool skipPartial;
663 };
664 } // namespace
665 
666 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
667                                          ArrayRef<unsigned> loops,
668                                          bool skipPartial) {
669   MLIRContext *ctx = funcOp.getContext();
670   RewritePatternSet patterns(ctx);
671   for (unsigned idx : loops)
672     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
673   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
674 
675   // Drop the markers.
676   funcOp.walk([](TiledLoopOp op) {
677     op->removeAttr(kPeeledLoopsLabel);
678     op->removeAttr(kPartialIterationLabel);
679   });
680 }
681 
682 /// Apply transformations specified as patterns.
683 void TestLinalgTransforms::runOnFunction() {
684   auto lambda = [&](void *) {
685     getFunction().walk([](LinalgOp op) {
686       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
687     });
688   };
689   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
690 
691   if (testPromotionOptions) {
692     RewritePatternSet patterns(&getContext());
693     fillPromotionCallBackPatterns(&getContext(), patterns);
694     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
695     return;
696   }
697   if (testTileAndDistributionOptions) {
698     RewritePatternSet patterns(&getContext());
699     fillTileAndDistributePatterns(&getContext(), patterns);
700     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
701     return;
702   }
703   if (testPatterns)
704     return applyPatterns(getFunction());
705   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
706     return applyMatmulToVectorPatterns(getFunction(),
707                                        testMatmulToVectorPatterns1dTiling,
708                                        testMatmulToVectorPatterns2dTiling);
709   if (testVectorTransferForwardingPatterns)
710     return applyVectorTransferForwardingPatterns(getFunction());
711   if (testGenericToVectorPattern)
712     return applyLinalgToVectorPatterns(getFunction());
713   if (testTransformPadTensor)
714     return applyPadTensorToGenericPatterns(getFunction());
715   if (testGeneralizePadTensor)
716     return applyGeneralizePadTensorPatterns(getFunction());
717   if (testSwapSubTensorPadTensor)
718     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
719   if (testTiledLoopPeeling.hasValue())
720     return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
721                                         skipPartial);
722   if (testTilePattern)
723     return applyTilePattern(getFunction(), tileSizes, padTiles, peeledLoops,
724                             /*scalarizeDynamicDims=*/false);
725   if (testTileScalarizeDynamicDims)
726     return applyTilePattern(getFunction(), tileSizes, padTiles,
727                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
728   if (testHoistPadding) {
729     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
730       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
731     });
732   }
733   if (testInterchangePattern.hasValue())
734     return applyInterchangePattern(getFunction(), testInterchangePattern);
735 }
736 
737 namespace mlir {
738 namespace test {
739 void registerTestLinalgTransforms() {
740   PassRegistration<TestLinalgTransforms>();
741 }
742 } // namespace test
743 } // namespace mlir
744