1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
17 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallVector.h"
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 namespace {
32 struct TestLinalgTransforms
33     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
34   TestLinalgTransforms() = default;
35   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
36 
37   void getDependentDialects(DialectRegistry &registry) const override {
38     // clang-format off
39     registry.insert<AffineDialect,
40                     memref::MemRefDialect,
41                     scf::SCFDialect,
42                     StandardOpsDialect,
43                     vector::VectorDialect,
44                     gpu::GPUDialect>();
45     // clang-format on
46   }
47   StringRef getArgument() const final {
48     return "test-linalg-transform-patterns";
49   }
50   StringRef getDescription() const final {
51     return "Test Linalg transformation patterns by applying them greedily.";
52   }
53 
54   void runOnFunction() override;
55 
56   Option<bool> testPatterns{*this, "test-patterns",
57                             llvm::cl::desc("Test a mixed set of patterns"),
58                             llvm::cl::init(false)};
59   Option<bool> testMatmulToVectorPatterns1dTiling{
60       *this, "test-matmul-to-vector-patterns-tile-1d",
61       llvm::cl::desc(
62           "Test a fused pass that applies patterns from matmul to vectors via "
63           "1-d tiling"),
64       llvm::cl::init(false)};
65   Option<bool> testMatmulToVectorPatterns2dTiling{
66       *this, "test-matmul-to-vector-patterns-tile-2d",
67       llvm::cl::desc(
68           "Test a fused pass that applies patterns from matmul to vectors via "
69           "2-d tiling"),
70       llvm::cl::init(false)};
71   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
72                                     llvm::cl::desc("Test promotion options"),
73                                     llvm::cl::init(false)};
74   Option<bool> testTileAndDistributionOptions{
75       *this, "test-tile-and-distribute-options",
76       llvm::cl::desc("Test tile and distribute options"),
77       llvm::cl::init(false)};
78   Option<bool> testVectorTransferForwardingPatterns{
79       *this, "test-vector-transfer-forwarding-patterns",
80       llvm::cl::desc(
81           "Test a fused pass that forwards linalg.copy to vector.transfer"),
82       llvm::cl::init(false)};
83   Option<bool> testGenericToVectorPattern{
84       *this, "test-linalg-to-vector-patterns",
85       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
86                      "in vector.contract form"),
87       llvm::cl::init(false)};
88   Option<bool> testTilePattern{*this, "test-tile-pattern",
89                                llvm::cl::desc("Test tile pattern"),
90                                llvm::cl::init(false)};
91   Option<bool> testTileScalarizeDynamicDims{
92       *this, "test-tile-scalarize-dynamic-dims",
93       llvm::cl::desc("Test tiling of dynamic dims by 1"),
94       llvm::cl::init(false)};
95   Option<int> testHoistPadding{*this, "test-hoist-padding",
96                                llvm::cl::desc("Test hoist padding"),
97                                llvm::cl::init(0)};
98   Option<bool> testTransformPadTensor{
99       *this, "test-transform-pad-tensor",
100       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
101       llvm::cl::init(false)};
102   Option<bool> testGeneralizePadTensor{
103       *this, "test-generalize-pad-tensor",
104       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
105       llvm::cl::init(false)};
106   Option<bool> testSwapSubTensorPadTensor{
107       *this, "test-swap-subtensor-padtensor",
108       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
109                      "pad_tensor(subtensor)"),
110       llvm::cl::init(false)};
111   ListOption<int64_t> paddedOperands{
112       *this, "padded-operands",
113       llvm::cl::desc("Operands to pad when test-tile-pattern"),
114       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
115   ListOption<int64_t> peeledLoops{
116       *this, "peeled-loops",
117       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
118       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
119   ListOption<int64_t> tileSizes{
120       *this, "tile-sizes",
121       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
122       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
123   ListOption<unsigned> testInterchangePattern{
124       *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
125       llvm::cl::desc("Test the interchange pattern.")};
126   ListOption<unsigned> testTiledLoopPeeling{
127       *this, "test-tiled-loop-peeling",
128       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
129       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
130   Option<bool> skipPartial{
131       *this, "skip-partial",
132       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
133       llvm::cl::init(false)};
134   Option<std::string> loopType{
135       *this, "loop-type",
136       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
137                      "tiled_loop"),
138       llvm::cl::init("for")};
139 };
140 } // end anonymous namespace
141 
142 static void applyPatterns(FuncOp funcOp) {
143   MLIRContext *ctx = funcOp.getContext();
144   RewritePatternSet patterns(ctx);
145 
146   //===--------------------------------------------------------------------===//
147   // Linalg tiling patterns.
148   //===--------------------------------------------------------------------===//
149   patterns.add<LinalgTilingPattern<MatmulOp>>(
150       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
151       LinalgTransformationFilter(Identifier::get("MEM", ctx),
152                                  Identifier::get("L3", ctx)));
153   patterns.add<LinalgTilingPattern<MatmulOp>>(
154       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
155       LinalgTransformationFilter(Identifier::get("L3", ctx),
156                                  Identifier::get("L2", ctx)));
157   patterns.add<LinalgTilingPattern<MatmulOp>>(
158       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
159       LinalgTransformationFilter(Identifier::get("L2", ctx),
160                                  Identifier::get("L1", ctx)));
161   patterns.add<LinalgTilingPattern<MatmulOp>>(
162       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
163       LinalgTransformationFilter(Identifier::get("L1", ctx),
164                                  Identifier::get("REG", ctx)));
165 
166   patterns.add<LinalgTilingPattern<MatvecOp>>(
167       ctx,
168       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
169           LinalgTilingLoopType::ParallelLoops),
170       LinalgTransformationFilter(ArrayRef<Identifier>{},
171                                  Identifier::get("L1", ctx)));
172 
173   patterns.add<LinalgTilingPattern<DotOp>>(
174       ctx, LinalgTilingOptions().setTileSizes(8000),
175       LinalgTransformationFilter(
176           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
177                                Identifier::get("L3", ctx),
178                                Identifier::get("L2", ctx)},
179           Identifier::get("REG", ctx)));
180 
181   //===--------------------------------------------------------------------===//
182   // Linalg tiling and permutation patterns.
183   //===--------------------------------------------------------------------===//
184   patterns.add<LinalgTilingPattern<MatmulOp>>(
185       ctx,
186       LinalgTilingOptions()
187           .setTileSizes({2000, 3000, 4000})
188           .setInterchange({1, 2, 0}),
189       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
190                                  Identifier::get("L2__with_perm__", ctx)));
191   patterns.add<LinalgTilingPattern<MatmulOp>>(
192       ctx,
193       LinalgTilingOptions()
194           .setTileSizes({200, 300, 400})
195           .setInterchange({1, 0, 2}),
196       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
197                                  Identifier::get("L1__with_perm__", ctx)));
198   patterns.add<LinalgTilingPattern<MatmulOp>>(
199       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
200       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
201                                  Identifier::get("REG__with_perm__", ctx)));
202 
203   patterns.add<LinalgTilingPattern<MatvecOp>>(
204       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
205       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
206                                  Identifier::get("L1__with_perm__", ctx)));
207 
208   patterns.add<LinalgTilingPattern<MatmulOp>>(
209       ctx,
210       LinalgTilingOptions()
211           .setTileSizes({16, 8, 4})
212           .setInterchange({1, 2, 0})
213           .setLoopType(LinalgTilingLoopType::ParallelLoops),
214       LinalgTransformationFilter(
215           Identifier::get("par__with_perm__", ctx),
216           Identifier::get("after_par__with_perm__", ctx)));
217 
218   //===--------------------------------------------------------------------===//
219   // Linalg to loops patterns.
220   //===--------------------------------------------------------------------===//
221   patterns.add<LinalgLoweringPattern<DotOp>>(
222       ctx,
223       /*loweringType=*/LinalgLoweringType::Loops,
224       LinalgTransformationFilter(Identifier::get("REG", ctx)));
225 
226   //===--------------------------------------------------------------------===//
227   // Linalg distribution patterns.
228   //===--------------------------------------------------------------------===//
229   LinalgLoopDistributionOptions distributionOptions;
230 
231   //===--------------------------------------------------------------------===//
232   // Linalg to vector contraction patterns.
233   //===--------------------------------------------------------------------===//
234   patterns.add<LinalgVectorizationPattern>(
235       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
236                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
237 
238   //===--------------------------------------------------------------------===//
239   // Linalg generic interchange pattern.
240   //===--------------------------------------------------------------------===//
241   patterns.add<GenericOpInterchangePattern>(
242       ctx,
243       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
244       LinalgTransformationFilter(ArrayRef<Identifier>{},
245                                  Identifier::get("PERMUTED", ctx)));
246 
247   //===--------------------------------------------------------------------===//
248   // Linalg subview operands promotion.
249   //===--------------------------------------------------------------------===//
250   patterns.add<LinalgPromotionPattern<MatmulOp>>(
251       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
252       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
253                                  Identifier::get("_views_promoted_", ctx)));
254   patterns.add<LinalgPromotionPattern<MatmulOp>>(
255       ctx,
256       LinalgPromotionOptions()
257           .setOperandsToPromote({0})
258           .setUseFullTileBuffersByDefault(true),
259       LinalgTransformationFilter(
260           Identifier::get("_promote_first_view_", ctx),
261           Identifier::get("_first_view_promoted_", ctx)));
262   patterns.add<LinalgPromotionPattern<FillOp>>(
263       ctx,
264       LinalgPromotionOptions()
265           .setOperandsToPromote({1})
266           .setUseFullTileBuffers({false, true})
267           .setAlignment(32),
268       LinalgTransformationFilter(
269           Identifier::get("_promote_views_aligned_", ctx),
270           Identifier::get("_views_aligned_promoted_", ctx)));
271 
272   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
273 
274   // Drop the marker.
275   funcOp.walk([](LinalgOp op) {
276     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
277   });
278 }
279 
280 static void fillL1TilingAndMatmulToVectorPatterns(
281     FuncOp funcOp, StringRef startMarker,
282     SmallVectorImpl<RewritePatternSet> &patternsVector) {
283   MLIRContext *ctx = funcOp.getContext();
284   patternsVector.emplace_back(
285       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
286                ctx,
287                LinalgTilingOptions()
288                    .setTileSizes({8, 12, 16})
289                    .setInterchange({1, 0, 2}),
290                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
291                                           Identifier::get("L1", ctx))));
292 
293   patternsVector.emplace_back(
294       ctx,
295       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
296           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
297           LinalgTransformationFilter(Identifier::get("L1", ctx),
298                                      Identifier::get("VEC", ctx))));
299 
300   patternsVector.emplace_back(
301       ctx, std::make_unique<LinalgVectorizationPattern>(
302                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
303                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
304   patternsVector.back().add<LinalgVectorizationPattern>(
305       ctx, LinalgTransformationFilter().addFilter(
306                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // Test promotion callbacks
311 //===----------------------------------------------------------------------===//
312 
313 // Allocation call back
314 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
315                                        ArrayRef<Value> boundingSubViewSize,
316                                        DataLayout &layout) {
317   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
318   return b
319       .create<memref::AllocOp>(
320           subView.getLoc(),
321           MemRefType::get(shape, subView.getType().getElementType(),
322                           /*affineMapComposition =*/{}, 3),
323           boundingSubViewSize)
324       .getResult();
325 }
326 
327 // Deallocation callback
328 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
329   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
330   return success();
331 }
332 
333 // Copy in call back
334 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
335                                     bool isOutput) {
336   auto floatType = src.getType().cast<MemRefType>().getElementType();
337   if (!floatType.isa<FloatType>())
338     return failure();
339   if (!isOutput) {
340     Value cst =
341         b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0));
342     b.create<FillOp>(src.getLoc(), cst, dst);
343   }
344   b.create<CopyOp>(src.getLoc(), src, dst);
345   return success();
346 }
347 
348 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
349                                           RewritePatternSet &patterns) {
350   patterns.add<LinalgTilingPattern<MatmulOp>>(
351       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
352       LinalgTransformationFilter(Identifier::get("START", ctx),
353                                  Identifier::get("PROMOTE", ctx)));
354   patterns.add<LinalgPromotionPattern<MatmulOp>>(
355       ctx,
356       LinalgPromotionOptions()
357           .setOperandsToPromote({0, 2})
358           .setUseFullTileBuffers({false, false})
359           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
360           .setCopyInOutFns(
361               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
362                 return copyCallBackFn(b, src, dst, false);
363               },
364               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
365                 return copyCallBackFn(b, src, dst, true);
366               }),
367       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
368 }
369 
370 template <typename IdOp, typename NProcsOp>
371 static SmallVector<ProcInfo, 2>
372 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
373   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
374   SmallVector<ProcInfo, 2> procInfo(count);
375   const char *xyz[] = {"x", "y", "z"};
376   Type indexType = b.getIndexType();
377   for (unsigned i = 0; i < count; ++i) {
378     procInfo[count - 1 - i] = {
379         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
380         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
381   }
382   return procInfo;
383 }
384 
385 static void fillTileAndDistributePatterns(MLIRContext *context,
386                                           RewritePatternSet &patterns) {
387   {
388     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
389     cyclicNprocsEqNiters.distributionMethod.resize(
390         2, DistributionMethod::CyclicNumProcsEqNumIters);
391     cyclicNprocsEqNiters.procInfo =
392         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
393     patterns.add<LinalgTilingPattern<MatmulOp>>(
394         context,
395         LinalgTilingOptions()
396             .setTileSizes({8, 8, 4})
397             .setLoopType(LinalgTilingLoopType::ParallelLoops)
398             .setDistributionOptions(cyclicNprocsEqNiters),
399         LinalgTransformationFilter(
400             Identifier::get("distribute1", context),
401             Identifier::get("after_distribute1", context)));
402   }
403 
404   {
405     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
406     cyclicNprocsGeNiters.distributionMethod.resize(
407         2, DistributionMethod::CyclicNumProcsGeNumIters);
408     cyclicNprocsGeNiters.procInfo =
409         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
410     patterns.add<LinalgTilingPattern<MatmulOp>>(
411         context,
412         LinalgTilingOptions()
413             .setTileSizes({8, 8, 4})
414             .setLoopType(LinalgTilingLoopType::ParallelLoops)
415             .setDistributionOptions(cyclicNprocsGeNiters),
416         LinalgTransformationFilter(
417             Identifier::get("distribute2", context),
418             Identifier::get("after_distribute2", context)));
419   }
420 
421   {
422     LinalgLoopDistributionOptions cyclicNprocsDefault;
423     cyclicNprocsDefault.distributionMethod.resize(2,
424                                                   DistributionMethod::Cyclic);
425     cyclicNprocsDefault.procInfo =
426         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
427     patterns.add<LinalgTilingPattern<MatmulOp>>(
428         context,
429         LinalgTilingOptions()
430             .setTileSizes({8, 8, 4})
431             .setLoopType(LinalgTilingLoopType::ParallelLoops)
432             .setDistributionOptions(cyclicNprocsDefault),
433         LinalgTransformationFilter(
434             Identifier::get("distribute3", context),
435             Identifier::get("after_distribute3", context)));
436   }
437 
438   {
439     LinalgLoopDistributionOptions cyclicNprocsMixed1;
440     cyclicNprocsMixed1.distributionMethod = {
441         DistributionMethod::CyclicNumProcsEqNumIters,
442         DistributionMethod::CyclicNumProcsGeNumIters};
443     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
444     patterns.add<LinalgTilingPattern<MatmulOp>>(
445         context,
446         LinalgTilingOptions()
447             .setTileSizes({8, 8, 4})
448             .setLoopType(LinalgTilingLoopType::ParallelLoops)
449             .setDistributionOptions(cyclicNprocsMixed1),
450         LinalgTransformationFilter(
451             Identifier::get("distribute4", context),
452             Identifier::get("after_distribute4", context)));
453   }
454 
455   {
456     LinalgLoopDistributionOptions cyclicNprocsMixed2;
457     cyclicNprocsMixed2.distributionMethod = {
458         DistributionMethod::CyclicNumProcsGeNumIters,
459         DistributionMethod::Cyclic};
460     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
461     patterns.add<LinalgTilingPattern<MatmulOp>>(
462         context,
463         LinalgTilingOptions()
464             .setTileSizes({8, 8, 4})
465             .setLoopType(LinalgTilingLoopType::ParallelLoops)
466             .setDistributionOptions(cyclicNprocsMixed2),
467         LinalgTransformationFilter(
468             Identifier::get("distribute5", context),
469             Identifier::get("after_distribute5", context)));
470   }
471 
472   {
473     LinalgLoopDistributionOptions cyclicNprocsMixed3;
474     cyclicNprocsMixed3.distributionMethod = {
475         DistributionMethod::Cyclic,
476         DistributionMethod::CyclicNumProcsEqNumIters};
477     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
478 
479     patterns.add<LinalgTilingPattern<MatmulOp>>(
480         context,
481         LinalgTilingOptions()
482             .setTileSizes({8, 8, 4})
483             .setLoopType(LinalgTilingLoopType::ParallelLoops)
484             .setDistributionOptions(cyclicNprocsMixed3),
485         LinalgTransformationFilter(
486             Identifier::get("distribute6", context),
487             Identifier::get("after_distribute6", context)));
488   }
489 
490   {
491     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
492     cyclicNprocsEqNiters.distributionMethod.resize(2,
493                                                    DistributionMethod::Cyclic);
494     cyclicNprocsEqNiters.procInfo =
495         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
496     patterns.add<LinalgTilingPattern<MatmulOp>>(
497         context,
498         LinalgTilingOptions()
499             .setTileSizes({8, 8, 4})
500             .setLoopType(LinalgTilingLoopType::Loops)
501             .setDistributionOptions(cyclicNprocsEqNiters),
502         LinalgTransformationFilter(
503             Identifier::get("tensors_distribute1", context),
504             Identifier::get("tensors_after_distribute1", context)));
505   }
506 }
507 
508 static void
509 applyMatmulToVectorPatterns(FuncOp funcOp,
510                             bool testMatmulToVectorPatterns1dTiling,
511                             bool testMatmulToVectorPatterns2dTiling) {
512   MLIRContext *ctx = funcOp.getContext();
513   SmallVector<RewritePatternSet, 4> stage1Patterns;
514   if (testMatmulToVectorPatterns1dTiling) {
515     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
516                                           stage1Patterns);
517   } else if (testMatmulToVectorPatterns2dTiling) {
518     stage1Patterns.emplace_back(
519         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
520                  ctx,
521                  LinalgTilingOptions()
522                      .setTileSizes({768, 264, 768})
523                      .setInterchange({1, 2, 0}),
524                  LinalgTransformationFilter(Identifier::get("START", ctx),
525                                             Identifier::get("L2", ctx))));
526     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
527                                           stage1Patterns);
528   }
529   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
530   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
531   FrozenRewritePatternSet stage2Patterns =
532       getLinalgTilingCanonicalizationPatterns(ctx);
533   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
534                             std::move(stage2Patterns));
535 }
536 
537 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
538   RewritePatternSet forwardPattern(funcOp.getContext());
539   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
540   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
541   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
542 }
543 
544 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
545   RewritePatternSet patterns(funcOp.getContext());
546   patterns.add<LinalgVectorizationPattern>(
547       funcOp.getContext(),
548       LinalgTransformationFilter()
549           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
550   populatePadTensorOpVectorizationPatterns(patterns);
551   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
552 }
553 
554 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
555   RewritePatternSet patterns(funcOp.getContext());
556   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
557   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
558 }
559 
560 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
561   RewritePatternSet patterns(funcOp.getContext());
562   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
563   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
564 }
565 
566 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
567   RewritePatternSet patterns(funcOp.getContext());
568   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
569   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
570 }
571 
572 // For now, just assume it is the zero of type.
573 // In the future, it should be the zero of type + op.
574 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
575   auto t = getElementTypeOrSelf(op.get());
576   return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
577 }
578 
579 static void applyTilePattern(FuncOp funcOp, std::string loopType,
580                              ArrayRef<int64_t> tileSizes,
581                              ArrayRef<int64_t> paddedOperands,
582                              ArrayRef<int64_t> peeledLoops,
583                              bool scalarizeDynamicDims) {
584   MLIRContext *context = funcOp.getContext();
585   RewritePatternSet tilingPattern(context);
586   LinalgTilingLoopType type =
587       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
588           .Case("for", LinalgTilingLoopType::Loops)
589           .Case("affine", LinalgTilingLoopType::AffineLoops)
590           .Case("parallel", LinalgTilingLoopType::ParallelLoops)
591           .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
592   auto linalgTilingOptions = linalg::LinalgTilingOptions()
593                                  .setPeeledLoops(peeledLoops)
594                                  .setLoopType(type);
595   if (scalarizeDynamicDims) {
596     linalgTilingOptions.scalarizeDynamicDims();
597     assert(tileSizes.empty() &&
598            "tileSizes and scalarizeDynamicDims is mutually exclusive");
599   } else {
600     linalgTilingOptions.setTileSizes(tileSizes);
601   }
602   if (!paddedOperands.empty()) {
603     auto paddingFunc = [&](OpBuilder &b,
604                            OpOperand &opOperand) -> FailureOr<Value> {
605       if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0)
606         return failure();
607       return getNeutralOfLinalgOp(b, opOperand);
608     };
609     linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc);
610   }
611   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
612                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
613       context, linalgTilingOptions,
614       linalg::LinalgTransformationFilter(Identifier::get("tile", context)));
615   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
616 }
617 
618 static void applyInterchangePattern(FuncOp funcOp,
619                                     ArrayRef<unsigned> interchangeVector) {
620   MLIRContext *context = funcOp.getContext();
621   RewritePatternSet interchangePattern(context);
622   interchangePattern.add<GenericOpInterchangePattern>(
623       context, interchangeVector,
624       LinalgTransformationFilter(ArrayRef<Identifier>{},
625                                  Identifier::get("interchange", context)));
626   (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
627 }
628 
629 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
630 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
631 
632 namespace {
633 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
634 /// `idx`-th loop contains only "full" iterations and a second loop for the
635 /// remaining partial iteration (if any).
636 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
637   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
638       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
639   }
640 
641   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
642                                 PatternRewriter &rewriter) const override {
643     SmallVector<int64_t> peeledLoops;
644     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
645       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
646       peeledLoops =
647           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
648             return attr.cast<IntegerAttr>().getInt();
649           }));
650       // Check if the loop was already peeled.
651       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
652         return failure();
653     }
654     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
655       // No peeling of loop nests with a partial iteration.
656       return failure();
657 
658     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
659       return failure();
660 
661     // Peel loop and canonicalize.
662     TiledLoopOp result;
663     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
664                                                     result)))
665       return failure();
666 
667     // Apply label, so that the same loop is not rewritten a second time.
668     peeledLoops.push_back(idx);
669     rewriter.updateRootInPlace(loopOp, [&]() {
670       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
671     });
672     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
673     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
674 
675     return success();
676   }
677 
678   /// Index of loop to peel.
679   int64_t idx;
680 
681   /// If set to true, do not peel TiledLoopOps with a partial iteration.
682   bool skipPartial;
683 };
684 } // namespace
685 
686 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
687                                          ArrayRef<unsigned> loops,
688                                          bool skipPartial) {
689   MLIRContext *ctx = funcOp.getContext();
690   RewritePatternSet patterns(ctx);
691   for (unsigned idx : loops)
692     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
693   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
694 
695   // Drop the markers.
696   funcOp.walk([](TiledLoopOp op) {
697     op->removeAttr(kPeeledLoopsLabel);
698     op->removeAttr(kPartialIterationLabel);
699   });
700 }
701 
702 /// Apply transformations specified as patterns.
703 void TestLinalgTransforms::runOnFunction() {
704   auto lambda = [&](void *) {
705     getFunction().walk([](LinalgOp op) {
706       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
707     });
708   };
709   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
710 
711   if (testPromotionOptions) {
712     RewritePatternSet patterns(&getContext());
713     fillPromotionCallBackPatterns(&getContext(), patterns);
714     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
715     return;
716   }
717   if (testTileAndDistributionOptions) {
718     RewritePatternSet patterns(&getContext());
719     fillTileAndDistributePatterns(&getContext(), patterns);
720     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
721     return;
722   }
723   if (testPatterns)
724     return applyPatterns(getFunction());
725   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
726     return applyMatmulToVectorPatterns(getFunction(),
727                                        testMatmulToVectorPatterns1dTiling,
728                                        testMatmulToVectorPatterns2dTiling);
729   if (testVectorTransferForwardingPatterns)
730     return applyVectorTransferForwardingPatterns(getFunction());
731   if (testGenericToVectorPattern)
732     return applyLinalgToVectorPatterns(getFunction());
733   if (testTransformPadTensor)
734     return applyPadTensorToGenericPatterns(getFunction());
735   if (testGeneralizePadTensor)
736     return applyGeneralizePadTensorPatterns(getFunction());
737   if (testSwapSubTensorPadTensor)
738     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
739   if (testTiledLoopPeeling.hasValue())
740     return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
741                                         skipPartial);
742   if (testTilePattern)
743     return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
744                             peeledLoops, /*scalarizeDynamicDims=*/false);
745   if (testTileScalarizeDynamicDims)
746     return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
747                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
748   if (testHoistPadding) {
749     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
750       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
751     });
752   }
753   if (testInterchangePattern.hasValue())
754     return applyInterchangePattern(getFunction(), testInterchangePattern);
755 }
756 
757 namespace mlir {
758 namespace test {
759 void registerTestLinalgTransforms() {
760   PassRegistration<TestLinalgTransforms>();
761 }
762 } // namespace test
763 } // namespace mlir
764