1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/GPU/GPUDialect.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct TestLinalgTransforms
35     : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> {
36   TestLinalgTransforms() = default;
37   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
38 
39   void getDependentDialects(DialectRegistry &registry) const override {
40     // clang-format off
41     registry.insert<AffineDialect,
42                     memref::MemRefDialect,
43                     scf::SCFDialect,
44                     linalg::LinalgDialect,
45                     vector::VectorDialect,
46                     gpu::GPUDialect>();
47     // clang-format on
48   }
49   StringRef getArgument() const final {
50     return "test-linalg-transform-patterns";
51   }
52   StringRef getDescription() const final {
53     return "Test Linalg transformation patterns by applying them greedily.";
54   }
55 
56   void runOnOperation() override;
57 
58   Option<bool> testPatterns{*this, "test-patterns",
59                             llvm::cl::desc("Test a mixed set of patterns"),
60                             llvm::cl::init(false)};
61   Option<bool> testMatmulToVectorPatterns1dTiling{
62       *this, "test-matmul-to-vector-patterns-tile-1d",
63       llvm::cl::desc(
64           "Test a fused pass that applies patterns from matmul to vectors via "
65           "1-d tiling"),
66       llvm::cl::init(false)};
67   Option<bool> testMatmulToVectorPatterns2dTiling{
68       *this, "test-matmul-to-vector-patterns-tile-2d",
69       llvm::cl::desc(
70           "Test a fused pass that applies patterns from matmul to vectors via "
71           "2-d tiling"),
72       llvm::cl::init(false)};
73   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
74                                     llvm::cl::desc("Test promotion options"),
75                                     llvm::cl::init(false)};
76   Option<bool> testTileAndDistributionOptions{
77       *this, "test-tile-and-distribute-options",
78       llvm::cl::desc("Test tile and distribute options"),
79       llvm::cl::init(false)};
80   Option<bool> testTileFuseAndDistributionOptions{
81       *this, "test-tile-fuse-and-distribute-options",
82       llvm::cl::desc("Test tile, fuse and distribute options"),
83       llvm::cl::init(false)};
84   Option<bool> testVectorTransferForwardingPatterns{
85       *this, "test-vector-transfer-forwarding-patterns",
86       llvm::cl::desc(
87           "Test a fused pass that forwards memref.copy to vector.transfer"),
88       llvm::cl::init(false)};
89   Option<bool> testGenericToVectorPattern{
90       *this, "test-linalg-to-vector-patterns",
91       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
92                      "in vector.contract form"),
93       llvm::cl::init(false)};
94   Option<bool> testTilePattern{*this, "test-tile-pattern",
95                                llvm::cl::desc("Test tile pattern"),
96                                llvm::cl::init(false)};
97   Option<bool> testTileScalarizeDynamicDims{
98       *this, "test-tile-scalarize-dynamic-dims",
99       llvm::cl::desc("Test tiling of dynamic dims by 1"),
100       llvm::cl::init(false)};
101   Option<bool> testTransformPadTensor{
102       *this, "test-transform-pad-tensor",
103       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
104       llvm::cl::init(false)};
105   Option<bool> testGeneralizePadTensor{
106       *this, "test-generalize-pad-tensor",
107       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
108       llvm::cl::init(false)};
109   Option<bool> testSwapSubTensorPadTensor{
110       *this, "test-swap-subtensor-padtensor",
111       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
112                      "pad_tensor(subtensor)"),
113       llvm::cl::init(false)};
114   Option<bool> testSplitReduction{
115       *this, "test-split-reduction",
116       llvm::cl::desc("Test split reduction transformation"),
117       llvm::cl::init(false)};
118   ListOption<int64_t> peeledLoops{
119       *this, "peeled-loops",
120       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
121       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
122   ListOption<int64_t> tileSizes{
123       *this, "tile-sizes",
124       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
125       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
126   Option<bool> skipPartial{
127       *this, "skip-partial",
128       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
129       llvm::cl::init(false)};
130   Option<std::string> loopType{
131       *this, "loop-type",
132       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
133                      "tiled_loop"),
134       llvm::cl::init("for")};
135   Option<bool> testBubbleUpExtractSliceOpPattern{
136       *this, "test-bubble-up-extract-slice-op-pattern",
137       llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
138                      "extract_slice + linalgOp"),
139       llvm::cl::init(false)};
140 };
141 } // namespace
142 
143 static void applyPatterns(FuncOp funcOp) {
144   MLIRContext *ctx = funcOp.getContext();
145   RewritePatternSet patterns(ctx);
146 
147   //===--------------------------------------------------------------------===//
148   // Linalg tiling patterns.
149   //===--------------------------------------------------------------------===//
150   patterns.add<LinalgTilingPattern>(
151       MatmulOp::getOperationName(), ctx,
152       LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
153       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
154                                  StringAttr::get(ctx, "L3")));
155   patterns.add<LinalgTilingPattern>(
156       MatmulOp::getOperationName(), ctx,
157       LinalgTilingOptions().setTileSizes({200, 300, 400}),
158       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
159                                  StringAttr::get(ctx, "L2")));
160   patterns.add<LinalgTilingPattern>(
161       MatmulOp::getOperationName(), ctx,
162       LinalgTilingOptions().setTileSizes({20, 30, 40}),
163       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
164                                  StringAttr::get(ctx, "L1")));
165   patterns.add<LinalgTilingPattern>(
166       MatmulOp::getOperationName(), ctx,
167       LinalgTilingOptions().setTileSizes({2, 3, 4}),
168       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
169                                  StringAttr::get(ctx, "REG")));
170 
171   patterns.add<LinalgTilingPattern>(
172       MatvecOp::getOperationName(), ctx,
173       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
174           LinalgTilingLoopType::ParallelLoops),
175       LinalgTransformationFilter(ArrayRef<StringAttr>{},
176                                  StringAttr::get(ctx, "L1")));
177 
178   patterns.add<LinalgTilingPattern>(
179       DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
180       LinalgTransformationFilter(
181           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
182                                StringAttr::get(ctx, "L3"),
183                                StringAttr::get(ctx, "L2")},
184           StringAttr::get(ctx, "REG")));
185 
186   //===--------------------------------------------------------------------===//
187   // Linalg tiling and permutation patterns.
188   //===--------------------------------------------------------------------===//
189   patterns.add<LinalgTilingPattern>(
190       MatmulOp::getOperationName(), ctx,
191       LinalgTilingOptions()
192           .setTileSizes({2000, 3000, 4000})
193           .setInterchange({1, 2, 0}),
194       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
195                                  StringAttr::get(ctx, "L2__with_perm__")));
196   patterns.add<LinalgTilingPattern>(
197       MatmulOp::getOperationName(), ctx,
198       LinalgTilingOptions()
199           .setTileSizes({200, 300, 400})
200           .setInterchange({1, 0, 2}),
201       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
202                                  StringAttr::get(ctx, "L1__with_perm__")));
203   patterns.add<LinalgTilingPattern>(
204       MatmulOp::getOperationName(), ctx,
205       LinalgTilingOptions().setTileSizes({20, 30, 40}),
206       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
207                                  StringAttr::get(ctx, "REG__with_perm__")));
208 
209   patterns.add<LinalgTilingPattern>(
210       MatvecOp::getOperationName(), ctx,
211       LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
212       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
213                                  StringAttr::get(ctx, "L1__with_perm__")));
214 
215   patterns.add<LinalgTilingPattern>(
216       MatmulOp::getOperationName(), ctx,
217       LinalgTilingOptions()
218           .setTileSizes({16, 8, 4})
219           .setInterchange({1, 2, 0})
220           .setLoopType(LinalgTilingLoopType::ParallelLoops),
221       LinalgTransformationFilter(
222           StringAttr::get(ctx, "par__with_perm__"),
223           StringAttr::get(ctx, "after_par__with_perm__")));
224 
225   //===--------------------------------------------------------------------===//
226   // Linalg to loops patterns.
227   //===--------------------------------------------------------------------===//
228   patterns.add<LinalgLoweringPattern<DotOp>>(
229       ctx,
230       /*loweringType=*/LinalgLoweringType::Loops,
231       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
232 
233   //===--------------------------------------------------------------------===//
234   // Linalg distribution patterns.
235   //===--------------------------------------------------------------------===//
236   LinalgLoopDistributionOptions distributionOptions;
237 
238   //===--------------------------------------------------------------------===//
239   // Linalg to vector contraction patterns.
240   //===--------------------------------------------------------------------===//
241   patterns.add<LinalgVectorizationPattern>(
242       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
243                .addOpFilter<MatmulOp, FillOp, GenericOp>());
244   patterns.add<CopyVectorizationPattern>(ctx);
245 
246   //===--------------------------------------------------------------------===//
247   // Linalg generic interchange pattern.
248   //===--------------------------------------------------------------------===//
249   patterns.add<GenericOpInterchangePattern>(
250       ctx,
251       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
252       LinalgTransformationFilter(ArrayRef<StringAttr>{},
253                                  StringAttr::get(ctx, "PERMUTED")));
254 
255   //===--------------------------------------------------------------------===//
256   // Linalg subview operands promotion.
257   //===--------------------------------------------------------------------===//
258   patterns.add<LinalgPromotionPattern<MatmulOp>>(
259       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
260       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
261                                  StringAttr::get(ctx, "_views_promoted_")));
262   patterns.add<LinalgPromotionPattern<MatmulOp>>(
263       ctx,
264       LinalgPromotionOptions()
265           .setOperandsToPromote({0})
266           .setUseFullTileBuffersByDefault(true),
267       LinalgTransformationFilter(
268           StringAttr::get(ctx, "_promote_first_view_"),
269           StringAttr::get(ctx, "_first_view_promoted_")));
270   patterns.add<LinalgPromotionPattern<FillOp>>(
271       ctx,
272       LinalgPromotionOptions()
273           .setOperandsToPromote({1})
274           .setUseFullTileBuffers({false, true})
275           .setAlignment(32),
276       LinalgTransformationFilter(
277           StringAttr::get(ctx, "_promote_views_aligned_"),
278           StringAttr::get(ctx, "_views_aligned_promoted_")));
279 
280   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
281 
282   // Drop the marker.
283   funcOp.walk([](LinalgOp op) {
284     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
285   });
286 }
287 
288 static void fillL1TilingAndMatmulToVectorPatterns(
289     FuncOp funcOp, StringRef startMarker,
290     SmallVectorImpl<RewritePatternSet> &patternsVector) {
291   MLIRContext *ctx = funcOp.getContext();
292   patternsVector.emplace_back(
293       ctx, std::make_unique<LinalgTilingPattern>(
294                MatmulOp::getOperationName(), ctx,
295                LinalgTilingOptions()
296                    .setTileSizes({8, 12, 16})
297                    .setInterchange({1, 0, 2}),
298                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
299                                           StringAttr::get(ctx, "L1"))));
300 
301   patternsVector.emplace_back(
302       ctx,
303       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
304           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
305           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
306                                      StringAttr::get(ctx, "VEC"))));
307 
308   patternsVector.emplace_back(
309       ctx, std::make_unique<LinalgVectorizationPattern>(
310                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
311                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
312   patternsVector.back().add<LinalgVectorizationPattern>(
313       ctx, LinalgTransformationFilter().addOpFilter<FillOp>());
314   patternsVector.back().add<CopyVectorizationPattern>(ctx);
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // Test promotion callbacks
319 //===----------------------------------------------------------------------===//
320 
321 // Allocation call back
322 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
323                                        ArrayRef<Value> boundingSubViewSize,
324                                        DataLayout &layout) {
325   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
326   return b
327       .create<memref::AllocOp>(
328           subView.getLoc(),
329           MemRefType::get(shape, subView.getType().getElementType(),
330                           /*affineMapComposition =*/{}, 3),
331           boundingSubViewSize)
332       .getResult();
333 }
334 
335 // Deallocation callback
336 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
337   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
338   return success();
339 }
340 
341 // Copy in call back
342 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
343                                     bool isOutput) {
344   auto floatType = src.getType().cast<MemRefType>().getElementType();
345   if (!floatType.isa<FloatType>())
346     return failure();
347   if (!isOutput) {
348     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
349                                             FloatAttr::get(floatType, 42.0));
350     b.create<FillOp>(src.getLoc(), cst, dst);
351   }
352   b.create<memref::CopyOp>(src.getLoc(), src, dst);
353   return success();
354 }
355 
356 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
357                                           RewritePatternSet &patterns) {
358   patterns.add<LinalgTilingPattern>(
359       MatmulOp::getOperationName(), ctx,
360       LinalgTilingOptions().setTileSizes({16, 16, 16}),
361       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
362                                  StringAttr::get(ctx, "PROMOTE")));
363   patterns.add<LinalgPromotionPattern<MatmulOp>>(
364       ctx,
365       LinalgPromotionOptions()
366           .setOperandsToPromote({0, 2})
367           .setUseFullTileBuffers({false, false})
368           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
369           .setCopyInOutFns(
370               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
371                 return copyCallBackFn(b, src, dst, false);
372               },
373               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
374                 return copyCallBackFn(b, src, dst, true);
375               }),
376       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
377 }
378 
379 template <typename IdOp, typename NProcsOp>
380 static SmallVector<ProcInfo, 2>
381 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
382   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
383   SmallVector<ProcInfo, 2> procInfo(count);
384   Type indexType = b.getIndexType();
385   for (unsigned i = 0; i < count; ++i) {
386     gpu::Dimension dim = *gpu::symbolizeDimension(i);
387     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
388                                b.create<NProcsOp>(loc, indexType, dim)};
389   }
390   return procInfo;
391 }
392 
393 static void fillTileAndDistributePatterns(MLIRContext *context,
394                                           RewritePatternSet &patterns) {
395   {
396     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
397     cyclicNprocsEqNiters.distributionMethod.resize(
398         2, DistributionMethod::CyclicNumProcsEqNumIters);
399     cyclicNprocsEqNiters.procInfo =
400         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
401     patterns.add<LinalgTilingPattern>(
402         MatmulOp::getOperationName(), context,
403         LinalgTilingOptions()
404             .setTileSizes({8, 8, 4})
405             .setLoopType(LinalgTilingLoopType::ParallelLoops)
406             .setDistributionOptions(cyclicNprocsEqNiters),
407         LinalgTransformationFilter(
408             StringAttr::get(context, "distribute1"),
409             StringAttr::get(context, "after_distribute1")));
410   }
411 
412   {
413     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
414     cyclicNprocsGeNiters.distributionMethod.resize(
415         2, DistributionMethod::CyclicNumProcsGeNumIters);
416     cyclicNprocsGeNiters.procInfo =
417         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
418     patterns.add<LinalgTilingPattern>(
419         MatmulOp::getOperationName(), context,
420         LinalgTilingOptions()
421             .setTileSizes({8, 8, 4})
422             .setLoopType(LinalgTilingLoopType::ParallelLoops)
423             .setDistributionOptions(cyclicNprocsGeNiters),
424         LinalgTransformationFilter(
425             StringAttr::get(context, "distribute2"),
426             StringAttr::get(context, "after_distribute2")));
427   }
428 
429   {
430     LinalgLoopDistributionOptions cyclicNprocsDefault;
431     cyclicNprocsDefault.distributionMethod.resize(2,
432                                                   DistributionMethod::Cyclic);
433     cyclicNprocsDefault.procInfo =
434         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
435     patterns.add<LinalgTilingPattern>(
436         MatmulOp::getOperationName(), context,
437         LinalgTilingOptions()
438             .setTileSizes({8, 8, 4})
439             .setLoopType(LinalgTilingLoopType::ParallelLoops)
440             .setDistributionOptions(cyclicNprocsDefault),
441         LinalgTransformationFilter(
442             StringAttr::get(context, "distribute3"),
443             StringAttr::get(context, "after_distribute3")));
444   }
445 
446   {
447     LinalgLoopDistributionOptions cyclicNprocsMixed1;
448     cyclicNprocsMixed1.distributionMethod = {
449         DistributionMethod::CyclicNumProcsEqNumIters,
450         DistributionMethod::CyclicNumProcsGeNumIters};
451     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
452     patterns.add<LinalgTilingPattern>(
453         MatmulOp::getOperationName(), context,
454         LinalgTilingOptions()
455             .setTileSizes({8, 8, 4})
456             .setLoopType(LinalgTilingLoopType::ParallelLoops)
457             .setDistributionOptions(cyclicNprocsMixed1),
458         LinalgTransformationFilter(
459             StringAttr::get(context, "distribute4"),
460             StringAttr::get(context, "after_distribute4")));
461   }
462 
463   {
464     LinalgLoopDistributionOptions cyclicNprocsMixed2;
465     cyclicNprocsMixed2.distributionMethod = {
466         DistributionMethod::CyclicNumProcsGeNumIters,
467         DistributionMethod::Cyclic};
468     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
469     patterns.add<LinalgTilingPattern>(
470         MatmulOp::getOperationName(), context,
471         LinalgTilingOptions()
472             .setTileSizes({8, 8, 4})
473             .setLoopType(LinalgTilingLoopType::ParallelLoops)
474             .setDistributionOptions(cyclicNprocsMixed2),
475         LinalgTransformationFilter(
476             StringAttr::get(context, "distribute5"),
477             StringAttr::get(context, "after_distribute5")));
478   }
479 
480   {
481     LinalgLoopDistributionOptions cyclicNprocsMixed3;
482     cyclicNprocsMixed3.distributionMethod = {
483         DistributionMethod::Cyclic,
484         DistributionMethod::CyclicNumProcsEqNumIters};
485     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
486 
487     patterns.add<LinalgTilingPattern>(
488         MatmulOp::getOperationName(), context,
489         LinalgTilingOptions()
490             .setTileSizes({8, 8, 4})
491             .setLoopType(LinalgTilingLoopType::ParallelLoops)
492             .setDistributionOptions(cyclicNprocsMixed3),
493         LinalgTransformationFilter(
494             StringAttr::get(context, "distribute6"),
495             StringAttr::get(context, "after_distribute6")));
496   }
497 
498   {
499     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
500     cyclicNprocsEqNiters.distributionMethod.resize(2,
501                                                    DistributionMethod::Cyclic);
502     cyclicNprocsEqNiters.procInfo =
503         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
504     patterns.add<LinalgTilingPattern>(
505         MatmulOp::getOperationName(), context,
506         LinalgTilingOptions()
507             .setTileSizes({8, 8, 4})
508             .setLoopType(LinalgTilingLoopType::Loops)
509             .setDistributionOptions(cyclicNprocsEqNiters),
510         LinalgTransformationFilter(
511             StringAttr::get(context, "tensors_distribute1"),
512             StringAttr::get(context, "tensors_after_distribute1")));
513   }
514 }
515 
516 static void fillTileFuseAndDistributePatterns(MLIRContext *context,
517                                               RewritePatternSet &patterns) {
518   LinalgLoopDistributionOptions cyclicNprocsEqNiters;
519   cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
520   cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
521   patterns.add<LinalgTileAndFuseTensorOpsPattern>(
522       MatmulOp::getOperationName(), context,
523       LinalgTilingAndFusionOptions()
524           .setTileSizes({8, 8, 4})
525           .setDistributionOptions(cyclicNprocsEqNiters),
526       LinalgTransformationFilter(
527           StringAttr::get(context, "tensors_fuse_distribute1"),
528           StringAttr::get(context, "tensors_after_fuse_distribute1")));
529 }
530 
531 static void
532 applyMatmulToVectorPatterns(FuncOp funcOp,
533                             bool testMatmulToVectorPatterns1dTiling,
534                             bool testMatmulToVectorPatterns2dTiling) {
535   MLIRContext *ctx = funcOp.getContext();
536   SmallVector<RewritePatternSet, 4> stage1Patterns;
537   if (testMatmulToVectorPatterns1dTiling) {
538     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
539   } else if (testMatmulToVectorPatterns2dTiling) {
540     stage1Patterns.emplace_back(
541         ctx, std::make_unique<LinalgTilingPattern>(
542                  MatmulOp::getOperationName(), ctx,
543                  LinalgTilingOptions()
544                      .setTileSizes({768, 264, 768})
545                      .setInterchange({1, 2, 0}),
546                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
547                                             StringAttr::get(ctx, "L2"))));
548     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
549   }
550   {
551     // Canonicalization patterns
552     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
553     vector::populateVectorTransferPermutationMapLoweringPatterns(
554         canonicalizationPatterns);
555     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
556     stage1Patterns.push_back(std::move(canonicalizationPatterns));
557   }
558   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
559   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
560   FrozenRewritePatternSet stage2Patterns =
561       getLinalgTilingCanonicalizationPatterns(ctx);
562   (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns);
563 }
564 
565 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
566   RewritePatternSet forwardPattern(funcOp.getContext());
567   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
568   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
569   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
570 }
571 
572 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
573   RewritePatternSet patterns(funcOp.getContext());
574   auto *ctx = funcOp.getContext();
575   patterns.add<LinalgVectorizationPattern>(
576       ctx, LinalgTransformationFilter()
577                .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
578   patterns.add<CopyVectorizationPattern>(ctx);
579   populatePadOpVectorizationPatterns(patterns);
580   populateConvolutionVectorizationPatterns(patterns);
581   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
582 }
583 
584 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
585   RewritePatternSet patterns(funcOp.getContext());
586   patterns.add<PadOpTransformationPattern>(funcOp.getContext());
587   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
588 }
589 
590 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
591   RewritePatternSet patterns(funcOp.getContext());
592   patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
593   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
594 }
595 
596 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
597   RewritePatternSet patterns(funcOp.getContext());
598   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
599   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
600 }
601 
602 static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
603                              ArrayRef<int64_t> tileSizes,
604                              ArrayRef<int64_t> peeledLoops,
605                              bool scalarizeDynamicDims) {
606   MLIRContext *context = funcOp.getContext();
607   RewritePatternSet tilingPattern(context);
608   LinalgTilingLoopType type =
609       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
610           .Case("for", LinalgTilingLoopType::Loops)
611           .Case("affine", LinalgTilingLoopType::AffineLoops)
612           .Case("parallel", LinalgTilingLoopType::ParallelLoops);
613   auto linalgTilingOptions = linalg::LinalgTilingOptions()
614                                  .setPeeledLoops(peeledLoops)
615                                  .setLoopType(type);
616   if (scalarizeDynamicDims) {
617     linalgTilingOptions.scalarizeDynamicDims();
618     assert(tileSizes.empty() &&
619            "tileSizes and scalarizeDynamicDims is mutually exclusive");
620   } else {
621     linalgTilingOptions.setTileSizes(tileSizes);
622   }
623   linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
624   TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
625       tilingPattern, linalgTilingOptions, f);
626   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
627 }
628 
629 static void applySplitReduction(FuncOp funcOp) {
630   RewritePatternSet patterns(funcOp.getContext());
631   linalg::populateSplitReductionPattern(
632       patterns,
633       [](LinalgOp op) {
634         unsigned insertDimIndex = op.getNumLoops() - 1;
635         return std::make_pair(4, insertDimIndex);
636       },
637       LinalgTransformationFilter(
638           ArrayRef<StringAttr>{},
639           StringAttr::get(funcOp.getContext(), "SPLIT")));
640   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
641 }
642 
643 static void applyBubbleUpExtractSliceOpPattern(FuncOp funcOp) {
644   RewritePatternSet patterns(funcOp.getContext());
645   populateBubbleUpExtractSliceOpPatterns(patterns);
646   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
647 }
648 
649 /// Apply transformations specified as patterns.
650 void TestLinalgTransforms::runOnOperation() {
651   auto lambda = [&](void *) {
652     getOperation().walk([](LinalgOp op) {
653       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
654     });
655   };
656   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
657 
658   if (testPromotionOptions) {
659     RewritePatternSet patterns(&getContext());
660     fillPromotionCallBackPatterns(&getContext(), patterns);
661     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
662     return;
663   }
664   if (testTileAndDistributionOptions) {
665     RewritePatternSet patterns(&getContext());
666     fillTileAndDistributePatterns(&getContext(), patterns);
667     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
668     return;
669   }
670   if (testTileFuseAndDistributionOptions) {
671     RewritePatternSet patterns(&getContext());
672     fillTileFuseAndDistributePatterns(&getContext(), patterns);
673     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
674     return;
675   }
676   if (testPatterns)
677     return applyPatterns(getOperation());
678   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
679     return applyMatmulToVectorPatterns(getOperation(),
680                                        testMatmulToVectorPatterns1dTiling,
681                                        testMatmulToVectorPatterns2dTiling);
682   if (testVectorTransferForwardingPatterns)
683     return applyVectorTransferForwardingPatterns(getOperation());
684   if (testGenericToVectorPattern)
685     return applyLinalgToVectorPatterns(getOperation());
686   if (testTransformPadTensor)
687     return applyPadTensorToGenericPatterns(getOperation());
688   if (testGeneralizePadTensor)
689     return applyGeneralizePadTensorPatterns(getOperation());
690   if (testSwapSubTensorPadTensor)
691     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
692   if (testTilePattern)
693     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
694                             /*scalarizeDynamicDims=*/false);
695   if (testTileScalarizeDynamicDims)
696     return applyTilePattern(getOperation(), loopType, tileSizes,
697                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
698   if (testSplitReduction)
699     return applySplitReduction(getOperation());
700   if (testBubbleUpExtractSliceOpPattern)
701     return applyBubbleUpExtractSliceOpPattern(getOperation());
702 }
703 
704 namespace mlir {
705 namespace test {
706 void registerTestLinalgTransforms() {
707   PassRegistration<TestLinalgTransforms>();
708 }
709 } // namespace test
710 } // namespace mlir
711