1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 using namespace mlir::vector;
30 
31 namespace {
32 
33 struct TestVectorToVectorLowering
34     : public PassWrapper<TestVectorToVectorLowering,
35                          OperationPass<func::FuncOp>> {
36   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
37 
38   TestVectorToVectorLowering() = default;
39   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
40       : PassWrapper(pass) {}
41   StringRef getArgument() const final {
42     return "test-vector-to-vector-lowering";
43   }
44   StringRef getDescription() const final {
45     return "Test lowering patterns between ops in the vector dialect";
46   }
47 
48   void getDependentDialects(DialectRegistry &registry) const override {
49     registry.insert<AffineDialect>();
50   }
51 
52   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
53                       llvm::cl::init(false)};
54 
55   void runOnOperation() override {
56     auto *ctx = &getContext();
57     RewritePatternSet patterns(ctx);
58     if (unroll) {
59       populateVectorUnrollPatterns(
60           patterns,
61           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
62               filter));
63     }
64     populateVectorToVectorCanonicalizationPatterns(patterns);
65     populateBubbleVectorBitCastOpPatterns(patterns);
66     populateCastAwayVectorLeadingOneDimPatterns(patterns);
67     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
68   }
69 
70 private:
71   // Return the target shape based on op type.
72   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
73     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
74       return SmallVector<int64_t, 4>(2, 2);
75     if (isa<vector::ContractionOp>(op))
76       return SmallVector<int64_t, 4>(3, 2);
77     // For transfer ops, just propagate the shape coming from
78     // InsertStridedSlices/ExtractStridedSlices.
79     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
80       VectorType dstVec;
81       for (Operation *users : readOp->getUsers()) {
82         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
83         if (!extract)
84           return llvm::None;
85         auto vecType = extract.getResult().getType().cast<VectorType>();
86         if (dstVec && dstVec != vecType)
87           return llvm::None;
88         dstVec = vecType;
89       }
90       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
91                                      dstVec.getShape().end());
92     }
93     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
94       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
95       if (!insert)
96         return llvm::None;
97       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
98       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
99     }
100     return llvm::None;
101   }
102 
103   static LogicalResult filter(Operation *op) {
104     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
105                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
106   }
107 };
108 
109 struct TestVectorContractionLowering
110     : public PassWrapper<TestVectorContractionLowering,
111                          OperationPass<func::FuncOp>> {
112   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
113 
114   StringRef getArgument() const final {
115     return "test-vector-contraction-lowering";
116   }
117   StringRef getDescription() const final {
118     return "Test lowering patterns that lower contract ops in the vector "
119            "dialect";
120   }
121   TestVectorContractionLowering() = default;
122   TestVectorContractionLowering(const TestVectorContractionLowering &pass)
123       : PassWrapper(pass) {}
124 
125   Option<bool> lowerToFlatMatrix{
126       *this, "vector-lower-matrix-intrinsics",
127       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
128       llvm::cl::init(false)};
129   Option<bool> lowerToOuterProduct{
130       *this, "vector-outerproduct",
131       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
132       llvm::cl::init(false)};
133   Option<bool> lowerToFilterOuterProduct{
134       *this, "vector-filter-outerproduct",
135       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
136                      "vectors of size 4."),
137       llvm::cl::init(false)};
138   Option<bool> lowerToParallelArith{
139       *this, "vector-parallel-arith",
140       llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
141       llvm::cl::init(false)};
142 
143   void runOnOperation() override {
144     RewritePatternSet patterns(&getContext());
145 
146     // Test on one pattern in isolation.
147     if (lowerToOuterProduct) {
148       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
149       VectorTransformsOptions options{lowering};
150       patterns.add<ContractionOpToOuterProductOpLowering>(options,
151                                                           &getContext());
152       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
153       return;
154     }
155 
156     // Test on one pattern in isolation.
157     if (lowerToFilterOuterProduct) {
158       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
159       VectorTransformsOptions options{lowering};
160       patterns.add<ContractionOpToOuterProductOpLowering>(
161           options, &getContext(), [](vector::ContractionOp op) {
162             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
163             // where M is not 4.
164             if (op.getRhsType().getShape()[0] == 4)
165               return failure();
166             return success();
167           });
168       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
169       return;
170     }
171 
172     if (lowerToParallelArith) {
173       vector::populateVectorContractLoweringPatterns(
174           patterns,
175           vector::VectorTransformsOptions().setVectorTransformsOptions(
176               vector::VectorContractLowering::ParallelArith));
177       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
178       return;
179     }
180 
181     // Test on all contract lowering patterns.
182     VectorContractLowering contractLowering = VectorContractLowering::Dot;
183     if (lowerToFlatMatrix)
184       contractLowering = VectorContractLowering::Matmul;
185     VectorMultiReductionLowering vectorMultiReductionLowering =
186         VectorMultiReductionLowering::InnerParallel;
187     VectorTransformsOptions options{contractLowering,
188                                     vectorMultiReductionLowering,
189                                     VectorTransposeLowering()};
190     populateVectorBroadcastLoweringPatterns(patterns);
191     populateVectorContractLoweringPatterns(patterns, options);
192     populateVectorMaskOpLoweringPatterns(patterns);
193     populateVectorShapeCastLoweringPatterns(patterns);
194     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
195   }
196 };
197 
198 struct TestVectorTransposeLowering
199     : public PassWrapper<TestVectorTransposeLowering,
200                          OperationPass<func::FuncOp>> {
201   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
202 
203   StringRef getArgument() const final {
204     return "test-vector-transpose-lowering";
205   }
206   StringRef getDescription() const final {
207     return "Test lowering patterns that lower contract ops in the vector "
208            "dialect";
209   }
210   TestVectorTransposeLowering() = default;
211   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
212       : PassWrapper(pass) {}
213 
214   Option<bool> lowerToEltwise{
215       *this, "eltwise",
216       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
217       llvm::cl::init(false)};
218   Option<bool> lowerToFlatTranspose{
219       *this, "flat",
220       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
221       llvm::cl::init(false)};
222   Option<bool> lowerToShuffleTranspose{
223       *this, "shuffle",
224       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
225       llvm::cl::init(false)};
226   Option<bool> lowerToAvx2{
227       *this, "avx2",
228       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
229       llvm::cl::init(false)};
230 
231   void getDependentDialects(DialectRegistry &registry) const override {
232     registry.insert<LLVM::LLVMDialect>();
233   }
234 
235   void runOnOperation() override {
236     RewritePatternSet patterns(&getContext());
237 
238     // Test on one pattern in isolation.
239     // Explicitly disable shape_cast lowering.
240     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
241                                               .enableVectorTransposeLowering()
242                                               .enableShapeCastLowering(false);
243     if (lowerToEltwise) {
244       options = options.setVectorTransformsOptions(
245           VectorTransformsOptions().setVectorTransposeLowering(
246               VectorTransposeLowering::EltWise));
247     }
248     if (lowerToFlatTranspose) {
249       options = options.setVectorTransformsOptions(
250           VectorTransformsOptions().setVectorTransposeLowering(
251               VectorTransposeLowering::Flat));
252     }
253     if (lowerToShuffleTranspose) {
254       options = options.setVectorTransformsOptions(
255           VectorTransformsOptions().setVectorTransposeLowering(
256               VectorTransposeLowering::Shuffle));
257     }
258     if (lowerToAvx2) {
259       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
260           x86vector::avx2::LoweringOptions().setTransposeOptions(
261               x86vector::avx2::TransposeLoweringOptions()
262                   .lower4x8xf32()
263                   .lower8x8xf32()));
264     }
265 
266     OpPassManager dynamicPM("func.func");
267     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
268     if (failed(runPipeline(dynamicPM, getOperation())))
269       return signalPassFailure();
270   }
271 };
272 
273 struct TestVectorUnrollingPatterns
274     : public PassWrapper<TestVectorUnrollingPatterns,
275                          OperationPass<func::FuncOp>> {
276   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
277 
278   StringRef getArgument() const final {
279     return "test-vector-unrolling-patterns";
280   }
281   StringRef getDescription() const final {
282     return "Test lowering patterns to unroll contract ops in the vector "
283            "dialect";
284   }
285   TestVectorUnrollingPatterns() = default;
286   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
287       : PassWrapper(pass) {}
288   void runOnOperation() override {
289     MLIRContext *ctx = &getContext();
290     RewritePatternSet patterns(ctx);
291     populateVectorUnrollPatterns(
292         patterns, UnrollVectorOptions()
293                       .setNativeShape(ArrayRef<int64_t>{2, 2})
294                       .setFilterConstraint([](Operation *op) {
295                         return success(isa<arith::AddFOp, vector::FMAOp,
296                                            vector::MultiDimReductionOp>(op));
297                       }));
298     populateVectorUnrollPatterns(
299         patterns, UnrollVectorOptions()
300                       .setNativeShape(ArrayRef<int64_t>{2})
301                       .setFilterConstraint([](Operation *op) {
302                         return success(isa<vector::ReductionOp>(op));
303                       }));
304     populateVectorUnrollPatterns(
305         patterns, UnrollVectorOptions()
306                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
307                       .setFilterConstraint([](Operation *op) {
308                         return success(isa<vector::TransposeOp>(op));
309                       }));
310 
311     if (unrollBasedOnType) {
312       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
313           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
314         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
315         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
316         if (auto floatType = contractOp.getLhsType()
317                                  .getElementType()
318                                  .dyn_cast<FloatType>()) {
319           if (floatType.getWidth() == 16) {
320             nativeShape[2] = 4;
321           }
322         }
323         return nativeShape;
324       };
325       populateVectorUnrollPatterns(patterns,
326                                    UnrollVectorOptions()
327                                        .setNativeShapeFn(nativeShapeFn)
328                                        .setFilterConstraint([](Operation *op) {
329                                          return success(isa<ContractionOp>(op));
330                                        }));
331     } else {
332       populateVectorUnrollPatterns(
333           patterns, UnrollVectorOptions()
334                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
335                         .setFilterConstraint([](Operation *op) {
336                           return success(isa<ContractionOp>(op));
337                         }));
338     }
339     populateVectorToVectorCanonicalizationPatterns(patterns);
340     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
341   }
342 
343   Option<bool> unrollBasedOnType{
344       *this, "unroll-based-on-type",
345       llvm::cl::desc("Set the unroll factor based on type of the operation"),
346       llvm::cl::init(false)};
347 };
348 
349 struct TestVectorDistributePatterns
350     : public PassWrapper<TestVectorDistributePatterns,
351                          OperationPass<func::FuncOp>> {
352   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns)
353 
354   StringRef getArgument() const final {
355     return "test-vector-distribute-patterns";
356   }
357   StringRef getDescription() const final {
358     return "Test lowering patterns to distribute vector ops in the vector "
359            "dialect";
360   }
361   TestVectorDistributePatterns() = default;
362   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
363       : PassWrapper(pass) {}
364   void getDependentDialects(DialectRegistry &registry) const override {
365     registry.insert<VectorDialect>();
366     registry.insert<AffineDialect>();
367   }
368   ListOption<int32_t> multiplicity{
369       *this, "distribution-multiplicity",
370       llvm::cl::desc("Set the multiplicity used for distributing vector")};
371 
372   void runOnOperation() override {
373     MLIRContext *ctx = &getContext();
374     RewritePatternSet patterns(ctx);
375     func::FuncOp func = getOperation();
376     func.walk([&](arith::AddFOp op) {
377       OpBuilder builder(op);
378       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
379         SmallVector<int64_t, 2> mul;
380         SmallVector<AffineExpr, 2> perm;
381         SmallVector<Value, 2> ids;
382         unsigned count = 0;
383         // Remove the multiplicity of 1 and calculate the affine map based on
384         // the multiplicity.
385         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
386         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
387           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
388             mul.push_back(m[i]);
389             ids.push_back(func.getArgument(count++));
390             perm.push_back(getAffineDimExpr(i, ctx));
391           }
392         }
393         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
394                                   perm, ctx);
395         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
396             builder, op.getOperation(), ids, mul, map);
397         if (ops.hasValue()) {
398           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
399           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
400                                               extractOp);
401         }
402       }
403     });
404     populatePropagateVectorDistributionPatterns(patterns);
405     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
406   }
407 };
408 
409 struct TestVectorToLoopPatterns
410     : public PassWrapper<TestVectorToLoopPatterns,
411                          OperationPass<func::FuncOp>> {
412   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns)
413 
414   StringRef getArgument() const final { return "test-vector-to-forloop"; }
415   StringRef getDescription() const final {
416     return "Test lowering patterns to break up a vector op into a for loop";
417   }
418   TestVectorToLoopPatterns() = default;
419   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
420       : PassWrapper(pass) {}
421   void getDependentDialects(DialectRegistry &registry) const override {
422     registry.insert<VectorDialect>();
423     registry.insert<AffineDialect>();
424   }
425   Option<int32_t> multiplicity{
426       *this, "distribution-multiplicity",
427       llvm::cl::desc("Set the multiplicity used for distributing vector"),
428       llvm::cl::init(32)};
429   void runOnOperation() override {
430     MLIRContext *ctx = &getContext();
431     RewritePatternSet patterns(ctx);
432     func::FuncOp func = getOperation();
433     func.walk([&](arith::AddFOp op) {
434       // Check that the operation type can be broken down into a loop.
435       VectorType type = op.getType().dyn_cast<VectorType>();
436       if (!type || type.getRank() != 1 ||
437           type.getNumElements() % multiplicity != 0)
438         return mlir::WalkResult::advance();
439       auto filterAlloc = [](Operation *op) {
440         return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
441       };
442       auto dependentOps = getSlice(op, filterAlloc);
443       // Create a loop and move instructions from the Op slice into the loop.
444       OpBuilder builder(op);
445       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
446       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
447       auto numIter =
448           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
449       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
450       for (Operation *it : dependentOps) {
451         it->moveBefore(forOp.getBody()->getTerminator());
452       }
453       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
454       // break up the original op and let the patterns propagate.
455       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
456           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
457           map);
458       if (ops.hasValue()) {
459         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
460         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
461       }
462       return mlir::WalkResult::interrupt();
463     });
464     populatePropagateVectorDistributionPatterns(patterns);
465     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
466   }
467 };
468 
469 struct TestVectorTransferUnrollingPatterns
470     : public PassWrapper<TestVectorTransferUnrollingPatterns,
471                          OperationPass<func::FuncOp>> {
472   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
473       TestVectorTransferUnrollingPatterns)
474 
475   void getDependentDialects(DialectRegistry &registry) const override {
476     registry.insert<AffineDialect>();
477   }
478   StringRef getArgument() const final {
479     return "test-vector-transfer-unrolling-patterns";
480   }
481   StringRef getDescription() const final {
482     return "Test lowering patterns to unroll transfer ops in the vector "
483            "dialect";
484   }
485   void runOnOperation() override {
486     MLIRContext *ctx = &getContext();
487     RewritePatternSet patterns(ctx);
488     populateVectorUnrollPatterns(
489         patterns,
490         UnrollVectorOptions()
491             .setNativeShape(ArrayRef<int64_t>{2, 2})
492             .setFilterConstraint([](Operation *op) {
493               return success(
494                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
495             }));
496     populateVectorToVectorCanonicalizationPatterns(patterns);
497     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
498   }
499 };
500 
501 struct TestVectorTransferFullPartialSplitPatterns
502     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
503                          OperationPass<func::FuncOp>> {
504   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
505       TestVectorTransferFullPartialSplitPatterns)
506 
507   StringRef getArgument() const final {
508     return "test-vector-transfer-full-partial-split";
509   }
510   StringRef getDescription() const final {
511     return "Test lowering patterns to split "
512            "transfer ops via scf.if + linalg ops";
513   }
514   TestVectorTransferFullPartialSplitPatterns() = default;
515   TestVectorTransferFullPartialSplitPatterns(
516       const TestVectorTransferFullPartialSplitPatterns &pass)
517       : PassWrapper(pass) {}
518 
519   void getDependentDialects(DialectRegistry &registry) const override {
520     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
521                     scf::SCFDialect>();
522   }
523 
524   Option<bool> useLinalgOps{
525       *this, "use-memref-copy",
526       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
527                      "memref.copy operations."),
528       llvm::cl::init(false)};
529   void runOnOperation() override {
530     MLIRContext *ctx = &getContext();
531     RewritePatternSet patterns(ctx);
532     VectorTransformsOptions options;
533     if (useLinalgOps)
534       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
535     else
536       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
537     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
538     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
539   }
540 };
541 
542 struct TestVectorTransferOpt
543     : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
544   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
545 
546   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
547   StringRef getDescription() const final {
548     return "Test optimization transformations for transfer ops";
549   }
550   void runOnOperation() override { transferOpflowOpt(getOperation()); }
551 };
552 
553 struct TestVectorTransferLoweringPatterns
554     : public PassWrapper<TestVectorTransferLoweringPatterns,
555                          OperationPass<func::FuncOp>> {
556   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
557       TestVectorTransferLoweringPatterns)
558 
559   void getDependentDialects(DialectRegistry &registry) const override {
560     registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
561   }
562   StringRef getArgument() const final {
563     return "test-vector-transfer-lowering-patterns";
564   }
565   StringRef getDescription() const final {
566     return "Test lowering patterns to lower transfer ops to other vector ops";
567   }
568   void runOnOperation() override {
569     RewritePatternSet patterns(&getContext());
570     populateVectorTransferLoweringPatterns(patterns);
571     populateVectorTransferPermutationMapLoweringPatterns(patterns);
572     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
573   }
574 };
575 
576 struct TestVectorMultiReductionLoweringPatterns
577     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
578                          OperationPass<func::FuncOp>> {
579   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
580       TestVectorMultiReductionLoweringPatterns)
581 
582   TestVectorMultiReductionLoweringPatterns() = default;
583   TestVectorMultiReductionLoweringPatterns(
584       const TestVectorMultiReductionLoweringPatterns &pass)
585       : PassWrapper(pass) {}
586   void getDependentDialects(DialectRegistry &registry) const override {
587     registry.insert<memref::MemRefDialect>();
588   }
589   StringRef getArgument() const final {
590     return "test-vector-multi-reduction-lowering-patterns";
591   }
592   StringRef getDescription() const final {
593     return "Test lowering patterns to lower vector.multi_reduction to other "
594            "vector ops";
595   }
596   Option<bool> useOuterReductions{
597       *this, "use-outer-reductions",
598       llvm::cl::desc("Move reductions to outer most dimensions"),
599       llvm::cl::init(false)};
600   void runOnOperation() override {
601     RewritePatternSet patterns(&getContext());
602     populateVectorMultiReductionLoweringPatterns(
603         patterns, useOuterReductions
604                       ? vector::VectorMultiReductionLowering::InnerParallel
605                       : vector::VectorMultiReductionLowering::InnerReduction);
606     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
607   }
608 };
609 
610 struct TestVectorTransferCollapseInnerMostContiguousDims
611     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
612                          OperationPass<func::FuncOp>> {
613   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
614       TestVectorTransferCollapseInnerMostContiguousDims)
615 
616   TestVectorTransferCollapseInnerMostContiguousDims() = default;
617   TestVectorTransferCollapseInnerMostContiguousDims(
618       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
619 
620   void getDependentDialects(DialectRegistry &registry) const override {
621     registry.insert<memref::MemRefDialect, AffineDialect>();
622   }
623 
624   StringRef getArgument() const final {
625     return "test-vector-transfer-collapse-inner-most-dims";
626   }
627 
628   StringRef getDescription() const final {
629     return "Test lowering patterns that reducedes the rank of the vector "
630            "transfer memory and vector operands.";
631   }
632 
633   void runOnOperation() override {
634     RewritePatternSet patterns(&getContext());
635     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
636     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
637   }
638 };
639 
640 struct TestVectorReduceToContractPatternsPatterns
641     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
642                          OperationPass<func::FuncOp>> {
643   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
644       TestVectorReduceToContractPatternsPatterns)
645 
646   StringRef getArgument() const final {
647     return "test-vector-reduction-to-contract-patterns";
648   }
649   StringRef getDescription() const final {
650     return "Test patterns to convert multireduce op to contract and combine "
651            "broadcast/transpose to contract";
652   }
653   void runOnOperation() override {
654     RewritePatternSet patterns(&getContext());
655     populateVectorReductionToContractPatterns(patterns);
656     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
657   }
658 };
659 
660 struct TestVectorTransferDropUnitDimsPatterns
661     : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
662                          OperationPass<func::FuncOp>> {
663   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
664       TestVectorTransferDropUnitDimsPatterns)
665 
666   StringRef getArgument() const final {
667     return "test-vector-transfer-drop-unit-dims-patterns";
668   }
669   void getDependentDialects(DialectRegistry &registry) const override {
670     registry.insert<memref::MemRefDialect>();
671   }
672   void runOnOperation() override {
673     RewritePatternSet patterns(&getContext());
674     populateVectorTransferDropUnitDimsPatterns(patterns);
675     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
676   }
677 };
678 
679 struct TestFlattenVectorTransferPatterns
680     : public PassWrapper<TestFlattenVectorTransferPatterns,
681                          OperationPass<func::FuncOp>> {
682   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
683       TestFlattenVectorTransferPatterns)
684 
685   StringRef getArgument() const final {
686     return "test-vector-transfer-flatten-patterns";
687   }
688   StringRef getDescription() const final {
689     return "Test patterns to rewrite contiguous row-major N-dimensional "
690            "vector.transfer_{read,write} ops into 1D transfers";
691   }
692   void getDependentDialects(DialectRegistry &registry) const override {
693     registry.insert<memref::MemRefDialect>();
694   }
695   void runOnOperation() override {
696     RewritePatternSet patterns(&getContext());
697     populateFlattenVectorTransferPatterns(patterns);
698     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
699   }
700 };
701 
702 struct TestVectorScanLowering
703     : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
704   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
705 
706   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
707   StringRef getDescription() const final {
708     return "Test lowering patterns that lower the scan op in the vector "
709            "dialect";
710   }
711   void runOnOperation() override {
712     RewritePatternSet patterns(&getContext());
713     populateVectorScanLoweringPatterns(patterns);
714     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
715   }
716 };
717 
718 /// Allocate shared memory for a single warp to test lowering of
719 /// WarpExecuteOnLane0Op.
720 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
721                                         WarpExecuteOnLane0Op warpOp,
722                                         Type type) {
723   static constexpr int64_t kSharedMemorySpace = 3;
724   // Compute type of shared memory buffer.
725   MemRefType memrefType;
726   if (auto vectorType = type.dyn_cast<VectorType>()) {
727     memrefType =
728         MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
729                         kSharedMemorySpace);
730   } else {
731     memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
732   }
733 
734   // Get symbol table holding all shared memory globals.
735   ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
736   SymbolTable symbolTable(moduleOp);
737 
738   // Create a pretty name.
739   SmallString<64> buf;
740   llvm::raw_svector_ostream os(buf);
741   interleave(memrefType.getShape(), os, "x");
742   os << "x" << memrefType.getElementType();
743   std::string symbolName = (Twine("__shared_") + os.str()).str();
744 
745   auto ip = builder.saveInsertionPoint();
746   builder.setInsertionPoint(moduleOp);
747   auto global = builder.create<memref::GlobalOp>(
748       loc,
749       /*sym_name=*/symbolName,
750       /*sym_visibility=*/builder.getStringAttr("private"),
751       /*type=*/memrefType,
752       /*initial_value=*/Attribute(),
753       /*constant=*/false,
754       /*alignment=*/IntegerAttr());
755   symbolTable.insert(global);
756   // The symbol table inserts at the end of the module, but globals are a bit
757   // nicer if they are at the beginning.
758   global->moveBefore(&moduleOp.front());
759 
760   builder.restoreInsertionPoint(ip);
761   return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
762 }
763 
764 struct TestVectorDistribution
765     : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
766   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
767 
768   void getDependentDialects(DialectRegistry &registry) const override {
769     registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect>();
770   }
771 
772   StringRef getArgument() const final { return "test-vector-warp-distribute"; }
773   StringRef getDescription() const final {
774     return "Test vector warp distribute transformation and lowering patterns";
775   }
776   TestVectorDistribution() = default;
777   TestVectorDistribution(const TestVectorDistribution &pass)
778       : PassWrapper(pass) {}
779 
780   Option<bool> warpOpToSCF{
781       *this, "rewrite-warp-ops-to-scf-if",
782       llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
783       llvm::cl::init(false)};
784 
785   void runOnOperation() override {
786     RewritePatternSet patterns(&getContext());
787     WarpExecuteOnLane0LoweringOptions options;
788     options.warpAllocationFn = allocateGlobalSharedMemory;
789     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
790                                       WarpExecuteOnLane0Op warpOp) {
791       builder.create<gpu::BarrierOp>(loc);
792     };
793     // Test on one pattern in isolation.
794     if (warpOpToSCF) {
795       populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
796       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
797       return;
798     }
799   }
800 };
801 
802 } // namespace
803 
804 namespace mlir {
805 namespace test {
806 void registerTestVectorLowerings() {
807   PassRegistration<TestVectorToVectorLowering>();
808 
809   PassRegistration<TestVectorContractionLowering>();
810 
811   PassRegistration<TestVectorTransposeLowering>();
812 
813   PassRegistration<TestVectorUnrollingPatterns>();
814 
815   PassRegistration<TestVectorTransferUnrollingPatterns>();
816 
817   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
818 
819   PassRegistration<TestVectorDistributePatterns>();
820 
821   PassRegistration<TestVectorToLoopPatterns>();
822 
823   PassRegistration<TestVectorTransferOpt>();
824 
825   PassRegistration<TestVectorTransferLoweringPatterns>();
826 
827   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
828 
829   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
830 
831   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
832 
833   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
834 
835   PassRegistration<TestFlattenVectorTransferPatterns>();
836 
837   PassRegistration<TestVectorScanLowering>();
838 
839   PassRegistration<TestVectorDistribution>();
840 }
841 } // namespace test
842 } // namespace mlir
843