1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
18 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 namespace {
33 struct TestLinalgTransforms
34     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
35   TestLinalgTransforms() = default;
36   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
37 
38   void getDependentDialects(DialectRegistry &registry) const override {
39     // clang-format off
40     registry.insert<AffineDialect,
41                     memref::MemRefDialect,
42                     scf::SCFDialect,
43                     StandardOpsDialect,
44                     vector::VectorDialect,
45                     gpu::GPUDialect>();
46     // clang-format on
47   }
48   StringRef getArgument() const final {
49     return "test-linalg-transform-patterns";
50   }
51   StringRef getDescription() const final {
52     return "Test Linalg transformation patterns by applying them greedily.";
53   }
54 
55   void runOnFunction() override;
56 
57   Option<bool> testPatterns{*this, "test-patterns",
58                             llvm::cl::desc("Test a mixed set of patterns"),
59                             llvm::cl::init(false)};
60   Option<bool> testMatmulToVectorPatterns1dTiling{
61       *this, "test-matmul-to-vector-patterns-tile-1d",
62       llvm::cl::desc(
63           "Test a fused pass that applies patterns from matmul to vectors via "
64           "1-d tiling"),
65       llvm::cl::init(false)};
66   Option<bool> testMatmulToVectorPatterns2dTiling{
67       *this, "test-matmul-to-vector-patterns-tile-2d",
68       llvm::cl::desc(
69           "Test a fused pass that applies patterns from matmul to vectors via "
70           "2-d tiling"),
71       llvm::cl::init(false)};
72   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
73                                     llvm::cl::desc("Test promotion options"),
74                                     llvm::cl::init(false)};
75   Option<bool> testTileAndDistributionOptions{
76       *this, "test-tile-and-distribute-options",
77       llvm::cl::desc("Test tile and distribute options"),
78       llvm::cl::init(false)};
79   Option<bool> testVectorTransferForwardingPatterns{
80       *this, "test-vector-transfer-forwarding-patterns",
81       llvm::cl::desc(
82           "Test a fused pass that forwards linalg.copy to vector.transfer"),
83       llvm::cl::init(false)};
84   Option<bool> testGenericToVectorPattern{
85       *this, "test-linalg-to-vector-patterns",
86       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
87                      "in vector.contract form"),
88       llvm::cl::init(false)};
89   Option<bool> testTilePattern{*this, "test-tile-pattern",
90                                llvm::cl::desc("Test tile pattern"),
91                                llvm::cl::init(false)};
92   Option<bool> testTileScalarizeDynamicDims{
93       *this, "test-tile-scalarize-dynamic-dims",
94       llvm::cl::desc("Test tiling of dynamic dims by 1"),
95       llvm::cl::init(false)};
96   Option<int> testHoistPadding{*this, "test-hoist-padding",
97                                llvm::cl::desc("Test hoist padding"),
98                                llvm::cl::init(0)};
99   Option<bool> testTransformPadTensor{
100       *this, "test-transform-pad-tensor",
101       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
102       llvm::cl::init(false)};
103   Option<bool> testGeneralizePadTensor{
104       *this, "test-generalize-pad-tensor",
105       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
106       llvm::cl::init(false)};
107   Option<bool> testSwapSubTensorPadTensor{
108       *this, "test-swap-subtensor-padtensor",
109       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
110                      "pad_tensor(subtensor)"),
111       llvm::cl::init(false)};
112   ListOption<int64_t> paddedOperands{
113       *this, "padded-operands",
114       llvm::cl::desc("Operands to pad when test-tile-pattern"),
115       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
116   ListOption<int64_t> nofoldOperands{
117       *this, "nofold-operands",
118       llvm::cl::desc("Operands to set nofold when test-tile-pattern"),
119       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
120   ListOption<int64_t> peeledLoops{
121       *this, "peeled-loops",
122       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
123       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
124   ListOption<int64_t> tileSizes{
125       *this, "tile-sizes",
126       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
127       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
128   ListOption<unsigned> testInterchangePattern{
129       *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
130       llvm::cl::desc("Test the interchange pattern.")};
131   ListOption<unsigned> testTiledLoopPeeling{
132       *this, "test-tiled-loop-peeling",
133       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
134       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
135   Option<bool> skipPartial{
136       *this, "skip-partial",
137       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
138       llvm::cl::init(false)};
139   Option<std::string> loopType{
140       *this, "loop-type",
141       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
142                      "tiled_loop"),
143       llvm::cl::init("for")};
144 };
145 } // end anonymous namespace
146 
147 static void applyPatterns(FuncOp funcOp) {
148   MLIRContext *ctx = funcOp.getContext();
149   RewritePatternSet patterns(ctx);
150 
151   //===--------------------------------------------------------------------===//
152   // Linalg tiling patterns.
153   //===--------------------------------------------------------------------===//
154   patterns.add<LinalgTilingPattern<MatmulOp>>(
155       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
156       LinalgTransformationFilter(Identifier::get("MEM", ctx),
157                                  Identifier::get("L3", ctx)));
158   patterns.add<LinalgTilingPattern<MatmulOp>>(
159       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
160       LinalgTransformationFilter(Identifier::get("L3", ctx),
161                                  Identifier::get("L2", ctx)));
162   patterns.add<LinalgTilingPattern<MatmulOp>>(
163       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
164       LinalgTransformationFilter(Identifier::get("L2", ctx),
165                                  Identifier::get("L1", ctx)));
166   patterns.add<LinalgTilingPattern<MatmulOp>>(
167       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
168       LinalgTransformationFilter(Identifier::get("L1", ctx),
169                                  Identifier::get("REG", ctx)));
170 
171   patterns.add<LinalgTilingPattern<MatvecOp>>(
172       ctx,
173       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
174           LinalgTilingLoopType::ParallelLoops),
175       LinalgTransformationFilter(ArrayRef<Identifier>{},
176                                  Identifier::get("L1", ctx)));
177 
178   patterns.add<LinalgTilingPattern<DotOp>>(
179       ctx, LinalgTilingOptions().setTileSizes(8000),
180       LinalgTransformationFilter(
181           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
182                                Identifier::get("L3", ctx),
183                                Identifier::get("L2", ctx)},
184           Identifier::get("REG", ctx)));
185 
186   //===--------------------------------------------------------------------===//
187   // Linalg tiling and permutation patterns.
188   //===--------------------------------------------------------------------===//
189   patterns.add<LinalgTilingPattern<MatmulOp>>(
190       ctx,
191       LinalgTilingOptions()
192           .setTileSizes({2000, 3000, 4000})
193           .setInterchange({1, 2, 0}),
194       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
195                                  Identifier::get("L2__with_perm__", ctx)));
196   patterns.add<LinalgTilingPattern<MatmulOp>>(
197       ctx,
198       LinalgTilingOptions()
199           .setTileSizes({200, 300, 400})
200           .setInterchange({1, 0, 2}),
201       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
202                                  Identifier::get("L1__with_perm__", ctx)));
203   patterns.add<LinalgTilingPattern<MatmulOp>>(
204       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
205       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
206                                  Identifier::get("REG__with_perm__", ctx)));
207 
208   patterns.add<LinalgTilingPattern<MatvecOp>>(
209       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
210       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
211                                  Identifier::get("L1__with_perm__", ctx)));
212 
213   patterns.add<LinalgTilingPattern<MatmulOp>>(
214       ctx,
215       LinalgTilingOptions()
216           .setTileSizes({16, 8, 4})
217           .setInterchange({1, 2, 0})
218           .setLoopType(LinalgTilingLoopType::ParallelLoops),
219       LinalgTransformationFilter(
220           Identifier::get("par__with_perm__", ctx),
221           Identifier::get("after_par__with_perm__", ctx)));
222 
223   //===--------------------------------------------------------------------===//
224   // Linalg to loops patterns.
225   //===--------------------------------------------------------------------===//
226   patterns.add<LinalgLoweringPattern<DotOp>>(
227       ctx,
228       /*loweringType=*/LinalgLoweringType::Loops,
229       LinalgTransformationFilter(Identifier::get("REG", ctx)));
230 
231   //===--------------------------------------------------------------------===//
232   // Linalg distribution patterns.
233   //===--------------------------------------------------------------------===//
234   LinalgLoopDistributionOptions distributionOptions;
235 
236   //===--------------------------------------------------------------------===//
237   // Linalg to vector contraction patterns.
238   //===--------------------------------------------------------------------===//
239   patterns.add<LinalgVectorizationPattern>(
240       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
241                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
242 
243   //===--------------------------------------------------------------------===//
244   // Linalg generic interchange pattern.
245   //===--------------------------------------------------------------------===//
246   patterns.add<GenericOpInterchangePattern>(
247       ctx,
248       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
249       LinalgTransformationFilter(ArrayRef<Identifier>{},
250                                  Identifier::get("PERMUTED", ctx)));
251 
252   //===--------------------------------------------------------------------===//
253   // Linalg subview operands promotion.
254   //===--------------------------------------------------------------------===//
255   patterns.add<LinalgPromotionPattern<MatmulOp>>(
256       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
257       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
258                                  Identifier::get("_views_promoted_", ctx)));
259   patterns.add<LinalgPromotionPattern<MatmulOp>>(
260       ctx,
261       LinalgPromotionOptions()
262           .setOperandsToPromote({0})
263           .setUseFullTileBuffersByDefault(true),
264       LinalgTransformationFilter(
265           Identifier::get("_promote_first_view_", ctx),
266           Identifier::get("_first_view_promoted_", ctx)));
267   patterns.add<LinalgPromotionPattern<FillOp>>(
268       ctx,
269       LinalgPromotionOptions()
270           .setOperandsToPromote({1})
271           .setUseFullTileBuffers({false, true})
272           .setAlignment(32),
273       LinalgTransformationFilter(
274           Identifier::get("_promote_views_aligned_", ctx),
275           Identifier::get("_views_aligned_promoted_", ctx)));
276 
277   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
278 
279   // Drop the marker.
280   funcOp.walk([](LinalgOp op) {
281     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
282   });
283 }
284 
285 static void fillL1TilingAndMatmulToVectorPatterns(
286     FuncOp funcOp, StringRef startMarker,
287     SmallVectorImpl<RewritePatternSet> &patternsVector) {
288   MLIRContext *ctx = funcOp.getContext();
289   patternsVector.emplace_back(
290       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
291                ctx,
292                LinalgTilingOptions()
293                    .setTileSizes({8, 12, 16})
294                    .setInterchange({1, 0, 2}),
295                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
296                                           Identifier::get("L1", ctx))));
297 
298   patternsVector.emplace_back(
299       ctx,
300       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
301           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
302           LinalgTransformationFilter(Identifier::get("L1", ctx),
303                                      Identifier::get("VEC", ctx))));
304 
305   patternsVector.emplace_back(
306       ctx, std::make_unique<LinalgVectorizationPattern>(
307                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
308                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
309   patternsVector.back().add<LinalgVectorizationPattern>(
310       ctx, LinalgTransformationFilter().addFilter(
311                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Test promotion callbacks
316 //===----------------------------------------------------------------------===//
317 
318 // Allocation call back
319 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
320                                        ArrayRef<Value> boundingSubViewSize,
321                                        DataLayout &layout) {
322   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
323   return b
324       .create<memref::AllocOp>(
325           subView.getLoc(),
326           MemRefType::get(shape, subView.getType().getElementType(),
327                           /*affineMapComposition =*/{}, 3),
328           boundingSubViewSize)
329       .getResult();
330 }
331 
332 // Deallocation callback
333 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
334   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
335   return success();
336 }
337 
338 // Copy in call back
339 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
340                                     bool isOutput) {
341   auto floatType = src.getType().cast<MemRefType>().getElementType();
342   if (!floatType.isa<FloatType>())
343     return failure();
344   if (!isOutput) {
345     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
346                                             FloatAttr::get(floatType, 42.0));
347     b.create<FillOp>(src.getLoc(), cst, dst);
348   }
349   b.create<CopyOp>(src.getLoc(), src, dst);
350   return success();
351 }
352 
353 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
354                                           RewritePatternSet &patterns) {
355   patterns.add<LinalgTilingPattern<MatmulOp>>(
356       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
357       LinalgTransformationFilter(Identifier::get("START", ctx),
358                                  Identifier::get("PROMOTE", ctx)));
359   patterns.add<LinalgPromotionPattern<MatmulOp>>(
360       ctx,
361       LinalgPromotionOptions()
362           .setOperandsToPromote({0, 2})
363           .setUseFullTileBuffers({false, false})
364           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
365           .setCopyInOutFns(
366               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
367                 return copyCallBackFn(b, src, dst, false);
368               },
369               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
370                 return copyCallBackFn(b, src, dst, true);
371               }),
372       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
373 }
374 
375 template <typename IdOp, typename NProcsOp>
376 static SmallVector<ProcInfo, 2>
377 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
378   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
379   SmallVector<ProcInfo, 2> procInfo(count);
380   const char *xyz[] = {"x", "y", "z"};
381   Type indexType = b.getIndexType();
382   for (unsigned i = 0; i < count; ++i) {
383     procInfo[count - 1 - i] = {
384         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
385         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
386   }
387   return procInfo;
388 }
389 
390 static void fillTileAndDistributePatterns(MLIRContext *context,
391                                           RewritePatternSet &patterns) {
392   {
393     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
394     cyclicNprocsEqNiters.distributionMethod.resize(
395         2, DistributionMethod::CyclicNumProcsEqNumIters);
396     cyclicNprocsEqNiters.procInfo =
397         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
398     patterns.add<LinalgTilingPattern<MatmulOp>>(
399         context,
400         LinalgTilingOptions()
401             .setTileSizes({8, 8, 4})
402             .setLoopType(LinalgTilingLoopType::ParallelLoops)
403             .setDistributionOptions(cyclicNprocsEqNiters),
404         LinalgTransformationFilter(
405             Identifier::get("distribute1", context),
406             Identifier::get("after_distribute1", context)));
407   }
408 
409   {
410     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
411     cyclicNprocsGeNiters.distributionMethod.resize(
412         2, DistributionMethod::CyclicNumProcsGeNumIters);
413     cyclicNprocsGeNiters.procInfo =
414         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
415     patterns.add<LinalgTilingPattern<MatmulOp>>(
416         context,
417         LinalgTilingOptions()
418             .setTileSizes({8, 8, 4})
419             .setLoopType(LinalgTilingLoopType::ParallelLoops)
420             .setDistributionOptions(cyclicNprocsGeNiters),
421         LinalgTransformationFilter(
422             Identifier::get("distribute2", context),
423             Identifier::get("after_distribute2", context)));
424   }
425 
426   {
427     LinalgLoopDistributionOptions cyclicNprocsDefault;
428     cyclicNprocsDefault.distributionMethod.resize(2,
429                                                   DistributionMethod::Cyclic);
430     cyclicNprocsDefault.procInfo =
431         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
432     patterns.add<LinalgTilingPattern<MatmulOp>>(
433         context,
434         LinalgTilingOptions()
435             .setTileSizes({8, 8, 4})
436             .setLoopType(LinalgTilingLoopType::ParallelLoops)
437             .setDistributionOptions(cyclicNprocsDefault),
438         LinalgTransformationFilter(
439             Identifier::get("distribute3", context),
440             Identifier::get("after_distribute3", context)));
441   }
442 
443   {
444     LinalgLoopDistributionOptions cyclicNprocsMixed1;
445     cyclicNprocsMixed1.distributionMethod = {
446         DistributionMethod::CyclicNumProcsEqNumIters,
447         DistributionMethod::CyclicNumProcsGeNumIters};
448     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
449     patterns.add<LinalgTilingPattern<MatmulOp>>(
450         context,
451         LinalgTilingOptions()
452             .setTileSizes({8, 8, 4})
453             .setLoopType(LinalgTilingLoopType::ParallelLoops)
454             .setDistributionOptions(cyclicNprocsMixed1),
455         LinalgTransformationFilter(
456             Identifier::get("distribute4", context),
457             Identifier::get("after_distribute4", context)));
458   }
459 
460   {
461     LinalgLoopDistributionOptions cyclicNprocsMixed2;
462     cyclicNprocsMixed2.distributionMethod = {
463         DistributionMethod::CyclicNumProcsGeNumIters,
464         DistributionMethod::Cyclic};
465     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
466     patterns.add<LinalgTilingPattern<MatmulOp>>(
467         context,
468         LinalgTilingOptions()
469             .setTileSizes({8, 8, 4})
470             .setLoopType(LinalgTilingLoopType::ParallelLoops)
471             .setDistributionOptions(cyclicNprocsMixed2),
472         LinalgTransformationFilter(
473             Identifier::get("distribute5", context),
474             Identifier::get("after_distribute5", context)));
475   }
476 
477   {
478     LinalgLoopDistributionOptions cyclicNprocsMixed3;
479     cyclicNprocsMixed3.distributionMethod = {
480         DistributionMethod::Cyclic,
481         DistributionMethod::CyclicNumProcsEqNumIters};
482     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
483 
484     patterns.add<LinalgTilingPattern<MatmulOp>>(
485         context,
486         LinalgTilingOptions()
487             .setTileSizes({8, 8, 4})
488             .setLoopType(LinalgTilingLoopType::ParallelLoops)
489             .setDistributionOptions(cyclicNprocsMixed3),
490         LinalgTransformationFilter(
491             Identifier::get("distribute6", context),
492             Identifier::get("after_distribute6", context)));
493   }
494 
495   {
496     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
497     cyclicNprocsEqNiters.distributionMethod.resize(2,
498                                                    DistributionMethod::Cyclic);
499     cyclicNprocsEqNiters.procInfo =
500         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
501     patterns.add<LinalgTilingPattern<MatmulOp>>(
502         context,
503         LinalgTilingOptions()
504             .setTileSizes({8, 8, 4})
505             .setLoopType(LinalgTilingLoopType::Loops)
506             .setDistributionOptions(cyclicNprocsEqNiters),
507         LinalgTransformationFilter(
508             Identifier::get("tensors_distribute1", context),
509             Identifier::get("tensors_after_distribute1", context)));
510   }
511 }
512 
513 static void
514 applyMatmulToVectorPatterns(FuncOp funcOp,
515                             bool testMatmulToVectorPatterns1dTiling,
516                             bool testMatmulToVectorPatterns2dTiling) {
517   MLIRContext *ctx = funcOp.getContext();
518   SmallVector<RewritePatternSet, 4> stage1Patterns;
519   if (testMatmulToVectorPatterns1dTiling) {
520     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
521                                           stage1Patterns);
522   } else if (testMatmulToVectorPatterns2dTiling) {
523     stage1Patterns.emplace_back(
524         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
525                  ctx,
526                  LinalgTilingOptions()
527                      .setTileSizes({768, 264, 768})
528                      .setInterchange({1, 2, 0}),
529                  LinalgTransformationFilter(Identifier::get("START", ctx),
530                                             Identifier::get("L2", ctx))));
531     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
532                                           stage1Patterns);
533   }
534   {
535     // Canonicalization patterns
536     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
537     vector::populateVectorTransferPermutationMapLoweringPatterns(
538         canonicalizationPatterns);
539     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
540     stage1Patterns.push_back(std::move(canonicalizationPatterns));
541   }
542   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
543   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
544   FrozenRewritePatternSet stage2Patterns =
545       getLinalgTilingCanonicalizationPatterns(ctx);
546   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
547                             std::move(stage2Patterns));
548 }
549 
550 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
551   RewritePatternSet forwardPattern(funcOp.getContext());
552   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
553   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
554   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
555 }
556 
557 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
558   RewritePatternSet patterns(funcOp.getContext());
559   patterns.add<LinalgVectorizationPattern>(
560       funcOp.getContext(),
561       LinalgTransformationFilter()
562           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
563   populatePadTensorOpVectorizationPatterns(patterns);
564   populateConvolutionVectorizationPatterns(patterns);
565   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
566 }
567 
568 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
569   RewritePatternSet patterns(funcOp.getContext());
570   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
571   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
572 }
573 
574 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
575   RewritePatternSet patterns(funcOp.getContext());
576   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
577   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
578 }
579 
580 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
581   RewritePatternSet patterns(funcOp.getContext());
582   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
583   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
584 }
585 
586 // For now, just assume it is the zero of type.
587 // In the future, it should be the zero of type + op.
588 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
589   auto t = getElementTypeOrSelf(op.get());
590   return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t,
591                                      b.getZeroAttr(t));
592 }
593 
594 static void applyTilePattern(FuncOp funcOp, std::string loopType,
595                              ArrayRef<int64_t> tileSizes,
596                              ArrayRef<int64_t> paddedOperands,
597                              ArrayRef<int64_t> nofoldOperands,
598                              ArrayRef<int64_t> peeledLoops,
599                              bool scalarizeDynamicDims) {
600   MLIRContext *context = funcOp.getContext();
601   RewritePatternSet tilingPattern(context);
602   LinalgTilingLoopType type =
603       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
604           .Case("for", LinalgTilingLoopType::Loops)
605           .Case("affine", LinalgTilingLoopType::AffineLoops)
606           .Case("parallel", LinalgTilingLoopType::ParallelLoops)
607           .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
608   auto linalgTilingOptions = linalg::LinalgTilingOptions()
609                                  .setPeeledLoops(peeledLoops)
610                                  .setLoopType(type);
611   if (scalarizeDynamicDims) {
612     linalgTilingOptions.scalarizeDynamicDims();
613     assert(tileSizes.empty() &&
614            "tileSizes and scalarizeDynamicDims is mutually exclusive");
615   } else {
616     linalgTilingOptions.setTileSizes(tileSizes);
617   }
618   if (!paddedOperands.empty()) {
619     auto paddingFunc = [&](OpBuilder &b,
620                            OpOperand &opOperand) -> FailureOr<Value> {
621       if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0)
622         return failure();
623       return getNeutralOfLinalgOp(b, opOperand);
624     };
625     auto nofoldFunc = [&](OpOperand &opOperand) {
626       if (llvm::count(nofoldOperands, opOperand.getOperandNumber()) != 0)
627         return true;
628       return false;
629     };
630     linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc);
631     linalgTilingOptions.setPaddingNoFoldComputationFunction(nofoldFunc);
632   }
633   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
634                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
635       context, linalgTilingOptions,
636       linalg::LinalgTransformationFilter(Identifier::get("tile", context)));
637   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
638 }
639 
640 static void applyInterchangePattern(FuncOp funcOp,
641                                     ArrayRef<unsigned> interchangeVector) {
642   MLIRContext *context = funcOp.getContext();
643   RewritePatternSet interchangePattern(context);
644   interchangePattern.add<GenericOpInterchangePattern>(
645       context, interchangeVector,
646       LinalgTransformationFilter(ArrayRef<Identifier>{},
647                                  Identifier::get("interchange", context)));
648   (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
649 }
650 
651 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
652 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
653 
654 namespace {
655 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
656 /// `idx`-th loop contains only "full" iterations and a second loop for the
657 /// remaining partial iteration (if any).
658 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
659   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
660       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
661   }
662 
663   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
664                                 PatternRewriter &rewriter) const override {
665     SmallVector<int64_t> peeledLoops;
666     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
667       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
668       peeledLoops =
669           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
670             return attr.cast<IntegerAttr>().getInt();
671           }));
672       // Check if the loop was already peeled.
673       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
674         return failure();
675     }
676     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
677       // No peeling of loop nests with a partial iteration.
678       return failure();
679 
680     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
681       return failure();
682 
683     // Peel loop and canonicalize.
684     TiledLoopOp result;
685     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
686                                                     result)))
687       return failure();
688 
689     // Apply label, so that the same loop is not rewritten a second time.
690     peeledLoops.push_back(idx);
691     rewriter.updateRootInPlace(loopOp, [&]() {
692       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
693     });
694     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
695     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
696 
697     return success();
698   }
699 
700   /// Index of loop to peel.
701   int64_t idx;
702 
703   /// If set to true, do not peel TiledLoopOps with a partial iteration.
704   bool skipPartial;
705 };
706 } // namespace
707 
708 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
709                                          ArrayRef<unsigned> loops,
710                                          bool skipPartial) {
711   MLIRContext *ctx = funcOp.getContext();
712   RewritePatternSet patterns(ctx);
713   for (unsigned idx : loops)
714     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
715   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
716 
717   // Drop the markers.
718   funcOp.walk([](TiledLoopOp op) {
719     op->removeAttr(kPeeledLoopsLabel);
720     op->removeAttr(kPartialIterationLabel);
721   });
722 }
723 
724 /// Apply transformations specified as patterns.
725 void TestLinalgTransforms::runOnFunction() {
726   auto lambda = [&](void *) {
727     getFunction().walk([](LinalgOp op) {
728       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
729     });
730   };
731   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
732 
733   if (testPromotionOptions) {
734     RewritePatternSet patterns(&getContext());
735     fillPromotionCallBackPatterns(&getContext(), patterns);
736     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
737     return;
738   }
739   if (testTileAndDistributionOptions) {
740     RewritePatternSet patterns(&getContext());
741     fillTileAndDistributePatterns(&getContext(), patterns);
742     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
743     return;
744   }
745   if (testPatterns)
746     return applyPatterns(getFunction());
747   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
748     return applyMatmulToVectorPatterns(getFunction(),
749                                        testMatmulToVectorPatterns1dTiling,
750                                        testMatmulToVectorPatterns2dTiling);
751   if (testVectorTransferForwardingPatterns)
752     return applyVectorTransferForwardingPatterns(getFunction());
753   if (testGenericToVectorPattern)
754     return applyLinalgToVectorPatterns(getFunction());
755   if (testTransformPadTensor)
756     return applyPadTensorToGenericPatterns(getFunction());
757   if (testGeneralizePadTensor)
758     return applyGeneralizePadTensorPatterns(getFunction());
759   if (testSwapSubTensorPadTensor)
760     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
761   if (testTiledLoopPeeling.hasValue())
762     return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
763                                         skipPartial);
764   if (testTilePattern)
765     return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
766                             nofoldOperands, peeledLoops,
767                             /*scalarizeDynamicDims=*/false);
768   if (testTileScalarizeDynamicDims)
769     return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
770                             nofoldOperands,
771                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
772   if (testHoistPadding) {
773     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
774       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
775     });
776   }
777   if (testInterchangePattern.hasValue())
778     return applyInterchangePattern(getFunction(), testInterchangePattern);
779 }
780 
781 namespace mlir {
782 namespace test {
783 void registerTestLinalgTransforms() {
784   PassRegistration<TestLinalgTransforms>();
785 }
786 } // namespace test
787 } // namespace mlir
788