1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
30 namespace {
31 struct TestLinalgTransforms
32     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
33   TestLinalgTransforms() = default;
34   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
35 
36   void getDependentDialects(DialectRegistry &registry) const override {
37     // clang-format off
38     registry.insert<AffineDialect,
39                     memref::MemRefDialect,
40                     scf::SCFDialect,
41                     StandardOpsDialect,
42                     vector::VectorDialect,
43                     gpu::GPUDialect>();
44     // clang-format on
45   }
46   StringRef getArgument() const final {
47     return "test-linalg-transform-patterns";
48   }
49   StringRef getDescription() const final {
50     return "Test Linalg transformation patterns by applying them greedily.";
51   }
52 
53   void runOnFunction() override;
54 
55   Option<bool> testPatterns{*this, "test-patterns",
56                             llvm::cl::desc("Test a mixed set of patterns"),
57                             llvm::cl::init(false)};
58   Option<bool> testMatmulToVectorPatterns1dTiling{
59       *this, "test-matmul-to-vector-patterns-tile-1d",
60       llvm::cl::desc(
61           "Test a fused pass that applies patterns from matmul to vectors via "
62           "1-d tiling"),
63       llvm::cl::init(false)};
64   Option<bool> testMatmulToVectorPatterns2dTiling{
65       *this, "test-matmul-to-vector-patterns-tile-2d",
66       llvm::cl::desc(
67           "Test a fused pass that applies patterns from matmul to vectors via "
68           "2-d tiling"),
69       llvm::cl::init(false)};
70   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
71                                     llvm::cl::desc("Test promotion options"),
72                                     llvm::cl::init(false)};
73   Option<bool> testTileAndDistributionOptions{
74       *this, "test-tile-and-distribute-options",
75       llvm::cl::desc("Test tile and distribute options"),
76       llvm::cl::init(false)};
77   Option<bool> testVectorTransferForwardingPatterns{
78       *this, "test-vector-transfer-forwarding-patterns",
79       llvm::cl::desc(
80           "Test a fused pass that forwards linalg.copy to vector.transfer"),
81       llvm::cl::init(false)};
82   Option<bool> testGenericToVectorPattern{
83       *this, "test-linalg-to-vector-patterns",
84       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
85                      "in vector.contract form"),
86       llvm::cl::init(false)};
87   Option<bool> testTilePattern{*this, "test-tile-pattern",
88                                llvm::cl::desc("Test tile pattern"),
89                                llvm::cl::init(false)};
90   Option<bool> testTileScalarizeDynamicDims{
91       *this, "test-tile-scalarize-dynamic-dims",
92       llvm::cl::desc("Test tiling of dynamic dims by 1"),
93       llvm::cl::init(false)};
94   Option<int> testHoistPadding{*this, "test-hoist-padding",
95                                llvm::cl::desc("Test hoist padding"),
96                                llvm::cl::init(0)};
97   Option<bool> testTransformPadTensor{
98       *this, "test-transform-pad-tensor",
99       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
100       llvm::cl::init(false)};
101   Option<bool> testGeneralizePadTensor{
102       *this, "test-generalize-pad-tensor",
103       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
104       llvm::cl::init(false)};
105   Option<bool> testSwapSubTensorPadTensor{
106       *this, "test-swap-subtensor-padtensor",
107       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
108                      "pad_tensor(subtensor)"),
109       llvm::cl::init(false)};
110   Option<bool> padTiles{*this, "pad-tiles",
111                         llvm::cl::desc("Pad tiles when test-tile-pattern"),
112                         llvm::cl::init(false)};
113   ListOption<int64_t> peeledLoops{
114       *this, "peeled-loops",
115       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
116       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
117   ListOption<int64_t> tileSizes{
118       *this, "tile-sizes",
119       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
120       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
121   ListOption<unsigned> testInterchangePattern{
122       *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
123       llvm::cl::desc("Test the interchange pattern.")};
124   ListOption<unsigned> testTiledLoopPeeling{
125       *this, "test-tiled-loop-peeling",
126       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
127       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
128   Option<bool> skipPartial{
129       *this, "skip-partial",
130       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
131       llvm::cl::init(false)};
132   Option<std::string> loopType{
133       *this, "loop-type",
134       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
135                      "tiled_loop"),
136       llvm::cl::init("for")};
137 };
138 } // end anonymous namespace
139 
140 static void applyPatterns(FuncOp funcOp) {
141   MLIRContext *ctx = funcOp.getContext();
142   RewritePatternSet patterns(ctx);
143 
144   //===--------------------------------------------------------------------===//
145   // Linalg tiling patterns.
146   //===--------------------------------------------------------------------===//
147   patterns.add<LinalgTilingPattern<MatmulOp>>(
148       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
149       LinalgTransformationFilter(Identifier::get("MEM", ctx),
150                                  Identifier::get("L3", ctx)));
151   patterns.add<LinalgTilingPattern<MatmulOp>>(
152       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
153       LinalgTransformationFilter(Identifier::get("L3", ctx),
154                                  Identifier::get("L2", ctx)));
155   patterns.add<LinalgTilingPattern<MatmulOp>>(
156       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
157       LinalgTransformationFilter(Identifier::get("L2", ctx),
158                                  Identifier::get("L1", ctx)));
159   patterns.add<LinalgTilingPattern<MatmulOp>>(
160       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
161       LinalgTransformationFilter(Identifier::get("L1", ctx),
162                                  Identifier::get("REG", ctx)));
163 
164   patterns.add<LinalgTilingPattern<MatvecOp>>(
165       ctx,
166       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
167           LinalgTilingLoopType::ParallelLoops),
168       LinalgTransformationFilter(ArrayRef<Identifier>{},
169                                  Identifier::get("L1", ctx)));
170 
171   patterns.add<LinalgTilingPattern<DotOp>>(
172       ctx, LinalgTilingOptions().setTileSizes(8000),
173       LinalgTransformationFilter(
174           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
175                                Identifier::get("L3", ctx),
176                                Identifier::get("L2", ctx)},
177           Identifier::get("REG", ctx)));
178 
179   //===--------------------------------------------------------------------===//
180   // Linalg tiling and permutation patterns.
181   //===--------------------------------------------------------------------===//
182   patterns.add<LinalgTilingPattern<MatmulOp>>(
183       ctx,
184       LinalgTilingOptions()
185           .setTileSizes({2000, 3000, 4000})
186           .setInterchange({1, 2, 0}),
187       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
188                                  Identifier::get("L2__with_perm__", ctx)));
189   patterns.add<LinalgTilingPattern<MatmulOp>>(
190       ctx,
191       LinalgTilingOptions()
192           .setTileSizes({200, 300, 400})
193           .setInterchange({1, 0, 2}),
194       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
195                                  Identifier::get("L1__with_perm__", ctx)));
196   patterns.add<LinalgTilingPattern<MatmulOp>>(
197       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
198       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
199                                  Identifier::get("REG__with_perm__", ctx)));
200 
201   patterns.add<LinalgTilingPattern<MatvecOp>>(
202       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
203       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
204                                  Identifier::get("L1__with_perm__", ctx)));
205 
206   patterns.add<LinalgTilingPattern<MatmulOp>>(
207       ctx,
208       LinalgTilingOptions()
209           .setTileSizes({16, 8, 4})
210           .setInterchange({1, 2, 0})
211           .setLoopType(LinalgTilingLoopType::ParallelLoops),
212       LinalgTransformationFilter(
213           Identifier::get("par__with_perm__", ctx),
214           Identifier::get("after_par__with_perm__", ctx)));
215 
216   //===--------------------------------------------------------------------===//
217   // Linalg to loops patterns.
218   //===--------------------------------------------------------------------===//
219   patterns.add<LinalgLoweringPattern<DotOp>>(
220       ctx,
221       /*loweringType=*/LinalgLoweringType::Loops,
222       LinalgTransformationFilter(Identifier::get("REG", ctx)));
223 
224   //===--------------------------------------------------------------------===//
225   // Linalg distribution patterns.
226   //===--------------------------------------------------------------------===//
227   LinalgLoopDistributionOptions distributionOptions;
228 
229   //===--------------------------------------------------------------------===//
230   // Linalg to vector contraction patterns.
231   //===--------------------------------------------------------------------===//
232   patterns.add<LinalgVectorizationPattern>(
233       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
234                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
235 
236   //===--------------------------------------------------------------------===//
237   // Linalg generic interchange pattern.
238   //===--------------------------------------------------------------------===//
239   patterns.add<GenericOpInterchangePattern>(
240       ctx,
241       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
242       LinalgTransformationFilter(ArrayRef<Identifier>{},
243                                  Identifier::get("PERMUTED", ctx)));
244 
245   //===--------------------------------------------------------------------===//
246   // Linalg subview operands promotion.
247   //===--------------------------------------------------------------------===//
248   patterns.add<LinalgPromotionPattern<MatmulOp>>(
249       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
250       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
251                                  Identifier::get("_views_promoted_", ctx)));
252   patterns.add<LinalgPromotionPattern<MatmulOp>>(
253       ctx,
254       LinalgPromotionOptions()
255           .setOperandsToPromote({0})
256           .setUseFullTileBuffersByDefault(true),
257       LinalgTransformationFilter(
258           Identifier::get("_promote_first_view_", ctx),
259           Identifier::get("_first_view_promoted_", ctx)));
260   patterns.add<LinalgPromotionPattern<FillOp>>(
261       ctx,
262       LinalgPromotionOptions()
263           .setOperandsToPromote({1})
264           .setUseFullTileBuffers({false, true})
265           .setAlignment(32),
266       LinalgTransformationFilter(
267           Identifier::get("_promote_views_aligned_", ctx),
268           Identifier::get("_views_aligned_promoted_", ctx)));
269 
270   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
271 
272   // Drop the marker.
273   funcOp.walk([](LinalgOp op) {
274     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
275   });
276 }
277 
278 static void fillL1TilingAndMatmulToVectorPatterns(
279     FuncOp funcOp, StringRef startMarker,
280     SmallVectorImpl<RewritePatternSet> &patternsVector) {
281   MLIRContext *ctx = funcOp.getContext();
282   patternsVector.emplace_back(
283       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
284                ctx,
285                LinalgTilingOptions()
286                    .setTileSizes({8, 12, 16})
287                    .setInterchange({1, 0, 2}),
288                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
289                                           Identifier::get("L1", ctx))));
290 
291   patternsVector.emplace_back(
292       ctx,
293       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
294           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
295           LinalgTransformationFilter(Identifier::get("L1", ctx),
296                                      Identifier::get("VEC", ctx))));
297 
298   patternsVector.emplace_back(
299       ctx, std::make_unique<LinalgVectorizationPattern>(
300                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
301                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
302   patternsVector.back().add<LinalgVectorizationPattern>(
303       ctx, LinalgTransformationFilter().addFilter(
304                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // Test promotion callbacks
309 //===----------------------------------------------------------------------===//
310 
311 // Allocation call back
312 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
313                                        ArrayRef<Value> boundingSubViewSize,
314                                        DataLayout &layout) {
315   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
316   return b
317       .create<memref::AllocOp>(
318           subView.getLoc(),
319           MemRefType::get(shape, subView.getType().getElementType(),
320                           /*affineMapComposition =*/{}, 3),
321           boundingSubViewSize)
322       .getResult();
323 }
324 
325 // Deallocation callback
326 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
327   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
328   return success();
329 }
330 
331 // Copy in call back
332 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
333                                     bool isOutput) {
334   auto floatType = src.getType().cast<MemRefType>().getElementType();
335   if (!floatType.isa<FloatType>())
336     return failure();
337   if (!isOutput) {
338     Value cst =
339         b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0));
340     b.create<FillOp>(src.getLoc(), cst, dst);
341   }
342   b.create<CopyOp>(src.getLoc(), src, dst);
343   return success();
344 }
345 
346 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
347                                           RewritePatternSet &patterns) {
348   patterns.add<LinalgTilingPattern<MatmulOp>>(
349       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
350       LinalgTransformationFilter(Identifier::get("START", ctx),
351                                  Identifier::get("PROMOTE", ctx)));
352   patterns.add<LinalgPromotionPattern<MatmulOp>>(
353       ctx,
354       LinalgPromotionOptions()
355           .setOperandsToPromote({0, 2})
356           .setUseFullTileBuffers({false, false})
357           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
358           .setCopyInOutFns(
359               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
360                 return copyCallBackFn(b, src, dst, false);
361               },
362               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
363                 return copyCallBackFn(b, src, dst, true);
364               }),
365       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
366 }
367 
368 template <typename IdOp, typename NProcsOp>
369 static SmallVector<ProcInfo, 2>
370 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
371   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
372   SmallVector<ProcInfo, 2> procInfo(count);
373   const char *xyz[] = {"x", "y", "z"};
374   Type indexType = b.getIndexType();
375   for (unsigned i = 0; i < count; ++i) {
376     procInfo[count - 1 - i] = {
377         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
378         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
379   }
380   return procInfo;
381 }
382 
383 static void fillTileAndDistributePatterns(MLIRContext *context,
384                                           RewritePatternSet &patterns) {
385   {
386     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
387     cyclicNprocsEqNiters.distributionMethod.resize(
388         2, DistributionMethod::CyclicNumProcsEqNumIters);
389     cyclicNprocsEqNiters.procInfo =
390         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
391     patterns.add<LinalgTilingPattern<MatmulOp>>(
392         context,
393         LinalgTilingOptions()
394             .setTileSizes({8, 8, 4})
395             .setLoopType(LinalgTilingLoopType::ParallelLoops)
396             .setDistributionOptions(cyclicNprocsEqNiters),
397         LinalgTransformationFilter(
398             Identifier::get("distribute1", context),
399             Identifier::get("after_distribute1", context)));
400   }
401 
402   {
403     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
404     cyclicNprocsGeNiters.distributionMethod.resize(
405         2, DistributionMethod::CyclicNumProcsGeNumIters);
406     cyclicNprocsGeNiters.procInfo =
407         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
408     patterns.add<LinalgTilingPattern<MatmulOp>>(
409         context,
410         LinalgTilingOptions()
411             .setTileSizes({8, 8, 4})
412             .setLoopType(LinalgTilingLoopType::ParallelLoops)
413             .setDistributionOptions(cyclicNprocsGeNiters),
414         LinalgTransformationFilter(
415             Identifier::get("distribute2", context),
416             Identifier::get("after_distribute2", context)));
417   }
418 
419   {
420     LinalgLoopDistributionOptions cyclicNprocsDefault;
421     cyclicNprocsDefault.distributionMethod.resize(2,
422                                                   DistributionMethod::Cyclic);
423     cyclicNprocsDefault.procInfo =
424         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
425     patterns.add<LinalgTilingPattern<MatmulOp>>(
426         context,
427         LinalgTilingOptions()
428             .setTileSizes({8, 8, 4})
429             .setLoopType(LinalgTilingLoopType::ParallelLoops)
430             .setDistributionOptions(cyclicNprocsDefault),
431         LinalgTransformationFilter(
432             Identifier::get("distribute3", context),
433             Identifier::get("after_distribute3", context)));
434   }
435 
436   {
437     LinalgLoopDistributionOptions cyclicNprocsMixed1;
438     cyclicNprocsMixed1.distributionMethod = {
439         DistributionMethod::CyclicNumProcsEqNumIters,
440         DistributionMethod::CyclicNumProcsGeNumIters};
441     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
442     patterns.add<LinalgTilingPattern<MatmulOp>>(
443         context,
444         LinalgTilingOptions()
445             .setTileSizes({8, 8, 4})
446             .setLoopType(LinalgTilingLoopType::ParallelLoops)
447             .setDistributionOptions(cyclicNprocsMixed1),
448         LinalgTransformationFilter(
449             Identifier::get("distribute4", context),
450             Identifier::get("after_distribute4", context)));
451   }
452 
453   {
454     LinalgLoopDistributionOptions cyclicNprocsMixed2;
455     cyclicNprocsMixed2.distributionMethod = {
456         DistributionMethod::CyclicNumProcsGeNumIters,
457         DistributionMethod::Cyclic};
458     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
459     patterns.add<LinalgTilingPattern<MatmulOp>>(
460         context,
461         LinalgTilingOptions()
462             .setTileSizes({8, 8, 4})
463             .setLoopType(LinalgTilingLoopType::ParallelLoops)
464             .setDistributionOptions(cyclicNprocsMixed2),
465         LinalgTransformationFilter(
466             Identifier::get("distribute5", context),
467             Identifier::get("after_distribute5", context)));
468   }
469 
470   {
471     LinalgLoopDistributionOptions cyclicNprocsMixed3;
472     cyclicNprocsMixed3.distributionMethod = {
473         DistributionMethod::Cyclic,
474         DistributionMethod::CyclicNumProcsEqNumIters};
475     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
476 
477     patterns.add<LinalgTilingPattern<MatmulOp>>(
478         context,
479         LinalgTilingOptions()
480             .setTileSizes({8, 8, 4})
481             .setLoopType(LinalgTilingLoopType::ParallelLoops)
482             .setDistributionOptions(cyclicNprocsMixed3),
483         LinalgTransformationFilter(
484             Identifier::get("distribute6", context),
485             Identifier::get("after_distribute6", context)));
486   }
487 
488   {
489     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
490     cyclicNprocsEqNiters.distributionMethod.resize(2,
491                                                    DistributionMethod::Cyclic);
492     cyclicNprocsEqNiters.procInfo =
493         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
494     patterns.add<LinalgTilingPattern<MatmulOp>>(
495         context,
496         LinalgTilingOptions()
497             .setTileSizes({8, 8, 4})
498             .setLoopType(LinalgTilingLoopType::Loops)
499             .setDistributionOptions(cyclicNprocsEqNiters),
500         LinalgTransformationFilter(
501             Identifier::get("tensors_distribute1", context),
502             Identifier::get("tensors_after_distribute1", context)));
503   }
504 }
505 
506 static void
507 applyMatmulToVectorPatterns(FuncOp funcOp,
508                             bool testMatmulToVectorPatterns1dTiling,
509                             bool testMatmulToVectorPatterns2dTiling) {
510   MLIRContext *ctx = funcOp.getContext();
511   SmallVector<RewritePatternSet, 4> stage1Patterns;
512   if (testMatmulToVectorPatterns1dTiling) {
513     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
514                                           stage1Patterns);
515   } else if (testMatmulToVectorPatterns2dTiling) {
516     stage1Patterns.emplace_back(
517         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
518                  ctx,
519                  LinalgTilingOptions()
520                      .setTileSizes({768, 264, 768})
521                      .setInterchange({1, 2, 0}),
522                  LinalgTransformationFilter(Identifier::get("START", ctx),
523                                             Identifier::get("L2", ctx))));
524     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
525                                           stage1Patterns);
526   }
527   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
528   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
529   FrozenRewritePatternSet stage2Patterns =
530       getLinalgTilingCanonicalizationPatterns(ctx);
531   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
532                             std::move(stage2Patterns));
533 }
534 
535 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
536   RewritePatternSet forwardPattern(funcOp.getContext());
537   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
538   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
539   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
540 }
541 
542 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
543   RewritePatternSet patterns(funcOp.getContext());
544   patterns.add<LinalgVectorizationPattern>(
545       funcOp.getContext(),
546       LinalgTransformationFilter()
547           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
548   populatePadTensorOpVectorizationPatterns(patterns);
549   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
550 }
551 
552 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
553   RewritePatternSet patterns(funcOp.getContext());
554   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
555   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
556 }
557 
558 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
559   RewritePatternSet patterns(funcOp.getContext());
560   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
561   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
562 }
563 
564 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
565   RewritePatternSet patterns(funcOp.getContext());
566   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
567   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
568 }
569 
570 // For now, just assume it is the zero of type.
571 // In the future, it should be the zero of type + op.
572 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
573   auto t = getElementTypeOrSelf(op.get());
574   return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
575 }
576 
577 static void applyTilePattern(FuncOp funcOp, std::string loopType,
578                              ArrayRef<int64_t> tileSizes, bool padTiles,
579                              ArrayRef<int64_t> peeledLoops,
580                              bool scalarizeDynamicDims) {
581   MLIRContext *context = funcOp.getContext();
582   RewritePatternSet tilingPattern(context);
583   LinalgTilingLoopType type =
584       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
585           .Case("for", LinalgTilingLoopType::Loops)
586           .Case("affine", LinalgTilingLoopType::AffineLoops)
587           .Case("parallel", LinalgTilingLoopType::ParallelLoops)
588           .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
589   auto linalgTilingOptions = linalg::LinalgTilingOptions()
590                                  .setPeeledLoops(peeledLoops)
591                                  .setLoopType(type);
592   if (scalarizeDynamicDims) {
593     linalgTilingOptions.scalarizeDynamicDims();
594     assert(tileSizes.empty() &&
595            "tileSizes and scalarizeDynamicDims is mutually exclusive");
596   } else {
597     linalgTilingOptions.setTileSizes(tileSizes);
598   }
599   if (padTiles)
600     linalgTilingOptions.setPaddingValueComputationFunction(
601         getNeutralOfLinalgOp);
602 
603   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
604                     linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>,
605                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
606       context, linalgTilingOptions,
607       linalg::LinalgTransformationFilter(Identifier::get("tile", context)));
608   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
609 }
610 
611 static void applyInterchangePattern(FuncOp funcOp,
612                                     ArrayRef<unsigned> interchangeVector) {
613   MLIRContext *context = funcOp.getContext();
614   RewritePatternSet interchangePattern(context);
615   interchangePattern.add<GenericOpInterchangePattern>(
616       context, interchangeVector,
617       LinalgTransformationFilter(ArrayRef<Identifier>{},
618                                  Identifier::get("interchange", context)));
619   (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
620 }
621 
622 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
623 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
624 
625 namespace {
626 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
627 /// `idx`-th loop contains only "full" iterations and a second loop for the
628 /// remaining partial iteration (if any).
629 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
630   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
631       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
632   }
633 
634   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
635                                 PatternRewriter &rewriter) const override {
636     SmallVector<int64_t> peeledLoops;
637     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
638       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
639       peeledLoops =
640           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
641             return attr.cast<IntegerAttr>().getInt();
642           }));
643       // Check if the loop was already peeled.
644       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
645         return failure();
646     }
647     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
648       // No peeling of loop nests with a partial iteration.
649       return failure();
650 
651     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
652       return failure();
653 
654     // Peel loop and canonicalize.
655     TiledLoopOp result;
656     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
657                                                     result)))
658       return failure();
659 
660     // Apply label, so that the same loop is not rewritten a second time.
661     peeledLoops.push_back(idx);
662     rewriter.updateRootInPlace(loopOp, [&]() {
663       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
664     });
665     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
666     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
667 
668     return success();
669   }
670 
671   /// Index of loop to peel.
672   int64_t idx;
673 
674   /// If set to true, do not peel TiledLoopOps with a partial iteration.
675   bool skipPartial;
676 };
677 } // namespace
678 
679 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
680                                          ArrayRef<unsigned> loops,
681                                          bool skipPartial) {
682   MLIRContext *ctx = funcOp.getContext();
683   RewritePatternSet patterns(ctx);
684   for (unsigned idx : loops)
685     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
686   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
687 
688   // Drop the markers.
689   funcOp.walk([](TiledLoopOp op) {
690     op->removeAttr(kPeeledLoopsLabel);
691     op->removeAttr(kPartialIterationLabel);
692   });
693 }
694 
695 /// Apply transformations specified as patterns.
696 void TestLinalgTransforms::runOnFunction() {
697   auto lambda = [&](void *) {
698     getFunction().walk([](LinalgOp op) {
699       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
700     });
701   };
702   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
703 
704   if (testPromotionOptions) {
705     RewritePatternSet patterns(&getContext());
706     fillPromotionCallBackPatterns(&getContext(), patterns);
707     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
708     return;
709   }
710   if (testTileAndDistributionOptions) {
711     RewritePatternSet patterns(&getContext());
712     fillTileAndDistributePatterns(&getContext(), patterns);
713     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
714     return;
715   }
716   if (testPatterns)
717     return applyPatterns(getFunction());
718   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
719     return applyMatmulToVectorPatterns(getFunction(),
720                                        testMatmulToVectorPatterns1dTiling,
721                                        testMatmulToVectorPatterns2dTiling);
722   if (testVectorTransferForwardingPatterns)
723     return applyVectorTransferForwardingPatterns(getFunction());
724   if (testGenericToVectorPattern)
725     return applyLinalgToVectorPatterns(getFunction());
726   if (testTransformPadTensor)
727     return applyPadTensorToGenericPatterns(getFunction());
728   if (testGeneralizePadTensor)
729     return applyGeneralizePadTensorPatterns(getFunction());
730   if (testSwapSubTensorPadTensor)
731     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
732   if (testTiledLoopPeeling.hasValue())
733     return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
734                                         skipPartial);
735   if (testTilePattern)
736     return applyTilePattern(getFunction(), loopType, tileSizes, padTiles,
737                             peeledLoops, /*scalarizeDynamicDims=*/false);
738   if (testTileScalarizeDynamicDims)
739     return applyTilePattern(getFunction(), loopType, tileSizes, padTiles,
740                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
741   if (testHoistPadding) {
742     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
743       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
744     });
745   }
746   if (testInterchangePattern.hasValue())
747     return applyInterchangePattern(getFunction(), testInterchangePattern);
748 }
749 
750 namespace mlir {
751 namespace test {
752 void registerTestLinalgTransforms() {
753   PassRegistration<TestLinalgTransforms>();
754 }
755 } // namespace test
756 } // namespace mlir
757