1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Pass/PassManager.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 using namespace mlir::vector;
28 
29 namespace {
30 
31 struct TestVectorToVectorLowering
32     : public PassWrapper<TestVectorToVectorLowering, OperationPass<FuncOp>> {
33   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
34 
35   TestVectorToVectorLowering() = default;
36   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
37       : PassWrapper(pass) {}
38   StringRef getArgument() const final {
39     return "test-vector-to-vector-lowering";
40   }
41   StringRef getDescription() const final {
42     return "Test lowering patterns between ops in the vector dialect";
43   }
44 
45   void getDependentDialects(DialectRegistry &registry) const override {
46     registry.insert<AffineDialect>();
47   }
48 
49   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
50                       llvm::cl::init(false)};
51 
52   void runOnOperation() override {
53     auto *ctx = &getContext();
54     RewritePatternSet patterns(ctx);
55     if (unroll) {
56       populateVectorUnrollPatterns(
57           patterns,
58           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
59               filter));
60     }
61     populateVectorToVectorCanonicalizationPatterns(patterns);
62     populateBubbleVectorBitCastOpPatterns(patterns);
63     populateCastAwayVectorLeadingOneDimPatterns(patterns);
64     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
65   }
66 
67 private:
68   // Return the target shape based on op type.
69   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
70     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
71       return SmallVector<int64_t, 4>(2, 2);
72     if (isa<vector::ContractionOp>(op))
73       return SmallVector<int64_t, 4>(3, 2);
74     // For transfer ops, just propagate the shape coming from
75     // InsertStridedSlices/ExtractStridedSlices.
76     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
77       VectorType dstVec;
78       for (Operation *users : readOp->getUsers()) {
79         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
80         if (!extract)
81           return llvm::None;
82         auto vecType = extract.getResult().getType().cast<VectorType>();
83         if (dstVec && dstVec != vecType)
84           return llvm::None;
85         dstVec = vecType;
86       }
87       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
88                                      dstVec.getShape().end());
89     }
90     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
91       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
92       if (!insert)
93         return llvm::None;
94       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
95       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
96     }
97     return llvm::None;
98   }
99 
100   static LogicalResult filter(Operation *op) {
101     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
102                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
103   }
104 };
105 
106 struct TestVectorContractionLowering
107     : public PassWrapper<TestVectorContractionLowering, OperationPass<FuncOp>> {
108   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
109 
110   StringRef getArgument() const final {
111     return "test-vector-contraction-lowering";
112   }
113   StringRef getDescription() const final {
114     return "Test lowering patterns that lower contract ops in the vector "
115            "dialect";
116   }
117   TestVectorContractionLowering() = default;
118   TestVectorContractionLowering(const TestVectorContractionLowering &pass)
119       : PassWrapper(pass) {}
120 
121   Option<bool> lowerToFlatMatrix{
122       *this, "vector-lower-matrix-intrinsics",
123       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
124       llvm::cl::init(false)};
125   Option<bool> lowerToOuterProduct{
126       *this, "vector-outerproduct",
127       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
128       llvm::cl::init(false)};
129   Option<bool> lowerToFilterOuterProduct{
130       *this, "vector-filter-outerproduct",
131       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
132                      "vectors of size 4."),
133       llvm::cl::init(false)};
134 
135   void runOnOperation() override {
136     RewritePatternSet patterns(&getContext());
137 
138     // Test on one pattern in isolation.
139     if (lowerToOuterProduct) {
140       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
141       VectorTransformsOptions options{lowering};
142       patterns.add<ContractionOpToOuterProductOpLowering>(options,
143                                                           &getContext());
144       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
145       return;
146     }
147 
148     // Test on one pattern in isolation.
149     if (lowerToFilterOuterProduct) {
150       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
151       VectorTransformsOptions options{lowering};
152       patterns.add<ContractionOpToOuterProductOpLowering>(
153           options, &getContext(), [](vector::ContractionOp op) {
154             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
155             // where M is not 4.
156             if (op.getRhsType().getShape()[0] == 4)
157               return failure();
158             return success();
159           });
160       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
161       return;
162     }
163 
164     // Test on all contract lowering patterns.
165     VectorContractLowering contractLowering = VectorContractLowering::Dot;
166     if (lowerToFlatMatrix)
167       contractLowering = VectorContractLowering::Matmul;
168     VectorMultiReductionLowering vectorMultiReductionLowering =
169         VectorMultiReductionLowering::InnerParallel;
170     VectorTransformsOptions options{contractLowering,
171                                     vectorMultiReductionLowering,
172                                     VectorTransposeLowering()};
173     populateVectorBroadcastLoweringPatterns(patterns);
174     populateVectorContractLoweringPatterns(patterns, options);
175     populateVectorMaskOpLoweringPatterns(patterns);
176     populateVectorShapeCastLoweringPatterns(patterns);
177     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
178   }
179 };
180 
181 struct TestVectorTransposeLowering
182     : public PassWrapper<TestVectorTransposeLowering, OperationPass<FuncOp>> {
183   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
184 
185   StringRef getArgument() const final {
186     return "test-vector-transpose-lowering";
187   }
188   StringRef getDescription() const final {
189     return "Test lowering patterns that lower contract ops in the vector "
190            "dialect";
191   }
192   TestVectorTransposeLowering() = default;
193   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
194       : PassWrapper(pass) {}
195 
196   Option<bool> lowerToEltwise{
197       *this, "eltwise",
198       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
199       llvm::cl::init(false)};
200   Option<bool> lowerToFlatTranspose{
201       *this, "flat",
202       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
203       llvm::cl::init(false)};
204   Option<bool> lowerToShuffleTranspose{
205       *this, "shuffle",
206       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
207       llvm::cl::init(false)};
208   Option<bool> lowerToAvx2{
209       *this, "avx2",
210       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
211       llvm::cl::init(false)};
212 
213   void getDependentDialects(DialectRegistry &registry) const override {
214     registry.insert<LLVM::LLVMDialect>();
215   }
216 
217   void runOnOperation() override {
218     RewritePatternSet patterns(&getContext());
219 
220     // Test on one pattern in isolation.
221     // Explicitly disable shape_cast lowering.
222     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
223                                               .enableVectorTransposeLowering()
224                                               .enableShapeCastLowering(false);
225     if (lowerToEltwise) {
226       options = options.setVectorTransformsOptions(
227           VectorTransformsOptions().setVectorTransposeLowering(
228               VectorTransposeLowering::EltWise));
229     }
230     if (lowerToFlatTranspose) {
231       options = options.setVectorTransformsOptions(
232           VectorTransformsOptions().setVectorTransposeLowering(
233               VectorTransposeLowering::Flat));
234     }
235     if (lowerToShuffleTranspose) {
236       options = options.setVectorTransformsOptions(
237           VectorTransformsOptions().setVectorTransposeLowering(
238               VectorTransposeLowering::Shuffle));
239     }
240     if (lowerToAvx2) {
241       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
242           x86vector::avx2::LoweringOptions().setTransposeOptions(
243               x86vector::avx2::TransposeLoweringOptions()
244                   .lower4x8xf32()
245                   .lower8x8xf32()));
246     }
247 
248     OpPassManager dynamicPM("func.func");
249     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
250     if (failed(runPipeline(dynamicPM, getOperation())))
251       return signalPassFailure();
252   }
253 };
254 
255 struct TestVectorUnrollingPatterns
256     : public PassWrapper<TestVectorUnrollingPatterns, OperationPass<FuncOp>> {
257   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
258 
259   StringRef getArgument() const final {
260     return "test-vector-unrolling-patterns";
261   }
262   StringRef getDescription() const final {
263     return "Test lowering patterns to unroll contract ops in the vector "
264            "dialect";
265   }
266   TestVectorUnrollingPatterns() = default;
267   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
268       : PassWrapper(pass) {}
269   void runOnOperation() override {
270     MLIRContext *ctx = &getContext();
271     RewritePatternSet patterns(ctx);
272     populateVectorUnrollPatterns(
273         patterns, UnrollVectorOptions()
274                       .setNativeShape(ArrayRef<int64_t>{2, 2})
275                       .setFilterConstraint([](Operation *op) {
276                         return success(isa<arith::AddFOp, vector::FMAOp,
277                                            vector::MultiDimReductionOp>(op));
278                       }));
279     populateVectorUnrollPatterns(
280         patterns, UnrollVectorOptions()
281                       .setNativeShape(ArrayRef<int64_t>{2})
282                       .setFilterConstraint([](Operation *op) {
283                         return success(isa<vector::ReductionOp>(op));
284                       }));
285 
286     if (unrollBasedOnType) {
287       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
288           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
289         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
290         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
291         if (auto floatType = contractOp.getLhsType()
292                                  .getElementType()
293                                  .dyn_cast<FloatType>()) {
294           if (floatType.getWidth() == 16) {
295             nativeShape[2] = 4;
296           }
297         }
298         return nativeShape;
299       };
300       populateVectorUnrollPatterns(patterns,
301                                    UnrollVectorOptions()
302                                        .setNativeShapeFn(nativeShapeFn)
303                                        .setFilterConstraint([](Operation *op) {
304                                          return success(isa<ContractionOp>(op));
305                                        }));
306     } else {
307       populateVectorUnrollPatterns(
308           patterns, UnrollVectorOptions()
309                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
310                         .setFilterConstraint([](Operation *op) {
311                           return success(isa<ContractionOp>(op));
312                         }));
313     }
314     populateVectorToVectorCanonicalizationPatterns(patterns);
315     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
316   }
317 
318   Option<bool> unrollBasedOnType{
319       *this, "unroll-based-on-type",
320       llvm::cl::desc("Set the unroll factor based on type of the operation"),
321       llvm::cl::init(false)};
322 };
323 
324 struct TestVectorDistributePatterns
325     : public PassWrapper<TestVectorDistributePatterns, OperationPass<FuncOp>> {
326   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns)
327 
328   StringRef getArgument() const final {
329     return "test-vector-distribute-patterns";
330   }
331   StringRef getDescription() const final {
332     return "Test lowering patterns to distribute vector ops in the vector "
333            "dialect";
334   }
335   TestVectorDistributePatterns() = default;
336   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
337       : PassWrapper(pass) {}
338   void getDependentDialects(DialectRegistry &registry) const override {
339     registry.insert<VectorDialect>();
340     registry.insert<AffineDialect>();
341   }
342   ListOption<int32_t> multiplicity{
343       *this, "distribution-multiplicity",
344       llvm::cl::desc("Set the multiplicity used for distributing vector")};
345 
346   void runOnOperation() override {
347     MLIRContext *ctx = &getContext();
348     RewritePatternSet patterns(ctx);
349     FuncOp func = getOperation();
350     func.walk([&](arith::AddFOp op) {
351       OpBuilder builder(op);
352       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
353         SmallVector<int64_t, 2> mul;
354         SmallVector<AffineExpr, 2> perm;
355         SmallVector<Value, 2> ids;
356         unsigned count = 0;
357         // Remove the multiplicity of 1 and calculate the affine map based on
358         // the multiplicity.
359         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
360         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
361           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
362             mul.push_back(m[i]);
363             ids.push_back(func.getArgument(count++));
364             perm.push_back(getAffineDimExpr(i, ctx));
365           }
366         }
367         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
368                                   perm, ctx);
369         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
370             builder, op.getOperation(), ids, mul, map);
371         if (ops.hasValue()) {
372           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
373           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
374                                               extractOp);
375         }
376       }
377     });
378     populatePropagateVectorDistributionPatterns(patterns);
379     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
380   }
381 };
382 
383 struct TestVectorToLoopPatterns
384     : public PassWrapper<TestVectorToLoopPatterns, OperationPass<FuncOp>> {
385   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns)
386 
387   StringRef getArgument() const final { return "test-vector-to-forloop"; }
388   StringRef getDescription() const final {
389     return "Test lowering patterns to break up a vector op into a for loop";
390   }
391   TestVectorToLoopPatterns() = default;
392   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
393       : PassWrapper(pass) {}
394   void getDependentDialects(DialectRegistry &registry) const override {
395     registry.insert<VectorDialect>();
396     registry.insert<AffineDialect>();
397   }
398   Option<int32_t> multiplicity{
399       *this, "distribution-multiplicity",
400       llvm::cl::desc("Set the multiplicity used for distributing vector"),
401       llvm::cl::init(32)};
402   void runOnOperation() override {
403     MLIRContext *ctx = &getContext();
404     RewritePatternSet patterns(ctx);
405     FuncOp func = getOperation();
406     func.walk([&](arith::AddFOp op) {
407       // Check that the operation type can be broken down into a loop.
408       VectorType type = op.getType().dyn_cast<VectorType>();
409       if (!type || type.getRank() != 1 ||
410           type.getNumElements() % multiplicity != 0)
411         return mlir::WalkResult::advance();
412       auto filterAlloc = [](Operation *op) {
413         return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
414       };
415       auto dependentOps = getSlice(op, filterAlloc);
416       // Create a loop and move instructions from the Op slice into the loop.
417       OpBuilder builder(op);
418       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
419       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
420       auto numIter =
421           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
422       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
423       for (Operation *it : dependentOps) {
424         it->moveBefore(forOp.getBody()->getTerminator());
425       }
426       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
427       // break up the original op and let the patterns propagate.
428       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
429           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
430           map);
431       if (ops.hasValue()) {
432         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
433         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
434       }
435       return mlir::WalkResult::interrupt();
436     });
437     populatePropagateVectorDistributionPatterns(patterns);
438     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
439   }
440 };
441 
442 struct TestVectorTransferUnrollingPatterns
443     : public PassWrapper<TestVectorTransferUnrollingPatterns,
444                          OperationPass<FuncOp>> {
445   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
446       TestVectorTransferUnrollingPatterns)
447 
448   void getDependentDialects(DialectRegistry &registry) const override {
449     registry.insert<AffineDialect>();
450   }
451   StringRef getArgument() const final {
452     return "test-vector-transfer-unrolling-patterns";
453   }
454   StringRef getDescription() const final {
455     return "Test lowering patterns to unroll transfer ops in the vector "
456            "dialect";
457   }
458   void runOnOperation() override {
459     MLIRContext *ctx = &getContext();
460     RewritePatternSet patterns(ctx);
461     populateVectorUnrollPatterns(
462         patterns,
463         UnrollVectorOptions()
464             .setNativeShape(ArrayRef<int64_t>{2, 2})
465             .setFilterConstraint([](Operation *op) {
466               return success(
467                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
468             }));
469     populateVectorToVectorCanonicalizationPatterns(patterns);
470     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
471   }
472 };
473 
474 struct TestVectorTransferFullPartialSplitPatterns
475     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
476                          OperationPass<FuncOp>> {
477   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
478       TestVectorTransferFullPartialSplitPatterns)
479 
480   StringRef getArgument() const final {
481     return "test-vector-transfer-full-partial-split";
482   }
483   StringRef getDescription() const final {
484     return "Test lowering patterns to split "
485            "transfer ops via scf.if + linalg ops";
486   }
487   TestVectorTransferFullPartialSplitPatterns() = default;
488   TestVectorTransferFullPartialSplitPatterns(
489       const TestVectorTransferFullPartialSplitPatterns &pass)
490       : PassWrapper(pass) {}
491 
492   void getDependentDialects(DialectRegistry &registry) const override {
493     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
494                     scf::SCFDialect>();
495   }
496 
497   Option<bool> useLinalgOps{
498       *this, "use-memref-copy",
499       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
500                      "memref.copy operations."),
501       llvm::cl::init(false)};
502   void runOnOperation() override {
503     MLIRContext *ctx = &getContext();
504     RewritePatternSet patterns(ctx);
505     VectorTransformsOptions options;
506     if (useLinalgOps)
507       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
508     else
509       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
510     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
511     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
512   }
513 };
514 
515 struct TestVectorTransferOpt
516     : public PassWrapper<TestVectorTransferOpt, OperationPass<FuncOp>> {
517   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
518 
519   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
520   StringRef getDescription() const final {
521     return "Test optimization transformations for transfer ops";
522   }
523   void runOnOperation() override { transferOpflowOpt(getOperation()); }
524 };
525 
526 struct TestVectorTransferLoweringPatterns
527     : public PassWrapper<TestVectorTransferLoweringPatterns,
528                          OperationPass<FuncOp>> {
529   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
530       TestVectorTransferLoweringPatterns)
531 
532   void getDependentDialects(DialectRegistry &registry) const override {
533     registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
534   }
535   StringRef getArgument() const final {
536     return "test-vector-transfer-lowering-patterns";
537   }
538   StringRef getDescription() const final {
539     return "Test lowering patterns to lower transfer ops to other vector ops";
540   }
541   void runOnOperation() override {
542     RewritePatternSet patterns(&getContext());
543     populateVectorTransferLoweringPatterns(patterns);
544     populateVectorTransferPermutationMapLoweringPatterns(patterns);
545     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
546   }
547 };
548 
549 struct TestVectorMultiReductionLoweringPatterns
550     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
551                          OperationPass<FuncOp>> {
552   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
553       TestVectorMultiReductionLoweringPatterns)
554 
555   TestVectorMultiReductionLoweringPatterns() = default;
556   TestVectorMultiReductionLoweringPatterns(
557       const TestVectorMultiReductionLoweringPatterns &pass)
558       : PassWrapper(pass) {}
559   void getDependentDialects(DialectRegistry &registry) const override {
560     registry.insert<memref::MemRefDialect>();
561   }
562   StringRef getArgument() const final {
563     return "test-vector-multi-reduction-lowering-patterns";
564   }
565   StringRef getDescription() const final {
566     return "Test lowering patterns to lower vector.multi_reduction to other "
567            "vector ops";
568   }
569   Option<bool> useOuterReductions{
570       *this, "use-outer-reductions",
571       llvm::cl::desc("Move reductions to outer most dimensions"),
572       llvm::cl::init(false)};
573   void runOnOperation() override {
574     RewritePatternSet patterns(&getContext());
575     populateVectorMultiReductionLoweringPatterns(
576         patterns, useOuterReductions
577                       ? vector::VectorMultiReductionLowering::InnerParallel
578                       : vector::VectorMultiReductionLowering::InnerReduction);
579     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
580   }
581 };
582 
583 struct TestVectorTransferCollapseInnerMostContiguousDims
584     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
585                          OperationPass<FuncOp>> {
586   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
587       TestVectorTransferCollapseInnerMostContiguousDims)
588 
589   TestVectorTransferCollapseInnerMostContiguousDims() = default;
590   TestVectorTransferCollapseInnerMostContiguousDims(
591       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
592 
593   void getDependentDialects(DialectRegistry &registry) const override {
594     registry.insert<memref::MemRefDialect, AffineDialect>();
595   }
596 
597   StringRef getArgument() const final {
598     return "test-vector-transfer-collapse-inner-most-dims";
599   }
600 
601   StringRef getDescription() const final {
602     return "Test lowering patterns that reducedes the rank of the vector "
603            "transfer memory and vector operands.";
604   }
605 
606   void runOnOperation() override {
607     RewritePatternSet patterns(&getContext());
608     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
609     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
610   }
611 };
612 
613 struct TestVectorReduceToContractPatternsPatterns
614     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
615                          OperationPass<FuncOp>> {
616   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
617       TestVectorReduceToContractPatternsPatterns)
618 
619   StringRef getArgument() const final {
620     return "test-vector-reduction-to-contract-patterns";
621   }
622   StringRef getDescription() const final {
623     return "Test patterns to convert multireduce op to contract and combine "
624            "broadcast/transpose to contract";
625   }
626   void runOnOperation() override {
627     RewritePatternSet patterns(&getContext());
628     populateVectorReductionToContractPatterns(patterns);
629     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
630   }
631 };
632 
633 struct TestVectorTransferDropUnitDimsPatterns
634     : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
635                          OperationPass<FuncOp>> {
636   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
637       TestVectorTransferDropUnitDimsPatterns)
638 
639   StringRef getArgument() const final {
640     return "test-vector-transfer-drop-unit-dims-patterns";
641   }
642   void getDependentDialects(DialectRegistry &registry) const override {
643     registry.insert<memref::MemRefDialect>();
644   }
645   void runOnOperation() override {
646     RewritePatternSet patterns(&getContext());
647     populateVectorTransferDropUnitDimsPatterns(patterns);
648     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
649   }
650 };
651 
652 struct TestFlattenVectorTransferPatterns
653     : public PassWrapper<TestFlattenVectorTransferPatterns,
654                          OperationPass<FuncOp>> {
655   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
656       TestFlattenVectorTransferPatterns)
657 
658   StringRef getArgument() const final {
659     return "test-vector-transfer-flatten-patterns";
660   }
661   StringRef getDescription() const final {
662     return "Test patterns to rewrite contiguous row-major N-dimensional "
663            "vector.transfer_{read,write} ops into 1D transfers";
664   }
665   void getDependentDialects(DialectRegistry &registry) const override {
666     registry.insert<memref::MemRefDialect>();
667   }
668   void runOnOperation() override {
669     RewritePatternSet patterns(&getContext());
670     populateFlattenVectorTransferPatterns(patterns);
671     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
672   }
673 };
674 
675 struct TestVectorScanLowering
676     : public PassWrapper<TestVectorScanLowering, OperationPass<FuncOp>> {
677   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
678 
679   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
680   StringRef getDescription() const final {
681     return "Test lowering patterns that lower the scan op in the vector "
682            "dialect";
683   }
684   void runOnOperation() override {
685     RewritePatternSet patterns(&getContext());
686     populateVectorScanLoweringPatterns(patterns);
687     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
688   }
689 };
690 
691 } // namespace
692 
693 namespace mlir {
694 namespace test {
695 void registerTestVectorLowerings() {
696   PassRegistration<TestVectorToVectorLowering>();
697 
698   PassRegistration<TestVectorContractionLowering>();
699 
700   PassRegistration<TestVectorTransposeLowering>();
701 
702   PassRegistration<TestVectorUnrollingPatterns>();
703 
704   PassRegistration<TestVectorTransferUnrollingPatterns>();
705 
706   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
707 
708   PassRegistration<TestVectorDistributePatterns>();
709 
710   PassRegistration<TestVectorToLoopPatterns>();
711 
712   PassRegistration<TestVectorTransferOpt>();
713 
714   PassRegistration<TestVectorTransferLoweringPatterns>();
715 
716   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
717 
718   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
719 
720   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
721 
722   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
723 
724   PassRegistration<TestFlattenVectorTransferPatterns>();
725 
726   PassRegistration<TestVectorScanLowering>();
727 }
728 } // namespace test
729 } // namespace mlir
730