1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Pass/PassManager.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 using namespace mlir::vector; 28 29 namespace { 30 31 struct TestVectorToVectorLowering 32 : public PassWrapper<TestVectorToVectorLowering, OperationPass<FuncOp>> { 33 TestVectorToVectorLowering() = default; 34 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 35 : PassWrapper(pass) {} 36 StringRef getArgument() const final { 37 return "test-vector-to-vector-lowering"; 38 } 39 StringRef getDescription() const final { 40 return "Test lowering patterns between ops in the vector dialect"; 41 } 42 43 void getDependentDialects(DialectRegistry ®istry) const override { 44 registry.insert<AffineDialect>(); 45 } 46 47 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 48 llvm::cl::init(false)}; 49 50 void runOnOperation() override { 51 auto *ctx = &getContext(); 52 RewritePatternSet patterns(ctx); 53 if (unroll) { 54 populateVectorUnrollPatterns( 55 patterns, 56 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 57 filter)); 58 } 59 populateVectorToVectorCanonicalizationPatterns(patterns); 60 populateBubbleVectorBitCastOpPatterns(patterns); 61 populateCastAwayVectorLeadingOneDimPatterns(patterns); 62 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 63 } 64 65 private: 66 // Return the target shape based on op type. 67 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 68 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 69 return SmallVector<int64_t, 4>(2, 2); 70 if (isa<vector::ContractionOp>(op)) 71 return SmallVector<int64_t, 4>(3, 2); 72 // For transfer ops, just propagate the shape coming from 73 // InsertStridedSlices/ExtractStridedSlices. 74 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 75 VectorType dstVec; 76 for (Operation *users : readOp->getUsers()) { 77 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 78 if (!extract) 79 return llvm::None; 80 auto vecType = extract.getResult().getType().cast<VectorType>(); 81 if (dstVec && dstVec != vecType) 82 return llvm::None; 83 dstVec = vecType; 84 } 85 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 86 dstVec.getShape().end()); 87 } 88 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 89 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>(); 90 if (!insert) 91 return llvm::None; 92 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 93 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 94 } 95 return llvm::None; 96 } 97 98 static LogicalResult filter(Operation *op) { 99 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 100 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 101 } 102 }; 103 104 struct TestVectorContractionLowering 105 : public PassWrapper<TestVectorContractionLowering, OperationPass<FuncOp>> { 106 StringRef getArgument() const final { 107 return "test-vector-contraction-lowering"; 108 } 109 StringRef getDescription() const final { 110 return "Test lowering patterns that lower contract ops in the vector " 111 "dialect"; 112 } 113 TestVectorContractionLowering() = default; 114 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 115 : PassWrapper(pass) {} 116 117 Option<bool> lowerToFlatMatrix{ 118 *this, "vector-lower-matrix-intrinsics", 119 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 120 llvm::cl::init(false)}; 121 Option<bool> lowerToOuterProduct{ 122 *this, "vector-outerproduct", 123 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 124 llvm::cl::init(false)}; 125 Option<bool> lowerToFilterOuterProduct{ 126 *this, "vector-filter-outerproduct", 127 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 128 "vectors of size 4."), 129 llvm::cl::init(false)}; 130 131 void runOnOperation() override { 132 RewritePatternSet patterns(&getContext()); 133 134 // Test on one pattern in isolation. 135 if (lowerToOuterProduct) { 136 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 137 VectorTransformsOptions options{lowering}; 138 patterns.add<ContractionOpToOuterProductOpLowering>(options, 139 &getContext()); 140 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 141 return; 142 } 143 144 // Test on one pattern in isolation. 145 if (lowerToFilterOuterProduct) { 146 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 147 VectorTransformsOptions options{lowering}; 148 patterns.add<ContractionOpToOuterProductOpLowering>( 149 options, &getContext(), [](vector::ContractionOp op) { 150 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 151 // where M is not 4. 152 if (op.getRhsType().getShape()[0] == 4) 153 return failure(); 154 return success(); 155 }); 156 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 157 return; 158 } 159 160 // Test on all contract lowering patterns. 161 VectorContractLowering contractLowering = VectorContractLowering::Dot; 162 if (lowerToFlatMatrix) 163 contractLowering = VectorContractLowering::Matmul; 164 VectorMultiReductionLowering vectorMultiReductionLowering = 165 VectorMultiReductionLowering::InnerParallel; 166 VectorTransformsOptions options{contractLowering, 167 vectorMultiReductionLowering, 168 VectorTransposeLowering()}; 169 populateVectorBroadcastLoweringPatterns(patterns); 170 populateVectorContractLoweringPatterns(patterns, options); 171 populateVectorMaskOpLoweringPatterns(patterns); 172 populateVectorShapeCastLoweringPatterns(patterns); 173 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 174 } 175 }; 176 177 struct TestVectorTransposeLowering 178 : public PassWrapper<TestVectorTransposeLowering, OperationPass<FuncOp>> { 179 StringRef getArgument() const final { 180 return "test-vector-transpose-lowering"; 181 } 182 StringRef getDescription() const final { 183 return "Test lowering patterns that lower contract ops in the vector " 184 "dialect"; 185 } 186 TestVectorTransposeLowering() = default; 187 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 188 : PassWrapper(pass) {} 189 190 Option<bool> lowerToEltwise{ 191 *this, "eltwise", 192 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 193 llvm::cl::init(false)}; 194 Option<bool> lowerToFlatTranspose{ 195 *this, "flat", 196 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 197 llvm::cl::init(false)}; 198 Option<bool> lowerToShuffleTranspose{ 199 *this, "shuffle", 200 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 201 llvm::cl::init(false)}; 202 Option<bool> lowerToAvx2{ 203 *this, "avx2", 204 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 205 llvm::cl::init(false)}; 206 207 void getDependentDialects(DialectRegistry ®istry) const override { 208 registry.insert<LLVM::LLVMDialect>(); 209 } 210 211 void runOnOperation() override { 212 RewritePatternSet patterns(&getContext()); 213 214 // Test on one pattern in isolation. 215 // Explicitly disable shape_cast lowering. 216 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 217 .enableVectorTransposeLowering() 218 .enableShapeCastLowering(false); 219 if (lowerToEltwise) { 220 options = options.setVectorTransformsOptions( 221 VectorTransformsOptions().setVectorTransposeLowering( 222 VectorTransposeLowering::EltWise)); 223 } 224 if (lowerToFlatTranspose) { 225 options = options.setVectorTransformsOptions( 226 VectorTransformsOptions().setVectorTransposeLowering( 227 VectorTransposeLowering::Flat)); 228 } 229 if (lowerToShuffleTranspose) { 230 options = options.setVectorTransformsOptions( 231 VectorTransformsOptions().setVectorTransposeLowering( 232 VectorTransposeLowering::Shuffle)); 233 } 234 if (lowerToAvx2) { 235 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 236 x86vector::avx2::LoweringOptions().setTransposeOptions( 237 x86vector::avx2::TransposeLoweringOptions() 238 .lower4x8xf32() 239 .lower8x8xf32())); 240 } 241 242 OpPassManager dynamicPM("builtin.func"); 243 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 244 if (failed(runPipeline(dynamicPM, getOperation()))) 245 return signalPassFailure(); 246 } 247 }; 248 249 struct TestVectorUnrollingPatterns 250 : public PassWrapper<TestVectorUnrollingPatterns, OperationPass<FuncOp>> { 251 StringRef getArgument() const final { 252 return "test-vector-unrolling-patterns"; 253 } 254 StringRef getDescription() const final { 255 return "Test lowering patterns to unroll contract ops in the vector " 256 "dialect"; 257 } 258 TestVectorUnrollingPatterns() = default; 259 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 260 : PassWrapper(pass) {} 261 void runOnOperation() override { 262 MLIRContext *ctx = &getContext(); 263 RewritePatternSet patterns(ctx); 264 populateVectorUnrollPatterns( 265 patterns, UnrollVectorOptions() 266 .setNativeShape(ArrayRef<int64_t>{2, 2}) 267 .setFilterConstraint([](Operation *op) { 268 return success(isa<arith::AddFOp, vector::FMAOp, 269 vector::MultiDimReductionOp>(op)); 270 })); 271 272 if (unrollBasedOnType) { 273 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 274 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 275 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 276 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 277 if (auto floatType = contractOp.getLhsType() 278 .getElementType() 279 .dyn_cast<FloatType>()) { 280 if (floatType.getWidth() == 16) { 281 nativeShape[2] = 4; 282 } 283 } 284 return nativeShape; 285 }; 286 populateVectorUnrollPatterns(patterns, 287 UnrollVectorOptions() 288 .setNativeShapeFn(nativeShapeFn) 289 .setFilterConstraint([](Operation *op) { 290 return success(isa<ContractionOp>(op)); 291 })); 292 } else { 293 populateVectorUnrollPatterns( 294 patterns, UnrollVectorOptions() 295 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 296 .setFilterConstraint([](Operation *op) { 297 return success(isa<ContractionOp>(op)); 298 })); 299 } 300 populateVectorToVectorCanonicalizationPatterns(patterns); 301 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 302 } 303 304 Option<bool> unrollBasedOnType{ 305 *this, "unroll-based-on-type", 306 llvm::cl::desc("Set the unroll factor based on type of the operation"), 307 llvm::cl::init(false)}; 308 }; 309 310 struct TestVectorDistributePatterns 311 : public PassWrapper<TestVectorDistributePatterns, OperationPass<FuncOp>> { 312 StringRef getArgument() const final { 313 return "test-vector-distribute-patterns"; 314 } 315 StringRef getDescription() const final { 316 return "Test lowering patterns to distribute vector ops in the vector " 317 "dialect"; 318 } 319 TestVectorDistributePatterns() = default; 320 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 321 : PassWrapper(pass) {} 322 void getDependentDialects(DialectRegistry ®istry) const override { 323 registry.insert<VectorDialect>(); 324 registry.insert<AffineDialect>(); 325 } 326 ListOption<int32_t> multiplicity{ 327 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 328 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 329 330 void runOnOperation() override { 331 MLIRContext *ctx = &getContext(); 332 RewritePatternSet patterns(ctx); 333 FuncOp func = getOperation(); 334 func.walk([&](arith::AddFOp op) { 335 OpBuilder builder(op); 336 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 337 SmallVector<int64_t, 2> mul; 338 SmallVector<AffineExpr, 2> perm; 339 SmallVector<Value, 2> ids; 340 unsigned count = 0; 341 // Remove the multiplicity of 1 and calculate the affine map based on 342 // the multiplicity. 343 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 344 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 345 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 346 mul.push_back(m[i]); 347 ids.push_back(func.getArgument(count++)); 348 perm.push_back(getAffineDimExpr(i, ctx)); 349 } 350 } 351 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 352 perm, ctx); 353 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 354 builder, op.getOperation(), ids, mul, map); 355 if (ops.hasValue()) { 356 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 357 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 358 extractOp); 359 } 360 } 361 }); 362 populatePropagateVectorDistributionPatterns(patterns); 363 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 364 } 365 }; 366 367 struct TestVectorToLoopPatterns 368 : public PassWrapper<TestVectorToLoopPatterns, OperationPass<FuncOp>> { 369 StringRef getArgument() const final { return "test-vector-to-forloop"; } 370 StringRef getDescription() const final { 371 return "Test lowering patterns to break up a vector op into a for loop"; 372 } 373 TestVectorToLoopPatterns() = default; 374 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 375 : PassWrapper(pass) {} 376 void getDependentDialects(DialectRegistry ®istry) const override { 377 registry.insert<VectorDialect>(); 378 registry.insert<AffineDialect>(); 379 } 380 Option<int32_t> multiplicity{ 381 *this, "distribution-multiplicity", 382 llvm::cl::desc("Set the multiplicity used for distributing vector"), 383 llvm::cl::init(32)}; 384 void runOnOperation() override { 385 MLIRContext *ctx = &getContext(); 386 RewritePatternSet patterns(ctx); 387 FuncOp func = getOperation(); 388 func.walk([&](arith::AddFOp op) { 389 // Check that the operation type can be broken down into a loop. 390 VectorType type = op.getType().dyn_cast<VectorType>(); 391 if (!type || type.getRank() != 1 || 392 type.getNumElements() % multiplicity != 0) 393 return mlir::WalkResult::advance(); 394 auto filterAlloc = [](Operation *op) { 395 return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op); 396 }; 397 auto dependentOps = getSlice(op, filterAlloc); 398 // Create a loop and move instructions from the Op slice into the loop. 399 OpBuilder builder(op); 400 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 401 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 402 auto numIter = 403 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 404 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 405 for (Operation *it : dependentOps) { 406 it->moveBefore(forOp.getBody()->getTerminator()); 407 } 408 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 409 // break up the original op and let the patterns propagate. 410 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 411 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 412 map); 413 if (ops.hasValue()) { 414 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 415 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 416 } 417 return mlir::WalkResult::interrupt(); 418 }); 419 populatePropagateVectorDistributionPatterns(patterns); 420 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 421 } 422 }; 423 424 struct TestVectorTransferUnrollingPatterns 425 : public PassWrapper<TestVectorTransferUnrollingPatterns, 426 OperationPass<FuncOp>> { 427 void getDependentDialects(DialectRegistry ®istry) const override { 428 registry.insert<AffineDialect>(); 429 } 430 StringRef getArgument() const final { 431 return "test-vector-transfer-unrolling-patterns"; 432 } 433 StringRef getDescription() const final { 434 return "Test lowering patterns to unroll transfer ops in the vector " 435 "dialect"; 436 } 437 void runOnOperation() override { 438 MLIRContext *ctx = &getContext(); 439 RewritePatternSet patterns(ctx); 440 populateVectorUnrollPatterns( 441 patterns, 442 UnrollVectorOptions() 443 .setNativeShape(ArrayRef<int64_t>{2, 2}) 444 .setFilterConstraint([](Operation *op) { 445 return success( 446 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 447 })); 448 populateVectorToVectorCanonicalizationPatterns(patterns); 449 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 450 } 451 }; 452 453 struct TestVectorTransferFullPartialSplitPatterns 454 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 455 OperationPass<FuncOp>> { 456 StringRef getArgument() const final { 457 return "test-vector-transfer-full-partial-split"; 458 } 459 StringRef getDescription() const final { 460 return "Test lowering patterns to split " 461 "transfer ops via scf.if + linalg ops"; 462 } 463 TestVectorTransferFullPartialSplitPatterns() = default; 464 TestVectorTransferFullPartialSplitPatterns( 465 const TestVectorTransferFullPartialSplitPatterns &pass) 466 : PassWrapper(pass) {} 467 468 void getDependentDialects(DialectRegistry ®istry) const override { 469 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 470 scf::SCFDialect>(); 471 } 472 473 Option<bool> useLinalgOps{ 474 *this, "use-memref-copy", 475 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 476 "memref.copy operations."), 477 llvm::cl::init(false)}; 478 void runOnOperation() override { 479 MLIRContext *ctx = &getContext(); 480 RewritePatternSet patterns(ctx); 481 VectorTransformsOptions options; 482 if (useLinalgOps) 483 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 484 else 485 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 486 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 487 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 488 } 489 }; 490 491 struct TestVectorTransferOpt 492 : public PassWrapper<TestVectorTransferOpt, OperationPass<FuncOp>> { 493 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 494 StringRef getDescription() const final { 495 return "Test optimization transformations for transfer ops"; 496 } 497 void runOnOperation() override { transferOpflowOpt(getOperation()); } 498 }; 499 500 struct TestVectorTransferLoweringPatterns 501 : public PassWrapper<TestVectorTransferLoweringPatterns, 502 OperationPass<FuncOp>> { 503 void getDependentDialects(DialectRegistry ®istry) const override { 504 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 505 } 506 StringRef getArgument() const final { 507 return "test-vector-transfer-lowering-patterns"; 508 } 509 StringRef getDescription() const final { 510 return "Test lowering patterns to lower transfer ops to other vector ops"; 511 } 512 void runOnOperation() override { 513 RewritePatternSet patterns(&getContext()); 514 populateVectorTransferLoweringPatterns(patterns); 515 populateVectorTransferPermutationMapLoweringPatterns(patterns); 516 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 517 } 518 }; 519 520 struct TestVectorMultiReductionLoweringPatterns 521 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 522 OperationPass<FuncOp>> { 523 TestVectorMultiReductionLoweringPatterns() = default; 524 TestVectorMultiReductionLoweringPatterns( 525 const TestVectorMultiReductionLoweringPatterns &pass) 526 : PassWrapper(pass) {} 527 void getDependentDialects(DialectRegistry ®istry) const override { 528 registry.insert<memref::MemRefDialect>(); 529 } 530 StringRef getArgument() const final { 531 return "test-vector-multi-reduction-lowering-patterns"; 532 } 533 StringRef getDescription() const final { 534 return "Test lowering patterns to lower vector.multi_reduction to other " 535 "vector ops"; 536 } 537 Option<bool> useOuterReductions{ 538 *this, "use-outer-reductions", 539 llvm::cl::desc("Move reductions to outer most dimensions"), 540 llvm::cl::init(false)}; 541 void runOnOperation() override { 542 RewritePatternSet patterns(&getContext()); 543 populateVectorMultiReductionLoweringPatterns( 544 patterns, useOuterReductions 545 ? vector::VectorMultiReductionLowering::InnerParallel 546 : vector::VectorMultiReductionLowering::InnerReduction); 547 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 548 } 549 }; 550 551 struct TestVectorTransferCollapseInnerMostContiguousDims 552 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 553 OperationPass<FuncOp>> { 554 TestVectorTransferCollapseInnerMostContiguousDims() = default; 555 TestVectorTransferCollapseInnerMostContiguousDims( 556 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 557 558 void getDependentDialects(DialectRegistry ®istry) const override { 559 registry.insert<memref::MemRefDialect, AffineDialect>(); 560 } 561 562 StringRef getArgument() const final { 563 return "test-vector-transfer-collapse-inner-most-dims"; 564 } 565 566 StringRef getDescription() const final { 567 return "Test lowering patterns that reducedes the rank of the vector " 568 "transfer memory and vector operands."; 569 } 570 571 void runOnOperation() override { 572 RewritePatternSet patterns(&getContext()); 573 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 574 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 575 } 576 }; 577 578 struct TestVectorReduceToContractPatternsPatterns 579 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 580 OperationPass<FuncOp>> { 581 StringRef getArgument() const final { 582 return "test-vector-reduction-to-contract-patterns"; 583 } 584 StringRef getDescription() const final { 585 return "Test patterns to convert multireduce op to contract and combine " 586 "broadcast/transpose to contract"; 587 } 588 void runOnOperation() override { 589 RewritePatternSet patterns(&getContext()); 590 populateVectorReductionToContractPatterns(patterns); 591 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 592 } 593 }; 594 595 struct TestVectorTransferDropUnitDimsPatterns 596 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, 597 OperationPass<FuncOp>> { 598 StringRef getArgument() const final { 599 return "test-vector-transfer-drop-unit-dims-patterns"; 600 } 601 void getDependentDialects(DialectRegistry ®istry) const override { 602 registry.insert<memref::MemRefDialect>(); 603 } 604 void runOnOperation() override { 605 RewritePatternSet patterns(&getContext()); 606 populateVectorTransferDropUnitDimsPatterns(patterns); 607 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 608 } 609 }; 610 611 struct TestFlattenVectorTransferPatterns 612 : public PassWrapper<TestFlattenVectorTransferPatterns, 613 OperationPass<FuncOp>> { 614 StringRef getArgument() const final { 615 return "test-vector-transfer-flatten-patterns"; 616 } 617 StringRef getDescription() const final { 618 return "Test patterns to rewrite contiguous row-major N-dimensional " 619 "vector.transfer_{read,write} ops into 1D transfers"; 620 } 621 void getDependentDialects(DialectRegistry ®istry) const override { 622 registry.insert<memref::MemRefDialect>(); 623 } 624 void runOnOperation() override { 625 RewritePatternSet patterns(&getContext()); 626 populateFlattenVectorTransferPatterns(patterns); 627 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 628 } 629 }; 630 631 struct TestVectorScanLowering 632 : public PassWrapper<TestVectorScanLowering, OperationPass<FuncOp>> { 633 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 634 StringRef getDescription() const final { 635 return "Test lowering patterns that lower the scan op in the vector " 636 "dialect"; 637 } 638 void runOnOperation() override { 639 RewritePatternSet patterns(&getContext()); 640 populateVectorScanLoweringPatterns(patterns); 641 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 642 } 643 }; 644 645 } // namespace 646 647 namespace mlir { 648 namespace test { 649 void registerTestVectorLowerings() { 650 PassRegistration<TestVectorToVectorLowering>(); 651 652 PassRegistration<TestVectorContractionLowering>(); 653 654 PassRegistration<TestVectorTransposeLowering>(); 655 656 PassRegistration<TestVectorUnrollingPatterns>(); 657 658 PassRegistration<TestVectorTransferUnrollingPatterns>(); 659 660 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 661 662 PassRegistration<TestVectorDistributePatterns>(); 663 664 PassRegistration<TestVectorToLoopPatterns>(); 665 666 PassRegistration<TestVectorTransferOpt>(); 667 668 PassRegistration<TestVectorTransferLoweringPatterns>(); 669 670 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 671 672 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 673 674 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 675 676 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 677 678 PassRegistration<TestFlattenVectorTransferPatterns>(); 679 680 PassRegistration<TestVectorScanLowering>(); 681 } 682 } // namespace test 683 } // namespace mlir 684