1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct TestLinalgTransforms
35     : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> {
36   TestLinalgTransforms() = default;
37   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
38 
39   void getDependentDialects(DialectRegistry &registry) const override {
40     // clang-format off
41     registry.insert<AffineDialect,
42                     memref::MemRefDialect,
43                     scf::SCFDialect,
44                     StandardOpsDialect,
45                     linalg::LinalgDialect,
46                     vector::VectorDialect,
47                     gpu::GPUDialect>();
48     // clang-format on
49   }
50   StringRef getArgument() const final {
51     return "test-linalg-transform-patterns";
52   }
53   StringRef getDescription() const final {
54     return "Test Linalg transformation patterns by applying them greedily.";
55   }
56 
57   void runOnOperation() override;
58 
59   Option<bool> testPatterns{*this, "test-patterns",
60                             llvm::cl::desc("Test a mixed set of patterns"),
61                             llvm::cl::init(false)};
62   Option<bool> testMatmulToVectorPatterns1dTiling{
63       *this, "test-matmul-to-vector-patterns-tile-1d",
64       llvm::cl::desc(
65           "Test a fused pass that applies patterns from matmul to vectors via "
66           "1-d tiling"),
67       llvm::cl::init(false)};
68   Option<bool> testMatmulToVectorPatterns2dTiling{
69       *this, "test-matmul-to-vector-patterns-tile-2d",
70       llvm::cl::desc(
71           "Test a fused pass that applies patterns from matmul to vectors via "
72           "2-d tiling"),
73       llvm::cl::init(false)};
74   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
75                                     llvm::cl::desc("Test promotion options"),
76                                     llvm::cl::init(false)};
77   Option<bool> testTileAndDistributionOptions{
78       *this, "test-tile-and-distribute-options",
79       llvm::cl::desc("Test tile and distribute options"),
80       llvm::cl::init(false)};
81   Option<bool> testTileFuseAndDistributionOptions{
82       *this, "test-tile-fuse-and-distribute-options",
83       llvm::cl::desc("Test tile, fuse and distribute options"),
84       llvm::cl::init(false)};
85   Option<bool> testVectorTransferForwardingPatterns{
86       *this, "test-vector-transfer-forwarding-patterns",
87       llvm::cl::desc(
88           "Test a fused pass that forwards memref.copy to vector.transfer"),
89       llvm::cl::init(false)};
90   Option<bool> testGenericToVectorPattern{
91       *this, "test-linalg-to-vector-patterns",
92       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
93                      "in vector.contract form"),
94       llvm::cl::init(false)};
95   Option<bool> testTilePattern{*this, "test-tile-pattern",
96                                llvm::cl::desc("Test tile pattern"),
97                                llvm::cl::init(false)};
98   Option<bool> testTileScalarizeDynamicDims{
99       *this, "test-tile-scalarize-dynamic-dims",
100       llvm::cl::desc("Test tiling of dynamic dims by 1"),
101       llvm::cl::init(false)};
102   Option<bool> testTransformPadTensor{
103       *this, "test-transform-pad-tensor",
104       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
105       llvm::cl::init(false)};
106   Option<bool> testGeneralizePadTensor{
107       *this, "test-generalize-pad-tensor",
108       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
109       llvm::cl::init(false)};
110   Option<bool> testSwapSubTensorPadTensor{
111       *this, "test-swap-subtensor-padtensor",
112       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
113                      "pad_tensor(subtensor)"),
114       llvm::cl::init(false)};
115   ListOption<int64_t> peeledLoops{
116       *this, "peeled-loops",
117       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
118       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
119   ListOption<int64_t> tileSizes{
120       *this, "tile-sizes",
121       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
122       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
123   ListOption<unsigned> testTiledLoopPeeling{
124       *this, "test-tiled-loop-peeling",
125       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
126       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
127   Option<bool> skipPartial{
128       *this, "skip-partial",
129       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
130       llvm::cl::init(false)};
131   Option<std::string> loopType{
132       *this, "loop-type",
133       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
134                      "tiled_loop"),
135       llvm::cl::init("for")};
136 };
137 } // namespace
138 
139 static void applyPatterns(FuncOp funcOp) {
140   MLIRContext *ctx = funcOp.getContext();
141   RewritePatternSet patterns(ctx);
142 
143   //===--------------------------------------------------------------------===//
144   // Linalg tiling patterns.
145   //===--------------------------------------------------------------------===//
146   patterns.add<LinalgTilingPattern>(
147       MatmulOp::getOperationName(), ctx,
148       LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
149       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
150                                  StringAttr::get(ctx, "L3")));
151   patterns.add<LinalgTilingPattern>(
152       MatmulOp::getOperationName(), ctx,
153       LinalgTilingOptions().setTileSizes({200, 300, 400}),
154       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
155                                  StringAttr::get(ctx, "L2")));
156   patterns.add<LinalgTilingPattern>(
157       MatmulOp::getOperationName(), ctx,
158       LinalgTilingOptions().setTileSizes({20, 30, 40}),
159       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
160                                  StringAttr::get(ctx, "L1")));
161   patterns.add<LinalgTilingPattern>(
162       MatmulOp::getOperationName(), ctx,
163       LinalgTilingOptions().setTileSizes({2, 3, 4}),
164       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
165                                  StringAttr::get(ctx, "REG")));
166 
167   patterns.add<LinalgTilingPattern>(
168       MatvecOp::getOperationName(), ctx,
169       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
170           LinalgTilingLoopType::ParallelLoops),
171       LinalgTransformationFilter(ArrayRef<StringAttr>{},
172                                  StringAttr::get(ctx, "L1")));
173 
174   patterns.add<LinalgTilingPattern>(
175       DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
176       LinalgTransformationFilter(
177           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
178                                StringAttr::get(ctx, "L3"),
179                                StringAttr::get(ctx, "L2")},
180           StringAttr::get(ctx, "REG")));
181 
182   //===--------------------------------------------------------------------===//
183   // Linalg tiling and permutation patterns.
184   //===--------------------------------------------------------------------===//
185   patterns.add<LinalgTilingPattern>(
186       MatmulOp::getOperationName(), ctx,
187       LinalgTilingOptions()
188           .setTileSizes({2000, 3000, 4000})
189           .setInterchange({1, 2, 0}),
190       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
191                                  StringAttr::get(ctx, "L2__with_perm__")));
192   patterns.add<LinalgTilingPattern>(
193       MatmulOp::getOperationName(), ctx,
194       LinalgTilingOptions()
195           .setTileSizes({200, 300, 400})
196           .setInterchange({1, 0, 2}),
197       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
198                                  StringAttr::get(ctx, "L1__with_perm__")));
199   patterns.add<LinalgTilingPattern>(
200       MatmulOp::getOperationName(), ctx,
201       LinalgTilingOptions().setTileSizes({20, 30, 40}),
202       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
203                                  StringAttr::get(ctx, "REG__with_perm__")));
204 
205   patterns.add<LinalgTilingPattern>(
206       MatvecOp::getOperationName(), ctx,
207       LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
208       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
209                                  StringAttr::get(ctx, "L1__with_perm__")));
210 
211   patterns.add<LinalgTilingPattern>(
212       MatmulOp::getOperationName(), ctx,
213       LinalgTilingOptions()
214           .setTileSizes({16, 8, 4})
215           .setInterchange({1, 2, 0})
216           .setLoopType(LinalgTilingLoopType::ParallelLoops),
217       LinalgTransformationFilter(
218           StringAttr::get(ctx, "par__with_perm__"),
219           StringAttr::get(ctx, "after_par__with_perm__")));
220 
221   //===--------------------------------------------------------------------===//
222   // Linalg to loops patterns.
223   //===--------------------------------------------------------------------===//
224   patterns.add<LinalgLoweringPattern<DotOp>>(
225       ctx,
226       /*loweringType=*/LinalgLoweringType::Loops,
227       LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
228 
229   //===--------------------------------------------------------------------===//
230   // Linalg distribution patterns.
231   //===--------------------------------------------------------------------===//
232   LinalgLoopDistributionOptions distributionOptions;
233 
234   //===--------------------------------------------------------------------===//
235   // Linalg to vector contraction patterns.
236   //===--------------------------------------------------------------------===//
237   patterns.add<LinalgVectorizationPattern>(
238       ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
239                .addOpFilter<MatmulOp, FillOp, GenericOp>());
240   patterns.add<CopyVectorizationPattern>(ctx);
241 
242   //===--------------------------------------------------------------------===//
243   // Linalg generic interchange pattern.
244   //===--------------------------------------------------------------------===//
245   patterns.add<GenericOpInterchangePattern>(
246       ctx,
247       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
248       LinalgTransformationFilter(ArrayRef<StringAttr>{},
249                                  StringAttr::get(ctx, "PERMUTED")));
250 
251   //===--------------------------------------------------------------------===//
252   // Linalg subview operands promotion.
253   //===--------------------------------------------------------------------===//
254   patterns.add<LinalgPromotionPattern<MatmulOp>>(
255       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
256       LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
257                                  StringAttr::get(ctx, "_views_promoted_")));
258   patterns.add<LinalgPromotionPattern<MatmulOp>>(
259       ctx,
260       LinalgPromotionOptions()
261           .setOperandsToPromote({0})
262           .setUseFullTileBuffersByDefault(true),
263       LinalgTransformationFilter(
264           StringAttr::get(ctx, "_promote_first_view_"),
265           StringAttr::get(ctx, "_first_view_promoted_")));
266   patterns.add<LinalgPromotionPattern<FillOp>>(
267       ctx,
268       LinalgPromotionOptions()
269           .setOperandsToPromote({1})
270           .setUseFullTileBuffers({false, true})
271           .setAlignment(32),
272       LinalgTransformationFilter(
273           StringAttr::get(ctx, "_promote_views_aligned_"),
274           StringAttr::get(ctx, "_views_aligned_promoted_")));
275 
276   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
277 
278   // Drop the marker.
279   funcOp.walk([](LinalgOp op) {
280     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
281   });
282 }
283 
284 static void fillL1TilingAndMatmulToVectorPatterns(
285     FuncOp funcOp, StringRef startMarker,
286     SmallVectorImpl<RewritePatternSet> &patternsVector) {
287   MLIRContext *ctx = funcOp.getContext();
288   patternsVector.emplace_back(
289       ctx, std::make_unique<LinalgTilingPattern>(
290                MatmulOp::getOperationName(), ctx,
291                LinalgTilingOptions()
292                    .setTileSizes({8, 12, 16})
293                    .setInterchange({1, 0, 2}),
294                LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
295                                           StringAttr::get(ctx, "L1"))));
296 
297   patternsVector.emplace_back(
298       ctx,
299       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
300           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
301           LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
302                                      StringAttr::get(ctx, "VEC"))));
303 
304   patternsVector.emplace_back(
305       ctx, std::make_unique<LinalgVectorizationPattern>(
306                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
307                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
308   patternsVector.back().add<LinalgVectorizationPattern>(
309       ctx, LinalgTransformationFilter().addOpFilter<FillOp>());
310   patternsVector.back().add<CopyVectorizationPattern>(ctx);
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // Test promotion callbacks
315 //===----------------------------------------------------------------------===//
316 
317 // Allocation call back
318 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
319                                        ArrayRef<Value> boundingSubViewSize,
320                                        DataLayout &layout) {
321   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
322   return b
323       .create<memref::AllocOp>(
324           subView.getLoc(),
325           MemRefType::get(shape, subView.getType().getElementType(),
326                           /*affineMapComposition =*/{}, 3),
327           boundingSubViewSize)
328       .getResult();
329 }
330 
331 // Deallocation callback
332 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
333   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
334   return success();
335 }
336 
337 // Copy in call back
338 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
339                                     bool isOutput) {
340   auto floatType = src.getType().cast<MemRefType>().getElementType();
341   if (!floatType.isa<FloatType>())
342     return failure();
343   if (!isOutput) {
344     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
345                                             FloatAttr::get(floatType, 42.0));
346     b.create<FillOp>(src.getLoc(), cst, dst);
347   }
348   b.create<memref::CopyOp>(src.getLoc(), src, dst);
349   return success();
350 }
351 
352 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
353                                           RewritePatternSet &patterns) {
354   patterns.add<LinalgTilingPattern>(
355       MatmulOp::getOperationName(), ctx,
356       LinalgTilingOptions().setTileSizes({16, 16, 16}),
357       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
358                                  StringAttr::get(ctx, "PROMOTE")));
359   patterns.add<LinalgPromotionPattern<MatmulOp>>(
360       ctx,
361       LinalgPromotionOptions()
362           .setOperandsToPromote({0, 2})
363           .setUseFullTileBuffers({false, false})
364           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
365           .setCopyInOutFns(
366               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
367                 return copyCallBackFn(b, src, dst, false);
368               },
369               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
370                 return copyCallBackFn(b, src, dst, true);
371               }),
372       LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
373 }
374 
375 template <typename IdOp, typename NProcsOp>
376 static SmallVector<ProcInfo, 2>
377 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
378   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
379   SmallVector<ProcInfo, 2> procInfo(count);
380   Type indexType = b.getIndexType();
381   for (unsigned i = 0; i < count; ++i) {
382     gpu::Dimension dim = *gpu::symbolizeDimension(i);
383     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
384                                b.create<NProcsOp>(loc, indexType, dim)};
385   }
386   return procInfo;
387 }
388 
389 static void fillTileAndDistributePatterns(MLIRContext *context,
390                                           RewritePatternSet &patterns) {
391   {
392     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
393     cyclicNprocsEqNiters.distributionMethod.resize(
394         2, DistributionMethod::CyclicNumProcsEqNumIters);
395     cyclicNprocsEqNiters.procInfo =
396         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
397     patterns.add<LinalgTilingPattern>(
398         MatmulOp::getOperationName(), context,
399         LinalgTilingOptions()
400             .setTileSizes({8, 8, 4})
401             .setLoopType(LinalgTilingLoopType::ParallelLoops)
402             .setDistributionOptions(cyclicNprocsEqNiters),
403         LinalgTransformationFilter(
404             StringAttr::get(context, "distribute1"),
405             StringAttr::get(context, "after_distribute1")));
406   }
407 
408   {
409     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
410     cyclicNprocsGeNiters.distributionMethod.resize(
411         2, DistributionMethod::CyclicNumProcsGeNumIters);
412     cyclicNprocsGeNiters.procInfo =
413         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
414     patterns.add<LinalgTilingPattern>(
415         MatmulOp::getOperationName(), context,
416         LinalgTilingOptions()
417             .setTileSizes({8, 8, 4})
418             .setLoopType(LinalgTilingLoopType::ParallelLoops)
419             .setDistributionOptions(cyclicNprocsGeNiters),
420         LinalgTransformationFilter(
421             StringAttr::get(context, "distribute2"),
422             StringAttr::get(context, "after_distribute2")));
423   }
424 
425   {
426     LinalgLoopDistributionOptions cyclicNprocsDefault;
427     cyclicNprocsDefault.distributionMethod.resize(2,
428                                                   DistributionMethod::Cyclic);
429     cyclicNprocsDefault.procInfo =
430         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
431     patterns.add<LinalgTilingPattern>(
432         MatmulOp::getOperationName(), context,
433         LinalgTilingOptions()
434             .setTileSizes({8, 8, 4})
435             .setLoopType(LinalgTilingLoopType::ParallelLoops)
436             .setDistributionOptions(cyclicNprocsDefault),
437         LinalgTransformationFilter(
438             StringAttr::get(context, "distribute3"),
439             StringAttr::get(context, "after_distribute3")));
440   }
441 
442   {
443     LinalgLoopDistributionOptions cyclicNprocsMixed1;
444     cyclicNprocsMixed1.distributionMethod = {
445         DistributionMethod::CyclicNumProcsEqNumIters,
446         DistributionMethod::CyclicNumProcsGeNumIters};
447     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
448     patterns.add<LinalgTilingPattern>(
449         MatmulOp::getOperationName(), context,
450         LinalgTilingOptions()
451             .setTileSizes({8, 8, 4})
452             .setLoopType(LinalgTilingLoopType::ParallelLoops)
453             .setDistributionOptions(cyclicNprocsMixed1),
454         LinalgTransformationFilter(
455             StringAttr::get(context, "distribute4"),
456             StringAttr::get(context, "after_distribute4")));
457   }
458 
459   {
460     LinalgLoopDistributionOptions cyclicNprocsMixed2;
461     cyclicNprocsMixed2.distributionMethod = {
462         DistributionMethod::CyclicNumProcsGeNumIters,
463         DistributionMethod::Cyclic};
464     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
465     patterns.add<LinalgTilingPattern>(
466         MatmulOp::getOperationName(), context,
467         LinalgTilingOptions()
468             .setTileSizes({8, 8, 4})
469             .setLoopType(LinalgTilingLoopType::ParallelLoops)
470             .setDistributionOptions(cyclicNprocsMixed2),
471         LinalgTransformationFilter(
472             StringAttr::get(context, "distribute5"),
473             StringAttr::get(context, "after_distribute5")));
474   }
475 
476   {
477     LinalgLoopDistributionOptions cyclicNprocsMixed3;
478     cyclicNprocsMixed3.distributionMethod = {
479         DistributionMethod::Cyclic,
480         DistributionMethod::CyclicNumProcsEqNumIters};
481     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
482 
483     patterns.add<LinalgTilingPattern>(
484         MatmulOp::getOperationName(), context,
485         LinalgTilingOptions()
486             .setTileSizes({8, 8, 4})
487             .setLoopType(LinalgTilingLoopType::ParallelLoops)
488             .setDistributionOptions(cyclicNprocsMixed3),
489         LinalgTransformationFilter(
490             StringAttr::get(context, "distribute6"),
491             StringAttr::get(context, "after_distribute6")));
492   }
493 
494   {
495     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
496     cyclicNprocsEqNiters.distributionMethod.resize(2,
497                                                    DistributionMethod::Cyclic);
498     cyclicNprocsEqNiters.procInfo =
499         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
500     patterns.add<LinalgTilingPattern>(
501         MatmulOp::getOperationName(), context,
502         LinalgTilingOptions()
503             .setTileSizes({8, 8, 4})
504             .setLoopType(LinalgTilingLoopType::Loops)
505             .setDistributionOptions(cyclicNprocsEqNiters),
506         LinalgTransformationFilter(
507             StringAttr::get(context, "tensors_distribute1"),
508             StringAttr::get(context, "tensors_after_distribute1")));
509   }
510 }
511 
512 static void fillTileFuseAndDistributePatterns(MLIRContext *context,
513                                               RewritePatternSet &patterns) {
514   LinalgLoopDistributionOptions cyclicNprocsEqNiters;
515   cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
516   cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
517   patterns.add<LinalgTileAndFuseTensorOpsPattern>(
518       MatmulOp::getOperationName(), context,
519       LinalgTilingAndFusionOptions()
520           .setTileSizes({8, 8, 4})
521           .setDistributionOptions(cyclicNprocsEqNiters),
522       LinalgTransformationFilter(
523           StringAttr::get(context, "tensors_fuse_distribute1"),
524           StringAttr::get(context, "tensors_after_fuse_distribute1")));
525 }
526 
527 static void
528 applyMatmulToVectorPatterns(FuncOp funcOp,
529                             bool testMatmulToVectorPatterns1dTiling,
530                             bool testMatmulToVectorPatterns2dTiling) {
531   MLIRContext *ctx = funcOp.getContext();
532   SmallVector<RewritePatternSet, 4> stage1Patterns;
533   if (testMatmulToVectorPatterns1dTiling) {
534     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
535   } else if (testMatmulToVectorPatterns2dTiling) {
536     stage1Patterns.emplace_back(
537         ctx, std::make_unique<LinalgTilingPattern>(
538                  MatmulOp::getOperationName(), ctx,
539                  LinalgTilingOptions()
540                      .setTileSizes({768, 264, 768})
541                      .setInterchange({1, 2, 0}),
542                  LinalgTransformationFilter(StringAttr::get(ctx, "START"),
543                                             StringAttr::get(ctx, "L2"))));
544     fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
545   }
546   {
547     // Canonicalization patterns
548     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
549     vector::populateVectorTransferPermutationMapLoweringPatterns(
550         canonicalizationPatterns);
551     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
552     stage1Patterns.push_back(std::move(canonicalizationPatterns));
553   }
554   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
555   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
556   FrozenRewritePatternSet stage2Patterns =
557       getLinalgTilingCanonicalizationPatterns(ctx);
558   (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns);
559 }
560 
561 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
562   RewritePatternSet forwardPattern(funcOp.getContext());
563   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
564   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
565   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
566 }
567 
568 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
569   RewritePatternSet patterns(funcOp.getContext());
570   auto *ctx = funcOp.getContext();
571   patterns.add<LinalgVectorizationPattern>(
572       ctx, LinalgTransformationFilter()
573                .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
574   patterns.add<CopyVectorizationPattern>(ctx);
575   populatePadOpVectorizationPatterns(patterns);
576   populateConvolutionVectorizationPatterns(patterns);
577   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
578 }
579 
580 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
581   RewritePatternSet patterns(funcOp.getContext());
582   patterns.add<PadOpTransformationPattern>(funcOp.getContext());
583   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
584 }
585 
586 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
587   RewritePatternSet patterns(funcOp.getContext());
588   patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
589   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
590 }
591 
592 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
593   RewritePatternSet patterns(funcOp.getContext());
594   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
595   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
596 }
597 
598 static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
599                              ArrayRef<int64_t> tileSizes,
600                              ArrayRef<int64_t> peeledLoops,
601                              bool scalarizeDynamicDims) {
602   MLIRContext *context = funcOp.getContext();
603   RewritePatternSet tilingPattern(context);
604   LinalgTilingLoopType type =
605       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
606           .Case("for", LinalgTilingLoopType::Loops)
607           .Case("affine", LinalgTilingLoopType::AffineLoops)
608           .Case("parallel", LinalgTilingLoopType::ParallelLoops)
609           .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
610   auto linalgTilingOptions = linalg::LinalgTilingOptions()
611                                  .setPeeledLoops(peeledLoops)
612                                  .setLoopType(type);
613   if (scalarizeDynamicDims) {
614     linalgTilingOptions.scalarizeDynamicDims();
615     assert(tileSizes.empty() &&
616            "tileSizes and scalarizeDynamicDims is mutually exclusive");
617   } else {
618     linalgTilingOptions.setTileSizes(tileSizes);
619   }
620   linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
621   TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
622       tilingPattern, linalgTilingOptions, f);
623   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
624 }
625 
626 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
627 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
628 
629 namespace {
630 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
631 /// `idx`-th loop contains only "full" iterations and a second loop for the
632 /// remaining partial iteration (if any).
633 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
634   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
635       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
636   }
637 
638   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
639                                 PatternRewriter &rewriter) const override {
640     SmallVector<int64_t> peeledLoops;
641     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
642       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
643       peeledLoops =
644           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
645             return attr.cast<IntegerAttr>().getInt();
646           }));
647       // Check if the loop was already peeled.
648       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
649         return failure();
650     }
651     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
652       // No peeling of loop nests with a partial iteration.
653       return failure();
654 
655     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
656       return failure();
657 
658     // Peel loop and canonicalize.
659     TiledLoopOp result;
660     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
661                                                     result)))
662       return failure();
663 
664     // Apply label, so that the same loop is not rewritten a second time.
665     peeledLoops.push_back(idx);
666     rewriter.updateRootInPlace(loopOp, [&]() {
667       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
668     });
669     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
670     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
671 
672     return success();
673   }
674 
675   /// Index of loop to peel.
676   int64_t idx;
677 
678   /// If set to true, do not peel TiledLoopOps with a partial iteration.
679   bool skipPartial;
680 };
681 } // namespace
682 
683 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
684                                          ArrayRef<unsigned> loops,
685                                          bool skipPartial) {
686   MLIRContext *ctx = funcOp.getContext();
687   RewritePatternSet patterns(ctx);
688   for (unsigned idx : loops)
689     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
690   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
691 
692   // Drop the markers.
693   funcOp.walk([](TiledLoopOp op) {
694     op->removeAttr(kPeeledLoopsLabel);
695     op->removeAttr(kPartialIterationLabel);
696   });
697 }
698 
699 /// Apply transformations specified as patterns.
700 void TestLinalgTransforms::runOnOperation() {
701   auto lambda = [&](void *) {
702     getOperation().walk([](LinalgOp op) {
703       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
704     });
705   };
706   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
707 
708   if (testPromotionOptions) {
709     RewritePatternSet patterns(&getContext());
710     fillPromotionCallBackPatterns(&getContext(), patterns);
711     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
712     return;
713   }
714   if (testTileAndDistributionOptions) {
715     RewritePatternSet patterns(&getContext());
716     fillTileAndDistributePatterns(&getContext(), patterns);
717     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
718     return;
719   }
720   if (testTileFuseAndDistributionOptions) {
721     RewritePatternSet patterns(&getContext());
722     fillTileFuseAndDistributePatterns(&getContext(), patterns);
723     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
724     return;
725   }
726   if (testPatterns)
727     return applyPatterns(getOperation());
728   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
729     return applyMatmulToVectorPatterns(getOperation(),
730                                        testMatmulToVectorPatterns1dTiling,
731                                        testMatmulToVectorPatterns2dTiling);
732   if (testVectorTransferForwardingPatterns)
733     return applyVectorTransferForwardingPatterns(getOperation());
734   if (testGenericToVectorPattern)
735     return applyLinalgToVectorPatterns(getOperation());
736   if (testTransformPadTensor)
737     return applyPadTensorToGenericPatterns(getOperation());
738   if (testGeneralizePadTensor)
739     return applyGeneralizePadTensorPatterns(getOperation());
740   if (testSwapSubTensorPadTensor)
741     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
742   if (testTiledLoopPeeling.hasValue())
743     return applyTiledLoopPeelingPattern(getOperation(), testTiledLoopPeeling,
744                                         skipPartial);
745   if (testTilePattern)
746     return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops,
747                             /*scalarizeDynamicDims=*/false);
748   if (testTileScalarizeDynamicDims)
749     return applyTilePattern(getOperation(), loopType, tileSizes,
750                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
751 }
752 
753 namespace mlir {
754 namespace test {
755 void registerTestLinalgTransforms() {
756   PassRegistration<TestLinalgTransforms>();
757 }
758 } // namespace test
759 } // namespace mlir
760