1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 #include "llvm/ADT/SetVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 namespace {
30 struct TestLinalgTransforms
31     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
32   TestLinalgTransforms() = default;
33   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
34 
35   void getDependentDialects(DialectRegistry &registry) const override {
36     // clang-format off
37     registry.insert<AffineDialect,
38                     memref::MemRefDialect,
39                     scf::SCFDialect,
40                     StandardOpsDialect,
41                     vector::VectorDialect,
42                     gpu::GPUDialect>();
43     // clang-format on
44   }
45   StringRef getArgument() const final {
46     return "test-linalg-transform-patterns";
47   }
48   StringRef getDescription() const final {
49     return "Test Linalg transformation patterns by applying them greedily.";
50   }
51 
52   void runOnFunction() override;
53 
54   Option<bool> testPatterns{*this, "test-patterns",
55                             llvm::cl::desc("Test a mixed set of patterns"),
56                             llvm::cl::init(false)};
57   Option<bool> testMatmulToVectorPatterns1dTiling{
58       *this, "test-matmul-to-vector-patterns-tile-1d",
59       llvm::cl::desc(
60           "Test a fused pass that applies patterns from matmul to vectors via "
61           "1-d tiling"),
62       llvm::cl::init(false)};
63   Option<bool> testMatmulToVectorPatterns2dTiling{
64       *this, "test-matmul-to-vector-patterns-tile-2d",
65       llvm::cl::desc(
66           "Test a fused pass that applies patterns from matmul to vectors via "
67           "2-d tiling"),
68       llvm::cl::init(false)};
69   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
70                                     llvm::cl::desc("Test promotion options"),
71                                     llvm::cl::init(false)};
72   Option<bool> testTileAndDistributionOptions{
73       *this, "test-tile-and-distribute-options",
74       llvm::cl::desc("Test tile and distribute options"),
75       llvm::cl::init(false)};
76   Option<bool> testVectorTransferForwardingPatterns{
77       *this, "test-vector-transfer-forwarding-patterns",
78       llvm::cl::desc(
79           "Test a fused pass that forwards linalg.copy to vector.transfer"),
80       llvm::cl::init(false)};
81   Option<bool> testGenericToVectorPattern{
82       *this, "test-linalg-to-vector-patterns",
83       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
84                      "in vector.contract form"),
85       llvm::cl::init(false)};
86   Option<bool> testTileAndPadPattern{
87       *this, "test-tile-and-pad-pattern",
88       llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)};
89   Option<int> testHoistPadding{*this, "test-hoist-padding",
90                                llvm::cl::desc("Test hoist padding"),
91                                llvm::cl::init(0)};
92   Option<bool> testTransformPadTensor{
93       *this, "test-transform-pad-tensor",
94       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
95       llvm::cl::init(false)};
96   Option<bool> testGeneralizePadTensor{
97       *this, "test-generalize-pad-tensor",
98       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
99       llvm::cl::init(false)};
100   Option<bool> testSwapSubTensorPadTensor{
101       *this, "test-swap-subtensor-padtensor",
102       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
103                      "pad_tensor(subtensor)"),
104       llvm::cl::init(false)};
105   ListOption<int64_t> tileSizesForPadding{
106       *this, "tile-sizes-for-padding",
107       llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
108       llvm::cl::MiscFlags::CommaSeparated};
109   ListOption<unsigned> testInterchangePattern{
110       *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
111       llvm::cl::desc("Test the interchange pattern.")};
112 };
113 } // end anonymous namespace
114 
115 static void applyPatterns(FuncOp funcOp) {
116   MLIRContext *ctx = funcOp.getContext();
117   RewritePatternSet patterns(ctx);
118 
119   //===--------------------------------------------------------------------===//
120   // Linalg tiling patterns.
121   //===--------------------------------------------------------------------===//
122   patterns.add<LinalgTilingPattern<MatmulOp>>(
123       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
124       LinalgTransformationFilter(Identifier::get("MEM", ctx),
125                                  Identifier::get("L3", ctx)));
126   patterns.add<LinalgTilingPattern<MatmulOp>>(
127       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
128       LinalgTransformationFilter(Identifier::get("L3", ctx),
129                                  Identifier::get("L2", ctx)));
130   patterns.add<LinalgTilingPattern<MatmulOp>>(
131       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
132       LinalgTransformationFilter(Identifier::get("L2", ctx),
133                                  Identifier::get("L1", ctx)));
134   patterns.add<LinalgTilingPattern<MatmulOp>>(
135       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
136       LinalgTransformationFilter(Identifier::get("L1", ctx),
137                                  Identifier::get("REG", ctx)));
138 
139   patterns.add<LinalgTilingPattern<MatvecOp>>(
140       ctx,
141       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
142           LinalgTilingLoopType::ParallelLoops),
143       LinalgTransformationFilter(ArrayRef<Identifier>{},
144                                  Identifier::get("L1", ctx)));
145 
146   patterns.add<LinalgTilingPattern<DotOp>>(
147       ctx, LinalgTilingOptions().setTileSizes(8000),
148       LinalgTransformationFilter(
149           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
150                                Identifier::get("L3", ctx),
151                                Identifier::get("L2", ctx)},
152           Identifier::get("REG", ctx)));
153 
154   //===--------------------------------------------------------------------===//
155   // Linalg tiling and permutation patterns.
156   //===--------------------------------------------------------------------===//
157   patterns.add<LinalgTilingPattern<MatmulOp>>(
158       ctx,
159       LinalgTilingOptions()
160           .setTileSizes({2000, 3000, 4000})
161           .setInterchange({1, 2, 0}),
162       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
163                                  Identifier::get("L2__with_perm__", ctx)));
164   patterns.add<LinalgTilingPattern<MatmulOp>>(
165       ctx,
166       LinalgTilingOptions()
167           .setTileSizes({200, 300, 400})
168           .setInterchange({1, 0, 2}),
169       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
170                                  Identifier::get("L1__with_perm__", ctx)));
171   patterns.add<LinalgTilingPattern<MatmulOp>>(
172       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
173       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
174                                  Identifier::get("REG__with_perm__", ctx)));
175 
176   patterns.add<LinalgTilingPattern<MatvecOp>>(
177       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
178       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
179                                  Identifier::get("L1__with_perm__", ctx)));
180 
181   patterns.add<LinalgTilingPattern<MatmulOp>>(
182       ctx,
183       LinalgTilingOptions()
184           .setTileSizes({16, 8, 4})
185           .setInterchange({1, 2, 0})
186           .setLoopType(LinalgTilingLoopType::ParallelLoops),
187       LinalgTransformationFilter(
188           Identifier::get("par__with_perm__", ctx),
189           Identifier::get("after_par__with_perm__", ctx)));
190 
191   //===--------------------------------------------------------------------===//
192   // Linalg to loops patterns.
193   //===--------------------------------------------------------------------===//
194   patterns.add<LinalgLoweringPattern<DotOp>>(
195       ctx,
196       /*loweringType=*/LinalgLoweringType::Loops,
197       LinalgTransformationFilter(Identifier::get("REG", ctx)));
198 
199   //===--------------------------------------------------------------------===//
200   // Linalg distribution patterns.
201   //===--------------------------------------------------------------------===//
202   LinalgLoopDistributionOptions distributionOptions;
203 
204   //===--------------------------------------------------------------------===//
205   // Linalg to vector contraction patterns.
206   //===--------------------------------------------------------------------===//
207   patterns.add<LinalgVectorizationPattern>(
208       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
209                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
210 
211   //===--------------------------------------------------------------------===//
212   // Linalg generic interchange pattern.
213   //===--------------------------------------------------------------------===//
214   patterns.add<GenericOpInterchangePattern>(
215       ctx,
216       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
217       LinalgTransformationFilter(ArrayRef<Identifier>{},
218                                  Identifier::get("PERMUTED", ctx)));
219 
220   //===--------------------------------------------------------------------===//
221   // Linalg subview operands promotion.
222   //===--------------------------------------------------------------------===//
223   patterns.add<LinalgPromotionPattern<MatmulOp>>(
224       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
225       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
226                                  Identifier::get("_views_promoted_", ctx)));
227   patterns.add<LinalgPromotionPattern<MatmulOp>>(
228       ctx,
229       LinalgPromotionOptions()
230           .setOperandsToPromote({0})
231           .setUseFullTileBuffersByDefault(true),
232       LinalgTransformationFilter(
233           Identifier::get("_promote_first_view_", ctx),
234           Identifier::get("_first_view_promoted_", ctx)));
235   patterns.add<LinalgPromotionPattern<FillOp>>(
236       ctx,
237       LinalgPromotionOptions()
238           .setOperandsToPromote({1})
239           .setUseFullTileBuffers({false, true})
240           .setAlignment(32),
241       LinalgTransformationFilter(
242           Identifier::get("_promote_views_aligned_", ctx),
243           Identifier::get("_views_aligned_promoted_", ctx)));
244 
245   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
246 
247   // Drop the marker.
248   funcOp.walk([](LinalgOp op) {
249     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
250   });
251 }
252 
253 static void fillL1TilingAndMatmulToVectorPatterns(
254     FuncOp funcOp, StringRef startMarker,
255     SmallVectorImpl<RewritePatternSet> &patternsVector) {
256   MLIRContext *ctx = funcOp.getContext();
257   patternsVector.emplace_back(
258       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
259                ctx,
260                LinalgTilingOptions()
261                    .setTileSizes({8, 12, 16})
262                    .setInterchange({1, 0, 2}),
263                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
264                                           Identifier::get("L1", ctx))));
265 
266   patternsVector.emplace_back(
267       ctx,
268       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
269           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
270           LinalgTransformationFilter(Identifier::get("L1", ctx),
271                                      Identifier::get("VEC", ctx))));
272 
273   patternsVector.emplace_back(
274       ctx, std::make_unique<LinalgVectorizationPattern>(
275                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
276                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
277   patternsVector.back().add<LinalgVectorizationPattern>(
278       ctx, LinalgTransformationFilter().addFilter(
279                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // Test promotion callbacks
284 //===----------------------------------------------------------------------===//
285 
286 // Allocation call back
287 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
288                                        ArrayRef<Value> boundingSubViewSize,
289                                        DataLayout &layout) {
290   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
291   return b
292       .create<memref::AllocOp>(
293           subView.getLoc(),
294           MemRefType::get(shape, subView.getType().getElementType(),
295                           /*affineMapComposition =*/{}, 3),
296           boundingSubViewSize)
297       .getResult();
298 }
299 
300 // Deallocation callback
301 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
302   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
303   return success();
304 }
305 
306 // Copy in call back
307 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
308                                     bool isOutput) {
309   auto floatType = src.getType().cast<MemRefType>().getElementType();
310   if (!floatType.isa<FloatType>())
311     return failure();
312   if (!isOutput) {
313     Value cst =
314         b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0));
315     b.create<FillOp>(src.getLoc(), cst, dst);
316   }
317   b.create<CopyOp>(src.getLoc(), src, dst);
318   return success();
319 }
320 
321 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
322                                           RewritePatternSet &patterns) {
323   patterns.add<LinalgTilingPattern<MatmulOp>>(
324       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
325       LinalgTransformationFilter(Identifier::get("START", ctx),
326                                  Identifier::get("PROMOTE", ctx)));
327   patterns.add<LinalgPromotionPattern<MatmulOp>>(
328       ctx,
329       LinalgPromotionOptions()
330           .setOperandsToPromote({0, 2})
331           .setUseFullTileBuffers({false, false})
332           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
333           .setCopyInOutFns(
334               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
335                 return copyCallBackFn(b, src, dst, false);
336               },
337               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
338                 return copyCallBackFn(b, src, dst, true);
339               }),
340       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
341 }
342 
343 template <typename IdOp, typename NProcsOp>
344 static SmallVector<ProcInfo, 2>
345 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
346   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
347   SmallVector<ProcInfo, 2> procInfo(count);
348   const char *xyz[] = {"x", "y", "z"};
349   Type indexType = b.getIndexType();
350   for (unsigned i = 0; i < count; ++i) {
351     procInfo[count - 1 - i] = {
352         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
353         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
354   }
355   return procInfo;
356 }
357 
358 static void fillTileAndDistributePatterns(MLIRContext *context,
359                                           RewritePatternSet &patterns) {
360   {
361     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
362     cyclicNprocsEqNiters.distributionMethod.resize(
363         2, DistributionMethod::CyclicNumProcsEqNumIters);
364     cyclicNprocsEqNiters.procInfo =
365         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
366     patterns.add<LinalgTilingPattern<MatmulOp>>(
367         context,
368         LinalgTilingOptions()
369             .setTileSizes({8, 8, 4})
370             .setLoopType(LinalgTilingLoopType::ParallelLoops)
371             .setDistributionOptions(cyclicNprocsEqNiters),
372         LinalgTransformationFilter(
373             Identifier::get("distribute1", context),
374             Identifier::get("after_distribute1", context)));
375   }
376 
377   {
378     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
379     cyclicNprocsGeNiters.distributionMethod.resize(
380         2, DistributionMethod::CyclicNumProcsGeNumIters);
381     cyclicNprocsGeNiters.procInfo =
382         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
383     patterns.add<LinalgTilingPattern<MatmulOp>>(
384         context,
385         LinalgTilingOptions()
386             .setTileSizes({8, 8, 4})
387             .setLoopType(LinalgTilingLoopType::ParallelLoops)
388             .setDistributionOptions(cyclicNprocsGeNiters),
389         LinalgTransformationFilter(
390             Identifier::get("distribute2", context),
391             Identifier::get("after_distribute2", context)));
392   }
393 
394   {
395     LinalgLoopDistributionOptions cyclicNprocsDefault;
396     cyclicNprocsDefault.distributionMethod.resize(2,
397                                                   DistributionMethod::Cyclic);
398     cyclicNprocsDefault.procInfo =
399         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
400     patterns.add<LinalgTilingPattern<MatmulOp>>(
401         context,
402         LinalgTilingOptions()
403             .setTileSizes({8, 8, 4})
404             .setLoopType(LinalgTilingLoopType::ParallelLoops)
405             .setDistributionOptions(cyclicNprocsDefault),
406         LinalgTransformationFilter(
407             Identifier::get("distribute3", context),
408             Identifier::get("after_distribute3", context)));
409   }
410 
411   {
412     LinalgLoopDistributionOptions cyclicNprocsMixed1;
413     cyclicNprocsMixed1.distributionMethod = {
414         DistributionMethod::CyclicNumProcsEqNumIters,
415         DistributionMethod::CyclicNumProcsGeNumIters};
416     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
417     patterns.add<LinalgTilingPattern<MatmulOp>>(
418         context,
419         LinalgTilingOptions()
420             .setTileSizes({8, 8, 4})
421             .setLoopType(LinalgTilingLoopType::ParallelLoops)
422             .setDistributionOptions(cyclicNprocsMixed1),
423         LinalgTransformationFilter(
424             Identifier::get("distribute4", context),
425             Identifier::get("after_distribute4", context)));
426   }
427 
428   {
429     LinalgLoopDistributionOptions cyclicNprocsMixed2;
430     cyclicNprocsMixed2.distributionMethod = {
431         DistributionMethod::CyclicNumProcsGeNumIters,
432         DistributionMethod::Cyclic};
433     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
434     patterns.add<LinalgTilingPattern<MatmulOp>>(
435         context,
436         LinalgTilingOptions()
437             .setTileSizes({8, 8, 4})
438             .setLoopType(LinalgTilingLoopType::ParallelLoops)
439             .setDistributionOptions(cyclicNprocsMixed2),
440         LinalgTransformationFilter(
441             Identifier::get("distribute5", context),
442             Identifier::get("after_distribute5", context)));
443   }
444 
445   {
446     LinalgLoopDistributionOptions cyclicNprocsMixed3;
447     cyclicNprocsMixed3.distributionMethod = {
448         DistributionMethod::Cyclic,
449         DistributionMethod::CyclicNumProcsEqNumIters};
450     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
451 
452     patterns.add<LinalgTilingPattern<MatmulOp>>(
453         context,
454         LinalgTilingOptions()
455             .setTileSizes({8, 8, 4})
456             .setLoopType(LinalgTilingLoopType::ParallelLoops)
457             .setDistributionOptions(cyclicNprocsMixed3),
458         LinalgTransformationFilter(
459             Identifier::get("distribute6", context),
460             Identifier::get("after_distribute6", context)));
461   }
462 
463   {
464     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
465     cyclicNprocsEqNiters.distributionMethod.resize(2,
466                                                    DistributionMethod::Cyclic);
467     cyclicNprocsEqNiters.procInfo =
468         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
469     patterns.add<LinalgTilingPattern<MatmulOp>>(
470         context,
471         LinalgTilingOptions()
472             .setTileSizes({8, 8, 4})
473             .setLoopType(LinalgTilingLoopType::Loops)
474             .setDistributionOptions(cyclicNprocsEqNiters),
475         LinalgTransformationFilter(
476             Identifier::get("tensors_distribute1", context),
477             Identifier::get("tensors_after_distribute1", context)));
478   }
479 }
480 
481 static void
482 applyMatmulToVectorPatterns(FuncOp funcOp,
483                             bool testMatmulToVectorPatterns1dTiling,
484                             bool testMatmulToVectorPatterns2dTiling) {
485   MLIRContext *ctx = funcOp.getContext();
486   SmallVector<RewritePatternSet, 4> stage1Patterns;
487   if (testMatmulToVectorPatterns1dTiling) {
488     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
489                                           stage1Patterns);
490   } else if (testMatmulToVectorPatterns2dTiling) {
491     stage1Patterns.emplace_back(
492         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
493                  ctx,
494                  LinalgTilingOptions()
495                      .setTileSizes({768, 264, 768})
496                      .setInterchange({1, 2, 0}),
497                  LinalgTransformationFilter(Identifier::get("START", ctx),
498                                             Identifier::get("L2", ctx))));
499     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
500                                           stage1Patterns);
501   }
502   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
503   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
504   FrozenRewritePatternSet stage2Patterns =
505       getLinalgTilingCanonicalizationPatterns(ctx);
506   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
507                             std::move(stage2Patterns));
508 }
509 
510 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
511   RewritePatternSet forwardPattern(funcOp.getContext());
512   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
513   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
514   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
515 }
516 
517 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
518   RewritePatternSet patterns(funcOp.getContext());
519   patterns.add<LinalgVectorizationPattern>(
520       funcOp.getContext(),
521       LinalgTransformationFilter()
522           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
523   populatePadTensorOpVectorizationPatterns(patterns);
524   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
525 }
526 
527 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
528   RewritePatternSet patterns(funcOp.getContext());
529   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
530   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
531 }
532 
533 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
534   RewritePatternSet patterns(funcOp.getContext());
535   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
536   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
537 }
538 
539 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
540   RewritePatternSet patterns(funcOp.getContext());
541   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
542   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
543 }
544 
545 // For now, just assume it is the zero of type.
546 // In the future, it should be the zero of type + op.
547 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
548   auto t = getElementTypeOrSelf(op.get());
549   return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
550 }
551 
552 static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) {
553   MLIRContext *context = funcOp.getContext();
554   RewritePatternSet tilingPattern(context);
555   auto linalgTilingOptions =
556       linalg::LinalgTilingOptions()
557           .setTileSizes(tileSizes)
558           .setPaddingValueComputationFunction(getNeutralOfLinalgOp);
559   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>,
560                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
561       context, linalgTilingOptions,
562       linalg::LinalgTransformationFilter(
563           Identifier::get("tile-and-pad", context)));
564   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
565 }
566 
567 static void applyInterchangePattern(FuncOp funcOp,
568                                     ArrayRef<unsigned> interchangeVector) {
569   MLIRContext *context = funcOp.getContext();
570   RewritePatternSet interchangePattern(context);
571   interchangePattern.add<GenericOpInterchangePattern>(
572       context, interchangeVector,
573       LinalgTransformationFilter(ArrayRef<Identifier>{},
574                                  Identifier::get("interchange", context)));
575   (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
576 }
577 
578 /// Apply transformations specified as patterns.
579 void TestLinalgTransforms::runOnFunction() {
580   auto lambda = [&](void *) {
581     getFunction().walk([](LinalgOp op) {
582       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
583     });
584   };
585   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
586 
587   if (testPromotionOptions) {
588     RewritePatternSet patterns(&getContext());
589     fillPromotionCallBackPatterns(&getContext(), patterns);
590     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
591     return;
592   }
593   if (testTileAndDistributionOptions) {
594     RewritePatternSet patterns(&getContext());
595     fillTileAndDistributePatterns(&getContext(), patterns);
596     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
597     return;
598   }
599   if (testPatterns)
600     return applyPatterns(getFunction());
601   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
602     return applyMatmulToVectorPatterns(getFunction(),
603                                        testMatmulToVectorPatterns1dTiling,
604                                        testMatmulToVectorPatterns2dTiling);
605   if (testVectorTransferForwardingPatterns)
606     return applyVectorTransferForwardingPatterns(getFunction());
607   if (testGenericToVectorPattern)
608     return applyLinalgToVectorPatterns(getFunction());
609   if (testTransformPadTensor)
610     return applyPadTensorToGenericPatterns(getFunction());
611   if (testGeneralizePadTensor)
612     return applyGeneralizePadTensorPatterns(getFunction());
613   if (testSwapSubTensorPadTensor)
614     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
615   if (testTileAndPadPattern)
616     return applyTileAndPadPattern(getFunction(), tileSizesForPadding);
617   if (testHoistPadding) {
618     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
619       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
620     });
621   }
622   if (testInterchangePattern.hasValue())
623     return applyInterchangePattern(getFunction(), testInterchangePattern);
624 }
625 
626 namespace mlir {
627 namespace test {
628 void registerTestLinalgTransforms() {
629   PassRegistration<TestLinalgTransforms>();
630 }
631 } // namespace test
632 } // namespace mlir
633