1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Vector/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct TestLinalgTransforms
35     : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> {
36   TestLinalgTransforms() = default;
37   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
38 
39   void getDependentDialects(DialectRegistry &registry) const override {
40     // clang-format off
41     registry.insert<AffineDialect,
42                     memref::MemRefDialect,
43                     scf::SCFDialect,
44                     StandardOpsDialect,
45                     linalg::LinalgDialect,
46                     vector::VectorDialect,
47                     gpu::GPUDialect>();
48     // clang-format on
49   }
50   StringRef getArgument() const final {
51     return "test-linalg-transform-patterns";
52   }
53   StringRef getDescription() const final {
54     return "Test Linalg transformation patterns by applying them greedily.";
55   }
56 
57   void runOnOperation() override;
58 
59   Option<bool> testPatterns{*this, "test-patterns",
60                             llvm::cl::desc("Test a mixed set of patterns"),
61                             llvm::cl::init(false)};
62   Option<bool> testMatmulToVectorPatterns1dTiling{
63       *this, "test-matmul-to-vector-patterns-tile-1d",
64       llvm::cl::desc(
65           "Test a fused pass that applies patterns from matmul to vectors via "
66           "1-d tiling"),
67       llvm::cl::init(false)};
68   Option<bool> testMatmulToVectorPatterns2dTiling{
69       *this, "test-matmul-to-vector-patterns-tile-2d",
70       llvm::cl::desc(
71           "Test a fused pass that applies patterns from matmul to vectors via "
72           "2-d tiling"),
73       llvm::cl::init(false)};
74   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
75                                     llvm::cl::desc("Test promotion options"),
76                                     llvm::cl::init(false)};
77   Option<bool> testTileAndDistributionOptions{
78       *this, "test-tile-and-distribute-options",
79       llvm::cl::desc("Test tile and distribute options"),
80       llvm::cl::init(false)};
81   Option<bool> testVectorTransferForwardingPatterns{
82       *this, "test-vector-transfer-forwarding-patterns",
83       llvm::cl::desc(
84           "Test a fused pass that forwards linalg.copy to vector.transfer"),
85       llvm::cl::init(false)};
86   Option<bool> testGenericToVectorPattern{
87       *this, "test-linalg-to-vector-patterns",
88       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
89                      "in vector.contract form"),
90       llvm::cl::init(false)};
91   Option<bool> testTilePattern{*this, "test-tile-pattern",
92                                llvm::cl::desc("Test tile pattern"),
93                                llvm::cl::init(false)};
94   Option<bool> testTileScalarizeDynamicDims{
95       *this, "test-tile-scalarize-dynamic-dims",
96       llvm::cl::desc("Test tiling of dynamic dims by 1"),
97       llvm::cl::init(false)};
98   Option<bool> testTransformPadTensor{
99       *this, "test-transform-pad-tensor",
100       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
101       llvm::cl::init(false)};
102   Option<bool> testGeneralizePadTensor{
103       *this, "test-generalize-pad-tensor",
104       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
105       llvm::cl::init(false)};
106   Option<bool> testSwapSubTensorPadTensor{
107       *this, "test-swap-subtensor-padtensor",
108       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
109                      "pad_tensor(subtensor)"),
110       llvm::cl::init(false)};
111   ListOption<int64_t> peeledLoops{
112       *this, "peeled-loops",
113       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
114       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
115   ListOption<int64_t> tileSizes{
116       *this, "tile-sizes",
117       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
118       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
119   ListOption<unsigned> testTiledLoopPeeling{
120       *this, "test-tiled-loop-peeling",
121       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
122       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
123   Option<bool> skipPartial{
124       *this, "skip-partial",
125       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
126       llvm::cl::init(false)};
127   Option<std::string> loopType{
128       *this, "loop-type",
129       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
130                      "tiled_loop"),
131       llvm::cl::init("for")};
132 };
133 } // namespace
134 
135 static void applyPatterns(FuncOp funcOp) {
136   MLIRContext *ctx = funcOp.getContext();
137   RewritePatternSet patterns(ctx);
138 
139   //===--------------------------------------------------------------------===//
140   // Linalg tiling patterns.
141   //===--------------------------------------------------------------------===//
142   patterns.add<LinalgTilingPattern>(
143       MatmulOp::getOperationName(), ctx,
144       LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
145       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
146                                  StringAttr::get(ctx, "L3")));
147   patterns.add<LinalgTilingPattern>(
148       MatmulOp::getOperationName(), ctx,
149       LinalgTilingOptions().setTileSizes({200, 300, 400}),
150       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
151                                  StringAttr::get(ctx, "L2")));
152   patterns.add<LinalgTilingPattern>(
153       MatmulOp::getOperationName(), ctx,
154       LinalgTilingOptions().setTileSizes({20, 30, 40}),
155       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
156                                  StringAttr::get(ctx, "L1")));
157   patterns.add<LinalgTilingPattern>(
158       MatmulOp::getOperationName(), ctx,
159       LinalgTilingOptions().setTileSizes({2, 3, 4}),
160       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
161                                  StringAttr::get(ctx, "REG")));
162 
163   patterns.add<LinalgTilingPattern>(
164       MatvecOp::getOperationName(), ctx,
165       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
166           LinalgTilingLoopType::ParallelLoops),
167       LinalgTransformationFilter(ArrayRef<StringAttr>{},
168                                  StringAttr::get(ctx, "L1")));
169 
170   patterns.add<LinalgTilingPattern>(
171       DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
172       LinalgTransformationFilter(
173           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
174                                StringAttr::get(ctx, "L3"),
175                                StringAttr::get(ctx, "L2")},
176           StringAttr::get(ctx, "REG")));
177 
178   //===--------------------------------------------------------------------===//
179   // Linalg tiling and permutation patterns.
180   //===--------------------------------------------------------------------===//
181   patterns.add<LinalgTilingPattern>(
182       MatmulOp::getOperationName(), ctx,
183       LinalgTilingOptions()
184           .setTileSizes({2000, 3000, 4000})
185           .setInterchange({1, 2, 0}),
186       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
187                                  StringAttr::get(ctx, "L2__with_perm__")));
188   patterns.add<LinalgTilingPattern>(
189       MatmulOp::getOperationName(), ctx,
190       LinalgTilingOptions()
191           .setTileSizes({200, 300, 400})
192           .setInterchange({1, 0, 2}),
193       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
194                                  StringAttr::get(ctx, "L1__with_perm__")));
195   patterns.add<LinalgTilingPattern>(
196       MatmulOp::getOperationName(), ctx,
197       LinalgTilingOptions().setTileSizes({20, 30, 40}),
198       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
199                                  StringAttr::get(ctx, "REG__with_perm__")));
200 
201   patterns.add<LinalgTilingPattern>(
202       MatvecOp::getOperationName(), ctx,
203       LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
204       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
205                                  StringAttr::get(ctx, "L1__with_perm__")));
206 
207   patterns.add<LinalgTilingPattern>(
208       MatmulOp::getOperationName(), ctx,
209       LinalgTilingOptions()
210           .setTileSizes({16, 8, 4})
211           .setInterchange({1, 2, 0})
212           .setLoopType(LinalgTilingLoopType::ParallelLoops),
213       LinalgTransformationFilter(
214           StringAttr::get(ctx, "par__with_perm__"),
215           StringAttr::get(ctx, "after_par__with_perm__")));
216 
217   //===--------------------------------------------------------------------===//
218   // Linalg to loops patterns.
219   //===--------------------------------------------------------------------===//
220   patterns.add<LinalgLoweringPattern<DotOp>>(
221       ctx,
222       /*loweringType=*/LinalgLoweringType::Loops,
223       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
224 
225   //===--------------------------------------------------------------------===//
226   // Linalg distribution patterns.
227   //===--------------------------------------------------------------------===//
228   LinalgLoopDistributionOptions distributionOptions;
229 
230   //===--------------------------------------------------------------------===//
231   // Linalg to vector contraction patterns.
232   //===--------------------------------------------------------------------===//
233   patterns.add<LinalgVectorizationPattern>(
234       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
235                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
236 
237   //===--------------------------------------------------------------------===//
238   // Linalg generic interchange pattern.
239   //===--------------------------------------------------------------------===//
240   patterns.add<GenericOpInterchangePattern>(
241       ctx,
242       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
243       LinalgTransformationFilter(ArrayRef<StringAttr>{},
244                                  StringAttr::get(ctx, "PERMUTED")));
245 
246   //===--------------------------------------------------------------------===//
247   // Linalg subview operands promotion.
248   //===--------------------------------------------------------------------===//
249   patterns.add<LinalgPromotionPattern<MatmulOp>>(
250       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
251       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
252                                  StringAttr::get(ctx, "_views_promoted_")));
253   patterns.add<LinalgPromotionPattern<MatmulOp>>(
254       ctx,
255       LinalgPromotionOptions()
256           .setOperandsToPromote({0})
257           .setUseFullTileBuffersByDefault(true),
258       LinalgTransformationFilter(
259           StringAttr::get(ctx, "_promote_first_view_"),
260           StringAttr::get(ctx, "_first_view_promoted_")));
261   patterns.add<LinalgPromotionPattern<FillOp>>(
262       ctx,
263       LinalgPromotionOptions()
264           .setOperandsToPromote({1})
265           .setUseFullTileBuffers({false, true})
266           .setAlignment(32),
267       LinalgTransformationFilter(
268           StringAttr::get(ctx, "_promote_views_aligned_"),
269           StringAttr::get(ctx, "_views_aligned_promoted_")));
270 
271   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
272 
273   // Drop the marker.
274   funcOp.walk([](LinalgOp op) {
275     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
276   });
277 }
278 
279 static void fillL1TilingAndMatmulToVectorPatterns(
280     FuncOp funcOp, StringRef startMarker,
281     SmallVectorImpl<RewritePatternSet> &patternsVector) {
282   MLIRContext *ctx = funcOp.getContext();
283   patternsVector.emplace_back(
284       ctx, std::make_unique<LinalgTilingPattern>(
285                MatmulOp::getOperationName(), ctx,
286                LinalgTilingOptions()
287                    .setTileSizes({8, 12, 16})
288                    .setInterchange({1, 0, 2}),
289                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
290                                           StringAttr::get(ctx, "L1"))));
291 
292   patternsVector.emplace_back(
293       ctx,
294       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
295           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
296           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
297                                      StringAttr::get(ctx, "VEC"))));
298 
299   patternsVector.emplace_back(
300       ctx, std::make_unique<LinalgVectorizationPattern>(
301                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
302                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
303   patternsVector.back().add<LinalgVectorizationPattern>(
304       ctx, LinalgTransformationFilter().addOpFilter<FillOp, CopyOp>());
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // Test promotion callbacks
309 //===----------------------------------------------------------------------===//
310 
311 // Allocation call back
312 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
313                                        ArrayRef<Value> boundingSubViewSize,
314                                        DataLayout &layout) {
315   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
316   return b
317       .create<memref::AllocOp>(
318           subView.getLoc(),
319           MemRefType::get(shape, subView.getType().getElementType(),
320                           /*affineMapComposition =*/{}, 3),
321           boundingSubViewSize)
322       .getResult();
323 }
324 
325 // Deallocation callback
326 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
327   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
328   return success();
329 }
330 
331 // Copy in call back
332 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
333                                     bool isOutput) {
334   auto floatType = src.getType().cast<MemRefType>().getElementType();
335   if (!floatType.isa<FloatType>())
336     return failure();
337   if (!isOutput) {
338     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
339                                             FloatAttr::get(floatType, 42.0));
340     b.create<FillOp>(src.getLoc(), cst, dst);
341   }
342   b.create<CopyOp>(src.getLoc(), src, dst);
343   return success();
344 }
345 
346 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
347                                           RewritePatternSet &patterns) {
348   patterns.add<LinalgTilingPattern>(
349       MatmulOp::getOperationName(), ctx,
350       LinalgTilingOptions().setTileSizes({16, 16, 16}),
351       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
352                                  StringAttr::get(ctx, "PROMOTE")));
353   patterns.add<LinalgPromotionPattern<MatmulOp>>(
354       ctx,
355       LinalgPromotionOptions()
356           .setOperandsToPromote({0, 2})
357           .setUseFullTileBuffers({false, false})
358           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
359           .setCopyInOutFns(
360               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
361                 return copyCallBackFn(b, src, dst, false);
362               },
363               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
364                 return copyCallBackFn(b, src, dst, true);
365               }),
366       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
367 }
368 
369 template <typename IdOp, typename NProcsOp>
370 static SmallVector<ProcInfo, 2>
371 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
372   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
373   SmallVector<ProcInfo, 2> procInfo(count);
374   Type indexType = b.getIndexType();
375   for (unsigned i = 0; i < count; ++i) {
376     gpu::Dimension dim = *gpu::symbolizeDimension(i);
377     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
378                                b.create<NProcsOp>(loc, indexType, dim)};
379   }
380   return procInfo;
381 }
382 
383 static void fillTileAndDistributePatterns(MLIRContext *context,
384                                           RewritePatternSet &patterns) {
385   {
386     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
387     cyclicNprocsEqNiters.distributionMethod.resize(
388         2, DistributionMethod::CyclicNumProcsEqNumIters);
389     cyclicNprocsEqNiters.procInfo =
390         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
391     patterns.add<LinalgTilingPattern>(
392         MatmulOp::getOperationName(), context,
393         LinalgTilingOptions()
394             .setTileSizes({8, 8, 4})
395             .setLoopType(LinalgTilingLoopType::ParallelLoops)
396             .setDistributionOptions(cyclicNprocsEqNiters),
397         LinalgTransformationFilter(
398             StringAttr::get(context, "distribute1"),
399             StringAttr::get(context, "after_distribute1")));
400   }
401 
402   {
403     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
404     cyclicNprocsGeNiters.distributionMethod.resize(
405         2, DistributionMethod::CyclicNumProcsGeNumIters);
406     cyclicNprocsGeNiters.procInfo =
407         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
408     patterns.add<LinalgTilingPattern>(
409         MatmulOp::getOperationName(), context,
410         LinalgTilingOptions()
411             .setTileSizes({8, 8, 4})
412             .setLoopType(LinalgTilingLoopType::ParallelLoops)
413             .setDistributionOptions(cyclicNprocsGeNiters),
414         LinalgTransformationFilter(
415             StringAttr::get(context, "distribute2"),
416             StringAttr::get(context, "after_distribute2")));
417   }
418 
419   {
420     LinalgLoopDistributionOptions cyclicNprocsDefault;
421     cyclicNprocsDefault.distributionMethod.resize(2,
422                                                   DistributionMethod::Cyclic);
423     cyclicNprocsDefault.procInfo =
424         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
425     patterns.add<LinalgTilingPattern>(
426         MatmulOp::getOperationName(), context,
427         LinalgTilingOptions()
428             .setTileSizes({8, 8, 4})
429             .setLoopType(LinalgTilingLoopType::ParallelLoops)
430             .setDistributionOptions(cyclicNprocsDefault),
431         LinalgTransformationFilter(
432             StringAttr::get(context, "distribute3"),
433             StringAttr::get(context, "after_distribute3")));
434   }
435 
436   {
437     LinalgLoopDistributionOptions cyclicNprocsMixed1;
438     cyclicNprocsMixed1.distributionMethod = {
439         DistributionMethod::CyclicNumProcsEqNumIters,
440         DistributionMethod::CyclicNumProcsGeNumIters};
441     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
442     patterns.add<LinalgTilingPattern>(
443         MatmulOp::getOperationName(), context,
444         LinalgTilingOptions()
445             .setTileSizes({8, 8, 4})
446             .setLoopType(LinalgTilingLoopType::ParallelLoops)
447             .setDistributionOptions(cyclicNprocsMixed1),
448         LinalgTransformationFilter(
449             StringAttr::get(context, "distribute4"),
450             StringAttr::get(context, "after_distribute4")));
451   }
452 
453   {
454     LinalgLoopDistributionOptions cyclicNprocsMixed2;
455     cyclicNprocsMixed2.distributionMethod = {
456         DistributionMethod::CyclicNumProcsGeNumIters,
457         DistributionMethod::Cyclic};
458     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
459     patterns.add<LinalgTilingPattern>(
460         MatmulOp::getOperationName(), context,
461         LinalgTilingOptions()
462             .setTileSizes({8, 8, 4})
463             .setLoopType(LinalgTilingLoopType::ParallelLoops)
464             .setDistributionOptions(cyclicNprocsMixed2),
465         LinalgTransformationFilter(
466             StringAttr::get(context, "distribute5"),
467             StringAttr::get(context, "after_distribute5")));
468   }
469 
470   {
471     LinalgLoopDistributionOptions cyclicNprocsMixed3;
472     cyclicNprocsMixed3.distributionMethod = {
473         DistributionMethod::Cyclic,
474         DistributionMethod::CyclicNumProcsEqNumIters};
475     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
476 
477     patterns.add<LinalgTilingPattern>(
478         MatmulOp::getOperationName(), context,
479         LinalgTilingOptions()
480             .setTileSizes({8, 8, 4})
481             .setLoopType(LinalgTilingLoopType::ParallelLoops)
482             .setDistributionOptions(cyclicNprocsMixed3),
483         LinalgTransformationFilter(
484             StringAttr::get(context, "distribute6"),
485             StringAttr::get(context, "after_distribute6")));
486   }
487 
488   {
489     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
490     cyclicNprocsEqNiters.distributionMethod.resize(2,
491                                                    DistributionMethod::Cyclic);
492     cyclicNprocsEqNiters.procInfo =
493         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
494     patterns.add<LinalgTilingPattern>(
495         MatmulOp::getOperationName(), context,
496         LinalgTilingOptions()
497             .setTileSizes({8, 8, 4})
498             .setLoopType(LinalgTilingLoopType::Loops)
499             .setDistributionOptions(cyclicNprocsEqNiters),
500         LinalgTransformationFilter(
501             StringAttr::get(context, "tensors_distribute1"),
502             StringAttr::get(context, "tensors_after_distribute1")));
503   }
504 }
505 
506 static void
507 applyMatmulToVectorPatterns(FuncOp funcOp,
508                             bool testMatmulToVectorPatterns1dTiling,
509                             bool testMatmulToVectorPatterns2dTiling) {
510   MLIRContext *ctx = funcOp.getContext();
511   SmallVector<RewritePatternSet, 4> stage1Patterns;
512   if (testMatmulToVectorPatterns1dTiling) {
513     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
514   } else if (testMatmulToVectorPatterns2dTiling) {
515     stage1Patterns.emplace_back(
516         ctx, std::make_unique<LinalgTilingPattern>(
517                  MatmulOp::getOperationName(), ctx,
518                  LinalgTilingOptions()
519                      .setTileSizes({768, 264, 768})
520                      .setInterchange({1, 2, 0}),
521                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
522                                             StringAttr::get(ctx, "L2"))));
523     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
524   }
525   {
526     // Canonicalization patterns
527     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
528     vector::populateVectorTransferPermutationMapLoweringPatterns(
529         canonicalizationPatterns);
530     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
531     stage1Patterns.push_back(std::move(canonicalizationPatterns));
532   }
533   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
534   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
535   FrozenRewritePatternSet stage2Patterns =
536       getLinalgTilingCanonicalizationPatterns(ctx);
537   (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns);
538 }
539 
540 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
541   RewritePatternSet forwardPattern(funcOp.getContext());
542   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
543   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
544   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
545 }
546 
547 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
548   RewritePatternSet patterns(funcOp.getContext());
549   patterns.add<LinalgVectorizationPattern>(
550       funcOp.getContext(),
551       LinalgTransformationFilter()
552           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
553   populatePadOpVectorizationPatterns(patterns);
554   populateConvolutionVectorizationPatterns(patterns);
555   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
556 }
557 
558 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
559   RewritePatternSet patterns(funcOp.getContext());
560   patterns.add<PadOpTransformationPattern>(funcOp.getContext());
561   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
562 }
563 
564 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
565   RewritePatternSet patterns(funcOp.getContext());
566   patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
567   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
568 }
569 
570 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
571   RewritePatternSet patterns(funcOp.getContext());
572   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
573   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
574 }
575 
576 static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
577                              ArrayRef<int64_t> tileSizes,
578                              ArrayRef<int64_t> peeledLoops,
579                              bool scalarizeDynamicDims) {
580   MLIRContext *context = funcOp.getContext();
581   RewritePatternSet tilingPattern(context);
582   LinalgTilingLoopType type =
583       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
584           .Case("for", LinalgTilingLoopType::Loops)
585           .Case("affine", LinalgTilingLoopType::AffineLoops)
586           .Case("parallel", LinalgTilingLoopType::ParallelLoops)
587           .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
588   auto linalgTilingOptions = linalg::LinalgTilingOptions()
589                                  .setPeeledLoops(peeledLoops)
590                                  .setLoopType(type);
591   if (scalarizeDynamicDims) {
592     linalgTilingOptions.scalarizeDynamicDims();
593     assert(tileSizes.empty() &&
594            "tileSizes and scalarizeDynamicDims is mutually exclusive");
595   } else {
596     linalgTilingOptions.setTileSizes(tileSizes);
597   }
598   linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
599   TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
600       tilingPattern, linalgTilingOptions, f);
601   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
602 }
603 
604 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
605 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
606 
607 namespace {
608 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
609 /// `idx`-th loop contains only "full" iterations and a second loop for the
610 /// remaining partial iteration (if any).
611 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
612   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
613       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
614   }
615 
616   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
617                                 PatternRewriter &rewriter) const override {
618     SmallVector<int64_t> peeledLoops;
619     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
620       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
621       peeledLoops =
622           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
623             return attr.cast<IntegerAttr>().getInt();
624           }));
625       // Check if the loop was already peeled.
626       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
627         return failure();
628     }
629     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
630       // No peeling of loop nests with a partial iteration.
631       return failure();
632 
633     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
634       return failure();
635 
636     // Peel loop and canonicalize.
637     TiledLoopOp result;
638     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
639                                                     result)))
640       return failure();
641 
642     // Apply label, so that the same loop is not rewritten a second time.
643     peeledLoops.push_back(idx);
644     rewriter.updateRootInPlace(loopOp, [&]() {
645       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
646     });
647     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
648     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
649 
650     return success();
651   }
652 
653   /// Index of loop to peel.
654   int64_t idx;
655 
656   /// If set to true, do not peel TiledLoopOps with a partial iteration.
657   bool skipPartial;
658 };
659 } // namespace
660 
661 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
662                                          ArrayRef<unsigned> loops,
663                                          bool skipPartial) {
664   MLIRContext *ctx = funcOp.getContext();
665   RewritePatternSet patterns(ctx);
666   for (unsigned idx : loops)
667     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
668   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
669 
670   // Drop the markers.
671   funcOp.walk([](TiledLoopOp op) {
672     op->removeAttr(kPeeledLoopsLabel);
673     op->removeAttr(kPartialIterationLabel);
674   });
675 }
676 
677 /// Apply transformations specified as patterns.
678 void TestLinalgTransforms::runOnOperation() {
679   auto lambda = [&](void *) {
680     getOperation().walk([](LinalgOp op) {
681       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
682     });
683   };
684   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
685 
686   if (testPromotionOptions) {
687     RewritePatternSet patterns(&getContext());
688     fillPromotionCallBackPatterns(&getContext(), patterns);
689     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
690     return;
691   }
692   if (testTileAndDistributionOptions) {
693     RewritePatternSet patterns(&getContext());
694     fillTileAndDistributePatterns(&getContext(), patterns);
695     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
696     return;
697   }
698   if (testPatterns)
699     return applyPatterns(getOperation());
700   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
701     return applyMatmulToVectorPatterns(getOperation(),
702                                        testMatmulToVectorPatterns1dTiling,
703                                        testMatmulToVectorPatterns2dTiling);
704   if (testVectorTransferForwardingPatterns)
705     return applyVectorTransferForwardingPatterns(getOperation());
706   if (testGenericToVectorPattern)
707     return applyLinalgToVectorPatterns(getOperation());
708   if (testTransformPadTensor)
709     return applyPadTensorToGenericPatterns(getOperation());
710   if (testGeneralizePadTensor)
711     return applyGeneralizePadTensorPatterns(getOperation());
712   if (testSwapSubTensorPadTensor)
713     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
714   if (testTiledLoopPeeling.hasValue())
715     return applyTiledLoopPeelingPattern(getOperation(), testTiledLoopPeeling,
716                                         skipPartial);
717   if (testTilePattern)
718     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
719                             /*scalarizeDynamicDims=*/false);
720   if (testTileScalarizeDynamicDims)
721     return applyTilePattern(getOperation(), loopType, tileSizes,
722                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
723 }
724 
725 namespace mlir {
726 namespace test {
727 void registerTestLinalgTransforms() {
728   PassRegistration<TestLinalgTransforms>();
729 }
730 } // namespace test
731 } // namespace mlir
732