1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 using namespace mlir::vector;
31 
32 namespace {
33 
34 struct TestVectorToVectorLowering
35     : public PassWrapper<TestVectorToVectorLowering,
36                          OperationPass<func::FuncOp>> {
37   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
38 
39   TestVectorToVectorLowering() = default;
40   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
41       : PassWrapper(pass) {}
42   StringRef getArgument() const final {
43     return "test-vector-to-vector-lowering";
44   }
45   StringRef getDescription() const final {
46     return "Test lowering patterns between ops in the vector dialect";
47   }
48 
49   void getDependentDialects(DialectRegistry &registry) const override {
50     registry.insert<AffineDialect>();
51   }
52 
53   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
54                       llvm::cl::init(false)};
55 
56   void runOnOperation() override {
57     auto *ctx = &getContext();
58     RewritePatternSet patterns(ctx);
59     if (unroll) {
60       populateVectorUnrollPatterns(
61           patterns,
62           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
63               filter));
64     }
65     populateVectorToVectorCanonicalizationPatterns(patterns);
66     populateBubbleVectorBitCastOpPatterns(patterns);
67     populateCastAwayVectorLeadingOneDimPatterns(patterns);
68     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
69   }
70 
71 private:
72   // Return the target shape based on op type.
73   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
74     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
75       return SmallVector<int64_t, 4>(2, 2);
76     if (isa<vector::ContractionOp>(op))
77       return SmallVector<int64_t, 4>(3, 2);
78     // For transfer ops, just propagate the shape coming from
79     // InsertStridedSlices/ExtractStridedSlices.
80     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
81       VectorType dstVec;
82       for (Operation *users : readOp->getUsers()) {
83         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
84         if (!extract)
85           return llvm::None;
86         auto vecType = extract.getResult().getType().cast<VectorType>();
87         if (dstVec && dstVec != vecType)
88           return llvm::None;
89         dstVec = vecType;
90       }
91       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
92                                      dstVec.getShape().end());
93     }
94     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
95       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
96       if (!insert)
97         return llvm::None;
98       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
99       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
100     }
101     return llvm::None;
102   }
103 
104   static LogicalResult filter(Operation *op) {
105     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
106                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
107   }
108 };
109 
110 struct TestVectorContractionLowering
111     : public PassWrapper<TestVectorContractionLowering,
112                          OperationPass<func::FuncOp>> {
113   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
114 
115   StringRef getArgument() const final {
116     return "test-vector-contraction-lowering";
117   }
118   StringRef getDescription() const final {
119     return "Test lowering patterns that lower contract ops in the vector "
120            "dialect";
121   }
122   TestVectorContractionLowering() = default;
123   TestVectorContractionLowering(const TestVectorContractionLowering &pass)
124       : PassWrapper(pass) {}
125 
126   Option<bool> lowerToFlatMatrix{
127       *this, "vector-lower-matrix-intrinsics",
128       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
129       llvm::cl::init(false)};
130   Option<bool> lowerToOuterProduct{
131       *this, "vector-outerproduct",
132       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
133       llvm::cl::init(false)};
134   Option<bool> lowerToFilterOuterProduct{
135       *this, "vector-filter-outerproduct",
136       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
137                      "vectors of size 4."),
138       llvm::cl::init(false)};
139   Option<bool> lowerToParallelArith{
140       *this, "vector-parallel-arith",
141       llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
142       llvm::cl::init(false)};
143 
144   void runOnOperation() override {
145     RewritePatternSet patterns(&getContext());
146 
147     // Test on one pattern in isolation.
148     if (lowerToOuterProduct) {
149       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
150       VectorTransformsOptions options{lowering};
151       patterns.add<ContractionOpToOuterProductOpLowering>(options,
152                                                           &getContext());
153       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
154       return;
155     }
156 
157     // Test on one pattern in isolation.
158     if (lowerToFilterOuterProduct) {
159       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
160       VectorTransformsOptions options{lowering};
161       patterns.add<ContractionOpToOuterProductOpLowering>(
162           options, &getContext(), [](vector::ContractionOp op) {
163             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
164             // where M is not 4.
165             if (op.getRhsType().getShape()[0] == 4)
166               return failure();
167             return success();
168           });
169       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
170       return;
171     }
172 
173     if (lowerToParallelArith) {
174       vector::populateVectorContractLoweringPatterns(
175           patterns,
176           vector::VectorTransformsOptions().setVectorTransformsOptions(
177               vector::VectorContractLowering::ParallelArith));
178       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
179       return;
180     }
181 
182     // Test on all contract lowering patterns.
183     VectorContractLowering contractLowering = VectorContractLowering::Dot;
184     if (lowerToFlatMatrix)
185       contractLowering = VectorContractLowering::Matmul;
186     VectorMultiReductionLowering vectorMultiReductionLowering =
187         VectorMultiReductionLowering::InnerParallel;
188     VectorTransformsOptions options{contractLowering,
189                                     vectorMultiReductionLowering,
190                                     VectorTransposeLowering()};
191     populateVectorBroadcastLoweringPatterns(patterns);
192     populateVectorContractLoweringPatterns(patterns, options);
193     populateVectorMaskOpLoweringPatterns(patterns);
194     populateVectorShapeCastLoweringPatterns(patterns);
195     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
196   }
197 };
198 
199 struct TestVectorTransposeLowering
200     : public PassWrapper<TestVectorTransposeLowering,
201                          OperationPass<func::FuncOp>> {
202   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
203 
204   StringRef getArgument() const final {
205     return "test-vector-transpose-lowering";
206   }
207   StringRef getDescription() const final {
208     return "Test lowering patterns that lower contract ops in the vector "
209            "dialect";
210   }
211   TestVectorTransposeLowering() = default;
212   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
213       : PassWrapper(pass) {}
214 
215   Option<bool> lowerToEltwise{
216       *this, "eltwise",
217       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
218       llvm::cl::init(false)};
219   Option<bool> lowerToFlatTranspose{
220       *this, "flat",
221       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
222       llvm::cl::init(false)};
223   Option<bool> lowerToShuffleTranspose{
224       *this, "shuffle",
225       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
226       llvm::cl::init(false)};
227   Option<bool> lowerToAvx2{
228       *this, "avx2",
229       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
230       llvm::cl::init(false)};
231 
232   void getDependentDialects(DialectRegistry &registry) const override {
233     registry.insert<LLVM::LLVMDialect>();
234   }
235 
236   void runOnOperation() override {
237     RewritePatternSet patterns(&getContext());
238 
239     // Test on one pattern in isolation.
240     // Explicitly disable shape_cast lowering.
241     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
242                                               .enableVectorTransposeLowering()
243                                               .enableShapeCastLowering(false);
244     if (lowerToEltwise) {
245       options = options.setVectorTransformsOptions(
246           VectorTransformsOptions().setVectorTransposeLowering(
247               VectorTransposeLowering::EltWise));
248     }
249     if (lowerToFlatTranspose) {
250       options = options.setVectorTransformsOptions(
251           VectorTransformsOptions().setVectorTransposeLowering(
252               VectorTransposeLowering::Flat));
253     }
254     if (lowerToShuffleTranspose) {
255       options = options.setVectorTransformsOptions(
256           VectorTransformsOptions().setVectorTransposeLowering(
257               VectorTransposeLowering::Shuffle));
258     }
259     if (lowerToAvx2) {
260       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
261           x86vector::avx2::LoweringOptions().setTransposeOptions(
262               x86vector::avx2::TransposeLoweringOptions()
263                   .lower4x8xf32()
264                   .lower8x8xf32()));
265     }
266 
267     OpPassManager dynamicPM("func.func");
268     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
269     if (failed(runPipeline(dynamicPM, getOperation())))
270       return signalPassFailure();
271   }
272 };
273 
274 struct TestVectorUnrollingPatterns
275     : public PassWrapper<TestVectorUnrollingPatterns,
276                          OperationPass<func::FuncOp>> {
277   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
278 
279   StringRef getArgument() const final {
280     return "test-vector-unrolling-patterns";
281   }
282   StringRef getDescription() const final {
283     return "Test lowering patterns to unroll contract ops in the vector "
284            "dialect";
285   }
286   TestVectorUnrollingPatterns() = default;
287   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
288       : PassWrapper(pass) {}
289   void runOnOperation() override {
290     MLIRContext *ctx = &getContext();
291     RewritePatternSet patterns(ctx);
292     populateVectorUnrollPatterns(
293         patterns, UnrollVectorOptions()
294                       .setNativeShape(ArrayRef<int64_t>{2, 2})
295                       .setFilterConstraint([](Operation *op) {
296                         return success(isa<arith::AddFOp, vector::FMAOp,
297                                            vector::MultiDimReductionOp>(op));
298                       }));
299     populateVectorUnrollPatterns(
300         patterns, UnrollVectorOptions()
301                       .setNativeShape(ArrayRef<int64_t>{2})
302                       .setFilterConstraint([](Operation *op) {
303                         return success(isa<vector::ReductionOp>(op));
304                       }));
305     populateVectorUnrollPatterns(
306         patterns, UnrollVectorOptions()
307                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
308                       .setFilterConstraint([](Operation *op) {
309                         return success(isa<vector::TransposeOp>(op));
310                       }));
311 
312     if (unrollBasedOnType) {
313       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
314           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
315         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
316         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
317         if (auto floatType = contractOp.getLhsType()
318                                  .getElementType()
319                                  .dyn_cast<FloatType>()) {
320           if (floatType.getWidth() == 16) {
321             nativeShape[2] = 4;
322           }
323         }
324         return nativeShape;
325       };
326 
327       UnrollVectorOptions opts;
328       opts.setNativeShapeFn(nativeShapeFn)
329           .setFilterConstraint(
330               [](Operation *op) { return success(isa<ContractionOp>(op)); });
331       if (!unrollOrder.empty()) {
332         opts.setUnrollTraversalOrderFn([this](Operation *op)
333                                            -> Optional<SmallVector<int64_t>> {
334           return SmallVector<int64_t>{unrollOrder.begin(), unrollOrder.end()};
335         });
336       }
337       populateVectorUnrollPatterns(patterns, opts);
338     } else {
339       populateVectorUnrollPatterns(
340           patterns, UnrollVectorOptions()
341                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
342                         .setFilterConstraint([](Operation *op) {
343                           return success(isa<ContractionOp>(op));
344                         }));
345     }
346     populateVectorToVectorCanonicalizationPatterns(patterns);
347     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
348   }
349 
350   ListOption<int64_t> unrollOrder{*this, "unroll-order",
351                                   llvm::cl::desc("set the unroll order"),
352                                   llvm::cl::ZeroOrMore};
353 
354   Option<bool> unrollBasedOnType{
355       *this, "unroll-based-on-type",
356       llvm::cl::desc("Set the unroll factor based on type of the operation"),
357       llvm::cl::init(false)};
358 };
359 
360 struct TestVectorDistributePatterns
361     : public PassWrapper<TestVectorDistributePatterns,
362                          OperationPass<func::FuncOp>> {
363   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns)
364 
365   StringRef getArgument() const final {
366     return "test-vector-distribute-patterns";
367   }
368   StringRef getDescription() const final {
369     return "Test lowering patterns to distribute vector ops in the vector "
370            "dialect";
371   }
372   TestVectorDistributePatterns() = default;
373   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
374       : PassWrapper(pass) {}
375   void getDependentDialects(DialectRegistry &registry) const override {
376     registry.insert<VectorDialect>();
377     registry.insert<AffineDialect>();
378   }
379   ListOption<int32_t> multiplicity{
380       *this, "distribution-multiplicity",
381       llvm::cl::desc("Set the multiplicity used for distributing vector")};
382 
383   void runOnOperation() override {
384     MLIRContext *ctx = &getContext();
385     RewritePatternSet patterns(ctx);
386     func::FuncOp func = getOperation();
387     func.walk([&](arith::AddFOp op) {
388       OpBuilder builder(op);
389       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
390         SmallVector<int64_t, 2> mul;
391         SmallVector<AffineExpr, 2> perm;
392         SmallVector<Value, 2> ids;
393         unsigned count = 0;
394         // Remove the multiplicity of 1 and calculate the affine map based on
395         // the multiplicity.
396         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
397         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
398           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
399             mul.push_back(m[i]);
400             ids.push_back(func.getArgument(count++));
401             perm.push_back(getAffineDimExpr(i, ctx));
402           }
403         }
404         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
405                                   perm, ctx);
406         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
407             builder, op.getOperation(), ids, mul, map);
408         if (ops.hasValue()) {
409           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
410           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
411                                               extractOp);
412         }
413       }
414     });
415     populatePropagateVectorDistributionPatterns(patterns);
416     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
417   }
418 };
419 
420 struct TestVectorToLoopPatterns
421     : public PassWrapper<TestVectorToLoopPatterns,
422                          OperationPass<func::FuncOp>> {
423   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns)
424 
425   StringRef getArgument() const final { return "test-vector-to-forloop"; }
426   StringRef getDescription() const final {
427     return "Test lowering patterns to break up a vector op into a for loop";
428   }
429   TestVectorToLoopPatterns() = default;
430   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
431       : PassWrapper(pass) {}
432   void getDependentDialects(DialectRegistry &registry) const override {
433     registry.insert<VectorDialect>();
434     registry.insert<AffineDialect>();
435   }
436   Option<int32_t> multiplicity{
437       *this, "distribution-multiplicity",
438       llvm::cl::desc("Set the multiplicity used for distributing vector"),
439       llvm::cl::init(32)};
440   void runOnOperation() override {
441     MLIRContext *ctx = &getContext();
442     RewritePatternSet patterns(ctx);
443     func::FuncOp func = getOperation();
444     func.walk([&](arith::AddFOp op) {
445       // Check that the operation type can be broken down into a loop.
446       VectorType type = op.getType().dyn_cast<VectorType>();
447       if (!type || type.getRank() != 1 ||
448           type.getNumElements() % multiplicity != 0)
449         return mlir::WalkResult::advance();
450       auto filterAlloc = [](Operation *op) {
451         return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
452       };
453       auto dependentOps = getSlice(op, filterAlloc);
454       // Create a loop and move instructions from the Op slice into the loop.
455       OpBuilder builder(op);
456       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
457       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
458       auto numIter =
459           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
460       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
461       for (Operation *it : dependentOps) {
462         it->moveBefore(forOp.getBody()->getTerminator());
463       }
464       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
465       // break up the original op and let the patterns propagate.
466       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
467           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
468           map);
469       if (ops.hasValue()) {
470         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
471         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
472       }
473       return mlir::WalkResult::interrupt();
474     });
475     populatePropagateVectorDistributionPatterns(patterns);
476     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
477   }
478 };
479 
480 struct TestVectorTransferUnrollingPatterns
481     : public PassWrapper<TestVectorTransferUnrollingPatterns,
482                          OperationPass<func::FuncOp>> {
483   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
484       TestVectorTransferUnrollingPatterns)
485 
486   TestVectorTransferUnrollingPatterns() = default;
487   TestVectorTransferUnrollingPatterns(
488       const TestVectorTransferUnrollingPatterns &pass)
489       : PassWrapper(pass) {}
490 
491   void getDependentDialects(DialectRegistry &registry) const override {
492     registry.insert<AffineDialect>();
493   }
494   StringRef getArgument() const final {
495     return "test-vector-transfer-unrolling-patterns";
496   }
497   StringRef getDescription() const final {
498     return "Test lowering patterns to unroll transfer ops in the vector "
499            "dialect";
500   }
501   void runOnOperation() override {
502     MLIRContext *ctx = &getContext();
503     RewritePatternSet patterns(ctx);
504     UnrollVectorOptions opts;
505     opts.setNativeShape(ArrayRef<int64_t>{2, 2})
506         .setFilterConstraint([](Operation *op) {
507           return success(
508               isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
509         });
510     if (reverseUnrollOrder.getValue()) {
511       opts.setUnrollTraversalOrderFn(
512           [](Operation *op) -> Optional<SmallVector<int64_t>> {
513             int64_t numLoops = 0;
514             if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
515               numLoops = readOp.getVectorType().getRank();
516             else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
517               numLoops = writeOp.getVectorType().getRank();
518             else
519               return None;
520             auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
521             return llvm::to_vector(order);
522           });
523     }
524     populateVectorUnrollPatterns(patterns, opts);
525     populateVectorToVectorCanonicalizationPatterns(patterns);
526     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
527   }
528 
529   Option<bool> reverseUnrollOrder{
530       *this, "reverse-unroll-order",
531       llvm::cl::desc(
532           "reverse the order of unrolling of vector transfer operations"),
533       llvm::cl::init(false)};
534 };
535 
536 struct TestVectorTransferFullPartialSplitPatterns
537     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
538                          OperationPass<func::FuncOp>> {
539   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
540       TestVectorTransferFullPartialSplitPatterns)
541 
542   StringRef getArgument() const final {
543     return "test-vector-transfer-full-partial-split";
544   }
545   StringRef getDescription() const final {
546     return "Test lowering patterns to split "
547            "transfer ops via scf.if + linalg ops";
548   }
549   TestVectorTransferFullPartialSplitPatterns() = default;
550   TestVectorTransferFullPartialSplitPatterns(
551       const TestVectorTransferFullPartialSplitPatterns &pass)
552       : PassWrapper(pass) {}
553 
554   void getDependentDialects(DialectRegistry &registry) const override {
555     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
556                     scf::SCFDialect>();
557   }
558 
559   Option<bool> useLinalgOps{
560       *this, "use-memref-copy",
561       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
562                      "memref.copy operations."),
563       llvm::cl::init(false)};
564   void runOnOperation() override {
565     MLIRContext *ctx = &getContext();
566     RewritePatternSet patterns(ctx);
567     VectorTransformsOptions options;
568     if (useLinalgOps)
569       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
570     else
571       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
572     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
573     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
574   }
575 };
576 
577 struct TestVectorTransferOpt
578     : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
579   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
580 
581   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
582   StringRef getDescription() const final {
583     return "Test optimization transformations for transfer ops";
584   }
585   void runOnOperation() override { transferOpflowOpt(getOperation()); }
586 };
587 
588 struct TestVectorTransferLoweringPatterns
589     : public PassWrapper<TestVectorTransferLoweringPatterns,
590                          OperationPass<func::FuncOp>> {
591   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
592       TestVectorTransferLoweringPatterns)
593 
594   void getDependentDialects(DialectRegistry &registry) const override {
595     registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
596   }
597   StringRef getArgument() const final {
598     return "test-vector-transfer-lowering-patterns";
599   }
600   StringRef getDescription() const final {
601     return "Test lowering patterns to lower transfer ops to other vector ops";
602   }
603   void runOnOperation() override {
604     RewritePatternSet patterns(&getContext());
605     populateVectorTransferLoweringPatterns(patterns);
606     populateVectorTransferPermutationMapLoweringPatterns(patterns);
607     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
608   }
609 };
610 
611 struct TestVectorMultiReductionLoweringPatterns
612     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
613                          OperationPass<func::FuncOp>> {
614   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
615       TestVectorMultiReductionLoweringPatterns)
616 
617   TestVectorMultiReductionLoweringPatterns() = default;
618   TestVectorMultiReductionLoweringPatterns(
619       const TestVectorMultiReductionLoweringPatterns &pass)
620       : PassWrapper(pass) {}
621   void getDependentDialects(DialectRegistry &registry) const override {
622     registry.insert<memref::MemRefDialect>();
623   }
624   StringRef getArgument() const final {
625     return "test-vector-multi-reduction-lowering-patterns";
626   }
627   StringRef getDescription() const final {
628     return "Test lowering patterns to lower vector.multi_reduction to other "
629            "vector ops";
630   }
631   Option<bool> useOuterReductions{
632       *this, "use-outer-reductions",
633       llvm::cl::desc("Move reductions to outer most dimensions"),
634       llvm::cl::init(false)};
635   void runOnOperation() override {
636     RewritePatternSet patterns(&getContext());
637     populateVectorMultiReductionLoweringPatterns(
638         patterns, useOuterReductions
639                       ? vector::VectorMultiReductionLowering::InnerParallel
640                       : vector::VectorMultiReductionLowering::InnerReduction);
641     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
642   }
643 };
644 
645 struct TestVectorTransferCollapseInnerMostContiguousDims
646     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
647                          OperationPass<func::FuncOp>> {
648   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
649       TestVectorTransferCollapseInnerMostContiguousDims)
650 
651   TestVectorTransferCollapseInnerMostContiguousDims() = default;
652   TestVectorTransferCollapseInnerMostContiguousDims(
653       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
654 
655   void getDependentDialects(DialectRegistry &registry) const override {
656     registry.insert<memref::MemRefDialect, AffineDialect>();
657   }
658 
659   StringRef getArgument() const final {
660     return "test-vector-transfer-collapse-inner-most-dims";
661   }
662 
663   StringRef getDescription() const final {
664     return "Test lowering patterns that reducedes the rank of the vector "
665            "transfer memory and vector operands.";
666   }
667 
668   void runOnOperation() override {
669     RewritePatternSet patterns(&getContext());
670     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
671     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
672   }
673 };
674 
675 struct TestVectorReduceToContractPatternsPatterns
676     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
677                          OperationPass<func::FuncOp>> {
678   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
679       TestVectorReduceToContractPatternsPatterns)
680 
681   StringRef getArgument() const final {
682     return "test-vector-reduction-to-contract-patterns";
683   }
684   StringRef getDescription() const final {
685     return "Test patterns to convert multireduce op to contract and combine "
686            "broadcast/transpose to contract";
687   }
688   void runOnOperation() override {
689     RewritePatternSet patterns(&getContext());
690     populateVectorReductionToContractPatterns(patterns);
691     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
692   }
693 };
694 
695 struct TestVectorTransferDropUnitDimsPatterns
696     : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
697                          OperationPass<func::FuncOp>> {
698   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
699       TestVectorTransferDropUnitDimsPatterns)
700 
701   StringRef getArgument() const final {
702     return "test-vector-transfer-drop-unit-dims-patterns";
703   }
704   void getDependentDialects(DialectRegistry &registry) const override {
705     registry.insert<memref::MemRefDialect>();
706   }
707   void runOnOperation() override {
708     RewritePatternSet patterns(&getContext());
709     populateVectorTransferDropUnitDimsPatterns(patterns);
710     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
711   }
712 };
713 
714 struct TestFlattenVectorTransferPatterns
715     : public PassWrapper<TestFlattenVectorTransferPatterns,
716                          OperationPass<func::FuncOp>> {
717   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
718       TestFlattenVectorTransferPatterns)
719 
720   StringRef getArgument() const final {
721     return "test-vector-transfer-flatten-patterns";
722   }
723   StringRef getDescription() const final {
724     return "Test patterns to rewrite contiguous row-major N-dimensional "
725            "vector.transfer_{read,write} ops into 1D transfers";
726   }
727   void getDependentDialects(DialectRegistry &registry) const override {
728     registry.insert<memref::MemRefDialect>();
729   }
730   void runOnOperation() override {
731     RewritePatternSet patterns(&getContext());
732     populateFlattenVectorTransferPatterns(patterns);
733     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
734   }
735 };
736 
737 struct TestVectorScanLowering
738     : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
739   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
740 
741   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
742   StringRef getDescription() const final {
743     return "Test lowering patterns that lower the scan op in the vector "
744            "dialect";
745   }
746   void runOnOperation() override {
747     RewritePatternSet patterns(&getContext());
748     populateVectorScanLoweringPatterns(patterns);
749     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
750   }
751 };
752 
753 /// Allocate shared memory for a single warp to test lowering of
754 /// WarpExecuteOnLane0Op.
755 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
756                                         WarpExecuteOnLane0Op warpOp,
757                                         Type type) {
758   static constexpr int64_t kSharedMemorySpace = 3;
759   // Compute type of shared memory buffer.
760   MemRefType memrefType;
761   if (auto vectorType = type.dyn_cast<VectorType>()) {
762     memrefType =
763         MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
764                         kSharedMemorySpace);
765   } else {
766     memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
767   }
768 
769   // Get symbol table holding all shared memory globals.
770   ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
771   SymbolTable symbolTable(moduleOp);
772 
773   // Create a pretty name.
774   SmallString<64> buf;
775   llvm::raw_svector_ostream os(buf);
776   interleave(memrefType.getShape(), os, "x");
777   os << "x" << memrefType.getElementType();
778   std::string symbolName = (Twine("__shared_") + os.str()).str();
779 
780   auto ip = builder.saveInsertionPoint();
781   builder.setInsertionPoint(moduleOp);
782   auto global = builder.create<memref::GlobalOp>(
783       loc,
784       /*sym_name=*/symbolName,
785       /*sym_visibility=*/builder.getStringAttr("private"),
786       /*type=*/memrefType,
787       /*initial_value=*/Attribute(),
788       /*constant=*/false,
789       /*alignment=*/IntegerAttr());
790   symbolTable.insert(global);
791   // The symbol table inserts at the end of the module, but globals are a bit
792   // nicer if they are at the beginning.
793   global->moveBefore(&moduleOp.front());
794 
795   builder.restoreInsertionPoint(ip);
796   return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
797 }
798 
799 struct TestVectorDistribution
800     : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
801   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
802 
803   void getDependentDialects(DialectRegistry &registry) const override {
804     registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect>();
805   }
806 
807   StringRef getArgument() const final { return "test-vector-warp-distribute"; }
808   StringRef getDescription() const final {
809     return "Test vector warp distribute transformation and lowering patterns";
810   }
811   TestVectorDistribution() = default;
812   TestVectorDistribution(const TestVectorDistribution &pass)
813       : PassWrapper(pass) {}
814 
815   Option<bool> warpOpToSCF{
816       *this, "rewrite-warp-ops-to-scf-if",
817       llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
818       llvm::cl::init(false)};
819 
820   void runOnOperation() override {
821     RewritePatternSet patterns(&getContext());
822     WarpExecuteOnLane0LoweringOptions options;
823     options.warpAllocationFn = allocateGlobalSharedMemory;
824     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
825                                       WarpExecuteOnLane0Op warpOp) {
826       builder.create<gpu::BarrierOp>(loc);
827     };
828     // Test on one pattern in isolation.
829     if (warpOpToSCF) {
830       populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
831       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
832       return;
833     }
834   }
835 };
836 
837 } // namespace
838 
839 namespace mlir {
840 namespace test {
841 void registerTestVectorLowerings() {
842   PassRegistration<TestVectorToVectorLowering>();
843 
844   PassRegistration<TestVectorContractionLowering>();
845 
846   PassRegistration<TestVectorTransposeLowering>();
847 
848   PassRegistration<TestVectorUnrollingPatterns>();
849 
850   PassRegistration<TestVectorTransferUnrollingPatterns>();
851 
852   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
853 
854   PassRegistration<TestVectorDistributePatterns>();
855 
856   PassRegistration<TestVectorToLoopPatterns>();
857 
858   PassRegistration<TestVectorTransferOpt>();
859 
860   PassRegistration<TestVectorTransferLoweringPatterns>();
861 
862   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
863 
864   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
865 
866   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
867 
868   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
869 
870   PassRegistration<TestFlattenVectorTransferPatterns>();
871 
872   PassRegistration<TestVectorScanLowering>();
873 
874   PassRegistration<TestVectorDistribution>();
875 }
876 } // namespace test
877 } // namespace mlir
878