1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Vector/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct TestLinalgTransforms
35     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
36   TestLinalgTransforms() = default;
37   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
38 
39   void getDependentDialects(DialectRegistry &registry) const override {
40     // clang-format off
41     registry.insert<AffineDialect,
42                     memref::MemRefDialect,
43                     scf::SCFDialect,
44                     StandardOpsDialect,
45                     vector::VectorDialect,
46                     gpu::GPUDialect>();
47     // clang-format on
48   }
49   StringRef getArgument() const final {
50     return "test-linalg-transform-patterns";
51   }
52   StringRef getDescription() const final {
53     return "Test Linalg transformation patterns by applying them greedily.";
54   }
55 
56   void runOnFunction() override;
57 
58   Option<bool> testPatterns{*this, "test-patterns",
59                             llvm::cl::desc("Test a mixed set of patterns"),
60                             llvm::cl::init(false)};
61   Option<bool> testMatmulToVectorPatterns1dTiling{
62       *this, "test-matmul-to-vector-patterns-tile-1d",
63       llvm::cl::desc(
64           "Test a fused pass that applies patterns from matmul to vectors via "
65           "1-d tiling"),
66       llvm::cl::init(false)};
67   Option<bool> testMatmulToVectorPatterns2dTiling{
68       *this, "test-matmul-to-vector-patterns-tile-2d",
69       llvm::cl::desc(
70           "Test a fused pass that applies patterns from matmul to vectors via "
71           "2-d tiling"),
72       llvm::cl::init(false)};
73   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
74                                     llvm::cl::desc("Test promotion options"),
75                                     llvm::cl::init(false)};
76   Option<bool> testTileAndDistributionOptions{
77       *this, "test-tile-and-distribute-options",
78       llvm::cl::desc("Test tile and distribute options"),
79       llvm::cl::init(false)};
80   Option<bool> testVectorTransferForwardingPatterns{
81       *this, "test-vector-transfer-forwarding-patterns",
82       llvm::cl::desc(
83           "Test a fused pass that forwards linalg.copy to vector.transfer"),
84       llvm::cl::init(false)};
85   Option<bool> testGenericToVectorPattern{
86       *this, "test-linalg-to-vector-patterns",
87       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
88                      "in vector.contract form"),
89       llvm::cl::init(false)};
90   Option<bool> testTilePattern{*this, "test-tile-pattern",
91                                llvm::cl::desc("Test tile pattern"),
92                                llvm::cl::init(false)};
93   Option<bool> testTileScalarizeDynamicDims{
94       *this, "test-tile-scalarize-dynamic-dims",
95       llvm::cl::desc("Test tiling of dynamic dims by 1"),
96       llvm::cl::init(false)};
97   Option<bool> testTransformPadTensor{
98       *this, "test-transform-pad-tensor",
99       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
100       llvm::cl::init(false)};
101   Option<bool> testGeneralizePadTensor{
102       *this, "test-generalize-pad-tensor",
103       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
104       llvm::cl::init(false)};
105   Option<bool> testSwapSubTensorPadTensor{
106       *this, "test-swap-subtensor-padtensor",
107       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
108                      "pad_tensor(subtensor)"),
109       llvm::cl::init(false)};
110   ListOption<int64_t> peeledLoops{
111       *this, "peeled-loops",
112       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
113       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
114   ListOption<int64_t> tileSizes{
115       *this, "tile-sizes",
116       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
117       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
118   ListOption<unsigned> testTiledLoopPeeling{
119       *this, "test-tiled-loop-peeling",
120       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
121       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
122   Option<bool> skipPartial{
123       *this, "skip-partial",
124       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
125       llvm::cl::init(false)};
126   Option<std::string> loopType{
127       *this, "loop-type",
128       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
129                      "tiled_loop"),
130       llvm::cl::init("for")};
131 };
132 } // end anonymous namespace
133 
134 static void applyPatterns(FuncOp funcOp) {
135   MLIRContext *ctx = funcOp.getContext();
136   RewritePatternSet patterns(ctx);
137 
138   //===--------------------------------------------------------------------===//
139   // Linalg tiling patterns.
140   //===--------------------------------------------------------------------===//
141   patterns.add<LinalgTilingPattern<MatmulOp>>(
142       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
143       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
144                                  StringAttr::get(ctx, "L3")));
145   patterns.add<LinalgTilingPattern<MatmulOp>>(
146       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
147       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
148                                  StringAttr::get(ctx, "L2")));
149   patterns.add<LinalgTilingPattern<MatmulOp>>(
150       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
151       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
152                                  StringAttr::get(ctx, "L1")));
153   patterns.add<LinalgTilingPattern<MatmulOp>>(
154       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
155       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
156                                  StringAttr::get(ctx, "REG")));
157 
158   patterns.add<LinalgTilingPattern<MatvecOp>>(
159       ctx,
160       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
161           LinalgTilingLoopType::ParallelLoops),
162       LinalgTransformationFilter(ArrayRef<StringAttr>{},
163                                  StringAttr::get(ctx, "L1")));
164 
165   patterns.add<LinalgTilingPattern<DotOp>>(
166       ctx, LinalgTilingOptions().setTileSizes(8000),
167       LinalgTransformationFilter(
168           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
169                                StringAttr::get(ctx, "L3"),
170                                StringAttr::get(ctx, "L2")},
171           StringAttr::get(ctx, "REG")));
172 
173   //===--------------------------------------------------------------------===//
174   // Linalg tiling and permutation patterns.
175   //===--------------------------------------------------------------------===//
176   patterns.add<LinalgTilingPattern<MatmulOp>>(
177       ctx,
178       LinalgTilingOptions()
179           .setTileSizes({2000, 3000, 4000})
180           .setInterchange({1, 2, 0}),
181       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
182                                  StringAttr::get(ctx, "L2__with_perm__")));
183   patterns.add<LinalgTilingPattern<MatmulOp>>(
184       ctx,
185       LinalgTilingOptions()
186           .setTileSizes({200, 300, 400})
187           .setInterchange({1, 0, 2}),
188       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
189                                  StringAttr::get(ctx, "L1__with_perm__")));
190   patterns.add<LinalgTilingPattern<MatmulOp>>(
191       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
192       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
193                                  StringAttr::get(ctx, "REG__with_perm__")));
194 
195   patterns.add<LinalgTilingPattern<MatvecOp>>(
196       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
197       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
198                                  StringAttr::get(ctx, "L1__with_perm__")));
199 
200   patterns.add<LinalgTilingPattern<MatmulOp>>(
201       ctx,
202       LinalgTilingOptions()
203           .setTileSizes({16, 8, 4})
204           .setInterchange({1, 2, 0})
205           .setLoopType(LinalgTilingLoopType::ParallelLoops),
206       LinalgTransformationFilter(
207           StringAttr::get(ctx, "par__with_perm__"),
208           StringAttr::get(ctx, "after_par__with_perm__")));
209 
210   //===--------------------------------------------------------------------===//
211   // Linalg to loops patterns.
212   //===--------------------------------------------------------------------===//
213   patterns.add<LinalgLoweringPattern<DotOp>>(
214       ctx,
215       /*loweringType=*/LinalgLoweringType::Loops,
216       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
217 
218   //===--------------------------------------------------------------------===//
219   // Linalg distribution patterns.
220   //===--------------------------------------------------------------------===//
221   LinalgLoopDistributionOptions distributionOptions;
222 
223   //===--------------------------------------------------------------------===//
224   // Linalg to vector contraction patterns.
225   //===--------------------------------------------------------------------===//
226   patterns.add<LinalgVectorizationPattern>(
227       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
228                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
229 
230   //===--------------------------------------------------------------------===//
231   // Linalg generic interchange pattern.
232   //===--------------------------------------------------------------------===//
233   patterns.add<GenericOpInterchangePattern>(
234       ctx,
235       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
236       LinalgTransformationFilter(ArrayRef<StringAttr>{},
237                                  StringAttr::get(ctx, "PERMUTED")));
238 
239   //===--------------------------------------------------------------------===//
240   // Linalg subview operands promotion.
241   //===--------------------------------------------------------------------===//
242   patterns.add<LinalgPromotionPattern<MatmulOp>>(
243       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
244       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
245                                  StringAttr::get(ctx, "_views_promoted_")));
246   patterns.add<LinalgPromotionPattern<MatmulOp>>(
247       ctx,
248       LinalgPromotionOptions()
249           .setOperandsToPromote({0})
250           .setUseFullTileBuffersByDefault(true),
251       LinalgTransformationFilter(
252           StringAttr::get(ctx, "_promote_first_view_"),
253           StringAttr::get(ctx, "_first_view_promoted_")));
254   patterns.add<LinalgPromotionPattern<FillOp>>(
255       ctx,
256       LinalgPromotionOptions()
257           .setOperandsToPromote({1})
258           .setUseFullTileBuffers({false, true})
259           .setAlignment(32),
260       LinalgTransformationFilter(
261           StringAttr::get(ctx, "_promote_views_aligned_"),
262           StringAttr::get(ctx, "_views_aligned_promoted_")));
263 
264   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
265 
266   // Drop the marker.
267   funcOp.walk([](LinalgOp op) {
268     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
269   });
270 }
271 
272 static void fillL1TilingAndMatmulToVectorPatterns(
273     FuncOp funcOp, StringRef startMarker,
274     SmallVectorImpl<RewritePatternSet> &patternsVector) {
275   MLIRContext *ctx = funcOp.getContext();
276   patternsVector.emplace_back(
277       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
278                ctx,
279                LinalgTilingOptions()
280                    .setTileSizes({8, 12, 16})
281                    .setInterchange({1, 0, 2}),
282                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
283                                           StringAttr::get(ctx, "L1"))));
284 
285   patternsVector.emplace_back(
286       ctx,
287       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
288           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
289           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
290                                      StringAttr::get(ctx, "VEC"))));
291 
292   patternsVector.emplace_back(
293       ctx, std::make_unique<LinalgVectorizationPattern>(
294                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
295                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
296   patternsVector.back().add<LinalgVectorizationPattern>(
297       ctx, LinalgTransformationFilter().addFilter(
298                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // Test promotion callbacks
303 //===----------------------------------------------------------------------===//
304 
305 // Allocation call back
306 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
307                                        ArrayRef<Value> boundingSubViewSize,
308                                        DataLayout &layout) {
309   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
310   return b
311       .create<memref::AllocOp>(
312           subView.getLoc(),
313           MemRefType::get(shape, subView.getType().getElementType(),
314                           /*affineMapComposition =*/{}, 3),
315           boundingSubViewSize)
316       .getResult();
317 }
318 
319 // Deallocation callback
320 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
321   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
322   return success();
323 }
324 
325 // Copy in call back
326 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
327                                     bool isOutput) {
328   auto floatType = src.getType().cast<MemRefType>().getElementType();
329   if (!floatType.isa<FloatType>())
330     return failure();
331   if (!isOutput) {
332     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
333                                             FloatAttr::get(floatType, 42.0));
334     b.create<FillOp>(src.getLoc(), cst, dst);
335   }
336   b.create<CopyOp>(src.getLoc(), src, dst);
337   return success();
338 }
339 
340 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
341                                           RewritePatternSet &patterns) {
342   patterns.add<LinalgTilingPattern<MatmulOp>>(
343       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
344       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
345                                  StringAttr::get(ctx, "PROMOTE")));
346   patterns.add<LinalgPromotionPattern<MatmulOp>>(
347       ctx,
348       LinalgPromotionOptions()
349           .setOperandsToPromote({0, 2})
350           .setUseFullTileBuffers({false, false})
351           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
352           .setCopyInOutFns(
353               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
354                 return copyCallBackFn(b, src, dst, false);
355               },
356               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
357                 return copyCallBackFn(b, src, dst, true);
358               }),
359       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
360 }
361 
362 template <typename IdOp, typename NProcsOp>
363 static SmallVector<ProcInfo, 2>
364 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
365   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
366   SmallVector<ProcInfo, 2> procInfo(count);
367   const char *xyz[] = {"x", "y", "z"};
368   Type indexType = b.getIndexType();
369   for (unsigned i = 0; i < count; ++i) {
370     procInfo[count - 1 - i] = {
371         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
372         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
373   }
374   return procInfo;
375 }
376 
377 static void fillTileAndDistributePatterns(MLIRContext *context,
378                                           RewritePatternSet &patterns) {
379   {
380     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
381     cyclicNprocsEqNiters.distributionMethod.resize(
382         2, DistributionMethod::CyclicNumProcsEqNumIters);
383     cyclicNprocsEqNiters.procInfo =
384         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
385     patterns.add<LinalgTilingPattern<MatmulOp>>(
386         context,
387         LinalgTilingOptions()
388             .setTileSizes({8, 8, 4})
389             .setLoopType(LinalgTilingLoopType::ParallelLoops)
390             .setDistributionOptions(cyclicNprocsEqNiters),
391         LinalgTransformationFilter(
392             StringAttr::get(context, "distribute1"),
393             StringAttr::get(context, "after_distribute1")));
394   }
395 
396   {
397     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
398     cyclicNprocsGeNiters.distributionMethod.resize(
399         2, DistributionMethod::CyclicNumProcsGeNumIters);
400     cyclicNprocsGeNiters.procInfo =
401         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
402     patterns.add<LinalgTilingPattern<MatmulOp>>(
403         context,
404         LinalgTilingOptions()
405             .setTileSizes({8, 8, 4})
406             .setLoopType(LinalgTilingLoopType::ParallelLoops)
407             .setDistributionOptions(cyclicNprocsGeNiters),
408         LinalgTransformationFilter(
409             StringAttr::get(context, "distribute2"),
410             StringAttr::get(context, "after_distribute2")));
411   }
412 
413   {
414     LinalgLoopDistributionOptions cyclicNprocsDefault;
415     cyclicNprocsDefault.distributionMethod.resize(2,
416                                                   DistributionMethod::Cyclic);
417     cyclicNprocsDefault.procInfo =
418         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
419     patterns.add<LinalgTilingPattern<MatmulOp>>(
420         context,
421         LinalgTilingOptions()
422             .setTileSizes({8, 8, 4})
423             .setLoopType(LinalgTilingLoopType::ParallelLoops)
424             .setDistributionOptions(cyclicNprocsDefault),
425         LinalgTransformationFilter(
426             StringAttr::get(context, "distribute3"),
427             StringAttr::get(context, "after_distribute3")));
428   }
429 
430   {
431     LinalgLoopDistributionOptions cyclicNprocsMixed1;
432     cyclicNprocsMixed1.distributionMethod = {
433         DistributionMethod::CyclicNumProcsEqNumIters,
434         DistributionMethod::CyclicNumProcsGeNumIters};
435     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
436     patterns.add<LinalgTilingPattern<MatmulOp>>(
437         context,
438         LinalgTilingOptions()
439             .setTileSizes({8, 8, 4})
440             .setLoopType(LinalgTilingLoopType::ParallelLoops)
441             .setDistributionOptions(cyclicNprocsMixed1),
442         LinalgTransformationFilter(
443             StringAttr::get(context, "distribute4"),
444             StringAttr::get(context, "after_distribute4")));
445   }
446 
447   {
448     LinalgLoopDistributionOptions cyclicNprocsMixed2;
449     cyclicNprocsMixed2.distributionMethod = {
450         DistributionMethod::CyclicNumProcsGeNumIters,
451         DistributionMethod::Cyclic};
452     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
453     patterns.add<LinalgTilingPattern<MatmulOp>>(
454         context,
455         LinalgTilingOptions()
456             .setTileSizes({8, 8, 4})
457             .setLoopType(LinalgTilingLoopType::ParallelLoops)
458             .setDistributionOptions(cyclicNprocsMixed2),
459         LinalgTransformationFilter(
460             StringAttr::get(context, "distribute5"),
461             StringAttr::get(context, "after_distribute5")));
462   }
463 
464   {
465     LinalgLoopDistributionOptions cyclicNprocsMixed3;
466     cyclicNprocsMixed3.distributionMethod = {
467         DistributionMethod::Cyclic,
468         DistributionMethod::CyclicNumProcsEqNumIters};
469     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
470 
471     patterns.add<LinalgTilingPattern<MatmulOp>>(
472         context,
473         LinalgTilingOptions()
474             .setTileSizes({8, 8, 4})
475             .setLoopType(LinalgTilingLoopType::ParallelLoops)
476             .setDistributionOptions(cyclicNprocsMixed3),
477         LinalgTransformationFilter(
478             StringAttr::get(context, "distribute6"),
479             StringAttr::get(context, "after_distribute6")));
480   }
481 
482   {
483     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
484     cyclicNprocsEqNiters.distributionMethod.resize(2,
485                                                    DistributionMethod::Cyclic);
486     cyclicNprocsEqNiters.procInfo =
487         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
488     patterns.add<LinalgTilingPattern<MatmulOp>>(
489         context,
490         LinalgTilingOptions()
491             .setTileSizes({8, 8, 4})
492             .setLoopType(LinalgTilingLoopType::Loops)
493             .setDistributionOptions(cyclicNprocsEqNiters),
494         LinalgTransformationFilter(
495             StringAttr::get(context, "tensors_distribute1"),
496             StringAttr::get(context, "tensors_after_distribute1")));
497   }
498 }
499 
500 static void
501 applyMatmulToVectorPatterns(FuncOp funcOp,
502                             bool testMatmulToVectorPatterns1dTiling,
503                             bool testMatmulToVectorPatterns2dTiling) {
504   MLIRContext *ctx = funcOp.getContext();
505   SmallVector<RewritePatternSet, 4> stage1Patterns;
506   if (testMatmulToVectorPatterns1dTiling) {
507     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
508   } else if (testMatmulToVectorPatterns2dTiling) {
509     stage1Patterns.emplace_back(
510         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
511                  ctx,
512                  LinalgTilingOptions()
513                      .setTileSizes({768, 264, 768})
514                      .setInterchange({1, 2, 0}),
515                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
516                                             StringAttr::get(ctx, "L2"))));
517     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
518   }
519   {
520     // Canonicalization patterns
521     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
522     vector::populateVectorTransferPermutationMapLoweringPatterns(
523         canonicalizationPatterns);
524     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
525     stage1Patterns.push_back(std::move(canonicalizationPatterns));
526   }
527   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
528   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
529   FrozenRewritePatternSet stage2Patterns =
530       getLinalgTilingCanonicalizationPatterns(ctx);
531   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
532                             std::move(stage2Patterns));
533 }
534 
535 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
536   RewritePatternSet forwardPattern(funcOp.getContext());
537   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
538   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
539   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
540 }
541 
542 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
543   RewritePatternSet patterns(funcOp.getContext());
544   patterns.add<LinalgVectorizationPattern>(
545       funcOp.getContext(),
546       LinalgTransformationFilter()
547           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
548   populatePadTensorOpVectorizationPatterns(patterns);
549   populateConvolutionVectorizationPatterns(patterns);
550   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
551 }
552 
553 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
554   RewritePatternSet patterns(funcOp.getContext());
555   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
556   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
557 }
558 
559 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
560   RewritePatternSet patterns(funcOp.getContext());
561   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
562   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
563 }
564 
565 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
566   RewritePatternSet patterns(funcOp.getContext());
567   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
568   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
569 }
570 
571 static void applyTilePattern(FuncOp funcOp, std::string loopType,
572                              ArrayRef<int64_t> tileSizes,
573                              ArrayRef<int64_t> peeledLoops,
574                              bool scalarizeDynamicDims) {
575   MLIRContext *context = funcOp.getContext();
576   RewritePatternSet tilingPattern(context);
577   LinalgTilingLoopType type =
578       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
579           .Case("for", LinalgTilingLoopType::Loops)
580           .Case("affine", LinalgTilingLoopType::AffineLoops)
581           .Case("parallel", LinalgTilingLoopType::ParallelLoops)
582           .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
583   auto linalgTilingOptions = linalg::LinalgTilingOptions()
584                                  .setPeeledLoops(peeledLoops)
585                                  .setLoopType(type);
586   if (scalarizeDynamicDims) {
587     linalgTilingOptions.scalarizeDynamicDims();
588     assert(tileSizes.empty() &&
589            "tileSizes and scalarizeDynamicDims is mutually exclusive");
590   } else {
591     linalgTilingOptions.setTileSizes(tileSizes);
592   }
593   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
594                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
595       context, linalgTilingOptions,
596       linalg::LinalgTransformationFilter(StringAttr::get(context, "tile")));
597   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
598 }
599 
600 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
601 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
602 
603 namespace {
604 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
605 /// `idx`-th loop contains only "full" iterations and a second loop for the
606 /// remaining partial iteration (if any).
607 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
608   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
609       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
610   }
611 
612   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
613                                 PatternRewriter &rewriter) const override {
614     SmallVector<int64_t> peeledLoops;
615     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
616       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
617       peeledLoops =
618           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
619             return attr.cast<IntegerAttr>().getInt();
620           }));
621       // Check if the loop was already peeled.
622       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
623         return failure();
624     }
625     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
626       // No peeling of loop nests with a partial iteration.
627       return failure();
628 
629     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
630       return failure();
631 
632     // Peel loop and canonicalize.
633     TiledLoopOp result;
634     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
635                                                     result)))
636       return failure();
637 
638     // Apply label, so that the same loop is not rewritten a second time.
639     peeledLoops.push_back(idx);
640     rewriter.updateRootInPlace(loopOp, [&]() {
641       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
642     });
643     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
644     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
645 
646     return success();
647   }
648 
649   /// Index of loop to peel.
650   int64_t idx;
651 
652   /// If set to true, do not peel TiledLoopOps with a partial iteration.
653   bool skipPartial;
654 };
655 } // namespace
656 
657 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
658                                          ArrayRef<unsigned> loops,
659                                          bool skipPartial) {
660   MLIRContext *ctx = funcOp.getContext();
661   RewritePatternSet patterns(ctx);
662   for (unsigned idx : loops)
663     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
664   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
665 
666   // Drop the markers.
667   funcOp.walk([](TiledLoopOp op) {
668     op->removeAttr(kPeeledLoopsLabel);
669     op->removeAttr(kPartialIterationLabel);
670   });
671 }
672 
673 /// Apply transformations specified as patterns.
674 void TestLinalgTransforms::runOnFunction() {
675   auto lambda = [&](void *) {
676     getFunction().walk([](LinalgOp op) {
677       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
678     });
679   };
680   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
681 
682   if (testPromotionOptions) {
683     RewritePatternSet patterns(&getContext());
684     fillPromotionCallBackPatterns(&getContext(), patterns);
685     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
686     return;
687   }
688   if (testTileAndDistributionOptions) {
689     RewritePatternSet patterns(&getContext());
690     fillTileAndDistributePatterns(&getContext(), patterns);
691     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
692     return;
693   }
694   if (testPatterns)
695     return applyPatterns(getFunction());
696   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
697     return applyMatmulToVectorPatterns(getFunction(),
698                                        testMatmulToVectorPatterns1dTiling,
699                                        testMatmulToVectorPatterns2dTiling);
700   if (testVectorTransferForwardingPatterns)
701     return applyVectorTransferForwardingPatterns(getFunction());
702   if (testGenericToVectorPattern)
703     return applyLinalgToVectorPatterns(getFunction());
704   if (testTransformPadTensor)
705     return applyPadTensorToGenericPatterns(getFunction());
706   if (testGeneralizePadTensor)
707     return applyGeneralizePadTensorPatterns(getFunction());
708   if (testSwapSubTensorPadTensor)
709     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
710   if (testTiledLoopPeeling.hasValue())
711     return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
712                                         skipPartial);
713   if (testTilePattern)
714     return applyTilePattern(getFunction(), loopType, tileSizes, peeledLoops,
715                             /*scalarizeDynamicDims=*/false);
716   if (testTileScalarizeDynamicDims)
717     return applyTilePattern(getFunction(), loopType, tileSizes,
718                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
719 }
720 
721 namespace mlir {
722 namespace test {
723 void registerTestLinalgTransforms() {
724   PassRegistration<TestLinalgTransforms>();
725 }
726 } // namespace test
727 } // namespace mlir
728