1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Pass/PassManager.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 using namespace mlir::vector;
28 
29 namespace {
30 
31 struct TestVectorToVectorLowering
32     : public PassWrapper<TestVectorToVectorLowering,
33                          OperationPass<func::FuncOp>> {
34   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
35 
36   TestVectorToVectorLowering() = default;
37   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
38       : PassWrapper(pass) {}
39   StringRef getArgument() const final {
40     return "test-vector-to-vector-lowering";
41   }
42   StringRef getDescription() const final {
43     return "Test lowering patterns between ops in the vector dialect";
44   }
45 
46   void getDependentDialects(DialectRegistry &registry) const override {
47     registry.insert<AffineDialect>();
48   }
49 
50   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
51                       llvm::cl::init(false)};
52 
53   void runOnOperation() override {
54     auto *ctx = &getContext();
55     RewritePatternSet patterns(ctx);
56     if (unroll) {
57       populateVectorUnrollPatterns(
58           patterns,
59           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
60               filter));
61     }
62     populateVectorToVectorCanonicalizationPatterns(patterns);
63     populateBubbleVectorBitCastOpPatterns(patterns);
64     populateCastAwayVectorLeadingOneDimPatterns(patterns);
65     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
66   }
67 
68 private:
69   // Return the target shape based on op type.
70   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
71     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
72       return SmallVector<int64_t, 4>(2, 2);
73     if (isa<vector::ContractionOp>(op))
74       return SmallVector<int64_t, 4>(3, 2);
75     // For transfer ops, just propagate the shape coming from
76     // InsertStridedSlices/ExtractStridedSlices.
77     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
78       VectorType dstVec;
79       for (Operation *users : readOp->getUsers()) {
80         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
81         if (!extract)
82           return llvm::None;
83         auto vecType = extract.getResult().getType().cast<VectorType>();
84         if (dstVec && dstVec != vecType)
85           return llvm::None;
86         dstVec = vecType;
87       }
88       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
89                                      dstVec.getShape().end());
90     }
91     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
92       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
93       if (!insert)
94         return llvm::None;
95       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
96       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
97     }
98     return llvm::None;
99   }
100 
101   static LogicalResult filter(Operation *op) {
102     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
103                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
104   }
105 };
106 
107 struct TestVectorContractionLowering
108     : public PassWrapper<TestVectorContractionLowering,
109                          OperationPass<func::FuncOp>> {
110   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
111 
112   StringRef getArgument() const final {
113     return "test-vector-contraction-lowering";
114   }
115   StringRef getDescription() const final {
116     return "Test lowering patterns that lower contract ops in the vector "
117            "dialect";
118   }
119   TestVectorContractionLowering() = default;
120   TestVectorContractionLowering(const TestVectorContractionLowering &pass)
121       : PassWrapper(pass) {}
122 
123   Option<bool> lowerToFlatMatrix{
124       *this, "vector-lower-matrix-intrinsics",
125       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
126       llvm::cl::init(false)};
127   Option<bool> lowerToOuterProduct{
128       *this, "vector-outerproduct",
129       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
130       llvm::cl::init(false)};
131   Option<bool> lowerToFilterOuterProduct{
132       *this, "vector-filter-outerproduct",
133       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
134                      "vectors of size 4."),
135       llvm::cl::init(false)};
136 
137   void runOnOperation() override {
138     RewritePatternSet patterns(&getContext());
139 
140     // Test on one pattern in isolation.
141     if (lowerToOuterProduct) {
142       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
143       VectorTransformsOptions options{lowering};
144       patterns.add<ContractionOpToOuterProductOpLowering>(options,
145                                                           &getContext());
146       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
147       return;
148     }
149 
150     // Test on one pattern in isolation.
151     if (lowerToFilterOuterProduct) {
152       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
153       VectorTransformsOptions options{lowering};
154       patterns.add<ContractionOpToOuterProductOpLowering>(
155           options, &getContext(), [](vector::ContractionOp op) {
156             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
157             // where M is not 4.
158             if (op.getRhsType().getShape()[0] == 4)
159               return failure();
160             return success();
161           });
162       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
163       return;
164     }
165 
166     // Test on all contract lowering patterns.
167     VectorContractLowering contractLowering = VectorContractLowering::Dot;
168     if (lowerToFlatMatrix)
169       contractLowering = VectorContractLowering::Matmul;
170     VectorMultiReductionLowering vectorMultiReductionLowering =
171         VectorMultiReductionLowering::InnerParallel;
172     VectorTransformsOptions options{contractLowering,
173                                     vectorMultiReductionLowering,
174                                     VectorTransposeLowering()};
175     populateVectorBroadcastLoweringPatterns(patterns);
176     populateVectorContractLoweringPatterns(patterns, options);
177     populateVectorMaskOpLoweringPatterns(patterns);
178     populateVectorShapeCastLoweringPatterns(patterns);
179     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
180   }
181 };
182 
183 struct TestVectorTransposeLowering
184     : public PassWrapper<TestVectorTransposeLowering,
185                          OperationPass<func::FuncOp>> {
186   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
187 
188   StringRef getArgument() const final {
189     return "test-vector-transpose-lowering";
190   }
191   StringRef getDescription() const final {
192     return "Test lowering patterns that lower contract ops in the vector "
193            "dialect";
194   }
195   TestVectorTransposeLowering() = default;
196   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
197       : PassWrapper(pass) {}
198 
199   Option<bool> lowerToEltwise{
200       *this, "eltwise",
201       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
202       llvm::cl::init(false)};
203   Option<bool> lowerToFlatTranspose{
204       *this, "flat",
205       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
206       llvm::cl::init(false)};
207   Option<bool> lowerToShuffleTranspose{
208       *this, "shuffle",
209       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
210       llvm::cl::init(false)};
211   Option<bool> lowerToAvx2{
212       *this, "avx2",
213       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
214       llvm::cl::init(false)};
215 
216   void getDependentDialects(DialectRegistry &registry) const override {
217     registry.insert<LLVM::LLVMDialect>();
218   }
219 
220   void runOnOperation() override {
221     RewritePatternSet patterns(&getContext());
222 
223     // Test on one pattern in isolation.
224     // Explicitly disable shape_cast lowering.
225     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
226                                               .enableVectorTransposeLowering()
227                                               .enableShapeCastLowering(false);
228     if (lowerToEltwise) {
229       options = options.setVectorTransformsOptions(
230           VectorTransformsOptions().setVectorTransposeLowering(
231               VectorTransposeLowering::EltWise));
232     }
233     if (lowerToFlatTranspose) {
234       options = options.setVectorTransformsOptions(
235           VectorTransformsOptions().setVectorTransposeLowering(
236               VectorTransposeLowering::Flat));
237     }
238     if (lowerToShuffleTranspose) {
239       options = options.setVectorTransformsOptions(
240           VectorTransformsOptions().setVectorTransposeLowering(
241               VectorTransposeLowering::Shuffle));
242     }
243     if (lowerToAvx2) {
244       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
245           x86vector::avx2::LoweringOptions().setTransposeOptions(
246               x86vector::avx2::TransposeLoweringOptions()
247                   .lower4x8xf32()
248                   .lower8x8xf32()));
249     }
250 
251     OpPassManager dynamicPM("func.func");
252     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
253     if (failed(runPipeline(dynamicPM, getOperation())))
254       return signalPassFailure();
255   }
256 };
257 
258 struct TestVectorUnrollingPatterns
259     : public PassWrapper<TestVectorUnrollingPatterns,
260                          OperationPass<func::FuncOp>> {
261   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
262 
263   StringRef getArgument() const final {
264     return "test-vector-unrolling-patterns";
265   }
266   StringRef getDescription() const final {
267     return "Test lowering patterns to unroll contract ops in the vector "
268            "dialect";
269   }
270   TestVectorUnrollingPatterns() = default;
271   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
272       : PassWrapper(pass) {}
273   void runOnOperation() override {
274     MLIRContext *ctx = &getContext();
275     RewritePatternSet patterns(ctx);
276     populateVectorUnrollPatterns(
277         patterns, UnrollVectorOptions()
278                       .setNativeShape(ArrayRef<int64_t>{2, 2})
279                       .setFilterConstraint([](Operation *op) {
280                         return success(isa<arith::AddFOp, vector::FMAOp,
281                                            vector::MultiDimReductionOp>(op));
282                       }));
283     populateVectorUnrollPatterns(
284         patterns, UnrollVectorOptions()
285                       .setNativeShape(ArrayRef<int64_t>{2})
286                       .setFilterConstraint([](Operation *op) {
287                         return success(isa<vector::ReductionOp>(op));
288                       }));
289     populateVectorUnrollPatterns(
290         patterns, UnrollVectorOptions()
291                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
292                       .setFilterConstraint([](Operation *op) {
293                         return success(isa<vector::TransposeOp>(op));
294                       }));
295 
296     if (unrollBasedOnType) {
297       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
298           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
299         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
300         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
301         if (auto floatType = contractOp.getLhsType()
302                                  .getElementType()
303                                  .dyn_cast<FloatType>()) {
304           if (floatType.getWidth() == 16) {
305             nativeShape[2] = 4;
306           }
307         }
308         return nativeShape;
309       };
310       populateVectorUnrollPatterns(patterns,
311                                    UnrollVectorOptions()
312                                        .setNativeShapeFn(nativeShapeFn)
313                                        .setFilterConstraint([](Operation *op) {
314                                          return success(isa<ContractionOp>(op));
315                                        }));
316     } else {
317       populateVectorUnrollPatterns(
318           patterns, UnrollVectorOptions()
319                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
320                         .setFilterConstraint([](Operation *op) {
321                           return success(isa<ContractionOp>(op));
322                         }));
323     }
324     populateVectorToVectorCanonicalizationPatterns(patterns);
325     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
326   }
327 
328   Option<bool> unrollBasedOnType{
329       *this, "unroll-based-on-type",
330       llvm::cl::desc("Set the unroll factor based on type of the operation"),
331       llvm::cl::init(false)};
332 };
333 
334 struct TestVectorDistributePatterns
335     : public PassWrapper<TestVectorDistributePatterns,
336                          OperationPass<func::FuncOp>> {
337   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns)
338 
339   StringRef getArgument() const final {
340     return "test-vector-distribute-patterns";
341   }
342   StringRef getDescription() const final {
343     return "Test lowering patterns to distribute vector ops in the vector "
344            "dialect";
345   }
346   TestVectorDistributePatterns() = default;
347   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
348       : PassWrapper(pass) {}
349   void getDependentDialects(DialectRegistry &registry) const override {
350     registry.insert<VectorDialect>();
351     registry.insert<AffineDialect>();
352   }
353   ListOption<int32_t> multiplicity{
354       *this, "distribution-multiplicity",
355       llvm::cl::desc("Set the multiplicity used for distributing vector")};
356 
357   void runOnOperation() override {
358     MLIRContext *ctx = &getContext();
359     RewritePatternSet patterns(ctx);
360     func::FuncOp func = getOperation();
361     func.walk([&](arith::AddFOp op) {
362       OpBuilder builder(op);
363       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
364         SmallVector<int64_t, 2> mul;
365         SmallVector<AffineExpr, 2> perm;
366         SmallVector<Value, 2> ids;
367         unsigned count = 0;
368         // Remove the multiplicity of 1 and calculate the affine map based on
369         // the multiplicity.
370         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
371         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
372           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
373             mul.push_back(m[i]);
374             ids.push_back(func.getArgument(count++));
375             perm.push_back(getAffineDimExpr(i, ctx));
376           }
377         }
378         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
379                                   perm, ctx);
380         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
381             builder, op.getOperation(), ids, mul, map);
382         if (ops.hasValue()) {
383           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
384           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
385                                               extractOp);
386         }
387       }
388     });
389     populatePropagateVectorDistributionPatterns(patterns);
390     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
391   }
392 };
393 
394 struct TestVectorToLoopPatterns
395     : public PassWrapper<TestVectorToLoopPatterns,
396                          OperationPass<func::FuncOp>> {
397   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns)
398 
399   StringRef getArgument() const final { return "test-vector-to-forloop"; }
400   StringRef getDescription() const final {
401     return "Test lowering patterns to break up a vector op into a for loop";
402   }
403   TestVectorToLoopPatterns() = default;
404   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
405       : PassWrapper(pass) {}
406   void getDependentDialects(DialectRegistry &registry) const override {
407     registry.insert<VectorDialect>();
408     registry.insert<AffineDialect>();
409   }
410   Option<int32_t> multiplicity{
411       *this, "distribution-multiplicity",
412       llvm::cl::desc("Set the multiplicity used for distributing vector"),
413       llvm::cl::init(32)};
414   void runOnOperation() override {
415     MLIRContext *ctx = &getContext();
416     RewritePatternSet patterns(ctx);
417     func::FuncOp func = getOperation();
418     func.walk([&](arith::AddFOp op) {
419       // Check that the operation type can be broken down into a loop.
420       VectorType type = op.getType().dyn_cast<VectorType>();
421       if (!type || type.getRank() != 1 ||
422           type.getNumElements() % multiplicity != 0)
423         return mlir::WalkResult::advance();
424       auto filterAlloc = [](Operation *op) {
425         return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
426       };
427       auto dependentOps = getSlice(op, filterAlloc);
428       // Create a loop and move instructions from the Op slice into the loop.
429       OpBuilder builder(op);
430       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
431       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
432       auto numIter =
433           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
434       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
435       for (Operation *it : dependentOps) {
436         it->moveBefore(forOp.getBody()->getTerminator());
437       }
438       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
439       // break up the original op and let the patterns propagate.
440       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
441           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
442           map);
443       if (ops.hasValue()) {
444         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
445         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
446       }
447       return mlir::WalkResult::interrupt();
448     });
449     populatePropagateVectorDistributionPatterns(patterns);
450     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
451   }
452 };
453 
454 struct TestVectorTransferUnrollingPatterns
455     : public PassWrapper<TestVectorTransferUnrollingPatterns,
456                          OperationPass<func::FuncOp>> {
457   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
458       TestVectorTransferUnrollingPatterns)
459 
460   void getDependentDialects(DialectRegistry &registry) const override {
461     registry.insert<AffineDialect>();
462   }
463   StringRef getArgument() const final {
464     return "test-vector-transfer-unrolling-patterns";
465   }
466   StringRef getDescription() const final {
467     return "Test lowering patterns to unroll transfer ops in the vector "
468            "dialect";
469   }
470   void runOnOperation() override {
471     MLIRContext *ctx = &getContext();
472     RewritePatternSet patterns(ctx);
473     populateVectorUnrollPatterns(
474         patterns,
475         UnrollVectorOptions()
476             .setNativeShape(ArrayRef<int64_t>{2, 2})
477             .setFilterConstraint([](Operation *op) {
478               return success(
479                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
480             }));
481     populateVectorToVectorCanonicalizationPatterns(patterns);
482     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
483   }
484 };
485 
486 struct TestVectorTransferFullPartialSplitPatterns
487     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
488                          OperationPass<func::FuncOp>> {
489   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
490       TestVectorTransferFullPartialSplitPatterns)
491 
492   StringRef getArgument() const final {
493     return "test-vector-transfer-full-partial-split";
494   }
495   StringRef getDescription() const final {
496     return "Test lowering patterns to split "
497            "transfer ops via scf.if + linalg ops";
498   }
499   TestVectorTransferFullPartialSplitPatterns() = default;
500   TestVectorTransferFullPartialSplitPatterns(
501       const TestVectorTransferFullPartialSplitPatterns &pass)
502       : PassWrapper(pass) {}
503 
504   void getDependentDialects(DialectRegistry &registry) const override {
505     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
506                     scf::SCFDialect>();
507   }
508 
509   Option<bool> useLinalgOps{
510       *this, "use-memref-copy",
511       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
512                      "memref.copy operations."),
513       llvm::cl::init(false)};
514   void runOnOperation() override {
515     MLIRContext *ctx = &getContext();
516     RewritePatternSet patterns(ctx);
517     VectorTransformsOptions options;
518     if (useLinalgOps)
519       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
520     else
521       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
522     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
523     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
524   }
525 };
526 
527 struct TestVectorTransferOpt
528     : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
529   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
530 
531   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
532   StringRef getDescription() const final {
533     return "Test optimization transformations for transfer ops";
534   }
535   void runOnOperation() override { transferOpflowOpt(getOperation()); }
536 };
537 
538 struct TestVectorTransferLoweringPatterns
539     : public PassWrapper<TestVectorTransferLoweringPatterns,
540                          OperationPass<func::FuncOp>> {
541   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
542       TestVectorTransferLoweringPatterns)
543 
544   void getDependentDialects(DialectRegistry &registry) const override {
545     registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
546   }
547   StringRef getArgument() const final {
548     return "test-vector-transfer-lowering-patterns";
549   }
550   StringRef getDescription() const final {
551     return "Test lowering patterns to lower transfer ops to other vector ops";
552   }
553   void runOnOperation() override {
554     RewritePatternSet patterns(&getContext());
555     populateVectorTransferLoweringPatterns(patterns);
556     populateVectorTransferPermutationMapLoweringPatterns(patterns);
557     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
558   }
559 };
560 
561 struct TestVectorMultiReductionLoweringPatterns
562     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
563                          OperationPass<func::FuncOp>> {
564   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
565       TestVectorMultiReductionLoweringPatterns)
566 
567   TestVectorMultiReductionLoweringPatterns() = default;
568   TestVectorMultiReductionLoweringPatterns(
569       const TestVectorMultiReductionLoweringPatterns &pass)
570       : PassWrapper(pass) {}
571   void getDependentDialects(DialectRegistry &registry) const override {
572     registry.insert<memref::MemRefDialect>();
573   }
574   StringRef getArgument() const final {
575     return "test-vector-multi-reduction-lowering-patterns";
576   }
577   StringRef getDescription() const final {
578     return "Test lowering patterns to lower vector.multi_reduction to other "
579            "vector ops";
580   }
581   Option<bool> useOuterReductions{
582       *this, "use-outer-reductions",
583       llvm::cl::desc("Move reductions to outer most dimensions"),
584       llvm::cl::init(false)};
585   void runOnOperation() override {
586     RewritePatternSet patterns(&getContext());
587     populateVectorMultiReductionLoweringPatterns(
588         patterns, useOuterReductions
589                       ? vector::VectorMultiReductionLowering::InnerParallel
590                       : vector::VectorMultiReductionLowering::InnerReduction);
591     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
592   }
593 };
594 
595 struct TestVectorTransferCollapseInnerMostContiguousDims
596     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
597                          OperationPass<func::FuncOp>> {
598   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
599       TestVectorTransferCollapseInnerMostContiguousDims)
600 
601   TestVectorTransferCollapseInnerMostContiguousDims() = default;
602   TestVectorTransferCollapseInnerMostContiguousDims(
603       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
604 
605   void getDependentDialects(DialectRegistry &registry) const override {
606     registry.insert<memref::MemRefDialect, AffineDialect>();
607   }
608 
609   StringRef getArgument() const final {
610     return "test-vector-transfer-collapse-inner-most-dims";
611   }
612 
613   StringRef getDescription() const final {
614     return "Test lowering patterns that reducedes the rank of the vector "
615            "transfer memory and vector operands.";
616   }
617 
618   void runOnOperation() override {
619     RewritePatternSet patterns(&getContext());
620     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
621     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
622   }
623 };
624 
625 struct TestVectorReduceToContractPatternsPatterns
626     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
627                          OperationPass<func::FuncOp>> {
628   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
629       TestVectorReduceToContractPatternsPatterns)
630 
631   StringRef getArgument() const final {
632     return "test-vector-reduction-to-contract-patterns";
633   }
634   StringRef getDescription() const final {
635     return "Test patterns to convert multireduce op to contract and combine "
636            "broadcast/transpose to contract";
637   }
638   void runOnOperation() override {
639     RewritePatternSet patterns(&getContext());
640     populateVectorReductionToContractPatterns(patterns);
641     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
642   }
643 };
644 
645 struct TestVectorTransferDropUnitDimsPatterns
646     : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
647                          OperationPass<func::FuncOp>> {
648   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
649       TestVectorTransferDropUnitDimsPatterns)
650 
651   StringRef getArgument() const final {
652     return "test-vector-transfer-drop-unit-dims-patterns";
653   }
654   void getDependentDialects(DialectRegistry &registry) const override {
655     registry.insert<memref::MemRefDialect>();
656   }
657   void runOnOperation() override {
658     RewritePatternSet patterns(&getContext());
659     populateVectorTransferDropUnitDimsPatterns(patterns);
660     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
661   }
662 };
663 
664 struct TestFlattenVectorTransferPatterns
665     : public PassWrapper<TestFlattenVectorTransferPatterns,
666                          OperationPass<func::FuncOp>> {
667   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
668       TestFlattenVectorTransferPatterns)
669 
670   StringRef getArgument() const final {
671     return "test-vector-transfer-flatten-patterns";
672   }
673   StringRef getDescription() const final {
674     return "Test patterns to rewrite contiguous row-major N-dimensional "
675            "vector.transfer_{read,write} ops into 1D transfers";
676   }
677   void getDependentDialects(DialectRegistry &registry) const override {
678     registry.insert<memref::MemRefDialect>();
679   }
680   void runOnOperation() override {
681     RewritePatternSet patterns(&getContext());
682     populateFlattenVectorTransferPatterns(patterns);
683     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
684   }
685 };
686 
687 struct TestVectorScanLowering
688     : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
689   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
690 
691   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
692   StringRef getDescription() const final {
693     return "Test lowering patterns that lower the scan op in the vector "
694            "dialect";
695   }
696   void runOnOperation() override {
697     RewritePatternSet patterns(&getContext());
698     populateVectorScanLoweringPatterns(patterns);
699     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
700   }
701 };
702 
703 } // namespace
704 
705 namespace mlir {
706 namespace test {
707 void registerTestVectorLowerings() {
708   PassRegistration<TestVectorToVectorLowering>();
709 
710   PassRegistration<TestVectorContractionLowering>();
711 
712   PassRegistration<TestVectorTransposeLowering>();
713 
714   PassRegistration<TestVectorUnrollingPatterns>();
715 
716   PassRegistration<TestVectorTransferUnrollingPatterns>();
717 
718   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
719 
720   PassRegistration<TestVectorDistributePatterns>();
721 
722   PassRegistration<TestVectorToLoopPatterns>();
723 
724   PassRegistration<TestVectorTransferOpt>();
725 
726   PassRegistration<TestVectorTransferLoweringPatterns>();
727 
728   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
729 
730   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
731 
732   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
733 
734   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
735 
736   PassRegistration<TestFlattenVectorTransferPatterns>();
737 
738   PassRegistration<TestVectorScanLowering>();
739 }
740 } // namespace test
741 } // namespace mlir
742