1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
18 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 namespace {
33 struct TestLinalgTransforms
34     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
35   TestLinalgTransforms() = default;
36   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
37 
38   void getDependentDialects(DialectRegistry &registry) const override {
39     // clang-format off
40     registry.insert<AffineDialect,
41                     memref::MemRefDialect,
42                     scf::SCFDialect,
43                     StandardOpsDialect,
44                     vector::VectorDialect,
45                     gpu::GPUDialect>();
46     // clang-format on
47   }
48   StringRef getArgument() const final {
49     return "test-linalg-transform-patterns";
50   }
51   StringRef getDescription() const final {
52     return "Test Linalg transformation patterns by applying them greedily.";
53   }
54 
55   void runOnFunction() override;
56 
57   Option<bool> testPatterns{*this, "test-patterns",
58                             llvm::cl::desc("Test a mixed set of patterns"),
59                             llvm::cl::init(false)};
60   Option<bool> testMatmulToVectorPatterns1dTiling{
61       *this, "test-matmul-to-vector-patterns-tile-1d",
62       llvm::cl::desc(
63           "Test a fused pass that applies patterns from matmul to vectors via "
64           "1-d tiling"),
65       llvm::cl::init(false)};
66   Option<bool> testMatmulToVectorPatterns2dTiling{
67       *this, "test-matmul-to-vector-patterns-tile-2d",
68       llvm::cl::desc(
69           "Test a fused pass that applies patterns from matmul to vectors via "
70           "2-d tiling"),
71       llvm::cl::init(false)};
72   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
73                                     llvm::cl::desc("Test promotion options"),
74                                     llvm::cl::init(false)};
75   Option<bool> testTileAndDistributionOptions{
76       *this, "test-tile-and-distribute-options",
77       llvm::cl::desc("Test tile and distribute options"),
78       llvm::cl::init(false)};
79   Option<bool> testVectorTransferForwardingPatterns{
80       *this, "test-vector-transfer-forwarding-patterns",
81       llvm::cl::desc(
82           "Test a fused pass that forwards linalg.copy to vector.transfer"),
83       llvm::cl::init(false)};
84   Option<bool> testGenericToVectorPattern{
85       *this, "test-linalg-to-vector-patterns",
86       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
87                      "in vector.contract form"),
88       llvm::cl::init(false)};
89   Option<bool> testTilePattern{*this, "test-tile-pattern",
90                                llvm::cl::desc("Test tile pattern"),
91                                llvm::cl::init(false)};
92   Option<bool> testTileScalarizeDynamicDims{
93       *this, "test-tile-scalarize-dynamic-dims",
94       llvm::cl::desc("Test tiling of dynamic dims by 1"),
95       llvm::cl::init(false)};
96   Option<int> testHoistPadding{*this, "test-hoist-padding",
97                                llvm::cl::desc("Test hoist padding"),
98                                llvm::cl::init(0)};
99   Option<bool> testTransformPadTensor{
100       *this, "test-transform-pad-tensor",
101       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
102       llvm::cl::init(false)};
103   Option<bool> testGeneralizePadTensor{
104       *this, "test-generalize-pad-tensor",
105       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
106       llvm::cl::init(false)};
107   Option<bool> testSwapSubTensorPadTensor{
108       *this, "test-swap-subtensor-padtensor",
109       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
110                      "pad_tensor(subtensor)"),
111       llvm::cl::init(false)};
112   ListOption<int64_t> paddedOperands{
113       *this, "padded-operands",
114       llvm::cl::desc("Operands to pad when test-tile-pattern"),
115       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
116   ListOption<int64_t> peeledLoops{
117       *this, "peeled-loops",
118       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
119       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
120   ListOption<int64_t> tileSizes{
121       *this, "tile-sizes",
122       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
123       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
124   ListOption<unsigned> testInterchangePattern{
125       *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
126       llvm::cl::desc("Test the interchange pattern.")};
127   ListOption<unsigned> testTiledLoopPeeling{
128       *this, "test-tiled-loop-peeling",
129       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
130       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
131   Option<bool> skipPartial{
132       *this, "skip-partial",
133       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
134       llvm::cl::init(false)};
135   Option<std::string> loopType{
136       *this, "loop-type",
137       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
138                      "tiled_loop"),
139       llvm::cl::init("for")};
140 };
141 } // end anonymous namespace
142 
143 static void applyPatterns(FuncOp funcOp) {
144   MLIRContext *ctx = funcOp.getContext();
145   RewritePatternSet patterns(ctx);
146 
147   //===--------------------------------------------------------------------===//
148   // Linalg tiling patterns.
149   //===--------------------------------------------------------------------===//
150   patterns.add<LinalgTilingPattern<MatmulOp>>(
151       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
152       LinalgTransformationFilter(Identifier::get("MEM", ctx),
153                                  Identifier::get("L3", ctx)));
154   patterns.add<LinalgTilingPattern<MatmulOp>>(
155       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
156       LinalgTransformationFilter(Identifier::get("L3", ctx),
157                                  Identifier::get("L2", ctx)));
158   patterns.add<LinalgTilingPattern<MatmulOp>>(
159       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
160       LinalgTransformationFilter(Identifier::get("L2", ctx),
161                                  Identifier::get("L1", ctx)));
162   patterns.add<LinalgTilingPattern<MatmulOp>>(
163       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
164       LinalgTransformationFilter(Identifier::get("L1", ctx),
165                                  Identifier::get("REG", ctx)));
166 
167   patterns.add<LinalgTilingPattern<MatvecOp>>(
168       ctx,
169       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
170           LinalgTilingLoopType::ParallelLoops),
171       LinalgTransformationFilter(ArrayRef<Identifier>{},
172                                  Identifier::get("L1", ctx)));
173 
174   patterns.add<LinalgTilingPattern<DotOp>>(
175       ctx, LinalgTilingOptions().setTileSizes(8000),
176       LinalgTransformationFilter(
177           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
178                                Identifier::get("L3", ctx),
179                                Identifier::get("L2", ctx)},
180           Identifier::get("REG", ctx)));
181 
182   //===--------------------------------------------------------------------===//
183   // Linalg tiling and permutation patterns.
184   //===--------------------------------------------------------------------===//
185   patterns.add<LinalgTilingPattern<MatmulOp>>(
186       ctx,
187       LinalgTilingOptions()
188           .setTileSizes({2000, 3000, 4000})
189           .setInterchange({1, 2, 0}),
190       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
191                                  Identifier::get("L2__with_perm__", ctx)));
192   patterns.add<LinalgTilingPattern<MatmulOp>>(
193       ctx,
194       LinalgTilingOptions()
195           .setTileSizes({200, 300, 400})
196           .setInterchange({1, 0, 2}),
197       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
198                                  Identifier::get("L1__with_perm__", ctx)));
199   patterns.add<LinalgTilingPattern<MatmulOp>>(
200       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
201       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
202                                  Identifier::get("REG__with_perm__", ctx)));
203 
204   patterns.add<LinalgTilingPattern<MatvecOp>>(
205       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
206       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
207                                  Identifier::get("L1__with_perm__", ctx)));
208 
209   patterns.add<LinalgTilingPattern<MatmulOp>>(
210       ctx,
211       LinalgTilingOptions()
212           .setTileSizes({16, 8, 4})
213           .setInterchange({1, 2, 0})
214           .setLoopType(LinalgTilingLoopType::ParallelLoops),
215       LinalgTransformationFilter(
216           Identifier::get("par__with_perm__", ctx),
217           Identifier::get("after_par__with_perm__", ctx)));
218 
219   //===--------------------------------------------------------------------===//
220   // Linalg to loops patterns.
221   //===--------------------------------------------------------------------===//
222   patterns.add<LinalgLoweringPattern<DotOp>>(
223       ctx,
224       /*loweringType=*/LinalgLoweringType::Loops,
225       LinalgTransformationFilter(Identifier::get("REG", ctx)));
226 
227   //===--------------------------------------------------------------------===//
228   // Linalg distribution patterns.
229   //===--------------------------------------------------------------------===//
230   LinalgLoopDistributionOptions distributionOptions;
231 
232   //===--------------------------------------------------------------------===//
233   // Linalg to vector contraction patterns.
234   //===--------------------------------------------------------------------===//
235   patterns.add<LinalgVectorizationPattern>(
236       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
237                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
238 
239   //===--------------------------------------------------------------------===//
240   // Linalg generic interchange pattern.
241   //===--------------------------------------------------------------------===//
242   patterns.add<GenericOpInterchangePattern>(
243       ctx,
244       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
245       LinalgTransformationFilter(ArrayRef<Identifier>{},
246                                  Identifier::get("PERMUTED", ctx)));
247 
248   //===--------------------------------------------------------------------===//
249   // Linalg subview operands promotion.
250   //===--------------------------------------------------------------------===//
251   patterns.add<LinalgPromotionPattern<MatmulOp>>(
252       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
253       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
254                                  Identifier::get("_views_promoted_", ctx)));
255   patterns.add<LinalgPromotionPattern<MatmulOp>>(
256       ctx,
257       LinalgPromotionOptions()
258           .setOperandsToPromote({0})
259           .setUseFullTileBuffersByDefault(true),
260       LinalgTransformationFilter(
261           Identifier::get("_promote_first_view_", ctx),
262           Identifier::get("_first_view_promoted_", ctx)));
263   patterns.add<LinalgPromotionPattern<FillOp>>(
264       ctx,
265       LinalgPromotionOptions()
266           .setOperandsToPromote({1})
267           .setUseFullTileBuffers({false, true})
268           .setAlignment(32),
269       LinalgTransformationFilter(
270           Identifier::get("_promote_views_aligned_", ctx),
271           Identifier::get("_views_aligned_promoted_", ctx)));
272 
273   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
274 
275   // Drop the marker.
276   funcOp.walk([](LinalgOp op) {
277     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
278   });
279 }
280 
281 static void fillL1TilingAndMatmulToVectorPatterns(
282     FuncOp funcOp, StringRef startMarker,
283     SmallVectorImpl<RewritePatternSet> &patternsVector) {
284   MLIRContext *ctx = funcOp.getContext();
285   patternsVector.emplace_back(
286       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
287                ctx,
288                LinalgTilingOptions()
289                    .setTileSizes({8, 12, 16})
290                    .setInterchange({1, 0, 2}),
291                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
292                                           Identifier::get("L1", ctx))));
293 
294   patternsVector.emplace_back(
295       ctx,
296       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
297           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
298           LinalgTransformationFilter(Identifier::get("L1", ctx),
299                                      Identifier::get("VEC", ctx))));
300 
301   patternsVector.emplace_back(
302       ctx, std::make_unique<LinalgVectorizationPattern>(
303                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
304                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
305   patternsVector.back().add<LinalgVectorizationPattern>(
306       ctx, LinalgTransformationFilter().addFilter(
307                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // Test promotion callbacks
312 //===----------------------------------------------------------------------===//
313 
314 // Allocation call back
315 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
316                                        ArrayRef<Value> boundingSubViewSize,
317                                        DataLayout &layout) {
318   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
319   return b
320       .create<memref::AllocOp>(
321           subView.getLoc(),
322           MemRefType::get(shape, subView.getType().getElementType(),
323                           /*affineMapComposition =*/{}, 3),
324           boundingSubViewSize)
325       .getResult();
326 }
327 
328 // Deallocation callback
329 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
330   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
331   return success();
332 }
333 
334 // Copy in call back
335 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
336                                     bool isOutput) {
337   auto floatType = src.getType().cast<MemRefType>().getElementType();
338   if (!floatType.isa<FloatType>())
339     return failure();
340   if (!isOutput) {
341     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
342                                             FloatAttr::get(floatType, 42.0));
343     b.create<FillOp>(src.getLoc(), cst, dst);
344   }
345   b.create<CopyOp>(src.getLoc(), src, dst);
346   return success();
347 }
348 
349 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
350                                           RewritePatternSet &patterns) {
351   patterns.add<LinalgTilingPattern<MatmulOp>>(
352       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
353       LinalgTransformationFilter(Identifier::get("START", ctx),
354                                  Identifier::get("PROMOTE", ctx)));
355   patterns.add<LinalgPromotionPattern<MatmulOp>>(
356       ctx,
357       LinalgPromotionOptions()
358           .setOperandsToPromote({0, 2})
359           .setUseFullTileBuffers({false, false})
360           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
361           .setCopyInOutFns(
362               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
363                 return copyCallBackFn(b, src, dst, false);
364               },
365               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
366                 return copyCallBackFn(b, src, dst, true);
367               }),
368       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
369 }
370 
371 template <typename IdOp, typename NProcsOp>
372 static SmallVector<ProcInfo, 2>
373 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
374   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
375   SmallVector<ProcInfo, 2> procInfo(count);
376   const char *xyz[] = {"x", "y", "z"};
377   Type indexType = b.getIndexType();
378   for (unsigned i = 0; i < count; ++i) {
379     procInfo[count - 1 - i] = {
380         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
381         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
382   }
383   return procInfo;
384 }
385 
386 static void fillTileAndDistributePatterns(MLIRContext *context,
387                                           RewritePatternSet &patterns) {
388   {
389     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
390     cyclicNprocsEqNiters.distributionMethod.resize(
391         2, DistributionMethod::CyclicNumProcsEqNumIters);
392     cyclicNprocsEqNiters.procInfo =
393         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
394     patterns.add<LinalgTilingPattern<MatmulOp>>(
395         context,
396         LinalgTilingOptions()
397             .setTileSizes({8, 8, 4})
398             .setLoopType(LinalgTilingLoopType::ParallelLoops)
399             .setDistributionOptions(cyclicNprocsEqNiters),
400         LinalgTransformationFilter(
401             Identifier::get("distribute1", context),
402             Identifier::get("after_distribute1", context)));
403   }
404 
405   {
406     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
407     cyclicNprocsGeNiters.distributionMethod.resize(
408         2, DistributionMethod::CyclicNumProcsGeNumIters);
409     cyclicNprocsGeNiters.procInfo =
410         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
411     patterns.add<LinalgTilingPattern<MatmulOp>>(
412         context,
413         LinalgTilingOptions()
414             .setTileSizes({8, 8, 4})
415             .setLoopType(LinalgTilingLoopType::ParallelLoops)
416             .setDistributionOptions(cyclicNprocsGeNiters),
417         LinalgTransformationFilter(
418             Identifier::get("distribute2", context),
419             Identifier::get("after_distribute2", context)));
420   }
421 
422   {
423     LinalgLoopDistributionOptions cyclicNprocsDefault;
424     cyclicNprocsDefault.distributionMethod.resize(2,
425                                                   DistributionMethod::Cyclic);
426     cyclicNprocsDefault.procInfo =
427         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
428     patterns.add<LinalgTilingPattern<MatmulOp>>(
429         context,
430         LinalgTilingOptions()
431             .setTileSizes({8, 8, 4})
432             .setLoopType(LinalgTilingLoopType::ParallelLoops)
433             .setDistributionOptions(cyclicNprocsDefault),
434         LinalgTransformationFilter(
435             Identifier::get("distribute3", context),
436             Identifier::get("after_distribute3", context)));
437   }
438 
439   {
440     LinalgLoopDistributionOptions cyclicNprocsMixed1;
441     cyclicNprocsMixed1.distributionMethod = {
442         DistributionMethod::CyclicNumProcsEqNumIters,
443         DistributionMethod::CyclicNumProcsGeNumIters};
444     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
445     patterns.add<LinalgTilingPattern<MatmulOp>>(
446         context,
447         LinalgTilingOptions()
448             .setTileSizes({8, 8, 4})
449             .setLoopType(LinalgTilingLoopType::ParallelLoops)
450             .setDistributionOptions(cyclicNprocsMixed1),
451         LinalgTransformationFilter(
452             Identifier::get("distribute4", context),
453             Identifier::get("after_distribute4", context)));
454   }
455 
456   {
457     LinalgLoopDistributionOptions cyclicNprocsMixed2;
458     cyclicNprocsMixed2.distributionMethod = {
459         DistributionMethod::CyclicNumProcsGeNumIters,
460         DistributionMethod::Cyclic};
461     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
462     patterns.add<LinalgTilingPattern<MatmulOp>>(
463         context,
464         LinalgTilingOptions()
465             .setTileSizes({8, 8, 4})
466             .setLoopType(LinalgTilingLoopType::ParallelLoops)
467             .setDistributionOptions(cyclicNprocsMixed2),
468         LinalgTransformationFilter(
469             Identifier::get("distribute5", context),
470             Identifier::get("after_distribute5", context)));
471   }
472 
473   {
474     LinalgLoopDistributionOptions cyclicNprocsMixed3;
475     cyclicNprocsMixed3.distributionMethod = {
476         DistributionMethod::Cyclic,
477         DistributionMethod::CyclicNumProcsEqNumIters};
478     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
479 
480     patterns.add<LinalgTilingPattern<MatmulOp>>(
481         context,
482         LinalgTilingOptions()
483             .setTileSizes({8, 8, 4})
484             .setLoopType(LinalgTilingLoopType::ParallelLoops)
485             .setDistributionOptions(cyclicNprocsMixed3),
486         LinalgTransformationFilter(
487             Identifier::get("distribute6", context),
488             Identifier::get("after_distribute6", context)));
489   }
490 
491   {
492     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
493     cyclicNprocsEqNiters.distributionMethod.resize(2,
494                                                    DistributionMethod::Cyclic);
495     cyclicNprocsEqNiters.procInfo =
496         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
497     patterns.add<LinalgTilingPattern<MatmulOp>>(
498         context,
499         LinalgTilingOptions()
500             .setTileSizes({8, 8, 4})
501             .setLoopType(LinalgTilingLoopType::Loops)
502             .setDistributionOptions(cyclicNprocsEqNiters),
503         LinalgTransformationFilter(
504             Identifier::get("tensors_distribute1", context),
505             Identifier::get("tensors_after_distribute1", context)));
506   }
507 }
508 
509 static void
510 applyMatmulToVectorPatterns(FuncOp funcOp,
511                             bool testMatmulToVectorPatterns1dTiling,
512                             bool testMatmulToVectorPatterns2dTiling) {
513   MLIRContext *ctx = funcOp.getContext();
514   SmallVector<RewritePatternSet, 4> stage1Patterns;
515   if (testMatmulToVectorPatterns1dTiling) {
516     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
517                                           stage1Patterns);
518   } else if (testMatmulToVectorPatterns2dTiling) {
519     stage1Patterns.emplace_back(
520         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
521                  ctx,
522                  LinalgTilingOptions()
523                      .setTileSizes({768, 264, 768})
524                      .setInterchange({1, 2, 0}),
525                  LinalgTransformationFilter(Identifier::get("START", ctx),
526                                             Identifier::get("L2", ctx))));
527     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
528                                           stage1Patterns);
529   }
530   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
531   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
532   FrozenRewritePatternSet stage2Patterns =
533       getLinalgTilingCanonicalizationPatterns(ctx);
534   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
535                             std::move(stage2Patterns));
536 }
537 
538 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
539   RewritePatternSet forwardPattern(funcOp.getContext());
540   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
541   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
542   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
543 }
544 
545 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
546   RewritePatternSet patterns(funcOp.getContext());
547   patterns.add<LinalgVectorizationPattern>(
548       funcOp.getContext(),
549       LinalgTransformationFilter()
550           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
551   populatePadTensorOpVectorizationPatterns(patterns);
552   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
553 }
554 
555 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
556   RewritePatternSet patterns(funcOp.getContext());
557   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
558   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
559 }
560 
561 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
562   RewritePatternSet patterns(funcOp.getContext());
563   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
564   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
565 }
566 
567 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
568   RewritePatternSet patterns(funcOp.getContext());
569   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
570   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
571 }
572 
573 // For now, just assume it is the zero of type.
574 // In the future, it should be the zero of type + op.
575 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
576   auto t = getElementTypeOrSelf(op.get());
577   return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t,
578                                      b.getZeroAttr(t));
579 }
580 
581 static void applyTilePattern(FuncOp funcOp, std::string loopType,
582                              ArrayRef<int64_t> tileSizes,
583                              ArrayRef<int64_t> paddedOperands,
584                              ArrayRef<int64_t> peeledLoops,
585                              bool scalarizeDynamicDims) {
586   MLIRContext *context = funcOp.getContext();
587   RewritePatternSet tilingPattern(context);
588   LinalgTilingLoopType type =
589       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
590           .Case("for", LinalgTilingLoopType::Loops)
591           .Case("affine", LinalgTilingLoopType::AffineLoops)
592           .Case("parallel", LinalgTilingLoopType::ParallelLoops)
593           .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
594   auto linalgTilingOptions = linalg::LinalgTilingOptions()
595                                  .setPeeledLoops(peeledLoops)
596                                  .setLoopType(type);
597   if (scalarizeDynamicDims) {
598     linalgTilingOptions.scalarizeDynamicDims();
599     assert(tileSizes.empty() &&
600            "tileSizes and scalarizeDynamicDims is mutually exclusive");
601   } else {
602     linalgTilingOptions.setTileSizes(tileSizes);
603   }
604   if (!paddedOperands.empty()) {
605     auto paddingFunc = [&](OpBuilder &b,
606                            OpOperand &opOperand) -> FailureOr<Value> {
607       if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0)
608         return failure();
609       return getNeutralOfLinalgOp(b, opOperand);
610     };
611     linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc);
612   }
613   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
614                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
615       context, linalgTilingOptions,
616       linalg::LinalgTransformationFilter(Identifier::get("tile", context)));
617   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
618 }
619 
620 static void applyInterchangePattern(FuncOp funcOp,
621                                     ArrayRef<unsigned> interchangeVector) {
622   MLIRContext *context = funcOp.getContext();
623   RewritePatternSet interchangePattern(context);
624   interchangePattern.add<GenericOpInterchangePattern>(
625       context, interchangeVector,
626       LinalgTransformationFilter(ArrayRef<Identifier>{},
627                                  Identifier::get("interchange", context)));
628   (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
629 }
630 
631 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
632 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
633 
634 namespace {
635 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
636 /// `idx`-th loop contains only "full" iterations and a second loop for the
637 /// remaining partial iteration (if any).
638 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
639   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
640       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
641   }
642 
643   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
644                                 PatternRewriter &rewriter) const override {
645     SmallVector<int64_t> peeledLoops;
646     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
647       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
648       peeledLoops =
649           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
650             return attr.cast<IntegerAttr>().getInt();
651           }));
652       // Check if the loop was already peeled.
653       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
654         return failure();
655     }
656     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
657       // No peeling of loop nests with a partial iteration.
658       return failure();
659 
660     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
661       return failure();
662 
663     // Peel loop and canonicalize.
664     TiledLoopOp result;
665     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
666                                                     result)))
667       return failure();
668 
669     // Apply label, so that the same loop is not rewritten a second time.
670     peeledLoops.push_back(idx);
671     rewriter.updateRootInPlace(loopOp, [&]() {
672       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
673     });
674     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
675     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
676 
677     return success();
678   }
679 
680   /// Index of loop to peel.
681   int64_t idx;
682 
683   /// If set to true, do not peel TiledLoopOps with a partial iteration.
684   bool skipPartial;
685 };
686 } // namespace
687 
688 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
689                                          ArrayRef<unsigned> loops,
690                                          bool skipPartial) {
691   MLIRContext *ctx = funcOp.getContext();
692   RewritePatternSet patterns(ctx);
693   for (unsigned idx : loops)
694     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
695   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
696 
697   // Drop the markers.
698   funcOp.walk([](TiledLoopOp op) {
699     op->removeAttr(kPeeledLoopsLabel);
700     op->removeAttr(kPartialIterationLabel);
701   });
702 }
703 
704 /// Apply transformations specified as patterns.
705 void TestLinalgTransforms::runOnFunction() {
706   auto lambda = [&](void *) {
707     getFunction().walk([](LinalgOp op) {
708       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
709     });
710   };
711   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
712 
713   if (testPromotionOptions) {
714     RewritePatternSet patterns(&getContext());
715     fillPromotionCallBackPatterns(&getContext(), patterns);
716     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
717     return;
718   }
719   if (testTileAndDistributionOptions) {
720     RewritePatternSet patterns(&getContext());
721     fillTileAndDistributePatterns(&getContext(), patterns);
722     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
723     return;
724   }
725   if (testPatterns)
726     return applyPatterns(getFunction());
727   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
728     return applyMatmulToVectorPatterns(getFunction(),
729                                        testMatmulToVectorPatterns1dTiling,
730                                        testMatmulToVectorPatterns2dTiling);
731   if (testVectorTransferForwardingPatterns)
732     return applyVectorTransferForwardingPatterns(getFunction());
733   if (testGenericToVectorPattern)
734     return applyLinalgToVectorPatterns(getFunction());
735   if (testTransformPadTensor)
736     return applyPadTensorToGenericPatterns(getFunction());
737   if (testGeneralizePadTensor)
738     return applyGeneralizePadTensorPatterns(getFunction());
739   if (testSwapSubTensorPadTensor)
740     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
741   if (testTiledLoopPeeling.hasValue())
742     return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
743                                         skipPartial);
744   if (testTilePattern)
745     return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
746                             peeledLoops, /*scalarizeDynamicDims=*/false);
747   if (testTileScalarizeDynamicDims)
748     return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
749                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
750   if (testHoistPadding) {
751     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
752       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
753     });
754   }
755   if (testInterchangePattern.hasValue())
756     return applyInterchangePattern(getFunction(), testInterchangePattern);
757 }
758 
759 namespace mlir {
760 namespace test {
761 void registerTestLinalgTransforms() {
762   PassRegistration<TestLinalgTransforms>();
763 }
764 } // namespace test
765 } // namespace mlir
766