134ff8573SNicolas Vasilache //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle 
93fef2d26SRiver Riddle #include <type_traits>
103fef2d26SRiver Riddle 
113fef2d26SRiver Riddle #include "mlir/Analysis/SliceAnalysis.h"
123fef2d26SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
1323aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
14d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15b2729fdaSNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
1734ff8573SNicolas Vasilache #include "mlir/Dialect/Linalg/Passes.h"
1834ff8573SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
193fef2d26SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
219f122152SChristopher Bate #include "mlir/Dialect/Vector/IR/VectorOps.h"
22d02f10d9SThomas Raoux #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
2399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
243fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
2534ff8573SNicolas Vasilache #include "mlir/Pass/PassManager.h"
269f122152SChristopher Bate #include "mlir/Support/LLVM.h"
273fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
283fef2d26SRiver Riddle 
293fef2d26SRiver Riddle using namespace mlir;
3034ff8573SNicolas Vasilache using namespace mlir::linalg;
313fef2d26SRiver Riddle using namespace mlir::vector;
32d054b80bSNicolas Vasilache 
333fef2d26SRiver Riddle namespace {
343fef2d26SRiver Riddle 
3534ff8573SNicolas Vasilache struct TestVectorToVectorLowering
3658ceae95SRiver Riddle     : public PassWrapper<TestVectorToVectorLowering,
3758ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
385e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
395e50dd04SRiver Riddle 
4034ff8573SNicolas Vasilache   TestVectorToVectorLowering() = default;
TestVectorToVectorLowering__anonb56dea510111::TestVectorToVectorLowering413bab9d4eSMehdi Amini   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
423bab9d4eSMehdi Amini       : PassWrapper(pass) {}
getArgument__anonb56dea510111::TestVectorToVectorLowering43b5e22e6dSMehdi Amini   StringRef getArgument() const final {
4434ff8573SNicolas Vasilache     return "test-vector-to-vector-lowering";
45b5e22e6dSMehdi Amini   }
getDescription__anonb56dea510111::TestVectorToVectorLowering46b5e22e6dSMehdi Amini   StringRef getDescription() const final {
4734ff8573SNicolas Vasilache     return "Test lowering patterns between ops in the vector dialect";
48b5e22e6dSMehdi Amini   }
493fef2d26SRiver Riddle 
getDependentDialects__anonb56dea510111::TestVectorToVectorLowering503fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
513fef2d26SRiver Riddle     registry.insert<AffineDialect>();
523fef2d26SRiver Riddle   }
533fef2d26SRiver Riddle 
543fef2d26SRiver Riddle   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
553fef2d26SRiver Riddle                       llvm::cl::init(false)};
563fef2d26SRiver Riddle 
runOnOperation__anonb56dea510111::TestVectorToVectorLowering5741574554SRiver Riddle   void runOnOperation() override {
583fef2d26SRiver Riddle     auto *ctx = &getContext();
593fef2d26SRiver Riddle     RewritePatternSet patterns(ctx);
603fef2d26SRiver Riddle     if (unroll) {
6129102538Sthomasraoux       populateVectorUnrollPatterns(
6229102538Sthomasraoux           patterns,
633fef2d26SRiver Riddle           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
643fef2d26SRiver Riddle               filter));
653fef2d26SRiver Riddle     }
663fef2d26SRiver Riddle     populateVectorToVectorCanonicalizationPatterns(patterns);
673fef2d26SRiver Riddle     populateBubbleVectorBitCastOpPatterns(patterns);
683fef2d26SRiver Riddle     populateCastAwayVectorLeadingOneDimPatterns(patterns);
6941574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
703fef2d26SRiver Riddle   }
713fef2d26SRiver Riddle 
723fef2d26SRiver Riddle private:
733fef2d26SRiver Riddle   // Return the target shape based on op type.
getShape__anonb56dea510111::TestVectorToVectorLowering743fef2d26SRiver Riddle   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
75dec8af70SRiver Riddle     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
763fef2d26SRiver Riddle       return SmallVector<int64_t, 4>(2, 2);
773fef2d26SRiver Riddle     if (isa<vector::ContractionOp>(op))
783fef2d26SRiver Riddle       return SmallVector<int64_t, 4>(3, 2);
7929102538Sthomasraoux     // For transfer ops, just propagate the shape coming from
8029102538Sthomasraoux     // InsertStridedSlices/ExtractStridedSlices.
8129102538Sthomasraoux     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
8229102538Sthomasraoux       VectorType dstVec;
8329102538Sthomasraoux       for (Operation *users : readOp->getUsers()) {
8429102538Sthomasraoux         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
8529102538Sthomasraoux         if (!extract)
8629102538Sthomasraoux           return llvm::None;
8729102538Sthomasraoux         auto vecType = extract.getResult().getType().cast<VectorType>();
8829102538Sthomasraoux         if (dstVec && dstVec != vecType)
8929102538Sthomasraoux           return llvm::None;
9029102538Sthomasraoux         dstVec = vecType;
9129102538Sthomasraoux       }
9229102538Sthomasraoux       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
9329102538Sthomasraoux                                      dstVec.getShape().end());
9429102538Sthomasraoux     }
9529102538Sthomasraoux     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
967c38fd60SJacques Pienaar       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
9729102538Sthomasraoux       if (!insert)
9829102538Sthomasraoux         return llvm::None;
9929102538Sthomasraoux       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
10029102538Sthomasraoux       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
10129102538Sthomasraoux     }
1023fef2d26SRiver Riddle     return llvm::None;
1033fef2d26SRiver Riddle   }
1043fef2d26SRiver Riddle 
filter__anonb56dea510111::TestVectorToVectorLowering1053fef2d26SRiver Riddle   static LogicalResult filter(Operation *op) {
106dec8af70SRiver Riddle     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
107dec8af70SRiver Riddle                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
1083fef2d26SRiver Riddle   }
1093fef2d26SRiver Riddle };
1103fef2d26SRiver Riddle 
11134ff8573SNicolas Vasilache struct TestVectorContractionLowering
11258ceae95SRiver Riddle     : public PassWrapper<TestVectorContractionLowering,
11358ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorContractionLowering1145e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
1155e50dd04SRiver Riddle 
116b5e22e6dSMehdi Amini   StringRef getArgument() const final {
11734ff8573SNicolas Vasilache     return "test-vector-contraction-lowering";
118b5e22e6dSMehdi Amini   }
getDescription__anonb56dea510111::TestVectorContractionLowering119b5e22e6dSMehdi Amini   StringRef getDescription() const final {
12034ff8573SNicolas Vasilache     return "Test lowering patterns that lower contract ops in the vector "
121b5e22e6dSMehdi Amini            "dialect";
122b5e22e6dSMehdi Amini   }
12334ff8573SNicolas Vasilache   TestVectorContractionLowering() = default;
TestVectorContractionLowering__anonb56dea510111::TestVectorContractionLowering1243bab9d4eSMehdi Amini   TestVectorContractionLowering(const TestVectorContractionLowering &pass)
1253bab9d4eSMehdi Amini       : PassWrapper(pass) {}
1263fef2d26SRiver Riddle 
1273fef2d26SRiver Riddle   Option<bool> lowerToFlatMatrix{
1283fef2d26SRiver Riddle       *this, "vector-lower-matrix-intrinsics",
1293fef2d26SRiver Riddle       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
1303fef2d26SRiver Riddle       llvm::cl::init(false)};
1313fef2d26SRiver Riddle   Option<bool> lowerToOuterProduct{
1323fef2d26SRiver Riddle       *this, "vector-outerproduct",
1333fef2d26SRiver Riddle       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
1343fef2d26SRiver Riddle       llvm::cl::init(false)};
1353fef2d26SRiver Riddle   Option<bool> lowerToFilterOuterProduct{
1363fef2d26SRiver Riddle       *this, "vector-filter-outerproduct",
1373fef2d26SRiver Riddle       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
1383fef2d26SRiver Riddle                      "vectors of size 4."),
1393fef2d26SRiver Riddle       llvm::cl::init(false)};
14089aaa2d0SThomas Raoux   Option<bool> lowerToParallelArith{
14189aaa2d0SThomas Raoux       *this, "vector-parallel-arith",
14289aaa2d0SThomas Raoux       llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
14389aaa2d0SThomas Raoux       llvm::cl::init(false)};
1443fef2d26SRiver Riddle 
runOnOperation__anonb56dea510111::TestVectorContractionLowering14541574554SRiver Riddle   void runOnOperation() override {
1463fef2d26SRiver Riddle     RewritePatternSet patterns(&getContext());
1473fef2d26SRiver Riddle 
1483fef2d26SRiver Riddle     // Test on one pattern in isolation.
1493fef2d26SRiver Riddle     if (lowerToOuterProduct) {
1503fef2d26SRiver Riddle       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
1513fef2d26SRiver Riddle       VectorTransformsOptions options{lowering};
1523fef2d26SRiver Riddle       patterns.add<ContractionOpToOuterProductOpLowering>(options,
1533fef2d26SRiver Riddle                                                           &getContext());
15441574554SRiver Riddle       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1553fef2d26SRiver Riddle       return;
1563fef2d26SRiver Riddle     }
1573fef2d26SRiver Riddle 
1583fef2d26SRiver Riddle     // Test on one pattern in isolation.
1593fef2d26SRiver Riddle     if (lowerToFilterOuterProduct) {
1603fef2d26SRiver Riddle       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
1613fef2d26SRiver Riddle       VectorTransformsOptions options{lowering};
1623fef2d26SRiver Riddle       patterns.add<ContractionOpToOuterProductOpLowering>(
1633fef2d26SRiver Riddle           options, &getContext(), [](vector::ContractionOp op) {
1643fef2d26SRiver Riddle             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
1653fef2d26SRiver Riddle             // where M is not 4.
1663fef2d26SRiver Riddle             if (op.getRhsType().getShape()[0] == 4)
1673fef2d26SRiver Riddle               return failure();
1683fef2d26SRiver Riddle             return success();
1693fef2d26SRiver Riddle           });
17041574554SRiver Riddle       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1713fef2d26SRiver Riddle       return;
1723fef2d26SRiver Riddle     }
1733fef2d26SRiver Riddle 
17489aaa2d0SThomas Raoux     if (lowerToParallelArith) {
17589aaa2d0SThomas Raoux       vector::populateVectorContractLoweringPatterns(
17689aaa2d0SThomas Raoux           patterns,
17789aaa2d0SThomas Raoux           vector::VectorTransformsOptions().setVectorTransformsOptions(
17889aaa2d0SThomas Raoux               vector::VectorContractLowering::ParallelArith));
17989aaa2d0SThomas Raoux       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
18089aaa2d0SThomas Raoux       return;
18189aaa2d0SThomas Raoux     }
18289aaa2d0SThomas Raoux 
1833fef2d26SRiver Riddle     // Test on all contract lowering patterns.
1843fef2d26SRiver Riddle     VectorContractLowering contractLowering = VectorContractLowering::Dot;
1853fef2d26SRiver Riddle     if (lowerToFlatMatrix)
1863fef2d26SRiver Riddle       contractLowering = VectorContractLowering::Matmul;
187176a0ea5SNicolas Vasilache     VectorMultiReductionLowering vectorMultiReductionLowering =
188176a0ea5SNicolas Vasilache         VectorMultiReductionLowering::InnerParallel;
18934ff8573SNicolas Vasilache     VectorTransformsOptions options{contractLowering,
19034ff8573SNicolas Vasilache                                     vectorMultiReductionLowering,
19134ff8573SNicolas Vasilache                                     VectorTransposeLowering()};
1923964c1dbSLei Zhang     populateVectorBroadcastLoweringPatterns(patterns);
1933fef2d26SRiver Riddle     populateVectorContractLoweringPatterns(patterns, options);
1943964c1dbSLei Zhang     populateVectorMaskOpLoweringPatterns(patterns);
1953964c1dbSLei Zhang     populateVectorShapeCastLoweringPatterns(patterns);
19641574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1973fef2d26SRiver Riddle   }
1983fef2d26SRiver Riddle };
1993fef2d26SRiver Riddle 
20034ff8573SNicolas Vasilache struct TestVectorTransposeLowering
20158ceae95SRiver Riddle     : public PassWrapper<TestVectorTransposeLowering,
20258ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransposeLowering2035e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
2045e50dd04SRiver Riddle 
20534ff8573SNicolas Vasilache   StringRef getArgument() const final {
20634ff8573SNicolas Vasilache     return "test-vector-transpose-lowering";
20734ff8573SNicolas Vasilache   }
getDescription__anonb56dea510111::TestVectorTransposeLowering20834ff8573SNicolas Vasilache   StringRef getDescription() const final {
20934ff8573SNicolas Vasilache     return "Test lowering patterns that lower contract ops in the vector "
21034ff8573SNicolas Vasilache            "dialect";
21134ff8573SNicolas Vasilache   }
21234ff8573SNicolas Vasilache   TestVectorTransposeLowering() = default;
TestVectorTransposeLowering__anonb56dea510111::TestVectorTransposeLowering2133bab9d4eSMehdi Amini   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
2143bab9d4eSMehdi Amini       : PassWrapper(pass) {}
21534ff8573SNicolas Vasilache 
21634ff8573SNicolas Vasilache   Option<bool> lowerToEltwise{
21734ff8573SNicolas Vasilache       *this, "eltwise",
21834ff8573SNicolas Vasilache       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
21934ff8573SNicolas Vasilache       llvm::cl::init(false)};
22034ff8573SNicolas Vasilache   Option<bool> lowerToFlatTranspose{
22134ff8573SNicolas Vasilache       *this, "flat",
22234ff8573SNicolas Vasilache       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
22334ff8573SNicolas Vasilache       llvm::cl::init(false)};
22434ff8573SNicolas Vasilache   Option<bool> lowerToShuffleTranspose{
22534ff8573SNicolas Vasilache       *this, "shuffle",
22634ff8573SNicolas Vasilache       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
22734ff8573SNicolas Vasilache       llvm::cl::init(false)};
22834ff8573SNicolas Vasilache   Option<bool> lowerToAvx2{
22934ff8573SNicolas Vasilache       *this, "avx2",
23034ff8573SNicolas Vasilache       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
23134ff8573SNicolas Vasilache       llvm::cl::init(false)};
23234ff8573SNicolas Vasilache 
getDependentDialects__anonb56dea510111::TestVectorTransposeLowering233b2729fdaSNicolas Vasilache   void getDependentDialects(DialectRegistry &registry) const override {
234b2729fdaSNicolas Vasilache     registry.insert<LLVM::LLVMDialect>();
235b2729fdaSNicolas Vasilache   }
236b2729fdaSNicolas Vasilache 
runOnOperation__anonb56dea510111::TestVectorTransposeLowering23741574554SRiver Riddle   void runOnOperation() override {
23834ff8573SNicolas Vasilache     RewritePatternSet patterns(&getContext());
23934ff8573SNicolas Vasilache 
24034ff8573SNicolas Vasilache     // Test on one pattern in isolation.
24134ff8573SNicolas Vasilache     // Explicitly disable shape_cast lowering.
24234ff8573SNicolas Vasilache     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
24334ff8573SNicolas Vasilache                                               .enableVectorTransposeLowering()
24434ff8573SNicolas Vasilache                                               .enableShapeCastLowering(false);
24534ff8573SNicolas Vasilache     if (lowerToEltwise) {
24634ff8573SNicolas Vasilache       options = options.setVectorTransformsOptions(
24734ff8573SNicolas Vasilache           VectorTransformsOptions().setVectorTransposeLowering(
24834ff8573SNicolas Vasilache               VectorTransposeLowering::EltWise));
24934ff8573SNicolas Vasilache     }
25034ff8573SNicolas Vasilache     if (lowerToFlatTranspose) {
25134ff8573SNicolas Vasilache       options = options.setVectorTransformsOptions(
25234ff8573SNicolas Vasilache           VectorTransformsOptions().setVectorTransposeLowering(
25334ff8573SNicolas Vasilache               VectorTransposeLowering::Flat));
25434ff8573SNicolas Vasilache     }
25534ff8573SNicolas Vasilache     if (lowerToShuffleTranspose) {
25634ff8573SNicolas Vasilache       options = options.setVectorTransformsOptions(
25734ff8573SNicolas Vasilache           VectorTransformsOptions().setVectorTransposeLowering(
25834ff8573SNicolas Vasilache               VectorTransposeLowering::Shuffle));
25934ff8573SNicolas Vasilache     }
26034ff8573SNicolas Vasilache     if (lowerToAvx2) {
26134ff8573SNicolas Vasilache       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
26234ff8573SNicolas Vasilache           x86vector::avx2::LoweringOptions().setTransposeOptions(
26334ff8573SNicolas Vasilache               x86vector::avx2::TransposeLoweringOptions()
26434ff8573SNicolas Vasilache                   .lower4x8xf32()
26534ff8573SNicolas Vasilache                   .lower8x8xf32()));
26634ff8573SNicolas Vasilache     }
26734ff8573SNicolas Vasilache 
26836550692SRiver Riddle     OpPassManager dynamicPM("func.func");
26934ff8573SNicolas Vasilache     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
27041574554SRiver Riddle     if (failed(runPipeline(dynamicPM, getOperation())))
27134ff8573SNicolas Vasilache       return signalPassFailure();
27234ff8573SNicolas Vasilache   }
27334ff8573SNicolas Vasilache };
27434ff8573SNicolas Vasilache 
2753fef2d26SRiver Riddle struct TestVectorUnrollingPatterns
27658ceae95SRiver Riddle     : public PassWrapper<TestVectorUnrollingPatterns,
27758ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorUnrollingPatterns2785e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
2795e50dd04SRiver Riddle 
280b5e22e6dSMehdi Amini   StringRef getArgument() const final {
281b5e22e6dSMehdi Amini     return "test-vector-unrolling-patterns";
282b5e22e6dSMehdi Amini   }
getDescription__anonb56dea510111::TestVectorUnrollingPatterns283b5e22e6dSMehdi Amini   StringRef getDescription() const final {
28434ff8573SNicolas Vasilache     return "Test lowering patterns to unroll contract ops in the vector "
285b5e22e6dSMehdi Amini            "dialect";
286b5e22e6dSMehdi Amini   }
2873fef2d26SRiver Riddle   TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anonb56dea510111::TestVectorUnrollingPatterns2883bab9d4eSMehdi Amini   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
2893bab9d4eSMehdi Amini       : PassWrapper(pass) {}
runOnOperation__anonb56dea510111::TestVectorUnrollingPatterns29041574554SRiver Riddle   void runOnOperation() override {
2913fef2d26SRiver Riddle     MLIRContext *ctx = &getContext();
2923fef2d26SRiver Riddle     RewritePatternSet patterns(ctx);
29329102538Sthomasraoux     populateVectorUnrollPatterns(
29429102538Sthomasraoux         patterns, UnrollVectorOptions()
2953fef2d26SRiver Riddle                       .setNativeShape(ArrayRef<int64_t>{2, 2})
2963fef2d26SRiver Riddle                       .setFilterConstraint([](Operation *op) {
297f69175b1SThomas Raoux                         return success(isa<arith::AddFOp, vector::FMAOp,
298f69175b1SThomas Raoux                                            vector::MultiDimReductionOp>(op));
2993fef2d26SRiver Riddle                       }));
300de5022c7SMatthias Springer     populateVectorUnrollPatterns(
301de5022c7SMatthias Springer         patterns, UnrollVectorOptions()
302de5022c7SMatthias Springer                       .setNativeShape(ArrayRef<int64_t>{2})
303de5022c7SMatthias Springer                       .setFilterConstraint([](Operation *op) {
304de5022c7SMatthias Springer                         return success(isa<vector::ReductionOp>(op));
305de5022c7SMatthias Springer                       }));
3065b1b7108SThomas Raoux     populateVectorUnrollPatterns(
3075b1b7108SThomas Raoux         patterns, UnrollVectorOptions()
3085b1b7108SThomas Raoux                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
3095b1b7108SThomas Raoux                       .setFilterConstraint([](Operation *op) {
3105b1b7108SThomas Raoux                         return success(isa<vector::TransposeOp>(op));
3115b1b7108SThomas Raoux                       }));
3123fef2d26SRiver Riddle 
3133fef2d26SRiver Riddle     if (unrollBasedOnType) {
3143fef2d26SRiver Riddle       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
3153fef2d26SRiver Riddle           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
3163fef2d26SRiver Riddle         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
3179f122152SChristopher Bate         SmallVector<int64_t, 4> nativeShape(
3189f122152SChristopher Bate             contractOp.getIteratorTypes().size(), 4);
3199f122152SChristopher Bate         Type lhsType = contractOp.getLhsType().getElementType();
3209f122152SChristopher Bate         nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
3213fef2d26SRiver Riddle         return nativeShape;
3223fef2d26SRiver Riddle       };
3239f122152SChristopher Bate 
3249f122152SChristopher Bate       UnrollVectorOptions opts;
3259f122152SChristopher Bate       opts.setNativeShapeFn(nativeShapeFn)
3269f122152SChristopher Bate           .setFilterConstraint(
3279f122152SChristopher Bate               [](Operation *op) { return success(isa<ContractionOp>(op)); });
3289f122152SChristopher Bate 
3299f122152SChristopher Bate       if (!unrollOrder.empty()) {
3309f122152SChristopher Bate         opts.setUnrollTraversalOrderFn([this](Operation *op)
3319f122152SChristopher Bate                                            -> Optional<SmallVector<int64_t>> {
3329f122152SChristopher Bate           vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
3339f122152SChristopher Bate           if (contractOp.getIteratorTypes().size() == unrollOrder.size())
3349f122152SChristopher Bate             return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end());
3359f122152SChristopher Bate           return None;
3369f122152SChristopher Bate         });
3379f122152SChristopher Bate       }
3389f122152SChristopher Bate       populateVectorUnrollPatterns(patterns, opts);
3399f122152SChristopher Bate     } else {
3409f122152SChristopher Bate       auto nativeShapeFn =
3419f122152SChristopher Bate           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
3429f122152SChristopher Bate         auto contractOp = dyn_cast<ContractionOp>(op);
3439f122152SChristopher Bate         if (!contractOp)
3449f122152SChristopher Bate           return None;
3459f122152SChristopher Bate         return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
3469f122152SChristopher Bate       };
34753fe155bSChristopher Bate       populateVectorUnrollPatterns(patterns,
34853fe155bSChristopher Bate                                    UnrollVectorOptions()
34953fe155bSChristopher Bate                                        .setNativeShapeFn(nativeShapeFn)
35053fe155bSChristopher Bate                                        .setFilterConstraint([](Operation *op) {
35153fe155bSChristopher Bate                                          return success(isa<ContractionOp>(op));
35253fe155bSChristopher Bate                                        }));
3533fef2d26SRiver Riddle     }
3543fef2d26SRiver Riddle     populateVectorToVectorCanonicalizationPatterns(patterns);
35541574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
3563fef2d26SRiver Riddle   }
3573fef2d26SRiver Riddle 
3589f122152SChristopher Bate   ListOption<int64_t> unrollOrder{*this, "unroll-order",
359*62a4e6abSFangrui Song                                   llvm::cl::desc("set the unroll order")};
3609f122152SChristopher Bate 
3613fef2d26SRiver Riddle   Option<bool> unrollBasedOnType{
3623fef2d26SRiver Riddle       *this, "unroll-based-on-type",
3633fef2d26SRiver Riddle       llvm::cl::desc("Set the unroll factor based on type of the operation"),
3643fef2d26SRiver Riddle       llvm::cl::init(false)};
3653fef2d26SRiver Riddle };
3663fef2d26SRiver Riddle 
3673fef2d26SRiver Riddle struct TestVectorDistributePatterns
36858ceae95SRiver Riddle     : public PassWrapper<TestVectorDistributePatterns,
36958ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorDistributePatterns3705e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns)
3715e50dd04SRiver Riddle 
372b5e22e6dSMehdi Amini   StringRef getArgument() const final {
373b5e22e6dSMehdi Amini     return "test-vector-distribute-patterns";
374b5e22e6dSMehdi Amini   }
getDescription__anonb56dea510111::TestVectorDistributePatterns375b5e22e6dSMehdi Amini   StringRef getDescription() const final {
37634ff8573SNicolas Vasilache     return "Test lowering patterns to distribute vector ops in the vector "
377b5e22e6dSMehdi Amini            "dialect";
378b5e22e6dSMehdi Amini   }
3793fef2d26SRiver Riddle   TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anonb56dea510111::TestVectorDistributePatterns3803bab9d4eSMehdi Amini   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
3813bab9d4eSMehdi Amini       : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorDistributePatterns3823fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
3833fef2d26SRiver Riddle     registry.insert<VectorDialect>();
3843fef2d26SRiver Riddle     registry.insert<AffineDialect>();
3853fef2d26SRiver Riddle   }
3863fef2d26SRiver Riddle   ListOption<int32_t> multiplicity{
3876edef135SRiver Riddle       *this, "distribution-multiplicity",
3883fef2d26SRiver Riddle       llvm::cl::desc("Set the multiplicity used for distributing vector")};
3893fef2d26SRiver Riddle 
runOnOperation__anonb56dea510111::TestVectorDistributePatterns39041574554SRiver Riddle   void runOnOperation() override {
3913fef2d26SRiver Riddle     MLIRContext *ctx = &getContext();
3923fef2d26SRiver Riddle     RewritePatternSet patterns(ctx);
39358ceae95SRiver Riddle     func::FuncOp func = getOperation();
394a54f4eaeSMogball     func.walk([&](arith::AddFOp op) {
3953fef2d26SRiver Riddle       OpBuilder builder(op);
3963fef2d26SRiver Riddle       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
3973fef2d26SRiver Riddle         SmallVector<int64_t, 2> mul;
3983fef2d26SRiver Riddle         SmallVector<AffineExpr, 2> perm;
3993fef2d26SRiver Riddle         SmallVector<Value, 2> ids;
4003fef2d26SRiver Riddle         unsigned count = 0;
4013fef2d26SRiver Riddle         // Remove the multiplicity of 1 and calculate the affine map based on
4023fef2d26SRiver Riddle         // the multiplicity.
4033fef2d26SRiver Riddle         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
4043fef2d26SRiver Riddle         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
4053fef2d26SRiver Riddle           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
4063fef2d26SRiver Riddle             mul.push_back(m[i]);
4073fef2d26SRiver Riddle             ids.push_back(func.getArgument(count++));
4083fef2d26SRiver Riddle             perm.push_back(getAffineDimExpr(i, ctx));
4093fef2d26SRiver Riddle           }
4103fef2d26SRiver Riddle         }
4113fef2d26SRiver Riddle         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
4123fef2d26SRiver Riddle                                   perm, ctx);
4133fef2d26SRiver Riddle         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
4143fef2d26SRiver Riddle             builder, op.getOperation(), ids, mul, map);
415037f0995SKazu Hirata         if (ops) {
4163fef2d26SRiver Riddle           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
4173fef2d26SRiver Riddle           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
4183fef2d26SRiver Riddle                                               extractOp);
4193fef2d26SRiver Riddle         }
4203fef2d26SRiver Riddle       }
4213fef2d26SRiver Riddle     });
422627733b5Sthomasraoux     populatePropagateVectorDistributionPatterns(patterns);
42341574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
4243fef2d26SRiver Riddle   }
4253fef2d26SRiver Riddle };
4263fef2d26SRiver Riddle 
4273fef2d26SRiver Riddle struct TestVectorToLoopPatterns
42858ceae95SRiver Riddle     : public PassWrapper<TestVectorToLoopPatterns,
42958ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorToLoopPatterns4305e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns)
4315e50dd04SRiver Riddle 
432b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-vector-to-forloop"; }
getDescription__anonb56dea510111::TestVectorToLoopPatterns433b5e22e6dSMehdi Amini   StringRef getDescription() const final {
43434ff8573SNicolas Vasilache     return "Test lowering patterns to break up a vector op into a for loop";
435b5e22e6dSMehdi Amini   }
4363fef2d26SRiver Riddle   TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anonb56dea510111::TestVectorToLoopPatterns4373bab9d4eSMehdi Amini   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
4383bab9d4eSMehdi Amini       : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorToLoopPatterns4393fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
4403fef2d26SRiver Riddle     registry.insert<VectorDialect>();
4413fef2d26SRiver Riddle     registry.insert<AffineDialect>();
4423fef2d26SRiver Riddle   }
4433fef2d26SRiver Riddle   Option<int32_t> multiplicity{
4443fef2d26SRiver Riddle       *this, "distribution-multiplicity",
4453fef2d26SRiver Riddle       llvm::cl::desc("Set the multiplicity used for distributing vector"),
4463fef2d26SRiver Riddle       llvm::cl::init(32)};
runOnOperation__anonb56dea510111::TestVectorToLoopPatterns44741574554SRiver Riddle   void runOnOperation() override {
4483fef2d26SRiver Riddle     MLIRContext *ctx = &getContext();
4493fef2d26SRiver Riddle     RewritePatternSet patterns(ctx);
45058ceae95SRiver Riddle     func::FuncOp func = getOperation();
451a54f4eaeSMogball     func.walk([&](arith::AddFOp op) {
4523fef2d26SRiver Riddle       // Check that the operation type can be broken down into a loop.
4533fef2d26SRiver Riddle       VectorType type = op.getType().dyn_cast<VectorType>();
4543fef2d26SRiver Riddle       if (!type || type.getRank() != 1 ||
4553fef2d26SRiver Riddle           type.getNumElements() % multiplicity != 0)
4563fef2d26SRiver Riddle         return mlir::WalkResult::advance();
4573fef2d26SRiver Riddle       auto filterAlloc = [](Operation *op) {
45823aa5a74SRiver Riddle         return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
4593fef2d26SRiver Riddle       };
4603fef2d26SRiver Riddle       auto dependentOps = getSlice(op, filterAlloc);
4613fef2d26SRiver Riddle       // Create a loop and move instructions from the Op slice into the loop.
4623fef2d26SRiver Riddle       OpBuilder builder(op);
463a54f4eaeSMogball       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
464a54f4eaeSMogball       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
465a54f4eaeSMogball       auto numIter =
466a54f4eaeSMogball           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
4673fef2d26SRiver Riddle       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
4683fef2d26SRiver Riddle       for (Operation *it : dependentOps) {
4693fef2d26SRiver Riddle         it->moveBefore(forOp.getBody()->getTerminator());
4703fef2d26SRiver Riddle       }
4713fef2d26SRiver Riddle       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
4723fef2d26SRiver Riddle       // break up the original op and let the patterns propagate.
4733fef2d26SRiver Riddle       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
4743fef2d26SRiver Riddle           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
4753fef2d26SRiver Riddle           map);
476037f0995SKazu Hirata       if (ops) {
4773fef2d26SRiver Riddle         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
4783fef2d26SRiver Riddle         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
4793fef2d26SRiver Riddle       }
4803fef2d26SRiver Riddle       return mlir::WalkResult::interrupt();
4813fef2d26SRiver Riddle     });
482627733b5Sthomasraoux     populatePropagateVectorDistributionPatterns(patterns);
48341574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
4843fef2d26SRiver Riddle   }
4853fef2d26SRiver Riddle };
4863fef2d26SRiver Riddle 
4873fef2d26SRiver Riddle struct TestVectorTransferUnrollingPatterns
48841574554SRiver Riddle     : public PassWrapper<TestVectorTransferUnrollingPatterns,
48958ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
4905e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
4915e50dd04SRiver Riddle       TestVectorTransferUnrollingPatterns)
4925e50dd04SRiver Riddle 
4939f122152SChristopher Bate   TestVectorTransferUnrollingPatterns() = default;
TestVectorTransferUnrollingPatterns__anonb56dea510111::TestVectorTransferUnrollingPatterns4949f122152SChristopher Bate   TestVectorTransferUnrollingPatterns(
4959f122152SChristopher Bate       const TestVectorTransferUnrollingPatterns &pass)
4969f122152SChristopher Bate       : PassWrapper(pass) {}
4979f122152SChristopher Bate 
getDependentDialects__anonb56dea510111::TestVectorTransferUnrollingPatterns4983fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
4993fef2d26SRiver Riddle     registry.insert<AffineDialect>();
5003fef2d26SRiver Riddle   }
getArgument__anonb56dea510111::TestVectorTransferUnrollingPatterns501b5e22e6dSMehdi Amini   StringRef getArgument() const final {
502b5e22e6dSMehdi Amini     return "test-vector-transfer-unrolling-patterns";
503b5e22e6dSMehdi Amini   }
getDescription__anonb56dea510111::TestVectorTransferUnrollingPatterns504b5e22e6dSMehdi Amini   StringRef getDescription() const final {
50534ff8573SNicolas Vasilache     return "Test lowering patterns to unroll transfer ops in the vector "
506b5e22e6dSMehdi Amini            "dialect";
507b5e22e6dSMehdi Amini   }
runOnOperation__anonb56dea510111::TestVectorTransferUnrollingPatterns50841574554SRiver Riddle   void runOnOperation() override {
5093fef2d26SRiver Riddle     MLIRContext *ctx = &getContext();
5103fef2d26SRiver Riddle     RewritePatternSet patterns(ctx);
5119f122152SChristopher Bate     UnrollVectorOptions opts;
5129f122152SChristopher Bate     opts.setNativeShape(ArrayRef<int64_t>{2, 2})
5133fef2d26SRiver Riddle         .setFilterConstraint([](Operation *op) {
5143fef2d26SRiver Riddle           return success(
5153fef2d26SRiver Riddle               isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
5169f122152SChristopher Bate         });
5179f122152SChristopher Bate     if (reverseUnrollOrder.getValue()) {
5189f122152SChristopher Bate       opts.setUnrollTraversalOrderFn(
5199f122152SChristopher Bate           [](Operation *op) -> Optional<SmallVector<int64_t>> {
5209f122152SChristopher Bate             int64_t numLoops = 0;
5219f122152SChristopher Bate             if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
5229f122152SChristopher Bate               numLoops = readOp.getVectorType().getRank();
5239f122152SChristopher Bate             else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
5249f122152SChristopher Bate               numLoops = writeOp.getVectorType().getRank();
5259f122152SChristopher Bate             else
5269f122152SChristopher Bate               return None;
5279f122152SChristopher Bate             auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
5289f122152SChristopher Bate             return llvm::to_vector(order);
5299f122152SChristopher Bate           });
5309f122152SChristopher Bate     }
5319f122152SChristopher Bate     populateVectorUnrollPatterns(patterns, opts);
5323fef2d26SRiver Riddle     populateVectorToVectorCanonicalizationPatterns(patterns);
53341574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
5343fef2d26SRiver Riddle   }
5359f122152SChristopher Bate 
5369f122152SChristopher Bate   Option<bool> reverseUnrollOrder{
5379f122152SChristopher Bate       *this, "reverse-unroll-order",
5389f122152SChristopher Bate       llvm::cl::desc(
5399f122152SChristopher Bate           "reverse the order of unrolling of vector transfer operations"),
5409f122152SChristopher Bate       llvm::cl::init(false)};
5413fef2d26SRiver Riddle };
5423fef2d26SRiver Riddle 
5433fef2d26SRiver Riddle struct TestVectorTransferFullPartialSplitPatterns
5443fef2d26SRiver Riddle     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
54558ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns5465e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
5475e50dd04SRiver Riddle       TestVectorTransferFullPartialSplitPatterns)
5485e50dd04SRiver Riddle 
549b5e22e6dSMehdi Amini   StringRef getArgument() const final {
550b5e22e6dSMehdi Amini     return "test-vector-transfer-full-partial-split";
551b5e22e6dSMehdi Amini   }
getDescription__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns552b5e22e6dSMehdi Amini   StringRef getDescription() const final {
55334ff8573SNicolas Vasilache     return "Test lowering patterns to split "
554b5e22e6dSMehdi Amini            "transfer ops via scf.if + linalg ops";
555b5e22e6dSMehdi Amini   }
5563fef2d26SRiver Riddle   TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns5573fef2d26SRiver Riddle   TestVectorTransferFullPartialSplitPatterns(
5583bab9d4eSMehdi Amini       const TestVectorTransferFullPartialSplitPatterns &pass)
5593bab9d4eSMehdi Amini       : PassWrapper(pass) {}
5603fef2d26SRiver Riddle 
getDependentDialects__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns5613fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
5623fef2d26SRiver Riddle     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
5633fef2d26SRiver Riddle                     scf::SCFDialect>();
5643fef2d26SRiver Riddle   }
5653fef2d26SRiver Riddle 
5663fef2d26SRiver Riddle   Option<bool> useLinalgOps{
567ebc81537SAlexander Belyaev       *this, "use-memref-copy",
5683fef2d26SRiver Riddle       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
569ebc81537SAlexander Belyaev                      "memref.copy operations."),
5703fef2d26SRiver Riddle       llvm::cl::init(false)};
runOnOperation__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns57141574554SRiver Riddle   void runOnOperation() override {
5723fef2d26SRiver Riddle     MLIRContext *ctx = &getContext();
5733fef2d26SRiver Riddle     RewritePatternSet patterns(ctx);
5743fef2d26SRiver Riddle     VectorTransformsOptions options;
5753fef2d26SRiver Riddle     if (useLinalgOps)
5763fef2d26SRiver Riddle       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
5773fef2d26SRiver Riddle     else
5783fef2d26SRiver Riddle       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
5793fef2d26SRiver Riddle     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
58041574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
5813fef2d26SRiver Riddle   }
5823fef2d26SRiver Riddle };
5833fef2d26SRiver Riddle 
5843fef2d26SRiver Riddle struct TestVectorTransferOpt
58558ceae95SRiver Riddle     : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferOpt5865e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
5875e50dd04SRiver Riddle 
588b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
getDescription__anonb56dea510111::TestVectorTransferOpt589b5e22e6dSMehdi Amini   StringRef getDescription() const final {
590b5e22e6dSMehdi Amini     return "Test optimization transformations for transfer ops";
591b5e22e6dSMehdi Amini   }
runOnOperation__anonb56dea510111::TestVectorTransferOpt59241574554SRiver Riddle   void runOnOperation() override { transferOpflowOpt(getOperation()); }
5933fef2d26SRiver Riddle };
5943fef2d26SRiver Riddle 
5953fef2d26SRiver Riddle struct TestVectorTransferLoweringPatterns
59641574554SRiver Riddle     : public PassWrapper<TestVectorTransferLoweringPatterns,
59758ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferLoweringPatterns5985e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
5995e50dd04SRiver Riddle       TestVectorTransferLoweringPatterns)
6005e50dd04SRiver Riddle 
6013fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
60257470abcSAlexander Belyaev     registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
6033fef2d26SRiver Riddle   }
getArgument__anonb56dea510111::TestVectorTransferLoweringPatterns604b5e22e6dSMehdi Amini   StringRef getArgument() const final {
605b5e22e6dSMehdi Amini     return "test-vector-transfer-lowering-patterns";
606b5e22e6dSMehdi Amini   }
getDescription__anonb56dea510111::TestVectorTransferLoweringPatterns607b5e22e6dSMehdi Amini   StringRef getDescription() const final {
60834ff8573SNicolas Vasilache     return "Test lowering patterns to lower transfer ops to other vector ops";
609b5e22e6dSMehdi Amini   }
runOnOperation__anonb56dea510111::TestVectorTransferLoweringPatterns61041574554SRiver Riddle   void runOnOperation() override {
6113fef2d26SRiver Riddle     RewritePatternSet patterns(&getContext());
6123fef2d26SRiver Riddle     populateVectorTransferLoweringPatterns(patterns);
613d1a9e9a7SMatthias Springer     populateVectorTransferPermutationMapLoweringPatterns(patterns);
61441574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
6153fef2d26SRiver Riddle   }
6163fef2d26SRiver Riddle };
6173fef2d26SRiver Riddle 
6183fef2d26SRiver Riddle struct TestVectorMultiReductionLoweringPatterns
6193fef2d26SRiver Riddle     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
62058ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
6215e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
6225e50dd04SRiver Riddle       TestVectorMultiReductionLoweringPatterns)
6235e50dd04SRiver Riddle 
624e33f301eSharsh-nod   TestVectorMultiReductionLoweringPatterns() = default;
TestVectorMultiReductionLoweringPatterns__anonb56dea510111::TestVectorMultiReductionLoweringPatterns625e33f301eSharsh-nod   TestVectorMultiReductionLoweringPatterns(
6263bab9d4eSMehdi Amini       const TestVectorMultiReductionLoweringPatterns &pass)
6273bab9d4eSMehdi Amini       : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorMultiReductionLoweringPatterns6283fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
6293fef2d26SRiver Riddle     registry.insert<memref::MemRefDialect>();
6303fef2d26SRiver Riddle   }
getArgument__anonb56dea510111::TestVectorMultiReductionLoweringPatterns631b5e22e6dSMehdi Amini   StringRef getArgument() const final {
632b5e22e6dSMehdi Amini     return "test-vector-multi-reduction-lowering-patterns";
633b5e22e6dSMehdi Amini   }
getDescription__anonb56dea510111::TestVectorMultiReductionLoweringPatterns634b5e22e6dSMehdi Amini   StringRef getDescription() const final {
63534ff8573SNicolas Vasilache     return "Test lowering patterns to lower vector.multi_reduction to other "
636b5e22e6dSMehdi Amini            "vector ops";
637b5e22e6dSMehdi Amini   }
638e33f301eSharsh-nod   Option<bool> useOuterReductions{
639e33f301eSharsh-nod       *this, "use-outer-reductions",
640e33f301eSharsh-nod       llvm::cl::desc("Move reductions to outer most dimensions"),
641e33f301eSharsh-nod       llvm::cl::init(false)};
runOnOperation__anonb56dea510111::TestVectorMultiReductionLoweringPatterns64241574554SRiver Riddle   void runOnOperation() override {
6433fef2d26SRiver Riddle     RewritePatternSet patterns(&getContext());
644176a0ea5SNicolas Vasilache     populateVectorMultiReductionLoweringPatterns(
645176a0ea5SNicolas Vasilache         patterns, useOuterReductions
646176a0ea5SNicolas Vasilache                       ? vector::VectorMultiReductionLowering::InnerParallel
647176a0ea5SNicolas Vasilache                       : vector::VectorMultiReductionLowering::InnerReduction);
64841574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
6493fef2d26SRiver Riddle   }
6503fef2d26SRiver Riddle };
6513fef2d26SRiver Riddle 
652a3dd4e77SAhmed S. Taei struct TestVectorTransferCollapseInnerMostContiguousDims
653a3dd4e77SAhmed S. Taei     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
65458ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
6555e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
6565e50dd04SRiver Riddle       TestVectorTransferCollapseInnerMostContiguousDims)
6575e50dd04SRiver Riddle 
658a3dd4e77SAhmed S. Taei   TestVectorTransferCollapseInnerMostContiguousDims() = default;
659a3dd4e77SAhmed S. Taei   TestVectorTransferCollapseInnerMostContiguousDims(
660322c8914SMehdi Amini       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
661a3dd4e77SAhmed S. Taei 
getDependentDialects__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims662a3dd4e77SAhmed S. Taei   void getDependentDialects(DialectRegistry &registry) const override {
663a3dd4e77SAhmed S. Taei     registry.insert<memref::MemRefDialect, AffineDialect>();
664a3dd4e77SAhmed S. Taei   }
665a3dd4e77SAhmed S. Taei 
getArgument__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims666a3dd4e77SAhmed S. Taei   StringRef getArgument() const final {
667a3dd4e77SAhmed S. Taei     return "test-vector-transfer-collapse-inner-most-dims";
668a3dd4e77SAhmed S. Taei   }
669a3dd4e77SAhmed S. Taei 
getDescription__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims670a3dd4e77SAhmed S. Taei   StringRef getDescription() const final {
67134ff8573SNicolas Vasilache     return "Test lowering patterns that reducedes the rank of the vector "
672a3dd4e77SAhmed S. Taei            "transfer memory and vector operands.";
673a3dd4e77SAhmed S. Taei   }
674a3dd4e77SAhmed S. Taei 
runOnOperation__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims67541574554SRiver Riddle   void runOnOperation() override {
676a3dd4e77SAhmed S. Taei     RewritePatternSet patterns(&getContext());
677a3dd4e77SAhmed S. Taei     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
67841574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
679a3dd4e77SAhmed S. Taei   }
680a3dd4e77SAhmed S. Taei };
681a3dd4e77SAhmed S. Taei 
6821d8cc45bSthomasraoux struct TestVectorReduceToContractPatternsPatterns
6831d8cc45bSthomasraoux     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
68458ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorReduceToContractPatternsPatterns6855e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
6865e50dd04SRiver Riddle       TestVectorReduceToContractPatternsPatterns)
6875e50dd04SRiver Riddle 
6881d8cc45bSthomasraoux   StringRef getArgument() const final {
6891d8cc45bSthomasraoux     return "test-vector-reduction-to-contract-patterns";
6901d8cc45bSthomasraoux   }
getDescription__anonb56dea510111::TestVectorReduceToContractPatternsPatterns6911d8cc45bSthomasraoux   StringRef getDescription() const final {
6921d8cc45bSthomasraoux     return "Test patterns to convert multireduce op to contract and combine "
6931d8cc45bSthomasraoux            "broadcast/transpose to contract";
6941d8cc45bSthomasraoux   }
runOnOperation__anonb56dea510111::TestVectorReduceToContractPatternsPatterns69541574554SRiver Riddle   void runOnOperation() override {
6961d8cc45bSthomasraoux     RewritePatternSet patterns(&getContext());
697d054b80bSNicolas Vasilache     populateVectorReductionToContractPatterns(patterns);
69841574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
6991d8cc45bSthomasraoux   }
7001d8cc45bSthomasraoux };
7011d8cc45bSthomasraoux 
7020aea49a7SBenoit Jacob struct TestVectorTransferDropUnitDimsPatterns
70341574554SRiver Riddle     : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
70458ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns7055e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
7065e50dd04SRiver Riddle       TestVectorTransferDropUnitDimsPatterns)
7075e50dd04SRiver Riddle 
7080aea49a7SBenoit Jacob   StringRef getArgument() const final {
7090aea49a7SBenoit Jacob     return "test-vector-transfer-drop-unit-dims-patterns";
7100aea49a7SBenoit Jacob   }
getDependentDialects__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns7110aea49a7SBenoit Jacob   void getDependentDialects(DialectRegistry &registry) const override {
7120aea49a7SBenoit Jacob     registry.insert<memref::MemRefDialect>();
7130aea49a7SBenoit Jacob   }
runOnOperation__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns71441574554SRiver Riddle   void runOnOperation() override {
7150aea49a7SBenoit Jacob     RewritePatternSet patterns(&getContext());
7160aea49a7SBenoit Jacob     populateVectorTransferDropUnitDimsPatterns(patterns);
71741574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
7180aea49a7SBenoit Jacob   }
7190aea49a7SBenoit Jacob };
7200aea49a7SBenoit Jacob 
721aba437ceSBenoit Jacob struct TestFlattenVectorTransferPatterns
72241574554SRiver Riddle     : public PassWrapper<TestFlattenVectorTransferPatterns,
72358ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestFlattenVectorTransferPatterns7245e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
7255e50dd04SRiver Riddle       TestFlattenVectorTransferPatterns)
7265e50dd04SRiver Riddle 
727aba437ceSBenoit Jacob   StringRef getArgument() const final {
728aba437ceSBenoit Jacob     return "test-vector-transfer-flatten-patterns";
729aba437ceSBenoit Jacob   }
getDescription__anonb56dea510111::TestFlattenVectorTransferPatterns730aba437ceSBenoit Jacob   StringRef getDescription() const final {
731aba437ceSBenoit Jacob     return "Test patterns to rewrite contiguous row-major N-dimensional "
732aba437ceSBenoit Jacob            "vector.transfer_{read,write} ops into 1D transfers";
733aba437ceSBenoit Jacob   }
getDependentDialects__anonb56dea510111::TestFlattenVectorTransferPatterns734aba437ceSBenoit Jacob   void getDependentDialects(DialectRegistry &registry) const override {
735aba437ceSBenoit Jacob     registry.insert<memref::MemRefDialect>();
736aba437ceSBenoit Jacob   }
runOnOperation__anonb56dea510111::TestFlattenVectorTransferPatterns73741574554SRiver Riddle   void runOnOperation() override {
738aba437ceSBenoit Jacob     RewritePatternSet patterns(&getContext());
739aba437ceSBenoit Jacob     populateFlattenVectorTransferPatterns(patterns);
74041574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
741aba437ceSBenoit Jacob   }
742aba437ceSBenoit Jacob };
743aba437ceSBenoit Jacob 
74480e0bf1aSharsh struct TestVectorScanLowering
74558ceae95SRiver Riddle     : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorScanLowering7465e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
7475e50dd04SRiver Riddle 
74880e0bf1aSharsh   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
getDescription__anonb56dea510111::TestVectorScanLowering74980e0bf1aSharsh   StringRef getDescription() const final {
75080e0bf1aSharsh     return "Test lowering patterns that lower the scan op in the vector "
75180e0bf1aSharsh            "dialect";
75280e0bf1aSharsh   }
runOnOperation__anonb56dea510111::TestVectorScanLowering75380e0bf1aSharsh   void runOnOperation() override {
75480e0bf1aSharsh     RewritePatternSet patterns(&getContext());
75580e0bf1aSharsh     populateVectorScanLoweringPatterns(patterns);
75680e0bf1aSharsh     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
75780e0bf1aSharsh   }
75880e0bf1aSharsh };
75980e0bf1aSharsh 
760d02f10d9SThomas Raoux /// Allocate shared memory for a single warp to test lowering of
761d02f10d9SThomas Raoux /// WarpExecuteOnLane0Op.
allocateGlobalSharedMemory(Location loc,OpBuilder & builder,WarpExecuteOnLane0Op warpOp,Type type)762d02f10d9SThomas Raoux static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
763d02f10d9SThomas Raoux                                         WarpExecuteOnLane0Op warpOp,
764d02f10d9SThomas Raoux                                         Type type) {
765d02f10d9SThomas Raoux   static constexpr int64_t kSharedMemorySpace = 3;
766d02f10d9SThomas Raoux   // Compute type of shared memory buffer.
767d02f10d9SThomas Raoux   MemRefType memrefType;
768d02f10d9SThomas Raoux   if (auto vectorType = type.dyn_cast<VectorType>()) {
769d02f10d9SThomas Raoux     memrefType =
770d02f10d9SThomas Raoux         MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
771d02f10d9SThomas Raoux                         kSharedMemorySpace);
772d02f10d9SThomas Raoux   } else {
773d02f10d9SThomas Raoux     memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
774d02f10d9SThomas Raoux   }
775d02f10d9SThomas Raoux 
776d02f10d9SThomas Raoux   // Get symbol table holding all shared memory globals.
777d02f10d9SThomas Raoux   ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
778d02f10d9SThomas Raoux   SymbolTable symbolTable(moduleOp);
779d02f10d9SThomas Raoux 
780d02f10d9SThomas Raoux   // Create a pretty name.
781d02f10d9SThomas Raoux   SmallString<64> buf;
782d02f10d9SThomas Raoux   llvm::raw_svector_ostream os(buf);
783d02f10d9SThomas Raoux   interleave(memrefType.getShape(), os, "x");
784d02f10d9SThomas Raoux   os << "x" << memrefType.getElementType();
785d02f10d9SThomas Raoux   std::string symbolName = (Twine("__shared_") + os.str()).str();
786d02f10d9SThomas Raoux 
787d02f10d9SThomas Raoux   auto ip = builder.saveInsertionPoint();
788d02f10d9SThomas Raoux   builder.setInsertionPoint(moduleOp);
789d02f10d9SThomas Raoux   auto global = builder.create<memref::GlobalOp>(
790d02f10d9SThomas Raoux       loc,
791d02f10d9SThomas Raoux       /*sym_name=*/symbolName,
792d02f10d9SThomas Raoux       /*sym_visibility=*/builder.getStringAttr("private"),
793d02f10d9SThomas Raoux       /*type=*/memrefType,
794d02f10d9SThomas Raoux       /*initial_value=*/Attribute(),
795d02f10d9SThomas Raoux       /*constant=*/false,
796d02f10d9SThomas Raoux       /*alignment=*/IntegerAttr());
797d02f10d9SThomas Raoux   symbolTable.insert(global);
798d02f10d9SThomas Raoux   // The symbol table inserts at the end of the module, but globals are a bit
799d02f10d9SThomas Raoux   // nicer if they are at the beginning.
800d02f10d9SThomas Raoux   global->moveBefore(&moduleOp.front());
801d02f10d9SThomas Raoux 
802d02f10d9SThomas Raoux   builder.restoreInsertionPoint(ip);
803d02f10d9SThomas Raoux   return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
804d02f10d9SThomas Raoux }
805d02f10d9SThomas Raoux 
warpReduction(Location loc,OpBuilder & builder,Value input,CombiningKind kind,uint32_t size)8066834803cSThomas Raoux static Value warpReduction(Location loc, OpBuilder &builder, Value input,
8076834803cSThomas Raoux                            CombiningKind kind, uint32_t size) {
8086834803cSThomas Raoux   Value laneVal = input;
8096834803cSThomas Raoux   // Parallel reduction using butterfly shuffles.
8106834803cSThomas Raoux   for (uint64_t i = 1; i < size; i <<= 1) {
8116834803cSThomas Raoux     Value shuffled = builder
8126834803cSThomas Raoux                          .create<gpu::ShuffleOp>(loc, laneVal, i,
8136834803cSThomas Raoux                                                  /*width=*/size,
8146834803cSThomas Raoux                                                  /*mode=*/gpu::ShuffleMode::XOR)
8156834803cSThomas Raoux                          .result();
8166834803cSThomas Raoux     laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
8176834803cSThomas Raoux   }
8186834803cSThomas Raoux   return laneVal;
8196834803cSThomas Raoux }
8206834803cSThomas Raoux 
821d02f10d9SThomas Raoux struct TestVectorDistribution
822d02f10d9SThomas Raoux     : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorDistribution823d02f10d9SThomas Raoux   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
824d02f10d9SThomas Raoux 
825d02f10d9SThomas Raoux   void getDependentDialects(DialectRegistry &registry) const override {
826ed0288f7SThomas Raoux     registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
827ed0288f7SThomas Raoux                     AffineDialect>();
828d02f10d9SThomas Raoux   }
829d02f10d9SThomas Raoux 
getArgument__anonb56dea510111::TestVectorDistribution830d02f10d9SThomas Raoux   StringRef getArgument() const final { return "test-vector-warp-distribute"; }
getDescription__anonb56dea510111::TestVectorDistribution831d02f10d9SThomas Raoux   StringRef getDescription() const final {
832d02f10d9SThomas Raoux     return "Test vector warp distribute transformation and lowering patterns";
833d02f10d9SThomas Raoux   }
834d02f10d9SThomas Raoux   TestVectorDistribution() = default;
TestVectorDistribution__anonb56dea510111::TestVectorDistribution835d02f10d9SThomas Raoux   TestVectorDistribution(const TestVectorDistribution &pass)
836d02f10d9SThomas Raoux       : PassWrapper(pass) {}
837d02f10d9SThomas Raoux 
838d02f10d9SThomas Raoux   Option<bool> warpOpToSCF{
839d02f10d9SThomas Raoux       *this, "rewrite-warp-ops-to-scf-if",
840d02f10d9SThomas Raoux       llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
841d02f10d9SThomas Raoux       llvm::cl::init(false)};
842d02f10d9SThomas Raoux 
843ed0288f7SThomas Raoux   Option<bool> distributeTransferWriteOps{
844ed0288f7SThomas Raoux       *this, "distribute-transfer-write",
845ed0288f7SThomas Raoux       llvm::cl::desc("Test distribution of transfer write"),
846ed0288f7SThomas Raoux       llvm::cl::init(false)};
847ed0288f7SThomas Raoux 
848ed0288f7SThomas Raoux   Option<bool> hoistUniform{*this, "hoist-uniform",
849ed0288f7SThomas Raoux                             llvm::cl::desc("Test hoist uniform"),
850ed0288f7SThomas Raoux                             llvm::cl::init(false)};
851ed0288f7SThomas Raoux 
85276cf33daSThomas Raoux   Option<bool> propagateDistribution{
85376cf33daSThomas Raoux       *this, "propagate-distribution",
85476cf33daSThomas Raoux       llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
85576cf33daSThomas Raoux 
runOnOperation__anonb56dea510111::TestVectorDistribution856d02f10d9SThomas Raoux   void runOnOperation() override {
857d02f10d9SThomas Raoux     RewritePatternSet patterns(&getContext());
858ed0288f7SThomas Raoux 
859ed0288f7SThomas Raoux     getOperation().walk([&](Operation *op) {
860ed0288f7SThomas Raoux       if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
861ed0288f7SThomas Raoux         if (hoistUniform) {
862ed0288f7SThomas Raoux           moveScalarUniformCode(warpOp);
863ed0288f7SThomas Raoux         }
864ed0288f7SThomas Raoux         WalkResult::interrupt();
865ed0288f7SThomas Raoux       }
866ed0288f7SThomas Raoux     });
867ed0288f7SThomas Raoux     MLIRContext *ctx = &getContext();
868ed0288f7SThomas Raoux     if (distributeTransferWriteOps) {
869ed0288f7SThomas Raoux       auto distributionFn = [](vector::TransferWriteOp writeOp) {
870ed0288f7SThomas Raoux         // Create a map (d0, d1) -> (d1) to distribute along the inner
871ed0288f7SThomas Raoux         // dimension. Once we support n-d distribution we can add more
872ed0288f7SThomas Raoux         // complex cases.
873ed0288f7SThomas Raoux         int64_t vecRank = writeOp.getVectorType().getRank();
874ed0288f7SThomas Raoux         OpBuilder builder(writeOp.getContext());
875ed0288f7SThomas Raoux         auto map =
876ed0288f7SThomas Raoux             AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
877ed0288f7SThomas Raoux         return map;
878ed0288f7SThomas Raoux       };
879ed0288f7SThomas Raoux       RewritePatternSet patterns(ctx);
880ed0288f7SThomas Raoux       populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
881ed0288f7SThomas Raoux       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
882ed0288f7SThomas Raoux     }
88376cf33daSThomas Raoux     if (propagateDistribution) {
88476cf33daSThomas Raoux       RewritePatternSet patterns(ctx);
88576cf33daSThomas Raoux       vector::populatePropagateWarpVectorDistributionPatterns(patterns);
8866834803cSThomas Raoux       vector::populateDistributeReduction(patterns, warpReduction);
88776cf33daSThomas Raoux       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
88876cf33daSThomas Raoux     }
889d02f10d9SThomas Raoux     WarpExecuteOnLane0LoweringOptions options;
890d02f10d9SThomas Raoux     options.warpAllocationFn = allocateGlobalSharedMemory;
891d02f10d9SThomas Raoux     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
892d02f10d9SThomas Raoux                                       WarpExecuteOnLane0Op warpOp) {
893d02f10d9SThomas Raoux       builder.create<gpu::BarrierOp>(loc);
894d02f10d9SThomas Raoux     };
895d02f10d9SThomas Raoux     // Test on one pattern in isolation.
896d02f10d9SThomas Raoux     if (warpOpToSCF) {
897d02f10d9SThomas Raoux       populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
898d02f10d9SThomas Raoux       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
899d02f10d9SThomas Raoux       return;
900d02f10d9SThomas Raoux     }
901d02f10d9SThomas Raoux   }
902d02f10d9SThomas Raoux };
903d02f10d9SThomas Raoux 
904be0a7e9fSMehdi Amini } // namespace
9053fef2d26SRiver Riddle 
9063fef2d26SRiver Riddle namespace mlir {
9073fef2d26SRiver Riddle namespace test {
registerTestVectorLowerings()90834ff8573SNicolas Vasilache void registerTestVectorLowerings() {
90934ff8573SNicolas Vasilache   PassRegistration<TestVectorToVectorLowering>();
9103fef2d26SRiver Riddle 
91134ff8573SNicolas Vasilache   PassRegistration<TestVectorContractionLowering>();
91234ff8573SNicolas Vasilache 
91334ff8573SNicolas Vasilache   PassRegistration<TestVectorTransposeLowering>();
9143fef2d26SRiver Riddle 
915b5e22e6dSMehdi Amini   PassRegistration<TestVectorUnrollingPatterns>();
9163fef2d26SRiver Riddle 
917b5e22e6dSMehdi Amini   PassRegistration<TestVectorTransferUnrollingPatterns>();
9183fef2d26SRiver Riddle 
919b5e22e6dSMehdi Amini   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
9203fef2d26SRiver Riddle 
921b5e22e6dSMehdi Amini   PassRegistration<TestVectorDistributePatterns>();
9223fef2d26SRiver Riddle 
923b5e22e6dSMehdi Amini   PassRegistration<TestVectorToLoopPatterns>();
9243fef2d26SRiver Riddle 
925b5e22e6dSMehdi Amini   PassRegistration<TestVectorTransferOpt>();
9263fef2d26SRiver Riddle 
927b5e22e6dSMehdi Amini   PassRegistration<TestVectorTransferLoweringPatterns>();
9283fef2d26SRiver Riddle 
929b5e22e6dSMehdi Amini   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
930a3dd4e77SAhmed S. Taei 
931a3dd4e77SAhmed S. Taei   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
9321d8cc45bSthomasraoux 
9331d8cc45bSthomasraoux   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
9340aea49a7SBenoit Jacob 
9350aea49a7SBenoit Jacob   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
936aba437ceSBenoit Jacob 
937aba437ceSBenoit Jacob   PassRegistration<TestFlattenVectorTransferPatterns>();
93880e0bf1aSharsh 
93980e0bf1aSharsh   PassRegistration<TestVectorScanLowering>();
940d02f10d9SThomas Raoux 
941d02f10d9SThomas Raoux   PassRegistration<TestVectorDistribution>();
9423fef2d26SRiver Riddle }
9433fef2d26SRiver Riddle } // namespace test
9443fef2d26SRiver Riddle } // namespace mlir
945