134ff8573SNicolas Vasilache //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle
93fef2d26SRiver Riddle #include <type_traits>
103fef2d26SRiver Riddle
113fef2d26SRiver Riddle #include "mlir/Analysis/SliceAnalysis.h"
123fef2d26SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
1323aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
14d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15b2729fdaSNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
1734ff8573SNicolas Vasilache #include "mlir/Dialect/Linalg/Passes.h"
1834ff8573SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
193fef2d26SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
219f122152SChristopher Bate #include "mlir/Dialect/Vector/IR/VectorOps.h"
22d02f10d9SThomas Raoux #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
2399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
243fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
2534ff8573SNicolas Vasilache #include "mlir/Pass/PassManager.h"
269f122152SChristopher Bate #include "mlir/Support/LLVM.h"
273fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
283fef2d26SRiver Riddle
293fef2d26SRiver Riddle using namespace mlir;
3034ff8573SNicolas Vasilache using namespace mlir::linalg;
313fef2d26SRiver Riddle using namespace mlir::vector;
32d054b80bSNicolas Vasilache
333fef2d26SRiver Riddle namespace {
343fef2d26SRiver Riddle
3534ff8573SNicolas Vasilache struct TestVectorToVectorLowering
3658ceae95SRiver Riddle : public PassWrapper<TestVectorToVectorLowering,
3758ceae95SRiver Riddle OperationPass<func::FuncOp>> {
385e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
395e50dd04SRiver Riddle
4034ff8573SNicolas Vasilache TestVectorToVectorLowering() = default;
TestVectorToVectorLowering__anonb56dea510111::TestVectorToVectorLowering413bab9d4eSMehdi Amini TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
423bab9d4eSMehdi Amini : PassWrapper(pass) {}
getArgument__anonb56dea510111::TestVectorToVectorLowering43b5e22e6dSMehdi Amini StringRef getArgument() const final {
4434ff8573SNicolas Vasilache return "test-vector-to-vector-lowering";
45b5e22e6dSMehdi Amini }
getDescription__anonb56dea510111::TestVectorToVectorLowering46b5e22e6dSMehdi Amini StringRef getDescription() const final {
4734ff8573SNicolas Vasilache return "Test lowering patterns between ops in the vector dialect";
48b5e22e6dSMehdi Amini }
493fef2d26SRiver Riddle
getDependentDialects__anonb56dea510111::TestVectorToVectorLowering503fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override {
513fef2d26SRiver Riddle registry.insert<AffineDialect>();
523fef2d26SRiver Riddle }
533fef2d26SRiver Riddle
543fef2d26SRiver Riddle Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
553fef2d26SRiver Riddle llvm::cl::init(false)};
563fef2d26SRiver Riddle
runOnOperation__anonb56dea510111::TestVectorToVectorLowering5741574554SRiver Riddle void runOnOperation() override {
583fef2d26SRiver Riddle auto *ctx = &getContext();
593fef2d26SRiver Riddle RewritePatternSet patterns(ctx);
603fef2d26SRiver Riddle if (unroll) {
6129102538Sthomasraoux populateVectorUnrollPatterns(
6229102538Sthomasraoux patterns,
633fef2d26SRiver Riddle UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
643fef2d26SRiver Riddle filter));
653fef2d26SRiver Riddle }
663fef2d26SRiver Riddle populateVectorToVectorCanonicalizationPatterns(patterns);
673fef2d26SRiver Riddle populateBubbleVectorBitCastOpPatterns(patterns);
683fef2d26SRiver Riddle populateCastAwayVectorLeadingOneDimPatterns(patterns);
6941574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
703fef2d26SRiver Riddle }
713fef2d26SRiver Riddle
723fef2d26SRiver Riddle private:
733fef2d26SRiver Riddle // Return the target shape based on op type.
getShape__anonb56dea510111::TestVectorToVectorLowering743fef2d26SRiver Riddle static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
75dec8af70SRiver Riddle if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
763fef2d26SRiver Riddle return SmallVector<int64_t, 4>(2, 2);
773fef2d26SRiver Riddle if (isa<vector::ContractionOp>(op))
783fef2d26SRiver Riddle return SmallVector<int64_t, 4>(3, 2);
7929102538Sthomasraoux // For transfer ops, just propagate the shape coming from
8029102538Sthomasraoux // InsertStridedSlices/ExtractStridedSlices.
8129102538Sthomasraoux if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
8229102538Sthomasraoux VectorType dstVec;
8329102538Sthomasraoux for (Operation *users : readOp->getUsers()) {
8429102538Sthomasraoux auto extract = dyn_cast<ExtractStridedSliceOp>(users);
8529102538Sthomasraoux if (!extract)
8629102538Sthomasraoux return llvm::None;
8729102538Sthomasraoux auto vecType = extract.getResult().getType().cast<VectorType>();
8829102538Sthomasraoux if (dstVec && dstVec != vecType)
8929102538Sthomasraoux return llvm::None;
9029102538Sthomasraoux dstVec = vecType;
9129102538Sthomasraoux }
9229102538Sthomasraoux return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
9329102538Sthomasraoux dstVec.getShape().end());
9429102538Sthomasraoux }
9529102538Sthomasraoux if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
967c38fd60SJacques Pienaar auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
9729102538Sthomasraoux if (!insert)
9829102538Sthomasraoux return llvm::None;
9929102538Sthomasraoux ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
10029102538Sthomasraoux return SmallVector<int64_t, 4>(shape.begin(), shape.end());
10129102538Sthomasraoux }
1023fef2d26SRiver Riddle return llvm::None;
1033fef2d26SRiver Riddle }
1043fef2d26SRiver Riddle
filter__anonb56dea510111::TestVectorToVectorLowering1053fef2d26SRiver Riddle static LogicalResult filter(Operation *op) {
106dec8af70SRiver Riddle return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
107dec8af70SRiver Riddle ContractionOp, TransferReadOp, TransferWriteOp>(op));
1083fef2d26SRiver Riddle }
1093fef2d26SRiver Riddle };
1103fef2d26SRiver Riddle
11134ff8573SNicolas Vasilache struct TestVectorContractionLowering
11258ceae95SRiver Riddle : public PassWrapper<TestVectorContractionLowering,
11358ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorContractionLowering1145e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
1155e50dd04SRiver Riddle
116b5e22e6dSMehdi Amini StringRef getArgument() const final {
11734ff8573SNicolas Vasilache return "test-vector-contraction-lowering";
118b5e22e6dSMehdi Amini }
getDescription__anonb56dea510111::TestVectorContractionLowering119b5e22e6dSMehdi Amini StringRef getDescription() const final {
12034ff8573SNicolas Vasilache return "Test lowering patterns that lower contract ops in the vector "
121b5e22e6dSMehdi Amini "dialect";
122b5e22e6dSMehdi Amini }
12334ff8573SNicolas Vasilache TestVectorContractionLowering() = default;
TestVectorContractionLowering__anonb56dea510111::TestVectorContractionLowering1243bab9d4eSMehdi Amini TestVectorContractionLowering(const TestVectorContractionLowering &pass)
1253bab9d4eSMehdi Amini : PassWrapper(pass) {}
1263fef2d26SRiver Riddle
1273fef2d26SRiver Riddle Option<bool> lowerToFlatMatrix{
1283fef2d26SRiver Riddle *this, "vector-lower-matrix-intrinsics",
1293fef2d26SRiver Riddle llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
1303fef2d26SRiver Riddle llvm::cl::init(false)};
1313fef2d26SRiver Riddle Option<bool> lowerToOuterProduct{
1323fef2d26SRiver Riddle *this, "vector-outerproduct",
1333fef2d26SRiver Riddle llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
1343fef2d26SRiver Riddle llvm::cl::init(false)};
1353fef2d26SRiver Riddle Option<bool> lowerToFilterOuterProduct{
1363fef2d26SRiver Riddle *this, "vector-filter-outerproduct",
1373fef2d26SRiver Riddle llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
1383fef2d26SRiver Riddle "vectors of size 4."),
1393fef2d26SRiver Riddle llvm::cl::init(false)};
14089aaa2d0SThomas Raoux Option<bool> lowerToParallelArith{
14189aaa2d0SThomas Raoux *this, "vector-parallel-arith",
14289aaa2d0SThomas Raoux llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
14389aaa2d0SThomas Raoux llvm::cl::init(false)};
1443fef2d26SRiver Riddle
runOnOperation__anonb56dea510111::TestVectorContractionLowering14541574554SRiver Riddle void runOnOperation() override {
1463fef2d26SRiver Riddle RewritePatternSet patterns(&getContext());
1473fef2d26SRiver Riddle
1483fef2d26SRiver Riddle // Test on one pattern in isolation.
1493fef2d26SRiver Riddle if (lowerToOuterProduct) {
1503fef2d26SRiver Riddle VectorContractLowering lowering = VectorContractLowering::OuterProduct;
1513fef2d26SRiver Riddle VectorTransformsOptions options{lowering};
1523fef2d26SRiver Riddle patterns.add<ContractionOpToOuterProductOpLowering>(options,
1533fef2d26SRiver Riddle &getContext());
15441574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1553fef2d26SRiver Riddle return;
1563fef2d26SRiver Riddle }
1573fef2d26SRiver Riddle
1583fef2d26SRiver Riddle // Test on one pattern in isolation.
1593fef2d26SRiver Riddle if (lowerToFilterOuterProduct) {
1603fef2d26SRiver Riddle VectorContractLowering lowering = VectorContractLowering::OuterProduct;
1613fef2d26SRiver Riddle VectorTransformsOptions options{lowering};
1623fef2d26SRiver Riddle patterns.add<ContractionOpToOuterProductOpLowering>(
1633fef2d26SRiver Riddle options, &getContext(), [](vector::ContractionOp op) {
1643fef2d26SRiver Riddle // Only lowers vector.contract where the lhs as a type vector<MxNx?>
1653fef2d26SRiver Riddle // where M is not 4.
1663fef2d26SRiver Riddle if (op.getRhsType().getShape()[0] == 4)
1673fef2d26SRiver Riddle return failure();
1683fef2d26SRiver Riddle return success();
1693fef2d26SRiver Riddle });
17041574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1713fef2d26SRiver Riddle return;
1723fef2d26SRiver Riddle }
1733fef2d26SRiver Riddle
17489aaa2d0SThomas Raoux if (lowerToParallelArith) {
17589aaa2d0SThomas Raoux vector::populateVectorContractLoweringPatterns(
17689aaa2d0SThomas Raoux patterns,
17789aaa2d0SThomas Raoux vector::VectorTransformsOptions().setVectorTransformsOptions(
17889aaa2d0SThomas Raoux vector::VectorContractLowering::ParallelArith));
17989aaa2d0SThomas Raoux (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
18089aaa2d0SThomas Raoux return;
18189aaa2d0SThomas Raoux }
18289aaa2d0SThomas Raoux
1833fef2d26SRiver Riddle // Test on all contract lowering patterns.
1843fef2d26SRiver Riddle VectorContractLowering contractLowering = VectorContractLowering::Dot;
1853fef2d26SRiver Riddle if (lowerToFlatMatrix)
1863fef2d26SRiver Riddle contractLowering = VectorContractLowering::Matmul;
187176a0ea5SNicolas Vasilache VectorMultiReductionLowering vectorMultiReductionLowering =
188176a0ea5SNicolas Vasilache VectorMultiReductionLowering::InnerParallel;
18934ff8573SNicolas Vasilache VectorTransformsOptions options{contractLowering,
19034ff8573SNicolas Vasilache vectorMultiReductionLowering,
19134ff8573SNicolas Vasilache VectorTransposeLowering()};
1923964c1dbSLei Zhang populateVectorBroadcastLoweringPatterns(patterns);
1933fef2d26SRiver Riddle populateVectorContractLoweringPatterns(patterns, options);
1943964c1dbSLei Zhang populateVectorMaskOpLoweringPatterns(patterns);
1953964c1dbSLei Zhang populateVectorShapeCastLoweringPatterns(patterns);
19641574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1973fef2d26SRiver Riddle }
1983fef2d26SRiver Riddle };
1993fef2d26SRiver Riddle
20034ff8573SNicolas Vasilache struct TestVectorTransposeLowering
20158ceae95SRiver Riddle : public PassWrapper<TestVectorTransposeLowering,
20258ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransposeLowering2035e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
2045e50dd04SRiver Riddle
20534ff8573SNicolas Vasilache StringRef getArgument() const final {
20634ff8573SNicolas Vasilache return "test-vector-transpose-lowering";
20734ff8573SNicolas Vasilache }
getDescription__anonb56dea510111::TestVectorTransposeLowering20834ff8573SNicolas Vasilache StringRef getDescription() const final {
20934ff8573SNicolas Vasilache return "Test lowering patterns that lower contract ops in the vector "
21034ff8573SNicolas Vasilache "dialect";
21134ff8573SNicolas Vasilache }
21234ff8573SNicolas Vasilache TestVectorTransposeLowering() = default;
TestVectorTransposeLowering__anonb56dea510111::TestVectorTransposeLowering2133bab9d4eSMehdi Amini TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
2143bab9d4eSMehdi Amini : PassWrapper(pass) {}
21534ff8573SNicolas Vasilache
21634ff8573SNicolas Vasilache Option<bool> lowerToEltwise{
21734ff8573SNicolas Vasilache *this, "eltwise",
21834ff8573SNicolas Vasilache llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
21934ff8573SNicolas Vasilache llvm::cl::init(false)};
22034ff8573SNicolas Vasilache Option<bool> lowerToFlatTranspose{
22134ff8573SNicolas Vasilache *this, "flat",
22234ff8573SNicolas Vasilache llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
22334ff8573SNicolas Vasilache llvm::cl::init(false)};
22434ff8573SNicolas Vasilache Option<bool> lowerToShuffleTranspose{
22534ff8573SNicolas Vasilache *this, "shuffle",
22634ff8573SNicolas Vasilache llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
22734ff8573SNicolas Vasilache llvm::cl::init(false)};
22834ff8573SNicolas Vasilache Option<bool> lowerToAvx2{
22934ff8573SNicolas Vasilache *this, "avx2",
23034ff8573SNicolas Vasilache llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
23134ff8573SNicolas Vasilache llvm::cl::init(false)};
23234ff8573SNicolas Vasilache
getDependentDialects__anonb56dea510111::TestVectorTransposeLowering233b2729fdaSNicolas Vasilache void getDependentDialects(DialectRegistry ®istry) const override {
234b2729fdaSNicolas Vasilache registry.insert<LLVM::LLVMDialect>();
235b2729fdaSNicolas Vasilache }
236b2729fdaSNicolas Vasilache
runOnOperation__anonb56dea510111::TestVectorTransposeLowering23741574554SRiver Riddle void runOnOperation() override {
23834ff8573SNicolas Vasilache RewritePatternSet patterns(&getContext());
23934ff8573SNicolas Vasilache
24034ff8573SNicolas Vasilache // Test on one pattern in isolation.
24134ff8573SNicolas Vasilache // Explicitly disable shape_cast lowering.
24234ff8573SNicolas Vasilache LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
24334ff8573SNicolas Vasilache .enableVectorTransposeLowering()
24434ff8573SNicolas Vasilache .enableShapeCastLowering(false);
24534ff8573SNicolas Vasilache if (lowerToEltwise) {
24634ff8573SNicolas Vasilache options = options.setVectorTransformsOptions(
24734ff8573SNicolas Vasilache VectorTransformsOptions().setVectorTransposeLowering(
24834ff8573SNicolas Vasilache VectorTransposeLowering::EltWise));
24934ff8573SNicolas Vasilache }
25034ff8573SNicolas Vasilache if (lowerToFlatTranspose) {
25134ff8573SNicolas Vasilache options = options.setVectorTransformsOptions(
25234ff8573SNicolas Vasilache VectorTransformsOptions().setVectorTransposeLowering(
25334ff8573SNicolas Vasilache VectorTransposeLowering::Flat));
25434ff8573SNicolas Vasilache }
25534ff8573SNicolas Vasilache if (lowerToShuffleTranspose) {
25634ff8573SNicolas Vasilache options = options.setVectorTransformsOptions(
25734ff8573SNicolas Vasilache VectorTransformsOptions().setVectorTransposeLowering(
25834ff8573SNicolas Vasilache VectorTransposeLowering::Shuffle));
25934ff8573SNicolas Vasilache }
26034ff8573SNicolas Vasilache if (lowerToAvx2) {
26134ff8573SNicolas Vasilache options = options.enableAVX2Lowering().setAVX2LoweringOptions(
26234ff8573SNicolas Vasilache x86vector::avx2::LoweringOptions().setTransposeOptions(
26334ff8573SNicolas Vasilache x86vector::avx2::TransposeLoweringOptions()
26434ff8573SNicolas Vasilache .lower4x8xf32()
26534ff8573SNicolas Vasilache .lower8x8xf32()));
26634ff8573SNicolas Vasilache }
26734ff8573SNicolas Vasilache
26836550692SRiver Riddle OpPassManager dynamicPM("func.func");
26934ff8573SNicolas Vasilache dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
27041574554SRiver Riddle if (failed(runPipeline(dynamicPM, getOperation())))
27134ff8573SNicolas Vasilache return signalPassFailure();
27234ff8573SNicolas Vasilache }
27334ff8573SNicolas Vasilache };
27434ff8573SNicolas Vasilache
2753fef2d26SRiver Riddle struct TestVectorUnrollingPatterns
27658ceae95SRiver Riddle : public PassWrapper<TestVectorUnrollingPatterns,
27758ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorUnrollingPatterns2785e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
2795e50dd04SRiver Riddle
280b5e22e6dSMehdi Amini StringRef getArgument() const final {
281b5e22e6dSMehdi Amini return "test-vector-unrolling-patterns";
282b5e22e6dSMehdi Amini }
getDescription__anonb56dea510111::TestVectorUnrollingPatterns283b5e22e6dSMehdi Amini StringRef getDescription() const final {
28434ff8573SNicolas Vasilache return "Test lowering patterns to unroll contract ops in the vector "
285b5e22e6dSMehdi Amini "dialect";
286b5e22e6dSMehdi Amini }
2873fef2d26SRiver Riddle TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anonb56dea510111::TestVectorUnrollingPatterns2883bab9d4eSMehdi Amini TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
2893bab9d4eSMehdi Amini : PassWrapper(pass) {}
runOnOperation__anonb56dea510111::TestVectorUnrollingPatterns29041574554SRiver Riddle void runOnOperation() override {
2913fef2d26SRiver Riddle MLIRContext *ctx = &getContext();
2923fef2d26SRiver Riddle RewritePatternSet patterns(ctx);
29329102538Sthomasraoux populateVectorUnrollPatterns(
29429102538Sthomasraoux patterns, UnrollVectorOptions()
2953fef2d26SRiver Riddle .setNativeShape(ArrayRef<int64_t>{2, 2})
2963fef2d26SRiver Riddle .setFilterConstraint([](Operation *op) {
297f69175b1SThomas Raoux return success(isa<arith::AddFOp, vector::FMAOp,
298f69175b1SThomas Raoux vector::MultiDimReductionOp>(op));
2993fef2d26SRiver Riddle }));
300de5022c7SMatthias Springer populateVectorUnrollPatterns(
301de5022c7SMatthias Springer patterns, UnrollVectorOptions()
302de5022c7SMatthias Springer .setNativeShape(ArrayRef<int64_t>{2})
303de5022c7SMatthias Springer .setFilterConstraint([](Operation *op) {
304de5022c7SMatthias Springer return success(isa<vector::ReductionOp>(op));
305de5022c7SMatthias Springer }));
3065b1b7108SThomas Raoux populateVectorUnrollPatterns(
3075b1b7108SThomas Raoux patterns, UnrollVectorOptions()
3085b1b7108SThomas Raoux .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
3095b1b7108SThomas Raoux .setFilterConstraint([](Operation *op) {
3105b1b7108SThomas Raoux return success(isa<vector::TransposeOp>(op));
3115b1b7108SThomas Raoux }));
3123fef2d26SRiver Riddle
3133fef2d26SRiver Riddle if (unrollBasedOnType) {
3143fef2d26SRiver Riddle UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
3153fef2d26SRiver Riddle [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
3163fef2d26SRiver Riddle vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
3179f122152SChristopher Bate SmallVector<int64_t, 4> nativeShape(
3189f122152SChristopher Bate contractOp.getIteratorTypes().size(), 4);
3199f122152SChristopher Bate Type lhsType = contractOp.getLhsType().getElementType();
3209f122152SChristopher Bate nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
3213fef2d26SRiver Riddle return nativeShape;
3223fef2d26SRiver Riddle };
3239f122152SChristopher Bate
3249f122152SChristopher Bate UnrollVectorOptions opts;
3259f122152SChristopher Bate opts.setNativeShapeFn(nativeShapeFn)
3269f122152SChristopher Bate .setFilterConstraint(
3279f122152SChristopher Bate [](Operation *op) { return success(isa<ContractionOp>(op)); });
3289f122152SChristopher Bate
3299f122152SChristopher Bate if (!unrollOrder.empty()) {
3309f122152SChristopher Bate opts.setUnrollTraversalOrderFn([this](Operation *op)
3319f122152SChristopher Bate -> Optional<SmallVector<int64_t>> {
3329f122152SChristopher Bate vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
3339f122152SChristopher Bate if (contractOp.getIteratorTypes().size() == unrollOrder.size())
3349f122152SChristopher Bate return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end());
3359f122152SChristopher Bate return None;
3369f122152SChristopher Bate });
3379f122152SChristopher Bate }
3389f122152SChristopher Bate populateVectorUnrollPatterns(patterns, opts);
3399f122152SChristopher Bate } else {
3409f122152SChristopher Bate auto nativeShapeFn =
3419f122152SChristopher Bate [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
3429f122152SChristopher Bate auto contractOp = dyn_cast<ContractionOp>(op);
3439f122152SChristopher Bate if (!contractOp)
3449f122152SChristopher Bate return None;
3459f122152SChristopher Bate return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
3469f122152SChristopher Bate };
34753fe155bSChristopher Bate populateVectorUnrollPatterns(patterns,
34853fe155bSChristopher Bate UnrollVectorOptions()
34953fe155bSChristopher Bate .setNativeShapeFn(nativeShapeFn)
35053fe155bSChristopher Bate .setFilterConstraint([](Operation *op) {
35153fe155bSChristopher Bate return success(isa<ContractionOp>(op));
35253fe155bSChristopher Bate }));
3533fef2d26SRiver Riddle }
3543fef2d26SRiver Riddle populateVectorToVectorCanonicalizationPatterns(patterns);
35541574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
3563fef2d26SRiver Riddle }
3573fef2d26SRiver Riddle
3589f122152SChristopher Bate ListOption<int64_t> unrollOrder{*this, "unroll-order",
359*62a4e6abSFangrui Song llvm::cl::desc("set the unroll order")};
3609f122152SChristopher Bate
3613fef2d26SRiver Riddle Option<bool> unrollBasedOnType{
3623fef2d26SRiver Riddle *this, "unroll-based-on-type",
3633fef2d26SRiver Riddle llvm::cl::desc("Set the unroll factor based on type of the operation"),
3643fef2d26SRiver Riddle llvm::cl::init(false)};
3653fef2d26SRiver Riddle };
3663fef2d26SRiver Riddle
3673fef2d26SRiver Riddle struct TestVectorDistributePatterns
36858ceae95SRiver Riddle : public PassWrapper<TestVectorDistributePatterns,
36958ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorDistributePatterns3705e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns)
3715e50dd04SRiver Riddle
372b5e22e6dSMehdi Amini StringRef getArgument() const final {
373b5e22e6dSMehdi Amini return "test-vector-distribute-patterns";
374b5e22e6dSMehdi Amini }
getDescription__anonb56dea510111::TestVectorDistributePatterns375b5e22e6dSMehdi Amini StringRef getDescription() const final {
37634ff8573SNicolas Vasilache return "Test lowering patterns to distribute vector ops in the vector "
377b5e22e6dSMehdi Amini "dialect";
378b5e22e6dSMehdi Amini }
3793fef2d26SRiver Riddle TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anonb56dea510111::TestVectorDistributePatterns3803bab9d4eSMehdi Amini TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
3813bab9d4eSMehdi Amini : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorDistributePatterns3823fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override {
3833fef2d26SRiver Riddle registry.insert<VectorDialect>();
3843fef2d26SRiver Riddle registry.insert<AffineDialect>();
3853fef2d26SRiver Riddle }
3863fef2d26SRiver Riddle ListOption<int32_t> multiplicity{
3876edef135SRiver Riddle *this, "distribution-multiplicity",
3883fef2d26SRiver Riddle llvm::cl::desc("Set the multiplicity used for distributing vector")};
3893fef2d26SRiver Riddle
runOnOperation__anonb56dea510111::TestVectorDistributePatterns39041574554SRiver Riddle void runOnOperation() override {
3913fef2d26SRiver Riddle MLIRContext *ctx = &getContext();
3923fef2d26SRiver Riddle RewritePatternSet patterns(ctx);
39358ceae95SRiver Riddle func::FuncOp func = getOperation();
394a54f4eaeSMogball func.walk([&](arith::AddFOp op) {
3953fef2d26SRiver Riddle OpBuilder builder(op);
3963fef2d26SRiver Riddle if (auto vecType = op.getType().dyn_cast<VectorType>()) {
3973fef2d26SRiver Riddle SmallVector<int64_t, 2> mul;
3983fef2d26SRiver Riddle SmallVector<AffineExpr, 2> perm;
3993fef2d26SRiver Riddle SmallVector<Value, 2> ids;
4003fef2d26SRiver Riddle unsigned count = 0;
4013fef2d26SRiver Riddle // Remove the multiplicity of 1 and calculate the affine map based on
4023fef2d26SRiver Riddle // the multiplicity.
4033fef2d26SRiver Riddle SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
4043fef2d26SRiver Riddle for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
4053fef2d26SRiver Riddle if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
4063fef2d26SRiver Riddle mul.push_back(m[i]);
4073fef2d26SRiver Riddle ids.push_back(func.getArgument(count++));
4083fef2d26SRiver Riddle perm.push_back(getAffineDimExpr(i, ctx));
4093fef2d26SRiver Riddle }
4103fef2d26SRiver Riddle }
4113fef2d26SRiver Riddle auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
4123fef2d26SRiver Riddle perm, ctx);
4133fef2d26SRiver Riddle Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
4143fef2d26SRiver Riddle builder, op.getOperation(), ids, mul, map);
415037f0995SKazu Hirata if (ops) {
4163fef2d26SRiver Riddle SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
4173fef2d26SRiver Riddle op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
4183fef2d26SRiver Riddle extractOp);
4193fef2d26SRiver Riddle }
4203fef2d26SRiver Riddle }
4213fef2d26SRiver Riddle });
422627733b5Sthomasraoux populatePropagateVectorDistributionPatterns(patterns);
42341574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
4243fef2d26SRiver Riddle }
4253fef2d26SRiver Riddle };
4263fef2d26SRiver Riddle
4273fef2d26SRiver Riddle struct TestVectorToLoopPatterns
42858ceae95SRiver Riddle : public PassWrapper<TestVectorToLoopPatterns,
42958ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorToLoopPatterns4305e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns)
4315e50dd04SRiver Riddle
432b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-vector-to-forloop"; }
getDescription__anonb56dea510111::TestVectorToLoopPatterns433b5e22e6dSMehdi Amini StringRef getDescription() const final {
43434ff8573SNicolas Vasilache return "Test lowering patterns to break up a vector op into a for loop";
435b5e22e6dSMehdi Amini }
4363fef2d26SRiver Riddle TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anonb56dea510111::TestVectorToLoopPatterns4373bab9d4eSMehdi Amini TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
4383bab9d4eSMehdi Amini : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorToLoopPatterns4393fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override {
4403fef2d26SRiver Riddle registry.insert<VectorDialect>();
4413fef2d26SRiver Riddle registry.insert<AffineDialect>();
4423fef2d26SRiver Riddle }
4433fef2d26SRiver Riddle Option<int32_t> multiplicity{
4443fef2d26SRiver Riddle *this, "distribution-multiplicity",
4453fef2d26SRiver Riddle llvm::cl::desc("Set the multiplicity used for distributing vector"),
4463fef2d26SRiver Riddle llvm::cl::init(32)};
runOnOperation__anonb56dea510111::TestVectorToLoopPatterns44741574554SRiver Riddle void runOnOperation() override {
4483fef2d26SRiver Riddle MLIRContext *ctx = &getContext();
4493fef2d26SRiver Riddle RewritePatternSet patterns(ctx);
45058ceae95SRiver Riddle func::FuncOp func = getOperation();
451a54f4eaeSMogball func.walk([&](arith::AddFOp op) {
4523fef2d26SRiver Riddle // Check that the operation type can be broken down into a loop.
4533fef2d26SRiver Riddle VectorType type = op.getType().dyn_cast<VectorType>();
4543fef2d26SRiver Riddle if (!type || type.getRank() != 1 ||
4553fef2d26SRiver Riddle type.getNumElements() % multiplicity != 0)
4563fef2d26SRiver Riddle return mlir::WalkResult::advance();
4573fef2d26SRiver Riddle auto filterAlloc = [](Operation *op) {
45823aa5a74SRiver Riddle return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
4593fef2d26SRiver Riddle };
4603fef2d26SRiver Riddle auto dependentOps = getSlice(op, filterAlloc);
4613fef2d26SRiver Riddle // Create a loop and move instructions from the Op slice into the loop.
4623fef2d26SRiver Riddle OpBuilder builder(op);
463a54f4eaeSMogball auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
464a54f4eaeSMogball auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
465a54f4eaeSMogball auto numIter =
466a54f4eaeSMogball builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
4673fef2d26SRiver Riddle auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
4683fef2d26SRiver Riddle for (Operation *it : dependentOps) {
4693fef2d26SRiver Riddle it->moveBefore(forOp.getBody()->getTerminator());
4703fef2d26SRiver Riddle }
4713fef2d26SRiver Riddle auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
4723fef2d26SRiver Riddle // break up the original op and let the patterns propagate.
4733fef2d26SRiver Riddle Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
4743fef2d26SRiver Riddle builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
4753fef2d26SRiver Riddle map);
476037f0995SKazu Hirata if (ops) {
4773fef2d26SRiver Riddle SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
4783fef2d26SRiver Riddle op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
4793fef2d26SRiver Riddle }
4803fef2d26SRiver Riddle return mlir::WalkResult::interrupt();
4813fef2d26SRiver Riddle });
482627733b5Sthomasraoux populatePropagateVectorDistributionPatterns(patterns);
48341574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
4843fef2d26SRiver Riddle }
4853fef2d26SRiver Riddle };
4863fef2d26SRiver Riddle
4873fef2d26SRiver Riddle struct TestVectorTransferUnrollingPatterns
48841574554SRiver Riddle : public PassWrapper<TestVectorTransferUnrollingPatterns,
48958ceae95SRiver Riddle OperationPass<func::FuncOp>> {
4905e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
4915e50dd04SRiver Riddle TestVectorTransferUnrollingPatterns)
4925e50dd04SRiver Riddle
4939f122152SChristopher Bate TestVectorTransferUnrollingPatterns() = default;
TestVectorTransferUnrollingPatterns__anonb56dea510111::TestVectorTransferUnrollingPatterns4949f122152SChristopher Bate TestVectorTransferUnrollingPatterns(
4959f122152SChristopher Bate const TestVectorTransferUnrollingPatterns &pass)
4969f122152SChristopher Bate : PassWrapper(pass) {}
4979f122152SChristopher Bate
getDependentDialects__anonb56dea510111::TestVectorTransferUnrollingPatterns4983fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override {
4993fef2d26SRiver Riddle registry.insert<AffineDialect>();
5003fef2d26SRiver Riddle }
getArgument__anonb56dea510111::TestVectorTransferUnrollingPatterns501b5e22e6dSMehdi Amini StringRef getArgument() const final {
502b5e22e6dSMehdi Amini return "test-vector-transfer-unrolling-patterns";
503b5e22e6dSMehdi Amini }
getDescription__anonb56dea510111::TestVectorTransferUnrollingPatterns504b5e22e6dSMehdi Amini StringRef getDescription() const final {
50534ff8573SNicolas Vasilache return "Test lowering patterns to unroll transfer ops in the vector "
506b5e22e6dSMehdi Amini "dialect";
507b5e22e6dSMehdi Amini }
runOnOperation__anonb56dea510111::TestVectorTransferUnrollingPatterns50841574554SRiver Riddle void runOnOperation() override {
5093fef2d26SRiver Riddle MLIRContext *ctx = &getContext();
5103fef2d26SRiver Riddle RewritePatternSet patterns(ctx);
5119f122152SChristopher Bate UnrollVectorOptions opts;
5129f122152SChristopher Bate opts.setNativeShape(ArrayRef<int64_t>{2, 2})
5133fef2d26SRiver Riddle .setFilterConstraint([](Operation *op) {
5143fef2d26SRiver Riddle return success(
5153fef2d26SRiver Riddle isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
5169f122152SChristopher Bate });
5179f122152SChristopher Bate if (reverseUnrollOrder.getValue()) {
5189f122152SChristopher Bate opts.setUnrollTraversalOrderFn(
5199f122152SChristopher Bate [](Operation *op) -> Optional<SmallVector<int64_t>> {
5209f122152SChristopher Bate int64_t numLoops = 0;
5219f122152SChristopher Bate if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
5229f122152SChristopher Bate numLoops = readOp.getVectorType().getRank();
5239f122152SChristopher Bate else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
5249f122152SChristopher Bate numLoops = writeOp.getVectorType().getRank();
5259f122152SChristopher Bate else
5269f122152SChristopher Bate return None;
5279f122152SChristopher Bate auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
5289f122152SChristopher Bate return llvm::to_vector(order);
5299f122152SChristopher Bate });
5309f122152SChristopher Bate }
5319f122152SChristopher Bate populateVectorUnrollPatterns(patterns, opts);
5323fef2d26SRiver Riddle populateVectorToVectorCanonicalizationPatterns(patterns);
53341574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
5343fef2d26SRiver Riddle }
5359f122152SChristopher Bate
5369f122152SChristopher Bate Option<bool> reverseUnrollOrder{
5379f122152SChristopher Bate *this, "reverse-unroll-order",
5389f122152SChristopher Bate llvm::cl::desc(
5399f122152SChristopher Bate "reverse the order of unrolling of vector transfer operations"),
5409f122152SChristopher Bate llvm::cl::init(false)};
5413fef2d26SRiver Riddle };
5423fef2d26SRiver Riddle
5433fef2d26SRiver Riddle struct TestVectorTransferFullPartialSplitPatterns
5443fef2d26SRiver Riddle : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
54558ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns5465e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
5475e50dd04SRiver Riddle TestVectorTransferFullPartialSplitPatterns)
5485e50dd04SRiver Riddle
549b5e22e6dSMehdi Amini StringRef getArgument() const final {
550b5e22e6dSMehdi Amini return "test-vector-transfer-full-partial-split";
551b5e22e6dSMehdi Amini }
getDescription__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns552b5e22e6dSMehdi Amini StringRef getDescription() const final {
55334ff8573SNicolas Vasilache return "Test lowering patterns to split "
554b5e22e6dSMehdi Amini "transfer ops via scf.if + linalg ops";
555b5e22e6dSMehdi Amini }
5563fef2d26SRiver Riddle TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns5573fef2d26SRiver Riddle TestVectorTransferFullPartialSplitPatterns(
5583bab9d4eSMehdi Amini const TestVectorTransferFullPartialSplitPatterns &pass)
5593bab9d4eSMehdi Amini : PassWrapper(pass) {}
5603fef2d26SRiver Riddle
getDependentDialects__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns5613fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override {
5623fef2d26SRiver Riddle registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
5633fef2d26SRiver Riddle scf::SCFDialect>();
5643fef2d26SRiver Riddle }
5653fef2d26SRiver Riddle
5663fef2d26SRiver Riddle Option<bool> useLinalgOps{
567ebc81537SAlexander Belyaev *this, "use-memref-copy",
5683fef2d26SRiver Riddle llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
569ebc81537SAlexander Belyaev "memref.copy operations."),
5703fef2d26SRiver Riddle llvm::cl::init(false)};
runOnOperation__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns57141574554SRiver Riddle void runOnOperation() override {
5723fef2d26SRiver Riddle MLIRContext *ctx = &getContext();
5733fef2d26SRiver Riddle RewritePatternSet patterns(ctx);
5743fef2d26SRiver Riddle VectorTransformsOptions options;
5753fef2d26SRiver Riddle if (useLinalgOps)
5763fef2d26SRiver Riddle options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
5773fef2d26SRiver Riddle else
5783fef2d26SRiver Riddle options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
5793fef2d26SRiver Riddle patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
58041574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
5813fef2d26SRiver Riddle }
5823fef2d26SRiver Riddle };
5833fef2d26SRiver Riddle
5843fef2d26SRiver Riddle struct TestVectorTransferOpt
58558ceae95SRiver Riddle : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferOpt5865e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
5875e50dd04SRiver Riddle
588b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-vector-transferop-opt"; }
getDescription__anonb56dea510111::TestVectorTransferOpt589b5e22e6dSMehdi Amini StringRef getDescription() const final {
590b5e22e6dSMehdi Amini return "Test optimization transformations for transfer ops";
591b5e22e6dSMehdi Amini }
runOnOperation__anonb56dea510111::TestVectorTransferOpt59241574554SRiver Riddle void runOnOperation() override { transferOpflowOpt(getOperation()); }
5933fef2d26SRiver Riddle };
5943fef2d26SRiver Riddle
5953fef2d26SRiver Riddle struct TestVectorTransferLoweringPatterns
59641574554SRiver Riddle : public PassWrapper<TestVectorTransferLoweringPatterns,
59758ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferLoweringPatterns5985e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
5995e50dd04SRiver Riddle TestVectorTransferLoweringPatterns)
6005e50dd04SRiver Riddle
6013fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override {
60257470abcSAlexander Belyaev registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
6033fef2d26SRiver Riddle }
getArgument__anonb56dea510111::TestVectorTransferLoweringPatterns604b5e22e6dSMehdi Amini StringRef getArgument() const final {
605b5e22e6dSMehdi Amini return "test-vector-transfer-lowering-patterns";
606b5e22e6dSMehdi Amini }
getDescription__anonb56dea510111::TestVectorTransferLoweringPatterns607b5e22e6dSMehdi Amini StringRef getDescription() const final {
60834ff8573SNicolas Vasilache return "Test lowering patterns to lower transfer ops to other vector ops";
609b5e22e6dSMehdi Amini }
runOnOperation__anonb56dea510111::TestVectorTransferLoweringPatterns61041574554SRiver Riddle void runOnOperation() override {
6113fef2d26SRiver Riddle RewritePatternSet patterns(&getContext());
6123fef2d26SRiver Riddle populateVectorTransferLoweringPatterns(patterns);
613d1a9e9a7SMatthias Springer populateVectorTransferPermutationMapLoweringPatterns(patterns);
61441574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
6153fef2d26SRiver Riddle }
6163fef2d26SRiver Riddle };
6173fef2d26SRiver Riddle
6183fef2d26SRiver Riddle struct TestVectorMultiReductionLoweringPatterns
6193fef2d26SRiver Riddle : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
62058ceae95SRiver Riddle OperationPass<func::FuncOp>> {
6215e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
6225e50dd04SRiver Riddle TestVectorMultiReductionLoweringPatterns)
6235e50dd04SRiver Riddle
624e33f301eSharsh-nod TestVectorMultiReductionLoweringPatterns() = default;
TestVectorMultiReductionLoweringPatterns__anonb56dea510111::TestVectorMultiReductionLoweringPatterns625e33f301eSharsh-nod TestVectorMultiReductionLoweringPatterns(
6263bab9d4eSMehdi Amini const TestVectorMultiReductionLoweringPatterns &pass)
6273bab9d4eSMehdi Amini : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorMultiReductionLoweringPatterns6283fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override {
6293fef2d26SRiver Riddle registry.insert<memref::MemRefDialect>();
6303fef2d26SRiver Riddle }
getArgument__anonb56dea510111::TestVectorMultiReductionLoweringPatterns631b5e22e6dSMehdi Amini StringRef getArgument() const final {
632b5e22e6dSMehdi Amini return "test-vector-multi-reduction-lowering-patterns";
633b5e22e6dSMehdi Amini }
getDescription__anonb56dea510111::TestVectorMultiReductionLoweringPatterns634b5e22e6dSMehdi Amini StringRef getDescription() const final {
63534ff8573SNicolas Vasilache return "Test lowering patterns to lower vector.multi_reduction to other "
636b5e22e6dSMehdi Amini "vector ops";
637b5e22e6dSMehdi Amini }
638e33f301eSharsh-nod Option<bool> useOuterReductions{
639e33f301eSharsh-nod *this, "use-outer-reductions",
640e33f301eSharsh-nod llvm::cl::desc("Move reductions to outer most dimensions"),
641e33f301eSharsh-nod llvm::cl::init(false)};
runOnOperation__anonb56dea510111::TestVectorMultiReductionLoweringPatterns64241574554SRiver Riddle void runOnOperation() override {
6433fef2d26SRiver Riddle RewritePatternSet patterns(&getContext());
644176a0ea5SNicolas Vasilache populateVectorMultiReductionLoweringPatterns(
645176a0ea5SNicolas Vasilache patterns, useOuterReductions
646176a0ea5SNicolas Vasilache ? vector::VectorMultiReductionLowering::InnerParallel
647176a0ea5SNicolas Vasilache : vector::VectorMultiReductionLowering::InnerReduction);
64841574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
6493fef2d26SRiver Riddle }
6503fef2d26SRiver Riddle };
6513fef2d26SRiver Riddle
652a3dd4e77SAhmed S. Taei struct TestVectorTransferCollapseInnerMostContiguousDims
653a3dd4e77SAhmed S. Taei : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
65458ceae95SRiver Riddle OperationPass<func::FuncOp>> {
6555e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
6565e50dd04SRiver Riddle TestVectorTransferCollapseInnerMostContiguousDims)
6575e50dd04SRiver Riddle
658a3dd4e77SAhmed S. Taei TestVectorTransferCollapseInnerMostContiguousDims() = default;
659a3dd4e77SAhmed S. Taei TestVectorTransferCollapseInnerMostContiguousDims(
660322c8914SMehdi Amini const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
661a3dd4e77SAhmed S. Taei
getDependentDialects__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims662a3dd4e77SAhmed S. Taei void getDependentDialects(DialectRegistry ®istry) const override {
663a3dd4e77SAhmed S. Taei registry.insert<memref::MemRefDialect, AffineDialect>();
664a3dd4e77SAhmed S. Taei }
665a3dd4e77SAhmed S. Taei
getArgument__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims666a3dd4e77SAhmed S. Taei StringRef getArgument() const final {
667a3dd4e77SAhmed S. Taei return "test-vector-transfer-collapse-inner-most-dims";
668a3dd4e77SAhmed S. Taei }
669a3dd4e77SAhmed S. Taei
getDescription__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims670a3dd4e77SAhmed S. Taei StringRef getDescription() const final {
67134ff8573SNicolas Vasilache return "Test lowering patterns that reducedes the rank of the vector "
672a3dd4e77SAhmed S. Taei "transfer memory and vector operands.";
673a3dd4e77SAhmed S. Taei }
674a3dd4e77SAhmed S. Taei
runOnOperation__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims67541574554SRiver Riddle void runOnOperation() override {
676a3dd4e77SAhmed S. Taei RewritePatternSet patterns(&getContext());
677a3dd4e77SAhmed S. Taei populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
67841574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
679a3dd4e77SAhmed S. Taei }
680a3dd4e77SAhmed S. Taei };
681a3dd4e77SAhmed S. Taei
6821d8cc45bSthomasraoux struct TestVectorReduceToContractPatternsPatterns
6831d8cc45bSthomasraoux : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
68458ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorReduceToContractPatternsPatterns6855e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
6865e50dd04SRiver Riddle TestVectorReduceToContractPatternsPatterns)
6875e50dd04SRiver Riddle
6881d8cc45bSthomasraoux StringRef getArgument() const final {
6891d8cc45bSthomasraoux return "test-vector-reduction-to-contract-patterns";
6901d8cc45bSthomasraoux }
getDescription__anonb56dea510111::TestVectorReduceToContractPatternsPatterns6911d8cc45bSthomasraoux StringRef getDescription() const final {
6921d8cc45bSthomasraoux return "Test patterns to convert multireduce op to contract and combine "
6931d8cc45bSthomasraoux "broadcast/transpose to contract";
6941d8cc45bSthomasraoux }
runOnOperation__anonb56dea510111::TestVectorReduceToContractPatternsPatterns69541574554SRiver Riddle void runOnOperation() override {
6961d8cc45bSthomasraoux RewritePatternSet patterns(&getContext());
697d054b80bSNicolas Vasilache populateVectorReductionToContractPatterns(patterns);
69841574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
6991d8cc45bSthomasraoux }
7001d8cc45bSthomasraoux };
7011d8cc45bSthomasraoux
7020aea49a7SBenoit Jacob struct TestVectorTransferDropUnitDimsPatterns
70341574554SRiver Riddle : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
70458ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns7055e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
7065e50dd04SRiver Riddle TestVectorTransferDropUnitDimsPatterns)
7075e50dd04SRiver Riddle
7080aea49a7SBenoit Jacob StringRef getArgument() const final {
7090aea49a7SBenoit Jacob return "test-vector-transfer-drop-unit-dims-patterns";
7100aea49a7SBenoit Jacob }
getDependentDialects__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns7110aea49a7SBenoit Jacob void getDependentDialects(DialectRegistry ®istry) const override {
7120aea49a7SBenoit Jacob registry.insert<memref::MemRefDialect>();
7130aea49a7SBenoit Jacob }
runOnOperation__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns71441574554SRiver Riddle void runOnOperation() override {
7150aea49a7SBenoit Jacob RewritePatternSet patterns(&getContext());
7160aea49a7SBenoit Jacob populateVectorTransferDropUnitDimsPatterns(patterns);
71741574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
7180aea49a7SBenoit Jacob }
7190aea49a7SBenoit Jacob };
7200aea49a7SBenoit Jacob
721aba437ceSBenoit Jacob struct TestFlattenVectorTransferPatterns
72241574554SRiver Riddle : public PassWrapper<TestFlattenVectorTransferPatterns,
72358ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestFlattenVectorTransferPatterns7245e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
7255e50dd04SRiver Riddle TestFlattenVectorTransferPatterns)
7265e50dd04SRiver Riddle
727aba437ceSBenoit Jacob StringRef getArgument() const final {
728aba437ceSBenoit Jacob return "test-vector-transfer-flatten-patterns";
729aba437ceSBenoit Jacob }
getDescription__anonb56dea510111::TestFlattenVectorTransferPatterns730aba437ceSBenoit Jacob StringRef getDescription() const final {
731aba437ceSBenoit Jacob return "Test patterns to rewrite contiguous row-major N-dimensional "
732aba437ceSBenoit Jacob "vector.transfer_{read,write} ops into 1D transfers";
733aba437ceSBenoit Jacob }
getDependentDialects__anonb56dea510111::TestFlattenVectorTransferPatterns734aba437ceSBenoit Jacob void getDependentDialects(DialectRegistry ®istry) const override {
735aba437ceSBenoit Jacob registry.insert<memref::MemRefDialect>();
736aba437ceSBenoit Jacob }
runOnOperation__anonb56dea510111::TestFlattenVectorTransferPatterns73741574554SRiver Riddle void runOnOperation() override {
738aba437ceSBenoit Jacob RewritePatternSet patterns(&getContext());
739aba437ceSBenoit Jacob populateFlattenVectorTransferPatterns(patterns);
74041574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
741aba437ceSBenoit Jacob }
742aba437ceSBenoit Jacob };
743aba437ceSBenoit Jacob
74480e0bf1aSharsh struct TestVectorScanLowering
74558ceae95SRiver Riddle : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorScanLowering7465e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
7475e50dd04SRiver Riddle
74880e0bf1aSharsh StringRef getArgument() const final { return "test-vector-scan-lowering"; }
getDescription__anonb56dea510111::TestVectorScanLowering74980e0bf1aSharsh StringRef getDescription() const final {
75080e0bf1aSharsh return "Test lowering patterns that lower the scan op in the vector "
75180e0bf1aSharsh "dialect";
75280e0bf1aSharsh }
runOnOperation__anonb56dea510111::TestVectorScanLowering75380e0bf1aSharsh void runOnOperation() override {
75480e0bf1aSharsh RewritePatternSet patterns(&getContext());
75580e0bf1aSharsh populateVectorScanLoweringPatterns(patterns);
75680e0bf1aSharsh (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
75780e0bf1aSharsh }
75880e0bf1aSharsh };
75980e0bf1aSharsh
760d02f10d9SThomas Raoux /// Allocate shared memory for a single warp to test lowering of
761d02f10d9SThomas Raoux /// WarpExecuteOnLane0Op.
allocateGlobalSharedMemory(Location loc,OpBuilder & builder,WarpExecuteOnLane0Op warpOp,Type type)762d02f10d9SThomas Raoux static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
763d02f10d9SThomas Raoux WarpExecuteOnLane0Op warpOp,
764d02f10d9SThomas Raoux Type type) {
765d02f10d9SThomas Raoux static constexpr int64_t kSharedMemorySpace = 3;
766d02f10d9SThomas Raoux // Compute type of shared memory buffer.
767d02f10d9SThomas Raoux MemRefType memrefType;
768d02f10d9SThomas Raoux if (auto vectorType = type.dyn_cast<VectorType>()) {
769d02f10d9SThomas Raoux memrefType =
770d02f10d9SThomas Raoux MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
771d02f10d9SThomas Raoux kSharedMemorySpace);
772d02f10d9SThomas Raoux } else {
773d02f10d9SThomas Raoux memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
774d02f10d9SThomas Raoux }
775d02f10d9SThomas Raoux
776d02f10d9SThomas Raoux // Get symbol table holding all shared memory globals.
777d02f10d9SThomas Raoux ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
778d02f10d9SThomas Raoux SymbolTable symbolTable(moduleOp);
779d02f10d9SThomas Raoux
780d02f10d9SThomas Raoux // Create a pretty name.
781d02f10d9SThomas Raoux SmallString<64> buf;
782d02f10d9SThomas Raoux llvm::raw_svector_ostream os(buf);
783d02f10d9SThomas Raoux interleave(memrefType.getShape(), os, "x");
784d02f10d9SThomas Raoux os << "x" << memrefType.getElementType();
785d02f10d9SThomas Raoux std::string symbolName = (Twine("__shared_") + os.str()).str();
786d02f10d9SThomas Raoux
787d02f10d9SThomas Raoux auto ip = builder.saveInsertionPoint();
788d02f10d9SThomas Raoux builder.setInsertionPoint(moduleOp);
789d02f10d9SThomas Raoux auto global = builder.create<memref::GlobalOp>(
790d02f10d9SThomas Raoux loc,
791d02f10d9SThomas Raoux /*sym_name=*/symbolName,
792d02f10d9SThomas Raoux /*sym_visibility=*/builder.getStringAttr("private"),
793d02f10d9SThomas Raoux /*type=*/memrefType,
794d02f10d9SThomas Raoux /*initial_value=*/Attribute(),
795d02f10d9SThomas Raoux /*constant=*/false,
796d02f10d9SThomas Raoux /*alignment=*/IntegerAttr());
797d02f10d9SThomas Raoux symbolTable.insert(global);
798d02f10d9SThomas Raoux // The symbol table inserts at the end of the module, but globals are a bit
799d02f10d9SThomas Raoux // nicer if they are at the beginning.
800d02f10d9SThomas Raoux global->moveBefore(&moduleOp.front());
801d02f10d9SThomas Raoux
802d02f10d9SThomas Raoux builder.restoreInsertionPoint(ip);
803d02f10d9SThomas Raoux return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
804d02f10d9SThomas Raoux }
805d02f10d9SThomas Raoux
warpReduction(Location loc,OpBuilder & builder,Value input,CombiningKind kind,uint32_t size)8066834803cSThomas Raoux static Value warpReduction(Location loc, OpBuilder &builder, Value input,
8076834803cSThomas Raoux CombiningKind kind, uint32_t size) {
8086834803cSThomas Raoux Value laneVal = input;
8096834803cSThomas Raoux // Parallel reduction using butterfly shuffles.
8106834803cSThomas Raoux for (uint64_t i = 1; i < size; i <<= 1) {
8116834803cSThomas Raoux Value shuffled = builder
8126834803cSThomas Raoux .create<gpu::ShuffleOp>(loc, laneVal, i,
8136834803cSThomas Raoux /*width=*/size,
8146834803cSThomas Raoux /*mode=*/gpu::ShuffleMode::XOR)
8156834803cSThomas Raoux .result();
8166834803cSThomas Raoux laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
8176834803cSThomas Raoux }
8186834803cSThomas Raoux return laneVal;
8196834803cSThomas Raoux }
8206834803cSThomas Raoux
821d02f10d9SThomas Raoux struct TestVectorDistribution
822d02f10d9SThomas Raoux : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorDistribution823d02f10d9SThomas Raoux MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
824d02f10d9SThomas Raoux
825d02f10d9SThomas Raoux void getDependentDialects(DialectRegistry ®istry) const override {
826ed0288f7SThomas Raoux registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
827ed0288f7SThomas Raoux AffineDialect>();
828d02f10d9SThomas Raoux }
829d02f10d9SThomas Raoux
getArgument__anonb56dea510111::TestVectorDistribution830d02f10d9SThomas Raoux StringRef getArgument() const final { return "test-vector-warp-distribute"; }
getDescription__anonb56dea510111::TestVectorDistribution831d02f10d9SThomas Raoux StringRef getDescription() const final {
832d02f10d9SThomas Raoux return "Test vector warp distribute transformation and lowering patterns";
833d02f10d9SThomas Raoux }
834d02f10d9SThomas Raoux TestVectorDistribution() = default;
TestVectorDistribution__anonb56dea510111::TestVectorDistribution835d02f10d9SThomas Raoux TestVectorDistribution(const TestVectorDistribution &pass)
836d02f10d9SThomas Raoux : PassWrapper(pass) {}
837d02f10d9SThomas Raoux
838d02f10d9SThomas Raoux Option<bool> warpOpToSCF{
839d02f10d9SThomas Raoux *this, "rewrite-warp-ops-to-scf-if",
840d02f10d9SThomas Raoux llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
841d02f10d9SThomas Raoux llvm::cl::init(false)};
842d02f10d9SThomas Raoux
843ed0288f7SThomas Raoux Option<bool> distributeTransferWriteOps{
844ed0288f7SThomas Raoux *this, "distribute-transfer-write",
845ed0288f7SThomas Raoux llvm::cl::desc("Test distribution of transfer write"),
846ed0288f7SThomas Raoux llvm::cl::init(false)};
847ed0288f7SThomas Raoux
848ed0288f7SThomas Raoux Option<bool> hoistUniform{*this, "hoist-uniform",
849ed0288f7SThomas Raoux llvm::cl::desc("Test hoist uniform"),
850ed0288f7SThomas Raoux llvm::cl::init(false)};
851ed0288f7SThomas Raoux
85276cf33daSThomas Raoux Option<bool> propagateDistribution{
85376cf33daSThomas Raoux *this, "propagate-distribution",
85476cf33daSThomas Raoux llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
85576cf33daSThomas Raoux
runOnOperation__anonb56dea510111::TestVectorDistribution856d02f10d9SThomas Raoux void runOnOperation() override {
857d02f10d9SThomas Raoux RewritePatternSet patterns(&getContext());
858ed0288f7SThomas Raoux
859ed0288f7SThomas Raoux getOperation().walk([&](Operation *op) {
860ed0288f7SThomas Raoux if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
861ed0288f7SThomas Raoux if (hoistUniform) {
862ed0288f7SThomas Raoux moveScalarUniformCode(warpOp);
863ed0288f7SThomas Raoux }
864ed0288f7SThomas Raoux WalkResult::interrupt();
865ed0288f7SThomas Raoux }
866ed0288f7SThomas Raoux });
867ed0288f7SThomas Raoux MLIRContext *ctx = &getContext();
868ed0288f7SThomas Raoux if (distributeTransferWriteOps) {
869ed0288f7SThomas Raoux auto distributionFn = [](vector::TransferWriteOp writeOp) {
870ed0288f7SThomas Raoux // Create a map (d0, d1) -> (d1) to distribute along the inner
871ed0288f7SThomas Raoux // dimension. Once we support n-d distribution we can add more
872ed0288f7SThomas Raoux // complex cases.
873ed0288f7SThomas Raoux int64_t vecRank = writeOp.getVectorType().getRank();
874ed0288f7SThomas Raoux OpBuilder builder(writeOp.getContext());
875ed0288f7SThomas Raoux auto map =
876ed0288f7SThomas Raoux AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
877ed0288f7SThomas Raoux return map;
878ed0288f7SThomas Raoux };
879ed0288f7SThomas Raoux RewritePatternSet patterns(ctx);
880ed0288f7SThomas Raoux populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
881ed0288f7SThomas Raoux (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
882ed0288f7SThomas Raoux }
88376cf33daSThomas Raoux if (propagateDistribution) {
88476cf33daSThomas Raoux RewritePatternSet patterns(ctx);
88576cf33daSThomas Raoux vector::populatePropagateWarpVectorDistributionPatterns(patterns);
8866834803cSThomas Raoux vector::populateDistributeReduction(patterns, warpReduction);
88776cf33daSThomas Raoux (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
88876cf33daSThomas Raoux }
889d02f10d9SThomas Raoux WarpExecuteOnLane0LoweringOptions options;
890d02f10d9SThomas Raoux options.warpAllocationFn = allocateGlobalSharedMemory;
891d02f10d9SThomas Raoux options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
892d02f10d9SThomas Raoux WarpExecuteOnLane0Op warpOp) {
893d02f10d9SThomas Raoux builder.create<gpu::BarrierOp>(loc);
894d02f10d9SThomas Raoux };
895d02f10d9SThomas Raoux // Test on one pattern in isolation.
896d02f10d9SThomas Raoux if (warpOpToSCF) {
897d02f10d9SThomas Raoux populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
898d02f10d9SThomas Raoux (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
899d02f10d9SThomas Raoux return;
900d02f10d9SThomas Raoux }
901d02f10d9SThomas Raoux }
902d02f10d9SThomas Raoux };
903d02f10d9SThomas Raoux
904be0a7e9fSMehdi Amini } // namespace
9053fef2d26SRiver Riddle
9063fef2d26SRiver Riddle namespace mlir {
9073fef2d26SRiver Riddle namespace test {
registerTestVectorLowerings()90834ff8573SNicolas Vasilache void registerTestVectorLowerings() {
90934ff8573SNicolas Vasilache PassRegistration<TestVectorToVectorLowering>();
9103fef2d26SRiver Riddle
91134ff8573SNicolas Vasilache PassRegistration<TestVectorContractionLowering>();
91234ff8573SNicolas Vasilache
91334ff8573SNicolas Vasilache PassRegistration<TestVectorTransposeLowering>();
9143fef2d26SRiver Riddle
915b5e22e6dSMehdi Amini PassRegistration<TestVectorUnrollingPatterns>();
9163fef2d26SRiver Riddle
917b5e22e6dSMehdi Amini PassRegistration<TestVectorTransferUnrollingPatterns>();
9183fef2d26SRiver Riddle
919b5e22e6dSMehdi Amini PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
9203fef2d26SRiver Riddle
921b5e22e6dSMehdi Amini PassRegistration<TestVectorDistributePatterns>();
9223fef2d26SRiver Riddle
923b5e22e6dSMehdi Amini PassRegistration<TestVectorToLoopPatterns>();
9243fef2d26SRiver Riddle
925b5e22e6dSMehdi Amini PassRegistration<TestVectorTransferOpt>();
9263fef2d26SRiver Riddle
927b5e22e6dSMehdi Amini PassRegistration<TestVectorTransferLoweringPatterns>();
9283fef2d26SRiver Riddle
929b5e22e6dSMehdi Amini PassRegistration<TestVectorMultiReductionLoweringPatterns>();
930a3dd4e77SAhmed S. Taei
931a3dd4e77SAhmed S. Taei PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
9321d8cc45bSthomasraoux
9331d8cc45bSthomasraoux PassRegistration<TestVectorReduceToContractPatternsPatterns>();
9340aea49a7SBenoit Jacob
9350aea49a7SBenoit Jacob PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
936aba437ceSBenoit Jacob
937aba437ceSBenoit Jacob PassRegistration<TestFlattenVectorTransferPatterns>();
93880e0bf1aSharsh
93980e0bf1aSharsh PassRegistration<TestVectorScanLowering>();
940d02f10d9SThomas Raoux
941d02f10d9SThomas Raoux PassRegistration<TestVectorDistribution>();
9423fef2d26SRiver Riddle }
9433fef2d26SRiver Riddle } // namespace test
9443fef2d26SRiver Riddle } // namespace mlir
945