1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/Vector/IR/VectorOps.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 namespace {
33 struct TestLinalgTransforms
34     : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> {
35   TestLinalgTransforms() = default;
36   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
37 
38   void getDependentDialects(DialectRegistry &registry) const override {
39     // clang-format off
40     registry.insert<AffineDialect,
41                     memref::MemRefDialect,
42                     scf::SCFDialect,
43                     linalg::LinalgDialect,
44                     vector::VectorDialect,
45                     gpu::GPUDialect>();
46     // clang-format on
47   }
48   StringRef getArgument() const final {
49     return "test-linalg-transform-patterns";
50   }
51   StringRef getDescription() const final {
52     return "Test Linalg transformation patterns by applying them greedily.";
53   }
54 
55   void runOnOperation() override;
56 
57   Option<bool> testPatterns{*this, "test-patterns",
58                             llvm::cl::desc("Test a mixed set of patterns"),
59                             llvm::cl::init(false)};
60   Option<bool> testMatmulToVectorPatterns1dTiling{
61       *this, "test-matmul-to-vector-patterns-tile-1d",
62       llvm::cl::desc(
63           "Test a fused pass that applies patterns from matmul to vectors via "
64           "1-d tiling"),
65       llvm::cl::init(false)};
66   Option<bool> testMatmulToVectorPatterns2dTiling{
67       *this, "test-matmul-to-vector-patterns-tile-2d",
68       llvm::cl::desc(
69           "Test a fused pass that applies patterns from matmul to vectors via "
70           "2-d tiling"),
71       llvm::cl::init(false)};
72   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
73                                     llvm::cl::desc("Test promotion options"),
74                                     llvm::cl::init(false)};
75   Option<bool> testTileAndDistributionOptions{
76       *this, "test-tile-and-distribute-options",
77       llvm::cl::desc("Test tile and distribute options"),
78       llvm::cl::init(false)};
79   Option<bool> testTileFuseAndDistributionOptions{
80       *this, "test-tile-fuse-and-distribute-options",
81       llvm::cl::desc("Test tile, fuse and distribute options"),
82       llvm::cl::init(false)};
83   Option<bool> testVectorTransferForwardingPatterns{
84       *this, "test-vector-transfer-forwarding-patterns",
85       llvm::cl::desc(
86           "Test a fused pass that forwards memref.copy to vector.transfer"),
87       llvm::cl::init(false)};
88   Option<bool> testGenericToVectorPattern{
89       *this, "test-linalg-to-vector-patterns",
90       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
91                      "in vector.contract form"),
92       llvm::cl::init(false)};
93   Option<bool> testTilePattern{*this, "test-tile-pattern",
94                                llvm::cl::desc("Test tile pattern"),
95                                llvm::cl::init(false)};
96   Option<bool> testTileScalarizeDynamicDims{
97       *this, "test-tile-scalarize-dynamic-dims",
98       llvm::cl::desc("Test tiling of dynamic dims by 1"),
99       llvm::cl::init(false)};
100   Option<bool> testTransformPadTensor{
101       *this, "test-transform-pad-tensor",
102       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
103       llvm::cl::init(false)};
104   Option<bool> testGeneralizePadTensor{
105       *this, "test-generalize-pad-tensor",
106       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
107       llvm::cl::init(false)};
108   Option<bool> testSwapSubTensorPadTensor{
109       *this, "test-swap-subtensor-padtensor",
110       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
111                      "pad_tensor(subtensor)"),
112       llvm::cl::init(false)};
113   ListOption<int64_t> peeledLoops{
114       *this, "peeled-loops",
115       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
116       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
117   ListOption<int64_t> tileSizes{
118       *this, "tile-sizes",
119       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
120       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
121   Option<bool> skipPartial{
122       *this, "skip-partial",
123       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
124       llvm::cl::init(false)};
125   Option<std::string> loopType{
126       *this, "loop-type",
127       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
128                      "tiled_loop"),
129       llvm::cl::init("for")};
130 };
131 } // namespace
132 
133 static void applyPatterns(FuncOp funcOp) {
134   MLIRContext *ctx = funcOp.getContext();
135   RewritePatternSet patterns(ctx);
136 
137   //===--------------------------------------------------------------------===//
138   // Linalg tiling patterns.
139   //===--------------------------------------------------------------------===//
140   patterns.add<LinalgTilingPattern>(
141       MatmulOp::getOperationName(), ctx,
142       LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
143       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
144                                  StringAttr::get(ctx, "L3")));
145   patterns.add<LinalgTilingPattern>(
146       MatmulOp::getOperationName(), ctx,
147       LinalgTilingOptions().setTileSizes({200, 300, 400}),
148       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
149                                  StringAttr::get(ctx, "L2")));
150   patterns.add<LinalgTilingPattern>(
151       MatmulOp::getOperationName(), ctx,
152       LinalgTilingOptions().setTileSizes({20, 30, 40}),
153       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
154                                  StringAttr::get(ctx, "L1")));
155   patterns.add<LinalgTilingPattern>(
156       MatmulOp::getOperationName(), ctx,
157       LinalgTilingOptions().setTileSizes({2, 3, 4}),
158       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
159                                  StringAttr::get(ctx, "REG")));
160 
161   patterns.add<LinalgTilingPattern>(
162       MatvecOp::getOperationName(), ctx,
163       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
164           LinalgTilingLoopType::ParallelLoops),
165       LinalgTransformationFilter(ArrayRef<StringAttr>{},
166                                  StringAttr::get(ctx, "L1")));
167 
168   patterns.add<LinalgTilingPattern>(
169       DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
170       LinalgTransformationFilter(
171           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
172                                StringAttr::get(ctx, "L3"),
173                                StringAttr::get(ctx, "L2")},
174           StringAttr::get(ctx, "REG")));
175 
176   //===--------------------------------------------------------------------===//
177   // Linalg tiling and permutation patterns.
178   //===--------------------------------------------------------------------===//
179   patterns.add<LinalgTilingPattern>(
180       MatmulOp::getOperationName(), ctx,
181       LinalgTilingOptions()
182           .setTileSizes({2000, 3000, 4000})
183           .setInterchange({1, 2, 0}),
184       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
185                                  StringAttr::get(ctx, "L2__with_perm__")));
186   patterns.add<LinalgTilingPattern>(
187       MatmulOp::getOperationName(), ctx,
188       LinalgTilingOptions()
189           .setTileSizes({200, 300, 400})
190           .setInterchange({1, 0, 2}),
191       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
192                                  StringAttr::get(ctx, "L1__with_perm__")));
193   patterns.add<LinalgTilingPattern>(
194       MatmulOp::getOperationName(), ctx,
195       LinalgTilingOptions().setTileSizes({20, 30, 40}),
196       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
197                                  StringAttr::get(ctx, "REG__with_perm__")));
198 
199   patterns.add<LinalgTilingPattern>(
200       MatvecOp::getOperationName(), ctx,
201       LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
202       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
203                                  StringAttr::get(ctx, "L1__with_perm__")));
204 
205   patterns.add<LinalgTilingPattern>(
206       MatmulOp::getOperationName(), ctx,
207       LinalgTilingOptions()
208           .setTileSizes({16, 8, 4})
209           .setInterchange({1, 2, 0})
210           .setLoopType(LinalgTilingLoopType::ParallelLoops),
211       LinalgTransformationFilter(
212           StringAttr::get(ctx, "par__with_perm__"),
213           StringAttr::get(ctx, "after_par__with_perm__")));
214 
215   //===--------------------------------------------------------------------===//
216   // Linalg to loops patterns.
217   //===--------------------------------------------------------------------===//
218   patterns.add<LinalgLoweringPattern<DotOp>>(
219       ctx,
220       /*loweringType=*/LinalgLoweringType::Loops,
221       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
222 
223   //===--------------------------------------------------------------------===//
224   // Linalg distribution patterns.
225   //===--------------------------------------------------------------------===//
226   LinalgLoopDistributionOptions distributionOptions;
227 
228   //===--------------------------------------------------------------------===//
229   // Linalg to vector contraction patterns.
230   //===--------------------------------------------------------------------===//
231   patterns.add<LinalgVectorizationPattern>(
232       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
233                .addOpFilter<MatmulOp, FillOp, GenericOp>());
234   patterns.add<CopyVectorizationPattern>(ctx);
235 
236   //===--------------------------------------------------------------------===//
237   // Linalg generic interchange pattern.
238   //===--------------------------------------------------------------------===//
239   patterns.add<GenericOpInterchangePattern>(
240       ctx,
241       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
242       LinalgTransformationFilter(ArrayRef<StringAttr>{},
243                                  StringAttr::get(ctx, "PERMUTED")));
244 
245   //===--------------------------------------------------------------------===//
246   // Linalg subview operands promotion.
247   //===--------------------------------------------------------------------===//
248   patterns.add<LinalgPromotionPattern<MatmulOp>>(
249       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
250       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
251                                  StringAttr::get(ctx, "_views_promoted_")));
252   patterns.add<LinalgPromotionPattern<MatmulOp>>(
253       ctx,
254       LinalgPromotionOptions()
255           .setOperandsToPromote({0})
256           .setUseFullTileBuffersByDefault(true),
257       LinalgTransformationFilter(
258           StringAttr::get(ctx, "_promote_first_view_"),
259           StringAttr::get(ctx, "_first_view_promoted_")));
260   patterns.add<LinalgPromotionPattern<FillOp>>(
261       ctx,
262       LinalgPromotionOptions()
263           .setOperandsToPromote({1})
264           .setUseFullTileBuffers({false, true})
265           .setAlignment(32),
266       LinalgTransformationFilter(
267           StringAttr::get(ctx, "_promote_views_aligned_"),
268           StringAttr::get(ctx, "_views_aligned_promoted_")));
269 
270   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
271 
272   // Drop the marker.
273   funcOp.walk([](LinalgOp op) {
274     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
275   });
276 }
277 
278 static void fillL1TilingAndMatmulToVectorPatterns(
279     FuncOp funcOp, StringRef startMarker,
280     SmallVectorImpl<RewritePatternSet> &patternsVector) {
281   MLIRContext *ctx = funcOp.getContext();
282   patternsVector.emplace_back(
283       ctx, std::make_unique<LinalgTilingPattern>(
284                MatmulOp::getOperationName(), ctx,
285                LinalgTilingOptions()
286                    .setTileSizes({8, 12, 16})
287                    .setInterchange({1, 0, 2}),
288                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
289                                           StringAttr::get(ctx, "L1"))));
290 
291   patternsVector.emplace_back(
292       ctx,
293       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
294           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
295           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
296                                      StringAttr::get(ctx, "VEC"))));
297 
298   patternsVector.emplace_back(
299       ctx, std::make_unique<LinalgVectorizationPattern>(
300                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
301                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
302   patternsVector.back().add<LinalgVectorizationPattern>(
303       ctx, LinalgTransformationFilter().addOpFilter<FillOp>());
304   patternsVector.back().add<CopyVectorizationPattern>(ctx);
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // Test promotion callbacks
309 //===----------------------------------------------------------------------===//
310 
311 // Allocation call back
312 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
313                                        ArrayRef<Value> boundingSubViewSize,
314                                        DataLayout &layout) {
315   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
316   return b
317       .create<memref::AllocOp>(
318           subView.getLoc(),
319           MemRefType::get(shape, subView.getType().getElementType(),
320                           /*affineMapComposition =*/{}, 3),
321           boundingSubViewSize)
322       .getResult();
323 }
324 
325 // Deallocation callback
326 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
327   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
328   return success();
329 }
330 
331 // Copy in call back
332 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
333                                     bool isOutput) {
334   auto floatType = src.getType().cast<MemRefType>().getElementType();
335   if (!floatType.isa<FloatType>())
336     return failure();
337   if (!isOutput) {
338     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
339                                             FloatAttr::get(floatType, 42.0));
340     b.create<FillOp>(src.getLoc(), cst, dst);
341   }
342   b.create<memref::CopyOp>(src.getLoc(), src, dst);
343   return success();
344 }
345 
346 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
347                                           RewritePatternSet &patterns) {
348   patterns.add<LinalgTilingPattern>(
349       MatmulOp::getOperationName(), ctx,
350       LinalgTilingOptions().setTileSizes({16, 16, 16}),
351       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
352                                  StringAttr::get(ctx, "PROMOTE")));
353   patterns.add<LinalgPromotionPattern<MatmulOp>>(
354       ctx,
355       LinalgPromotionOptions()
356           .setOperandsToPromote({0, 2})
357           .setUseFullTileBuffers({false, false})
358           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
359           .setCopyInOutFns(
360               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
361                 return copyCallBackFn(b, src, dst, false);
362               },
363               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
364                 return copyCallBackFn(b, src, dst, true);
365               }),
366       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
367 }
368 
369 template <typename IdOp, typename NProcsOp>
370 static SmallVector<ProcInfo, 2>
371 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
372   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
373   SmallVector<ProcInfo, 2> procInfo(count);
374   Type indexType = b.getIndexType();
375   for (unsigned i = 0; i < count; ++i) {
376     gpu::Dimension dim = *gpu::symbolizeDimension(i);
377     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
378                                b.create<NProcsOp>(loc, indexType, dim)};
379   }
380   return procInfo;
381 }
382 
383 static void fillTileAndDistributePatterns(MLIRContext *context,
384                                           RewritePatternSet &patterns) {
385   {
386     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
387     cyclicNprocsEqNiters.distributionMethod.resize(
388         2, DistributionMethod::CyclicNumProcsEqNumIters);
389     cyclicNprocsEqNiters.procInfo =
390         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
391     patterns.add<LinalgTilingPattern>(
392         MatmulOp::getOperationName(), context,
393         LinalgTilingOptions()
394             .setTileSizes({8, 8, 4})
395             .setLoopType(LinalgTilingLoopType::ParallelLoops)
396             .setDistributionOptions(cyclicNprocsEqNiters),
397         LinalgTransformationFilter(
398             StringAttr::get(context, "distribute1"),
399             StringAttr::get(context, "after_distribute1")));
400   }
401 
402   {
403     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
404     cyclicNprocsGeNiters.distributionMethod.resize(
405         2, DistributionMethod::CyclicNumProcsGeNumIters);
406     cyclicNprocsGeNiters.procInfo =
407         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
408     patterns.add<LinalgTilingPattern>(
409         MatmulOp::getOperationName(), context,
410         LinalgTilingOptions()
411             .setTileSizes({8, 8, 4})
412             .setLoopType(LinalgTilingLoopType::ParallelLoops)
413             .setDistributionOptions(cyclicNprocsGeNiters),
414         LinalgTransformationFilter(
415             StringAttr::get(context, "distribute2"),
416             StringAttr::get(context, "after_distribute2")));
417   }
418 
419   {
420     LinalgLoopDistributionOptions cyclicNprocsDefault;
421     cyclicNprocsDefault.distributionMethod.resize(2,
422                                                   DistributionMethod::Cyclic);
423     cyclicNprocsDefault.procInfo =
424         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
425     patterns.add<LinalgTilingPattern>(
426         MatmulOp::getOperationName(), context,
427         LinalgTilingOptions()
428             .setTileSizes({8, 8, 4})
429             .setLoopType(LinalgTilingLoopType::ParallelLoops)
430             .setDistributionOptions(cyclicNprocsDefault),
431         LinalgTransformationFilter(
432             StringAttr::get(context, "distribute3"),
433             StringAttr::get(context, "after_distribute3")));
434   }
435 
436   {
437     LinalgLoopDistributionOptions cyclicNprocsMixed1;
438     cyclicNprocsMixed1.distributionMethod = {
439         DistributionMethod::CyclicNumProcsEqNumIters,
440         DistributionMethod::CyclicNumProcsGeNumIters};
441     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
442     patterns.add<LinalgTilingPattern>(
443         MatmulOp::getOperationName(), context,
444         LinalgTilingOptions()
445             .setTileSizes({8, 8, 4})
446             .setLoopType(LinalgTilingLoopType::ParallelLoops)
447             .setDistributionOptions(cyclicNprocsMixed1),
448         LinalgTransformationFilter(
449             StringAttr::get(context, "distribute4"),
450             StringAttr::get(context, "after_distribute4")));
451   }
452 
453   {
454     LinalgLoopDistributionOptions cyclicNprocsMixed2;
455     cyclicNprocsMixed2.distributionMethod = {
456         DistributionMethod::CyclicNumProcsGeNumIters,
457         DistributionMethod::Cyclic};
458     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
459     patterns.add<LinalgTilingPattern>(
460         MatmulOp::getOperationName(), context,
461         LinalgTilingOptions()
462             .setTileSizes({8, 8, 4})
463             .setLoopType(LinalgTilingLoopType::ParallelLoops)
464             .setDistributionOptions(cyclicNprocsMixed2),
465         LinalgTransformationFilter(
466             StringAttr::get(context, "distribute5"),
467             StringAttr::get(context, "after_distribute5")));
468   }
469 
470   {
471     LinalgLoopDistributionOptions cyclicNprocsMixed3;
472     cyclicNprocsMixed3.distributionMethod = {
473         DistributionMethod::Cyclic,
474         DistributionMethod::CyclicNumProcsEqNumIters};
475     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
476 
477     patterns.add<LinalgTilingPattern>(
478         MatmulOp::getOperationName(), context,
479         LinalgTilingOptions()
480             .setTileSizes({8, 8, 4})
481             .setLoopType(LinalgTilingLoopType::ParallelLoops)
482             .setDistributionOptions(cyclicNprocsMixed3),
483         LinalgTransformationFilter(
484             StringAttr::get(context, "distribute6"),
485             StringAttr::get(context, "after_distribute6")));
486   }
487 
488   {
489     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
490     cyclicNprocsEqNiters.distributionMethod.resize(2,
491                                                    DistributionMethod::Cyclic);
492     cyclicNprocsEqNiters.procInfo =
493         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
494     patterns.add<LinalgTilingPattern>(
495         MatmulOp::getOperationName(), context,
496         LinalgTilingOptions()
497             .setTileSizes({8, 8, 4})
498             .setLoopType(LinalgTilingLoopType::Loops)
499             .setDistributionOptions(cyclicNprocsEqNiters),
500         LinalgTransformationFilter(
501             StringAttr::get(context, "tensors_distribute1"),
502             StringAttr::get(context, "tensors_after_distribute1")));
503   }
504 }
505 
506 static void fillTileFuseAndDistributePatterns(MLIRContext *context,
507                                               RewritePatternSet &patterns) {
508   LinalgLoopDistributionOptions cyclicNprocsEqNiters;
509   cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
510   cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
511   patterns.add<LinalgTileAndFuseTensorOpsPattern>(
512       MatmulOp::getOperationName(), context,
513       LinalgTilingAndFusionOptions()
514           .setTileSizes({8, 8, 4})
515           .setDistributionOptions(cyclicNprocsEqNiters),
516       LinalgTransformationFilter(
517           StringAttr::get(context, "tensors_fuse_distribute1"),
518           StringAttr::get(context, "tensors_after_fuse_distribute1")));
519 }
520 
521 static void
522 applyMatmulToVectorPatterns(FuncOp funcOp,
523                             bool testMatmulToVectorPatterns1dTiling,
524                             bool testMatmulToVectorPatterns2dTiling) {
525   MLIRContext *ctx = funcOp.getContext();
526   SmallVector<RewritePatternSet, 4> stage1Patterns;
527   if (testMatmulToVectorPatterns1dTiling) {
528     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
529   } else if (testMatmulToVectorPatterns2dTiling) {
530     stage1Patterns.emplace_back(
531         ctx, std::make_unique<LinalgTilingPattern>(
532                  MatmulOp::getOperationName(), ctx,
533                  LinalgTilingOptions()
534                      .setTileSizes({768, 264, 768})
535                      .setInterchange({1, 2, 0}),
536                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
537                                             StringAttr::get(ctx, "L2"))));
538     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
539   }
540   {
541     // Canonicalization patterns
542     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
543     vector::populateVectorTransferPermutationMapLoweringPatterns(
544         canonicalizationPatterns);
545     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
546     stage1Patterns.push_back(std::move(canonicalizationPatterns));
547   }
548   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
549   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
550   FrozenRewritePatternSet stage2Patterns =
551       getLinalgTilingCanonicalizationPatterns(ctx);
552   (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns);
553 }
554 
555 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
556   RewritePatternSet forwardPattern(funcOp.getContext());
557   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
558   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
559   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
560 }
561 
562 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
563   RewritePatternSet patterns(funcOp.getContext());
564   auto *ctx = funcOp.getContext();
565   patterns.add<LinalgVectorizationPattern>(
566       ctx, LinalgTransformationFilter()
567                .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
568   patterns.add<CopyVectorizationPattern>(ctx);
569   populatePadOpVectorizationPatterns(patterns);
570   populateConvolutionVectorizationPatterns(patterns);
571   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
572 }
573 
574 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
575   RewritePatternSet patterns(funcOp.getContext());
576   patterns.add<PadOpTransformationPattern>(funcOp.getContext());
577   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
578 }
579 
580 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
581   RewritePatternSet patterns(funcOp.getContext());
582   patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
583   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
584 }
585 
586 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
587   RewritePatternSet patterns(funcOp.getContext());
588   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
589   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
590 }
591 
592 static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
593                              ArrayRef<int64_t> tileSizes,
594                              ArrayRef<int64_t> peeledLoops,
595                              bool scalarizeDynamicDims) {
596   MLIRContext *context = funcOp.getContext();
597   RewritePatternSet tilingPattern(context);
598   LinalgTilingLoopType type =
599       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
600           .Case("for", LinalgTilingLoopType::Loops)
601           .Case("affine", LinalgTilingLoopType::AffineLoops)
602           .Case("parallel", LinalgTilingLoopType::ParallelLoops);
603   auto linalgTilingOptions = linalg::LinalgTilingOptions()
604                                  .setPeeledLoops(peeledLoops)
605                                  .setLoopType(type);
606   if (scalarizeDynamicDims) {
607     linalgTilingOptions.scalarizeDynamicDims();
608     assert(tileSizes.empty() &&
609            "tileSizes and scalarizeDynamicDims is mutually exclusive");
610   } else {
611     linalgTilingOptions.setTileSizes(tileSizes);
612   }
613   linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
614   TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
615       tilingPattern, linalgTilingOptions, f);
616   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
617 }
618 
619 /// Apply transformations specified as patterns.
620 void TestLinalgTransforms::runOnOperation() {
621   auto lambda = [&](void *) {
622     getOperation().walk([](LinalgOp op) {
623       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
624     });
625   };
626   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
627 
628   if (testPromotionOptions) {
629     RewritePatternSet patterns(&getContext());
630     fillPromotionCallBackPatterns(&getContext(), patterns);
631     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
632     return;
633   }
634   if (testTileAndDistributionOptions) {
635     RewritePatternSet patterns(&getContext());
636     fillTileAndDistributePatterns(&getContext(), patterns);
637     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
638     return;
639   }
640   if (testTileFuseAndDistributionOptions) {
641     RewritePatternSet patterns(&getContext());
642     fillTileFuseAndDistributePatterns(&getContext(), patterns);
643     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
644     return;
645   }
646   if (testPatterns)
647     return applyPatterns(getOperation());
648   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
649     return applyMatmulToVectorPatterns(getOperation(),
650                                        testMatmulToVectorPatterns1dTiling,
651                                        testMatmulToVectorPatterns2dTiling);
652   if (testVectorTransferForwardingPatterns)
653     return applyVectorTransferForwardingPatterns(getOperation());
654   if (testGenericToVectorPattern)
655     return applyLinalgToVectorPatterns(getOperation());
656   if (testTransformPadTensor)
657     return applyPadTensorToGenericPatterns(getOperation());
658   if (testGeneralizePadTensor)
659     return applyGeneralizePadTensorPatterns(getOperation());
660   if (testSwapSubTensorPadTensor)
661     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
662   if (testTilePattern)
663     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
664                             /*scalarizeDynamicDims=*/false);
665   if (testTileScalarizeDynamicDims)
666     return applyTilePattern(getOperation(), loopType, tileSizes,
667                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
668 }
669 
670 namespace mlir {
671 namespace test {
672 void registerTestLinalgTransforms() {
673   PassRegistration<TestLinalgTransforms>();
674 }
675 } // namespace test
676 } // namespace mlir
677