1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Vector/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct TestLinalgTransforms
35     : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> {
36   TestLinalgTransforms() = default;
37   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
38 
39   void getDependentDialects(DialectRegistry &registry) const override {
40     // clang-format off
41     registry.insert<AffineDialect,
42                     memref::MemRefDialect,
43                     scf::SCFDialect,
44                     StandardOpsDialect,
45                     vector::VectorDialect,
46                     gpu::GPUDialect>();
47     // clang-format on
48   }
49   StringRef getArgument() const final {
50     return "test-linalg-transform-patterns";
51   }
52   StringRef getDescription() const final {
53     return "Test Linalg transformation patterns by applying them greedily.";
54   }
55 
56   void runOnOperation() override;
57 
58   Option<bool> testPatterns{*this, "test-patterns",
59                             llvm::cl::desc("Test a mixed set of patterns"),
60                             llvm::cl::init(false)};
61   Option<bool> testMatmulToVectorPatterns1dTiling{
62       *this, "test-matmul-to-vector-patterns-tile-1d",
63       llvm::cl::desc(
64           "Test a fused pass that applies patterns from matmul to vectors via "
65           "1-d tiling"),
66       llvm::cl::init(false)};
67   Option<bool> testMatmulToVectorPatterns2dTiling{
68       *this, "test-matmul-to-vector-patterns-tile-2d",
69       llvm::cl::desc(
70           "Test a fused pass that applies patterns from matmul to vectors via "
71           "2-d tiling"),
72       llvm::cl::init(false)};
73   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
74                                     llvm::cl::desc("Test promotion options"),
75                                     llvm::cl::init(false)};
76   Option<bool> testTileAndDistributionOptions{
77       *this, "test-tile-and-distribute-options",
78       llvm::cl::desc("Test tile and distribute options"),
79       llvm::cl::init(false)};
80   Option<bool> testVectorTransferForwardingPatterns{
81       *this, "test-vector-transfer-forwarding-patterns",
82       llvm::cl::desc(
83           "Test a fused pass that forwards linalg.copy to vector.transfer"),
84       llvm::cl::init(false)};
85   Option<bool> testGenericToVectorPattern{
86       *this, "test-linalg-to-vector-patterns",
87       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
88                      "in vector.contract form"),
89       llvm::cl::init(false)};
90   Option<bool> testTilePattern{*this, "test-tile-pattern",
91                                llvm::cl::desc("Test tile pattern"),
92                                llvm::cl::init(false)};
93   Option<bool> testTileScalarizeDynamicDims{
94       *this, "test-tile-scalarize-dynamic-dims",
95       llvm::cl::desc("Test tiling of dynamic dims by 1"),
96       llvm::cl::init(false)};
97   Option<bool> testTransformPadTensor{
98       *this, "test-transform-pad-tensor",
99       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
100       llvm::cl::init(false)};
101   Option<bool> testGeneralizePadTensor{
102       *this, "test-generalize-pad-tensor",
103       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
104       llvm::cl::init(false)};
105   Option<bool> testSwapSubTensorPadTensor{
106       *this, "test-swap-subtensor-padtensor",
107       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
108                      "pad_tensor(subtensor)"),
109       llvm::cl::init(false)};
110   ListOption<int64_t> peeledLoops{
111       *this, "peeled-loops",
112       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
113       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
114   ListOption<int64_t> tileSizes{
115       *this, "tile-sizes",
116       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
117       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
118   ListOption<unsigned> testTiledLoopPeeling{
119       *this, "test-tiled-loop-peeling",
120       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
121       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
122   Option<bool> skipPartial{
123       *this, "skip-partial",
124       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
125       llvm::cl::init(false)};
126   Option<std::string> loopType{
127       *this, "loop-type",
128       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
129                      "tiled_loop"),
130       llvm::cl::init("for")};
131 };
132 } // namespace
133 
134 static void applyPatterns(FuncOp funcOp) {
135   MLIRContext *ctx = funcOp.getContext();
136   RewritePatternSet patterns(ctx);
137 
138   //===--------------------------------------------------------------------===//
139   // Linalg tiling patterns.
140   //===--------------------------------------------------------------------===//
141   patterns.add<LinalgTilingPattern>(
142       MatmulOp::getOperationName(), ctx,
143       LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
144       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
145                                  StringAttr::get(ctx, "L3")));
146   patterns.add<LinalgTilingPattern>(
147       MatmulOp::getOperationName(), ctx,
148       LinalgTilingOptions().setTileSizes({200, 300, 400}),
149       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
150                                  StringAttr::get(ctx, "L2")));
151   patterns.add<LinalgTilingPattern>(
152       MatmulOp::getOperationName(), ctx,
153       LinalgTilingOptions().setTileSizes({20, 30, 40}),
154       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
155                                  StringAttr::get(ctx, "L1")));
156   patterns.add<LinalgTilingPattern>(
157       MatmulOp::getOperationName(), ctx,
158       LinalgTilingOptions().setTileSizes({2, 3, 4}),
159       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
160                                  StringAttr::get(ctx, "REG")));
161 
162   patterns.add<LinalgTilingPattern>(
163       MatvecOp::getOperationName(), ctx,
164       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
165           LinalgTilingLoopType::ParallelLoops),
166       LinalgTransformationFilter(ArrayRef<StringAttr>{},
167                                  StringAttr::get(ctx, "L1")));
168 
169   patterns.add<LinalgTilingPattern>(
170       DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
171       LinalgTransformationFilter(
172           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
173                                StringAttr::get(ctx, "L3"),
174                                StringAttr::get(ctx, "L2")},
175           StringAttr::get(ctx, "REG")));
176 
177   //===--------------------------------------------------------------------===//
178   // Linalg tiling and permutation patterns.
179   //===--------------------------------------------------------------------===//
180   patterns.add<LinalgTilingPattern>(
181       MatmulOp::getOperationName(), ctx,
182       LinalgTilingOptions()
183           .setTileSizes({2000, 3000, 4000})
184           .setInterchange({1, 2, 0}),
185       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
186                                  StringAttr::get(ctx, "L2__with_perm__")));
187   patterns.add<LinalgTilingPattern>(
188       MatmulOp::getOperationName(), ctx,
189       LinalgTilingOptions()
190           .setTileSizes({200, 300, 400})
191           .setInterchange({1, 0, 2}),
192       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
193                                  StringAttr::get(ctx, "L1__with_perm__")));
194   patterns.add<LinalgTilingPattern>(
195       MatmulOp::getOperationName(), ctx,
196       LinalgTilingOptions().setTileSizes({20, 30, 40}),
197       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
198                                  StringAttr::get(ctx, "REG__with_perm__")));
199 
200   patterns.add<LinalgTilingPattern>(
201       MatvecOp::getOperationName(), ctx,
202       LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
203       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
204                                  StringAttr::get(ctx, "L1__with_perm__")));
205 
206   patterns.add<LinalgTilingPattern>(
207       MatmulOp::getOperationName(), ctx,
208       LinalgTilingOptions()
209           .setTileSizes({16, 8, 4})
210           .setInterchange({1, 2, 0})
211           .setLoopType(LinalgTilingLoopType::ParallelLoops),
212       LinalgTransformationFilter(
213           StringAttr::get(ctx, "par__with_perm__"),
214           StringAttr::get(ctx, "after_par__with_perm__")));
215 
216   //===--------------------------------------------------------------------===//
217   // Linalg to loops patterns.
218   //===--------------------------------------------------------------------===//
219   patterns.add<LinalgLoweringPattern<DotOp>>(
220       ctx,
221       /*loweringType=*/LinalgLoweringType::Loops,
222       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
223 
224   //===--------------------------------------------------------------------===//
225   // Linalg distribution patterns.
226   //===--------------------------------------------------------------------===//
227   LinalgLoopDistributionOptions distributionOptions;
228 
229   //===--------------------------------------------------------------------===//
230   // Linalg to vector contraction patterns.
231   //===--------------------------------------------------------------------===//
232   patterns.add<LinalgVectorizationPattern>(
233       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
234                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
235 
236   //===--------------------------------------------------------------------===//
237   // Linalg generic interchange pattern.
238   //===--------------------------------------------------------------------===//
239   patterns.add<GenericOpInterchangePattern>(
240       ctx,
241       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
242       LinalgTransformationFilter(ArrayRef<StringAttr>{},
243                                  StringAttr::get(ctx, "PERMUTED")));
244 
245   //===--------------------------------------------------------------------===//
246   // Linalg subview operands promotion.
247   //===--------------------------------------------------------------------===//
248   patterns.add<LinalgPromotionPattern<MatmulOp>>(
249       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
250       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
251                                  StringAttr::get(ctx, "_views_promoted_")));
252   patterns.add<LinalgPromotionPattern<MatmulOp>>(
253       ctx,
254       LinalgPromotionOptions()
255           .setOperandsToPromote({0})
256           .setUseFullTileBuffersByDefault(true),
257       LinalgTransformationFilter(
258           StringAttr::get(ctx, "_promote_first_view_"),
259           StringAttr::get(ctx, "_first_view_promoted_")));
260   patterns.add<LinalgPromotionPattern<FillOp>>(
261       ctx,
262       LinalgPromotionOptions()
263           .setOperandsToPromote({1})
264           .setUseFullTileBuffers({false, true})
265           .setAlignment(32),
266       LinalgTransformationFilter(
267           StringAttr::get(ctx, "_promote_views_aligned_"),
268           StringAttr::get(ctx, "_views_aligned_promoted_")));
269 
270   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
271 
272   // Drop the marker.
273   funcOp.walk([](LinalgOp op) {
274     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
275   });
276 }
277 
278 static void fillL1TilingAndMatmulToVectorPatterns(
279     FuncOp funcOp, StringRef startMarker,
280     SmallVectorImpl<RewritePatternSet> &patternsVector) {
281   MLIRContext *ctx = funcOp.getContext();
282   patternsVector.emplace_back(
283       ctx, std::make_unique<LinalgTilingPattern>(
284                MatmulOp::getOperationName(), ctx,
285                LinalgTilingOptions()
286                    .setTileSizes({8, 12, 16})
287                    .setInterchange({1, 0, 2}),
288                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
289                                           StringAttr::get(ctx, "L1"))));
290 
291   patternsVector.emplace_back(
292       ctx,
293       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
294           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
295           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
296                                      StringAttr::get(ctx, "VEC"))));
297 
298   patternsVector.emplace_back(
299       ctx, std::make_unique<LinalgVectorizationPattern>(
300                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
301                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
302   patternsVector.back().add<LinalgVectorizationPattern>(
303       ctx, LinalgTransformationFilter().addOpFilter<FillOp, CopyOp>());
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Test promotion callbacks
308 //===----------------------------------------------------------------------===//
309 
310 // Allocation call back
311 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
312                                        ArrayRef<Value> boundingSubViewSize,
313                                        DataLayout &layout) {
314   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
315   return b
316       .create<memref::AllocOp>(
317           subView.getLoc(),
318           MemRefType::get(shape, subView.getType().getElementType(),
319                           /*affineMapComposition =*/{}, 3),
320           boundingSubViewSize)
321       .getResult();
322 }
323 
324 // Deallocation callback
325 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
326   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
327   return success();
328 }
329 
330 // Copy in call back
331 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
332                                     bool isOutput) {
333   auto floatType = src.getType().cast<MemRefType>().getElementType();
334   if (!floatType.isa<FloatType>())
335     return failure();
336   if (!isOutput) {
337     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
338                                             FloatAttr::get(floatType, 42.0));
339     b.create<FillOp>(src.getLoc(), cst, dst);
340   }
341   b.create<CopyOp>(src.getLoc(), src, dst);
342   return success();
343 }
344 
345 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
346                                           RewritePatternSet &patterns) {
347   patterns.add<LinalgTilingPattern>(
348       MatmulOp::getOperationName(), ctx,
349       LinalgTilingOptions().setTileSizes({16, 16, 16}),
350       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
351                                  StringAttr::get(ctx, "PROMOTE")));
352   patterns.add<LinalgPromotionPattern<MatmulOp>>(
353       ctx,
354       LinalgPromotionOptions()
355           .setOperandsToPromote({0, 2})
356           .setUseFullTileBuffers({false, false})
357           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
358           .setCopyInOutFns(
359               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
360                 return copyCallBackFn(b, src, dst, false);
361               },
362               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
363                 return copyCallBackFn(b, src, dst, true);
364               }),
365       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
366 }
367 
368 template <typename IdOp, typename NProcsOp>
369 static SmallVector<ProcInfo, 2>
370 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
371   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
372   SmallVector<ProcInfo, 2> procInfo(count);
373   Type indexType = b.getIndexType();
374   for (unsigned i = 0; i < count; ++i) {
375     gpu::Dimension dim = *gpu::symbolizeDimension(i);
376     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
377                                b.create<NProcsOp>(loc, indexType, dim)};
378   }
379   return procInfo;
380 }
381 
382 static void fillTileAndDistributePatterns(MLIRContext *context,
383                                           RewritePatternSet &patterns) {
384   {
385     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
386     cyclicNprocsEqNiters.distributionMethod.resize(
387         2, DistributionMethod::CyclicNumProcsEqNumIters);
388     cyclicNprocsEqNiters.procInfo =
389         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
390     patterns.add<LinalgTilingPattern>(
391         MatmulOp::getOperationName(), context,
392         LinalgTilingOptions()
393             .setTileSizes({8, 8, 4})
394             .setLoopType(LinalgTilingLoopType::ParallelLoops)
395             .setDistributionOptions(cyclicNprocsEqNiters),
396         LinalgTransformationFilter(
397             StringAttr::get(context, "distribute1"),
398             StringAttr::get(context, "after_distribute1")));
399   }
400 
401   {
402     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
403     cyclicNprocsGeNiters.distributionMethod.resize(
404         2, DistributionMethod::CyclicNumProcsGeNumIters);
405     cyclicNprocsGeNiters.procInfo =
406         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
407     patterns.add<LinalgTilingPattern>(
408         MatmulOp::getOperationName(), context,
409         LinalgTilingOptions()
410             .setTileSizes({8, 8, 4})
411             .setLoopType(LinalgTilingLoopType::ParallelLoops)
412             .setDistributionOptions(cyclicNprocsGeNiters),
413         LinalgTransformationFilter(
414             StringAttr::get(context, "distribute2"),
415             StringAttr::get(context, "after_distribute2")));
416   }
417 
418   {
419     LinalgLoopDistributionOptions cyclicNprocsDefault;
420     cyclicNprocsDefault.distributionMethod.resize(2,
421                                                   DistributionMethod::Cyclic);
422     cyclicNprocsDefault.procInfo =
423         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
424     patterns.add<LinalgTilingPattern>(
425         MatmulOp::getOperationName(), context,
426         LinalgTilingOptions()
427             .setTileSizes({8, 8, 4})
428             .setLoopType(LinalgTilingLoopType::ParallelLoops)
429             .setDistributionOptions(cyclicNprocsDefault),
430         LinalgTransformationFilter(
431             StringAttr::get(context, "distribute3"),
432             StringAttr::get(context, "after_distribute3")));
433   }
434 
435   {
436     LinalgLoopDistributionOptions cyclicNprocsMixed1;
437     cyclicNprocsMixed1.distributionMethod = {
438         DistributionMethod::CyclicNumProcsEqNumIters,
439         DistributionMethod::CyclicNumProcsGeNumIters};
440     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
441     patterns.add<LinalgTilingPattern>(
442         MatmulOp::getOperationName(), context,
443         LinalgTilingOptions()
444             .setTileSizes({8, 8, 4})
445             .setLoopType(LinalgTilingLoopType::ParallelLoops)
446             .setDistributionOptions(cyclicNprocsMixed1),
447         LinalgTransformationFilter(
448             StringAttr::get(context, "distribute4"),
449             StringAttr::get(context, "after_distribute4")));
450   }
451 
452   {
453     LinalgLoopDistributionOptions cyclicNprocsMixed2;
454     cyclicNprocsMixed2.distributionMethod = {
455         DistributionMethod::CyclicNumProcsGeNumIters,
456         DistributionMethod::Cyclic};
457     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
458     patterns.add<LinalgTilingPattern>(
459         MatmulOp::getOperationName(), context,
460         LinalgTilingOptions()
461             .setTileSizes({8, 8, 4})
462             .setLoopType(LinalgTilingLoopType::ParallelLoops)
463             .setDistributionOptions(cyclicNprocsMixed2),
464         LinalgTransformationFilter(
465             StringAttr::get(context, "distribute5"),
466             StringAttr::get(context, "after_distribute5")));
467   }
468 
469   {
470     LinalgLoopDistributionOptions cyclicNprocsMixed3;
471     cyclicNprocsMixed3.distributionMethod = {
472         DistributionMethod::Cyclic,
473         DistributionMethod::CyclicNumProcsEqNumIters};
474     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
475 
476     patterns.add<LinalgTilingPattern>(
477         MatmulOp::getOperationName(), context,
478         LinalgTilingOptions()
479             .setTileSizes({8, 8, 4})
480             .setLoopType(LinalgTilingLoopType::ParallelLoops)
481             .setDistributionOptions(cyclicNprocsMixed3),
482         LinalgTransformationFilter(
483             StringAttr::get(context, "distribute6"),
484             StringAttr::get(context, "after_distribute6")));
485   }
486 
487   {
488     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
489     cyclicNprocsEqNiters.distributionMethod.resize(2,
490                                                    DistributionMethod::Cyclic);
491     cyclicNprocsEqNiters.procInfo =
492         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
493     patterns.add<LinalgTilingPattern>(
494         MatmulOp::getOperationName(), context,
495         LinalgTilingOptions()
496             .setTileSizes({8, 8, 4})
497             .setLoopType(LinalgTilingLoopType::Loops)
498             .setDistributionOptions(cyclicNprocsEqNiters),
499         LinalgTransformationFilter(
500             StringAttr::get(context, "tensors_distribute1"),
501             StringAttr::get(context, "tensors_after_distribute1")));
502   }
503 }
504 
505 static void
506 applyMatmulToVectorPatterns(FuncOp funcOp,
507                             bool testMatmulToVectorPatterns1dTiling,
508                             bool testMatmulToVectorPatterns2dTiling) {
509   MLIRContext *ctx = funcOp.getContext();
510   SmallVector<RewritePatternSet, 4> stage1Patterns;
511   if (testMatmulToVectorPatterns1dTiling) {
512     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
513   } else if (testMatmulToVectorPatterns2dTiling) {
514     stage1Patterns.emplace_back(
515         ctx, std::make_unique<LinalgTilingPattern>(
516                  MatmulOp::getOperationName(), ctx,
517                  LinalgTilingOptions()
518                      .setTileSizes({768, 264, 768})
519                      .setInterchange({1, 2, 0}),
520                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
521                                             StringAttr::get(ctx, "L2"))));
522     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
523   }
524   {
525     // Canonicalization patterns
526     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
527     vector::populateVectorTransferPermutationMapLoweringPatterns(
528         canonicalizationPatterns);
529     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
530     stage1Patterns.push_back(std::move(canonicalizationPatterns));
531   }
532   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
533   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
534   FrozenRewritePatternSet stage2Patterns =
535       getLinalgTilingCanonicalizationPatterns(ctx);
536   (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns);
537 }
538 
539 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
540   RewritePatternSet forwardPattern(funcOp.getContext());
541   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
542   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
543   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
544 }
545 
546 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
547   RewritePatternSet patterns(funcOp.getContext());
548   patterns.add<LinalgVectorizationPattern>(
549       funcOp.getContext(),
550       LinalgTransformationFilter()
551           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
552   populatePadTensorOpVectorizationPatterns(patterns);
553   populateConvolutionVectorizationPatterns(patterns);
554   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
555 }
556 
557 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
558   RewritePatternSet patterns(funcOp.getContext());
559   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
560   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
561 }
562 
563 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
564   RewritePatternSet patterns(funcOp.getContext());
565   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
566   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
567 }
568 
569 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
570   RewritePatternSet patterns(funcOp.getContext());
571   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
572   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
573 }
574 
575 static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
576                              ArrayRef<int64_t> tileSizes,
577                              ArrayRef<int64_t> peeledLoops,
578                              bool scalarizeDynamicDims) {
579   MLIRContext *context = funcOp.getContext();
580   RewritePatternSet tilingPattern(context);
581   LinalgTilingLoopType type =
582       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
583           .Case("for", LinalgTilingLoopType::Loops)
584           .Case("affine", LinalgTilingLoopType::AffineLoops)
585           .Case("parallel", LinalgTilingLoopType::ParallelLoops)
586           .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
587   auto linalgTilingOptions = linalg::LinalgTilingOptions()
588                                  .setPeeledLoops(peeledLoops)
589                                  .setLoopType(type);
590   if (scalarizeDynamicDims) {
591     linalgTilingOptions.scalarizeDynamicDims();
592     assert(tileSizes.empty() &&
593            "tileSizes and scalarizeDynamicDims is mutually exclusive");
594   } else {
595     linalgTilingOptions.setTileSizes(tileSizes);
596   }
597   linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
598   TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
599       tilingPattern, linalgTilingOptions, f);
600   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
601 }
602 
603 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
604 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
605 
606 namespace {
607 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
608 /// `idx`-th loop contains only "full" iterations and a second loop for the
609 /// remaining partial iteration (if any).
610 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
611   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
612       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
613   }
614 
615   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
616                                 PatternRewriter &rewriter) const override {
617     SmallVector<int64_t> peeledLoops;
618     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
619       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
620       peeledLoops =
621           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
622             return attr.cast<IntegerAttr>().getInt();
623           }));
624       // Check if the loop was already peeled.
625       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
626         return failure();
627     }
628     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
629       // No peeling of loop nests with a partial iteration.
630       return failure();
631 
632     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
633       return failure();
634 
635     // Peel loop and canonicalize.
636     TiledLoopOp result;
637     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
638                                                     result)))
639       return failure();
640 
641     // Apply label, so that the same loop is not rewritten a second time.
642     peeledLoops.push_back(idx);
643     rewriter.updateRootInPlace(loopOp, [&]() {
644       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
645     });
646     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
647     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
648 
649     return success();
650   }
651 
652   /// Index of loop to peel.
653   int64_t idx;
654 
655   /// If set to true, do not peel TiledLoopOps with a partial iteration.
656   bool skipPartial;
657 };
658 } // namespace
659 
660 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
661                                          ArrayRef<unsigned> loops,
662                                          bool skipPartial) {
663   MLIRContext *ctx = funcOp.getContext();
664   RewritePatternSet patterns(ctx);
665   for (unsigned idx : loops)
666     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
667   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
668 
669   // Drop the markers.
670   funcOp.walk([](TiledLoopOp op) {
671     op->removeAttr(kPeeledLoopsLabel);
672     op->removeAttr(kPartialIterationLabel);
673   });
674 }
675 
676 /// Apply transformations specified as patterns.
677 void TestLinalgTransforms::runOnOperation() {
678   auto lambda = [&](void *) {
679     getOperation().walk([](LinalgOp op) {
680       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
681     });
682   };
683   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
684 
685   if (testPromotionOptions) {
686     RewritePatternSet patterns(&getContext());
687     fillPromotionCallBackPatterns(&getContext(), patterns);
688     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
689     return;
690   }
691   if (testTileAndDistributionOptions) {
692     RewritePatternSet patterns(&getContext());
693     fillTileAndDistributePatterns(&getContext(), patterns);
694     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
695     return;
696   }
697   if (testPatterns)
698     return applyPatterns(getOperation());
699   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
700     return applyMatmulToVectorPatterns(getOperation(),
701                                        testMatmulToVectorPatterns1dTiling,
702                                        testMatmulToVectorPatterns2dTiling);
703   if (testVectorTransferForwardingPatterns)
704     return applyVectorTransferForwardingPatterns(getOperation());
705   if (testGenericToVectorPattern)
706     return applyLinalgToVectorPatterns(getOperation());
707   if (testTransformPadTensor)
708     return applyPadTensorToGenericPatterns(getOperation());
709   if (testGeneralizePadTensor)
710     return applyGeneralizePadTensorPatterns(getOperation());
711   if (testSwapSubTensorPadTensor)
712     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
713   if (testTiledLoopPeeling.hasValue())
714     return applyTiledLoopPeelingPattern(getOperation(), testTiledLoopPeeling,
715                                         skipPartial);
716   if (testTilePattern)
717     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
718                             /*scalarizeDynamicDims=*/false);
719   if (testTileScalarizeDynamicDims)
720     return applyTilePattern(getOperation(), loopType, tileSizes,
721                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
722 }
723 
724 namespace mlir {
725 namespace test {
726 void registerTestLinalgTransforms() {
727   PassRegistration<TestLinalgTransforms>();
728 }
729 } // namespace test
730 } // namespace mlir
731