1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 #include "llvm/ADT/SetVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 namespace {
30 struct TestLinalgTransforms
31     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
32   TestLinalgTransforms() = default;
33   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
34 
35   void getDependentDialects(DialectRegistry &registry) const override {
36     // clang-format off
37     registry.insert<AffineDialect,
38                     memref::MemRefDialect,
39                     scf::SCFDialect,
40                     StandardOpsDialect,
41                     vector::VectorDialect,
42                     gpu::GPUDialect>();
43     // clang-format on
44   }
45 
46   void runOnFunction() override;
47 
48   Option<bool> testPatterns{*this, "test-patterns",
49                             llvm::cl::desc("Test a mixed set of patterns"),
50                             llvm::cl::init(false)};
51   Option<bool> testMatmulToVectorPatterns1dTiling{
52       *this, "test-matmul-to-vector-patterns-tile-1d",
53       llvm::cl::desc(
54           "Test a fused pass that applies patterns from matmul to vectors via "
55           "1-d tiling"),
56       llvm::cl::init(false)};
57   Option<bool> testMatmulToVectorPatterns2dTiling{
58       *this, "test-matmul-to-vector-patterns-tile-2d",
59       llvm::cl::desc(
60           "Test a fused pass that applies patterns from matmul to vectors via "
61           "2-d tiling"),
62       llvm::cl::init(false)};
63   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
64                                     llvm::cl::desc("Test promotion options"),
65                                     llvm::cl::init(false)};
66   Option<bool> testTileAndDistributionOptions{
67       *this, "test-tile-and-distribute-options",
68       llvm::cl::desc("Test tile and distribute options"),
69       llvm::cl::init(false)};
70   Option<bool> testVectorTransferForwardingPatterns{
71       *this, "test-vector-transfer-forwarding-patterns",
72       llvm::cl::desc(
73           "Test a fused pass that forwards linalg.copy to vector.transfer"),
74       llvm::cl::init(false)};
75   Option<bool> testGenericToVectorPattern{
76       *this, "test-linalg-to-vector-patterns",
77       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
78                      "in vector.contract form"),
79       llvm::cl::init(false)};
80   Option<bool> testAffineMinSCFCanonicalizationPatterns{
81       *this, "test-affine-min-scf-canonicalization-patterns",
82       llvm::cl::desc("Test affine-min + scf canonicalization patterns."),
83       llvm::cl::init(false)};
84   Option<bool> testTileAndPadPattern{
85       *this, "test-tile-and-pad-pattern",
86       llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)};
87   Option<int> testHoistPadding{*this, "test-hoist-padding",
88                                llvm::cl::desc("Test hoist padding"),
89                                llvm::cl::init(0)};
90   Option<bool> testTransformPadTensor{
91       *this, "test-transform-pad-tensor",
92       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
93       llvm::cl::init(false)};
94   ListOption<int64_t> tileSizesForPadding{
95       *this, "tile-sizes-for-padding",
96       llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
97       llvm::cl::MiscFlags::CommaSeparated};
98   ListOption<unsigned> testInterchangePattern{
99       *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
100       llvm::cl::desc("Test the interchange pattern.")};
101 };
102 } // end anonymous namespace
103 
104 static void applyPatterns(FuncOp funcOp) {
105   MLIRContext *ctx = funcOp.getContext();
106   RewritePatternSet patterns(ctx);
107 
108   //===--------------------------------------------------------------------===//
109   // Linalg tiling patterns.
110   //===--------------------------------------------------------------------===//
111   patterns.add<LinalgTilingPattern<MatmulOp>>(
112       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
113       LinalgTransformationFilter(Identifier::get("MEM", ctx),
114                                  Identifier::get("L3", ctx)));
115   patterns.add<LinalgTilingPattern<MatmulOp>>(
116       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
117       LinalgTransformationFilter(Identifier::get("L3", ctx),
118                                  Identifier::get("L2", ctx)));
119   patterns.add<LinalgTilingPattern<MatmulOp>>(
120       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
121       LinalgTransformationFilter(Identifier::get("L2", ctx),
122                                  Identifier::get("L1", ctx)));
123   patterns.add<LinalgTilingPattern<MatmulOp>>(
124       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
125       LinalgTransformationFilter(Identifier::get("L1", ctx),
126                                  Identifier::get("REG", ctx)));
127 
128   patterns.add<LinalgTilingPattern<MatvecOp>>(
129       ctx,
130       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
131           LinalgTilingLoopType::ParallelLoops),
132       LinalgTransformationFilter(ArrayRef<Identifier>{},
133                                  Identifier::get("L1", ctx)));
134 
135   patterns.add<LinalgTilingPattern<DotOp>>(
136       ctx, LinalgTilingOptions().setTileSizes(8000),
137       LinalgTransformationFilter(
138           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
139                                Identifier::get("L3", ctx),
140                                Identifier::get("L2", ctx)},
141           Identifier::get("REG", ctx)));
142 
143   //===--------------------------------------------------------------------===//
144   // Linalg tiling and permutation patterns.
145   //===--------------------------------------------------------------------===//
146   patterns.add<LinalgTilingPattern<MatmulOp>>(
147       ctx,
148       LinalgTilingOptions()
149           .setTileSizes({2000, 3000, 4000})
150           .setInterchange({1, 2, 0}),
151       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
152                                  Identifier::get("L2__with_perm__", ctx)));
153   patterns.add<LinalgTilingPattern<MatmulOp>>(
154       ctx,
155       LinalgTilingOptions()
156           .setTileSizes({200, 300, 400})
157           .setInterchange({1, 0, 2}),
158       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
159                                  Identifier::get("L1__with_perm__", ctx)));
160   patterns.add<LinalgTilingPattern<MatmulOp>>(
161       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
162       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
163                                  Identifier::get("REG__with_perm__", ctx)));
164 
165   patterns.add<LinalgTilingPattern<MatvecOp>>(
166       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
167       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
168                                  Identifier::get("L1__with_perm__", ctx)));
169 
170   patterns.add<LinalgTilingPattern<MatmulOp>>(
171       ctx,
172       LinalgTilingOptions()
173           .setTileSizes({16, 8, 4})
174           .setInterchange({1, 2, 0})
175           .setLoopType(LinalgTilingLoopType::ParallelLoops),
176       LinalgTransformationFilter(
177           Identifier::get("par__with_perm__", ctx),
178           Identifier::get("after_par__with_perm__", ctx)));
179 
180   //===--------------------------------------------------------------------===//
181   // Linalg to loops patterns.
182   //===--------------------------------------------------------------------===//
183   patterns.add<LinalgLoweringPattern<DotOp>>(
184       ctx,
185       /*loweringType=*/LinalgLoweringType::Loops,
186       LinalgTransformationFilter(Identifier::get("REG", ctx)));
187 
188   //===--------------------------------------------------------------------===//
189   // Linalg distribution patterns.
190   //===--------------------------------------------------------------------===//
191   LinalgLoopDistributionOptions distributionOptions;
192 
193   //===--------------------------------------------------------------------===//
194   // Linalg to vector contraction patterns.
195   //===--------------------------------------------------------------------===//
196   patterns.add<LinalgVectorizationPattern>(
197       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
198                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
199 
200   //===--------------------------------------------------------------------===//
201   // Linalg generic interchange pattern.
202   //===--------------------------------------------------------------------===//
203   patterns.add<GenericOpInterchangePattern>(
204       ctx,
205       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
206       LinalgTransformationFilter(ArrayRef<Identifier>{},
207                                  Identifier::get("PERMUTED", ctx)));
208 
209   //===--------------------------------------------------------------------===//
210   // Linalg subview operands promotion.
211   //===--------------------------------------------------------------------===//
212   patterns.add<LinalgPromotionPattern<MatmulOp>>(
213       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
214       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
215                                  Identifier::get("_views_promoted_", ctx)));
216   patterns.add<LinalgPromotionPattern<MatmulOp>>(
217       ctx,
218       LinalgPromotionOptions()
219           .setOperandsToPromote({0})
220           .setUseFullTileBuffersByDefault(true),
221       LinalgTransformationFilter(
222           Identifier::get("_promote_first_view_", ctx),
223           Identifier::get("_first_view_promoted_", ctx)));
224   patterns.add<LinalgPromotionPattern<FillOp>>(
225       ctx,
226       LinalgPromotionOptions()
227           .setOperandsToPromote({0})
228           .setUseFullTileBuffers({true})
229           .setAlignment(32),
230       LinalgTransformationFilter(
231           Identifier::get("_promote_views_aligned_", ctx),
232           Identifier::get("_views_aligned_promoted_", ctx)));
233 
234   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
235 
236   // Drop the marker.
237   funcOp.walk([](LinalgOp op) {
238     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
239   });
240 }
241 
242 static void fillL1TilingAndMatmulToVectorPatterns(
243     FuncOp funcOp, StringRef startMarker,
244     SmallVectorImpl<RewritePatternSet> &patternsVector) {
245   MLIRContext *ctx = funcOp.getContext();
246   patternsVector.emplace_back(
247       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
248                ctx,
249                LinalgTilingOptions()
250                    .setTileSizes({8, 12, 16})
251                    .setInterchange({1, 0, 2}),
252                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
253                                           Identifier::get("L1", ctx))));
254 
255   patternsVector.emplace_back(
256       ctx,
257       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
258           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
259           LinalgTransformationFilter(Identifier::get("L1", ctx),
260                                      Identifier::get("VEC", ctx))));
261 
262   patternsVector.emplace_back(
263       ctx, std::make_unique<LinalgVectorizationPattern>(
264                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
265                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
266   patternsVector.back().add<LinalgVectorizationPattern>(
267       ctx, LinalgTransformationFilter().addFilter(
268                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // Test promotion callbacks
273 //===----------------------------------------------------------------------===//
274 
275 // Allocation call back
276 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
277                                        ArrayRef<Value> boundingSubViewSize,
278                                        DataLayout &layout) {
279   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
280   return b
281       .create<memref::AllocOp>(
282           subView.getLoc(),
283           MemRefType::get(shape, subView.getType().getElementType(),
284                           /*affineMapComposition =*/{}, 3),
285           boundingSubViewSize)
286       .getResult();
287 }
288 
289 // Deallocation callback
290 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
291   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
292   return success();
293 }
294 
295 // Copy in call back
296 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
297                                     bool isOutput) {
298   auto floatType = src.getType().cast<MemRefType>().getElementType();
299   if (!floatType.isa<FloatType>())
300     return failure();
301   if (!isOutput)
302     b.create<FillOp>(
303         src.getLoc(), dst,
304         b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)));
305   b.create<CopyOp>(src.getLoc(), src, dst);
306   return success();
307 }
308 
309 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
310                                           RewritePatternSet &patterns) {
311   patterns.add<LinalgTilingPattern<MatmulOp>>(
312       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
313       LinalgTransformationFilter(Identifier::get("START", ctx),
314                                  Identifier::get("PROMOTE", ctx)));
315   patterns.add<LinalgPromotionPattern<MatmulOp>>(
316       ctx,
317       LinalgPromotionOptions()
318           .setOperandsToPromote({0, 2})
319           .setUseFullTileBuffers({false, false})
320           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
321           .setCopyInOutFns(
322               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
323                 return copyCallBackFn(b, src, dst, false);
324               },
325               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
326                 return copyCallBackFn(b, src, dst, true);
327               }),
328       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
329 }
330 
331 template <typename IdOp, typename NProcsOp>
332 static SmallVector<ProcInfo, 2>
333 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
334   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
335   SmallVector<ProcInfo, 2> procInfo(count);
336   const char *xyz[] = {"x", "y", "z"};
337   Type indexType = b.getIndexType();
338   for (unsigned i = 0; i < count; ++i) {
339     procInfo[count - 1 - i] = {
340         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
341         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
342   }
343   return procInfo;
344 }
345 
346 static void fillTileAndDistributePatterns(MLIRContext *context,
347                                           RewritePatternSet &patterns) {
348   {
349     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
350     cyclicNprocsEqNiters.distributionMethod.resize(
351         2, DistributionMethod::CyclicNumProcsEqNumIters);
352     cyclicNprocsEqNiters.procInfo =
353         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
354     patterns.add<LinalgTilingPattern<MatmulOp>>(
355         context,
356         LinalgTilingOptions()
357             .setTileSizes({8, 8, 4})
358             .setLoopType(LinalgTilingLoopType::ParallelLoops)
359             .setDistributionOptions(cyclicNprocsEqNiters),
360         LinalgTransformationFilter(
361             Identifier::get("distribute1", context),
362             Identifier::get("after_distribute1", context)));
363   }
364 
365   {
366     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
367     cyclicNprocsGeNiters.distributionMethod.resize(
368         2, DistributionMethod::CyclicNumProcsGeNumIters);
369     cyclicNprocsGeNiters.procInfo =
370         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
371     patterns.add<LinalgTilingPattern<MatmulOp>>(
372         context,
373         LinalgTilingOptions()
374             .setTileSizes({8, 8, 4})
375             .setLoopType(LinalgTilingLoopType::ParallelLoops)
376             .setDistributionOptions(cyclicNprocsGeNiters),
377         LinalgTransformationFilter(
378             Identifier::get("distribute2", context),
379             Identifier::get("after_distribute2", context)));
380   }
381 
382   {
383     LinalgLoopDistributionOptions cyclicNprocsDefault;
384     cyclicNprocsDefault.distributionMethod.resize(2,
385                                                   DistributionMethod::Cyclic);
386     cyclicNprocsDefault.procInfo =
387         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
388     patterns.add<LinalgTilingPattern<MatmulOp>>(
389         context,
390         LinalgTilingOptions()
391             .setTileSizes({8, 8, 4})
392             .setLoopType(LinalgTilingLoopType::ParallelLoops)
393             .setDistributionOptions(cyclicNprocsDefault),
394         LinalgTransformationFilter(
395             Identifier::get("distribute3", context),
396             Identifier::get("after_distribute3", context)));
397   }
398 
399   {
400     LinalgLoopDistributionOptions cyclicNprocsMixed1;
401     cyclicNprocsMixed1.distributionMethod = {
402         DistributionMethod::CyclicNumProcsEqNumIters,
403         DistributionMethod::CyclicNumProcsGeNumIters};
404     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
405     patterns.add<LinalgTilingPattern<MatmulOp>>(
406         context,
407         LinalgTilingOptions()
408             .setTileSizes({8, 8, 4})
409             .setLoopType(LinalgTilingLoopType::ParallelLoops)
410             .setDistributionOptions(cyclicNprocsMixed1),
411         LinalgTransformationFilter(
412             Identifier::get("distribute4", context),
413             Identifier::get("after_distribute4", context)));
414   }
415 
416   {
417     LinalgLoopDistributionOptions cyclicNprocsMixed2;
418     cyclicNprocsMixed2.distributionMethod = {
419         DistributionMethod::CyclicNumProcsGeNumIters,
420         DistributionMethod::Cyclic};
421     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
422     patterns.add<LinalgTilingPattern<MatmulOp>>(
423         context,
424         LinalgTilingOptions()
425             .setTileSizes({8, 8, 4})
426             .setLoopType(LinalgTilingLoopType::ParallelLoops)
427             .setDistributionOptions(cyclicNprocsMixed2),
428         LinalgTransformationFilter(
429             Identifier::get("distribute5", context),
430             Identifier::get("after_distribute5", context)));
431   }
432 
433   {
434     LinalgLoopDistributionOptions cyclicNprocsMixed3;
435     cyclicNprocsMixed3.distributionMethod = {
436         DistributionMethod::Cyclic,
437         DistributionMethod::CyclicNumProcsEqNumIters};
438     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
439 
440     patterns.add<LinalgTilingPattern<MatmulOp>>(
441         context,
442         LinalgTilingOptions()
443             .setTileSizes({8, 8, 4})
444             .setLoopType(LinalgTilingLoopType::ParallelLoops)
445             .setDistributionOptions(cyclicNprocsMixed3),
446         LinalgTransformationFilter(
447             Identifier::get("distribute6", context),
448             Identifier::get("after_distribute6", context)));
449   }
450 
451   {
452     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
453     cyclicNprocsEqNiters.distributionMethod.resize(2,
454                                                    DistributionMethod::Cyclic);
455     cyclicNprocsEqNiters.procInfo =
456         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
457     patterns.add<LinalgTilingPattern<MatmulOp>>(
458         context,
459         LinalgTilingOptions()
460             .setTileSizes({8, 8, 4})
461             .setLoopType(LinalgTilingLoopType::Loops)
462             .setDistributionOptions(cyclicNprocsEqNiters),
463         LinalgTransformationFilter(
464             Identifier::get("tensors_distribute1", context),
465             Identifier::get("tensors_after_distribute1", context)));
466   }
467 }
468 
469 static void
470 applyMatmulToVectorPatterns(FuncOp funcOp,
471                             bool testMatmulToVectorPatterns1dTiling,
472                             bool testMatmulToVectorPatterns2dTiling) {
473   MLIRContext *ctx = funcOp.getContext();
474   SmallVector<RewritePatternSet, 4> stage1Patterns;
475   if (testMatmulToVectorPatterns1dTiling) {
476     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
477                                           stage1Patterns);
478   } else if (testMatmulToVectorPatterns2dTiling) {
479     stage1Patterns.emplace_back(
480         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
481                  ctx,
482                  LinalgTilingOptions()
483                      .setTileSizes({768, 264, 768})
484                      .setInterchange({1, 2, 0}),
485                  LinalgTransformationFilter(Identifier::get("START", ctx),
486                                             Identifier::get("L2", ctx))));
487     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
488                                           stage1Patterns);
489   }
490   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
491   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
492   FrozenRewritePatternSet stage2Patterns =
493       getLinalgTilingCanonicalizationPatterns(ctx);
494   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
495                             std::move(stage2Patterns));
496 }
497 
498 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
499   RewritePatternSet forwardPattern(funcOp.getContext());
500   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
501   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
502   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
503 }
504 
505 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
506   RewritePatternSet patterns(funcOp.getContext());
507   patterns.add<LinalgVectorizationPattern>(
508       funcOp.getContext(),
509       LinalgTransformationFilter()
510           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
511   populatePadTensorOpVectorizationPatterns(patterns);
512   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
513 }
514 
515 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
516   RewritePatternSet patterns(funcOp.getContext());
517   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
518   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
519 }
520 
521 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
522   RewritePatternSet foldPattern(funcOp.getContext());
523   foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
524   FrozenRewritePatternSet frozenPatterns(std::move(foldPattern));
525 
526   // Explicitly walk and apply the pattern locally to avoid more general folding
527   // on the rest of the IR.
528   funcOp.walk([&frozenPatterns](AffineMinOp minOp) {
529     (void)applyOpPatternsAndFold(minOp, frozenPatterns);
530   });
531 }
532 
533 // For now, just assume it is the zero of type.
534 // In the future, it should be the zero of type + op.
535 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
536   auto t = getElementTypeOrSelf(op.get().getType());
537   return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
538 }
539 
540 static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) {
541   MLIRContext *context = funcOp.getContext();
542   RewritePatternSet tilingPattern(context);
543   auto linalgTilingOptions =
544       linalg::LinalgTilingOptions()
545           .setTileSizes(tileSizes)
546           .setPaddingValueComputationFunction(getNeutralOfLinalgOp);
547   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>>(
548       context, linalgTilingOptions,
549       linalg::LinalgTransformationFilter(
550           Identifier::get("tile-and-pad", context)));
551   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
552 }
553 
554 static void applyInterchangePattern(FuncOp funcOp,
555                                     ArrayRef<unsigned> interchangeVector) {
556   MLIRContext *context = funcOp.getContext();
557   RewritePatternSet interchangePattern(context);
558   interchangePattern.add<GenericOpInterchangePattern>(
559       context, interchangeVector,
560       LinalgTransformationFilter(ArrayRef<Identifier>{},
561                                  Identifier::get("interchange", context)));
562   (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
563 }
564 
565 /// Apply transformations specified as patterns.
566 void TestLinalgTransforms::runOnFunction() {
567   auto lambda = [&](void *) {
568     getFunction().walk([](LinalgOp op) {
569       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
570     });
571   };
572   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
573 
574   if (testPromotionOptions) {
575     RewritePatternSet patterns(&getContext());
576     fillPromotionCallBackPatterns(&getContext(), patterns);
577     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
578     return;
579   }
580   if (testTileAndDistributionOptions) {
581     RewritePatternSet patterns(&getContext());
582     fillTileAndDistributePatterns(&getContext(), patterns);
583     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
584     return;
585   }
586   if (testPatterns)
587     return applyPatterns(getFunction());
588   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
589     return applyMatmulToVectorPatterns(getFunction(),
590                                        testMatmulToVectorPatterns1dTiling,
591                                        testMatmulToVectorPatterns2dTiling);
592   if (testVectorTransferForwardingPatterns)
593     return applyVectorTransferForwardingPatterns(getFunction());
594   if (testGenericToVectorPattern)
595     return applyLinalgToVectorPatterns(getFunction());
596   if (testTransformPadTensor)
597     return applyPadTensorToGenericPatterns(getFunction());
598   if (testAffineMinSCFCanonicalizationPatterns)
599     return applyAffineMinSCFCanonicalizationPatterns(getFunction());
600   if (testTileAndPadPattern)
601     return applyTileAndPadPattern(getFunction(), tileSizesForPadding);
602   if (testHoistPadding) {
603     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
604       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
605     });
606   }
607   if (testInterchangePattern.hasValue())
608     return applyInterchangePattern(getFunction(), testInterchangePattern);
609 }
610 
611 namespace mlir {
612 namespace test {
613 void registerTestLinalgTransforms() {
614   PassRegistration<TestLinalgTransforms> testTransformPatternsPass(
615       "test-linalg-transform-patterns",
616       "Test Linalg transformation patterns by applying them greedily.");
617 }
618 } // namespace test
619 } // namespace mlir
620