1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Pass/PassManager.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 using namespace mlir::vector;
28 
29 namespace {
30 
31 struct TestVectorToVectorLowering
32     : public PassWrapper<TestVectorToVectorLowering, OperationPass<FuncOp>> {
33   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
34 
35   TestVectorToVectorLowering() = default;
36   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
37       : PassWrapper(pass) {}
38   StringRef getArgument() const final {
39     return "test-vector-to-vector-lowering";
40   }
41   StringRef getDescription() const final {
42     return "Test lowering patterns between ops in the vector dialect";
43   }
44 
45   void getDependentDialects(DialectRegistry &registry) const override {
46     registry.insert<AffineDialect>();
47   }
48 
49   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
50                       llvm::cl::init(false)};
51 
52   void runOnOperation() override {
53     auto *ctx = &getContext();
54     RewritePatternSet patterns(ctx);
55     if (unroll) {
56       populateVectorUnrollPatterns(
57           patterns,
58           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
59               filter));
60     }
61     populateVectorToVectorCanonicalizationPatterns(patterns);
62     populateBubbleVectorBitCastOpPatterns(patterns);
63     populateCastAwayVectorLeadingOneDimPatterns(patterns);
64     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
65   }
66 
67 private:
68   // Return the target shape based on op type.
69   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
70     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
71       return SmallVector<int64_t, 4>(2, 2);
72     if (isa<vector::ContractionOp>(op))
73       return SmallVector<int64_t, 4>(3, 2);
74     // For transfer ops, just propagate the shape coming from
75     // InsertStridedSlices/ExtractStridedSlices.
76     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
77       VectorType dstVec;
78       for (Operation *users : readOp->getUsers()) {
79         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
80         if (!extract)
81           return llvm::None;
82         auto vecType = extract.getResult().getType().cast<VectorType>();
83         if (dstVec && dstVec != vecType)
84           return llvm::None;
85         dstVec = vecType;
86       }
87       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
88                                      dstVec.getShape().end());
89     }
90     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
91       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
92       if (!insert)
93         return llvm::None;
94       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
95       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
96     }
97     return llvm::None;
98   }
99 
100   static LogicalResult filter(Operation *op) {
101     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
102                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
103   }
104 };
105 
106 struct TestVectorContractionLowering
107     : public PassWrapper<TestVectorContractionLowering, OperationPass<FuncOp>> {
108   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
109 
110   StringRef getArgument() const final {
111     return "test-vector-contraction-lowering";
112   }
113   StringRef getDescription() const final {
114     return "Test lowering patterns that lower contract ops in the vector "
115            "dialect";
116   }
117   TestVectorContractionLowering() = default;
118   TestVectorContractionLowering(const TestVectorContractionLowering &pass)
119       : PassWrapper(pass) {}
120 
121   Option<bool> lowerToFlatMatrix{
122       *this, "vector-lower-matrix-intrinsics",
123       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
124       llvm::cl::init(false)};
125   Option<bool> lowerToOuterProduct{
126       *this, "vector-outerproduct",
127       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
128       llvm::cl::init(false)};
129   Option<bool> lowerToFilterOuterProduct{
130       *this, "vector-filter-outerproduct",
131       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
132                      "vectors of size 4."),
133       llvm::cl::init(false)};
134 
135   void runOnOperation() override {
136     RewritePatternSet patterns(&getContext());
137 
138     // Test on one pattern in isolation.
139     if (lowerToOuterProduct) {
140       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
141       VectorTransformsOptions options{lowering};
142       patterns.add<ContractionOpToOuterProductOpLowering>(options,
143                                                           &getContext());
144       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
145       return;
146     }
147 
148     // Test on one pattern in isolation.
149     if (lowerToFilterOuterProduct) {
150       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
151       VectorTransformsOptions options{lowering};
152       patterns.add<ContractionOpToOuterProductOpLowering>(
153           options, &getContext(), [](vector::ContractionOp op) {
154             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
155             // where M is not 4.
156             if (op.getRhsType().getShape()[0] == 4)
157               return failure();
158             return success();
159           });
160       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
161       return;
162     }
163 
164     // Test on all contract lowering patterns.
165     VectorContractLowering contractLowering = VectorContractLowering::Dot;
166     if (lowerToFlatMatrix)
167       contractLowering = VectorContractLowering::Matmul;
168     VectorMultiReductionLowering vectorMultiReductionLowering =
169         VectorMultiReductionLowering::InnerParallel;
170     VectorTransformsOptions options{contractLowering,
171                                     vectorMultiReductionLowering,
172                                     VectorTransposeLowering()};
173     populateVectorBroadcastLoweringPatterns(patterns);
174     populateVectorContractLoweringPatterns(patterns, options);
175     populateVectorMaskOpLoweringPatterns(patterns);
176     populateVectorShapeCastLoweringPatterns(patterns);
177     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
178   }
179 };
180 
181 struct TestVectorTransposeLowering
182     : public PassWrapper<TestVectorTransposeLowering, OperationPass<FuncOp>> {
183   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
184 
185   StringRef getArgument() const final {
186     return "test-vector-transpose-lowering";
187   }
188   StringRef getDescription() const final {
189     return "Test lowering patterns that lower contract ops in the vector "
190            "dialect";
191   }
192   TestVectorTransposeLowering() = default;
193   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
194       : PassWrapper(pass) {}
195 
196   Option<bool> lowerToEltwise{
197       *this, "eltwise",
198       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
199       llvm::cl::init(false)};
200   Option<bool> lowerToFlatTranspose{
201       *this, "flat",
202       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
203       llvm::cl::init(false)};
204   Option<bool> lowerToShuffleTranspose{
205       *this, "shuffle",
206       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
207       llvm::cl::init(false)};
208   Option<bool> lowerToAvx2{
209       *this, "avx2",
210       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
211       llvm::cl::init(false)};
212 
213   void getDependentDialects(DialectRegistry &registry) const override {
214     registry.insert<LLVM::LLVMDialect>();
215   }
216 
217   void runOnOperation() override {
218     RewritePatternSet patterns(&getContext());
219 
220     // Test on one pattern in isolation.
221     // Explicitly disable shape_cast lowering.
222     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
223                                               .enableVectorTransposeLowering()
224                                               .enableShapeCastLowering(false);
225     if (lowerToEltwise) {
226       options = options.setVectorTransformsOptions(
227           VectorTransformsOptions().setVectorTransposeLowering(
228               VectorTransposeLowering::EltWise));
229     }
230     if (lowerToFlatTranspose) {
231       options = options.setVectorTransformsOptions(
232           VectorTransformsOptions().setVectorTransposeLowering(
233               VectorTransposeLowering::Flat));
234     }
235     if (lowerToShuffleTranspose) {
236       options = options.setVectorTransformsOptions(
237           VectorTransformsOptions().setVectorTransposeLowering(
238               VectorTransposeLowering::Shuffle));
239     }
240     if (lowerToAvx2) {
241       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
242           x86vector::avx2::LoweringOptions().setTransposeOptions(
243               x86vector::avx2::TransposeLoweringOptions()
244                   .lower4x8xf32()
245                   .lower8x8xf32()));
246     }
247 
248     OpPassManager dynamicPM("func.func");
249     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
250     if (failed(runPipeline(dynamicPM, getOperation())))
251       return signalPassFailure();
252   }
253 };
254 
255 struct TestVectorUnrollingPatterns
256     : public PassWrapper<TestVectorUnrollingPatterns, OperationPass<FuncOp>> {
257   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
258 
259   StringRef getArgument() const final {
260     return "test-vector-unrolling-patterns";
261   }
262   StringRef getDescription() const final {
263     return "Test lowering patterns to unroll contract ops in the vector "
264            "dialect";
265   }
266   TestVectorUnrollingPatterns() = default;
267   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
268       : PassWrapper(pass) {}
269   void runOnOperation() override {
270     MLIRContext *ctx = &getContext();
271     RewritePatternSet patterns(ctx);
272     populateVectorUnrollPatterns(
273         patterns, UnrollVectorOptions()
274                       .setNativeShape(ArrayRef<int64_t>{2, 2})
275                       .setFilterConstraint([](Operation *op) {
276                         return success(isa<arith::AddFOp, vector::FMAOp,
277                                            vector::MultiDimReductionOp>(op));
278                       }));
279     populateVectorUnrollPatterns(
280         patterns, UnrollVectorOptions()
281                       .setNativeShape(ArrayRef<int64_t>{2})
282                       .setFilterConstraint([](Operation *op) {
283                         return success(isa<vector::ReductionOp>(op));
284                       }));
285     populateVectorUnrollPatterns(
286         patterns, UnrollVectorOptions()
287                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
288                       .setFilterConstraint([](Operation *op) {
289                         return success(isa<vector::TransposeOp>(op));
290                       }));
291 
292     if (unrollBasedOnType) {
293       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
294           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
295         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
296         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
297         if (auto floatType = contractOp.getLhsType()
298                                  .getElementType()
299                                  .dyn_cast<FloatType>()) {
300           if (floatType.getWidth() == 16) {
301             nativeShape[2] = 4;
302           }
303         }
304         return nativeShape;
305       };
306       populateVectorUnrollPatterns(patterns,
307                                    UnrollVectorOptions()
308                                        .setNativeShapeFn(nativeShapeFn)
309                                        .setFilterConstraint([](Operation *op) {
310                                          return success(isa<ContractionOp>(op));
311                                        }));
312     } else {
313       populateVectorUnrollPatterns(
314           patterns, UnrollVectorOptions()
315                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
316                         .setFilterConstraint([](Operation *op) {
317                           return success(isa<ContractionOp>(op));
318                         }));
319     }
320     populateVectorToVectorCanonicalizationPatterns(patterns);
321     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
322   }
323 
324   Option<bool> unrollBasedOnType{
325       *this, "unroll-based-on-type",
326       llvm::cl::desc("Set the unroll factor based on type of the operation"),
327       llvm::cl::init(false)};
328 };
329 
330 struct TestVectorDistributePatterns
331     : public PassWrapper<TestVectorDistributePatterns, OperationPass<FuncOp>> {
332   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns)
333 
334   StringRef getArgument() const final {
335     return "test-vector-distribute-patterns";
336   }
337   StringRef getDescription() const final {
338     return "Test lowering patterns to distribute vector ops in the vector "
339            "dialect";
340   }
341   TestVectorDistributePatterns() = default;
342   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
343       : PassWrapper(pass) {}
344   void getDependentDialects(DialectRegistry &registry) const override {
345     registry.insert<VectorDialect>();
346     registry.insert<AffineDialect>();
347   }
348   ListOption<int32_t> multiplicity{
349       *this, "distribution-multiplicity",
350       llvm::cl::desc("Set the multiplicity used for distributing vector")};
351 
352   void runOnOperation() override {
353     MLIRContext *ctx = &getContext();
354     RewritePatternSet patterns(ctx);
355     FuncOp func = getOperation();
356     func.walk([&](arith::AddFOp op) {
357       OpBuilder builder(op);
358       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
359         SmallVector<int64_t, 2> mul;
360         SmallVector<AffineExpr, 2> perm;
361         SmallVector<Value, 2> ids;
362         unsigned count = 0;
363         // Remove the multiplicity of 1 and calculate the affine map based on
364         // the multiplicity.
365         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
366         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
367           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
368             mul.push_back(m[i]);
369             ids.push_back(func.getArgument(count++));
370             perm.push_back(getAffineDimExpr(i, ctx));
371           }
372         }
373         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
374                                   perm, ctx);
375         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
376             builder, op.getOperation(), ids, mul, map);
377         if (ops.hasValue()) {
378           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
379           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
380                                               extractOp);
381         }
382       }
383     });
384     populatePropagateVectorDistributionPatterns(patterns);
385     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
386   }
387 };
388 
389 struct TestVectorToLoopPatterns
390     : public PassWrapper<TestVectorToLoopPatterns, OperationPass<FuncOp>> {
391   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns)
392 
393   StringRef getArgument() const final { return "test-vector-to-forloop"; }
394   StringRef getDescription() const final {
395     return "Test lowering patterns to break up a vector op into a for loop";
396   }
397   TestVectorToLoopPatterns() = default;
398   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
399       : PassWrapper(pass) {}
400   void getDependentDialects(DialectRegistry &registry) const override {
401     registry.insert<VectorDialect>();
402     registry.insert<AffineDialect>();
403   }
404   Option<int32_t> multiplicity{
405       *this, "distribution-multiplicity",
406       llvm::cl::desc("Set the multiplicity used for distributing vector"),
407       llvm::cl::init(32)};
408   void runOnOperation() override {
409     MLIRContext *ctx = &getContext();
410     RewritePatternSet patterns(ctx);
411     FuncOp func = getOperation();
412     func.walk([&](arith::AddFOp op) {
413       // Check that the operation type can be broken down into a loop.
414       VectorType type = op.getType().dyn_cast<VectorType>();
415       if (!type || type.getRank() != 1 ||
416           type.getNumElements() % multiplicity != 0)
417         return mlir::WalkResult::advance();
418       auto filterAlloc = [](Operation *op) {
419         return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
420       };
421       auto dependentOps = getSlice(op, filterAlloc);
422       // Create a loop and move instructions from the Op slice into the loop.
423       OpBuilder builder(op);
424       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
425       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
426       auto numIter =
427           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
428       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
429       for (Operation *it : dependentOps) {
430         it->moveBefore(forOp.getBody()->getTerminator());
431       }
432       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
433       // break up the original op and let the patterns propagate.
434       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
435           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
436           map);
437       if (ops.hasValue()) {
438         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
439         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
440       }
441       return mlir::WalkResult::interrupt();
442     });
443     populatePropagateVectorDistributionPatterns(patterns);
444     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
445   }
446 };
447 
448 struct TestVectorTransferUnrollingPatterns
449     : public PassWrapper<TestVectorTransferUnrollingPatterns,
450                          OperationPass<FuncOp>> {
451   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
452       TestVectorTransferUnrollingPatterns)
453 
454   void getDependentDialects(DialectRegistry &registry) const override {
455     registry.insert<AffineDialect>();
456   }
457   StringRef getArgument() const final {
458     return "test-vector-transfer-unrolling-patterns";
459   }
460   StringRef getDescription() const final {
461     return "Test lowering patterns to unroll transfer ops in the vector "
462            "dialect";
463   }
464   void runOnOperation() override {
465     MLIRContext *ctx = &getContext();
466     RewritePatternSet patterns(ctx);
467     populateVectorUnrollPatterns(
468         patterns,
469         UnrollVectorOptions()
470             .setNativeShape(ArrayRef<int64_t>{2, 2})
471             .setFilterConstraint([](Operation *op) {
472               return success(
473                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
474             }));
475     populateVectorToVectorCanonicalizationPatterns(patterns);
476     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
477   }
478 };
479 
480 struct TestVectorTransferFullPartialSplitPatterns
481     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
482                          OperationPass<FuncOp>> {
483   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
484       TestVectorTransferFullPartialSplitPatterns)
485 
486   StringRef getArgument() const final {
487     return "test-vector-transfer-full-partial-split";
488   }
489   StringRef getDescription() const final {
490     return "Test lowering patterns to split "
491            "transfer ops via scf.if + linalg ops";
492   }
493   TestVectorTransferFullPartialSplitPatterns() = default;
494   TestVectorTransferFullPartialSplitPatterns(
495       const TestVectorTransferFullPartialSplitPatterns &pass)
496       : PassWrapper(pass) {}
497 
498   void getDependentDialects(DialectRegistry &registry) const override {
499     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
500                     scf::SCFDialect>();
501   }
502 
503   Option<bool> useLinalgOps{
504       *this, "use-memref-copy",
505       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
506                      "memref.copy operations."),
507       llvm::cl::init(false)};
508   void runOnOperation() override {
509     MLIRContext *ctx = &getContext();
510     RewritePatternSet patterns(ctx);
511     VectorTransformsOptions options;
512     if (useLinalgOps)
513       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
514     else
515       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
516     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
517     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
518   }
519 };
520 
521 struct TestVectorTransferOpt
522     : public PassWrapper<TestVectorTransferOpt, OperationPass<FuncOp>> {
523   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
524 
525   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
526   StringRef getDescription() const final {
527     return "Test optimization transformations for transfer ops";
528   }
529   void runOnOperation() override { transferOpflowOpt(getOperation()); }
530 };
531 
532 struct TestVectorTransferLoweringPatterns
533     : public PassWrapper<TestVectorTransferLoweringPatterns,
534                          OperationPass<FuncOp>> {
535   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
536       TestVectorTransferLoweringPatterns)
537 
538   void getDependentDialects(DialectRegistry &registry) const override {
539     registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
540   }
541   StringRef getArgument() const final {
542     return "test-vector-transfer-lowering-patterns";
543   }
544   StringRef getDescription() const final {
545     return "Test lowering patterns to lower transfer ops to other vector ops";
546   }
547   void runOnOperation() override {
548     RewritePatternSet patterns(&getContext());
549     populateVectorTransferLoweringPatterns(patterns);
550     populateVectorTransferPermutationMapLoweringPatterns(patterns);
551     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
552   }
553 };
554 
555 struct TestVectorMultiReductionLoweringPatterns
556     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
557                          OperationPass<FuncOp>> {
558   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
559       TestVectorMultiReductionLoweringPatterns)
560 
561   TestVectorMultiReductionLoweringPatterns() = default;
562   TestVectorMultiReductionLoweringPatterns(
563       const TestVectorMultiReductionLoweringPatterns &pass)
564       : PassWrapper(pass) {}
565   void getDependentDialects(DialectRegistry &registry) const override {
566     registry.insert<memref::MemRefDialect>();
567   }
568   StringRef getArgument() const final {
569     return "test-vector-multi-reduction-lowering-patterns";
570   }
571   StringRef getDescription() const final {
572     return "Test lowering patterns to lower vector.multi_reduction to other "
573            "vector ops";
574   }
575   Option<bool> useOuterReductions{
576       *this, "use-outer-reductions",
577       llvm::cl::desc("Move reductions to outer most dimensions"),
578       llvm::cl::init(false)};
579   void runOnOperation() override {
580     RewritePatternSet patterns(&getContext());
581     populateVectorMultiReductionLoweringPatterns(
582         patterns, useOuterReductions
583                       ? vector::VectorMultiReductionLowering::InnerParallel
584                       : vector::VectorMultiReductionLowering::InnerReduction);
585     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
586   }
587 };
588 
589 struct TestVectorTransferCollapseInnerMostContiguousDims
590     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
591                          OperationPass<FuncOp>> {
592   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
593       TestVectorTransferCollapseInnerMostContiguousDims)
594 
595   TestVectorTransferCollapseInnerMostContiguousDims() = default;
596   TestVectorTransferCollapseInnerMostContiguousDims(
597       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
598 
599   void getDependentDialects(DialectRegistry &registry) const override {
600     registry.insert<memref::MemRefDialect, AffineDialect>();
601   }
602 
603   StringRef getArgument() const final {
604     return "test-vector-transfer-collapse-inner-most-dims";
605   }
606 
607   StringRef getDescription() const final {
608     return "Test lowering patterns that reducedes the rank of the vector "
609            "transfer memory and vector operands.";
610   }
611 
612   void runOnOperation() override {
613     RewritePatternSet patterns(&getContext());
614     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
615     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
616   }
617 };
618 
619 struct TestVectorReduceToContractPatternsPatterns
620     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
621                          OperationPass<FuncOp>> {
622   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
623       TestVectorReduceToContractPatternsPatterns)
624 
625   StringRef getArgument() const final {
626     return "test-vector-reduction-to-contract-patterns";
627   }
628   StringRef getDescription() const final {
629     return "Test patterns to convert multireduce op to contract and combine "
630            "broadcast/transpose to contract";
631   }
632   void runOnOperation() override {
633     RewritePatternSet patterns(&getContext());
634     populateVectorReductionToContractPatterns(patterns);
635     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
636   }
637 };
638 
639 struct TestVectorTransferDropUnitDimsPatterns
640     : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
641                          OperationPass<FuncOp>> {
642   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
643       TestVectorTransferDropUnitDimsPatterns)
644 
645   StringRef getArgument() const final {
646     return "test-vector-transfer-drop-unit-dims-patterns";
647   }
648   void getDependentDialects(DialectRegistry &registry) const override {
649     registry.insert<memref::MemRefDialect>();
650   }
651   void runOnOperation() override {
652     RewritePatternSet patterns(&getContext());
653     populateVectorTransferDropUnitDimsPatterns(patterns);
654     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
655   }
656 };
657 
658 struct TestFlattenVectorTransferPatterns
659     : public PassWrapper<TestFlattenVectorTransferPatterns,
660                          OperationPass<FuncOp>> {
661   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
662       TestFlattenVectorTransferPatterns)
663 
664   StringRef getArgument() const final {
665     return "test-vector-transfer-flatten-patterns";
666   }
667   StringRef getDescription() const final {
668     return "Test patterns to rewrite contiguous row-major N-dimensional "
669            "vector.transfer_{read,write} ops into 1D transfers";
670   }
671   void getDependentDialects(DialectRegistry &registry) const override {
672     registry.insert<memref::MemRefDialect>();
673   }
674   void runOnOperation() override {
675     RewritePatternSet patterns(&getContext());
676     populateFlattenVectorTransferPatterns(patterns);
677     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
678   }
679 };
680 
681 struct TestVectorScanLowering
682     : public PassWrapper<TestVectorScanLowering, OperationPass<FuncOp>> {
683   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
684 
685   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
686   StringRef getDescription() const final {
687     return "Test lowering patterns that lower the scan op in the vector "
688            "dialect";
689   }
690   void runOnOperation() override {
691     RewritePatternSet patterns(&getContext());
692     populateVectorScanLoweringPatterns(patterns);
693     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
694   }
695 };
696 
697 } // namespace
698 
699 namespace mlir {
700 namespace test {
701 void registerTestVectorLowerings() {
702   PassRegistration<TestVectorToVectorLowering>();
703 
704   PassRegistration<TestVectorContractionLowering>();
705 
706   PassRegistration<TestVectorTransposeLowering>();
707 
708   PassRegistration<TestVectorUnrollingPatterns>();
709 
710   PassRegistration<TestVectorTransferUnrollingPatterns>();
711 
712   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
713 
714   PassRegistration<TestVectorDistributePatterns>();
715 
716   PassRegistration<TestVectorToLoopPatterns>();
717 
718   PassRegistration<TestVectorTransferOpt>();
719 
720   PassRegistration<TestVectorTransferLoweringPatterns>();
721 
722   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
723 
724   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
725 
726   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
727 
728   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
729 
730   PassRegistration<TestFlattenVectorTransferPatterns>();
731 
732   PassRegistration<TestVectorScanLowering>();
733 }
734 } // namespace test
735 } // namespace mlir
736