1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Vector/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct TestLinalgTransforms
35     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
36   TestLinalgTransforms() = default;
37   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
38 
39   void getDependentDialects(DialectRegistry &registry) const override {
40     // clang-format off
41     registry.insert<AffineDialect,
42                     memref::MemRefDialect,
43                     scf::SCFDialect,
44                     StandardOpsDialect,
45                     vector::VectorDialect,
46                     gpu::GPUDialect>();
47     // clang-format on
48   }
49   StringRef getArgument() const final {
50     return "test-linalg-transform-patterns";
51   }
52   StringRef getDescription() const final {
53     return "Test Linalg transformation patterns by applying them greedily.";
54   }
55 
56   void runOnFunction() override;
57 
58   Option<bool> testPatterns{*this, "test-patterns",
59                             llvm::cl::desc("Test a mixed set of patterns"),
60                             llvm::cl::init(false)};
61   Option<bool> testMatmulToVectorPatterns1dTiling{
62       *this, "test-matmul-to-vector-patterns-tile-1d",
63       llvm::cl::desc(
64           "Test a fused pass that applies patterns from matmul to vectors via "
65           "1-d tiling"),
66       llvm::cl::init(false)};
67   Option<bool> testMatmulToVectorPatterns2dTiling{
68       *this, "test-matmul-to-vector-patterns-tile-2d",
69       llvm::cl::desc(
70           "Test a fused pass that applies patterns from matmul to vectors via "
71           "2-d tiling"),
72       llvm::cl::init(false)};
73   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
74                                     llvm::cl::desc("Test promotion options"),
75                                     llvm::cl::init(false)};
76   Option<bool> testTileAndDistributionOptions{
77       *this, "test-tile-and-distribute-options",
78       llvm::cl::desc("Test tile and distribute options"),
79       llvm::cl::init(false)};
80   Option<bool> testVectorTransferForwardingPatterns{
81       *this, "test-vector-transfer-forwarding-patterns",
82       llvm::cl::desc(
83           "Test a fused pass that forwards linalg.copy to vector.transfer"),
84       llvm::cl::init(false)};
85   Option<bool> testGenericToVectorPattern{
86       *this, "test-linalg-to-vector-patterns",
87       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
88                      "in vector.contract form"),
89       llvm::cl::init(false)};
90   Option<bool> testTilePattern{*this, "test-tile-pattern",
91                                llvm::cl::desc("Test tile pattern"),
92                                llvm::cl::init(false)};
93   Option<bool> testTileScalarizeDynamicDims{
94       *this, "test-tile-scalarize-dynamic-dims",
95       llvm::cl::desc("Test tiling of dynamic dims by 1"),
96       llvm::cl::init(false)};
97   Option<bool> testTransformPadTensor{
98       *this, "test-transform-pad-tensor",
99       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
100       llvm::cl::init(false)};
101   Option<bool> testGeneralizePadTensor{
102       *this, "test-generalize-pad-tensor",
103       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
104       llvm::cl::init(false)};
105   Option<bool> testSwapSubTensorPadTensor{
106       *this, "test-swap-subtensor-padtensor",
107       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
108                      "pad_tensor(subtensor)"),
109       llvm::cl::init(false)};
110   ListOption<int64_t> peeledLoops{
111       *this, "peeled-loops",
112       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
113       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
114   ListOption<int64_t> tileSizes{
115       *this, "tile-sizes",
116       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
117       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
118   ListOption<unsigned> testTiledLoopPeeling{
119       *this, "test-tiled-loop-peeling",
120       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
121       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
122   Option<bool> skipPartial{
123       *this, "skip-partial",
124       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
125       llvm::cl::init(false)};
126   Option<std::string> loopType{
127       *this, "loop-type",
128       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
129                      "tiled_loop"),
130       llvm::cl::init("for")};
131   Option<bool> testDecomposeConvolutionPattern{
132       *this, "test-decompose-convolution-patterns",
133       llvm::cl::desc("Test a set of patterns to rewrite high-D convolution ops "
134                      "into low-D ones"),
135       llvm::cl::init(false)};
136 };
137 } // end anonymous namespace
138 
139 static void applyPatterns(FuncOp funcOp) {
140   MLIRContext *ctx = funcOp.getContext();
141   RewritePatternSet patterns(ctx);
142 
143   //===--------------------------------------------------------------------===//
144   // Linalg tiling patterns.
145   //===--------------------------------------------------------------------===//
146   patterns.add<LinalgTilingPattern<MatmulOp>>(
147       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
148       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
149                                  StringAttr::get(ctx, "L3")));
150   patterns.add<LinalgTilingPattern<MatmulOp>>(
151       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
152       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
153                                  StringAttr::get(ctx, "L2")));
154   patterns.add<LinalgTilingPattern<MatmulOp>>(
155       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
156       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
157                                  StringAttr::get(ctx, "L1")));
158   patterns.add<LinalgTilingPattern<MatmulOp>>(
159       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
160       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
161                                  StringAttr::get(ctx, "REG")));
162 
163   patterns.add<LinalgTilingPattern<MatvecOp>>(
164       ctx,
165       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
166           LinalgTilingLoopType::ParallelLoops),
167       LinalgTransformationFilter(ArrayRef<StringAttr>{},
168                                  StringAttr::get(ctx, "L1")));
169 
170   patterns.add<LinalgTilingPattern<DotOp>>(
171       ctx, LinalgTilingOptions().setTileSizes(8000),
172       LinalgTransformationFilter(
173           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
174                                StringAttr::get(ctx, "L3"),
175                                StringAttr::get(ctx, "L2")},
176           StringAttr::get(ctx, "REG")));
177 
178   //===--------------------------------------------------------------------===//
179   // Linalg tiling and permutation patterns.
180   //===--------------------------------------------------------------------===//
181   patterns.add<LinalgTilingPattern<MatmulOp>>(
182       ctx,
183       LinalgTilingOptions()
184           .setTileSizes({2000, 3000, 4000})
185           .setInterchange({1, 2, 0}),
186       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
187                                  StringAttr::get(ctx, "L2__with_perm__")));
188   patterns.add<LinalgTilingPattern<MatmulOp>>(
189       ctx,
190       LinalgTilingOptions()
191           .setTileSizes({200, 300, 400})
192           .setInterchange({1, 0, 2}),
193       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
194                                  StringAttr::get(ctx, "L1__with_perm__")));
195   patterns.add<LinalgTilingPattern<MatmulOp>>(
196       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
197       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
198                                  StringAttr::get(ctx, "REG__with_perm__")));
199 
200   patterns.add<LinalgTilingPattern<MatvecOp>>(
201       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
202       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
203                                  StringAttr::get(ctx, "L1__with_perm__")));
204 
205   patterns.add<LinalgTilingPattern<MatmulOp>>(
206       ctx,
207       LinalgTilingOptions()
208           .setTileSizes({16, 8, 4})
209           .setInterchange({1, 2, 0})
210           .setLoopType(LinalgTilingLoopType::ParallelLoops),
211       LinalgTransformationFilter(
212           StringAttr::get(ctx, "par__with_perm__"),
213           StringAttr::get(ctx, "after_par__with_perm__")));
214 
215   //===--------------------------------------------------------------------===//
216   // Linalg to loops patterns.
217   //===--------------------------------------------------------------------===//
218   patterns.add<LinalgLoweringPattern<DotOp>>(
219       ctx,
220       /*loweringType=*/LinalgLoweringType::Loops,
221       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
222 
223   //===--------------------------------------------------------------------===//
224   // Linalg distribution patterns.
225   //===--------------------------------------------------------------------===//
226   LinalgLoopDistributionOptions distributionOptions;
227 
228   //===--------------------------------------------------------------------===//
229   // Linalg to vector contraction patterns.
230   //===--------------------------------------------------------------------===//
231   patterns.add<LinalgVectorizationPattern>(
232       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
233                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
234 
235   //===--------------------------------------------------------------------===//
236   // Linalg generic interchange pattern.
237   //===--------------------------------------------------------------------===//
238   patterns.add<GenericOpInterchangePattern>(
239       ctx,
240       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
241       LinalgTransformationFilter(ArrayRef<StringAttr>{},
242                                  StringAttr::get(ctx, "PERMUTED")));
243 
244   //===--------------------------------------------------------------------===//
245   // Linalg subview operands promotion.
246   //===--------------------------------------------------------------------===//
247   patterns.add<LinalgPromotionPattern<MatmulOp>>(
248       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
249       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
250                                  StringAttr::get(ctx, "_views_promoted_")));
251   patterns.add<LinalgPromotionPattern<MatmulOp>>(
252       ctx,
253       LinalgPromotionOptions()
254           .setOperandsToPromote({0})
255           .setUseFullTileBuffersByDefault(true),
256       LinalgTransformationFilter(
257           StringAttr::get(ctx, "_promote_first_view_"),
258           StringAttr::get(ctx, "_first_view_promoted_")));
259   patterns.add<LinalgPromotionPattern<FillOp>>(
260       ctx,
261       LinalgPromotionOptions()
262           .setOperandsToPromote({1})
263           .setUseFullTileBuffers({false, true})
264           .setAlignment(32),
265       LinalgTransformationFilter(
266           StringAttr::get(ctx, "_promote_views_aligned_"),
267           StringAttr::get(ctx, "_views_aligned_promoted_")));
268 
269   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
270 
271   // Drop the marker.
272   funcOp.walk([](LinalgOp op) {
273     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
274   });
275 }
276 
277 static void fillL1TilingAndMatmulToVectorPatterns(
278     FuncOp funcOp, StringRef startMarker,
279     SmallVectorImpl<RewritePatternSet> &patternsVector) {
280   MLIRContext *ctx = funcOp.getContext();
281   patternsVector.emplace_back(
282       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
283                ctx,
284                LinalgTilingOptions()
285                    .setTileSizes({8, 12, 16})
286                    .setInterchange({1, 0, 2}),
287                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
288                                           StringAttr::get(ctx, "L1"))));
289 
290   patternsVector.emplace_back(
291       ctx,
292       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
293           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
294           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
295                                      StringAttr::get(ctx, "VEC"))));
296 
297   patternsVector.emplace_back(
298       ctx, std::make_unique<LinalgVectorizationPattern>(
299                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
300                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
301   patternsVector.back().add<LinalgVectorizationPattern>(
302       ctx, LinalgTransformationFilter().addFilter(
303                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Test promotion callbacks
308 //===----------------------------------------------------------------------===//
309 
310 // Allocation call back
311 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
312                                        ArrayRef<Value> boundingSubViewSize,
313                                        DataLayout &layout) {
314   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
315   return b
316       .create<memref::AllocOp>(
317           subView.getLoc(),
318           MemRefType::get(shape, subView.getType().getElementType(),
319                           /*affineMapComposition =*/{}, 3),
320           boundingSubViewSize)
321       .getResult();
322 }
323 
324 // Deallocation callback
325 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
326   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
327   return success();
328 }
329 
330 // Copy in call back
331 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
332                                     bool isOutput) {
333   auto floatType = src.getType().cast<MemRefType>().getElementType();
334   if (!floatType.isa<FloatType>())
335     return failure();
336   if (!isOutput) {
337     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
338                                             FloatAttr::get(floatType, 42.0));
339     b.create<FillOp>(src.getLoc(), cst, dst);
340   }
341   b.create<CopyOp>(src.getLoc(), src, dst);
342   return success();
343 }
344 
345 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
346                                           RewritePatternSet &patterns) {
347   patterns.add<LinalgTilingPattern<MatmulOp>>(
348       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
349       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
350                                  StringAttr::get(ctx, "PROMOTE")));
351   patterns.add<LinalgPromotionPattern<MatmulOp>>(
352       ctx,
353       LinalgPromotionOptions()
354           .setOperandsToPromote({0, 2})
355           .setUseFullTileBuffers({false, false})
356           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
357           .setCopyInOutFns(
358               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
359                 return copyCallBackFn(b, src, dst, false);
360               },
361               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
362                 return copyCallBackFn(b, src, dst, true);
363               }),
364       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
365 }
366 
367 template <typename IdOp, typename NProcsOp>
368 static SmallVector<ProcInfo, 2>
369 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
370   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
371   SmallVector<ProcInfo, 2> procInfo(count);
372   const char *xyz[] = {"x", "y", "z"};
373   Type indexType = b.getIndexType();
374   for (unsigned i = 0; i < count; ++i) {
375     procInfo[count - 1 - i] = {
376         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
377         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
378   }
379   return procInfo;
380 }
381 
382 static void fillTileAndDistributePatterns(MLIRContext *context,
383                                           RewritePatternSet &patterns) {
384   {
385     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
386     cyclicNprocsEqNiters.distributionMethod.resize(
387         2, DistributionMethod::CyclicNumProcsEqNumIters);
388     cyclicNprocsEqNiters.procInfo =
389         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
390     patterns.add<LinalgTilingPattern<MatmulOp>>(
391         context,
392         LinalgTilingOptions()
393             .setTileSizes({8, 8, 4})
394             .setLoopType(LinalgTilingLoopType::ParallelLoops)
395             .setDistributionOptions(cyclicNprocsEqNiters),
396         LinalgTransformationFilter(
397             StringAttr::get(context, "distribute1"),
398             StringAttr::get(context, "after_distribute1")));
399   }
400 
401   {
402     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
403     cyclicNprocsGeNiters.distributionMethod.resize(
404         2, DistributionMethod::CyclicNumProcsGeNumIters);
405     cyclicNprocsGeNiters.procInfo =
406         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
407     patterns.add<LinalgTilingPattern<MatmulOp>>(
408         context,
409         LinalgTilingOptions()
410             .setTileSizes({8, 8, 4})
411             .setLoopType(LinalgTilingLoopType::ParallelLoops)
412             .setDistributionOptions(cyclicNprocsGeNiters),
413         LinalgTransformationFilter(
414             StringAttr::get(context, "distribute2"),
415             StringAttr::get(context, "after_distribute2")));
416   }
417 
418   {
419     LinalgLoopDistributionOptions cyclicNprocsDefault;
420     cyclicNprocsDefault.distributionMethod.resize(2,
421                                                   DistributionMethod::Cyclic);
422     cyclicNprocsDefault.procInfo =
423         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
424     patterns.add<LinalgTilingPattern<MatmulOp>>(
425         context,
426         LinalgTilingOptions()
427             .setTileSizes({8, 8, 4})
428             .setLoopType(LinalgTilingLoopType::ParallelLoops)
429             .setDistributionOptions(cyclicNprocsDefault),
430         LinalgTransformationFilter(
431             StringAttr::get(context, "distribute3"),
432             StringAttr::get(context, "after_distribute3")));
433   }
434 
435   {
436     LinalgLoopDistributionOptions cyclicNprocsMixed1;
437     cyclicNprocsMixed1.distributionMethod = {
438         DistributionMethod::CyclicNumProcsEqNumIters,
439         DistributionMethod::CyclicNumProcsGeNumIters};
440     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
441     patterns.add<LinalgTilingPattern<MatmulOp>>(
442         context,
443         LinalgTilingOptions()
444             .setTileSizes({8, 8, 4})
445             .setLoopType(LinalgTilingLoopType::ParallelLoops)
446             .setDistributionOptions(cyclicNprocsMixed1),
447         LinalgTransformationFilter(
448             StringAttr::get(context, "distribute4"),
449             StringAttr::get(context, "after_distribute4")));
450   }
451 
452   {
453     LinalgLoopDistributionOptions cyclicNprocsMixed2;
454     cyclicNprocsMixed2.distributionMethod = {
455         DistributionMethod::CyclicNumProcsGeNumIters,
456         DistributionMethod::Cyclic};
457     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
458     patterns.add<LinalgTilingPattern<MatmulOp>>(
459         context,
460         LinalgTilingOptions()
461             .setTileSizes({8, 8, 4})
462             .setLoopType(LinalgTilingLoopType::ParallelLoops)
463             .setDistributionOptions(cyclicNprocsMixed2),
464         LinalgTransformationFilter(
465             StringAttr::get(context, "distribute5"),
466             StringAttr::get(context, "after_distribute5")));
467   }
468 
469   {
470     LinalgLoopDistributionOptions cyclicNprocsMixed3;
471     cyclicNprocsMixed3.distributionMethod = {
472         DistributionMethod::Cyclic,
473         DistributionMethod::CyclicNumProcsEqNumIters};
474     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
475 
476     patterns.add<LinalgTilingPattern<MatmulOp>>(
477         context,
478         LinalgTilingOptions()
479             .setTileSizes({8, 8, 4})
480             .setLoopType(LinalgTilingLoopType::ParallelLoops)
481             .setDistributionOptions(cyclicNprocsMixed3),
482         LinalgTransformationFilter(
483             StringAttr::get(context, "distribute6"),
484             StringAttr::get(context, "after_distribute6")));
485   }
486 
487   {
488     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
489     cyclicNprocsEqNiters.distributionMethod.resize(2,
490                                                    DistributionMethod::Cyclic);
491     cyclicNprocsEqNiters.procInfo =
492         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
493     patterns.add<LinalgTilingPattern<MatmulOp>>(
494         context,
495         LinalgTilingOptions()
496             .setTileSizes({8, 8, 4})
497             .setLoopType(LinalgTilingLoopType::Loops)
498             .setDistributionOptions(cyclicNprocsEqNiters),
499         LinalgTransformationFilter(
500             StringAttr::get(context, "tensors_distribute1"),
501             StringAttr::get(context, "tensors_after_distribute1")));
502   }
503 }
504 
505 static void
506 applyMatmulToVectorPatterns(FuncOp funcOp,
507                             bool testMatmulToVectorPatterns1dTiling,
508                             bool testMatmulToVectorPatterns2dTiling) {
509   MLIRContext *ctx = funcOp.getContext();
510   SmallVector<RewritePatternSet, 4> stage1Patterns;
511   if (testMatmulToVectorPatterns1dTiling) {
512     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
513   } else if (testMatmulToVectorPatterns2dTiling) {
514     stage1Patterns.emplace_back(
515         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
516                  ctx,
517                  LinalgTilingOptions()
518                      .setTileSizes({768, 264, 768})
519                      .setInterchange({1, 2, 0}),
520                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
521                                             StringAttr::get(ctx, "L2"))));
522     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
523   }
524   {
525     // Canonicalization patterns
526     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
527     vector::populateVectorTransferPermutationMapLoweringPatterns(
528         canonicalizationPatterns);
529     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
530     stage1Patterns.push_back(std::move(canonicalizationPatterns));
531   }
532   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
533   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
534   FrozenRewritePatternSet stage2Patterns =
535       getLinalgTilingCanonicalizationPatterns(ctx);
536   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
537                             std::move(stage2Patterns));
538 }
539 
540 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
541   RewritePatternSet forwardPattern(funcOp.getContext());
542   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
543   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
544   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
545 }
546 
547 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
548   RewritePatternSet patterns(funcOp.getContext());
549   patterns.add<LinalgVectorizationPattern>(
550       funcOp.getContext(),
551       LinalgTransformationFilter()
552           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
553   populatePadTensorOpVectorizationPatterns(patterns);
554   populateConvolutionVectorizationPatterns(patterns);
555   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
556 }
557 
558 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
559   RewritePatternSet patterns(funcOp.getContext());
560   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
561   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
562 }
563 
564 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
565   RewritePatternSet patterns(funcOp.getContext());
566   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
567   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
568 }
569 
570 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
571   RewritePatternSet patterns(funcOp.getContext());
572   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
573   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
574 }
575 
576 static void applyTilePattern(FuncOp funcOp, std::string loopType,
577                              ArrayRef<int64_t> tileSizes,
578                              ArrayRef<int64_t> peeledLoops,
579                              bool scalarizeDynamicDims) {
580   MLIRContext *context = funcOp.getContext();
581   RewritePatternSet tilingPattern(context);
582   LinalgTilingLoopType type =
583       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
584           .Case("for", LinalgTilingLoopType::Loops)
585           .Case("affine", LinalgTilingLoopType::AffineLoops)
586           .Case("parallel", LinalgTilingLoopType::ParallelLoops)
587           .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
588   auto linalgTilingOptions = linalg::LinalgTilingOptions()
589                                  .setPeeledLoops(peeledLoops)
590                                  .setLoopType(type);
591   if (scalarizeDynamicDims) {
592     linalgTilingOptions.scalarizeDynamicDims();
593     assert(tileSizes.empty() &&
594            "tileSizes and scalarizeDynamicDims is mutually exclusive");
595   } else {
596     linalgTilingOptions.setTileSizes(tileSizes);
597   }
598   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
599                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
600       context, linalgTilingOptions,
601       linalg::LinalgTransformationFilter(StringAttr::get(context, "tile")));
602   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
603 }
604 
605 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
606 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
607 
608 namespace {
609 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
610 /// `idx`-th loop contains only "full" iterations and a second loop for the
611 /// remaining partial iteration (if any).
612 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
613   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
614       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
615   }
616 
617   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
618                                 PatternRewriter &rewriter) const override {
619     SmallVector<int64_t> peeledLoops;
620     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
621       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
622       peeledLoops =
623           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
624             return attr.cast<IntegerAttr>().getInt();
625           }));
626       // Check if the loop was already peeled.
627       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
628         return failure();
629     }
630     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
631       // No peeling of loop nests with a partial iteration.
632       return failure();
633 
634     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
635       return failure();
636 
637     // Peel loop and canonicalize.
638     TiledLoopOp result;
639     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
640                                                     result)))
641       return failure();
642 
643     // Apply label, so that the same loop is not rewritten a second time.
644     peeledLoops.push_back(idx);
645     rewriter.updateRootInPlace(loopOp, [&]() {
646       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
647     });
648     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
649     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
650 
651     return success();
652   }
653 
654   /// Index of loop to peel.
655   int64_t idx;
656 
657   /// If set to true, do not peel TiledLoopOps with a partial iteration.
658   bool skipPartial;
659 };
660 } // namespace
661 
662 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
663                                          ArrayRef<unsigned> loops,
664                                          bool skipPartial) {
665   MLIRContext *ctx = funcOp.getContext();
666   RewritePatternSet patterns(ctx);
667   for (unsigned idx : loops)
668     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
669   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
670 
671   // Drop the markers.
672   funcOp.walk([](TiledLoopOp op) {
673     op->removeAttr(kPeeledLoopsLabel);
674     op->removeAttr(kPartialIterationLabel);
675   });
676 }
677 
678 /// Apply transformations specified as patterns.
679 void TestLinalgTransforms::runOnFunction() {
680   auto lambda = [&](void *) {
681     getFunction().walk([](LinalgOp op) {
682       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
683     });
684   };
685   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
686 
687   if (testPromotionOptions) {
688     RewritePatternSet patterns(&getContext());
689     fillPromotionCallBackPatterns(&getContext(), patterns);
690     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
691     return;
692   }
693   if (testTileAndDistributionOptions) {
694     RewritePatternSet patterns(&getContext());
695     fillTileAndDistributePatterns(&getContext(), patterns);
696     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
697     return;
698   }
699   if (testPatterns)
700     return applyPatterns(getFunction());
701   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
702     return applyMatmulToVectorPatterns(getFunction(),
703                                        testMatmulToVectorPatterns1dTiling,
704                                        testMatmulToVectorPatterns2dTiling);
705   if (testVectorTransferForwardingPatterns)
706     return applyVectorTransferForwardingPatterns(getFunction());
707   if (testGenericToVectorPattern)
708     return applyLinalgToVectorPatterns(getFunction());
709   if (testTransformPadTensor)
710     return applyPadTensorToGenericPatterns(getFunction());
711   if (testGeneralizePadTensor)
712     return applyGeneralizePadTensorPatterns(getFunction());
713   if (testSwapSubTensorPadTensor)
714     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
715   if (testTiledLoopPeeling.hasValue())
716     return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
717                                         skipPartial);
718   if (testTilePattern)
719     return applyTilePattern(getFunction(), loopType, tileSizes, peeledLoops,
720                             /*scalarizeDynamicDims=*/false);
721   if (testTileScalarizeDynamicDims)
722     return applyTilePattern(getFunction(), loopType, tileSizes,
723                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
724   if (testDecomposeConvolutionPattern) {
725     // TODO: thread all tests through LinalgStrategy passes.
726     OpPassManager dynamicPM("builtin.func");
727     dynamicPM.addPass(createLinalgStrategyDecomposePass());
728     if (failed(runPipeline(dynamicPM, getFunction())))
729       return signalPassFailure();
730   }
731 }
732 
733 namespace mlir {
734 namespace test {
735 void registerTestLinalgTransforms() {
736   PassRegistration<TestLinalgTransforms>();
737 }
738 } // namespace test
739 } // namespace mlir
740