1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Passes.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorTransforms.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Pass/PassManager.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 using namespace mlir::vector; 28 29 namespace { 30 31 struct TestVectorToVectorLowering 32 : public PassWrapper<TestVectorToVectorLowering, FunctionPass> { 33 TestVectorToVectorLowering() = default; 34 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 35 : PassWrapper(pass) {} 36 StringRef getArgument() const final { 37 return "test-vector-to-vector-lowering"; 38 } 39 StringRef getDescription() const final { 40 return "Test lowering patterns between ops in the vector dialect"; 41 } 42 43 void getDependentDialects(DialectRegistry ®istry) const override { 44 registry.insert<AffineDialect>(); 45 } 46 47 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 48 llvm::cl::init(false)}; 49 50 void runOnFunction() override { 51 auto *ctx = &getContext(); 52 RewritePatternSet patterns(ctx); 53 if (unroll) { 54 populateVectorUnrollPatterns( 55 patterns, 56 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 57 filter)); 58 } 59 populateVectorToVectorCanonicalizationPatterns(patterns); 60 populateBubbleVectorBitCastOpPatterns(patterns); 61 populateCastAwayVectorLeadingOneDimPatterns(patterns); 62 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 63 } 64 65 private: 66 // Return the target shape based on op type. 67 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 68 if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op)) 69 return SmallVector<int64_t, 4>(2, 2); 70 if (isa<vector::ContractionOp>(op)) 71 return SmallVector<int64_t, 4>(3, 2); 72 // For transfer ops, just propagate the shape coming from 73 // InsertStridedSlices/ExtractStridedSlices. 74 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 75 VectorType dstVec; 76 for (Operation *users : readOp->getUsers()) { 77 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 78 if (!extract) 79 return llvm::None; 80 auto vecType = extract.getResult().getType().cast<VectorType>(); 81 if (dstVec && dstVec != vecType) 82 return llvm::None; 83 dstVec = vecType; 84 } 85 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 86 dstVec.getShape().end()); 87 } 88 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 89 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>(); 90 if (!insert) 91 return llvm::None; 92 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 93 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 94 } 95 return llvm::None; 96 } 97 98 static LogicalResult filter(Operation *op) { 99 return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp, 100 TransferReadOp, TransferWriteOp>(op)); 101 } 102 }; 103 104 struct TestVectorContractionLowering 105 : public PassWrapper<TestVectorContractionLowering, FunctionPass> { 106 StringRef getArgument() const final { 107 return "test-vector-contraction-lowering"; 108 } 109 StringRef getDescription() const final { 110 return "Test lowering patterns that lower contract ops in the vector " 111 "dialect"; 112 } 113 TestVectorContractionLowering() = default; 114 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 115 : PassWrapper(pass) {} 116 117 Option<bool> lowerToFlatMatrix{ 118 *this, "vector-lower-matrix-intrinsics", 119 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 120 llvm::cl::init(false)}; 121 Option<bool> lowerToOuterProduct{ 122 *this, "vector-outerproduct", 123 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 124 llvm::cl::init(false)}; 125 Option<bool> lowerToFilterOuterProduct{ 126 *this, "vector-filter-outerproduct", 127 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 128 "vectors of size 4."), 129 llvm::cl::init(false)}; 130 131 void runOnFunction() override { 132 RewritePatternSet patterns(&getContext()); 133 134 // Test on one pattern in isolation. 135 if (lowerToOuterProduct) { 136 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 137 VectorTransformsOptions options{lowering}; 138 patterns.add<ContractionOpToOuterProductOpLowering>(options, 139 &getContext()); 140 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 141 return; 142 } 143 144 // Test on one pattern in isolation. 145 if (lowerToFilterOuterProduct) { 146 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 147 VectorTransformsOptions options{lowering}; 148 patterns.add<ContractionOpToOuterProductOpLowering>( 149 options, &getContext(), [](vector::ContractionOp op) { 150 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 151 // where M is not 4. 152 if (op.getRhsType().getShape()[0] == 4) 153 return failure(); 154 return success(); 155 }); 156 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 157 return; 158 } 159 160 // Test on all contract lowering patterns. 161 VectorContractLowering contractLowering = VectorContractLowering::Dot; 162 if (lowerToFlatMatrix) 163 contractLowering = VectorContractLowering::Matmul; 164 VectorMultiReductionLowering vectorMultiReductionLowering = 165 VectorMultiReductionLowering::InnerParallel; 166 VectorTransformsOptions options{contractLowering, 167 vectorMultiReductionLowering, 168 VectorTransposeLowering()}; 169 populateVectorBroadcastLoweringPatterns(patterns); 170 populateVectorContractLoweringPatterns(patterns, options); 171 populateVectorMaskOpLoweringPatterns(patterns); 172 populateVectorShapeCastLoweringPatterns(patterns); 173 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 174 } 175 }; 176 177 struct TestVectorTransposeLowering 178 : public PassWrapper<TestVectorTransposeLowering, FunctionPass> { 179 StringRef getArgument() const final { 180 return "test-vector-transpose-lowering"; 181 } 182 StringRef getDescription() const final { 183 return "Test lowering patterns that lower contract ops in the vector " 184 "dialect"; 185 } 186 TestVectorTransposeLowering() = default; 187 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 188 : PassWrapper(pass) {} 189 190 Option<bool> lowerToEltwise{ 191 *this, "eltwise", 192 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 193 llvm::cl::init(false)}; 194 Option<bool> lowerToFlatTranspose{ 195 *this, "flat", 196 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 197 llvm::cl::init(false)}; 198 Option<bool> lowerToShuffleTranspose{ 199 *this, "shuffle", 200 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 201 llvm::cl::init(false)}; 202 Option<bool> lowerToAvx2{ 203 *this, "avx2", 204 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 205 llvm::cl::init(false)}; 206 207 void getDependentDialects(DialectRegistry ®istry) const override { 208 registry.insert<LLVM::LLVMDialect>(); 209 } 210 211 void runOnFunction() override { 212 RewritePatternSet patterns(&getContext()); 213 214 // Test on one pattern in isolation. 215 // Explicitly disable shape_cast lowering. 216 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 217 .enableVectorTransposeLowering() 218 .enableShapeCastLowering(false); 219 if (lowerToEltwise) { 220 options = options.setVectorTransformsOptions( 221 VectorTransformsOptions().setVectorTransposeLowering( 222 VectorTransposeLowering::EltWise)); 223 } 224 if (lowerToFlatTranspose) { 225 options = options.setVectorTransformsOptions( 226 VectorTransformsOptions().setVectorTransposeLowering( 227 VectorTransposeLowering::Flat)); 228 } 229 if (lowerToShuffleTranspose) { 230 options = options.setVectorTransformsOptions( 231 VectorTransformsOptions().setVectorTransposeLowering( 232 VectorTransposeLowering::Shuffle)); 233 } 234 if (lowerToAvx2) { 235 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 236 x86vector::avx2::LoweringOptions().setTransposeOptions( 237 x86vector::avx2::TransposeLoweringOptions() 238 .lower4x8xf32() 239 .lower8x8xf32())); 240 } 241 242 OpPassManager dynamicPM("builtin.func"); 243 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 244 if (failed(runPipeline(dynamicPM, getFunction()))) 245 return signalPassFailure(); 246 } 247 }; 248 249 struct TestVectorUnrollingPatterns 250 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> { 251 StringRef getArgument() const final { 252 return "test-vector-unrolling-patterns"; 253 } 254 StringRef getDescription() const final { 255 return "Test lowering patterns to unroll contract ops in the vector " 256 "dialect"; 257 } 258 TestVectorUnrollingPatterns() = default; 259 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 260 : PassWrapper(pass) {} 261 void runOnFunction() override { 262 MLIRContext *ctx = &getContext(); 263 RewritePatternSet patterns(ctx); 264 populateVectorUnrollPatterns( 265 patterns, UnrollVectorOptions() 266 .setNativeShape(ArrayRef<int64_t>{2, 2}) 267 .setFilterConstraint([](Operation *op) { 268 return success(isa<arith::AddFOp, vector::FMAOp>(op)); 269 })); 270 271 if (unrollBasedOnType) { 272 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 273 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 274 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 275 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 276 if (auto floatType = contractOp.getLhsType() 277 .getElementType() 278 .dyn_cast<FloatType>()) { 279 if (floatType.getWidth() == 16) { 280 nativeShape[2] = 4; 281 } 282 } 283 return nativeShape; 284 }; 285 populateVectorUnrollPatterns(patterns, 286 UnrollVectorOptions() 287 .setNativeShapeFn(nativeShapeFn) 288 .setFilterConstraint([](Operation *op) { 289 return success(isa<ContractionOp>(op)); 290 })); 291 } else { 292 populateVectorUnrollPatterns( 293 patterns, UnrollVectorOptions() 294 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 295 .setFilterConstraint([](Operation *op) { 296 return success(isa<ContractionOp>(op)); 297 })); 298 } 299 populateVectorToVectorCanonicalizationPatterns(patterns); 300 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 301 } 302 303 Option<bool> unrollBasedOnType{ 304 *this, "unroll-based-on-type", 305 llvm::cl::desc("Set the unroll factor based on type of the operation"), 306 llvm::cl::init(false)}; 307 }; 308 309 struct TestVectorDistributePatterns 310 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> { 311 StringRef getArgument() const final { 312 return "test-vector-distribute-patterns"; 313 } 314 StringRef getDescription() const final { 315 return "Test lowering patterns to distribute vector ops in the vector " 316 "dialect"; 317 } 318 TestVectorDistributePatterns() = default; 319 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 320 : PassWrapper(pass) {} 321 void getDependentDialects(DialectRegistry ®istry) const override { 322 registry.insert<VectorDialect>(); 323 registry.insert<AffineDialect>(); 324 } 325 ListOption<int32_t> multiplicity{ 326 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 327 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 328 329 void runOnFunction() override { 330 MLIRContext *ctx = &getContext(); 331 RewritePatternSet patterns(ctx); 332 FuncOp func = getFunction(); 333 func.walk([&](arith::AddFOp op) { 334 OpBuilder builder(op); 335 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 336 SmallVector<int64_t, 2> mul; 337 SmallVector<AffineExpr, 2> perm; 338 SmallVector<Value, 2> ids; 339 unsigned count = 0; 340 // Remove the multiplicity of 1 and calculate the affine map based on 341 // the multiplicity. 342 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 343 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 344 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 345 mul.push_back(m[i]); 346 ids.push_back(func.getArgument(count++)); 347 perm.push_back(getAffineDimExpr(i, ctx)); 348 } 349 } 350 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 351 perm, ctx); 352 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 353 builder, op.getOperation(), ids, mul, map); 354 if (ops.hasValue()) { 355 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 356 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 357 extractOp); 358 } 359 } 360 }); 361 populatePropagateVectorDistributionPatterns(patterns); 362 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 363 } 364 }; 365 366 struct TestVectorToLoopPatterns 367 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> { 368 StringRef getArgument() const final { return "test-vector-to-forloop"; } 369 StringRef getDescription() const final { 370 return "Test lowering patterns to break up a vector op into a for loop"; 371 } 372 TestVectorToLoopPatterns() = default; 373 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 374 : PassWrapper(pass) {} 375 void getDependentDialects(DialectRegistry ®istry) const override { 376 registry.insert<VectorDialect>(); 377 registry.insert<AffineDialect>(); 378 } 379 Option<int32_t> multiplicity{ 380 *this, "distribution-multiplicity", 381 llvm::cl::desc("Set the multiplicity used for distributing vector"), 382 llvm::cl::init(32)}; 383 void runOnFunction() override { 384 MLIRContext *ctx = &getContext(); 385 RewritePatternSet patterns(ctx); 386 FuncOp func = getFunction(); 387 func.walk([&](arith::AddFOp op) { 388 // Check that the operation type can be broken down into a loop. 389 VectorType type = op.getType().dyn_cast<VectorType>(); 390 if (!type || type.getRank() != 1 || 391 type.getNumElements() % multiplicity != 0) 392 return mlir::WalkResult::advance(); 393 auto filterAlloc = [](Operation *op) { 394 return !isa<arith::ConstantOp, memref::AllocOp, CallOp>(op); 395 }; 396 auto dependentOps = getSlice(op, filterAlloc); 397 // Create a loop and move instructions from the Op slice into the loop. 398 OpBuilder builder(op); 399 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 400 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 401 auto numIter = 402 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 403 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 404 for (Operation *it : dependentOps) { 405 it->moveBefore(forOp.getBody()->getTerminator()); 406 } 407 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 408 // break up the original op and let the patterns propagate. 409 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 410 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 411 map); 412 if (ops.hasValue()) { 413 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 414 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 415 } 416 return mlir::WalkResult::interrupt(); 417 }); 418 populatePropagateVectorDistributionPatterns(patterns); 419 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 420 } 421 }; 422 423 struct TestVectorTransferUnrollingPatterns 424 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> { 425 void getDependentDialects(DialectRegistry ®istry) const override { 426 registry.insert<AffineDialect>(); 427 } 428 StringRef getArgument() const final { 429 return "test-vector-transfer-unrolling-patterns"; 430 } 431 StringRef getDescription() const final { 432 return "Test lowering patterns to unroll transfer ops in the vector " 433 "dialect"; 434 } 435 void runOnFunction() override { 436 MLIRContext *ctx = &getContext(); 437 RewritePatternSet patterns(ctx); 438 populateVectorUnrollPatterns( 439 patterns, 440 UnrollVectorOptions() 441 .setNativeShape(ArrayRef<int64_t>{2, 2}) 442 .setFilterConstraint([](Operation *op) { 443 return success( 444 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 445 })); 446 populateVectorToVectorCanonicalizationPatterns(patterns); 447 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 448 } 449 }; 450 451 struct TestVectorTransferFullPartialSplitPatterns 452 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 453 FunctionPass> { 454 StringRef getArgument() const final { 455 return "test-vector-transfer-full-partial-split"; 456 } 457 StringRef getDescription() const final { 458 return "Test lowering patterns to split " 459 "transfer ops via scf.if + linalg ops"; 460 } 461 TestVectorTransferFullPartialSplitPatterns() = default; 462 TestVectorTransferFullPartialSplitPatterns( 463 const TestVectorTransferFullPartialSplitPatterns &pass) 464 : PassWrapper(pass) {} 465 466 void getDependentDialects(DialectRegistry ®istry) const override { 467 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 468 scf::SCFDialect>(); 469 } 470 471 Option<bool> useLinalgOps{ 472 *this, "use-linalg-copy", 473 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 474 "linalg.copy operations."), 475 llvm::cl::init(false)}; 476 void runOnFunction() override { 477 MLIRContext *ctx = &getContext(); 478 RewritePatternSet patterns(ctx); 479 VectorTransformsOptions options; 480 if (useLinalgOps) 481 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 482 else 483 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 484 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 485 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 486 } 487 }; 488 489 struct TestVectorTransferOpt 490 : public PassWrapper<TestVectorTransferOpt, FunctionPass> { 491 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 492 StringRef getDescription() const final { 493 return "Test optimization transformations for transfer ops"; 494 } 495 void runOnFunction() override { transferOpflowOpt(getFunction()); } 496 }; 497 498 struct TestVectorTransferLoweringPatterns 499 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> { 500 void getDependentDialects(DialectRegistry ®istry) const override { 501 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 502 } 503 StringRef getArgument() const final { 504 return "test-vector-transfer-lowering-patterns"; 505 } 506 StringRef getDescription() const final { 507 return "Test lowering patterns to lower transfer ops to other vector ops"; 508 } 509 void runOnFunction() override { 510 RewritePatternSet patterns(&getContext()); 511 populateVectorTransferLoweringPatterns(patterns); 512 populateVectorTransferPermutationMapLoweringPatterns(patterns); 513 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 514 } 515 }; 516 517 struct TestVectorMultiReductionLoweringPatterns 518 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 519 FunctionPass> { 520 TestVectorMultiReductionLoweringPatterns() = default; 521 TestVectorMultiReductionLoweringPatterns( 522 const TestVectorMultiReductionLoweringPatterns &pass) 523 : PassWrapper(pass) {} 524 void getDependentDialects(DialectRegistry ®istry) const override { 525 registry.insert<memref::MemRefDialect>(); 526 } 527 StringRef getArgument() const final { 528 return "test-vector-multi-reduction-lowering-patterns"; 529 } 530 StringRef getDescription() const final { 531 return "Test lowering patterns to lower vector.multi_reduction to other " 532 "vector ops"; 533 } 534 Option<bool> useOuterReductions{ 535 *this, "use-outer-reductions", 536 llvm::cl::desc("Move reductions to outer most dimensions"), 537 llvm::cl::init(false)}; 538 void runOnFunction() override { 539 RewritePatternSet patterns(&getContext()); 540 populateVectorMultiReductionLoweringPatterns( 541 patterns, useOuterReductions 542 ? vector::VectorMultiReductionLowering::InnerParallel 543 : vector::VectorMultiReductionLowering::InnerReduction); 544 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 545 } 546 }; 547 548 struct TestVectorTransferCollapseInnerMostContiguousDims 549 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 550 FunctionPass> { 551 TestVectorTransferCollapseInnerMostContiguousDims() = default; 552 TestVectorTransferCollapseInnerMostContiguousDims( 553 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 554 555 void getDependentDialects(DialectRegistry ®istry) const override { 556 registry.insert<memref::MemRefDialect, AffineDialect>(); 557 } 558 559 StringRef getArgument() const final { 560 return "test-vector-transfer-collapse-inner-most-dims"; 561 } 562 563 StringRef getDescription() const final { 564 return "Test lowering patterns that reducedes the rank of the vector " 565 "transfer memory and vector operands."; 566 } 567 568 void runOnFunction() override { 569 RewritePatternSet patterns(&getContext()); 570 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 571 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 572 } 573 }; 574 575 struct TestVectorReduceToContractPatternsPatterns 576 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 577 FunctionPass> { 578 StringRef getArgument() const final { 579 return "test-vector-reduction-to-contract-patterns"; 580 } 581 StringRef getDescription() const final { 582 return "Test patterns to convert multireduce op to contract and combine " 583 "broadcast/transpose to contract"; 584 } 585 void runOnFunction() override { 586 RewritePatternSet patterns(&getContext()); 587 populateVectorReductionToContractPatterns(patterns); 588 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 589 } 590 }; 591 592 struct TestVectorTransferDropUnitDimsPatterns 593 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, FunctionPass> { 594 StringRef getArgument() const final { 595 return "test-vector-transfer-drop-unit-dims-patterns"; 596 } 597 void getDependentDialects(DialectRegistry ®istry) const override { 598 registry.insert<memref::MemRefDialect>(); 599 } 600 void runOnFunction() override { 601 RewritePatternSet patterns(&getContext()); 602 populateVectorTransferDropUnitDimsPatterns(patterns); 603 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 604 } 605 }; 606 607 struct TestFlattenVectorTransferPatterns 608 : public PassWrapper<TestFlattenVectorTransferPatterns, FunctionPass> { 609 StringRef getArgument() const final { 610 return "test-vector-transfer-flatten-patterns"; 611 } 612 StringRef getDescription() const final { 613 return "Test patterns to rewrite contiguous row-major N-dimensional " 614 "vector.transfer_{read,write} ops into 1D transfers"; 615 } 616 void getDependentDialects(DialectRegistry ®istry) const override { 617 registry.insert<memref::MemRefDialect>(); 618 } 619 void runOnFunction() override { 620 RewritePatternSet patterns(&getContext()); 621 populateFlattenVectorTransferPatterns(patterns); 622 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 623 } 624 }; 625 626 } // namespace 627 628 namespace mlir { 629 namespace test { 630 void registerTestVectorLowerings() { 631 PassRegistration<TestVectorToVectorLowering>(); 632 633 PassRegistration<TestVectorContractionLowering>(); 634 635 PassRegistration<TestVectorTransposeLowering>(); 636 637 PassRegistration<TestVectorUnrollingPatterns>(); 638 639 PassRegistration<TestVectorTransferUnrollingPatterns>(); 640 641 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 642 643 PassRegistration<TestVectorDistributePatterns>(); 644 645 PassRegistration<TestVectorToLoopPatterns>(); 646 647 PassRegistration<TestVectorTransferOpt>(); 648 649 PassRegistration<TestVectorTransferLoweringPatterns>(); 650 651 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 652 653 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 654 655 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 656 657 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 658 659 PassRegistration<TestFlattenVectorTransferPatterns>(); 660 } 661 } // namespace test 662 } // namespace mlir 663