1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Pass/PassManager.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 using namespace mlir::vector;
28 
29 namespace {
30 
31 struct TestVectorToVectorLowering
32     : public PassWrapper<TestVectorToVectorLowering, OperationPass<FuncOp>> {
33   TestVectorToVectorLowering() = default;
34   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
35       : PassWrapper(pass) {}
36   StringRef getArgument() const final {
37     return "test-vector-to-vector-lowering";
38   }
39   StringRef getDescription() const final {
40     return "Test lowering patterns between ops in the vector dialect";
41   }
42 
43   void getDependentDialects(DialectRegistry &registry) const override {
44     registry.insert<AffineDialect>();
45   }
46 
47   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
48                       llvm::cl::init(false)};
49 
50   void runOnOperation() override {
51     auto *ctx = &getContext();
52     RewritePatternSet patterns(ctx);
53     if (unroll) {
54       populateVectorUnrollPatterns(
55           patterns,
56           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
57               filter));
58     }
59     populateVectorToVectorCanonicalizationPatterns(patterns);
60     populateBubbleVectorBitCastOpPatterns(patterns);
61     populateCastAwayVectorLeadingOneDimPatterns(patterns);
62     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
63   }
64 
65 private:
66   // Return the target shape based on op type.
67   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
68     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
69       return SmallVector<int64_t, 4>(2, 2);
70     if (isa<vector::ContractionOp>(op))
71       return SmallVector<int64_t, 4>(3, 2);
72     // For transfer ops, just propagate the shape coming from
73     // InsertStridedSlices/ExtractStridedSlices.
74     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
75       VectorType dstVec;
76       for (Operation *users : readOp->getUsers()) {
77         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
78         if (!extract)
79           return llvm::None;
80         auto vecType = extract.getResult().getType().cast<VectorType>();
81         if (dstVec && dstVec != vecType)
82           return llvm::None;
83         dstVec = vecType;
84       }
85       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
86                                      dstVec.getShape().end());
87     }
88     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
89       auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>();
90       if (!insert)
91         return llvm::None;
92       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
93       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
94     }
95     return llvm::None;
96   }
97 
98   static LogicalResult filter(Operation *op) {
99     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
100                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
101   }
102 };
103 
104 struct TestVectorContractionLowering
105     : public PassWrapper<TestVectorContractionLowering, OperationPass<FuncOp>> {
106   StringRef getArgument() const final {
107     return "test-vector-contraction-lowering";
108   }
109   StringRef getDescription() const final {
110     return "Test lowering patterns that lower contract ops in the vector "
111            "dialect";
112   }
113   TestVectorContractionLowering() = default;
114   TestVectorContractionLowering(const TestVectorContractionLowering &pass)
115       : PassWrapper(pass) {}
116 
117   Option<bool> lowerToFlatMatrix{
118       *this, "vector-lower-matrix-intrinsics",
119       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
120       llvm::cl::init(false)};
121   Option<bool> lowerToOuterProduct{
122       *this, "vector-outerproduct",
123       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
124       llvm::cl::init(false)};
125   Option<bool> lowerToFilterOuterProduct{
126       *this, "vector-filter-outerproduct",
127       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
128                      "vectors of size 4."),
129       llvm::cl::init(false)};
130 
131   void runOnOperation() override {
132     RewritePatternSet patterns(&getContext());
133 
134     // Test on one pattern in isolation.
135     if (lowerToOuterProduct) {
136       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
137       VectorTransformsOptions options{lowering};
138       patterns.add<ContractionOpToOuterProductOpLowering>(options,
139                                                           &getContext());
140       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
141       return;
142     }
143 
144     // Test on one pattern in isolation.
145     if (lowerToFilterOuterProduct) {
146       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
147       VectorTransformsOptions options{lowering};
148       patterns.add<ContractionOpToOuterProductOpLowering>(
149           options, &getContext(), [](vector::ContractionOp op) {
150             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
151             // where M is not 4.
152             if (op.getRhsType().getShape()[0] == 4)
153               return failure();
154             return success();
155           });
156       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
157       return;
158     }
159 
160     // Test on all contract lowering patterns.
161     VectorContractLowering contractLowering = VectorContractLowering::Dot;
162     if (lowerToFlatMatrix)
163       contractLowering = VectorContractLowering::Matmul;
164     VectorMultiReductionLowering vectorMultiReductionLowering =
165         VectorMultiReductionLowering::InnerParallel;
166     VectorTransformsOptions options{contractLowering,
167                                     vectorMultiReductionLowering,
168                                     VectorTransposeLowering()};
169     populateVectorBroadcastLoweringPatterns(patterns);
170     populateVectorContractLoweringPatterns(patterns, options);
171     populateVectorMaskOpLoweringPatterns(patterns);
172     populateVectorShapeCastLoweringPatterns(patterns);
173     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
174   }
175 };
176 
177 struct TestVectorTransposeLowering
178     : public PassWrapper<TestVectorTransposeLowering, OperationPass<FuncOp>> {
179   StringRef getArgument() const final {
180     return "test-vector-transpose-lowering";
181   }
182   StringRef getDescription() const final {
183     return "Test lowering patterns that lower contract ops in the vector "
184            "dialect";
185   }
186   TestVectorTransposeLowering() = default;
187   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
188       : PassWrapper(pass) {}
189 
190   Option<bool> lowerToEltwise{
191       *this, "eltwise",
192       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
193       llvm::cl::init(false)};
194   Option<bool> lowerToFlatTranspose{
195       *this, "flat",
196       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
197       llvm::cl::init(false)};
198   Option<bool> lowerToShuffleTranspose{
199       *this, "shuffle",
200       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
201       llvm::cl::init(false)};
202   Option<bool> lowerToAvx2{
203       *this, "avx2",
204       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
205       llvm::cl::init(false)};
206 
207   void getDependentDialects(DialectRegistry &registry) const override {
208     registry.insert<LLVM::LLVMDialect>();
209   }
210 
211   void runOnOperation() override {
212     RewritePatternSet patterns(&getContext());
213 
214     // Test on one pattern in isolation.
215     // Explicitly disable shape_cast lowering.
216     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
217                                               .enableVectorTransposeLowering()
218                                               .enableShapeCastLowering(false);
219     if (lowerToEltwise) {
220       options = options.setVectorTransformsOptions(
221           VectorTransformsOptions().setVectorTransposeLowering(
222               VectorTransposeLowering::EltWise));
223     }
224     if (lowerToFlatTranspose) {
225       options = options.setVectorTransformsOptions(
226           VectorTransformsOptions().setVectorTransposeLowering(
227               VectorTransposeLowering::Flat));
228     }
229     if (lowerToShuffleTranspose) {
230       options = options.setVectorTransformsOptions(
231           VectorTransformsOptions().setVectorTransposeLowering(
232               VectorTransposeLowering::Shuffle));
233     }
234     if (lowerToAvx2) {
235       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
236           x86vector::avx2::LoweringOptions().setTransposeOptions(
237               x86vector::avx2::TransposeLoweringOptions()
238                   .lower4x8xf32()
239                   .lower8x8xf32()));
240     }
241 
242     OpPassManager dynamicPM("builtin.func");
243     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
244     if (failed(runPipeline(dynamicPM, getOperation())))
245       return signalPassFailure();
246   }
247 };
248 
249 struct TestVectorUnrollingPatterns
250     : public PassWrapper<TestVectorUnrollingPatterns, OperationPass<FuncOp>> {
251   StringRef getArgument() const final {
252     return "test-vector-unrolling-patterns";
253   }
254   StringRef getDescription() const final {
255     return "Test lowering patterns to unroll contract ops in the vector "
256            "dialect";
257   }
258   TestVectorUnrollingPatterns() = default;
259   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
260       : PassWrapper(pass) {}
261   void runOnOperation() override {
262     MLIRContext *ctx = &getContext();
263     RewritePatternSet patterns(ctx);
264     populateVectorUnrollPatterns(
265         patterns, UnrollVectorOptions()
266                       .setNativeShape(ArrayRef<int64_t>{2, 2})
267                       .setFilterConstraint([](Operation *op) {
268                         return success(isa<arith::AddFOp, vector::FMAOp,
269                                            vector::MultiDimReductionOp>(op));
270                       }));
271     populateVectorUnrollPatterns(
272         patterns, UnrollVectorOptions()
273                       .setNativeShape(ArrayRef<int64_t>{2})
274                       .setFilterConstraint([](Operation *op) {
275                         return success(isa<vector::ReductionOp>(op));
276                       }));
277 
278     if (unrollBasedOnType) {
279       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
280           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
281         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
282         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
283         if (auto floatType = contractOp.getLhsType()
284                                  .getElementType()
285                                  .dyn_cast<FloatType>()) {
286           if (floatType.getWidth() == 16) {
287             nativeShape[2] = 4;
288           }
289         }
290         return nativeShape;
291       };
292       populateVectorUnrollPatterns(patterns,
293                                    UnrollVectorOptions()
294                                        .setNativeShapeFn(nativeShapeFn)
295                                        .setFilterConstraint([](Operation *op) {
296                                          return success(isa<ContractionOp>(op));
297                                        }));
298     } else {
299       populateVectorUnrollPatterns(
300           patterns, UnrollVectorOptions()
301                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
302                         .setFilterConstraint([](Operation *op) {
303                           return success(isa<ContractionOp>(op));
304                         }));
305     }
306     populateVectorToVectorCanonicalizationPatterns(patterns);
307     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
308   }
309 
310   Option<bool> unrollBasedOnType{
311       *this, "unroll-based-on-type",
312       llvm::cl::desc("Set the unroll factor based on type of the operation"),
313       llvm::cl::init(false)};
314 };
315 
316 struct TestVectorDistributePatterns
317     : public PassWrapper<TestVectorDistributePatterns, OperationPass<FuncOp>> {
318   StringRef getArgument() const final {
319     return "test-vector-distribute-patterns";
320   }
321   StringRef getDescription() const final {
322     return "Test lowering patterns to distribute vector ops in the vector "
323            "dialect";
324   }
325   TestVectorDistributePatterns() = default;
326   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
327       : PassWrapper(pass) {}
328   void getDependentDialects(DialectRegistry &registry) const override {
329     registry.insert<VectorDialect>();
330     registry.insert<AffineDialect>();
331   }
332   ListOption<int32_t> multiplicity{
333       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
334       llvm::cl::desc("Set the multiplicity used for distributing vector")};
335 
336   void runOnOperation() override {
337     MLIRContext *ctx = &getContext();
338     RewritePatternSet patterns(ctx);
339     FuncOp func = getOperation();
340     func.walk([&](arith::AddFOp op) {
341       OpBuilder builder(op);
342       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
343         SmallVector<int64_t, 2> mul;
344         SmallVector<AffineExpr, 2> perm;
345         SmallVector<Value, 2> ids;
346         unsigned count = 0;
347         // Remove the multiplicity of 1 and calculate the affine map based on
348         // the multiplicity.
349         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
350         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
351           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
352             mul.push_back(m[i]);
353             ids.push_back(func.getArgument(count++));
354             perm.push_back(getAffineDimExpr(i, ctx));
355           }
356         }
357         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
358                                   perm, ctx);
359         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
360             builder, op.getOperation(), ids, mul, map);
361         if (ops.hasValue()) {
362           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
363           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
364                                               extractOp);
365         }
366       }
367     });
368     populatePropagateVectorDistributionPatterns(patterns);
369     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
370   }
371 };
372 
373 struct TestVectorToLoopPatterns
374     : public PassWrapper<TestVectorToLoopPatterns, OperationPass<FuncOp>> {
375   StringRef getArgument() const final { return "test-vector-to-forloop"; }
376   StringRef getDescription() const final {
377     return "Test lowering patterns to break up a vector op into a for loop";
378   }
379   TestVectorToLoopPatterns() = default;
380   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
381       : PassWrapper(pass) {}
382   void getDependentDialects(DialectRegistry &registry) const override {
383     registry.insert<VectorDialect>();
384     registry.insert<AffineDialect>();
385   }
386   Option<int32_t> multiplicity{
387       *this, "distribution-multiplicity",
388       llvm::cl::desc("Set the multiplicity used for distributing vector"),
389       llvm::cl::init(32)};
390   void runOnOperation() override {
391     MLIRContext *ctx = &getContext();
392     RewritePatternSet patterns(ctx);
393     FuncOp func = getOperation();
394     func.walk([&](arith::AddFOp op) {
395       // Check that the operation type can be broken down into a loop.
396       VectorType type = op.getType().dyn_cast<VectorType>();
397       if (!type || type.getRank() != 1 ||
398           type.getNumElements() % multiplicity != 0)
399         return mlir::WalkResult::advance();
400       auto filterAlloc = [](Operation *op) {
401         return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
402       };
403       auto dependentOps = getSlice(op, filterAlloc);
404       // Create a loop and move instructions from the Op slice into the loop.
405       OpBuilder builder(op);
406       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
407       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
408       auto numIter =
409           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
410       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
411       for (Operation *it : dependentOps) {
412         it->moveBefore(forOp.getBody()->getTerminator());
413       }
414       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
415       // break up the original op and let the patterns propagate.
416       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
417           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
418           map);
419       if (ops.hasValue()) {
420         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
421         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
422       }
423       return mlir::WalkResult::interrupt();
424     });
425     populatePropagateVectorDistributionPatterns(patterns);
426     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
427   }
428 };
429 
430 struct TestVectorTransferUnrollingPatterns
431     : public PassWrapper<TestVectorTransferUnrollingPatterns,
432                          OperationPass<FuncOp>> {
433   void getDependentDialects(DialectRegistry &registry) const override {
434     registry.insert<AffineDialect>();
435   }
436   StringRef getArgument() const final {
437     return "test-vector-transfer-unrolling-patterns";
438   }
439   StringRef getDescription() const final {
440     return "Test lowering patterns to unroll transfer ops in the vector "
441            "dialect";
442   }
443   void runOnOperation() override {
444     MLIRContext *ctx = &getContext();
445     RewritePatternSet patterns(ctx);
446     populateVectorUnrollPatterns(
447         patterns,
448         UnrollVectorOptions()
449             .setNativeShape(ArrayRef<int64_t>{2, 2})
450             .setFilterConstraint([](Operation *op) {
451               return success(
452                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
453             }));
454     populateVectorToVectorCanonicalizationPatterns(patterns);
455     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
456   }
457 };
458 
459 struct TestVectorTransferFullPartialSplitPatterns
460     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
461                          OperationPass<FuncOp>> {
462   StringRef getArgument() const final {
463     return "test-vector-transfer-full-partial-split";
464   }
465   StringRef getDescription() const final {
466     return "Test lowering patterns to split "
467            "transfer ops via scf.if + linalg ops";
468   }
469   TestVectorTransferFullPartialSplitPatterns() = default;
470   TestVectorTransferFullPartialSplitPatterns(
471       const TestVectorTransferFullPartialSplitPatterns &pass)
472       : PassWrapper(pass) {}
473 
474   void getDependentDialects(DialectRegistry &registry) const override {
475     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
476                     scf::SCFDialect>();
477   }
478 
479   Option<bool> useLinalgOps{
480       *this, "use-memref-copy",
481       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
482                      "memref.copy operations."),
483       llvm::cl::init(false)};
484   void runOnOperation() override {
485     MLIRContext *ctx = &getContext();
486     RewritePatternSet patterns(ctx);
487     VectorTransformsOptions options;
488     if (useLinalgOps)
489       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
490     else
491       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
492     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
493     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
494   }
495 };
496 
497 struct TestVectorTransferOpt
498     : public PassWrapper<TestVectorTransferOpt, OperationPass<FuncOp>> {
499   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
500   StringRef getDescription() const final {
501     return "Test optimization transformations for transfer ops";
502   }
503   void runOnOperation() override { transferOpflowOpt(getOperation()); }
504 };
505 
506 struct TestVectorTransferLoweringPatterns
507     : public PassWrapper<TestVectorTransferLoweringPatterns,
508                          OperationPass<FuncOp>> {
509   void getDependentDialects(DialectRegistry &registry) const override {
510     registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
511   }
512   StringRef getArgument() const final {
513     return "test-vector-transfer-lowering-patterns";
514   }
515   StringRef getDescription() const final {
516     return "Test lowering patterns to lower transfer ops to other vector ops";
517   }
518   void runOnOperation() override {
519     RewritePatternSet patterns(&getContext());
520     populateVectorTransferLoweringPatterns(patterns);
521     populateVectorTransferPermutationMapLoweringPatterns(patterns);
522     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
523   }
524 };
525 
526 struct TestVectorMultiReductionLoweringPatterns
527     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
528                          OperationPass<FuncOp>> {
529   TestVectorMultiReductionLoweringPatterns() = default;
530   TestVectorMultiReductionLoweringPatterns(
531       const TestVectorMultiReductionLoweringPatterns &pass)
532       : PassWrapper(pass) {}
533   void getDependentDialects(DialectRegistry &registry) const override {
534     registry.insert<memref::MemRefDialect>();
535   }
536   StringRef getArgument() const final {
537     return "test-vector-multi-reduction-lowering-patterns";
538   }
539   StringRef getDescription() const final {
540     return "Test lowering patterns to lower vector.multi_reduction to other "
541            "vector ops";
542   }
543   Option<bool> useOuterReductions{
544       *this, "use-outer-reductions",
545       llvm::cl::desc("Move reductions to outer most dimensions"),
546       llvm::cl::init(false)};
547   void runOnOperation() override {
548     RewritePatternSet patterns(&getContext());
549     populateVectorMultiReductionLoweringPatterns(
550         patterns, useOuterReductions
551                       ? vector::VectorMultiReductionLowering::InnerParallel
552                       : vector::VectorMultiReductionLowering::InnerReduction);
553     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
554   }
555 };
556 
557 struct TestVectorTransferCollapseInnerMostContiguousDims
558     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
559                          OperationPass<FuncOp>> {
560   TestVectorTransferCollapseInnerMostContiguousDims() = default;
561   TestVectorTransferCollapseInnerMostContiguousDims(
562       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
563 
564   void getDependentDialects(DialectRegistry &registry) const override {
565     registry.insert<memref::MemRefDialect, AffineDialect>();
566   }
567 
568   StringRef getArgument() const final {
569     return "test-vector-transfer-collapse-inner-most-dims";
570   }
571 
572   StringRef getDescription() const final {
573     return "Test lowering patterns that reducedes the rank of the vector "
574            "transfer memory and vector operands.";
575   }
576 
577   void runOnOperation() override {
578     RewritePatternSet patterns(&getContext());
579     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
580     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
581   }
582 };
583 
584 struct TestVectorReduceToContractPatternsPatterns
585     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
586                          OperationPass<FuncOp>> {
587   StringRef getArgument() const final {
588     return "test-vector-reduction-to-contract-patterns";
589   }
590   StringRef getDescription() const final {
591     return "Test patterns to convert multireduce op to contract and combine "
592            "broadcast/transpose to contract";
593   }
594   void runOnOperation() override {
595     RewritePatternSet patterns(&getContext());
596     populateVectorReductionToContractPatterns(patterns);
597     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
598   }
599 };
600 
601 struct TestVectorTransferDropUnitDimsPatterns
602     : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
603                          OperationPass<FuncOp>> {
604   StringRef getArgument() const final {
605     return "test-vector-transfer-drop-unit-dims-patterns";
606   }
607   void getDependentDialects(DialectRegistry &registry) const override {
608     registry.insert<memref::MemRefDialect>();
609   }
610   void runOnOperation() override {
611     RewritePatternSet patterns(&getContext());
612     populateVectorTransferDropUnitDimsPatterns(patterns);
613     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
614   }
615 };
616 
617 struct TestFlattenVectorTransferPatterns
618     : public PassWrapper<TestFlattenVectorTransferPatterns,
619                          OperationPass<FuncOp>> {
620   StringRef getArgument() const final {
621     return "test-vector-transfer-flatten-patterns";
622   }
623   StringRef getDescription() const final {
624     return "Test patterns to rewrite contiguous row-major N-dimensional "
625            "vector.transfer_{read,write} ops into 1D transfers";
626   }
627   void getDependentDialects(DialectRegistry &registry) const override {
628     registry.insert<memref::MemRefDialect>();
629   }
630   void runOnOperation() override {
631     RewritePatternSet patterns(&getContext());
632     populateFlattenVectorTransferPatterns(patterns);
633     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
634   }
635 };
636 
637 struct TestVectorScanLowering
638     : public PassWrapper<TestVectorScanLowering, OperationPass<FuncOp>> {
639   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
640   StringRef getDescription() const final {
641     return "Test lowering patterns that lower the scan op in the vector "
642            "dialect";
643   }
644   void runOnOperation() override {
645     RewritePatternSet patterns(&getContext());
646     populateVectorScanLoweringPatterns(patterns);
647     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
648   }
649 };
650 
651 } // namespace
652 
653 namespace mlir {
654 namespace test {
655 void registerTestVectorLowerings() {
656   PassRegistration<TestVectorToVectorLowering>();
657 
658   PassRegistration<TestVectorContractionLowering>();
659 
660   PassRegistration<TestVectorTransposeLowering>();
661 
662   PassRegistration<TestVectorUnrollingPatterns>();
663 
664   PassRegistration<TestVectorTransferUnrollingPatterns>();
665 
666   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
667 
668   PassRegistration<TestVectorDistributePatterns>();
669 
670   PassRegistration<TestVectorToLoopPatterns>();
671 
672   PassRegistration<TestVectorTransferOpt>();
673 
674   PassRegistration<TestVectorTransferLoweringPatterns>();
675 
676   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
677 
678   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
679 
680   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
681 
682   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
683 
684   PassRegistration<TestFlattenVectorTransferPatterns>();
685 
686   PassRegistration<TestVectorScanLowering>();
687 }
688 } // namespace test
689 } // namespace mlir
690