1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/Linalg/Passes.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Vector/VectorTransforms.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Pass/PassManager.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 using namespace mlir;
25 using namespace mlir::linalg;
26 using namespace mlir::vector;
27 
28 namespace {
29 
30 struct TestVectorToVectorLowering
31     : public PassWrapper<TestVectorToVectorLowering, FunctionPass> {
32   TestVectorToVectorLowering() = default;
33   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) {}
34   StringRef getArgument() const final {
35     return "test-vector-to-vector-lowering";
36   }
37   StringRef getDescription() const final {
38     return "Test lowering patterns between ops in the vector dialect";
39   }
40 
41   void getDependentDialects(DialectRegistry &registry) const override {
42     registry.insert<AffineDialect>();
43   }
44 
45   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
46                       llvm::cl::init(false)};
47 
48   void runOnFunction() override {
49     auto *ctx = &getContext();
50     RewritePatternSet patterns(ctx);
51     if (unroll) {
52       populateVectorUnrollPatterns(
53           patterns,
54           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
55               filter));
56     }
57     populateVectorToVectorCanonicalizationPatterns(patterns);
58     populateBubbleVectorBitCastOpPatterns(patterns);
59     populateCastAwayVectorLeadingOneDimPatterns(patterns);
60     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
61   }
62 
63 private:
64   // Return the target shape based on op type.
65   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
66     if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op))
67       return SmallVector<int64_t, 4>(2, 2);
68     if (isa<vector::ContractionOp>(op))
69       return SmallVector<int64_t, 4>(3, 2);
70     // For transfer ops, just propagate the shape coming from
71     // InsertStridedSlices/ExtractStridedSlices.
72     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
73       VectorType dstVec;
74       for (Operation *users : readOp->getUsers()) {
75         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
76         if (!extract)
77           return llvm::None;
78         auto vecType = extract.getResult().getType().cast<VectorType>();
79         if (dstVec && dstVec != vecType)
80           return llvm::None;
81         dstVec = vecType;
82       }
83       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
84                                      dstVec.getShape().end());
85     }
86     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
87       auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>();
88       if (!insert)
89         return llvm::None;
90       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
91       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
92     }
93     return llvm::None;
94   }
95 
96   static LogicalResult filter(Operation *op) {
97     return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp,
98                        TransferReadOp, TransferWriteOp>(op));
99   }
100 };
101 
102 struct TestVectorContractionLowering
103     : public PassWrapper<TestVectorContractionLowering, FunctionPass> {
104   StringRef getArgument() const final {
105     return "test-vector-contraction-lowering";
106   }
107   StringRef getDescription() const final {
108     return "Test lowering patterns that lower contract ops in the vector "
109            "dialect";
110   }
111   TestVectorContractionLowering() = default;
112   TestVectorContractionLowering(const TestVectorContractionLowering &pass) {}
113 
114   Option<bool> lowerToFlatMatrix{
115       *this, "vector-lower-matrix-intrinsics",
116       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
117       llvm::cl::init(false)};
118   Option<bool> lowerToOuterProduct{
119       *this, "vector-outerproduct",
120       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
121       llvm::cl::init(false)};
122   Option<bool> lowerToFilterOuterProduct{
123       *this, "vector-filter-outerproduct",
124       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
125                      "vectors of size 4."),
126       llvm::cl::init(false)};
127 
128   void runOnFunction() override {
129     RewritePatternSet patterns(&getContext());
130 
131     // Test on one pattern in isolation.
132     if (lowerToOuterProduct) {
133       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
134       VectorTransformsOptions options{lowering};
135       patterns.add<ContractionOpToOuterProductOpLowering>(options,
136                                                           &getContext());
137       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
138       return;
139     }
140 
141     // Test on one pattern in isolation.
142     if (lowerToFilterOuterProduct) {
143       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
144       VectorTransformsOptions options{lowering};
145       patterns.add<ContractionOpToOuterProductOpLowering>(
146           options, &getContext(), [](vector::ContractionOp op) {
147             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
148             // where M is not 4.
149             if (op.getRhsType().getShape()[0] == 4)
150               return failure();
151             return success();
152           });
153       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
154       return;
155     }
156 
157     // Test on all contract lowering patterns.
158     VectorContractLowering contractLowering = VectorContractLowering::Dot;
159     if (lowerToFlatMatrix)
160       contractLowering = VectorContractLowering::Matmul;
161     VectorMultiReductionLowering vectorMultiReductionLowering =
162         VectorMultiReductionLowering::InnerParallel;
163     VectorTransformsOptions options{contractLowering,
164                                     vectorMultiReductionLowering,
165                                     VectorTransposeLowering()};
166     populateVectorBroadcastLoweringPatterns(patterns);
167     populateVectorContractLoweringPatterns(patterns, options);
168     populateVectorMaskOpLoweringPatterns(patterns);
169     populateVectorShapeCastLoweringPatterns(patterns);
170     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
171   }
172 };
173 
174 struct TestVectorTransposeLowering
175     : public PassWrapper<TestVectorTransposeLowering, FunctionPass> {
176   StringRef getArgument() const final {
177     return "test-vector-transpose-lowering";
178   }
179   StringRef getDescription() const final {
180     return "Test lowering patterns that lower contract ops in the vector "
181            "dialect";
182   }
183   TestVectorTransposeLowering() = default;
184   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) {}
185 
186   Option<bool> lowerToEltwise{
187       *this, "eltwise",
188       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
189       llvm::cl::init(false)};
190   Option<bool> lowerToFlatTranspose{
191       *this, "flat",
192       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
193       llvm::cl::init(false)};
194   Option<bool> lowerToShuffleTranspose{
195       *this, "shuffle",
196       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
197       llvm::cl::init(false)};
198   Option<bool> lowerToAvx2{
199       *this, "avx2",
200       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
201       llvm::cl::init(false)};
202 
203   void runOnFunction() override {
204     RewritePatternSet patterns(&getContext());
205 
206     // Test on one pattern in isolation.
207     // Explicitly disable shape_cast lowering.
208     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
209                                               .enableVectorTransposeLowering()
210                                               .enableShapeCastLowering(false);
211     if (lowerToEltwise) {
212       options = options.setVectorTransformsOptions(
213           VectorTransformsOptions().setVectorTransposeLowering(
214               VectorTransposeLowering::EltWise));
215     }
216     if (lowerToFlatTranspose) {
217       options = options.setVectorTransformsOptions(
218           VectorTransformsOptions().setVectorTransposeLowering(
219               VectorTransposeLowering::Flat));
220     }
221     if (lowerToShuffleTranspose) {
222       options = options.setVectorTransformsOptions(
223           VectorTransformsOptions().setVectorTransposeLowering(
224               VectorTransposeLowering::Shuffle));
225     }
226     if (lowerToAvx2) {
227       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
228           x86vector::avx2::LoweringOptions().setTransposeOptions(
229               x86vector::avx2::TransposeLoweringOptions()
230                   .lower4x8xf32()
231                   .lower8x8xf32()));
232     }
233 
234     OpPassManager dynamicPM("builtin.func");
235     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
236     if (failed(runPipeline(dynamicPM, getFunction())))
237       return signalPassFailure();
238   }
239 };
240 
241 struct TestVectorUnrollingPatterns
242     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
243   StringRef getArgument() const final {
244     return "test-vector-unrolling-patterns";
245   }
246   StringRef getDescription() const final {
247     return "Test lowering patterns to unroll contract ops in the vector "
248            "dialect";
249   }
250   TestVectorUnrollingPatterns() = default;
251   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
252   void runOnFunction() override {
253     MLIRContext *ctx = &getContext();
254     RewritePatternSet patterns(ctx);
255     populateVectorUnrollPatterns(
256         patterns, UnrollVectorOptions()
257                       .setNativeShape(ArrayRef<int64_t>{2, 2})
258                       .setFilterConstraint([](Operation *op) {
259                         return success(isa<arith::AddFOp, vector::FMAOp>(op));
260                       }));
261 
262     if (unrollBasedOnType) {
263       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
264           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
265         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
266         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
267         if (auto floatType = contractOp.getLhsType()
268                                  .getElementType()
269                                  .dyn_cast<FloatType>()) {
270           if (floatType.getWidth() == 16) {
271             nativeShape[2] = 4;
272           }
273         }
274         return nativeShape;
275       };
276       populateVectorUnrollPatterns(patterns,
277                                    UnrollVectorOptions()
278                                        .setNativeShapeFn(nativeShapeFn)
279                                        .setFilterConstraint([](Operation *op) {
280                                          return success(isa<ContractionOp>(op));
281                                        }));
282     } else {
283       populateVectorUnrollPatterns(
284           patterns, UnrollVectorOptions()
285                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
286                         .setFilterConstraint([](Operation *op) {
287                           return success(isa<ContractionOp>(op));
288                         }));
289     }
290     populateVectorToVectorCanonicalizationPatterns(patterns);
291     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
292   }
293 
294   Option<bool> unrollBasedOnType{
295       *this, "unroll-based-on-type",
296       llvm::cl::desc("Set the unroll factor based on type of the operation"),
297       llvm::cl::init(false)};
298 };
299 
300 struct TestVectorDistributePatterns
301     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
302   StringRef getArgument() const final {
303     return "test-vector-distribute-patterns";
304   }
305   StringRef getDescription() const final {
306     return "Test lowering patterns to distribute vector ops in the vector "
307            "dialect";
308   }
309   TestVectorDistributePatterns() = default;
310   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
311   void getDependentDialects(DialectRegistry &registry) const override {
312     registry.insert<VectorDialect>();
313     registry.insert<AffineDialect>();
314   }
315   ListOption<int32_t> multiplicity{
316       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
317       llvm::cl::desc("Set the multiplicity used for distributing vector")};
318 
319   void runOnFunction() override {
320     MLIRContext *ctx = &getContext();
321     RewritePatternSet patterns(ctx);
322     FuncOp func = getFunction();
323     func.walk([&](arith::AddFOp op) {
324       OpBuilder builder(op);
325       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
326         SmallVector<int64_t, 2> mul;
327         SmallVector<AffineExpr, 2> perm;
328         SmallVector<Value, 2> ids;
329         unsigned count = 0;
330         // Remove the multiplicity of 1 and calculate the affine map based on
331         // the multiplicity.
332         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
333         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
334           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
335             mul.push_back(m[i]);
336             ids.push_back(func.getArgument(count++));
337             perm.push_back(getAffineDimExpr(i, ctx));
338           }
339         }
340         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
341                                   perm, ctx);
342         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
343             builder, op.getOperation(), ids, mul, map);
344         if (ops.hasValue()) {
345           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
346           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
347                                               extractOp);
348         }
349       }
350     });
351     populatePropagateVectorDistributionPatterns(patterns);
352     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
353   }
354 };
355 
356 struct TestVectorToLoopPatterns
357     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
358   StringRef getArgument() const final { return "test-vector-to-forloop"; }
359   StringRef getDescription() const final {
360     return "Test lowering patterns to break up a vector op into a for loop";
361   }
362   TestVectorToLoopPatterns() = default;
363   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
364   void getDependentDialects(DialectRegistry &registry) const override {
365     registry.insert<VectorDialect>();
366     registry.insert<AffineDialect>();
367   }
368   Option<int32_t> multiplicity{
369       *this, "distribution-multiplicity",
370       llvm::cl::desc("Set the multiplicity used for distributing vector"),
371       llvm::cl::init(32)};
372   void runOnFunction() override {
373     MLIRContext *ctx = &getContext();
374     RewritePatternSet patterns(ctx);
375     FuncOp func = getFunction();
376     func.walk([&](arith::AddFOp op) {
377       // Check that the operation type can be broken down into a loop.
378       VectorType type = op.getType().dyn_cast<VectorType>();
379       if (!type || type.getRank() != 1 ||
380           type.getNumElements() % multiplicity != 0)
381         return mlir::WalkResult::advance();
382       auto filterAlloc = [](Operation *op) {
383         if (isa<arith::ConstantOp, memref::AllocOp, CallOp>(op))
384           return false;
385         return true;
386       };
387       auto dependentOps = getSlice(op, filterAlloc);
388       // Create a loop and move instructions from the Op slice into the loop.
389       OpBuilder builder(op);
390       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
391       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
392       auto numIter =
393           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
394       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
395       for (Operation *it : dependentOps) {
396         it->moveBefore(forOp.getBody()->getTerminator());
397       }
398       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
399       // break up the original op and let the patterns propagate.
400       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
401           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
402           map);
403       if (ops.hasValue()) {
404         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
405         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
406       }
407       return mlir::WalkResult::interrupt();
408     });
409     populatePropagateVectorDistributionPatterns(patterns);
410     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
411   }
412 };
413 
414 struct TestVectorTransferUnrollingPatterns
415     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
416   void getDependentDialects(DialectRegistry &registry) const override {
417     registry.insert<AffineDialect>();
418   }
419   StringRef getArgument() const final {
420     return "test-vector-transfer-unrolling-patterns";
421   }
422   StringRef getDescription() const final {
423     return "Test lowering patterns to unroll transfer ops in the vector "
424            "dialect";
425   }
426   void runOnFunction() override {
427     MLIRContext *ctx = &getContext();
428     RewritePatternSet patterns(ctx);
429     populateVectorUnrollPatterns(
430         patterns,
431         UnrollVectorOptions()
432             .setNativeShape(ArrayRef<int64_t>{2, 2})
433             .setFilterConstraint([](Operation *op) {
434               return success(
435                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
436             }));
437     populateVectorToVectorCanonicalizationPatterns(patterns);
438     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
439   }
440 };
441 
442 struct TestVectorTransferFullPartialSplitPatterns
443     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
444                          FunctionPass> {
445   StringRef getArgument() const final {
446     return "test-vector-transfer-full-partial-split";
447   }
448   StringRef getDescription() const final {
449     return "Test lowering patterns to split "
450            "transfer ops via scf.if + linalg ops";
451   }
452   TestVectorTransferFullPartialSplitPatterns() = default;
453   TestVectorTransferFullPartialSplitPatterns(
454       const TestVectorTransferFullPartialSplitPatterns &pass) {}
455 
456   void getDependentDialects(DialectRegistry &registry) const override {
457     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
458                     scf::SCFDialect>();
459   }
460 
461   Option<bool> useLinalgOps{
462       *this, "use-linalg-copy",
463       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
464                      "linalg.copy operations."),
465       llvm::cl::init(false)};
466   void runOnFunction() override {
467     MLIRContext *ctx = &getContext();
468     RewritePatternSet patterns(ctx);
469     VectorTransformsOptions options;
470     if (useLinalgOps)
471       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
472     else
473       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
474     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
475     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
476   }
477 };
478 
479 struct TestVectorTransferOpt
480     : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
481   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
482   StringRef getDescription() const final {
483     return "Test optimization transformations for transfer ops";
484   }
485   void runOnFunction() override { transferOpflowOpt(getFunction()); }
486 };
487 
488 struct TestVectorTransferLoweringPatterns
489     : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
490   void getDependentDialects(DialectRegistry &registry) const override {
491     registry.insert<memref::MemRefDialect>();
492   }
493   StringRef getArgument() const final {
494     return "test-vector-transfer-lowering-patterns";
495   }
496   StringRef getDescription() const final {
497     return "Test lowering patterns to lower transfer ops to other vector ops";
498   }
499   void runOnFunction() override {
500     RewritePatternSet patterns(&getContext());
501     populateVectorTransferLoweringPatterns(patterns);
502     populateVectorTransferPermutationMapLoweringPatterns(patterns);
503     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
504   }
505 };
506 
507 struct TestVectorMultiReductionLoweringPatterns
508     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
509                          FunctionPass> {
510   TestVectorMultiReductionLoweringPatterns() = default;
511   TestVectorMultiReductionLoweringPatterns(
512       const TestVectorMultiReductionLoweringPatterns &pass) {}
513   void getDependentDialects(DialectRegistry &registry) const override {
514     registry.insert<memref::MemRefDialect>();
515   }
516   StringRef getArgument() const final {
517     return "test-vector-multi-reduction-lowering-patterns";
518   }
519   StringRef getDescription() const final {
520     return "Test lowering patterns to lower vector.multi_reduction to other "
521            "vector ops";
522   }
523   Option<bool> useOuterReductions{
524       *this, "use-outer-reductions",
525       llvm::cl::desc("Move reductions to outer most dimensions"),
526       llvm::cl::init(false)};
527   void runOnFunction() override {
528     RewritePatternSet patterns(&getContext());
529     populateVectorMultiReductionLoweringPatterns(
530         patterns, useOuterReductions
531                       ? vector::VectorMultiReductionLowering::InnerParallel
532                       : vector::VectorMultiReductionLowering::InnerReduction);
533     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
534   }
535 };
536 
537 struct TestVectorTransferCollapseInnerMostContiguousDims
538     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
539                          FunctionPass> {
540   TestVectorTransferCollapseInnerMostContiguousDims() = default;
541   TestVectorTransferCollapseInnerMostContiguousDims(
542       const TestVectorTransferCollapseInnerMostContiguousDims &pass) {}
543 
544   void getDependentDialects(DialectRegistry &registry) const override {
545     registry.insert<memref::MemRefDialect, AffineDialect>();
546   }
547 
548   StringRef getArgument() const final {
549     return "test-vector-transfer-collapse-inner-most-dims";
550   }
551 
552   StringRef getDescription() const final {
553     return "Test lowering patterns that reducedes the rank of the vector "
554            "transfer memory and vector operands.";
555   }
556 
557   void runOnFunction() override {
558     RewritePatternSet patterns(&getContext());
559     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
560     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
561   }
562 };
563 
564 struct TestVectorReduceToContractPatternsPatterns
565     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
566                          FunctionPass> {
567   StringRef getArgument() const final {
568     return "test-vector-reduction-to-contract-patterns";
569   }
570   StringRef getDescription() const final {
571     return "Test patterns to convert multireduce op to contract and combine "
572            "broadcast/transpose to contract";
573   }
574   void runOnFunction() override {
575     RewritePatternSet patterns(&getContext());
576     populateVectorReductionToContractPatterns(patterns);
577     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
578   }
579 };
580 
581 } // end anonymous namespace
582 
583 namespace mlir {
584 namespace test {
585 void registerTestVectorLowerings() {
586   PassRegistration<TestVectorToVectorLowering>();
587 
588   PassRegistration<TestVectorContractionLowering>();
589 
590   PassRegistration<TestVectorTransposeLowering>();
591 
592   PassRegistration<TestVectorUnrollingPatterns>();
593 
594   PassRegistration<TestVectorTransferUnrollingPatterns>();
595 
596   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
597 
598   PassRegistration<TestVectorDistributePatterns>();
599 
600   PassRegistration<TestVectorToLoopPatterns>();
601 
602   PassRegistration<TestVectorTransferOpt>();
603 
604   PassRegistration<TestVectorTransferLoweringPatterns>();
605 
606   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
607 
608   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
609 
610   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
611 }
612 } // namespace test
613 } // namespace mlir
614