1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
30 namespace {
31 struct TestLinalgTransforms
32     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
33   TestLinalgTransforms() = default;
34   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
35 
36   void getDependentDialects(DialectRegistry &registry) const override {
37     // clang-format off
38     registry.insert<AffineDialect,
39                     memref::MemRefDialect,
40                     scf::SCFDialect,
41                     StandardOpsDialect,
42                     vector::VectorDialect,
43                     gpu::GPUDialect>();
44     // clang-format on
45   }
46   StringRef getArgument() const final {
47     return "test-linalg-transform-patterns";
48   }
49   StringRef getDescription() const final {
50     return "Test Linalg transformation patterns by applying them greedily.";
51   }
52 
53   void runOnFunction() override;
54 
55   Option<bool> testPatterns{*this, "test-patterns",
56                             llvm::cl::desc("Test a mixed set of patterns"),
57                             llvm::cl::init(false)};
58   Option<bool> testMatmulToVectorPatterns1dTiling{
59       *this, "test-matmul-to-vector-patterns-tile-1d",
60       llvm::cl::desc(
61           "Test a fused pass that applies patterns from matmul to vectors via "
62           "1-d tiling"),
63       llvm::cl::init(false)};
64   Option<bool> testMatmulToVectorPatterns2dTiling{
65       *this, "test-matmul-to-vector-patterns-tile-2d",
66       llvm::cl::desc(
67           "Test a fused pass that applies patterns from matmul to vectors via "
68           "2-d tiling"),
69       llvm::cl::init(false)};
70   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
71                                     llvm::cl::desc("Test promotion options"),
72                                     llvm::cl::init(false)};
73   Option<bool> testTileAndDistributionOptions{
74       *this, "test-tile-and-distribute-options",
75       llvm::cl::desc("Test tile and distribute options"),
76       llvm::cl::init(false)};
77   Option<bool> testVectorTransferForwardingPatterns{
78       *this, "test-vector-transfer-forwarding-patterns",
79       llvm::cl::desc(
80           "Test a fused pass that forwards linalg.copy to vector.transfer"),
81       llvm::cl::init(false)};
82   Option<bool> testGenericToVectorPattern{
83       *this, "test-linalg-to-vector-patterns",
84       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
85                      "in vector.contract form"),
86       llvm::cl::init(false)};
87   Option<bool> testTileAndPadPattern{
88       *this, "test-tile-and-pad-pattern",
89       llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)};
90   Option<int> testHoistPadding{*this, "test-hoist-padding",
91                                llvm::cl::desc("Test hoist padding"),
92                                llvm::cl::init(0)};
93   Option<bool> testTransformPadTensor{
94       *this, "test-transform-pad-tensor",
95       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
96       llvm::cl::init(false)};
97   Option<bool> testGeneralizePadTensor{
98       *this, "test-generalize-pad-tensor",
99       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
100       llvm::cl::init(false)};
101   Option<bool> testSwapSubTensorPadTensor{
102       *this, "test-swap-subtensor-padtensor",
103       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
104                      "pad_tensor(subtensor)"),
105       llvm::cl::init(false)};
106   ListOption<int64_t> tileSizesForPadding{
107       *this, "tile-sizes-for-padding",
108       llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
109       llvm::cl::MiscFlags::CommaSeparated};
110   ListOption<unsigned> testInterchangePattern{
111       *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
112       llvm::cl::desc("Test the interchange pattern.")};
113   ListOption<unsigned> testTiledLoopPeeling{
114       *this, "test-tiled-loop-peeling",
115       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
116       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
117 };
118 } // end anonymous namespace
119 
120 static void applyPatterns(FuncOp funcOp) {
121   MLIRContext *ctx = funcOp.getContext();
122   RewritePatternSet patterns(ctx);
123 
124   //===--------------------------------------------------------------------===//
125   // Linalg tiling patterns.
126   //===--------------------------------------------------------------------===//
127   patterns.add<LinalgTilingPattern<MatmulOp>>(
128       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
129       LinalgTransformationFilter(Identifier::get("MEM", ctx),
130                                  Identifier::get("L3", ctx)));
131   patterns.add<LinalgTilingPattern<MatmulOp>>(
132       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
133       LinalgTransformationFilter(Identifier::get("L3", ctx),
134                                  Identifier::get("L2", ctx)));
135   patterns.add<LinalgTilingPattern<MatmulOp>>(
136       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
137       LinalgTransformationFilter(Identifier::get("L2", ctx),
138                                  Identifier::get("L1", ctx)));
139   patterns.add<LinalgTilingPattern<MatmulOp>>(
140       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
141       LinalgTransformationFilter(Identifier::get("L1", ctx),
142                                  Identifier::get("REG", ctx)));
143 
144   patterns.add<LinalgTilingPattern<MatvecOp>>(
145       ctx,
146       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
147           LinalgTilingLoopType::ParallelLoops),
148       LinalgTransformationFilter(ArrayRef<Identifier>{},
149                                  Identifier::get("L1", ctx)));
150 
151   patterns.add<LinalgTilingPattern<DotOp>>(
152       ctx, LinalgTilingOptions().setTileSizes(8000),
153       LinalgTransformationFilter(
154           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
155                                Identifier::get("L3", ctx),
156                                Identifier::get("L2", ctx)},
157           Identifier::get("REG", ctx)));
158 
159   //===--------------------------------------------------------------------===//
160   // Linalg tiling and permutation patterns.
161   //===--------------------------------------------------------------------===//
162   patterns.add<LinalgTilingPattern<MatmulOp>>(
163       ctx,
164       LinalgTilingOptions()
165           .setTileSizes({2000, 3000, 4000})
166           .setInterchange({1, 2, 0}),
167       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
168                                  Identifier::get("L2__with_perm__", ctx)));
169   patterns.add<LinalgTilingPattern<MatmulOp>>(
170       ctx,
171       LinalgTilingOptions()
172           .setTileSizes({200, 300, 400})
173           .setInterchange({1, 0, 2}),
174       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
175                                  Identifier::get("L1__with_perm__", ctx)));
176   patterns.add<LinalgTilingPattern<MatmulOp>>(
177       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
178       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
179                                  Identifier::get("REG__with_perm__", ctx)));
180 
181   patterns.add<LinalgTilingPattern<MatvecOp>>(
182       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
183       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
184                                  Identifier::get("L1__with_perm__", ctx)));
185 
186   patterns.add<LinalgTilingPattern<MatmulOp>>(
187       ctx,
188       LinalgTilingOptions()
189           .setTileSizes({16, 8, 4})
190           .setInterchange({1, 2, 0})
191           .setLoopType(LinalgTilingLoopType::ParallelLoops),
192       LinalgTransformationFilter(
193           Identifier::get("par__with_perm__", ctx),
194           Identifier::get("after_par__with_perm__", ctx)));
195 
196   //===--------------------------------------------------------------------===//
197   // Linalg to loops patterns.
198   //===--------------------------------------------------------------------===//
199   patterns.add<LinalgLoweringPattern<DotOp>>(
200       ctx,
201       /*loweringType=*/LinalgLoweringType::Loops,
202       LinalgTransformationFilter(Identifier::get("REG", ctx)));
203 
204   //===--------------------------------------------------------------------===//
205   // Linalg distribution patterns.
206   //===--------------------------------------------------------------------===//
207   LinalgLoopDistributionOptions distributionOptions;
208 
209   //===--------------------------------------------------------------------===//
210   // Linalg to vector contraction patterns.
211   //===--------------------------------------------------------------------===//
212   patterns.add<LinalgVectorizationPattern>(
213       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
214                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
215 
216   //===--------------------------------------------------------------------===//
217   // Linalg generic interchange pattern.
218   //===--------------------------------------------------------------------===//
219   patterns.add<GenericOpInterchangePattern>(
220       ctx,
221       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
222       LinalgTransformationFilter(ArrayRef<Identifier>{},
223                                  Identifier::get("PERMUTED", ctx)));
224 
225   //===--------------------------------------------------------------------===//
226   // Linalg subview operands promotion.
227   //===--------------------------------------------------------------------===//
228   patterns.add<LinalgPromotionPattern<MatmulOp>>(
229       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
230       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
231                                  Identifier::get("_views_promoted_", ctx)));
232   patterns.add<LinalgPromotionPattern<MatmulOp>>(
233       ctx,
234       LinalgPromotionOptions()
235           .setOperandsToPromote({0})
236           .setUseFullTileBuffersByDefault(true),
237       LinalgTransformationFilter(
238           Identifier::get("_promote_first_view_", ctx),
239           Identifier::get("_first_view_promoted_", ctx)));
240   patterns.add<LinalgPromotionPattern<FillOp>>(
241       ctx,
242       LinalgPromotionOptions()
243           .setOperandsToPromote({1})
244           .setUseFullTileBuffers({false, true})
245           .setAlignment(32),
246       LinalgTransformationFilter(
247           Identifier::get("_promote_views_aligned_", ctx),
248           Identifier::get("_views_aligned_promoted_", ctx)));
249 
250   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
251 
252   // Drop the marker.
253   funcOp.walk([](LinalgOp op) {
254     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
255   });
256 }
257 
258 static void fillL1TilingAndMatmulToVectorPatterns(
259     FuncOp funcOp, StringRef startMarker,
260     SmallVectorImpl<RewritePatternSet> &patternsVector) {
261   MLIRContext *ctx = funcOp.getContext();
262   patternsVector.emplace_back(
263       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
264                ctx,
265                LinalgTilingOptions()
266                    .setTileSizes({8, 12, 16})
267                    .setInterchange({1, 0, 2}),
268                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
269                                           Identifier::get("L1", ctx))));
270 
271   patternsVector.emplace_back(
272       ctx,
273       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
274           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
275           LinalgTransformationFilter(Identifier::get("L1", ctx),
276                                      Identifier::get("VEC", ctx))));
277 
278   patternsVector.emplace_back(
279       ctx, std::make_unique<LinalgVectorizationPattern>(
280                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
281                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
282   patternsVector.back().add<LinalgVectorizationPattern>(
283       ctx, LinalgTransformationFilter().addFilter(
284                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // Test promotion callbacks
289 //===----------------------------------------------------------------------===//
290 
291 // Allocation call back
292 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
293                                        ArrayRef<Value> boundingSubViewSize,
294                                        DataLayout &layout) {
295   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
296   return b
297       .create<memref::AllocOp>(
298           subView.getLoc(),
299           MemRefType::get(shape, subView.getType().getElementType(),
300                           /*affineMapComposition =*/{}, 3),
301           boundingSubViewSize)
302       .getResult();
303 }
304 
305 // Deallocation callback
306 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
307   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
308   return success();
309 }
310 
311 // Copy in call back
312 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
313                                     bool isOutput) {
314   auto floatType = src.getType().cast<MemRefType>().getElementType();
315   if (!floatType.isa<FloatType>())
316     return failure();
317   if (!isOutput) {
318     Value cst =
319         b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0));
320     b.create<FillOp>(src.getLoc(), cst, dst);
321   }
322   b.create<CopyOp>(src.getLoc(), src, dst);
323   return success();
324 }
325 
326 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
327                                           RewritePatternSet &patterns) {
328   patterns.add<LinalgTilingPattern<MatmulOp>>(
329       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
330       LinalgTransformationFilter(Identifier::get("START", ctx),
331                                  Identifier::get("PROMOTE", ctx)));
332   patterns.add<LinalgPromotionPattern<MatmulOp>>(
333       ctx,
334       LinalgPromotionOptions()
335           .setOperandsToPromote({0, 2})
336           .setUseFullTileBuffers({false, false})
337           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
338           .setCopyInOutFns(
339               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
340                 return copyCallBackFn(b, src, dst, false);
341               },
342               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
343                 return copyCallBackFn(b, src, dst, true);
344               }),
345       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
346 }
347 
348 template <typename IdOp, typename NProcsOp>
349 static SmallVector<ProcInfo, 2>
350 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
351   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
352   SmallVector<ProcInfo, 2> procInfo(count);
353   const char *xyz[] = {"x", "y", "z"};
354   Type indexType = b.getIndexType();
355   for (unsigned i = 0; i < count; ++i) {
356     procInfo[count - 1 - i] = {
357         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
358         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
359   }
360   return procInfo;
361 }
362 
363 static void fillTileAndDistributePatterns(MLIRContext *context,
364                                           RewritePatternSet &patterns) {
365   {
366     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
367     cyclicNprocsEqNiters.distributionMethod.resize(
368         2, DistributionMethod::CyclicNumProcsEqNumIters);
369     cyclicNprocsEqNiters.procInfo =
370         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
371     patterns.add<LinalgTilingPattern<MatmulOp>>(
372         context,
373         LinalgTilingOptions()
374             .setTileSizes({8, 8, 4})
375             .setLoopType(LinalgTilingLoopType::ParallelLoops)
376             .setDistributionOptions(cyclicNprocsEqNiters),
377         LinalgTransformationFilter(
378             Identifier::get("distribute1", context),
379             Identifier::get("after_distribute1", context)));
380   }
381 
382   {
383     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
384     cyclicNprocsGeNiters.distributionMethod.resize(
385         2, DistributionMethod::CyclicNumProcsGeNumIters);
386     cyclicNprocsGeNiters.procInfo =
387         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
388     patterns.add<LinalgTilingPattern<MatmulOp>>(
389         context,
390         LinalgTilingOptions()
391             .setTileSizes({8, 8, 4})
392             .setLoopType(LinalgTilingLoopType::ParallelLoops)
393             .setDistributionOptions(cyclicNprocsGeNiters),
394         LinalgTransformationFilter(
395             Identifier::get("distribute2", context),
396             Identifier::get("after_distribute2", context)));
397   }
398 
399   {
400     LinalgLoopDistributionOptions cyclicNprocsDefault;
401     cyclicNprocsDefault.distributionMethod.resize(2,
402                                                   DistributionMethod::Cyclic);
403     cyclicNprocsDefault.procInfo =
404         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
405     patterns.add<LinalgTilingPattern<MatmulOp>>(
406         context,
407         LinalgTilingOptions()
408             .setTileSizes({8, 8, 4})
409             .setLoopType(LinalgTilingLoopType::ParallelLoops)
410             .setDistributionOptions(cyclicNprocsDefault),
411         LinalgTransformationFilter(
412             Identifier::get("distribute3", context),
413             Identifier::get("after_distribute3", context)));
414   }
415 
416   {
417     LinalgLoopDistributionOptions cyclicNprocsMixed1;
418     cyclicNprocsMixed1.distributionMethod = {
419         DistributionMethod::CyclicNumProcsEqNumIters,
420         DistributionMethod::CyclicNumProcsGeNumIters};
421     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
422     patterns.add<LinalgTilingPattern<MatmulOp>>(
423         context,
424         LinalgTilingOptions()
425             .setTileSizes({8, 8, 4})
426             .setLoopType(LinalgTilingLoopType::ParallelLoops)
427             .setDistributionOptions(cyclicNprocsMixed1),
428         LinalgTransformationFilter(
429             Identifier::get("distribute4", context),
430             Identifier::get("after_distribute4", context)));
431   }
432 
433   {
434     LinalgLoopDistributionOptions cyclicNprocsMixed2;
435     cyclicNprocsMixed2.distributionMethod = {
436         DistributionMethod::CyclicNumProcsGeNumIters,
437         DistributionMethod::Cyclic};
438     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
439     patterns.add<LinalgTilingPattern<MatmulOp>>(
440         context,
441         LinalgTilingOptions()
442             .setTileSizes({8, 8, 4})
443             .setLoopType(LinalgTilingLoopType::ParallelLoops)
444             .setDistributionOptions(cyclicNprocsMixed2),
445         LinalgTransformationFilter(
446             Identifier::get("distribute5", context),
447             Identifier::get("after_distribute5", context)));
448   }
449 
450   {
451     LinalgLoopDistributionOptions cyclicNprocsMixed3;
452     cyclicNprocsMixed3.distributionMethod = {
453         DistributionMethod::Cyclic,
454         DistributionMethod::CyclicNumProcsEqNumIters};
455     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
456 
457     patterns.add<LinalgTilingPattern<MatmulOp>>(
458         context,
459         LinalgTilingOptions()
460             .setTileSizes({8, 8, 4})
461             .setLoopType(LinalgTilingLoopType::ParallelLoops)
462             .setDistributionOptions(cyclicNprocsMixed3),
463         LinalgTransformationFilter(
464             Identifier::get("distribute6", context),
465             Identifier::get("after_distribute6", context)));
466   }
467 
468   {
469     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
470     cyclicNprocsEqNiters.distributionMethod.resize(2,
471                                                    DistributionMethod::Cyclic);
472     cyclicNprocsEqNiters.procInfo =
473         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
474     patterns.add<LinalgTilingPattern<MatmulOp>>(
475         context,
476         LinalgTilingOptions()
477             .setTileSizes({8, 8, 4})
478             .setLoopType(LinalgTilingLoopType::Loops)
479             .setDistributionOptions(cyclicNprocsEqNiters),
480         LinalgTransformationFilter(
481             Identifier::get("tensors_distribute1", context),
482             Identifier::get("tensors_after_distribute1", context)));
483   }
484 }
485 
486 static void
487 applyMatmulToVectorPatterns(FuncOp funcOp,
488                             bool testMatmulToVectorPatterns1dTiling,
489                             bool testMatmulToVectorPatterns2dTiling) {
490   MLIRContext *ctx = funcOp.getContext();
491   SmallVector<RewritePatternSet, 4> stage1Patterns;
492   if (testMatmulToVectorPatterns1dTiling) {
493     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
494                                           stage1Patterns);
495   } else if (testMatmulToVectorPatterns2dTiling) {
496     stage1Patterns.emplace_back(
497         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
498                  ctx,
499                  LinalgTilingOptions()
500                      .setTileSizes({768, 264, 768})
501                      .setInterchange({1, 2, 0}),
502                  LinalgTransformationFilter(Identifier::get("START", ctx),
503                                             Identifier::get("L2", ctx))));
504     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
505                                           stage1Patterns);
506   }
507   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
508   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
509   FrozenRewritePatternSet stage2Patterns =
510       getLinalgTilingCanonicalizationPatterns(ctx);
511   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
512                             std::move(stage2Patterns));
513 }
514 
515 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
516   RewritePatternSet forwardPattern(funcOp.getContext());
517   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
518   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
519   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
520 }
521 
522 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
523   RewritePatternSet patterns(funcOp.getContext());
524   patterns.add<LinalgVectorizationPattern>(
525       funcOp.getContext(),
526       LinalgTransformationFilter()
527           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
528   populatePadTensorOpVectorizationPatterns(patterns);
529   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
530 }
531 
532 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
533   RewritePatternSet patterns(funcOp.getContext());
534   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
535   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
536 }
537 
538 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
539   RewritePatternSet patterns(funcOp.getContext());
540   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
541   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
542 }
543 
544 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
545   RewritePatternSet patterns(funcOp.getContext());
546   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
547   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
548 }
549 
550 // For now, just assume it is the zero of type.
551 // In the future, it should be the zero of type + op.
552 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
553   auto t = getElementTypeOrSelf(op.get());
554   return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
555 }
556 
557 static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) {
558   MLIRContext *context = funcOp.getContext();
559   RewritePatternSet tilingPattern(context);
560   auto linalgTilingOptions =
561       linalg::LinalgTilingOptions()
562           .setTileSizes(tileSizes)
563           .setPaddingValueComputationFunction(getNeutralOfLinalgOp);
564   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>,
565                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
566       context, linalgTilingOptions,
567       linalg::LinalgTransformationFilter(
568           Identifier::get("tile-and-pad", context)));
569   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
570 }
571 
572 static void applyInterchangePattern(FuncOp funcOp,
573                                     ArrayRef<unsigned> interchangeVector) {
574   MLIRContext *context = funcOp.getContext();
575   RewritePatternSet interchangePattern(context);
576   interchangePattern.add<GenericOpInterchangePattern>(
577       context, interchangeVector,
578       LinalgTransformationFilter(ArrayRef<Identifier>{},
579                                  Identifier::get("interchange", context)));
580   (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
581 }
582 
583 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
584 
585 namespace {
586 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
587 /// `idx`-th loop contains only "full" iterations and a second loop for the
588 /// remaining partial iteration (if any).
589 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
590   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx)
591       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx) {}
592 
593   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
594                                 PatternRewriter &rewriter) const override {
595     SmallVector<int64_t> peeledLoops;
596     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
597       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
598       peeledLoops =
599           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
600             return attr.cast<IntegerAttr>().getInt();
601           }));
602       // Check if the loop was already peeled.
603       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
604         return failure();
605     }
606 
607     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
608       return failure();
609 
610     // Peel loop and canonicalize.
611     TiledLoopOp result;
612     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
613                                                     result)))
614       return failure();
615     peeledLoops.push_back(idx);
616 
617     // Apply label, so that the same loop is not rewritten a second time.
618     rewriter.updateRootInPlace(loopOp, [&]() {
619       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
620     });
621     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
622     return success();
623   }
624 
625   /// Index of loop to peel.
626   int64_t idx;
627 };
628 } // namespace
629 
630 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
631                                          ArrayRef<unsigned> loops) {
632   MLIRContext *ctx = funcOp.getContext();
633   RewritePatternSet patterns(ctx);
634   for (unsigned idx : loops)
635     patterns.add<TiledLoopPeelingPattern>(ctx, idx);
636   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
637 
638   // Drop the marker.
639   funcOp.walk([](TiledLoopOp op) { op->removeAttr(kPeeledLoopsLabel); });
640 }
641 
642 /// Apply transformations specified as patterns.
643 void TestLinalgTransforms::runOnFunction() {
644   auto lambda = [&](void *) {
645     getFunction().walk([](LinalgOp op) {
646       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
647     });
648   };
649   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
650 
651   if (testPromotionOptions) {
652     RewritePatternSet patterns(&getContext());
653     fillPromotionCallBackPatterns(&getContext(), patterns);
654     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
655     return;
656   }
657   if (testTileAndDistributionOptions) {
658     RewritePatternSet patterns(&getContext());
659     fillTileAndDistributePatterns(&getContext(), patterns);
660     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
661     return;
662   }
663   if (testPatterns)
664     return applyPatterns(getFunction());
665   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
666     return applyMatmulToVectorPatterns(getFunction(),
667                                        testMatmulToVectorPatterns1dTiling,
668                                        testMatmulToVectorPatterns2dTiling);
669   if (testVectorTransferForwardingPatterns)
670     return applyVectorTransferForwardingPatterns(getFunction());
671   if (testGenericToVectorPattern)
672     return applyLinalgToVectorPatterns(getFunction());
673   if (testTransformPadTensor)
674     return applyPadTensorToGenericPatterns(getFunction());
675   if (testGeneralizePadTensor)
676     return applyGeneralizePadTensorPatterns(getFunction());
677   if (testSwapSubTensorPadTensor)
678     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
679   if (testTiledLoopPeeling.hasValue())
680     return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling);
681   if (testTileAndPadPattern)
682     return applyTileAndPadPattern(getFunction(), tileSizesForPadding);
683   if (testHoistPadding) {
684     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
685       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
686     });
687   }
688   if (testInterchangePattern.hasValue())
689     return applyInterchangePattern(getFunction(), testInterchangePattern);
690 }
691 
692 namespace mlir {
693 namespace test {
694 void registerTestLinalgTransforms() {
695   PassRegistration<TestLinalgTransforms>();
696 }
697 } // namespace test
698 } // namespace mlir
699