1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/GPU/GPUDialect.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct TestLinalgTransforms
35     : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> {
36   TestLinalgTransforms() = default;
37   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
38 
39   void getDependentDialects(DialectRegistry &registry) const override {
40     // clang-format off
41     registry.insert<AffineDialect,
42                     memref::MemRefDialect,
43                     scf::SCFDialect,
44                     linalg::LinalgDialect,
45                     vector::VectorDialect,
46                     gpu::GPUDialect>();
47     // clang-format on
48   }
49   StringRef getArgument() const final {
50     return "test-linalg-transform-patterns";
51   }
52   StringRef getDescription() const final {
53     return "Test Linalg transformation patterns by applying them greedily.";
54   }
55 
56   void runOnOperation() override;
57 
58   Option<bool> testPatterns{*this, "test-patterns",
59                             llvm::cl::desc("Test a mixed set of patterns"),
60                             llvm::cl::init(false)};
61   Option<bool> testMatmulToVectorPatterns1dTiling{
62       *this, "test-matmul-to-vector-patterns-tile-1d",
63       llvm::cl::desc(
64           "Test a fused pass that applies patterns from matmul to vectors via "
65           "1-d tiling"),
66       llvm::cl::init(false)};
67   Option<bool> testMatmulToVectorPatterns2dTiling{
68       *this, "test-matmul-to-vector-patterns-tile-2d",
69       llvm::cl::desc(
70           "Test a fused pass that applies patterns from matmul to vectors via "
71           "2-d tiling"),
72       llvm::cl::init(false)};
73   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
74                                     llvm::cl::desc("Test promotion options"),
75                                     llvm::cl::init(false)};
76   Option<bool> testTileAndDistributionOptions{
77       *this, "test-tile-and-distribute-options",
78       llvm::cl::desc("Test tile and distribute options"),
79       llvm::cl::init(false)};
80   Option<bool> testTileFuseAndDistributionOptions{
81       *this, "test-tile-fuse-and-distribute-options",
82       llvm::cl::desc("Test tile, fuse and distribute options"),
83       llvm::cl::init(false)};
84   Option<bool> testVectorTransferForwardingPatterns{
85       *this, "test-vector-transfer-forwarding-patterns",
86       llvm::cl::desc(
87           "Test a fused pass that forwards memref.copy to vector.transfer"),
88       llvm::cl::init(false)};
89   Option<bool> testGenericToVectorPattern{
90       *this, "test-linalg-to-vector-patterns",
91       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
92                      "in vector.contract form"),
93       llvm::cl::init(false)};
94   Option<bool> testTilePattern{*this, "test-tile-pattern",
95                                llvm::cl::desc("Test tile pattern"),
96                                llvm::cl::init(false)};
97   Option<bool> testTileScalarizeDynamicDims{
98       *this, "test-tile-scalarize-dynamic-dims",
99       llvm::cl::desc("Test tiling of dynamic dims by 1"),
100       llvm::cl::init(false)};
101   Option<bool> testTransformPadTensor{
102       *this, "test-transform-pad-tensor",
103       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
104       llvm::cl::init(false)};
105   Option<bool> testGeneralizePadTensor{
106       *this, "test-generalize-pad-tensor",
107       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
108       llvm::cl::init(false)};
109   Option<bool> testSwapSubTensorPadTensor{
110       *this, "test-swap-subtensor-padtensor",
111       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
112                      "pad_tensor(subtensor)"),
113       llvm::cl::init(false)};
114   Option<bool> testSplitReduction{
115       *this, "test-split-reduction",
116       llvm::cl::desc("Test split reduction transformation"),
117       llvm::cl::init(false)};
118   ListOption<int64_t> peeledLoops{
119       *this, "peeled-loops",
120       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
121       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
122   ListOption<int64_t> tileSizes{
123       *this, "tile-sizes",
124       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
125       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
126   Option<bool> skipPartial{
127       *this, "skip-partial",
128       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
129       llvm::cl::init(false)};
130   Option<std::string> loopType{
131       *this, "loop-type",
132       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
133                      "tiled_loop"),
134       llvm::cl::init("for")};
135 };
136 } // namespace
137 
138 static void applyPatterns(FuncOp funcOp) {
139   MLIRContext *ctx = funcOp.getContext();
140   RewritePatternSet patterns(ctx);
141 
142   //===--------------------------------------------------------------------===//
143   // Linalg tiling patterns.
144   //===--------------------------------------------------------------------===//
145   patterns.add<LinalgTilingPattern>(
146       MatmulOp::getOperationName(), ctx,
147       LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
148       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
149                                  StringAttr::get(ctx, "L3")));
150   patterns.add<LinalgTilingPattern>(
151       MatmulOp::getOperationName(), ctx,
152       LinalgTilingOptions().setTileSizes({200, 300, 400}),
153       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
154                                  StringAttr::get(ctx, "L2")));
155   patterns.add<LinalgTilingPattern>(
156       MatmulOp::getOperationName(), ctx,
157       LinalgTilingOptions().setTileSizes({20, 30, 40}),
158       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
159                                  StringAttr::get(ctx, "L1")));
160   patterns.add<LinalgTilingPattern>(
161       MatmulOp::getOperationName(), ctx,
162       LinalgTilingOptions().setTileSizes({2, 3, 4}),
163       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
164                                  StringAttr::get(ctx, "REG")));
165 
166   patterns.add<LinalgTilingPattern>(
167       MatvecOp::getOperationName(), ctx,
168       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
169           LinalgTilingLoopType::ParallelLoops),
170       LinalgTransformationFilter(ArrayRef<StringAttr>{},
171                                  StringAttr::get(ctx, "L1")));
172 
173   patterns.add<LinalgTilingPattern>(
174       DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
175       LinalgTransformationFilter(
176           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
177                                StringAttr::get(ctx, "L3"),
178                                StringAttr::get(ctx, "L2")},
179           StringAttr::get(ctx, "REG")));
180 
181   //===--------------------------------------------------------------------===//
182   // Linalg tiling and permutation patterns.
183   //===--------------------------------------------------------------------===//
184   patterns.add<LinalgTilingPattern>(
185       MatmulOp::getOperationName(), ctx,
186       LinalgTilingOptions()
187           .setTileSizes({2000, 3000, 4000})
188           .setInterchange({1, 2, 0}),
189       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
190                                  StringAttr::get(ctx, "L2__with_perm__")));
191   patterns.add<LinalgTilingPattern>(
192       MatmulOp::getOperationName(), ctx,
193       LinalgTilingOptions()
194           .setTileSizes({200, 300, 400})
195           .setInterchange({1, 0, 2}),
196       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
197                                  StringAttr::get(ctx, "L1__with_perm__")));
198   patterns.add<LinalgTilingPattern>(
199       MatmulOp::getOperationName(), ctx,
200       LinalgTilingOptions().setTileSizes({20, 30, 40}),
201       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
202                                  StringAttr::get(ctx, "REG__with_perm__")));
203 
204   patterns.add<LinalgTilingPattern>(
205       MatvecOp::getOperationName(), ctx,
206       LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
207       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
208                                  StringAttr::get(ctx, "L1__with_perm__")));
209 
210   patterns.add<LinalgTilingPattern>(
211       MatmulOp::getOperationName(), ctx,
212       LinalgTilingOptions()
213           .setTileSizes({16, 8, 4})
214           .setInterchange({1, 2, 0})
215           .setLoopType(LinalgTilingLoopType::ParallelLoops),
216       LinalgTransformationFilter(
217           StringAttr::get(ctx, "par__with_perm__"),
218           StringAttr::get(ctx, "after_par__with_perm__")));
219 
220   //===--------------------------------------------------------------------===//
221   // Linalg to loops patterns.
222   //===--------------------------------------------------------------------===//
223   patterns.add<LinalgLoweringPattern<DotOp>>(
224       ctx,
225       /*loweringType=*/LinalgLoweringType::Loops,
226       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
227 
228   //===--------------------------------------------------------------------===//
229   // Linalg distribution patterns.
230   //===--------------------------------------------------------------------===//
231   LinalgLoopDistributionOptions distributionOptions;
232 
233   //===--------------------------------------------------------------------===//
234   // Linalg to vector contraction patterns.
235   //===--------------------------------------------------------------------===//
236   patterns.add<LinalgVectorizationPattern>(
237       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
238                .addOpFilter<MatmulOp, FillOp, GenericOp>());
239   patterns.add<CopyVectorizationPattern>(ctx);
240 
241   //===--------------------------------------------------------------------===//
242   // Linalg generic interchange pattern.
243   //===--------------------------------------------------------------------===//
244   patterns.add<GenericOpInterchangePattern>(
245       ctx,
246       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
247       LinalgTransformationFilter(ArrayRef<StringAttr>{},
248                                  StringAttr::get(ctx, "PERMUTED")));
249 
250   //===--------------------------------------------------------------------===//
251   // Linalg subview operands promotion.
252   //===--------------------------------------------------------------------===//
253   patterns.add<LinalgPromotionPattern<MatmulOp>>(
254       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
255       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
256                                  StringAttr::get(ctx, "_views_promoted_")));
257   patterns.add<LinalgPromotionPattern<MatmulOp>>(
258       ctx,
259       LinalgPromotionOptions()
260           .setOperandsToPromote({0})
261           .setUseFullTileBuffersByDefault(true),
262       LinalgTransformationFilter(
263           StringAttr::get(ctx, "_promote_first_view_"),
264           StringAttr::get(ctx, "_first_view_promoted_")));
265   patterns.add<LinalgPromotionPattern<FillOp>>(
266       ctx,
267       LinalgPromotionOptions()
268           .setOperandsToPromote({1})
269           .setUseFullTileBuffers({false, true})
270           .setAlignment(32),
271       LinalgTransformationFilter(
272           StringAttr::get(ctx, "_promote_views_aligned_"),
273           StringAttr::get(ctx, "_views_aligned_promoted_")));
274 
275   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
276 
277   // Drop the marker.
278   funcOp.walk([](LinalgOp op) {
279     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
280   });
281 }
282 
283 static void fillL1TilingAndMatmulToVectorPatterns(
284     FuncOp funcOp, StringRef startMarker,
285     SmallVectorImpl<RewritePatternSet> &patternsVector) {
286   MLIRContext *ctx = funcOp.getContext();
287   patternsVector.emplace_back(
288       ctx, std::make_unique<LinalgTilingPattern>(
289                MatmulOp::getOperationName(), ctx,
290                LinalgTilingOptions()
291                    .setTileSizes({8, 12, 16})
292                    .setInterchange({1, 0, 2}),
293                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
294                                           StringAttr::get(ctx, "L1"))));
295 
296   patternsVector.emplace_back(
297       ctx,
298       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
299           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
300           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
301                                      StringAttr::get(ctx, "VEC"))));
302 
303   patternsVector.emplace_back(
304       ctx, std::make_unique<LinalgVectorizationPattern>(
305                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
306                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
307   patternsVector.back().add<LinalgVectorizationPattern>(
308       ctx, LinalgTransformationFilter().addOpFilter<FillOp>());
309   patternsVector.back().add<CopyVectorizationPattern>(ctx);
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // Test promotion callbacks
314 //===----------------------------------------------------------------------===//
315 
316 // Allocation call back
317 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
318                                        ArrayRef<Value> boundingSubViewSize,
319                                        DataLayout &layout) {
320   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
321   return b
322       .create<memref::AllocOp>(
323           subView.getLoc(),
324           MemRefType::get(shape, subView.getType().getElementType(),
325                           /*affineMapComposition =*/{}, 3),
326           boundingSubViewSize)
327       .getResult();
328 }
329 
330 // Deallocation callback
331 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
332   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
333   return success();
334 }
335 
336 // Copy in call back
337 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
338                                     bool isOutput) {
339   auto floatType = src.getType().cast<MemRefType>().getElementType();
340   if (!floatType.isa<FloatType>())
341     return failure();
342   if (!isOutput) {
343     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
344                                             FloatAttr::get(floatType, 42.0));
345     b.create<FillOp>(src.getLoc(), cst, dst);
346   }
347   b.create<memref::CopyOp>(src.getLoc(), src, dst);
348   return success();
349 }
350 
351 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
352                                           RewritePatternSet &patterns) {
353   patterns.add<LinalgTilingPattern>(
354       MatmulOp::getOperationName(), ctx,
355       LinalgTilingOptions().setTileSizes({16, 16, 16}),
356       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
357                                  StringAttr::get(ctx, "PROMOTE")));
358   patterns.add<LinalgPromotionPattern<MatmulOp>>(
359       ctx,
360       LinalgPromotionOptions()
361           .setOperandsToPromote({0, 2})
362           .setUseFullTileBuffers({false, false})
363           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
364           .setCopyInOutFns(
365               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
366                 return copyCallBackFn(b, src, dst, false);
367               },
368               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
369                 return copyCallBackFn(b, src, dst, true);
370               }),
371       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
372 }
373 
374 template <typename IdOp, typename NProcsOp>
375 static SmallVector<ProcInfo, 2>
376 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
377   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
378   SmallVector<ProcInfo, 2> procInfo(count);
379   Type indexType = b.getIndexType();
380   for (unsigned i = 0; i < count; ++i) {
381     gpu::Dimension dim = *gpu::symbolizeDimension(i);
382     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
383                                b.create<NProcsOp>(loc, indexType, dim)};
384   }
385   return procInfo;
386 }
387 
388 static void fillTileAndDistributePatterns(MLIRContext *context,
389                                           RewritePatternSet &patterns) {
390   {
391     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
392     cyclicNprocsEqNiters.distributionMethod.resize(
393         2, DistributionMethod::CyclicNumProcsEqNumIters);
394     cyclicNprocsEqNiters.procInfo =
395         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
396     patterns.add<LinalgTilingPattern>(
397         MatmulOp::getOperationName(), context,
398         LinalgTilingOptions()
399             .setTileSizes({8, 8, 4})
400             .setLoopType(LinalgTilingLoopType::ParallelLoops)
401             .setDistributionOptions(cyclicNprocsEqNiters),
402         LinalgTransformationFilter(
403             StringAttr::get(context, "distribute1"),
404             StringAttr::get(context, "after_distribute1")));
405   }
406 
407   {
408     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
409     cyclicNprocsGeNiters.distributionMethod.resize(
410         2, DistributionMethod::CyclicNumProcsGeNumIters);
411     cyclicNprocsGeNiters.procInfo =
412         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
413     patterns.add<LinalgTilingPattern>(
414         MatmulOp::getOperationName(), context,
415         LinalgTilingOptions()
416             .setTileSizes({8, 8, 4})
417             .setLoopType(LinalgTilingLoopType::ParallelLoops)
418             .setDistributionOptions(cyclicNprocsGeNiters),
419         LinalgTransformationFilter(
420             StringAttr::get(context, "distribute2"),
421             StringAttr::get(context, "after_distribute2")));
422   }
423 
424   {
425     LinalgLoopDistributionOptions cyclicNprocsDefault;
426     cyclicNprocsDefault.distributionMethod.resize(2,
427                                                   DistributionMethod::Cyclic);
428     cyclicNprocsDefault.procInfo =
429         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
430     patterns.add<LinalgTilingPattern>(
431         MatmulOp::getOperationName(), context,
432         LinalgTilingOptions()
433             .setTileSizes({8, 8, 4})
434             .setLoopType(LinalgTilingLoopType::ParallelLoops)
435             .setDistributionOptions(cyclicNprocsDefault),
436         LinalgTransformationFilter(
437             StringAttr::get(context, "distribute3"),
438             StringAttr::get(context, "after_distribute3")));
439   }
440 
441   {
442     LinalgLoopDistributionOptions cyclicNprocsMixed1;
443     cyclicNprocsMixed1.distributionMethod = {
444         DistributionMethod::CyclicNumProcsEqNumIters,
445         DistributionMethod::CyclicNumProcsGeNumIters};
446     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
447     patterns.add<LinalgTilingPattern>(
448         MatmulOp::getOperationName(), context,
449         LinalgTilingOptions()
450             .setTileSizes({8, 8, 4})
451             .setLoopType(LinalgTilingLoopType::ParallelLoops)
452             .setDistributionOptions(cyclicNprocsMixed1),
453         LinalgTransformationFilter(
454             StringAttr::get(context, "distribute4"),
455             StringAttr::get(context, "after_distribute4")));
456   }
457 
458   {
459     LinalgLoopDistributionOptions cyclicNprocsMixed2;
460     cyclicNprocsMixed2.distributionMethod = {
461         DistributionMethod::CyclicNumProcsGeNumIters,
462         DistributionMethod::Cyclic};
463     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
464     patterns.add<LinalgTilingPattern>(
465         MatmulOp::getOperationName(), context,
466         LinalgTilingOptions()
467             .setTileSizes({8, 8, 4})
468             .setLoopType(LinalgTilingLoopType::ParallelLoops)
469             .setDistributionOptions(cyclicNprocsMixed2),
470         LinalgTransformationFilter(
471             StringAttr::get(context, "distribute5"),
472             StringAttr::get(context, "after_distribute5")));
473   }
474 
475   {
476     LinalgLoopDistributionOptions cyclicNprocsMixed3;
477     cyclicNprocsMixed3.distributionMethod = {
478         DistributionMethod::Cyclic,
479         DistributionMethod::CyclicNumProcsEqNumIters};
480     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
481 
482     patterns.add<LinalgTilingPattern>(
483         MatmulOp::getOperationName(), context,
484         LinalgTilingOptions()
485             .setTileSizes({8, 8, 4})
486             .setLoopType(LinalgTilingLoopType::ParallelLoops)
487             .setDistributionOptions(cyclicNprocsMixed3),
488         LinalgTransformationFilter(
489             StringAttr::get(context, "distribute6"),
490             StringAttr::get(context, "after_distribute6")));
491   }
492 
493   {
494     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
495     cyclicNprocsEqNiters.distributionMethod.resize(2,
496                                                    DistributionMethod::Cyclic);
497     cyclicNprocsEqNiters.procInfo =
498         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
499     patterns.add<LinalgTilingPattern>(
500         MatmulOp::getOperationName(), context,
501         LinalgTilingOptions()
502             .setTileSizes({8, 8, 4})
503             .setLoopType(LinalgTilingLoopType::Loops)
504             .setDistributionOptions(cyclicNprocsEqNiters),
505         LinalgTransformationFilter(
506             StringAttr::get(context, "tensors_distribute1"),
507             StringAttr::get(context, "tensors_after_distribute1")));
508   }
509 }
510 
511 static void fillTileFuseAndDistributePatterns(MLIRContext *context,
512                                               RewritePatternSet &patterns) {
513   LinalgLoopDistributionOptions cyclicNprocsEqNiters;
514   cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
515   cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
516   patterns.add<LinalgTileAndFuseTensorOpsPattern>(
517       MatmulOp::getOperationName(), context,
518       LinalgTilingAndFusionOptions()
519           .setTileSizes({8, 8, 4})
520           .setDistributionOptions(cyclicNprocsEqNiters),
521       LinalgTransformationFilter(
522           StringAttr::get(context, "tensors_fuse_distribute1"),
523           StringAttr::get(context, "tensors_after_fuse_distribute1")));
524 }
525 
526 static void
527 applyMatmulToVectorPatterns(FuncOp funcOp,
528                             bool testMatmulToVectorPatterns1dTiling,
529                             bool testMatmulToVectorPatterns2dTiling) {
530   MLIRContext *ctx = funcOp.getContext();
531   SmallVector<RewritePatternSet, 4> stage1Patterns;
532   if (testMatmulToVectorPatterns1dTiling) {
533     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
534   } else if (testMatmulToVectorPatterns2dTiling) {
535     stage1Patterns.emplace_back(
536         ctx, std::make_unique<LinalgTilingPattern>(
537                  MatmulOp::getOperationName(), ctx,
538                  LinalgTilingOptions()
539                      .setTileSizes({768, 264, 768})
540                      .setInterchange({1, 2, 0}),
541                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
542                                             StringAttr::get(ctx, "L2"))));
543     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
544   }
545   {
546     // Canonicalization patterns
547     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
548     vector::populateVectorTransferPermutationMapLoweringPatterns(
549         canonicalizationPatterns);
550     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
551     stage1Patterns.push_back(std::move(canonicalizationPatterns));
552   }
553   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
554   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
555   FrozenRewritePatternSet stage2Patterns =
556       getLinalgTilingCanonicalizationPatterns(ctx);
557   (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns);
558 }
559 
560 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
561   RewritePatternSet forwardPattern(funcOp.getContext());
562   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
563   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
564   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
565 }
566 
567 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
568   RewritePatternSet patterns(funcOp.getContext());
569   auto *ctx = funcOp.getContext();
570   patterns.add<LinalgVectorizationPattern>(
571       ctx, LinalgTransformationFilter()
572                .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
573   patterns.add<CopyVectorizationPattern>(ctx);
574   populatePadOpVectorizationPatterns(patterns);
575   populateConvolutionVectorizationPatterns(patterns);
576   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
577 }
578 
579 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
580   RewritePatternSet patterns(funcOp.getContext());
581   patterns.add<PadOpTransformationPattern>(funcOp.getContext());
582   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
583 }
584 
585 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
586   RewritePatternSet patterns(funcOp.getContext());
587   patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
588   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
589 }
590 
591 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
592   RewritePatternSet patterns(funcOp.getContext());
593   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
594   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
595 }
596 
597 static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
598                              ArrayRef<int64_t> tileSizes,
599                              ArrayRef<int64_t> peeledLoops,
600                              bool scalarizeDynamicDims) {
601   MLIRContext *context = funcOp.getContext();
602   RewritePatternSet tilingPattern(context);
603   LinalgTilingLoopType type =
604       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
605           .Case("for", LinalgTilingLoopType::Loops)
606           .Case("affine", LinalgTilingLoopType::AffineLoops)
607           .Case("parallel", LinalgTilingLoopType::ParallelLoops);
608   auto linalgTilingOptions = linalg::LinalgTilingOptions()
609                                  .setPeeledLoops(peeledLoops)
610                                  .setLoopType(type);
611   if (scalarizeDynamicDims) {
612     linalgTilingOptions.scalarizeDynamicDims();
613     assert(tileSizes.empty() &&
614            "tileSizes and scalarizeDynamicDims is mutually exclusive");
615   } else {
616     linalgTilingOptions.setTileSizes(tileSizes);
617   }
618   linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
619   TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
620       tilingPattern, linalgTilingOptions, f);
621   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
622 }
623 
624 static void applySplitReduction(FuncOp funcOp) {
625   RewritePatternSet patterns(funcOp.getContext());
626   linalg::populateSplitReductionPattern(
627       patterns,
628       [](LinalgOp op) {
629         unsigned insertDimIndex = op.getNumLoops() - 1;
630         return std::make_pair(4, insertDimIndex);
631       },
632       LinalgTransformationFilter(
633           ArrayRef<StringAttr>{},
634           StringAttr::get(funcOp.getContext(), "SPLIT")));
635   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
636 }
637 
638 /// Apply transformations specified as patterns.
639 void TestLinalgTransforms::runOnOperation() {
640   auto lambda = [&](void *) {
641     getOperation().walk([](LinalgOp op) {
642       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
643     });
644   };
645   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
646 
647   if (testPromotionOptions) {
648     RewritePatternSet patterns(&getContext());
649     fillPromotionCallBackPatterns(&getContext(), patterns);
650     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
651     return;
652   }
653   if (testTileAndDistributionOptions) {
654     RewritePatternSet patterns(&getContext());
655     fillTileAndDistributePatterns(&getContext(), patterns);
656     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
657     return;
658   }
659   if (testTileFuseAndDistributionOptions) {
660     RewritePatternSet patterns(&getContext());
661     fillTileFuseAndDistributePatterns(&getContext(), patterns);
662     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
663     return;
664   }
665   if (testPatterns)
666     return applyPatterns(getOperation());
667   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
668     return applyMatmulToVectorPatterns(getOperation(),
669                                        testMatmulToVectorPatterns1dTiling,
670                                        testMatmulToVectorPatterns2dTiling);
671   if (testVectorTransferForwardingPatterns)
672     return applyVectorTransferForwardingPatterns(getOperation());
673   if (testGenericToVectorPattern)
674     return applyLinalgToVectorPatterns(getOperation());
675   if (testTransformPadTensor)
676     return applyPadTensorToGenericPatterns(getOperation());
677   if (testGeneralizePadTensor)
678     return applyGeneralizePadTensorPatterns(getOperation());
679   if (testSwapSubTensorPadTensor)
680     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
681   if (testTilePattern)
682     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
683                             /*scalarizeDynamicDims=*/false);
684   if (testTileScalarizeDynamicDims)
685     return applyTilePattern(getOperation(), loopType, tileSizes,
686                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
687   if (testSplitReduction)
688     return applySplitReduction(getOperation());
689 }
690 
691 namespace mlir {
692 namespace test {
693 void registerTestLinalgTransforms() {
694   PassRegistration<TestLinalgTransforms>();
695 }
696 } // namespace test
697 } // namespace mlir
698