1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
18 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 namespace {
33 struct TestLinalgTransforms
34     : public PassWrapper<TestLinalgTransforms, FunctionPass> {
35   TestLinalgTransforms() = default;
36   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
37 
38   void getDependentDialects(DialectRegistry &registry) const override {
39     // clang-format off
40     registry.insert<AffineDialect,
41                     memref::MemRefDialect,
42                     scf::SCFDialect,
43                     StandardOpsDialect,
44                     vector::VectorDialect,
45                     gpu::GPUDialect>();
46     // clang-format on
47   }
48   StringRef getArgument() const final {
49     return "test-linalg-transform-patterns";
50   }
51   StringRef getDescription() const final {
52     return "Test Linalg transformation patterns by applying them greedily.";
53   }
54 
55   void runOnFunction() override;
56 
57   Option<bool> testPatterns{*this, "test-patterns",
58                             llvm::cl::desc("Test a mixed set of patterns"),
59                             llvm::cl::init(false)};
60   Option<bool> testMatmulToVectorPatterns1dTiling{
61       *this, "test-matmul-to-vector-patterns-tile-1d",
62       llvm::cl::desc(
63           "Test a fused pass that applies patterns from matmul to vectors via "
64           "1-d tiling"),
65       llvm::cl::init(false)};
66   Option<bool> testMatmulToVectorPatterns2dTiling{
67       *this, "test-matmul-to-vector-patterns-tile-2d",
68       llvm::cl::desc(
69           "Test a fused pass that applies patterns from matmul to vectors via "
70           "2-d tiling"),
71       llvm::cl::init(false)};
72   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
73                                     llvm::cl::desc("Test promotion options"),
74                                     llvm::cl::init(false)};
75   Option<bool> testTileAndDistributionOptions{
76       *this, "test-tile-and-distribute-options",
77       llvm::cl::desc("Test tile and distribute options"),
78       llvm::cl::init(false)};
79   Option<bool> testVectorTransferForwardingPatterns{
80       *this, "test-vector-transfer-forwarding-patterns",
81       llvm::cl::desc(
82           "Test a fused pass that forwards linalg.copy to vector.transfer"),
83       llvm::cl::init(false)};
84   Option<bool> testGenericToVectorPattern{
85       *this, "test-linalg-to-vector-patterns",
86       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
87                      "in vector.contract form"),
88       llvm::cl::init(false)};
89   Option<bool> testTilePattern{*this, "test-tile-pattern",
90                                llvm::cl::desc("Test tile pattern"),
91                                llvm::cl::init(false)};
92   Option<bool> testTileScalarizeDynamicDims{
93       *this, "test-tile-scalarize-dynamic-dims",
94       llvm::cl::desc("Test tiling of dynamic dims by 1"),
95       llvm::cl::init(false)};
96   Option<int> testHoistPadding{*this, "test-hoist-padding",
97                                llvm::cl::desc("Test hoist padding"),
98                                llvm::cl::init(0)};
99   Option<bool> testPadPattern{*this, "test-pad-pattern",
100                               llvm::cl::desc("Test pad pattern"),
101                               llvm::cl::init(false)};
102   Option<bool> testTransformPadTensor{
103       *this, "test-transform-pad-tensor",
104       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
105       llvm::cl::init(false)};
106   Option<bool> testGeneralizePadTensor{
107       *this, "test-generalize-pad-tensor",
108       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
109       llvm::cl::init(false)};
110   Option<bool> testSwapSubTensorPadTensor{
111       *this, "test-swap-subtensor-padtensor",
112       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
113                      "pad_tensor(subtensor)"),
114       llvm::cl::init(false)};
115   ListOption<int64_t> paddedOperands{
116       *this, "padded-operands",
117       llvm::cl::desc("Operands to pad when test-tile-pattern"),
118       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
119   ListOption<int64_t> nofoldOperands{
120       *this, "nofold-operands",
121       llvm::cl::desc("Operands to set nofold when test-tile-pattern"),
122       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
123   ListOption<int64_t> packPaddings{
124       *this, "pack-paddings",
125       llvm::cl::desc("Operand packing flags when test-pad-pattern"),
126       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
127   ListOption<int64_t> hoistPaddings{
128       *this, "hoist-paddings",
129       llvm::cl::desc("Operand hoisting depths when test-pad-pattern"),
130       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
131   ListOption<int64_t> peeledLoops{
132       *this, "peeled-loops",
133       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
134       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
135   ListOption<int64_t> tileSizes{
136       *this, "tile-sizes",
137       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
138       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
139   ListOption<unsigned> testInterchangePattern{
140       *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
141       llvm::cl::desc("Test the interchange pattern.")};
142   ListOption<unsigned> testTiledLoopPeeling{
143       *this, "test-tiled-loop-peeling",
144       llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
145       llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
146   Option<bool> skipPartial{
147       *this, "skip-partial",
148       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
149       llvm::cl::init(false)};
150   Option<std::string> loopType{
151       *this, "loop-type",
152       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
153                      "tiled_loop"),
154       llvm::cl::init("for")};
155   Option<bool> testDecomposeConvolutionPattern{
156       *this, "test-decompose-convolution-patterns",
157       llvm::cl::desc("Test a set of patterns to rewrite high-D convolution ops "
158                      "into low-D ones"),
159       llvm::cl::init(false)};
160 };
161 } // end anonymous namespace
162 
163 static void applyPatterns(FuncOp funcOp) {
164   MLIRContext *ctx = funcOp.getContext();
165   RewritePatternSet patterns(ctx);
166 
167   //===--------------------------------------------------------------------===//
168   // Linalg tiling patterns.
169   //===--------------------------------------------------------------------===//
170   patterns.add<LinalgTilingPattern<MatmulOp>>(
171       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
172       LinalgTransformationFilter(Identifier::get("MEM", ctx),
173                                  Identifier::get("L3", ctx)));
174   patterns.add<LinalgTilingPattern<MatmulOp>>(
175       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
176       LinalgTransformationFilter(Identifier::get("L3", ctx),
177                                  Identifier::get("L2", ctx)));
178   patterns.add<LinalgTilingPattern<MatmulOp>>(
179       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
180       LinalgTransformationFilter(Identifier::get("L2", ctx),
181                                  Identifier::get("L1", ctx)));
182   patterns.add<LinalgTilingPattern<MatmulOp>>(
183       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
184       LinalgTransformationFilter(Identifier::get("L1", ctx),
185                                  Identifier::get("REG", ctx)));
186 
187   patterns.add<LinalgTilingPattern<MatvecOp>>(
188       ctx,
189       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
190           LinalgTilingLoopType::ParallelLoops),
191       LinalgTransformationFilter(ArrayRef<Identifier>{},
192                                  Identifier::get("L1", ctx)));
193 
194   patterns.add<LinalgTilingPattern<DotOp>>(
195       ctx, LinalgTilingOptions().setTileSizes(8000),
196       LinalgTransformationFilter(
197           ArrayRef<Identifier>{Identifier::get("MEM", ctx),
198                                Identifier::get("L3", ctx),
199                                Identifier::get("L2", ctx)},
200           Identifier::get("REG", ctx)));
201 
202   //===--------------------------------------------------------------------===//
203   // Linalg tiling and permutation patterns.
204   //===--------------------------------------------------------------------===//
205   patterns.add<LinalgTilingPattern<MatmulOp>>(
206       ctx,
207       LinalgTilingOptions()
208           .setTileSizes({2000, 3000, 4000})
209           .setInterchange({1, 2, 0}),
210       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
211                                  Identifier::get("L2__with_perm__", ctx)));
212   patterns.add<LinalgTilingPattern<MatmulOp>>(
213       ctx,
214       LinalgTilingOptions()
215           .setTileSizes({200, 300, 400})
216           .setInterchange({1, 0, 2}),
217       LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
218                                  Identifier::get("L1__with_perm__", ctx)));
219   patterns.add<LinalgTilingPattern<MatmulOp>>(
220       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
221       LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
222                                  Identifier::get("REG__with_perm__", ctx)));
223 
224   patterns.add<LinalgTilingPattern<MatvecOp>>(
225       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
226       LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
227                                  Identifier::get("L1__with_perm__", ctx)));
228 
229   patterns.add<LinalgTilingPattern<MatmulOp>>(
230       ctx,
231       LinalgTilingOptions()
232           .setTileSizes({16, 8, 4})
233           .setInterchange({1, 2, 0})
234           .setLoopType(LinalgTilingLoopType::ParallelLoops),
235       LinalgTransformationFilter(
236           Identifier::get("par__with_perm__", ctx),
237           Identifier::get("after_par__with_perm__", ctx)));
238 
239   //===--------------------------------------------------------------------===//
240   // Linalg to loops patterns.
241   //===--------------------------------------------------------------------===//
242   patterns.add<LinalgLoweringPattern<DotOp>>(
243       ctx,
244       /*loweringType=*/LinalgLoweringType::Loops,
245       LinalgTransformationFilter(Identifier::get("REG", ctx)));
246 
247   //===--------------------------------------------------------------------===//
248   // Linalg distribution patterns.
249   //===--------------------------------------------------------------------===//
250   LinalgLoopDistributionOptions distributionOptions;
251 
252   //===--------------------------------------------------------------------===//
253   // Linalg to vector contraction patterns.
254   //===--------------------------------------------------------------------===//
255   patterns.add<LinalgVectorizationPattern>(
256       ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
257                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
258 
259   //===--------------------------------------------------------------------===//
260   // Linalg generic interchange pattern.
261   //===--------------------------------------------------------------------===//
262   patterns.add<GenericOpInterchangePattern>(
263       ctx,
264       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
265       LinalgTransformationFilter(ArrayRef<Identifier>{},
266                                  Identifier::get("PERMUTED", ctx)));
267 
268   //===--------------------------------------------------------------------===//
269   // Linalg subview operands promotion.
270   //===--------------------------------------------------------------------===//
271   patterns.add<LinalgPromotionPattern<MatmulOp>>(
272       ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
273       LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
274                                  Identifier::get("_views_promoted_", ctx)));
275   patterns.add<LinalgPromotionPattern<MatmulOp>>(
276       ctx,
277       LinalgPromotionOptions()
278           .setOperandsToPromote({0})
279           .setUseFullTileBuffersByDefault(true),
280       LinalgTransformationFilter(
281           Identifier::get("_promote_first_view_", ctx),
282           Identifier::get("_first_view_promoted_", ctx)));
283   patterns.add<LinalgPromotionPattern<FillOp>>(
284       ctx,
285       LinalgPromotionOptions()
286           .setOperandsToPromote({1})
287           .setUseFullTileBuffers({false, true})
288           .setAlignment(32),
289       LinalgTransformationFilter(
290           Identifier::get("_promote_views_aligned_", ctx),
291           Identifier::get("_views_aligned_promoted_", ctx)));
292 
293   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
294 
295   // Drop the marker.
296   funcOp.walk([](LinalgOp op) {
297     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
298   });
299 }
300 
301 static void fillL1TilingAndMatmulToVectorPatterns(
302     FuncOp funcOp, StringRef startMarker,
303     SmallVectorImpl<RewritePatternSet> &patternsVector) {
304   MLIRContext *ctx = funcOp.getContext();
305   patternsVector.emplace_back(
306       ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
307                ctx,
308                LinalgTilingOptions()
309                    .setTileSizes({8, 12, 16})
310                    .setInterchange({1, 0, 2}),
311                LinalgTransformationFilter(Identifier::get(startMarker, ctx),
312                                           Identifier::get("L1", ctx))));
313 
314   patternsVector.emplace_back(
315       ctx,
316       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
317           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
318           LinalgTransformationFilter(Identifier::get("L1", ctx),
319                                      Identifier::get("VEC", ctx))));
320 
321   patternsVector.emplace_back(
322       ctx, std::make_unique<LinalgVectorizationPattern>(
323                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
324                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
325   patternsVector.back().add<LinalgVectorizationPattern>(
326       ctx, LinalgTransformationFilter().addFilter(
327                [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // Test promotion callbacks
332 //===----------------------------------------------------------------------===//
333 
334 // Allocation call back
335 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
336                                        ArrayRef<Value> boundingSubViewSize,
337                                        DataLayout &layout) {
338   SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
339   return b
340       .create<memref::AllocOp>(
341           subView.getLoc(),
342           MemRefType::get(shape, subView.getType().getElementType(),
343                           /*affineMapComposition =*/{}, 3),
344           boundingSubViewSize)
345       .getResult();
346 }
347 
348 // Deallocation callback
349 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
350   b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
351   return success();
352 }
353 
354 // Copy in call back
355 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
356                                     bool isOutput) {
357   auto floatType = src.getType().cast<MemRefType>().getElementType();
358   if (!floatType.isa<FloatType>())
359     return failure();
360   if (!isOutput) {
361     Value cst = b.create<arith::ConstantOp>(src.getLoc(),
362                                             FloatAttr::get(floatType, 42.0));
363     b.create<FillOp>(src.getLoc(), cst, dst);
364   }
365   b.create<CopyOp>(src.getLoc(), src, dst);
366   return success();
367 }
368 
369 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
370                                           RewritePatternSet &patterns) {
371   patterns.add<LinalgTilingPattern<MatmulOp>>(
372       ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
373       LinalgTransformationFilter(Identifier::get("START", ctx),
374                                  Identifier::get("PROMOTE", ctx)));
375   patterns.add<LinalgPromotionPattern<MatmulOp>>(
376       ctx,
377       LinalgPromotionOptions()
378           .setOperandsToPromote({0, 2})
379           .setUseFullTileBuffers({false, false})
380           .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
381           .setCopyInOutFns(
382               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
383                 return copyCallBackFn(b, src, dst, false);
384               },
385               [](OpBuilder &b, Value src, Value dst) -> LogicalResult {
386                 return copyCallBackFn(b, src, dst, true);
387               }),
388       LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
389 }
390 
391 template <typename IdOp, typename NProcsOp>
392 static SmallVector<ProcInfo, 2>
393 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
394   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
395   SmallVector<ProcInfo, 2> procInfo(count);
396   const char *xyz[] = {"x", "y", "z"};
397   Type indexType = b.getIndexType();
398   for (unsigned i = 0; i < count; ++i) {
399     procInfo[count - 1 - i] = {
400         b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
401         b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
402   }
403   return procInfo;
404 }
405 
406 static void fillTileAndDistributePatterns(MLIRContext *context,
407                                           RewritePatternSet &patterns) {
408   {
409     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
410     cyclicNprocsEqNiters.distributionMethod.resize(
411         2, DistributionMethod::CyclicNumProcsEqNumIters);
412     cyclicNprocsEqNiters.procInfo =
413         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
414     patterns.add<LinalgTilingPattern<MatmulOp>>(
415         context,
416         LinalgTilingOptions()
417             .setTileSizes({8, 8, 4})
418             .setLoopType(LinalgTilingLoopType::ParallelLoops)
419             .setDistributionOptions(cyclicNprocsEqNiters),
420         LinalgTransformationFilter(
421             Identifier::get("distribute1", context),
422             Identifier::get("after_distribute1", context)));
423   }
424 
425   {
426     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
427     cyclicNprocsGeNiters.distributionMethod.resize(
428         2, DistributionMethod::CyclicNumProcsGeNumIters);
429     cyclicNprocsGeNiters.procInfo =
430         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
431     patterns.add<LinalgTilingPattern<MatmulOp>>(
432         context,
433         LinalgTilingOptions()
434             .setTileSizes({8, 8, 4})
435             .setLoopType(LinalgTilingLoopType::ParallelLoops)
436             .setDistributionOptions(cyclicNprocsGeNiters),
437         LinalgTransformationFilter(
438             Identifier::get("distribute2", context),
439             Identifier::get("after_distribute2", context)));
440   }
441 
442   {
443     LinalgLoopDistributionOptions cyclicNprocsDefault;
444     cyclicNprocsDefault.distributionMethod.resize(2,
445                                                   DistributionMethod::Cyclic);
446     cyclicNprocsDefault.procInfo =
447         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
448     patterns.add<LinalgTilingPattern<MatmulOp>>(
449         context,
450         LinalgTilingOptions()
451             .setTileSizes({8, 8, 4})
452             .setLoopType(LinalgTilingLoopType::ParallelLoops)
453             .setDistributionOptions(cyclicNprocsDefault),
454         LinalgTransformationFilter(
455             Identifier::get("distribute3", context),
456             Identifier::get("after_distribute3", context)));
457   }
458 
459   {
460     LinalgLoopDistributionOptions cyclicNprocsMixed1;
461     cyclicNprocsMixed1.distributionMethod = {
462         DistributionMethod::CyclicNumProcsEqNumIters,
463         DistributionMethod::CyclicNumProcsGeNumIters};
464     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
465     patterns.add<LinalgTilingPattern<MatmulOp>>(
466         context,
467         LinalgTilingOptions()
468             .setTileSizes({8, 8, 4})
469             .setLoopType(LinalgTilingLoopType::ParallelLoops)
470             .setDistributionOptions(cyclicNprocsMixed1),
471         LinalgTransformationFilter(
472             Identifier::get("distribute4", context),
473             Identifier::get("after_distribute4", context)));
474   }
475 
476   {
477     LinalgLoopDistributionOptions cyclicNprocsMixed2;
478     cyclicNprocsMixed2.distributionMethod = {
479         DistributionMethod::CyclicNumProcsGeNumIters,
480         DistributionMethod::Cyclic};
481     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
482     patterns.add<LinalgTilingPattern<MatmulOp>>(
483         context,
484         LinalgTilingOptions()
485             .setTileSizes({8, 8, 4})
486             .setLoopType(LinalgTilingLoopType::ParallelLoops)
487             .setDistributionOptions(cyclicNprocsMixed2),
488         LinalgTransformationFilter(
489             Identifier::get("distribute5", context),
490             Identifier::get("after_distribute5", context)));
491   }
492 
493   {
494     LinalgLoopDistributionOptions cyclicNprocsMixed3;
495     cyclicNprocsMixed3.distributionMethod = {
496         DistributionMethod::Cyclic,
497         DistributionMethod::CyclicNumProcsEqNumIters};
498     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
499 
500     patterns.add<LinalgTilingPattern<MatmulOp>>(
501         context,
502         LinalgTilingOptions()
503             .setTileSizes({8, 8, 4})
504             .setLoopType(LinalgTilingLoopType::ParallelLoops)
505             .setDistributionOptions(cyclicNprocsMixed3),
506         LinalgTransformationFilter(
507             Identifier::get("distribute6", context),
508             Identifier::get("after_distribute6", context)));
509   }
510 
511   {
512     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
513     cyclicNprocsEqNiters.distributionMethod.resize(2,
514                                                    DistributionMethod::Cyclic);
515     cyclicNprocsEqNiters.procInfo =
516         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
517     patterns.add<LinalgTilingPattern<MatmulOp>>(
518         context,
519         LinalgTilingOptions()
520             .setTileSizes({8, 8, 4})
521             .setLoopType(LinalgTilingLoopType::Loops)
522             .setDistributionOptions(cyclicNprocsEqNiters),
523         LinalgTransformationFilter(
524             Identifier::get("tensors_distribute1", context),
525             Identifier::get("tensors_after_distribute1", context)));
526   }
527 }
528 
529 static void
530 applyMatmulToVectorPatterns(FuncOp funcOp,
531                             bool testMatmulToVectorPatterns1dTiling,
532                             bool testMatmulToVectorPatterns2dTiling) {
533   MLIRContext *ctx = funcOp.getContext();
534   SmallVector<RewritePatternSet, 4> stage1Patterns;
535   if (testMatmulToVectorPatterns1dTiling) {
536     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
537                                           stage1Patterns);
538   } else if (testMatmulToVectorPatterns2dTiling) {
539     stage1Patterns.emplace_back(
540         ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
541                  ctx,
542                  LinalgTilingOptions()
543                      .setTileSizes({768, 264, 768})
544                      .setInterchange({1, 2, 0}),
545                  LinalgTransformationFilter(Identifier::get("START", ctx),
546                                             Identifier::get("L2", ctx))));
547     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
548                                           stage1Patterns);
549   }
550   {
551     // Canonicalization patterns
552     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
553     vector::populateVectorTransferPermutationMapLoweringPatterns(
554         canonicalizationPatterns);
555     vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
556     stage1Patterns.push_back(std::move(canonicalizationPatterns));
557   }
558   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
559   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
560   FrozenRewritePatternSet stage2Patterns =
561       getLinalgTilingCanonicalizationPatterns(ctx);
562   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
563                             std::move(stage2Patterns));
564 }
565 
566 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
567   RewritePatternSet forwardPattern(funcOp.getContext());
568   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
569   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
570   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
571 }
572 
573 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
574   RewritePatternSet patterns(funcOp.getContext());
575   patterns.add<LinalgVectorizationPattern>(
576       funcOp.getContext(),
577       LinalgTransformationFilter()
578           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
579   populatePadTensorOpVectorizationPatterns(patterns);
580   populateConvolutionVectorizationPatterns(patterns);
581   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
582 }
583 
584 static void applyDecomposeConvolutionPatterns(FuncOp funcOp) {
585   RewritePatternSet patterns(funcOp.getContext());
586   populateDecomposeConvolutionPatterns(patterns);
587   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
588 }
589 
590 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
591   RewritePatternSet patterns(funcOp.getContext());
592   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
593   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
594 }
595 
596 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
597   RewritePatternSet patterns(funcOp.getContext());
598   patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
599   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
600 }
601 
602 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
603   RewritePatternSet patterns(funcOp.getContext());
604   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
605   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
606 }
607 
608 // For now, just assume it is the zero of type.
609 // In the future, it should be the zero of type + op.
610 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
611   auto t = getElementTypeOrSelf(op.get());
612   return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t,
613                                      b.getZeroAttr(t));
614 }
615 
616 static void applyTilePattern(FuncOp funcOp, std::string loopType,
617                              ArrayRef<int64_t> tileSizes,
618                              ArrayRef<int64_t> paddedOperands,
619                              ArrayRef<int64_t> nofoldOperands,
620                              ArrayRef<int64_t> peeledLoops,
621                              bool scalarizeDynamicDims) {
622   MLIRContext *context = funcOp.getContext();
623   RewritePatternSet tilingPattern(context);
624   LinalgTilingLoopType type =
625       llvm::StringSwitch<LinalgTilingLoopType>(loopType)
626           .Case("for", LinalgTilingLoopType::Loops)
627           .Case("affine", LinalgTilingLoopType::AffineLoops)
628           .Case("parallel", LinalgTilingLoopType::ParallelLoops)
629           .Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
630   auto linalgTilingOptions = linalg::LinalgTilingOptions()
631                                  .setPeeledLoops(peeledLoops)
632                                  .setLoopType(type);
633   if (scalarizeDynamicDims) {
634     linalgTilingOptions.scalarizeDynamicDims();
635     assert(tileSizes.empty() &&
636            "tileSizes and scalarizeDynamicDims is mutually exclusive");
637   } else {
638     linalgTilingOptions.setTileSizes(tileSizes);
639   }
640   if (!paddedOperands.empty()) {
641     auto paddingFunc = [&](OpBuilder &b,
642                            OpOperand &opOperand) -> FailureOr<Value> {
643       if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0)
644         return failure();
645       return getNeutralOfLinalgOp(b, opOperand);
646     };
647     auto nofoldFunc = [&](OpOperand &opOperand) {
648       if (llvm::count(nofoldOperands, opOperand.getOperandNumber()) != 0)
649         return true;
650       return false;
651     };
652     linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc);
653     linalgTilingOptions.setPaddingNoFoldComputationFunction(nofoldFunc);
654   }
655   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
656                     linalg::LinalgTilingPattern<linalg::GenericOp>>(
657       context, linalgTilingOptions,
658       linalg::LinalgTransformationFilter(Identifier::get("tile", context)));
659   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
660 }
661 
662 static void applyPadPattern(FuncOp funcOp, ArrayRef<int64_t> packPaddings,
663                             ArrayRef<int64_t> hoistPaddings) {
664   MLIRContext *context = funcOp.getContext();
665   RewritePatternSet padPattern(context);
666   auto linalgPaddingOptions = linalg::LinalgPaddingOptions();
667   auto packFunc = [&](OpOperand &opOperand) {
668     return opOperand.getOperandNumber() < packPaddings.size()
669                ? packPaddings[opOperand.getOperandNumber()]
670                : false;
671   };
672   auto hoistingFunc = [&](OpOperand &opOperand) {
673     return opOperand.getOperandNumber() < hoistPaddings.size()
674                ? hoistPaddings[opOperand.getOperandNumber()]
675                : 0;
676   };
677   linalgPaddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
678   linalgPaddingOptions.setPaddingNoFoldComputationFunction(packFunc);
679   linalgPaddingOptions.setPaddingHoistComputationFunction(hoistingFunc);
680   padPattern.add<LinalgPaddingPattern>(
681       context, linalgPaddingOptions,
682       LinalgTransformationFilter(Identifier::get("pad", context)));
683   (void)applyPatternsAndFoldGreedily(funcOp, std::move(padPattern));
684 }
685 
686 static void applyInterchangePattern(FuncOp funcOp,
687                                     ArrayRef<unsigned> interchangeVector) {
688   MLIRContext *context = funcOp.getContext();
689   RewritePatternSet interchangePattern(context);
690   interchangePattern.add<GenericOpInterchangePattern>(
691       context, interchangeVector,
692       LinalgTransformationFilter(ArrayRef<Identifier>{},
693                                  Identifier::get("interchange", context)));
694   (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
695 }
696 
697 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
698 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
699 
700 namespace {
701 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
702 /// `idx`-th loop contains only "full" iterations and a second loop for the
703 /// remaining partial iteration (if any).
704 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
705   TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
706       : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
707   }
708 
709   LogicalResult matchAndRewrite(TiledLoopOp loopOp,
710                                 PatternRewriter &rewriter) const override {
711     SmallVector<int64_t> peeledLoops;
712     if (loopOp->hasAttr(kPeeledLoopsLabel)) {
713       auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
714       peeledLoops =
715           llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
716             return attr.cast<IntegerAttr>().getInt();
717           }));
718       // Check if the loop was already peeled.
719       if (llvm::find(peeledLoops, idx) != peeledLoops.end())
720         return failure();
721     }
722     if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
723       // No peeling of loop nests with a partial iteration.
724       return failure();
725 
726     if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
727       return failure();
728 
729     // Peel loop and canonicalize.
730     TiledLoopOp result;
731     if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
732                                                     result)))
733       return failure();
734 
735     // Apply label, so that the same loop is not rewritten a second time.
736     peeledLoops.push_back(idx);
737     rewriter.updateRootInPlace(loopOp, [&]() {
738       loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
739     });
740     result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
741     result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
742 
743     return success();
744   }
745 
746   /// Index of loop to peel.
747   int64_t idx;
748 
749   /// If set to true, do not peel TiledLoopOps with a partial iteration.
750   bool skipPartial;
751 };
752 } // namespace
753 
754 static void applyTiledLoopPeelingPattern(FuncOp funcOp,
755                                          ArrayRef<unsigned> loops,
756                                          bool skipPartial) {
757   MLIRContext *ctx = funcOp.getContext();
758   RewritePatternSet patterns(ctx);
759   for (unsigned idx : loops)
760     patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
761   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
762 
763   // Drop the markers.
764   funcOp.walk([](TiledLoopOp op) {
765     op->removeAttr(kPeeledLoopsLabel);
766     op->removeAttr(kPartialIterationLabel);
767   });
768 }
769 
770 /// Apply transformations specified as patterns.
771 void TestLinalgTransforms::runOnFunction() {
772   auto lambda = [&](void *) {
773     getFunction().walk([](LinalgOp op) {
774       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
775     });
776   };
777   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
778 
779   if (testPromotionOptions) {
780     RewritePatternSet patterns(&getContext());
781     fillPromotionCallBackPatterns(&getContext(), patterns);
782     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
783     return;
784   }
785   if (testTileAndDistributionOptions) {
786     RewritePatternSet patterns(&getContext());
787     fillTileAndDistributePatterns(&getContext(), patterns);
788     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
789     return;
790   }
791   if (testPatterns)
792     return applyPatterns(getFunction());
793   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
794     return applyMatmulToVectorPatterns(getFunction(),
795                                        testMatmulToVectorPatterns1dTiling,
796                                        testMatmulToVectorPatterns2dTiling);
797   if (testVectorTransferForwardingPatterns)
798     return applyVectorTransferForwardingPatterns(getFunction());
799   if (testGenericToVectorPattern)
800     return applyLinalgToVectorPatterns(getFunction());
801   if (testTransformPadTensor)
802     return applyPadTensorToGenericPatterns(getFunction());
803   if (testGeneralizePadTensor)
804     return applyGeneralizePadTensorPatterns(getFunction());
805   if (testSwapSubTensorPadTensor)
806     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
807   if (testTiledLoopPeeling.hasValue())
808     return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
809                                         skipPartial);
810   if (testTilePattern)
811     return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
812                             nofoldOperands, peeledLoops,
813                             /*scalarizeDynamicDims=*/false);
814   if (testTileScalarizeDynamicDims)
815     return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
816                             nofoldOperands,
817                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
818   if (testHoistPadding) {
819     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
820       PadTensorOp hoistedOp;
821       FailureOr<Value> newResult = linalg::hoistPaddingOnTensors(
822           padTensorOp, testHoistPadding, hoistedOp);
823       if (succeeded(newResult)) {
824         padTensorOp.getResult().replaceAllUsesWith(newResult.getValue());
825         padTensorOp->erase();
826       }
827     });
828   }
829   if (testPadPattern)
830     return applyPadPattern(getFunction(), packPaddings, hoistPaddings);
831   if (testInterchangePattern.hasValue())
832     return applyInterchangePattern(getFunction(), testInterchangePattern);
833   if (testDecomposeConvolutionPattern)
834     return applyDecomposeConvolutionPatterns(getFunction());
835 }
836 
837 namespace mlir {
838 namespace test {
839 void registerTestLinalgTransforms() {
840   PassRegistration<TestLinalgTransforms>();
841 }
842 } // namespace test
843 } // namespace mlir
844