1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 #include "llvm/ADT/SetVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 namespace {
30 struct TestLinalgTransforms
31     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
32   TestLinalgTransforms() = default;
33   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
34 
35   void getDependentDialects(DialectRegistry &registry) const override {
36     // clang-format off
37     registry.insert<AffineDialect,
38                     memref::MemRefDialect,
39                     scf::SCFDialect,
40                     StandardOpsDialect,
41                     vector::VectorDialect,
42                     gpu::GPUDialect>();
43     // clang-format on
44   }
45 
46   void runOnFunction() override;
47 
48   Option<bool> testPatterns{*this, "test-patterns",
49                             llvm::cl::desc("Test a mixed set of patterns"),
50                             llvm::cl::init(false)};
51   Option<bool> testMatmulToVectorPatterns1dTiling{
52       *this, "test-matmul-to-vector-patterns-tile-1d",
53       llvm::cl::desc(
54           "Test a fused pass that applies patterns from matmul to vectors via "
55           "1-d tiling"),
56       llvm::cl::init(false)};
57   Option<bool> testMatmulToVectorPatterns2dTiling{
58       *this, "test-matmul-to-vector-patterns-tile-2d",
59       llvm::cl::desc(
60           "Test a fused pass that applies patterns from matmul to vectors via "
61           "2-d tiling"),
62       llvm::cl::init(false)};
63   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
64                                     llvm::cl::desc("Test promotion options"),
65                                     llvm::cl::init(false)};
66   Option<bool> testTileAndDistributionOptions{
67       *this, "test-tile-and-distribute-options",
68       llvm::cl::desc("Test tile and distribute options"),
69       llvm::cl::init(false)};
70   Option<bool> testVectorTransferForwardingPatterns{
71       *this, "test-vector-transfer-forwarding-patterns",
72       llvm::cl::desc(
73           "Test a fused pass that forwards linalg.copy to vector.transfer"),
74       llvm::cl::init(false)};
75   Option<bool> testGenericToVectorPattern{
76       *this, "test-linalg-to-vector-patterns",
77       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
78                      "in vector.contract form"),
79       llvm::cl::init(false)};
80   Option<bool> testAffineMinSCFCanonicalizationPatterns{
81       *this, "test-affine-min-scf-canonicalization-patterns",
82       llvm::cl::desc("Test affine-min + scf canonicalization patterns."),
83       llvm::cl::init(false)};
84   Option<bool> testTileAndPadPattern{
85       *this, "test-tile-and-pad-pattern",
86       llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)};
87   Option<int> testHoistPadding{*this, "test-hoist-padding",
88                                llvm::cl::desc("Test hoist padding"),
89                                llvm::cl::init(0)};
90   ListOption<int64_t> tileSizesForPadding{
91       *this, "tile-sizes-for-padding",
92       llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
93       llvm::cl::MiscFlags::CommaSeparated};
94   ListOption<unsigned> testInterchangePattern{
95       *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
96       llvm::cl::desc("Test the interchange pattern.")};
97 };
98 } // end anonymous namespace
99 
100 static void applyPatterns(FuncOp funcOp) {
101   MLIRContext *ctx = funcOp.getContext();
102   RewritePatternSet patterns(ctx);
103 
104   //===--------------------------------------------------------------------===//
105   // Linalg tiling patterns.
106   //===--------------------------------------------------------------------===//
107   patterns.add<LinalgTilingPattern<MatmulOp>>(
108       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
109       LinalgTransformationFilter(Identifier::get("MEM", ctx),
110                                  Identifier::get("L3", ctx)));
111   patterns.add<LinalgTilingPattern<MatmulOp>>(
112       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
113       LinalgTransformationFilter(Identifier::get("L3", ctx),
114                                  Identifier::get("L2", ctx)));
115   patterns.add<LinalgTilingPattern<MatmulOp>>(
116       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
117       LinalgTransformationFilter(Identifier::get("L2", ctx),
118                                  Identifier::get("L1", ctx)));
119   patterns.add<LinalgTilingPattern<MatmulOp>>(
120       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
121       LinalgTransformationFilter(Identifier::get("L1", ctx),
122                                  Identifier::get("REG", ctx)));
123 
124   patterns.add<LinalgTilingPattern<MatvecOp>>(
125       ctx,
126       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
127           LinalgTilingLoopType::ParallelLoops),
128       LinalgTransformationFilter(ArrayRef<Identifier>{},
129                                  Identifier::get("L1", ctx)));
130 
131   patterns.add<LinalgTilingPattern<DotOp>>(
132       ctx, LinalgTilingOptions().setTileSizes(8000),
133       LinalgTransformationFilter(
134           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
135                                Identifier::get("L3", ctx),
136                                Identifier::get("L2", ctx)},
137           Identifier::get("REG", ctx)));
138 
139   //===--------------------------------------------------------------------===//
140   // Linalg tiling and permutation patterns.
141   //===--------------------------------------------------------------------===//
142   patterns.add<LinalgTilingPattern<MatmulOp>>(
143       ctx,
144       LinalgTilingOptions()
145           .setTileSizes({2000, 3000, 4000})
146           .setInterchange({1, 2, 0}),
147       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
148                                  Identifier::get("L2__with_perm__", ctx)));
149   patterns.add<LinalgTilingPattern<MatmulOp>>(
150       ctx,
151       LinalgTilingOptions()
152           .setTileSizes({200, 300, 400})
153           .setInterchange({1, 0, 2}),
154       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
155                                  Identifier::get("L1__with_perm__", ctx)));
156   patterns.add<LinalgTilingPattern<MatmulOp>>(
157       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
158       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
159                                  Identifier::get("REG__with_perm__", ctx)));
160 
161   patterns.add<LinalgTilingPattern<MatvecOp>>(
162       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
163       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
164                                  Identifier::get("L1__with_perm__", ctx)));
165 
166   patterns.add<LinalgTilingPattern<MatmulOp>>(
167       ctx,
168       LinalgTilingOptions()
169           .setTileSizes({16, 8, 4})
170           .setInterchange({1, 2, 0})
171           .setLoopType(LinalgTilingLoopType::ParallelLoops),
172       LinalgTransformationFilter(
173           Identifier::get("par__with_perm__", ctx),
174           Identifier::get("after_par__with_perm__", ctx)));
175 
176   //===--------------------------------------------------------------------===//
177   // Linalg to loops patterns.
178   //===--------------------------------------------------------------------===//
179   patterns.add<LinalgLoweringPattern<DotOp>>(
180       ctx,
181       /*loweringType=*/LinalgLoweringType::Loops,
182       LinalgTransformationFilter(Identifier::get("REG", ctx)));
183 
184   //===--------------------------------------------------------------------===//
185   // Linalg distribution patterns.
186   //===--------------------------------------------------------------------===//
187   LinalgLoopDistributionOptions distributionOptions;
188 
189   //===--------------------------------------------------------------------===//
190   // Linalg to vector contraction patterns.
191   //===--------------------------------------------------------------------===//
192   patterns.add<LinalgVectorizationPattern>(
193       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
194                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
195 
196   //===--------------------------------------------------------------------===//
197   // Linalg generic interchange pattern.
198   //===--------------------------------------------------------------------===//
199   patterns.add<GenericOpInterchangePattern>(
200       ctx,
201       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
202       LinalgTransformationFilter(ArrayRef<Identifier>{},
203                                  Identifier::get("PERMUTED", ctx)));
204 
205   //===--------------------------------------------------------------------===//
206   // Linalg subview operands promotion.
207   //===--------------------------------------------------------------------===//
208   patterns.add<LinalgPromotionPattern<MatmulOp>>(
209       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
210       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
211                                  Identifier::get("_views_promoted_", ctx)));
212   patterns.add<LinalgPromotionPattern<MatmulOp>>(
213       ctx,
214       LinalgPromotionOptions()
215           .setOperandsToPromote({0})
216           .setUseFullTileBuffersByDefault(true),
217       LinalgTransformationFilter(
218           Identifier::get("_promote_first_view_", ctx),
219           Identifier::get("_first_view_promoted_", ctx)));
220   patterns.add<LinalgPromotionPattern<FillOp>>(
221       ctx,
222       LinalgPromotionOptions()
223           .setOperandsToPromote({0})
224           .setUseFullTileBuffers({true})
225           .setAlignment(32),
226       LinalgTransformationFilter(
227           Identifier::get("_promote_views_aligned_", ctx),
228           Identifier::get("_views_aligned_promoted_", ctx)));
229 
230   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
231 
232   // Drop the marker.
233   funcOp.walk([](LinalgOp op) {
234     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
235   });
236 }
237 
238 static void fillL1TilingAndMatmulToVectorPatterns(
239     FuncOp funcOp, StringRef startMarker,
240     SmallVectorImpl<RewritePatternSet> &patternsVector) {
241   MLIRContext *ctx = funcOp.getContext();
242   patternsVector.emplace_back(
243       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
244                ctx,
245                LinalgTilingOptions()
246                    .setTileSizes({8, 12, 16})
247                    .setInterchange({1, 0, 2}),
248                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
249                                           Identifier::get("L1", ctx))));
250 
251   patternsVector.emplace_back(
252       ctx,
253       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
254           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
255           LinalgTransformationFilter(Identifier::get("L1", ctx),
256                                      Identifier::get("VEC", ctx))));
257 
258   patternsVector.emplace_back(
259       ctx, std::make_unique<LinalgVectorizationPattern>(
260                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
261                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
262   patternsVector.back().add<LinalgVectorizationPattern>(
263       ctx, LinalgTransformationFilter().addFilter(
264                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // Test promotion callbacks
269 //===----------------------------------------------------------------------===//
270 
271 // Allocation call back
272 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
273                                        ArrayRef<Value> boundingSubViewSize,
274                                        DataLayout &layout,
275                                        OperationFolder *folder) {
276   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
277   return b
278       .create<memref::AllocOp>(
279           subView.getLoc(),
280           MemRefType::get(shape, subView.getType().getElementType(),
281                           /*affineMapComposition =*/{}, 3),
282           boundingSubViewSize)
283       .getResult();
284 }
285 
286 // Deallocation callback
287 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
288   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
289   return success();
290 }
291 
292 // Copy in call back
293 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
294                                     bool isOutput) {
295   auto floatType = src.getType().cast<MemRefType>().getElementType();
296   if (!floatType.isa<FloatType>())
297     return failure();
298   if (!isOutput)
299     b.create<FillOp>(
300         src.getLoc(), dst,
301         b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)));
302   b.create<CopyOp>(src.getLoc(), src, dst);
303   return success();
304 }
305 
306 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
307                                           RewritePatternSet &patterns) {
308   patterns.add<LinalgTilingPattern<MatmulOp>>(
309       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
310       LinalgTransformationFilter(Identifier::get("START", ctx),
311                                  Identifier::get("PROMOTE", ctx)));
312   patterns.add<LinalgPromotionPattern<MatmulOp>>(
313       ctx,
314       LinalgPromotionOptions()
315           .setOperandsToPromote({0, 2})
316           .setUseFullTileBuffers({false, false})
317           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
318           .setCopyInOutFns(
319               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
320                 return copyCallBackFn(b, src, dst, false);
321               },
322               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
323                 return copyCallBackFn(b, src, dst, true);
324               }),
325       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
326 }
327 
328 template <typename IdOp, typename NProcsOp>
329 static SmallVector<ProcInfo, 2>
330 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
331   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
332   SmallVector<ProcInfo, 2> procInfo(count);
333   const char *xyz[] = {"x", "y", "z"};
334   Type indexType = b.getIndexType();
335   for (unsigned i = 0; i < count; ++i) {
336     procInfo[count - 1 - i] = {
337         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
338         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
339   }
340   return procInfo;
341 }
342 
343 static void fillTileAndDistributePatterns(MLIRContext *context,
344                                           RewritePatternSet &patterns) {
345   {
346     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
347     cyclicNprocsEqNiters.distributionMethod.resize(
348         2, DistributionMethod::CyclicNumProcsEqNumIters);
349     cyclicNprocsEqNiters.procInfo =
350         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
351     patterns.add<LinalgTilingPattern<MatmulOp>>(
352         context,
353         LinalgTilingOptions()
354             .setTileSizes({8, 8, 4})
355             .setLoopType(LinalgTilingLoopType::ParallelLoops)
356             .setDistributionOptions(cyclicNprocsEqNiters),
357         LinalgTransformationFilter(
358             Identifier::get("distribute1", context),
359             Identifier::get("after_distribute1", context)));
360   }
361 
362   {
363     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
364     cyclicNprocsGeNiters.distributionMethod.resize(
365         2, DistributionMethod::CyclicNumProcsGeNumIters);
366     cyclicNprocsGeNiters.procInfo =
367         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
368     patterns.add<LinalgTilingPattern<MatmulOp>>(
369         context,
370         LinalgTilingOptions()
371             .setTileSizes({8, 8, 4})
372             .setLoopType(LinalgTilingLoopType::ParallelLoops)
373             .setDistributionOptions(cyclicNprocsGeNiters),
374         LinalgTransformationFilter(
375             Identifier::get("distribute2", context),
376             Identifier::get("after_distribute2", context)));
377   }
378 
379   {
380     LinalgLoopDistributionOptions cyclicNprocsDefault;
381     cyclicNprocsDefault.distributionMethod.resize(2,
382                                                   DistributionMethod::Cyclic);
383     cyclicNprocsDefault.procInfo =
384         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
385     patterns.add<LinalgTilingPattern<MatmulOp>>(
386         context,
387         LinalgTilingOptions()
388             .setTileSizes({8, 8, 4})
389             .setLoopType(LinalgTilingLoopType::ParallelLoops)
390             .setDistributionOptions(cyclicNprocsDefault),
391         LinalgTransformationFilter(
392             Identifier::get("distribute3", context),
393             Identifier::get("after_distribute3", context)));
394   }
395 
396   {
397     LinalgLoopDistributionOptions cyclicNprocsMixed1;
398     cyclicNprocsMixed1.distributionMethod = {
399         DistributionMethod::CyclicNumProcsEqNumIters,
400         DistributionMethod::CyclicNumProcsGeNumIters};
401     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
402     patterns.add<LinalgTilingPattern<MatmulOp>>(
403         context,
404         LinalgTilingOptions()
405             .setTileSizes({8, 8, 4})
406             .setLoopType(LinalgTilingLoopType::ParallelLoops)
407             .setDistributionOptions(cyclicNprocsMixed1),
408         LinalgTransformationFilter(
409             Identifier::get("distribute4", context),
410             Identifier::get("after_distribute4", context)));
411   }
412 
413   {
414     LinalgLoopDistributionOptions cyclicNprocsMixed2;
415     cyclicNprocsMixed2.distributionMethod = {
416         DistributionMethod::CyclicNumProcsGeNumIters,
417         DistributionMethod::Cyclic};
418     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
419     patterns.add<LinalgTilingPattern<MatmulOp>>(
420         context,
421         LinalgTilingOptions()
422             .setTileSizes({8, 8, 4})
423             .setLoopType(LinalgTilingLoopType::ParallelLoops)
424             .setDistributionOptions(cyclicNprocsMixed2),
425         LinalgTransformationFilter(
426             Identifier::get("distribute5", context),
427             Identifier::get("after_distribute5", context)));
428   }
429 
430   {
431     LinalgLoopDistributionOptions cyclicNprocsMixed3;
432     cyclicNprocsMixed3.distributionMethod = {
433         DistributionMethod::Cyclic,
434         DistributionMethod::CyclicNumProcsEqNumIters};
435     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
436 
437     patterns.add<LinalgTilingPattern<MatmulOp>>(
438         context,
439         LinalgTilingOptions()
440             .setTileSizes({8, 8, 4})
441             .setLoopType(LinalgTilingLoopType::ParallelLoops)
442             .setDistributionOptions(cyclicNprocsMixed3),
443         LinalgTransformationFilter(
444             Identifier::get("distribute6", context),
445             Identifier::get("after_distribute6", context)));
446   }
447 
448   {
449     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
450     cyclicNprocsEqNiters.distributionMethod.resize(2,
451                                                    DistributionMethod::Cyclic);
452     cyclicNprocsEqNiters.procInfo =
453         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
454     patterns.add<LinalgTilingPattern<MatmulOp>>(
455         context,
456         LinalgTilingOptions()
457             .setTileSizes({8, 8, 4})
458             .setLoopType(LinalgTilingLoopType::Loops)
459             .setDistributionOptions(cyclicNprocsEqNiters),
460         LinalgTransformationFilter(
461             Identifier::get("tensors_distribute1", context),
462             Identifier::get("tensors_after_distribute1", context)));
463   }
464 }
465 
466 static void
467 applyMatmulToVectorPatterns(FuncOp funcOp,
468                             bool testMatmulToVectorPatterns1dTiling,
469                             bool testMatmulToVectorPatterns2dTiling) {
470   MLIRContext *ctx = funcOp.getContext();
471   SmallVector<RewritePatternSet, 4> stage1Patterns;
472   if (testMatmulToVectorPatterns1dTiling) {
473     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
474                                           stage1Patterns);
475   } else if (testMatmulToVectorPatterns2dTiling) {
476     stage1Patterns.emplace_back(
477         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
478                  ctx,
479                  LinalgTilingOptions()
480                      .setTileSizes({768, 264, 768})
481                      .setInterchange({1, 2, 0}),
482                  LinalgTransformationFilter(Identifier::get("START", ctx),
483                                             Identifier::get("L2", ctx))));
484     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
485                                           stage1Patterns);
486   }
487   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
488   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
489   FrozenRewritePatternSet stage2Patterns =
490       getLinalgTilingCanonicalizationPatterns(ctx);
491   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
492                             std::move(stage2Patterns));
493 }
494 
495 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
496   RewritePatternSet forwardPattern(funcOp.getContext());
497   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
498   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
499   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
500 }
501 
502 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
503   RewritePatternSet patterns(funcOp.getContext());
504   patterns.add<LinalgVectorizationPattern>(
505       funcOp.getContext(),
506       LinalgTransformationFilter()
507           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
508   patterns.add<PadTensorOpVectorizationPattern>(funcOp.getContext());
509   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
510 }
511 
512 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
513   RewritePatternSet foldPattern(funcOp.getContext());
514   foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
515   FrozenRewritePatternSet frozenPatterns(std::move(foldPattern));
516 
517   // Explicitly walk and apply the pattern locally to avoid more general folding
518   // on the rest of the IR.
519   funcOp.walk([&frozenPatterns](AffineMinOp minOp) {
520     (void)applyOpPatternsAndFold(minOp, frozenPatterns);
521   });
522 }
523 
524 // For now, just assume it is the zero of type.
525 // In the future, it should be the zero of type + op.
526 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
527   auto t = getElementTypeOrSelf(op.get().getType());
528   return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
529 }
530 
531 static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) {
532   MLIRContext *context = funcOp.getContext();
533   RewritePatternSet tilingPattern(context);
534   auto linalgTilingOptions =
535       linalg::LinalgTilingOptions()
536           .setTileSizes(tileSizes)
537           .setPaddingValueComputationFunction(getNeutralOfLinalgOp);
538   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>>(
539       context, linalgTilingOptions,
540       linalg::LinalgTransformationFilter(
541           Identifier::get("tile-and-pad", context)));
542   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
543 }
544 
545 static void applyInterchangePattern(FuncOp funcOp,
546                                     ArrayRef<unsigned> interchangeVector) {
547   MLIRContext *context = funcOp.getContext();
548   RewritePatternSet interchangePattern(context);
549   interchangePattern.add<GenericOpInterchangePattern>(
550       context, interchangeVector,
551       LinalgTransformationFilter(ArrayRef<Identifier>{},
552                                  Identifier::get("interchange", context)));
553   (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
554 }
555 
556 /// Apply transformations specified as patterns.
557 void TestLinalgTransforms::runOnFunction() {
558   auto lambda = [&](void *) {
559     getFunction().walk([](LinalgOp op) {
560       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
561     });
562   };
563   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
564 
565   if (testPromotionOptions) {
566     RewritePatternSet patterns(&getContext());
567     fillPromotionCallBackPatterns(&getContext(), patterns);
568     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
569     return;
570   }
571   if (testTileAndDistributionOptions) {
572     RewritePatternSet patterns(&getContext());
573     fillTileAndDistributePatterns(&getContext(), patterns);
574     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
575     return;
576   }
577   if (testPatterns)
578     return applyPatterns(getFunction());
579   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
580     return applyMatmulToVectorPatterns(getFunction(),
581                                        testMatmulToVectorPatterns1dTiling,
582                                        testMatmulToVectorPatterns2dTiling);
583   if (testVectorTransferForwardingPatterns)
584     return applyVectorTransferForwardingPatterns(getFunction());
585   if (testGenericToVectorPattern)
586     return applyLinalgToVectorPatterns(getFunction());
587   if (testAffineMinSCFCanonicalizationPatterns)
588     return applyAffineMinSCFCanonicalizationPatterns(getFunction());
589   if (testTileAndPadPattern)
590     return applyTileAndPadPattern(getFunction(), tileSizesForPadding);
591   if (testHoistPadding) {
592     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
593       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
594     });
595   }
596   if (testInterchangePattern.hasValue())
597     return applyInterchangePattern(getFunction(), testInterchangePattern);
598 }
599 
600 namespace mlir {
601 namespace test {
602 void registerTestLinalgTransforms() {
603   PassRegistration<TestLinalgTransforms> testTransformPatternsPass(
604       "test-linalg-transform-patterns",
605       "Test Linalg transformation patterns by applying them greedily.");
606 }
607 } // namespace test
608 } // namespace mlir
609