1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 #include "llvm/ADT/SetVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 namespace {
30 struct TestLinalgTransforms
31     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
32   TestLinalgTransforms() = default;
33   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
34 
35   void getDependentDialects(DialectRegistry &registry) const override {
36     // clang-format off
37     registry.insert<AffineDialect,
38                     memref::MemRefDialect,
39                     scf::SCFDialect,
40                     StandardOpsDialect,
41                     vector::VectorDialect,
42                     gpu::GPUDialect>();
43     // clang-format on
44   }
45   StringRef getArgument() const final {
46     return "test-linalg-transform-patterns";
47   }
48   StringRef getDescription() const final {
49     return "Test Linalg transformation patterns by applying them greedily.";
50   }
51 
52   void runOnFunction() override;
53 
54   Option<bool> testPatterns{*this, "test-patterns",
55                             llvm::cl::desc("Test a mixed set of patterns"),
56                             llvm::cl::init(false)};
57   Option<bool> testMatmulToVectorPatterns1dTiling{
58       *this, "test-matmul-to-vector-patterns-tile-1d",
59       llvm::cl::desc(
60           "Test a fused pass that applies patterns from matmul to vectors via "
61           "1-d tiling"),
62       llvm::cl::init(false)};
63   Option<bool> testMatmulToVectorPatterns2dTiling{
64       *this, "test-matmul-to-vector-patterns-tile-2d",
65       llvm::cl::desc(
66           "Test a fused pass that applies patterns from matmul to vectors via "
67           "2-d tiling"),
68       llvm::cl::init(false)};
69   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
70                                     llvm::cl::desc("Test promotion options"),
71                                     llvm::cl::init(false)};
72   Option<bool> testTileAndDistributionOptions{
73       *this, "test-tile-and-distribute-options",
74       llvm::cl::desc("Test tile and distribute options"),
75       llvm::cl::init(false)};
76   Option<bool> testVectorTransferForwardingPatterns{
77       *this, "test-vector-transfer-forwarding-patterns",
78       llvm::cl::desc(
79           "Test a fused pass that forwards linalg.copy to vector.transfer"),
80       llvm::cl::init(false)};
81   Option<bool> testGenericToVectorPattern{
82       *this, "test-linalg-to-vector-patterns",
83       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
84                      "in vector.contract form"),
85       llvm::cl::init(false)};
86   Option<bool> testAffineMinSCFCanonicalizationPatterns{
87       *this, "test-affine-min-scf-canonicalization-patterns",
88       llvm::cl::desc("Test affine-min + scf canonicalization patterns."),
89       llvm::cl::init(false)};
90   Option<bool> testTileAndPadPattern{
91       *this, "test-tile-and-pad-pattern",
92       llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)};
93   Option<int> testHoistPadding{*this, "test-hoist-padding",
94                                llvm::cl::desc("Test hoist padding"),
95                                llvm::cl::init(0)};
96   Option<bool> testTransformPadTensor{
97       *this, "test-transform-pad-tensor",
98       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
99       llvm::cl::init(false)};
100   Option<bool> testGeneralizePadTensor{
101       *this, "test-generalize-pad-tensor",
102       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
103       llvm::cl::init(false)};
104   Option<bool> testSwapSubTensorPadTensor{
105       *this, "test-swap-subtensor-padtensor",
106       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
107                      "pad_tensor(subtensor)"),
108       llvm::cl::init(false)};
109   ListOption<int64_t> tileSizesForPadding{
110       *this, "tile-sizes-for-padding",
111       llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
112       llvm::cl::MiscFlags::CommaSeparated};
113   ListOption<unsigned> testInterchangePattern{
114       *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
115       llvm::cl::desc("Test the interchange pattern.")};
116 };
117 } // end anonymous namespace
118 
119 static void applyPatterns(FuncOp funcOp) {
120   MLIRContext *ctx = funcOp.getContext();
121   RewritePatternSet patterns(ctx);
122 
123   //===--------------------------------------------------------------------===//
124   // Linalg tiling patterns.
125   //===--------------------------------------------------------------------===//
126   patterns.add<LinalgTilingPattern<MatmulOp>>(
127       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
128       LinalgTransformationFilter(Identifier::get("MEM", ctx),
129                                  Identifier::get("L3", ctx)));
130   patterns.add<LinalgTilingPattern<MatmulOp>>(
131       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
132       LinalgTransformationFilter(Identifier::get("L3", ctx),
133                                  Identifier::get("L2", ctx)));
134   patterns.add<LinalgTilingPattern<MatmulOp>>(
135       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
136       LinalgTransformationFilter(Identifier::get("L2", ctx),
137                                  Identifier::get("L1", ctx)));
138   patterns.add<LinalgTilingPattern<MatmulOp>>(
139       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
140       LinalgTransformationFilter(Identifier::get("L1", ctx),
141                                  Identifier::get("REG", ctx)));
142 
143   patterns.add<LinalgTilingPattern<MatvecOp>>(
144       ctx,
145       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
146           LinalgTilingLoopType::ParallelLoops),
147       LinalgTransformationFilter(ArrayRef<Identifier>{},
148                                  Identifier::get("L1", ctx)));
149 
150   patterns.add<LinalgTilingPattern<DotOp>>(
151       ctx, LinalgTilingOptions().setTileSizes(8000),
152       LinalgTransformationFilter(
153           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
154                                Identifier::get("L3", ctx),
155                                Identifier::get("L2", ctx)},
156           Identifier::get("REG", ctx)));
157 
158   //===--------------------------------------------------------------------===//
159   // Linalg tiling and permutation patterns.
160   //===--------------------------------------------------------------------===//
161   patterns.add<LinalgTilingPattern<MatmulOp>>(
162       ctx,
163       LinalgTilingOptions()
164           .setTileSizes({2000, 3000, 4000})
165           .setInterchange({1, 2, 0}),
166       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
167                                  Identifier::get("L2__with_perm__", ctx)));
168   patterns.add<LinalgTilingPattern<MatmulOp>>(
169       ctx,
170       LinalgTilingOptions()
171           .setTileSizes({200, 300, 400})
172           .setInterchange({1, 0, 2}),
173       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
174                                  Identifier::get("L1__with_perm__", ctx)));
175   patterns.add<LinalgTilingPattern<MatmulOp>>(
176       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
177       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
178                                  Identifier::get("REG__with_perm__", ctx)));
179 
180   patterns.add<LinalgTilingPattern<MatvecOp>>(
181       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
182       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
183                                  Identifier::get("L1__with_perm__", ctx)));
184 
185   patterns.add<LinalgTilingPattern<MatmulOp>>(
186       ctx,
187       LinalgTilingOptions()
188           .setTileSizes({16, 8, 4})
189           .setInterchange({1, 2, 0})
190           .setLoopType(LinalgTilingLoopType::ParallelLoops),
191       LinalgTransformationFilter(
192           Identifier::get("par__with_perm__", ctx),
193           Identifier::get("after_par__with_perm__", ctx)));
194 
195   //===--------------------------------------------------------------------===//
196   // Linalg to loops patterns.
197   //===--------------------------------------------------------------------===//
198   patterns.add<LinalgLoweringPattern<DotOp>>(
199       ctx,
200       /*loweringType=*/LinalgLoweringType::Loops,
201       LinalgTransformationFilter(Identifier::get("REG", ctx)));
202 
203   //===--------------------------------------------------------------------===//
204   // Linalg distribution patterns.
205   //===--------------------------------------------------------------------===//
206   LinalgLoopDistributionOptions distributionOptions;
207 
208   //===--------------------------------------------------------------------===//
209   // Linalg to vector contraction patterns.
210   //===--------------------------------------------------------------------===//
211   patterns.add<LinalgVectorizationPattern>(
212       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
213                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
214 
215   //===--------------------------------------------------------------------===//
216   // Linalg generic interchange pattern.
217   //===--------------------------------------------------------------------===//
218   patterns.add<GenericOpInterchangePattern>(
219       ctx,
220       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
221       LinalgTransformationFilter(ArrayRef<Identifier>{},
222                                  Identifier::get("PERMUTED", ctx)));
223 
224   //===--------------------------------------------------------------------===//
225   // Linalg subview operands promotion.
226   //===--------------------------------------------------------------------===//
227   patterns.add<LinalgPromotionPattern<MatmulOp>>(
228       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
229       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
230                                  Identifier::get("_views_promoted_", ctx)));
231   patterns.add<LinalgPromotionPattern<MatmulOp>>(
232       ctx,
233       LinalgPromotionOptions()
234           .setOperandsToPromote({0})
235           .setUseFullTileBuffersByDefault(true),
236       LinalgTransformationFilter(
237           Identifier::get("_promote_first_view_", ctx),
238           Identifier::get("_first_view_promoted_", ctx)));
239   patterns.add<LinalgPromotionPattern<FillOp>>(
240       ctx,
241       LinalgPromotionOptions()
242           .setOperandsToPromote({1})
243           .setUseFullTileBuffers({false, true})
244           .setAlignment(32),
245       LinalgTransformationFilter(
246           Identifier::get("_promote_views_aligned_", ctx),
247           Identifier::get("_views_aligned_promoted_", ctx)));
248 
249   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
250 
251   // Drop the marker.
252   funcOp.walk([](LinalgOp op) {
253     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
254   });
255 }
256 
257 static void fillL1TilingAndMatmulToVectorPatterns(
258     FuncOp funcOp, StringRef startMarker,
259     SmallVectorImpl<RewritePatternSet> &patternsVector) {
260   MLIRContext *ctx = funcOp.getContext();
261   patternsVector.emplace_back(
262       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
263                ctx,
264                LinalgTilingOptions()
265                    .setTileSizes({8, 12, 16})
266                    .setInterchange({1, 0, 2}),
267                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
268                                           Identifier::get("L1", ctx))));
269 
270   patternsVector.emplace_back(
271       ctx,
272       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
273           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
274           LinalgTransformationFilter(Identifier::get("L1", ctx),
275                                      Identifier::get("VEC", ctx))));
276 
277   patternsVector.emplace_back(
278       ctx, std::make_unique<LinalgVectorizationPattern>(
279                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
280                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
281   patternsVector.back().add<LinalgVectorizationPattern>(
282       ctx, LinalgTransformationFilter().addFilter(
283                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // Test promotion callbacks
288 //===----------------------------------------------------------------------===//
289 
290 // Allocation call back
291 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
292                                        ArrayRef<Value> boundingSubViewSize,
293                                        DataLayout &layout) {
294   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
295   return b
296       .create<memref::AllocOp>(
297           subView.getLoc(),
298           MemRefType::get(shape, subView.getType().getElementType(),
299                           /*affineMapComposition =*/{}, 3),
300           boundingSubViewSize)
301       .getResult();
302 }
303 
304 // Deallocation callback
305 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
306   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
307   return success();
308 }
309 
310 // Copy in call back
311 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
312                                     bool isOutput) {
313   auto floatType = src.getType().cast<MemRefType>().getElementType();
314   if (!floatType.isa<FloatType>())
315     return failure();
316   if (!isOutput) {
317     Value cst =
318         b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0));
319     b.create<FillOp>(src.getLoc(), cst, dst);
320   }
321   b.create<CopyOp>(src.getLoc(), src, dst);
322   return success();
323 }
324 
325 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
326                                           RewritePatternSet &patterns) {
327   patterns.add<LinalgTilingPattern<MatmulOp>>(
328       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
329       LinalgTransformationFilter(Identifier::get("START", ctx),
330                                  Identifier::get("PROMOTE", ctx)));
331   patterns.add<LinalgPromotionPattern<MatmulOp>>(
332       ctx,
333       LinalgPromotionOptions()
334           .setOperandsToPromote({0, 2})
335           .setUseFullTileBuffers({false, false})
336           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
337           .setCopyInOutFns(
338               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
339                 return copyCallBackFn(b, src, dst, false);
340               },
341               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
342                 return copyCallBackFn(b, src, dst, true);
343               }),
344       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
345 }
346 
347 template <typename IdOp, typename NProcsOp>
348 static SmallVector<ProcInfo, 2>
349 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
350   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
351   SmallVector<ProcInfo, 2> procInfo(count);
352   const char *xyz[] = {"x", "y", "z"};
353   Type indexType = b.getIndexType();
354   for (unsigned i = 0; i < count; ++i) {
355     procInfo[count - 1 - i] = {
356         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
357         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
358   }
359   return procInfo;
360 }
361 
362 static void fillTileAndDistributePatterns(MLIRContext *context,
363                                           RewritePatternSet &patterns) {
364   {
365     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
366     cyclicNprocsEqNiters.distributionMethod.resize(
367         2, DistributionMethod::CyclicNumProcsEqNumIters);
368     cyclicNprocsEqNiters.procInfo =
369         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
370     patterns.add<LinalgTilingPattern<MatmulOp>>(
371         context,
372         LinalgTilingOptions()
373             .setTileSizes({8, 8, 4})
374             .setLoopType(LinalgTilingLoopType::ParallelLoops)
375             .setDistributionOptions(cyclicNprocsEqNiters),
376         LinalgTransformationFilter(
377             Identifier::get("distribute1", context),
378             Identifier::get("after_distribute1", context)));
379   }
380 
381   {
382     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
383     cyclicNprocsGeNiters.distributionMethod.resize(
384         2, DistributionMethod::CyclicNumProcsGeNumIters);
385     cyclicNprocsGeNiters.procInfo =
386         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
387     patterns.add<LinalgTilingPattern<MatmulOp>>(
388         context,
389         LinalgTilingOptions()
390             .setTileSizes({8, 8, 4})
391             .setLoopType(LinalgTilingLoopType::ParallelLoops)
392             .setDistributionOptions(cyclicNprocsGeNiters),
393         LinalgTransformationFilter(
394             Identifier::get("distribute2", context),
395             Identifier::get("after_distribute2", context)));
396   }
397 
398   {
399     LinalgLoopDistributionOptions cyclicNprocsDefault;
400     cyclicNprocsDefault.distributionMethod.resize(2,
401                                                   DistributionMethod::Cyclic);
402     cyclicNprocsDefault.procInfo =
403         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
404     patterns.add<LinalgTilingPattern<MatmulOp>>(
405         context,
406         LinalgTilingOptions()
407             .setTileSizes({8, 8, 4})
408             .setLoopType(LinalgTilingLoopType::ParallelLoops)
409             .setDistributionOptions(cyclicNprocsDefault),
410         LinalgTransformationFilter(
411             Identifier::get("distribute3", context),
412             Identifier::get("after_distribute3", context)));
413   }
414 
415   {
416     LinalgLoopDistributionOptions cyclicNprocsMixed1;
417     cyclicNprocsMixed1.distributionMethod = {
418         DistributionMethod::CyclicNumProcsEqNumIters,
419         DistributionMethod::CyclicNumProcsGeNumIters};
420     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
421     patterns.add<LinalgTilingPattern<MatmulOp>>(
422         context,
423         LinalgTilingOptions()
424             .setTileSizes({8, 8, 4})
425             .setLoopType(LinalgTilingLoopType::ParallelLoops)
426             .setDistributionOptions(cyclicNprocsMixed1),
427         LinalgTransformationFilter(
428             Identifier::get("distribute4", context),
429             Identifier::get("after_distribute4", context)));
430   }
431 
432   {
433     LinalgLoopDistributionOptions cyclicNprocsMixed2;
434     cyclicNprocsMixed2.distributionMethod = {
435         DistributionMethod::CyclicNumProcsGeNumIters,
436         DistributionMethod::Cyclic};
437     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
438     patterns.add<LinalgTilingPattern<MatmulOp>>(
439         context,
440         LinalgTilingOptions()
441             .setTileSizes({8, 8, 4})
442             .setLoopType(LinalgTilingLoopType::ParallelLoops)
443             .setDistributionOptions(cyclicNprocsMixed2),
444         LinalgTransformationFilter(
445             Identifier::get("distribute5", context),
446             Identifier::get("after_distribute5", context)));
447   }
448 
449   {
450     LinalgLoopDistributionOptions cyclicNprocsMixed3;
451     cyclicNprocsMixed3.distributionMethod = {
452         DistributionMethod::Cyclic,
453         DistributionMethod::CyclicNumProcsEqNumIters};
454     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
455 
456     patterns.add<LinalgTilingPattern<MatmulOp>>(
457         context,
458         LinalgTilingOptions()
459             .setTileSizes({8, 8, 4})
460             .setLoopType(LinalgTilingLoopType::ParallelLoops)
461             .setDistributionOptions(cyclicNprocsMixed3),
462         LinalgTransformationFilter(
463             Identifier::get("distribute6", context),
464             Identifier::get("after_distribute6", context)));
465   }
466 
467   {
468     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
469     cyclicNprocsEqNiters.distributionMethod.resize(2,
470                                                    DistributionMethod::Cyclic);
471     cyclicNprocsEqNiters.procInfo =
472         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
473     patterns.add<LinalgTilingPattern<MatmulOp>>(
474         context,
475         LinalgTilingOptions()
476             .setTileSizes({8, 8, 4})
477             .setLoopType(LinalgTilingLoopType::Loops)
478             .setDistributionOptions(cyclicNprocsEqNiters),
479         LinalgTransformationFilter(
480             Identifier::get("tensors_distribute1", context),
481             Identifier::get("tensors_after_distribute1", context)));
482   }
483 }
484 
485 static void
486 applyMatmulToVectorPatterns(FuncOp funcOp,
487                             bool testMatmulToVectorPatterns1dTiling,
488                             bool testMatmulToVectorPatterns2dTiling) {
489   MLIRContext *ctx = funcOp.getContext();
490   SmallVector<RewritePatternSet, 4> stage1Patterns;
491   if (testMatmulToVectorPatterns1dTiling) {
492     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
493                                           stage1Patterns);
494   } else if (testMatmulToVectorPatterns2dTiling) {
495     stage1Patterns.emplace_back(
496         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
497                  ctx,
498                  LinalgTilingOptions()
499                      .setTileSizes({768, 264, 768})
500                      .setInterchange({1, 2, 0}),
501                  LinalgTransformationFilter(Identifier::get("START", ctx),
502                                             Identifier::get("L2", ctx))));
503     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
504                                           stage1Patterns);
505   }
506   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
507   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
508   FrozenRewritePatternSet stage2Patterns =
509       getLinalgTilingCanonicalizationPatterns(ctx);
510   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
511                             std::move(stage2Patterns));
512 }
513 
514 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
515   RewritePatternSet forwardPattern(funcOp.getContext());
516   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
517   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
518   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
519 }
520 
521 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
522   RewritePatternSet patterns(funcOp.getContext());
523   patterns.add<LinalgVectorizationPattern>(
524       funcOp.getContext(),
525       LinalgTransformationFilter()
526           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
527   populatePadTensorOpVectorizationPatterns(patterns);
528   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
529 }
530 
531 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
532   RewritePatternSet patterns(funcOp.getContext());
533   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
534   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
535 }
536 
537 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
538   RewritePatternSet patterns(funcOp.getContext());
539   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
540   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
541 }
542 
543 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
544   RewritePatternSet patterns(funcOp.getContext());
545   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
546   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
547 }
548 
549 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
550   RewritePatternSet foldPattern(funcOp.getContext());
551   foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
552   FrozenRewritePatternSet frozenPatterns(std::move(foldPattern));
553 
554   // Explicitly walk and apply the pattern locally to avoid more general folding
555   // on the rest of the IR.
556   funcOp.walk([&frozenPatterns](AffineMinOp minOp) {
557     (void)applyOpPatternsAndFold(minOp, frozenPatterns);
558   });
559 }
560 
561 // For now, just assume it is the zero of type.
562 // In the future, it should be the zero of type + op.
563 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
564   auto t = getElementTypeOrSelf(op.get());
565   return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
566 }
567 
568 static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) {
569   MLIRContext *context = funcOp.getContext();
570   RewritePatternSet tilingPattern(context);
571   auto linalgTilingOptions =
572       linalg::LinalgTilingOptions()
573           .setTileSizes(tileSizes)
574           .setPaddingValueComputationFunction(getNeutralOfLinalgOp);
575   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>,
576                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
577       context, linalgTilingOptions,
578       linalg::LinalgTransformationFilter(
579           Identifier::get("tile-and-pad", context)));
580   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
581 }
582 
583 static void applyInterchangePattern(FuncOp funcOp,
584                                     ArrayRef<unsigned> interchangeVector) {
585   MLIRContext *context = funcOp.getContext();
586   RewritePatternSet interchangePattern(context);
587   interchangePattern.add<GenericOpInterchangePattern>(
588       context, interchangeVector,
589       LinalgTransformationFilter(ArrayRef<Identifier>{},
590                                  Identifier::get("interchange", context)));
591   (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
592 }
593 
594 /// Apply transformations specified as patterns.
595 void TestLinalgTransforms::runOnFunction() {
596   auto lambda = [&](void *) {
597     getFunction().walk([](LinalgOp op) {
598       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
599     });
600   };
601   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
602 
603   if (testPromotionOptions) {
604     RewritePatternSet patterns(&getContext());
605     fillPromotionCallBackPatterns(&getContext(), patterns);
606     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
607     return;
608   }
609   if (testTileAndDistributionOptions) {
610     RewritePatternSet patterns(&getContext());
611     fillTileAndDistributePatterns(&getContext(), patterns);
612     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
613     return;
614   }
615   if (testPatterns)
616     return applyPatterns(getFunction());
617   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
618     return applyMatmulToVectorPatterns(getFunction(),
619                                        testMatmulToVectorPatterns1dTiling,
620                                        testMatmulToVectorPatterns2dTiling);
621   if (testVectorTransferForwardingPatterns)
622     return applyVectorTransferForwardingPatterns(getFunction());
623   if (testGenericToVectorPattern)
624     return applyLinalgToVectorPatterns(getFunction());
625   if (testTransformPadTensor)
626     return applyPadTensorToGenericPatterns(getFunction());
627   if (testGeneralizePadTensor)
628     return applyGeneralizePadTensorPatterns(getFunction());
629   if (testSwapSubTensorPadTensor)
630     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
631   if (testAffineMinSCFCanonicalizationPatterns)
632     return applyAffineMinSCFCanonicalizationPatterns(getFunction());
633   if (testTileAndPadPattern)
634     return applyTileAndPadPattern(getFunction(), tileSizesForPadding);
635   if (testHoistPadding) {
636     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
637       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
638     });
639   }
640   if (testInterchangePattern.hasValue())
641     return applyInterchangePattern(getFunction(), testInterchangePattern);
642 }
643 
644 namespace mlir {
645 namespace test {
646 void registerTestLinalgTransforms() {
647   PassRegistration<TestLinalgTransforms>();
648 }
649 } // namespace test
650 } // namespace mlir
651