1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Pass/PassManager.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 using namespace mlir::vector; 28 29 namespace { 30 31 struct TestVectorToVectorLowering 32 : public PassWrapper<TestVectorToVectorLowering, OperationPass<FuncOp>> { 33 TestVectorToVectorLowering() = default; 34 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 35 : PassWrapper(pass) {} 36 StringRef getArgument() const final { 37 return "test-vector-to-vector-lowering"; 38 } 39 StringRef getDescription() const final { 40 return "Test lowering patterns between ops in the vector dialect"; 41 } 42 43 void getDependentDialects(DialectRegistry ®istry) const override { 44 registry.insert<AffineDialect>(); 45 } 46 47 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 48 llvm::cl::init(false)}; 49 50 void runOnOperation() override { 51 auto *ctx = &getContext(); 52 RewritePatternSet patterns(ctx); 53 if (unroll) { 54 populateVectorUnrollPatterns( 55 patterns, 56 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 57 filter)); 58 } 59 populateVectorToVectorCanonicalizationPatterns(patterns); 60 populateBubbleVectorBitCastOpPatterns(patterns); 61 populateCastAwayVectorLeadingOneDimPatterns(patterns); 62 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 63 } 64 65 private: 66 // Return the target shape based on op type. 67 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 68 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 69 return SmallVector<int64_t, 4>(2, 2); 70 if (isa<vector::ContractionOp>(op)) 71 return SmallVector<int64_t, 4>(3, 2); 72 // For transfer ops, just propagate the shape coming from 73 // InsertStridedSlices/ExtractStridedSlices. 74 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 75 VectorType dstVec; 76 for (Operation *users : readOp->getUsers()) { 77 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 78 if (!extract) 79 return llvm::None; 80 auto vecType = extract.getResult().getType().cast<VectorType>(); 81 if (dstVec && dstVec != vecType) 82 return llvm::None; 83 dstVec = vecType; 84 } 85 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 86 dstVec.getShape().end()); 87 } 88 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 89 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>(); 90 if (!insert) 91 return llvm::None; 92 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 93 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 94 } 95 return llvm::None; 96 } 97 98 static LogicalResult filter(Operation *op) { 99 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 100 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 101 } 102 }; 103 104 struct TestVectorContractionLowering 105 : public PassWrapper<TestVectorContractionLowering, OperationPass<FuncOp>> { 106 StringRef getArgument() const final { 107 return "test-vector-contraction-lowering"; 108 } 109 StringRef getDescription() const final { 110 return "Test lowering patterns that lower contract ops in the vector " 111 "dialect"; 112 } 113 TestVectorContractionLowering() = default; 114 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 115 : PassWrapper(pass) {} 116 117 Option<bool> lowerToFlatMatrix{ 118 *this, "vector-lower-matrix-intrinsics", 119 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 120 llvm::cl::init(false)}; 121 Option<bool> lowerToOuterProduct{ 122 *this, "vector-outerproduct", 123 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 124 llvm::cl::init(false)}; 125 Option<bool> lowerToFilterOuterProduct{ 126 *this, "vector-filter-outerproduct", 127 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 128 "vectors of size 4."), 129 llvm::cl::init(false)}; 130 131 void runOnOperation() override { 132 RewritePatternSet patterns(&getContext()); 133 134 // Test on one pattern in isolation. 135 if (lowerToOuterProduct) { 136 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 137 VectorTransformsOptions options{lowering}; 138 patterns.add<ContractionOpToOuterProductOpLowering>(options, 139 &getContext()); 140 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 141 return; 142 } 143 144 // Test on one pattern in isolation. 145 if (lowerToFilterOuterProduct) { 146 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 147 VectorTransformsOptions options{lowering}; 148 patterns.add<ContractionOpToOuterProductOpLowering>( 149 options, &getContext(), [](vector::ContractionOp op) { 150 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 151 // where M is not 4. 152 if (op.getRhsType().getShape()[0] == 4) 153 return failure(); 154 return success(); 155 }); 156 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 157 return; 158 } 159 160 // Test on all contract lowering patterns. 161 VectorContractLowering contractLowering = VectorContractLowering::Dot; 162 if (lowerToFlatMatrix) 163 contractLowering = VectorContractLowering::Matmul; 164 VectorMultiReductionLowering vectorMultiReductionLowering = 165 VectorMultiReductionLowering::InnerParallel; 166 VectorTransformsOptions options{contractLowering, 167 vectorMultiReductionLowering, 168 VectorTransposeLowering()}; 169 populateVectorBroadcastLoweringPatterns(patterns); 170 populateVectorContractLoweringPatterns(patterns, options); 171 populateVectorMaskOpLoweringPatterns(patterns); 172 populateVectorShapeCastLoweringPatterns(patterns); 173 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 174 } 175 }; 176 177 struct TestVectorTransposeLowering 178 : public PassWrapper<TestVectorTransposeLowering, OperationPass<FuncOp>> { 179 StringRef getArgument() const final { 180 return "test-vector-transpose-lowering"; 181 } 182 StringRef getDescription() const final { 183 return "Test lowering patterns that lower contract ops in the vector " 184 "dialect"; 185 } 186 TestVectorTransposeLowering() = default; 187 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 188 : PassWrapper(pass) {} 189 190 Option<bool> lowerToEltwise{ 191 *this, "eltwise", 192 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 193 llvm::cl::init(false)}; 194 Option<bool> lowerToFlatTranspose{ 195 *this, "flat", 196 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 197 llvm::cl::init(false)}; 198 Option<bool> lowerToShuffleTranspose{ 199 *this, "shuffle", 200 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 201 llvm::cl::init(false)}; 202 Option<bool> lowerToAvx2{ 203 *this, "avx2", 204 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 205 llvm::cl::init(false)}; 206 207 void getDependentDialects(DialectRegistry ®istry) const override { 208 registry.insert<LLVM::LLVMDialect>(); 209 } 210 211 void runOnOperation() override { 212 RewritePatternSet patterns(&getContext()); 213 214 // Test on one pattern in isolation. 215 // Explicitly disable shape_cast lowering. 216 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 217 .enableVectorTransposeLowering() 218 .enableShapeCastLowering(false); 219 if (lowerToEltwise) { 220 options = options.setVectorTransformsOptions( 221 VectorTransformsOptions().setVectorTransposeLowering( 222 VectorTransposeLowering::EltWise)); 223 } 224 if (lowerToFlatTranspose) { 225 options = options.setVectorTransformsOptions( 226 VectorTransformsOptions().setVectorTransposeLowering( 227 VectorTransposeLowering::Flat)); 228 } 229 if (lowerToShuffleTranspose) { 230 options = options.setVectorTransformsOptions( 231 VectorTransformsOptions().setVectorTransposeLowering( 232 VectorTransposeLowering::Shuffle)); 233 } 234 if (lowerToAvx2) { 235 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 236 x86vector::avx2::LoweringOptions().setTransposeOptions( 237 x86vector::avx2::TransposeLoweringOptions() 238 .lower4x8xf32() 239 .lower8x8xf32())); 240 } 241 242 OpPassManager dynamicPM("builtin.func"); 243 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 244 if (failed(runPipeline(dynamicPM, getOperation()))) 245 return signalPassFailure(); 246 } 247 }; 248 249 struct TestVectorUnrollingPatterns 250 : public PassWrapper<TestVectorUnrollingPatterns, OperationPass<FuncOp>> { 251 StringRef getArgument() const final { 252 return "test-vector-unrolling-patterns"; 253 } 254 StringRef getDescription() const final { 255 return "Test lowering patterns to unroll contract ops in the vector " 256 "dialect"; 257 } 258 TestVectorUnrollingPatterns() = default; 259 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 260 : PassWrapper(pass) {} 261 void runOnOperation() override { 262 MLIRContext *ctx = &getContext(); 263 RewritePatternSet patterns(ctx); 264 populateVectorUnrollPatterns( 265 patterns, UnrollVectorOptions() 266 .setNativeShape(ArrayRef<int64_t>{2, 2}) 267 .setFilterConstraint([](Operation *op) { 268 return success(isa<arith::AddFOp, vector::FMAOp, 269 vector::MultiDimReductionOp>(op)); 270 })); 271 populateVectorUnrollPatterns( 272 patterns, UnrollVectorOptions() 273 .setNativeShape(ArrayRef<int64_t>{2}) 274 .setFilterConstraint([](Operation *op) { 275 return success(isa<vector::ReductionOp>(op)); 276 })); 277 278 if (unrollBasedOnType) { 279 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 280 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 281 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 282 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 283 if (auto floatType = contractOp.getLhsType() 284 .getElementType() 285 .dyn_cast<FloatType>()) { 286 if (floatType.getWidth() == 16) { 287 nativeShape[2] = 4; 288 } 289 } 290 return nativeShape; 291 }; 292 populateVectorUnrollPatterns(patterns, 293 UnrollVectorOptions() 294 .setNativeShapeFn(nativeShapeFn) 295 .setFilterConstraint([](Operation *op) { 296 return success(isa<ContractionOp>(op)); 297 })); 298 } else { 299 populateVectorUnrollPatterns( 300 patterns, UnrollVectorOptions() 301 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 302 .setFilterConstraint([](Operation *op) { 303 return success(isa<ContractionOp>(op)); 304 })); 305 } 306 populateVectorToVectorCanonicalizationPatterns(patterns); 307 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 308 } 309 310 Option<bool> unrollBasedOnType{ 311 *this, "unroll-based-on-type", 312 llvm::cl::desc("Set the unroll factor based on type of the operation"), 313 llvm::cl::init(false)}; 314 }; 315 316 struct TestVectorDistributePatterns 317 : public PassWrapper<TestVectorDistributePatterns, OperationPass<FuncOp>> { 318 StringRef getArgument() const final { 319 return "test-vector-distribute-patterns"; 320 } 321 StringRef getDescription() const final { 322 return "Test lowering patterns to distribute vector ops in the vector " 323 "dialect"; 324 } 325 TestVectorDistributePatterns() = default; 326 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 327 : PassWrapper(pass) {} 328 void getDependentDialects(DialectRegistry ®istry) const override { 329 registry.insert<VectorDialect>(); 330 registry.insert<AffineDialect>(); 331 } 332 ListOption<int32_t> multiplicity{ 333 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 334 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 335 336 void runOnOperation() override { 337 MLIRContext *ctx = &getContext(); 338 RewritePatternSet patterns(ctx); 339 FuncOp func = getOperation(); 340 func.walk([&](arith::AddFOp op) { 341 OpBuilder builder(op); 342 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 343 SmallVector<int64_t, 2> mul; 344 SmallVector<AffineExpr, 2> perm; 345 SmallVector<Value, 2> ids; 346 unsigned count = 0; 347 // Remove the multiplicity of 1 and calculate the affine map based on 348 // the multiplicity. 349 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 350 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 351 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 352 mul.push_back(m[i]); 353 ids.push_back(func.getArgument(count++)); 354 perm.push_back(getAffineDimExpr(i, ctx)); 355 } 356 } 357 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 358 perm, ctx); 359 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 360 builder, op.getOperation(), ids, mul, map); 361 if (ops.hasValue()) { 362 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 363 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 364 extractOp); 365 } 366 } 367 }); 368 populatePropagateVectorDistributionPatterns(patterns); 369 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 370 } 371 }; 372 373 struct TestVectorToLoopPatterns 374 : public PassWrapper<TestVectorToLoopPatterns, OperationPass<FuncOp>> { 375 StringRef getArgument() const final { return "test-vector-to-forloop"; } 376 StringRef getDescription() const final { 377 return "Test lowering patterns to break up a vector op into a for loop"; 378 } 379 TestVectorToLoopPatterns() = default; 380 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 381 : PassWrapper(pass) {} 382 void getDependentDialects(DialectRegistry ®istry) const override { 383 registry.insert<VectorDialect>(); 384 registry.insert<AffineDialect>(); 385 } 386 Option<int32_t> multiplicity{ 387 *this, "distribution-multiplicity", 388 llvm::cl::desc("Set the multiplicity used for distributing vector"), 389 llvm::cl::init(32)}; 390 void runOnOperation() override { 391 MLIRContext *ctx = &getContext(); 392 RewritePatternSet patterns(ctx); 393 FuncOp func = getOperation(); 394 func.walk([&](arith::AddFOp op) { 395 // Check that the operation type can be broken down into a loop. 396 VectorType type = op.getType().dyn_cast<VectorType>(); 397 if (!type || type.getRank() != 1 || 398 type.getNumElements() % multiplicity != 0) 399 return mlir::WalkResult::advance(); 400 auto filterAlloc = [](Operation *op) { 401 return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op); 402 }; 403 auto dependentOps = getSlice(op, filterAlloc); 404 // Create a loop and move instructions from the Op slice into the loop. 405 OpBuilder builder(op); 406 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 407 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 408 auto numIter = 409 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 410 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 411 for (Operation *it : dependentOps) { 412 it->moveBefore(forOp.getBody()->getTerminator()); 413 } 414 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 415 // break up the original op and let the patterns propagate. 416 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 417 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 418 map); 419 if (ops.hasValue()) { 420 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 421 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 422 } 423 return mlir::WalkResult::interrupt(); 424 }); 425 populatePropagateVectorDistributionPatterns(patterns); 426 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 427 } 428 }; 429 430 struct TestVectorTransferUnrollingPatterns 431 : public PassWrapper<TestVectorTransferUnrollingPatterns, 432 OperationPass<FuncOp>> { 433 void getDependentDialects(DialectRegistry ®istry) const override { 434 registry.insert<AffineDialect>(); 435 } 436 StringRef getArgument() const final { 437 return "test-vector-transfer-unrolling-patterns"; 438 } 439 StringRef getDescription() const final { 440 return "Test lowering patterns to unroll transfer ops in the vector " 441 "dialect"; 442 } 443 void runOnOperation() override { 444 MLIRContext *ctx = &getContext(); 445 RewritePatternSet patterns(ctx); 446 populateVectorUnrollPatterns( 447 patterns, 448 UnrollVectorOptions() 449 .setNativeShape(ArrayRef<int64_t>{2, 2}) 450 .setFilterConstraint([](Operation *op) { 451 return success( 452 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 453 })); 454 populateVectorToVectorCanonicalizationPatterns(patterns); 455 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 456 } 457 }; 458 459 struct TestVectorTransferFullPartialSplitPatterns 460 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 461 OperationPass<FuncOp>> { 462 StringRef getArgument() const final { 463 return "test-vector-transfer-full-partial-split"; 464 } 465 StringRef getDescription() const final { 466 return "Test lowering patterns to split " 467 "transfer ops via scf.if + linalg ops"; 468 } 469 TestVectorTransferFullPartialSplitPatterns() = default; 470 TestVectorTransferFullPartialSplitPatterns( 471 const TestVectorTransferFullPartialSplitPatterns &pass) 472 : PassWrapper(pass) {} 473 474 void getDependentDialects(DialectRegistry ®istry) const override { 475 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 476 scf::SCFDialect>(); 477 } 478 479 Option<bool> useLinalgOps{ 480 *this, "use-memref-copy", 481 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 482 "memref.copy operations."), 483 llvm::cl::init(false)}; 484 void runOnOperation() override { 485 MLIRContext *ctx = &getContext(); 486 RewritePatternSet patterns(ctx); 487 VectorTransformsOptions options; 488 if (useLinalgOps) 489 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 490 else 491 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 492 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 493 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 494 } 495 }; 496 497 struct TestVectorTransferOpt 498 : public PassWrapper<TestVectorTransferOpt, OperationPass<FuncOp>> { 499 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 500 StringRef getDescription() const final { 501 return "Test optimization transformations for transfer ops"; 502 } 503 void runOnOperation() override { transferOpflowOpt(getOperation()); } 504 }; 505 506 struct TestVectorTransferLoweringPatterns 507 : public PassWrapper<TestVectorTransferLoweringPatterns, 508 OperationPass<FuncOp>> { 509 void getDependentDialects(DialectRegistry ®istry) const override { 510 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 511 } 512 StringRef getArgument() const final { 513 return "test-vector-transfer-lowering-patterns"; 514 } 515 StringRef getDescription() const final { 516 return "Test lowering patterns to lower transfer ops to other vector ops"; 517 } 518 void runOnOperation() override { 519 RewritePatternSet patterns(&getContext()); 520 populateVectorTransferLoweringPatterns(patterns); 521 populateVectorTransferPermutationMapLoweringPatterns(patterns); 522 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 523 } 524 }; 525 526 struct TestVectorMultiReductionLoweringPatterns 527 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 528 OperationPass<FuncOp>> { 529 TestVectorMultiReductionLoweringPatterns() = default; 530 TestVectorMultiReductionLoweringPatterns( 531 const TestVectorMultiReductionLoweringPatterns &pass) 532 : PassWrapper(pass) {} 533 void getDependentDialects(DialectRegistry ®istry) const override { 534 registry.insert<memref::MemRefDialect>(); 535 } 536 StringRef getArgument() const final { 537 return "test-vector-multi-reduction-lowering-patterns"; 538 } 539 StringRef getDescription() const final { 540 return "Test lowering patterns to lower vector.multi_reduction to other " 541 "vector ops"; 542 } 543 Option<bool> useOuterReductions{ 544 *this, "use-outer-reductions", 545 llvm::cl::desc("Move reductions to outer most dimensions"), 546 llvm::cl::init(false)}; 547 void runOnOperation() override { 548 RewritePatternSet patterns(&getContext()); 549 populateVectorMultiReductionLoweringPatterns( 550 patterns, useOuterReductions 551 ? vector::VectorMultiReductionLowering::InnerParallel 552 : vector::VectorMultiReductionLowering::InnerReduction); 553 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 554 } 555 }; 556 557 struct TestVectorTransferCollapseInnerMostContiguousDims 558 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 559 OperationPass<FuncOp>> { 560 TestVectorTransferCollapseInnerMostContiguousDims() = default; 561 TestVectorTransferCollapseInnerMostContiguousDims( 562 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 563 564 void getDependentDialects(DialectRegistry ®istry) const override { 565 registry.insert<memref::MemRefDialect, AffineDialect>(); 566 } 567 568 StringRef getArgument() const final { 569 return "test-vector-transfer-collapse-inner-most-dims"; 570 } 571 572 StringRef getDescription() const final { 573 return "Test lowering patterns that reducedes the rank of the vector " 574 "transfer memory and vector operands."; 575 } 576 577 void runOnOperation() override { 578 RewritePatternSet patterns(&getContext()); 579 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 580 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 581 } 582 }; 583 584 struct TestVectorReduceToContractPatternsPatterns 585 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 586 OperationPass<FuncOp>> { 587 StringRef getArgument() const final { 588 return "test-vector-reduction-to-contract-patterns"; 589 } 590 StringRef getDescription() const final { 591 return "Test patterns to convert multireduce op to contract and combine " 592 "broadcast/transpose to contract"; 593 } 594 void runOnOperation() override { 595 RewritePatternSet patterns(&getContext()); 596 populateVectorReductionToContractPatterns(patterns); 597 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 598 } 599 }; 600 601 struct TestVectorTransferDropUnitDimsPatterns 602 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, 603 OperationPass<FuncOp>> { 604 StringRef getArgument() const final { 605 return "test-vector-transfer-drop-unit-dims-patterns"; 606 } 607 void getDependentDialects(DialectRegistry ®istry) const override { 608 registry.insert<memref::MemRefDialect>(); 609 } 610 void runOnOperation() override { 611 RewritePatternSet patterns(&getContext()); 612 populateVectorTransferDropUnitDimsPatterns(patterns); 613 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 614 } 615 }; 616 617 struct TestFlattenVectorTransferPatterns 618 : public PassWrapper<TestFlattenVectorTransferPatterns, 619 OperationPass<FuncOp>> { 620 StringRef getArgument() const final { 621 return "test-vector-transfer-flatten-patterns"; 622 } 623 StringRef getDescription() const final { 624 return "Test patterns to rewrite contiguous row-major N-dimensional " 625 "vector.transfer_{read,write} ops into 1D transfers"; 626 } 627 void getDependentDialects(DialectRegistry ®istry) const override { 628 registry.insert<memref::MemRefDialect>(); 629 } 630 void runOnOperation() override { 631 RewritePatternSet patterns(&getContext()); 632 populateFlattenVectorTransferPatterns(patterns); 633 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 634 } 635 }; 636 637 struct TestVectorScanLowering 638 : public PassWrapper<TestVectorScanLowering, OperationPass<FuncOp>> { 639 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 640 StringRef getDescription() const final { 641 return "Test lowering patterns that lower the scan op in the vector " 642 "dialect"; 643 } 644 void runOnOperation() override { 645 RewritePatternSet patterns(&getContext()); 646 populateVectorScanLoweringPatterns(patterns); 647 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 648 } 649 }; 650 651 } // namespace 652 653 namespace mlir { 654 namespace test { 655 void registerTestVectorLowerings() { 656 PassRegistration<TestVectorToVectorLowering>(); 657 658 PassRegistration<TestVectorContractionLowering>(); 659 660 PassRegistration<TestVectorTransposeLowering>(); 661 662 PassRegistration<TestVectorUnrollingPatterns>(); 663 664 PassRegistration<TestVectorTransferUnrollingPatterns>(); 665 666 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 667 668 PassRegistration<TestVectorDistributePatterns>(); 669 670 PassRegistration<TestVectorToLoopPatterns>(); 671 672 PassRegistration<TestVectorTransferOpt>(); 673 674 PassRegistration<TestVectorTransferLoweringPatterns>(); 675 676 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 677 678 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 679 680 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 681 682 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 683 684 PassRegistration<TestFlattenVectorTransferPatterns>(); 685 686 PassRegistration<TestVectorScanLowering>(); 687 } 688 } // namespace test 689 } // namespace mlir 690