1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Passes.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/SCF/SCF.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorTransforms.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Pass/PassManager.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 using namespace mlir::vector;
28 
29 namespace {
30 
31 struct TestVectorToVectorLowering
32     : public PassWrapper<TestVectorToVectorLowering, FunctionPass> {
33   TestVectorToVectorLowering() = default;
34   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) {}
35   StringRef getArgument() const final {
36     return "test-vector-to-vector-lowering";
37   }
38   StringRef getDescription() const final {
39     return "Test lowering patterns between ops in the vector dialect";
40   }
41 
42   void getDependentDialects(DialectRegistry &registry) const override {
43     registry.insert<AffineDialect>();
44   }
45 
46   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
47                       llvm::cl::init(false)};
48 
49   void runOnFunction() override {
50     auto *ctx = &getContext();
51     RewritePatternSet patterns(ctx);
52     if (unroll) {
53       populateVectorUnrollPatterns(
54           patterns,
55           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
56               filter));
57     }
58     populateVectorToVectorCanonicalizationPatterns(patterns);
59     populateBubbleVectorBitCastOpPatterns(patterns);
60     populateCastAwayVectorLeadingOneDimPatterns(patterns);
61     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
62   }
63 
64 private:
65   // Return the target shape based on op type.
66   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
67     if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op))
68       return SmallVector<int64_t, 4>(2, 2);
69     if (isa<vector::ContractionOp>(op))
70       return SmallVector<int64_t, 4>(3, 2);
71     // For transfer ops, just propagate the shape coming from
72     // InsertStridedSlices/ExtractStridedSlices.
73     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
74       VectorType dstVec;
75       for (Operation *users : readOp->getUsers()) {
76         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
77         if (!extract)
78           return llvm::None;
79         auto vecType = extract.getResult().getType().cast<VectorType>();
80         if (dstVec && dstVec != vecType)
81           return llvm::None;
82         dstVec = vecType;
83       }
84       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
85                                      dstVec.getShape().end());
86     }
87     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
88       auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>();
89       if (!insert)
90         return llvm::None;
91       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
92       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
93     }
94     return llvm::None;
95   }
96 
97   static LogicalResult filter(Operation *op) {
98     return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp,
99                        TransferReadOp, TransferWriteOp>(op));
100   }
101 };
102 
103 struct TestVectorContractionLowering
104     : public PassWrapper<TestVectorContractionLowering, FunctionPass> {
105   StringRef getArgument() const final {
106     return "test-vector-contraction-lowering";
107   }
108   StringRef getDescription() const final {
109     return "Test lowering patterns that lower contract ops in the vector "
110            "dialect";
111   }
112   TestVectorContractionLowering() = default;
113   TestVectorContractionLowering(const TestVectorContractionLowering &pass) {}
114 
115   Option<bool> lowerToFlatMatrix{
116       *this, "vector-lower-matrix-intrinsics",
117       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
118       llvm::cl::init(false)};
119   Option<bool> lowerToOuterProduct{
120       *this, "vector-outerproduct",
121       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
122       llvm::cl::init(false)};
123   Option<bool> lowerToFilterOuterProduct{
124       *this, "vector-filter-outerproduct",
125       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
126                      "vectors of size 4."),
127       llvm::cl::init(false)};
128 
129   void runOnFunction() override {
130     RewritePatternSet patterns(&getContext());
131 
132     // Test on one pattern in isolation.
133     if (lowerToOuterProduct) {
134       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
135       VectorTransformsOptions options{lowering};
136       patterns.add<ContractionOpToOuterProductOpLowering>(options,
137                                                           &getContext());
138       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139       return;
140     }
141 
142     // Test on one pattern in isolation.
143     if (lowerToFilterOuterProduct) {
144       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
145       VectorTransformsOptions options{lowering};
146       patterns.add<ContractionOpToOuterProductOpLowering>(
147           options, &getContext(), [](vector::ContractionOp op) {
148             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
149             // where M is not 4.
150             if (op.getRhsType().getShape()[0] == 4)
151               return failure();
152             return success();
153           });
154       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
155       return;
156     }
157 
158     // Test on all contract lowering patterns.
159     VectorContractLowering contractLowering = VectorContractLowering::Dot;
160     if (lowerToFlatMatrix)
161       contractLowering = VectorContractLowering::Matmul;
162     VectorMultiReductionLowering vectorMultiReductionLowering =
163         VectorMultiReductionLowering::InnerParallel;
164     VectorTransformsOptions options{contractLowering,
165                                     vectorMultiReductionLowering,
166                                     VectorTransposeLowering()};
167     populateVectorBroadcastLoweringPatterns(patterns);
168     populateVectorContractLoweringPatterns(patterns, options);
169     populateVectorMaskOpLoweringPatterns(patterns);
170     populateVectorShapeCastLoweringPatterns(patterns);
171     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
172   }
173 };
174 
175 struct TestVectorTransposeLowering
176     : public PassWrapper<TestVectorTransposeLowering, FunctionPass> {
177   StringRef getArgument() const final {
178     return "test-vector-transpose-lowering";
179   }
180   StringRef getDescription() const final {
181     return "Test lowering patterns that lower contract ops in the vector "
182            "dialect";
183   }
184   TestVectorTransposeLowering() = default;
185   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) {}
186 
187   Option<bool> lowerToEltwise{
188       *this, "eltwise",
189       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
190       llvm::cl::init(false)};
191   Option<bool> lowerToFlatTranspose{
192       *this, "flat",
193       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
194       llvm::cl::init(false)};
195   Option<bool> lowerToShuffleTranspose{
196       *this, "shuffle",
197       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
198       llvm::cl::init(false)};
199   Option<bool> lowerToAvx2{
200       *this, "avx2",
201       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
202       llvm::cl::init(false)};
203 
204   void getDependentDialects(DialectRegistry &registry) const override {
205     registry.insert<LLVM::LLVMDialect>();
206   }
207 
208   void runOnFunction() override {
209     RewritePatternSet patterns(&getContext());
210 
211     // Test on one pattern in isolation.
212     // Explicitly disable shape_cast lowering.
213     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
214                                               .enableVectorTransposeLowering()
215                                               .enableShapeCastLowering(false);
216     if (lowerToEltwise) {
217       options = options.setVectorTransformsOptions(
218           VectorTransformsOptions().setVectorTransposeLowering(
219               VectorTransposeLowering::EltWise));
220     }
221     if (lowerToFlatTranspose) {
222       options = options.setVectorTransformsOptions(
223           VectorTransformsOptions().setVectorTransposeLowering(
224               VectorTransposeLowering::Flat));
225     }
226     if (lowerToShuffleTranspose) {
227       options = options.setVectorTransformsOptions(
228           VectorTransformsOptions().setVectorTransposeLowering(
229               VectorTransposeLowering::Shuffle));
230     }
231     if (lowerToAvx2) {
232       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
233           x86vector::avx2::LoweringOptions().setTransposeOptions(
234               x86vector::avx2::TransposeLoweringOptions()
235                   .lower4x8xf32()
236                   .lower8x8xf32()));
237     }
238 
239     OpPassManager dynamicPM("builtin.func");
240     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
241     if (failed(runPipeline(dynamicPM, getFunction())))
242       return signalPassFailure();
243   }
244 };
245 
246 struct TestVectorUnrollingPatterns
247     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
248   StringRef getArgument() const final {
249     return "test-vector-unrolling-patterns";
250   }
251   StringRef getDescription() const final {
252     return "Test lowering patterns to unroll contract ops in the vector "
253            "dialect";
254   }
255   TestVectorUnrollingPatterns() = default;
256   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
257   void runOnFunction() override {
258     MLIRContext *ctx = &getContext();
259     RewritePatternSet patterns(ctx);
260     populateVectorUnrollPatterns(
261         patterns, UnrollVectorOptions()
262                       .setNativeShape(ArrayRef<int64_t>{2, 2})
263                       .setFilterConstraint([](Operation *op) {
264                         return success(isa<arith::AddFOp, vector::FMAOp>(op));
265                       }));
266 
267     if (unrollBasedOnType) {
268       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
269           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
270         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
271         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
272         if (auto floatType = contractOp.getLhsType()
273                                  .getElementType()
274                                  .dyn_cast<FloatType>()) {
275           if (floatType.getWidth() == 16) {
276             nativeShape[2] = 4;
277           }
278         }
279         return nativeShape;
280       };
281       populateVectorUnrollPatterns(patterns,
282                                    UnrollVectorOptions()
283                                        .setNativeShapeFn(nativeShapeFn)
284                                        .setFilterConstraint([](Operation *op) {
285                                          return success(isa<ContractionOp>(op));
286                                        }));
287     } else {
288       populateVectorUnrollPatterns(
289           patterns, UnrollVectorOptions()
290                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
291                         .setFilterConstraint([](Operation *op) {
292                           return success(isa<ContractionOp>(op));
293                         }));
294     }
295     populateVectorToVectorCanonicalizationPatterns(patterns);
296     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
297   }
298 
299   Option<bool> unrollBasedOnType{
300       *this, "unroll-based-on-type",
301       llvm::cl::desc("Set the unroll factor based on type of the operation"),
302       llvm::cl::init(false)};
303 };
304 
305 struct TestVectorDistributePatterns
306     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
307   StringRef getArgument() const final {
308     return "test-vector-distribute-patterns";
309   }
310   StringRef getDescription() const final {
311     return "Test lowering patterns to distribute vector ops in the vector "
312            "dialect";
313   }
314   TestVectorDistributePatterns() = default;
315   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
316   void getDependentDialects(DialectRegistry &registry) const override {
317     registry.insert<VectorDialect>();
318     registry.insert<AffineDialect>();
319   }
320   ListOption<int32_t> multiplicity{
321       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
322       llvm::cl::desc("Set the multiplicity used for distributing vector")};
323 
324   void runOnFunction() override {
325     MLIRContext *ctx = &getContext();
326     RewritePatternSet patterns(ctx);
327     FuncOp func = getFunction();
328     func.walk([&](arith::AddFOp op) {
329       OpBuilder builder(op);
330       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
331         SmallVector<int64_t, 2> mul;
332         SmallVector<AffineExpr, 2> perm;
333         SmallVector<Value, 2> ids;
334         unsigned count = 0;
335         // Remove the multiplicity of 1 and calculate the affine map based on
336         // the multiplicity.
337         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
338         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
339           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
340             mul.push_back(m[i]);
341             ids.push_back(func.getArgument(count++));
342             perm.push_back(getAffineDimExpr(i, ctx));
343           }
344         }
345         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
346                                   perm, ctx);
347         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
348             builder, op.getOperation(), ids, mul, map);
349         if (ops.hasValue()) {
350           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
351           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
352                                               extractOp);
353         }
354       }
355     });
356     populatePropagateVectorDistributionPatterns(patterns);
357     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
358   }
359 };
360 
361 struct TestVectorToLoopPatterns
362     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
363   StringRef getArgument() const final { return "test-vector-to-forloop"; }
364   StringRef getDescription() const final {
365     return "Test lowering patterns to break up a vector op into a for loop";
366   }
367   TestVectorToLoopPatterns() = default;
368   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
369   void getDependentDialects(DialectRegistry &registry) const override {
370     registry.insert<VectorDialect>();
371     registry.insert<AffineDialect>();
372   }
373   Option<int32_t> multiplicity{
374       *this, "distribution-multiplicity",
375       llvm::cl::desc("Set the multiplicity used for distributing vector"),
376       llvm::cl::init(32)};
377   void runOnFunction() override {
378     MLIRContext *ctx = &getContext();
379     RewritePatternSet patterns(ctx);
380     FuncOp func = getFunction();
381     func.walk([&](arith::AddFOp op) {
382       // Check that the operation type can be broken down into a loop.
383       VectorType type = op.getType().dyn_cast<VectorType>();
384       if (!type || type.getRank() != 1 ||
385           type.getNumElements() % multiplicity != 0)
386         return mlir::WalkResult::advance();
387       auto filterAlloc = [](Operation *op) {
388         if (isa<arith::ConstantOp, memref::AllocOp, CallOp>(op))
389           return false;
390         return true;
391       };
392       auto dependentOps = getSlice(op, filterAlloc);
393       // Create a loop and move instructions from the Op slice into the loop.
394       OpBuilder builder(op);
395       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
396       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
397       auto numIter =
398           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
399       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
400       for (Operation *it : dependentOps) {
401         it->moveBefore(forOp.getBody()->getTerminator());
402       }
403       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
404       // break up the original op and let the patterns propagate.
405       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
406           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
407           map);
408       if (ops.hasValue()) {
409         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
410         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
411       }
412       return mlir::WalkResult::interrupt();
413     });
414     populatePropagateVectorDistributionPatterns(patterns);
415     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
416   }
417 };
418 
419 struct TestVectorTransferUnrollingPatterns
420     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
421   void getDependentDialects(DialectRegistry &registry) const override {
422     registry.insert<AffineDialect>();
423   }
424   StringRef getArgument() const final {
425     return "test-vector-transfer-unrolling-patterns";
426   }
427   StringRef getDescription() const final {
428     return "Test lowering patterns to unroll transfer ops in the vector "
429            "dialect";
430   }
431   void runOnFunction() override {
432     MLIRContext *ctx = &getContext();
433     RewritePatternSet patterns(ctx);
434     populateVectorUnrollPatterns(
435         patterns,
436         UnrollVectorOptions()
437             .setNativeShape(ArrayRef<int64_t>{2, 2})
438             .setFilterConstraint([](Operation *op) {
439               return success(
440                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
441             }));
442     populateVectorToVectorCanonicalizationPatterns(patterns);
443     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
444   }
445 };
446 
447 struct TestVectorTransferFullPartialSplitPatterns
448     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
449                          FunctionPass> {
450   StringRef getArgument() const final {
451     return "test-vector-transfer-full-partial-split";
452   }
453   StringRef getDescription() const final {
454     return "Test lowering patterns to split "
455            "transfer ops via scf.if + linalg ops";
456   }
457   TestVectorTransferFullPartialSplitPatterns() = default;
458   TestVectorTransferFullPartialSplitPatterns(
459       const TestVectorTransferFullPartialSplitPatterns &pass) {}
460 
461   void getDependentDialects(DialectRegistry &registry) const override {
462     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
463                     scf::SCFDialect>();
464   }
465 
466   Option<bool> useLinalgOps{
467       *this, "use-linalg-copy",
468       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
469                      "linalg.copy operations."),
470       llvm::cl::init(false)};
471   void runOnFunction() override {
472     MLIRContext *ctx = &getContext();
473     RewritePatternSet patterns(ctx);
474     VectorTransformsOptions options;
475     if (useLinalgOps)
476       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
477     else
478       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
479     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
480     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
481   }
482 };
483 
484 struct TestVectorTransferOpt
485     : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
486   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
487   StringRef getDescription() const final {
488     return "Test optimization transformations for transfer ops";
489   }
490   void runOnFunction() override { transferOpflowOpt(getFunction()); }
491 };
492 
493 struct TestVectorTransferLoweringPatterns
494     : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
495   void getDependentDialects(DialectRegistry &registry) const override {
496     registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
497   }
498   StringRef getArgument() const final {
499     return "test-vector-transfer-lowering-patterns";
500   }
501   StringRef getDescription() const final {
502     return "Test lowering patterns to lower transfer ops to other vector ops";
503   }
504   void runOnFunction() override {
505     RewritePatternSet patterns(&getContext());
506     populateVectorTransferLoweringPatterns(patterns);
507     populateVectorTransferPermutationMapLoweringPatterns(patterns);
508     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
509   }
510 };
511 
512 struct TestVectorMultiReductionLoweringPatterns
513     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
514                          FunctionPass> {
515   TestVectorMultiReductionLoweringPatterns() = default;
516   TestVectorMultiReductionLoweringPatterns(
517       const TestVectorMultiReductionLoweringPatterns &pass) {}
518   void getDependentDialects(DialectRegistry &registry) const override {
519     registry.insert<memref::MemRefDialect>();
520   }
521   StringRef getArgument() const final {
522     return "test-vector-multi-reduction-lowering-patterns";
523   }
524   StringRef getDescription() const final {
525     return "Test lowering patterns to lower vector.multi_reduction to other "
526            "vector ops";
527   }
528   Option<bool> useOuterReductions{
529       *this, "use-outer-reductions",
530       llvm::cl::desc("Move reductions to outer most dimensions"),
531       llvm::cl::init(false)};
532   void runOnFunction() override {
533     RewritePatternSet patterns(&getContext());
534     populateVectorMultiReductionLoweringPatterns(
535         patterns, useOuterReductions
536                       ? vector::VectorMultiReductionLowering::InnerParallel
537                       : vector::VectorMultiReductionLowering::InnerReduction);
538     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
539   }
540 };
541 
542 struct TestVectorTransferCollapseInnerMostContiguousDims
543     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
544                          FunctionPass> {
545   TestVectorTransferCollapseInnerMostContiguousDims() = default;
546   TestVectorTransferCollapseInnerMostContiguousDims(
547       const TestVectorTransferCollapseInnerMostContiguousDims &pass) {}
548 
549   void getDependentDialects(DialectRegistry &registry) const override {
550     registry.insert<memref::MemRefDialect, AffineDialect>();
551   }
552 
553   StringRef getArgument() const final {
554     return "test-vector-transfer-collapse-inner-most-dims";
555   }
556 
557   StringRef getDescription() const final {
558     return "Test lowering patterns that reducedes the rank of the vector "
559            "transfer memory and vector operands.";
560   }
561 
562   void runOnFunction() override {
563     RewritePatternSet patterns(&getContext());
564     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
565     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
566   }
567 };
568 
569 struct TestVectorReduceToContractPatternsPatterns
570     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
571                          FunctionPass> {
572   StringRef getArgument() const final {
573     return "test-vector-reduction-to-contract-patterns";
574   }
575   StringRef getDescription() const final {
576     return "Test patterns to convert multireduce op to contract and combine "
577            "broadcast/transpose to contract";
578   }
579   void runOnFunction() override {
580     RewritePatternSet patterns(&getContext());
581     populateVectorReductionToContractPatterns(patterns);
582     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
583   }
584 };
585 
586 struct TestVectorTransferDropUnitDimsPatterns
587     : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, FunctionPass> {
588   StringRef getArgument() const final {
589     return "test-vector-transfer-drop-unit-dims-patterns";
590   }
591   void getDependentDialects(DialectRegistry &registry) const override {
592     registry.insert<memref::MemRefDialect>();
593   }
594   void runOnFunction() override {
595     RewritePatternSet patterns(&getContext());
596     populateVectorTransferDropUnitDimsPatterns(patterns);
597     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
598   }
599 };
600 
601 struct TestFlattenVectorTransferPatterns
602     : public PassWrapper<TestFlattenVectorTransferPatterns, FunctionPass> {
603   StringRef getArgument() const final {
604     return "test-vector-transfer-flatten-patterns";
605   }
606   StringRef getDescription() const final {
607     return "Test patterns to rewrite contiguous row-major N-dimensional "
608            "vector.transfer_{read,write} ops into 1D transfers";
609   }
610   void getDependentDialects(DialectRegistry &registry) const override {
611     registry.insert<memref::MemRefDialect>();
612   }
613   void runOnFunction() override {
614     RewritePatternSet patterns(&getContext());
615     populateFlattenVectorTransferPatterns(patterns);
616     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
617   }
618 };
619 
620 } // namespace
621 
622 namespace mlir {
623 namespace test {
624 void registerTestVectorLowerings() {
625   PassRegistration<TestVectorToVectorLowering>();
626 
627   PassRegistration<TestVectorContractionLowering>();
628 
629   PassRegistration<TestVectorTransposeLowering>();
630 
631   PassRegistration<TestVectorUnrollingPatterns>();
632 
633   PassRegistration<TestVectorTransferUnrollingPatterns>();
634 
635   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
636 
637   PassRegistration<TestVectorDistributePatterns>();
638 
639   PassRegistration<TestVectorToLoopPatterns>();
640 
641   PassRegistration<TestVectorTransferOpt>();
642 
643   PassRegistration<TestVectorTransferLoweringPatterns>();
644 
645   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
646 
647   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
648 
649   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
650 
651   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
652 
653   PassRegistration<TestFlattenVectorTransferPatterns>();
654 }
655 } // namespace test
656 } // namespace mlir
657