1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 using namespace mlir::vector;
30 
31 namespace {
32 
33 struct TestVectorToVectorLowering
34     : public PassWrapper<TestVectorToVectorLowering,
35                          OperationPass<func::FuncOp>> {
36   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
37 
38   TestVectorToVectorLowering() = default;
39   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
40       : PassWrapper(pass) {}
41   StringRef getArgument() const final {
42     return "test-vector-to-vector-lowering";
43   }
44   StringRef getDescription() const final {
45     return "Test lowering patterns between ops in the vector dialect";
46   }
47 
48   void getDependentDialects(DialectRegistry &registry) const override {
49     registry.insert<AffineDialect>();
50   }
51 
52   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
53                       llvm::cl::init(false)};
54 
55   void runOnOperation() override {
56     auto *ctx = &getContext();
57     RewritePatternSet patterns(ctx);
58     if (unroll) {
59       populateVectorUnrollPatterns(
60           patterns,
61           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
62               filter));
63     }
64     populateVectorToVectorCanonicalizationPatterns(patterns);
65     populateBubbleVectorBitCastOpPatterns(patterns);
66     populateCastAwayVectorLeadingOneDimPatterns(patterns);
67     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
68   }
69 
70 private:
71   // Return the target shape based on op type.
72   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
73     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
74       return SmallVector<int64_t, 4>(2, 2);
75     if (isa<vector::ContractionOp>(op))
76       return SmallVector<int64_t, 4>(3, 2);
77     // For transfer ops, just propagate the shape coming from
78     // InsertStridedSlices/ExtractStridedSlices.
79     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
80       VectorType dstVec;
81       for (Operation *users : readOp->getUsers()) {
82         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
83         if (!extract)
84           return llvm::None;
85         auto vecType = extract.getResult().getType().cast<VectorType>();
86         if (dstVec && dstVec != vecType)
87           return llvm::None;
88         dstVec = vecType;
89       }
90       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
91                                      dstVec.getShape().end());
92     }
93     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
94       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
95       if (!insert)
96         return llvm::None;
97       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
98       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
99     }
100     return llvm::None;
101   }
102 
103   static LogicalResult filter(Operation *op) {
104     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
105                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
106   }
107 };
108 
109 struct TestVectorContractionLowering
110     : public PassWrapper<TestVectorContractionLowering,
111                          OperationPass<func::FuncOp>> {
112   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
113 
114   StringRef getArgument() const final {
115     return "test-vector-contraction-lowering";
116   }
117   StringRef getDescription() const final {
118     return "Test lowering patterns that lower contract ops in the vector "
119            "dialect";
120   }
121   TestVectorContractionLowering() = default;
122   TestVectorContractionLowering(const TestVectorContractionLowering &pass)
123       : PassWrapper(pass) {}
124 
125   Option<bool> lowerToFlatMatrix{
126       *this, "vector-lower-matrix-intrinsics",
127       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
128       llvm::cl::init(false)};
129   Option<bool> lowerToOuterProduct{
130       *this, "vector-outerproduct",
131       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
132       llvm::cl::init(false)};
133   Option<bool> lowerToFilterOuterProduct{
134       *this, "vector-filter-outerproduct",
135       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
136                      "vectors of size 4."),
137       llvm::cl::init(false)};
138 
139   void runOnOperation() override {
140     RewritePatternSet patterns(&getContext());
141 
142     // Test on one pattern in isolation.
143     if (lowerToOuterProduct) {
144       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
145       VectorTransformsOptions options{lowering};
146       patterns.add<ContractionOpToOuterProductOpLowering>(options,
147                                                           &getContext());
148       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
149       return;
150     }
151 
152     // Test on one pattern in isolation.
153     if (lowerToFilterOuterProduct) {
154       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
155       VectorTransformsOptions options{lowering};
156       patterns.add<ContractionOpToOuterProductOpLowering>(
157           options, &getContext(), [](vector::ContractionOp op) {
158             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
159             // where M is not 4.
160             if (op.getRhsType().getShape()[0] == 4)
161               return failure();
162             return success();
163           });
164       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
165       return;
166     }
167 
168     // Test on all contract lowering patterns.
169     VectorContractLowering contractLowering = VectorContractLowering::Dot;
170     if (lowerToFlatMatrix)
171       contractLowering = VectorContractLowering::Matmul;
172     VectorMultiReductionLowering vectorMultiReductionLowering =
173         VectorMultiReductionLowering::InnerParallel;
174     VectorTransformsOptions options{contractLowering,
175                                     vectorMultiReductionLowering,
176                                     VectorTransposeLowering()};
177     populateVectorBroadcastLoweringPatterns(patterns);
178     populateVectorContractLoweringPatterns(patterns, options);
179     populateVectorMaskOpLoweringPatterns(patterns);
180     populateVectorShapeCastLoweringPatterns(patterns);
181     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
182   }
183 };
184 
185 struct TestVectorTransposeLowering
186     : public PassWrapper<TestVectorTransposeLowering,
187                          OperationPass<func::FuncOp>> {
188   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
189 
190   StringRef getArgument() const final {
191     return "test-vector-transpose-lowering";
192   }
193   StringRef getDescription() const final {
194     return "Test lowering patterns that lower contract ops in the vector "
195            "dialect";
196   }
197   TestVectorTransposeLowering() = default;
198   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
199       : PassWrapper(pass) {}
200 
201   Option<bool> lowerToEltwise{
202       *this, "eltwise",
203       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
204       llvm::cl::init(false)};
205   Option<bool> lowerToFlatTranspose{
206       *this, "flat",
207       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
208       llvm::cl::init(false)};
209   Option<bool> lowerToShuffleTranspose{
210       *this, "shuffle",
211       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
212       llvm::cl::init(false)};
213   Option<bool> lowerToAvx2{
214       *this, "avx2",
215       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
216       llvm::cl::init(false)};
217 
218   void getDependentDialects(DialectRegistry &registry) const override {
219     registry.insert<LLVM::LLVMDialect>();
220   }
221 
222   void runOnOperation() override {
223     RewritePatternSet patterns(&getContext());
224 
225     // Test on one pattern in isolation.
226     // Explicitly disable shape_cast lowering.
227     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
228                                               .enableVectorTransposeLowering()
229                                               .enableShapeCastLowering(false);
230     if (lowerToEltwise) {
231       options = options.setVectorTransformsOptions(
232           VectorTransformsOptions().setVectorTransposeLowering(
233               VectorTransposeLowering::EltWise));
234     }
235     if (lowerToFlatTranspose) {
236       options = options.setVectorTransformsOptions(
237           VectorTransformsOptions().setVectorTransposeLowering(
238               VectorTransposeLowering::Flat));
239     }
240     if (lowerToShuffleTranspose) {
241       options = options.setVectorTransformsOptions(
242           VectorTransformsOptions().setVectorTransposeLowering(
243               VectorTransposeLowering::Shuffle));
244     }
245     if (lowerToAvx2) {
246       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
247           x86vector::avx2::LoweringOptions().setTransposeOptions(
248               x86vector::avx2::TransposeLoweringOptions()
249                   .lower4x8xf32()
250                   .lower8x8xf32()));
251     }
252 
253     OpPassManager dynamicPM("func.func");
254     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
255     if (failed(runPipeline(dynamicPM, getOperation())))
256       return signalPassFailure();
257   }
258 };
259 
260 struct TestVectorUnrollingPatterns
261     : public PassWrapper<TestVectorUnrollingPatterns,
262                          OperationPass<func::FuncOp>> {
263   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
264 
265   StringRef getArgument() const final {
266     return "test-vector-unrolling-patterns";
267   }
268   StringRef getDescription() const final {
269     return "Test lowering patterns to unroll contract ops in the vector "
270            "dialect";
271   }
272   TestVectorUnrollingPatterns() = default;
273   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
274       : PassWrapper(pass) {}
275   void runOnOperation() override {
276     MLIRContext *ctx = &getContext();
277     RewritePatternSet patterns(ctx);
278     populateVectorUnrollPatterns(
279         patterns, UnrollVectorOptions()
280                       .setNativeShape(ArrayRef<int64_t>{2, 2})
281                       .setFilterConstraint([](Operation *op) {
282                         return success(isa<arith::AddFOp, vector::FMAOp,
283                                            vector::MultiDimReductionOp>(op));
284                       }));
285     populateVectorUnrollPatterns(
286         patterns, UnrollVectorOptions()
287                       .setNativeShape(ArrayRef<int64_t>{2})
288                       .setFilterConstraint([](Operation *op) {
289                         return success(isa<vector::ReductionOp>(op));
290                       }));
291     populateVectorUnrollPatterns(
292         patterns, UnrollVectorOptions()
293                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
294                       .setFilterConstraint([](Operation *op) {
295                         return success(isa<vector::TransposeOp>(op));
296                       }));
297 
298     if (unrollBasedOnType) {
299       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
300           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
301         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
302         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
303         if (auto floatType = contractOp.getLhsType()
304                                  .getElementType()
305                                  .dyn_cast<FloatType>()) {
306           if (floatType.getWidth() == 16) {
307             nativeShape[2] = 4;
308           }
309         }
310         return nativeShape;
311       };
312       populateVectorUnrollPatterns(patterns,
313                                    UnrollVectorOptions()
314                                        .setNativeShapeFn(nativeShapeFn)
315                                        .setFilterConstraint([](Operation *op) {
316                                          return success(isa<ContractionOp>(op));
317                                        }));
318     } else {
319       populateVectorUnrollPatterns(
320           patterns, UnrollVectorOptions()
321                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
322                         .setFilterConstraint([](Operation *op) {
323                           return success(isa<ContractionOp>(op));
324                         }));
325     }
326     populateVectorToVectorCanonicalizationPatterns(patterns);
327     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
328   }
329 
330   Option<bool> unrollBasedOnType{
331       *this, "unroll-based-on-type",
332       llvm::cl::desc("Set the unroll factor based on type of the operation"),
333       llvm::cl::init(false)};
334 };
335 
336 struct TestVectorDistributePatterns
337     : public PassWrapper<TestVectorDistributePatterns,
338                          OperationPass<func::FuncOp>> {
339   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns)
340 
341   StringRef getArgument() const final {
342     return "test-vector-distribute-patterns";
343   }
344   StringRef getDescription() const final {
345     return "Test lowering patterns to distribute vector ops in the vector "
346            "dialect";
347   }
348   TestVectorDistributePatterns() = default;
349   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
350       : PassWrapper(pass) {}
351   void getDependentDialects(DialectRegistry &registry) const override {
352     registry.insert<VectorDialect>();
353     registry.insert<AffineDialect>();
354   }
355   ListOption<int32_t> multiplicity{
356       *this, "distribution-multiplicity",
357       llvm::cl::desc("Set the multiplicity used for distributing vector")};
358 
359   void runOnOperation() override {
360     MLIRContext *ctx = &getContext();
361     RewritePatternSet patterns(ctx);
362     func::FuncOp func = getOperation();
363     func.walk([&](arith::AddFOp op) {
364       OpBuilder builder(op);
365       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
366         SmallVector<int64_t, 2> mul;
367         SmallVector<AffineExpr, 2> perm;
368         SmallVector<Value, 2> ids;
369         unsigned count = 0;
370         // Remove the multiplicity of 1 and calculate the affine map based on
371         // the multiplicity.
372         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
373         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
374           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
375             mul.push_back(m[i]);
376             ids.push_back(func.getArgument(count++));
377             perm.push_back(getAffineDimExpr(i, ctx));
378           }
379         }
380         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
381                                   perm, ctx);
382         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
383             builder, op.getOperation(), ids, mul, map);
384         if (ops.hasValue()) {
385           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
386           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
387                                               extractOp);
388         }
389       }
390     });
391     populatePropagateVectorDistributionPatterns(patterns);
392     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
393   }
394 };
395 
396 struct TestVectorToLoopPatterns
397     : public PassWrapper<TestVectorToLoopPatterns,
398                          OperationPass<func::FuncOp>> {
399   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns)
400 
401   StringRef getArgument() const final { return "test-vector-to-forloop"; }
402   StringRef getDescription() const final {
403     return "Test lowering patterns to break up a vector op into a for loop";
404   }
405   TestVectorToLoopPatterns() = default;
406   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
407       : PassWrapper(pass) {}
408   void getDependentDialects(DialectRegistry &registry) const override {
409     registry.insert<VectorDialect>();
410     registry.insert<AffineDialect>();
411   }
412   Option<int32_t> multiplicity{
413       *this, "distribution-multiplicity",
414       llvm::cl::desc("Set the multiplicity used for distributing vector"),
415       llvm::cl::init(32)};
416   void runOnOperation() override {
417     MLIRContext *ctx = &getContext();
418     RewritePatternSet patterns(ctx);
419     func::FuncOp func = getOperation();
420     func.walk([&](arith::AddFOp op) {
421       // Check that the operation type can be broken down into a loop.
422       VectorType type = op.getType().dyn_cast<VectorType>();
423       if (!type || type.getRank() != 1 ||
424           type.getNumElements() % multiplicity != 0)
425         return mlir::WalkResult::advance();
426       auto filterAlloc = [](Operation *op) {
427         return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
428       };
429       auto dependentOps = getSlice(op, filterAlloc);
430       // Create a loop and move instructions from the Op slice into the loop.
431       OpBuilder builder(op);
432       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
433       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
434       auto numIter =
435           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
436       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
437       for (Operation *it : dependentOps) {
438         it->moveBefore(forOp.getBody()->getTerminator());
439       }
440       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
441       // break up the original op and let the patterns propagate.
442       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
443           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
444           map);
445       if (ops.hasValue()) {
446         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
447         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
448       }
449       return mlir::WalkResult::interrupt();
450     });
451     populatePropagateVectorDistributionPatterns(patterns);
452     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
453   }
454 };
455 
456 struct TestVectorTransferUnrollingPatterns
457     : public PassWrapper<TestVectorTransferUnrollingPatterns,
458                          OperationPass<func::FuncOp>> {
459   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
460       TestVectorTransferUnrollingPatterns)
461 
462   void getDependentDialects(DialectRegistry &registry) const override {
463     registry.insert<AffineDialect>();
464   }
465   StringRef getArgument() const final {
466     return "test-vector-transfer-unrolling-patterns";
467   }
468   StringRef getDescription() const final {
469     return "Test lowering patterns to unroll transfer ops in the vector "
470            "dialect";
471   }
472   void runOnOperation() override {
473     MLIRContext *ctx = &getContext();
474     RewritePatternSet patterns(ctx);
475     populateVectorUnrollPatterns(
476         patterns,
477         UnrollVectorOptions()
478             .setNativeShape(ArrayRef<int64_t>{2, 2})
479             .setFilterConstraint([](Operation *op) {
480               return success(
481                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
482             }));
483     populateVectorToVectorCanonicalizationPatterns(patterns);
484     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
485   }
486 };
487 
488 struct TestVectorTransferFullPartialSplitPatterns
489     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
490                          OperationPass<func::FuncOp>> {
491   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
492       TestVectorTransferFullPartialSplitPatterns)
493 
494   StringRef getArgument() const final {
495     return "test-vector-transfer-full-partial-split";
496   }
497   StringRef getDescription() const final {
498     return "Test lowering patterns to split "
499            "transfer ops via scf.if + linalg ops";
500   }
501   TestVectorTransferFullPartialSplitPatterns() = default;
502   TestVectorTransferFullPartialSplitPatterns(
503       const TestVectorTransferFullPartialSplitPatterns &pass)
504       : PassWrapper(pass) {}
505 
506   void getDependentDialects(DialectRegistry &registry) const override {
507     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
508                     scf::SCFDialect>();
509   }
510 
511   Option<bool> useLinalgOps{
512       *this, "use-memref-copy",
513       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
514                      "memref.copy operations."),
515       llvm::cl::init(false)};
516   void runOnOperation() override {
517     MLIRContext *ctx = &getContext();
518     RewritePatternSet patterns(ctx);
519     VectorTransformsOptions options;
520     if (useLinalgOps)
521       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
522     else
523       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
524     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
525     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
526   }
527 };
528 
529 struct TestVectorTransferOpt
530     : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
531   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
532 
533   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
534   StringRef getDescription() const final {
535     return "Test optimization transformations for transfer ops";
536   }
537   void runOnOperation() override { transferOpflowOpt(getOperation()); }
538 };
539 
540 struct TestVectorTransferLoweringPatterns
541     : public PassWrapper<TestVectorTransferLoweringPatterns,
542                          OperationPass<func::FuncOp>> {
543   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
544       TestVectorTransferLoweringPatterns)
545 
546   void getDependentDialects(DialectRegistry &registry) const override {
547     registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
548   }
549   StringRef getArgument() const final {
550     return "test-vector-transfer-lowering-patterns";
551   }
552   StringRef getDescription() const final {
553     return "Test lowering patterns to lower transfer ops to other vector ops";
554   }
555   void runOnOperation() override {
556     RewritePatternSet patterns(&getContext());
557     populateVectorTransferLoweringPatterns(patterns);
558     populateVectorTransferPermutationMapLoweringPatterns(patterns);
559     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
560   }
561 };
562 
563 struct TestVectorMultiReductionLoweringPatterns
564     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
565                          OperationPass<func::FuncOp>> {
566   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
567       TestVectorMultiReductionLoweringPatterns)
568 
569   TestVectorMultiReductionLoweringPatterns() = default;
570   TestVectorMultiReductionLoweringPatterns(
571       const TestVectorMultiReductionLoweringPatterns &pass)
572       : PassWrapper(pass) {}
573   void getDependentDialects(DialectRegistry &registry) const override {
574     registry.insert<memref::MemRefDialect>();
575   }
576   StringRef getArgument() const final {
577     return "test-vector-multi-reduction-lowering-patterns";
578   }
579   StringRef getDescription() const final {
580     return "Test lowering patterns to lower vector.multi_reduction to other "
581            "vector ops";
582   }
583   Option<bool> useOuterReductions{
584       *this, "use-outer-reductions",
585       llvm::cl::desc("Move reductions to outer most dimensions"),
586       llvm::cl::init(false)};
587   void runOnOperation() override {
588     RewritePatternSet patterns(&getContext());
589     populateVectorMultiReductionLoweringPatterns(
590         patterns, useOuterReductions
591                       ? vector::VectorMultiReductionLowering::InnerParallel
592                       : vector::VectorMultiReductionLowering::InnerReduction);
593     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
594   }
595 };
596 
597 struct TestVectorTransferCollapseInnerMostContiguousDims
598     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
599                          OperationPass<func::FuncOp>> {
600   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
601       TestVectorTransferCollapseInnerMostContiguousDims)
602 
603   TestVectorTransferCollapseInnerMostContiguousDims() = default;
604   TestVectorTransferCollapseInnerMostContiguousDims(
605       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
606 
607   void getDependentDialects(DialectRegistry &registry) const override {
608     registry.insert<memref::MemRefDialect, AffineDialect>();
609   }
610 
611   StringRef getArgument() const final {
612     return "test-vector-transfer-collapse-inner-most-dims";
613   }
614 
615   StringRef getDescription() const final {
616     return "Test lowering patterns that reducedes the rank of the vector "
617            "transfer memory and vector operands.";
618   }
619 
620   void runOnOperation() override {
621     RewritePatternSet patterns(&getContext());
622     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
623     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
624   }
625 };
626 
627 struct TestVectorReduceToContractPatternsPatterns
628     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
629                          OperationPass<func::FuncOp>> {
630   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
631       TestVectorReduceToContractPatternsPatterns)
632 
633   StringRef getArgument() const final {
634     return "test-vector-reduction-to-contract-patterns";
635   }
636   StringRef getDescription() const final {
637     return "Test patterns to convert multireduce op to contract and combine "
638            "broadcast/transpose to contract";
639   }
640   void runOnOperation() override {
641     RewritePatternSet patterns(&getContext());
642     populateVectorReductionToContractPatterns(patterns);
643     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
644   }
645 };
646 
647 struct TestVectorTransferDropUnitDimsPatterns
648     : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
649                          OperationPass<func::FuncOp>> {
650   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
651       TestVectorTransferDropUnitDimsPatterns)
652 
653   StringRef getArgument() const final {
654     return "test-vector-transfer-drop-unit-dims-patterns";
655   }
656   void getDependentDialects(DialectRegistry &registry) const override {
657     registry.insert<memref::MemRefDialect>();
658   }
659   void runOnOperation() override {
660     RewritePatternSet patterns(&getContext());
661     populateVectorTransferDropUnitDimsPatterns(patterns);
662     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
663   }
664 };
665 
666 struct TestFlattenVectorTransferPatterns
667     : public PassWrapper<TestFlattenVectorTransferPatterns,
668                          OperationPass<func::FuncOp>> {
669   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
670       TestFlattenVectorTransferPatterns)
671 
672   StringRef getArgument() const final {
673     return "test-vector-transfer-flatten-patterns";
674   }
675   StringRef getDescription() const final {
676     return "Test patterns to rewrite contiguous row-major N-dimensional "
677            "vector.transfer_{read,write} ops into 1D transfers";
678   }
679   void getDependentDialects(DialectRegistry &registry) const override {
680     registry.insert<memref::MemRefDialect>();
681   }
682   void runOnOperation() override {
683     RewritePatternSet patterns(&getContext());
684     populateFlattenVectorTransferPatterns(patterns);
685     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
686   }
687 };
688 
689 struct TestVectorScanLowering
690     : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
691   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
692 
693   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
694   StringRef getDescription() const final {
695     return "Test lowering patterns that lower the scan op in the vector "
696            "dialect";
697   }
698   void runOnOperation() override {
699     RewritePatternSet patterns(&getContext());
700     populateVectorScanLoweringPatterns(patterns);
701     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
702   }
703 };
704 
705 /// Allocate shared memory for a single warp to test lowering of
706 /// WarpExecuteOnLane0Op.
707 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
708                                         WarpExecuteOnLane0Op warpOp,
709                                         Type type) {
710   static constexpr int64_t kSharedMemorySpace = 3;
711   // Compute type of shared memory buffer.
712   MemRefType memrefType;
713   if (auto vectorType = type.dyn_cast<VectorType>()) {
714     memrefType =
715         MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
716                         kSharedMemorySpace);
717   } else {
718     memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
719   }
720 
721   // Get symbol table holding all shared memory globals.
722   ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
723   SymbolTable symbolTable(moduleOp);
724 
725   // Create a pretty name.
726   SmallString<64> buf;
727   llvm::raw_svector_ostream os(buf);
728   interleave(memrefType.getShape(), os, "x");
729   os << "x" << memrefType.getElementType();
730   std::string symbolName = (Twine("__shared_") + os.str()).str();
731 
732   auto ip = builder.saveInsertionPoint();
733   builder.setInsertionPoint(moduleOp);
734   auto global = builder.create<memref::GlobalOp>(
735       loc,
736       /*sym_name=*/symbolName,
737       /*sym_visibility=*/builder.getStringAttr("private"),
738       /*type=*/memrefType,
739       /*initial_value=*/Attribute(),
740       /*constant=*/false,
741       /*alignment=*/IntegerAttr());
742   symbolTable.insert(global);
743   // The symbol table inserts at the end of the module, but globals are a bit
744   // nicer if they are at the beginning.
745   global->moveBefore(&moduleOp.front());
746 
747   builder.restoreInsertionPoint(ip);
748   return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
749 }
750 
751 struct TestVectorDistribution
752     : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
753   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
754 
755   void getDependentDialects(DialectRegistry &registry) const override {
756     registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect>();
757   }
758 
759   StringRef getArgument() const final { return "test-vector-warp-distribute"; }
760   StringRef getDescription() const final {
761     return "Test vector warp distribute transformation and lowering patterns";
762   }
763   TestVectorDistribution() = default;
764   TestVectorDistribution(const TestVectorDistribution &pass)
765       : PassWrapper(pass) {}
766 
767   Option<bool> warpOpToSCF{
768       *this, "rewrite-warp-ops-to-scf-if",
769       llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
770       llvm::cl::init(false)};
771 
772   void runOnOperation() override {
773     RewritePatternSet patterns(&getContext());
774     WarpExecuteOnLane0LoweringOptions options;
775     options.warpAllocationFn = allocateGlobalSharedMemory;
776     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
777                                       WarpExecuteOnLane0Op warpOp) {
778       builder.create<gpu::BarrierOp>(loc);
779     };
780     // Test on one pattern in isolation.
781     if (warpOpToSCF) {
782       populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
783       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
784       return;
785     }
786   }
787 };
788 
789 } // namespace
790 
791 namespace mlir {
792 namespace test {
793 void registerTestVectorLowerings() {
794   PassRegistration<TestVectorToVectorLowering>();
795 
796   PassRegistration<TestVectorContractionLowering>();
797 
798   PassRegistration<TestVectorTransposeLowering>();
799 
800   PassRegistration<TestVectorUnrollingPatterns>();
801 
802   PassRegistration<TestVectorTransferUnrollingPatterns>();
803 
804   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
805 
806   PassRegistration<TestVectorDistributePatterns>();
807 
808   PassRegistration<TestVectorToLoopPatterns>();
809 
810   PassRegistration<TestVectorTransferOpt>();
811 
812   PassRegistration<TestVectorTransferLoweringPatterns>();
813 
814   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
815 
816   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
817 
818   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
819 
820   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
821 
822   PassRegistration<TestFlattenVectorTransferPatterns>();
823 
824   PassRegistration<TestVectorScanLowering>();
825 
826   PassRegistration<TestVectorDistribution>();
827 }
828 } // namespace test
829 } // namespace mlir
830