1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Passes.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Pass/PassManager.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 using namespace mlir::vector; 28 29 namespace { 30 31 struct TestVectorToVectorLowering 32 : public PassWrapper<TestVectorToVectorLowering, OperationPass<FuncOp>> { 33 TestVectorToVectorLowering() = default; 34 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 35 : PassWrapper(pass) {} 36 StringRef getArgument() const final { 37 return "test-vector-to-vector-lowering"; 38 } 39 StringRef getDescription() const final { 40 return "Test lowering patterns between ops in the vector dialect"; 41 } 42 43 void getDependentDialects(DialectRegistry ®istry) const override { 44 registry.insert<AffineDialect>(); 45 } 46 47 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 48 llvm::cl::init(false)}; 49 50 void runOnOperation() override { 51 auto *ctx = &getContext(); 52 RewritePatternSet patterns(ctx); 53 if (unroll) { 54 populateVectorUnrollPatterns( 55 patterns, 56 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 57 filter)); 58 } 59 populateVectorToVectorCanonicalizationPatterns(patterns); 60 populateBubbleVectorBitCastOpPatterns(patterns); 61 populateCastAwayVectorLeadingOneDimPatterns(patterns); 62 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 63 } 64 65 private: 66 // Return the target shape based on op type. 67 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 68 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 69 return SmallVector<int64_t, 4>(2, 2); 70 if (isa<vector::ContractionOp>(op)) 71 return SmallVector<int64_t, 4>(3, 2); 72 // For transfer ops, just propagate the shape coming from 73 // InsertStridedSlices/ExtractStridedSlices. 74 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 75 VectorType dstVec; 76 for (Operation *users : readOp->getUsers()) { 77 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 78 if (!extract) 79 return llvm::None; 80 auto vecType = extract.getResult().getType().cast<VectorType>(); 81 if (dstVec && dstVec != vecType) 82 return llvm::None; 83 dstVec = vecType; 84 } 85 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 86 dstVec.getShape().end()); 87 } 88 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 89 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>(); 90 if (!insert) 91 return llvm::None; 92 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 93 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 94 } 95 return llvm::None; 96 } 97 98 static LogicalResult filter(Operation *op) { 99 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 100 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 101 } 102 }; 103 104 struct TestVectorContractionLowering 105 : public PassWrapper<TestVectorContractionLowering, OperationPass<FuncOp>> { 106 StringRef getArgument() const final { 107 return "test-vector-contraction-lowering"; 108 } 109 StringRef getDescription() const final { 110 return "Test lowering patterns that lower contract ops in the vector " 111 "dialect"; 112 } 113 TestVectorContractionLowering() = default; 114 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 115 : PassWrapper(pass) {} 116 117 Option<bool> lowerToFlatMatrix{ 118 *this, "vector-lower-matrix-intrinsics", 119 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 120 llvm::cl::init(false)}; 121 Option<bool> lowerToOuterProduct{ 122 *this, "vector-outerproduct", 123 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 124 llvm::cl::init(false)}; 125 Option<bool> lowerToFilterOuterProduct{ 126 *this, "vector-filter-outerproduct", 127 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 128 "vectors of size 4."), 129 llvm::cl::init(false)}; 130 131 void runOnOperation() override { 132 RewritePatternSet patterns(&getContext()); 133 134 // Test on one pattern in isolation. 135 if (lowerToOuterProduct) { 136 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 137 VectorTransformsOptions options{lowering}; 138 patterns.add<ContractionOpToOuterProductOpLowering>(options, 139 &getContext()); 140 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 141 return; 142 } 143 144 // Test on one pattern in isolation. 145 if (lowerToFilterOuterProduct) { 146 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 147 VectorTransformsOptions options{lowering}; 148 patterns.add<ContractionOpToOuterProductOpLowering>( 149 options, &getContext(), [](vector::ContractionOp op) { 150 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 151 // where M is not 4. 152 if (op.getRhsType().getShape()[0] == 4) 153 return failure(); 154 return success(); 155 }); 156 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 157 return; 158 } 159 160 // Test on all contract lowering patterns. 161 VectorContractLowering contractLowering = VectorContractLowering::Dot; 162 if (lowerToFlatMatrix) 163 contractLowering = VectorContractLowering::Matmul; 164 VectorMultiReductionLowering vectorMultiReductionLowering = 165 VectorMultiReductionLowering::InnerParallel; 166 VectorTransformsOptions options{contractLowering, 167 vectorMultiReductionLowering, 168 VectorTransposeLowering()}; 169 populateVectorBroadcastLoweringPatterns(patterns); 170 populateVectorContractLoweringPatterns(patterns, options); 171 populateVectorMaskOpLoweringPatterns(patterns); 172 populateVectorShapeCastLoweringPatterns(patterns); 173 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 174 } 175 }; 176 177 struct TestVectorTransposeLowering 178 : public PassWrapper<TestVectorTransposeLowering, OperationPass<FuncOp>> { 179 StringRef getArgument() const final { 180 return "test-vector-transpose-lowering"; 181 } 182 StringRef getDescription() const final { 183 return "Test lowering patterns that lower contract ops in the vector " 184 "dialect"; 185 } 186 TestVectorTransposeLowering() = default; 187 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 188 : PassWrapper(pass) {} 189 190 Option<bool> lowerToEltwise{ 191 *this, "eltwise", 192 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 193 llvm::cl::init(false)}; 194 Option<bool> lowerToFlatTranspose{ 195 *this, "flat", 196 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 197 llvm::cl::init(false)}; 198 Option<bool> lowerToShuffleTranspose{ 199 *this, "shuffle", 200 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 201 llvm::cl::init(false)}; 202 Option<bool> lowerToAvx2{ 203 *this, "avx2", 204 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 205 llvm::cl::init(false)}; 206 207 void getDependentDialects(DialectRegistry ®istry) const override { 208 registry.insert<LLVM::LLVMDialect>(); 209 } 210 211 void runOnOperation() override { 212 RewritePatternSet patterns(&getContext()); 213 214 // Test on one pattern in isolation. 215 // Explicitly disable shape_cast lowering. 216 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 217 .enableVectorTransposeLowering() 218 .enableShapeCastLowering(false); 219 if (lowerToEltwise) { 220 options = options.setVectorTransformsOptions( 221 VectorTransformsOptions().setVectorTransposeLowering( 222 VectorTransposeLowering::EltWise)); 223 } 224 if (lowerToFlatTranspose) { 225 options = options.setVectorTransformsOptions( 226 VectorTransformsOptions().setVectorTransposeLowering( 227 VectorTransposeLowering::Flat)); 228 } 229 if (lowerToShuffleTranspose) { 230 options = options.setVectorTransformsOptions( 231 VectorTransformsOptions().setVectorTransposeLowering( 232 VectorTransposeLowering::Shuffle)); 233 } 234 if (lowerToAvx2) { 235 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 236 x86vector::avx2::LoweringOptions().setTransposeOptions( 237 x86vector::avx2::TransposeLoweringOptions() 238 .lower4x8xf32() 239 .lower8x8xf32())); 240 } 241 242 OpPassManager dynamicPM("builtin.func"); 243 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 244 if (failed(runPipeline(dynamicPM, getOperation()))) 245 return signalPassFailure(); 246 } 247 }; 248 249 struct TestVectorUnrollingPatterns 250 : public PassWrapper<TestVectorUnrollingPatterns, OperationPass<FuncOp>> { 251 StringRef getArgument() const final { 252 return "test-vector-unrolling-patterns"; 253 } 254 StringRef getDescription() const final { 255 return "Test lowering patterns to unroll contract ops in the vector " 256 "dialect"; 257 } 258 TestVectorUnrollingPatterns() = default; 259 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 260 : PassWrapper(pass) {} 261 void runOnOperation() override { 262 MLIRContext *ctx = &getContext(); 263 RewritePatternSet patterns(ctx); 264 populateVectorUnrollPatterns( 265 patterns, UnrollVectorOptions() 266 .setNativeShape(ArrayRef<int64_t>{2, 2}) 267 .setFilterConstraint([](Operation *op) { 268 return success(isa<arith::AddFOp, vector::FMAOp>(op)); 269 })); 270 271 if (unrollBasedOnType) { 272 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 273 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 274 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 275 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 276 if (auto floatType = contractOp.getLhsType() 277 .getElementType() 278 .dyn_cast<FloatType>()) { 279 if (floatType.getWidth() == 16) { 280 nativeShape[2] = 4; 281 } 282 } 283 return nativeShape; 284 }; 285 populateVectorUnrollPatterns(patterns, 286 UnrollVectorOptions() 287 .setNativeShapeFn(nativeShapeFn) 288 .setFilterConstraint([](Operation *op) { 289 return success(isa<ContractionOp>(op)); 290 })); 291 } else { 292 populateVectorUnrollPatterns( 293 patterns, UnrollVectorOptions() 294 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 295 .setFilterConstraint([](Operation *op) { 296 return success(isa<ContractionOp>(op)); 297 })); 298 } 299 populateVectorToVectorCanonicalizationPatterns(patterns); 300 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 301 } 302 303 Option<bool> unrollBasedOnType{ 304 *this, "unroll-based-on-type", 305 llvm::cl::desc("Set the unroll factor based on type of the operation"), 306 llvm::cl::init(false)}; 307 }; 308 309 struct TestVectorDistributePatterns 310 : public PassWrapper<TestVectorDistributePatterns, OperationPass<FuncOp>> { 311 StringRef getArgument() const final { 312 return "test-vector-distribute-patterns"; 313 } 314 StringRef getDescription() const final { 315 return "Test lowering patterns to distribute vector ops in the vector " 316 "dialect"; 317 } 318 TestVectorDistributePatterns() = default; 319 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 320 : PassWrapper(pass) {} 321 void getDependentDialects(DialectRegistry ®istry) const override { 322 registry.insert<VectorDialect>(); 323 registry.insert<AffineDialect>(); 324 } 325 ListOption<int32_t> multiplicity{ 326 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 327 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 328 329 void runOnOperation() override { 330 MLIRContext *ctx = &getContext(); 331 RewritePatternSet patterns(ctx); 332 FuncOp func = getOperation(); 333 func.walk([&](arith::AddFOp op) { 334 OpBuilder builder(op); 335 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 336 SmallVector<int64_t, 2> mul; 337 SmallVector<AffineExpr, 2> perm; 338 SmallVector<Value, 2> ids; 339 unsigned count = 0; 340 // Remove the multiplicity of 1 and calculate the affine map based on 341 // the multiplicity. 342 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 343 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 344 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 345 mul.push_back(m[i]); 346 ids.push_back(func.getArgument(count++)); 347 perm.push_back(getAffineDimExpr(i, ctx)); 348 } 349 } 350 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 351 perm, ctx); 352 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 353 builder, op.getOperation(), ids, mul, map); 354 if (ops.hasValue()) { 355 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 356 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 357 extractOp); 358 } 359 } 360 }); 361 populatePropagateVectorDistributionPatterns(patterns); 362 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 363 } 364 }; 365 366 struct TestVectorToLoopPatterns 367 : public PassWrapper<TestVectorToLoopPatterns, OperationPass<FuncOp>> { 368 StringRef getArgument() const final { return "test-vector-to-forloop"; } 369 StringRef getDescription() const final { 370 return "Test lowering patterns to break up a vector op into a for loop"; 371 } 372 TestVectorToLoopPatterns() = default; 373 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 374 : PassWrapper(pass) {} 375 void getDependentDialects(DialectRegistry ®istry) const override { 376 registry.insert<VectorDialect>(); 377 registry.insert<AffineDialect>(); 378 } 379 Option<int32_t> multiplicity{ 380 *this, "distribution-multiplicity", 381 llvm::cl::desc("Set the multiplicity used for distributing vector"), 382 llvm::cl::init(32)}; 383 void runOnOperation() override { 384 MLIRContext *ctx = &getContext(); 385 RewritePatternSet patterns(ctx); 386 FuncOp func = getOperation(); 387 func.walk([&](arith::AddFOp op) { 388 // Check that the operation type can be broken down into a loop. 389 VectorType type = op.getType().dyn_cast<VectorType>(); 390 if (!type || type.getRank() != 1 || 391 type.getNumElements() % multiplicity != 0) 392 return mlir::WalkResult::advance(); 393 auto filterAlloc = [](Operation *op) { 394 return !isa<arith::ConstantOp, memref::AllocOp, CallOp>(op); 395 }; 396 auto dependentOps = getSlice(op, filterAlloc); 397 // Create a loop and move instructions from the Op slice into the loop. 398 OpBuilder builder(op); 399 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 400 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 401 auto numIter = 402 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 403 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 404 for (Operation *it : dependentOps) { 405 it->moveBefore(forOp.getBody()->getTerminator()); 406 } 407 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 408 // break up the original op and let the patterns propagate. 409 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 410 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 411 map); 412 if (ops.hasValue()) { 413 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 414 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 415 } 416 return mlir::WalkResult::interrupt(); 417 }); 418 populatePropagateVectorDistributionPatterns(patterns); 419 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 420 } 421 }; 422 423 struct TestVectorTransferUnrollingPatterns 424 : public PassWrapper<TestVectorTransferUnrollingPatterns, 425 OperationPass<FuncOp>> { 426 void getDependentDialects(DialectRegistry ®istry) const override { 427 registry.insert<AffineDialect>(); 428 } 429 StringRef getArgument() const final { 430 return "test-vector-transfer-unrolling-patterns"; 431 } 432 StringRef getDescription() const final { 433 return "Test lowering patterns to unroll transfer ops in the vector " 434 "dialect"; 435 } 436 void runOnOperation() override { 437 MLIRContext *ctx = &getContext(); 438 RewritePatternSet patterns(ctx); 439 populateVectorUnrollPatterns( 440 patterns, 441 UnrollVectorOptions() 442 .setNativeShape(ArrayRef<int64_t>{2, 2}) 443 .setFilterConstraint([](Operation *op) { 444 return success( 445 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 446 })); 447 populateVectorToVectorCanonicalizationPatterns(patterns); 448 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 449 } 450 }; 451 452 struct TestVectorTransferFullPartialSplitPatterns 453 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 454 OperationPass<FuncOp>> { 455 StringRef getArgument() const final { 456 return "test-vector-transfer-full-partial-split"; 457 } 458 StringRef getDescription() const final { 459 return "Test lowering patterns to split " 460 "transfer ops via scf.if + linalg ops"; 461 } 462 TestVectorTransferFullPartialSplitPatterns() = default; 463 TestVectorTransferFullPartialSplitPatterns( 464 const TestVectorTransferFullPartialSplitPatterns &pass) 465 : PassWrapper(pass) {} 466 467 void getDependentDialects(DialectRegistry ®istry) const override { 468 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 469 scf::SCFDialect>(); 470 } 471 472 Option<bool> useLinalgOps{ 473 *this, "use-memref-copy", 474 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 475 "memref.copy operations."), 476 llvm::cl::init(false)}; 477 void runOnOperation() override { 478 MLIRContext *ctx = &getContext(); 479 RewritePatternSet patterns(ctx); 480 VectorTransformsOptions options; 481 if (useLinalgOps) 482 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 483 else 484 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 485 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 486 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 487 } 488 }; 489 490 struct TestVectorTransferOpt 491 : public PassWrapper<TestVectorTransferOpt, OperationPass<FuncOp>> { 492 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 493 StringRef getDescription() const final { 494 return "Test optimization transformations for transfer ops"; 495 } 496 void runOnOperation() override { transferOpflowOpt(getOperation()); } 497 }; 498 499 struct TestVectorTransferLoweringPatterns 500 : public PassWrapper<TestVectorTransferLoweringPatterns, 501 OperationPass<FuncOp>> { 502 void getDependentDialects(DialectRegistry ®istry) const override { 503 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 504 } 505 StringRef getArgument() const final { 506 return "test-vector-transfer-lowering-patterns"; 507 } 508 StringRef getDescription() const final { 509 return "Test lowering patterns to lower transfer ops to other vector ops"; 510 } 511 void runOnOperation() override { 512 RewritePatternSet patterns(&getContext()); 513 populateVectorTransferLoweringPatterns(patterns); 514 populateVectorTransferPermutationMapLoweringPatterns(patterns); 515 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 516 } 517 }; 518 519 struct TestVectorMultiReductionLoweringPatterns 520 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 521 OperationPass<FuncOp>> { 522 TestVectorMultiReductionLoweringPatterns() = default; 523 TestVectorMultiReductionLoweringPatterns( 524 const TestVectorMultiReductionLoweringPatterns &pass) 525 : PassWrapper(pass) {} 526 void getDependentDialects(DialectRegistry ®istry) const override { 527 registry.insert<memref::MemRefDialect>(); 528 } 529 StringRef getArgument() const final { 530 return "test-vector-multi-reduction-lowering-patterns"; 531 } 532 StringRef getDescription() const final { 533 return "Test lowering patterns to lower vector.multi_reduction to other " 534 "vector ops"; 535 } 536 Option<bool> useOuterReductions{ 537 *this, "use-outer-reductions", 538 llvm::cl::desc("Move reductions to outer most dimensions"), 539 llvm::cl::init(false)}; 540 void runOnOperation() override { 541 RewritePatternSet patterns(&getContext()); 542 populateVectorMultiReductionLoweringPatterns( 543 patterns, useOuterReductions 544 ? vector::VectorMultiReductionLowering::InnerParallel 545 : vector::VectorMultiReductionLowering::InnerReduction); 546 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 547 } 548 }; 549 550 struct TestVectorTransferCollapseInnerMostContiguousDims 551 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 552 OperationPass<FuncOp>> { 553 TestVectorTransferCollapseInnerMostContiguousDims() = default; 554 TestVectorTransferCollapseInnerMostContiguousDims( 555 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 556 557 void getDependentDialects(DialectRegistry ®istry) const override { 558 registry.insert<memref::MemRefDialect, AffineDialect>(); 559 } 560 561 StringRef getArgument() const final { 562 return "test-vector-transfer-collapse-inner-most-dims"; 563 } 564 565 StringRef getDescription() const final { 566 return "Test lowering patterns that reducedes the rank of the vector " 567 "transfer memory and vector operands."; 568 } 569 570 void runOnOperation() override { 571 RewritePatternSet patterns(&getContext()); 572 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 573 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 574 } 575 }; 576 577 struct TestVectorReduceToContractPatternsPatterns 578 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 579 OperationPass<FuncOp>> { 580 StringRef getArgument() const final { 581 return "test-vector-reduction-to-contract-patterns"; 582 } 583 StringRef getDescription() const final { 584 return "Test patterns to convert multireduce op to contract and combine " 585 "broadcast/transpose to contract"; 586 } 587 void runOnOperation() override { 588 RewritePatternSet patterns(&getContext()); 589 populateVectorReductionToContractPatterns(patterns); 590 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 591 } 592 }; 593 594 struct TestVectorTransferDropUnitDimsPatterns 595 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, 596 OperationPass<FuncOp>> { 597 StringRef getArgument() const final { 598 return "test-vector-transfer-drop-unit-dims-patterns"; 599 } 600 void getDependentDialects(DialectRegistry ®istry) const override { 601 registry.insert<memref::MemRefDialect>(); 602 } 603 void runOnOperation() override { 604 RewritePatternSet patterns(&getContext()); 605 populateVectorTransferDropUnitDimsPatterns(patterns); 606 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 607 } 608 }; 609 610 struct TestFlattenVectorTransferPatterns 611 : public PassWrapper<TestFlattenVectorTransferPatterns, 612 OperationPass<FuncOp>> { 613 StringRef getArgument() const final { 614 return "test-vector-transfer-flatten-patterns"; 615 } 616 StringRef getDescription() const final { 617 return "Test patterns to rewrite contiguous row-major N-dimensional " 618 "vector.transfer_{read,write} ops into 1D transfers"; 619 } 620 void getDependentDialects(DialectRegistry ®istry) const override { 621 registry.insert<memref::MemRefDialect>(); 622 } 623 void runOnOperation() override { 624 RewritePatternSet patterns(&getContext()); 625 populateFlattenVectorTransferPatterns(patterns); 626 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 627 } 628 }; 629 630 struct TestVectorScanLowering 631 : public PassWrapper<TestVectorScanLowering, OperationPass<FuncOp>> { 632 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 633 StringRef getDescription() const final { 634 return "Test lowering patterns that lower the scan op in the vector " 635 "dialect"; 636 } 637 void runOnOperation() override { 638 RewritePatternSet patterns(&getContext()); 639 populateVectorScanLoweringPatterns(patterns); 640 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 641 } 642 }; 643 644 } // namespace 645 646 namespace mlir { 647 namespace test { 648 void registerTestVectorLowerings() { 649 PassRegistration<TestVectorToVectorLowering>(); 650 651 PassRegistration<TestVectorContractionLowering>(); 652 653 PassRegistration<TestVectorTransposeLowering>(); 654 655 PassRegistration<TestVectorUnrollingPatterns>(); 656 657 PassRegistration<TestVectorTransferUnrollingPatterns>(); 658 659 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 660 661 PassRegistration<TestVectorDistributePatterns>(); 662 663 PassRegistration<TestVectorToLoopPatterns>(); 664 665 PassRegistration<TestVectorTransferOpt>(); 666 667 PassRegistration<TestVectorTransferLoweringPatterns>(); 668 669 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 670 671 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 672 673 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 674 675 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 676 677 PassRegistration<TestFlattenVectorTransferPatterns>(); 678 679 PassRegistration<TestVectorScanLowering>(); 680 } 681 } // namespace test 682 } // namespace mlir 683