1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Passes.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorTransforms.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Pass/PassManager.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 using namespace mlir::vector; 28 29 namespace { 30 31 struct TestVectorToVectorLowering 32 : public PassWrapper<TestVectorToVectorLowering, FunctionPass> { 33 TestVectorToVectorLowering() = default; 34 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) {} 35 StringRef getArgument() const final { 36 return "test-vector-to-vector-lowering"; 37 } 38 StringRef getDescription() const final { 39 return "Test lowering patterns between ops in the vector dialect"; 40 } 41 42 void getDependentDialects(DialectRegistry ®istry) const override { 43 registry.insert<AffineDialect>(); 44 } 45 46 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 47 llvm::cl::init(false)}; 48 49 void runOnFunction() override { 50 auto *ctx = &getContext(); 51 RewritePatternSet patterns(ctx); 52 if (unroll) { 53 populateVectorUnrollPatterns( 54 patterns, 55 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 56 filter)); 57 } 58 populateVectorToVectorCanonicalizationPatterns(patterns); 59 populateBubbleVectorBitCastOpPatterns(patterns); 60 populateCastAwayVectorLeadingOneDimPatterns(patterns); 61 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 62 } 63 64 private: 65 // Return the target shape based on op type. 66 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 67 if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op)) 68 return SmallVector<int64_t, 4>(2, 2); 69 if (isa<vector::ContractionOp>(op)) 70 return SmallVector<int64_t, 4>(3, 2); 71 // For transfer ops, just propagate the shape coming from 72 // InsertStridedSlices/ExtractStridedSlices. 73 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 74 VectorType dstVec; 75 for (Operation *users : readOp->getUsers()) { 76 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 77 if (!extract) 78 return llvm::None; 79 auto vecType = extract.getResult().getType().cast<VectorType>(); 80 if (dstVec && dstVec != vecType) 81 return llvm::None; 82 dstVec = vecType; 83 } 84 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 85 dstVec.getShape().end()); 86 } 87 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 88 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>(); 89 if (!insert) 90 return llvm::None; 91 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 92 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 93 } 94 return llvm::None; 95 } 96 97 static LogicalResult filter(Operation *op) { 98 return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp, 99 TransferReadOp, TransferWriteOp>(op)); 100 } 101 }; 102 103 struct TestVectorContractionLowering 104 : public PassWrapper<TestVectorContractionLowering, FunctionPass> { 105 StringRef getArgument() const final { 106 return "test-vector-contraction-lowering"; 107 } 108 StringRef getDescription() const final { 109 return "Test lowering patterns that lower contract ops in the vector " 110 "dialect"; 111 } 112 TestVectorContractionLowering() = default; 113 TestVectorContractionLowering(const TestVectorContractionLowering &pass) {} 114 115 Option<bool> lowerToFlatMatrix{ 116 *this, "vector-lower-matrix-intrinsics", 117 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 118 llvm::cl::init(false)}; 119 Option<bool> lowerToOuterProduct{ 120 *this, "vector-outerproduct", 121 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 122 llvm::cl::init(false)}; 123 Option<bool> lowerToFilterOuterProduct{ 124 *this, "vector-filter-outerproduct", 125 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 126 "vectors of size 4."), 127 llvm::cl::init(false)}; 128 129 void runOnFunction() override { 130 RewritePatternSet patterns(&getContext()); 131 132 // Test on one pattern in isolation. 133 if (lowerToOuterProduct) { 134 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 135 VectorTransformsOptions options{lowering}; 136 patterns.add<ContractionOpToOuterProductOpLowering>(options, 137 &getContext()); 138 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 139 return; 140 } 141 142 // Test on one pattern in isolation. 143 if (lowerToFilterOuterProduct) { 144 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 145 VectorTransformsOptions options{lowering}; 146 patterns.add<ContractionOpToOuterProductOpLowering>( 147 options, &getContext(), [](vector::ContractionOp op) { 148 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 149 // where M is not 4. 150 if (op.getRhsType().getShape()[0] == 4) 151 return failure(); 152 return success(); 153 }); 154 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 155 return; 156 } 157 158 // Test on all contract lowering patterns. 159 VectorContractLowering contractLowering = VectorContractLowering::Dot; 160 if (lowerToFlatMatrix) 161 contractLowering = VectorContractLowering::Matmul; 162 VectorMultiReductionLowering vectorMultiReductionLowering = 163 VectorMultiReductionLowering::InnerParallel; 164 VectorTransformsOptions options{contractLowering, 165 vectorMultiReductionLowering, 166 VectorTransposeLowering()}; 167 populateVectorBroadcastLoweringPatterns(patterns); 168 populateVectorContractLoweringPatterns(patterns, options); 169 populateVectorMaskOpLoweringPatterns(patterns); 170 populateVectorShapeCastLoweringPatterns(patterns); 171 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 172 } 173 }; 174 175 struct TestVectorTransposeLowering 176 : public PassWrapper<TestVectorTransposeLowering, FunctionPass> { 177 StringRef getArgument() const final { 178 return "test-vector-transpose-lowering"; 179 } 180 StringRef getDescription() const final { 181 return "Test lowering patterns that lower contract ops in the vector " 182 "dialect"; 183 } 184 TestVectorTransposeLowering() = default; 185 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) {} 186 187 Option<bool> lowerToEltwise{ 188 *this, "eltwise", 189 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 190 llvm::cl::init(false)}; 191 Option<bool> lowerToFlatTranspose{ 192 *this, "flat", 193 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 194 llvm::cl::init(false)}; 195 Option<bool> lowerToShuffleTranspose{ 196 *this, "shuffle", 197 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 198 llvm::cl::init(false)}; 199 Option<bool> lowerToAvx2{ 200 *this, "avx2", 201 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 202 llvm::cl::init(false)}; 203 204 void getDependentDialects(DialectRegistry ®istry) const override { 205 registry.insert<LLVM::LLVMDialect>(); 206 } 207 208 void runOnFunction() override { 209 RewritePatternSet patterns(&getContext()); 210 211 // Test on one pattern in isolation. 212 // Explicitly disable shape_cast lowering. 213 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 214 .enableVectorTransposeLowering() 215 .enableShapeCastLowering(false); 216 if (lowerToEltwise) { 217 options = options.setVectorTransformsOptions( 218 VectorTransformsOptions().setVectorTransposeLowering( 219 VectorTransposeLowering::EltWise)); 220 } 221 if (lowerToFlatTranspose) { 222 options = options.setVectorTransformsOptions( 223 VectorTransformsOptions().setVectorTransposeLowering( 224 VectorTransposeLowering::Flat)); 225 } 226 if (lowerToShuffleTranspose) { 227 options = options.setVectorTransformsOptions( 228 VectorTransformsOptions().setVectorTransposeLowering( 229 VectorTransposeLowering::Shuffle)); 230 } 231 if (lowerToAvx2) { 232 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 233 x86vector::avx2::LoweringOptions().setTransposeOptions( 234 x86vector::avx2::TransposeLoweringOptions() 235 .lower4x8xf32() 236 .lower8x8xf32())); 237 } 238 239 OpPassManager dynamicPM("builtin.func"); 240 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 241 if (failed(runPipeline(dynamicPM, getFunction()))) 242 return signalPassFailure(); 243 } 244 }; 245 246 struct TestVectorUnrollingPatterns 247 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> { 248 StringRef getArgument() const final { 249 return "test-vector-unrolling-patterns"; 250 } 251 StringRef getDescription() const final { 252 return "Test lowering patterns to unroll contract ops in the vector " 253 "dialect"; 254 } 255 TestVectorUnrollingPatterns() = default; 256 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} 257 void runOnFunction() override { 258 MLIRContext *ctx = &getContext(); 259 RewritePatternSet patterns(ctx); 260 populateVectorUnrollPatterns( 261 patterns, UnrollVectorOptions() 262 .setNativeShape(ArrayRef<int64_t>{2, 2}) 263 .setFilterConstraint([](Operation *op) { 264 return success(isa<arith::AddFOp, vector::FMAOp>(op)); 265 })); 266 267 if (unrollBasedOnType) { 268 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 269 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 270 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 271 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 272 if (auto floatType = contractOp.getLhsType() 273 .getElementType() 274 .dyn_cast<FloatType>()) { 275 if (floatType.getWidth() == 16) { 276 nativeShape[2] = 4; 277 } 278 } 279 return nativeShape; 280 }; 281 populateVectorUnrollPatterns(patterns, 282 UnrollVectorOptions() 283 .setNativeShapeFn(nativeShapeFn) 284 .setFilterConstraint([](Operation *op) { 285 return success(isa<ContractionOp>(op)); 286 })); 287 } else { 288 populateVectorUnrollPatterns( 289 patterns, UnrollVectorOptions() 290 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 291 .setFilterConstraint([](Operation *op) { 292 return success(isa<ContractionOp>(op)); 293 })); 294 } 295 populateVectorToVectorCanonicalizationPatterns(patterns); 296 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 297 } 298 299 Option<bool> unrollBasedOnType{ 300 *this, "unroll-based-on-type", 301 llvm::cl::desc("Set the unroll factor based on type of the operation"), 302 llvm::cl::init(false)}; 303 }; 304 305 struct TestVectorDistributePatterns 306 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> { 307 StringRef getArgument() const final { 308 return "test-vector-distribute-patterns"; 309 } 310 StringRef getDescription() const final { 311 return "Test lowering patterns to distribute vector ops in the vector " 312 "dialect"; 313 } 314 TestVectorDistributePatterns() = default; 315 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {} 316 void getDependentDialects(DialectRegistry ®istry) const override { 317 registry.insert<VectorDialect>(); 318 registry.insert<AffineDialect>(); 319 } 320 ListOption<int32_t> multiplicity{ 321 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 322 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 323 324 void runOnFunction() override { 325 MLIRContext *ctx = &getContext(); 326 RewritePatternSet patterns(ctx); 327 FuncOp func = getFunction(); 328 func.walk([&](arith::AddFOp op) { 329 OpBuilder builder(op); 330 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 331 SmallVector<int64_t, 2> mul; 332 SmallVector<AffineExpr, 2> perm; 333 SmallVector<Value, 2> ids; 334 unsigned count = 0; 335 // Remove the multiplicity of 1 and calculate the affine map based on 336 // the multiplicity. 337 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 338 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 339 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 340 mul.push_back(m[i]); 341 ids.push_back(func.getArgument(count++)); 342 perm.push_back(getAffineDimExpr(i, ctx)); 343 } 344 } 345 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 346 perm, ctx); 347 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 348 builder, op.getOperation(), ids, mul, map); 349 if (ops.hasValue()) { 350 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 351 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 352 extractOp); 353 } 354 } 355 }); 356 populatePropagateVectorDistributionPatterns(patterns); 357 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 358 } 359 }; 360 361 struct TestVectorToLoopPatterns 362 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> { 363 StringRef getArgument() const final { return "test-vector-to-forloop"; } 364 StringRef getDescription() const final { 365 return "Test lowering patterns to break up a vector op into a for loop"; 366 } 367 TestVectorToLoopPatterns() = default; 368 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {} 369 void getDependentDialects(DialectRegistry ®istry) const override { 370 registry.insert<VectorDialect>(); 371 registry.insert<AffineDialect>(); 372 } 373 Option<int32_t> multiplicity{ 374 *this, "distribution-multiplicity", 375 llvm::cl::desc("Set the multiplicity used for distributing vector"), 376 llvm::cl::init(32)}; 377 void runOnFunction() override { 378 MLIRContext *ctx = &getContext(); 379 RewritePatternSet patterns(ctx); 380 FuncOp func = getFunction(); 381 func.walk([&](arith::AddFOp op) { 382 // Check that the operation type can be broken down into a loop. 383 VectorType type = op.getType().dyn_cast<VectorType>(); 384 if (!type || type.getRank() != 1 || 385 type.getNumElements() % multiplicity != 0) 386 return mlir::WalkResult::advance(); 387 auto filterAlloc = [](Operation *op) { 388 if (isa<arith::ConstantOp, memref::AllocOp, CallOp>(op)) 389 return false; 390 return true; 391 }; 392 auto dependentOps = getSlice(op, filterAlloc); 393 // Create a loop and move instructions from the Op slice into the loop. 394 OpBuilder builder(op); 395 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 396 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 397 auto numIter = 398 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 399 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 400 for (Operation *it : dependentOps) { 401 it->moveBefore(forOp.getBody()->getTerminator()); 402 } 403 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 404 // break up the original op and let the patterns propagate. 405 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 406 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 407 map); 408 if (ops.hasValue()) { 409 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 410 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 411 } 412 return mlir::WalkResult::interrupt(); 413 }); 414 populatePropagateVectorDistributionPatterns(patterns); 415 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 416 } 417 }; 418 419 struct TestVectorTransferUnrollingPatterns 420 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> { 421 void getDependentDialects(DialectRegistry ®istry) const override { 422 registry.insert<AffineDialect>(); 423 } 424 StringRef getArgument() const final { 425 return "test-vector-transfer-unrolling-patterns"; 426 } 427 StringRef getDescription() const final { 428 return "Test lowering patterns to unroll transfer ops in the vector " 429 "dialect"; 430 } 431 void runOnFunction() override { 432 MLIRContext *ctx = &getContext(); 433 RewritePatternSet patterns(ctx); 434 populateVectorUnrollPatterns( 435 patterns, 436 UnrollVectorOptions() 437 .setNativeShape(ArrayRef<int64_t>{2, 2}) 438 .setFilterConstraint([](Operation *op) { 439 return success( 440 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 441 })); 442 populateVectorToVectorCanonicalizationPatterns(patterns); 443 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 444 } 445 }; 446 447 struct TestVectorTransferFullPartialSplitPatterns 448 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 449 FunctionPass> { 450 StringRef getArgument() const final { 451 return "test-vector-transfer-full-partial-split"; 452 } 453 StringRef getDescription() const final { 454 return "Test lowering patterns to split " 455 "transfer ops via scf.if + linalg ops"; 456 } 457 TestVectorTransferFullPartialSplitPatterns() = default; 458 TestVectorTransferFullPartialSplitPatterns( 459 const TestVectorTransferFullPartialSplitPatterns &pass) {} 460 461 void getDependentDialects(DialectRegistry ®istry) const override { 462 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 463 scf::SCFDialect>(); 464 } 465 466 Option<bool> useLinalgOps{ 467 *this, "use-linalg-copy", 468 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 469 "linalg.copy operations."), 470 llvm::cl::init(false)}; 471 void runOnFunction() override { 472 MLIRContext *ctx = &getContext(); 473 RewritePatternSet patterns(ctx); 474 VectorTransformsOptions options; 475 if (useLinalgOps) 476 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 477 else 478 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 479 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 480 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 481 } 482 }; 483 484 struct TestVectorTransferOpt 485 : public PassWrapper<TestVectorTransferOpt, FunctionPass> { 486 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 487 StringRef getDescription() const final { 488 return "Test optimization transformations for transfer ops"; 489 } 490 void runOnFunction() override { transferOpflowOpt(getFunction()); } 491 }; 492 493 struct TestVectorTransferLoweringPatterns 494 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> { 495 void getDependentDialects(DialectRegistry ®istry) const override { 496 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 497 } 498 StringRef getArgument() const final { 499 return "test-vector-transfer-lowering-patterns"; 500 } 501 StringRef getDescription() const final { 502 return "Test lowering patterns to lower transfer ops to other vector ops"; 503 } 504 void runOnFunction() override { 505 RewritePatternSet patterns(&getContext()); 506 populateVectorTransferLoweringPatterns(patterns); 507 populateVectorTransferPermutationMapLoweringPatterns(patterns); 508 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 509 } 510 }; 511 512 struct TestVectorMultiReductionLoweringPatterns 513 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 514 FunctionPass> { 515 TestVectorMultiReductionLoweringPatterns() = default; 516 TestVectorMultiReductionLoweringPatterns( 517 const TestVectorMultiReductionLoweringPatterns &pass) {} 518 void getDependentDialects(DialectRegistry ®istry) const override { 519 registry.insert<memref::MemRefDialect>(); 520 } 521 StringRef getArgument() const final { 522 return "test-vector-multi-reduction-lowering-patterns"; 523 } 524 StringRef getDescription() const final { 525 return "Test lowering patterns to lower vector.multi_reduction to other " 526 "vector ops"; 527 } 528 Option<bool> useOuterReductions{ 529 *this, "use-outer-reductions", 530 llvm::cl::desc("Move reductions to outer most dimensions"), 531 llvm::cl::init(false)}; 532 void runOnFunction() override { 533 RewritePatternSet patterns(&getContext()); 534 populateVectorMultiReductionLoweringPatterns( 535 patterns, useOuterReductions 536 ? vector::VectorMultiReductionLowering::InnerParallel 537 : vector::VectorMultiReductionLowering::InnerReduction); 538 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 539 } 540 }; 541 542 struct TestVectorTransferCollapseInnerMostContiguousDims 543 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 544 FunctionPass> { 545 TestVectorTransferCollapseInnerMostContiguousDims() = default; 546 TestVectorTransferCollapseInnerMostContiguousDims( 547 const TestVectorTransferCollapseInnerMostContiguousDims &pass) {} 548 549 void getDependentDialects(DialectRegistry ®istry) const override { 550 registry.insert<memref::MemRefDialect, AffineDialect>(); 551 } 552 553 StringRef getArgument() const final { 554 return "test-vector-transfer-collapse-inner-most-dims"; 555 } 556 557 StringRef getDescription() const final { 558 return "Test lowering patterns that reducedes the rank of the vector " 559 "transfer memory and vector operands."; 560 } 561 562 void runOnFunction() override { 563 RewritePatternSet patterns(&getContext()); 564 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 565 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 566 } 567 }; 568 569 struct TestVectorReduceToContractPatternsPatterns 570 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 571 FunctionPass> { 572 StringRef getArgument() const final { 573 return "test-vector-reduction-to-contract-patterns"; 574 } 575 StringRef getDescription() const final { 576 return "Test patterns to convert multireduce op to contract and combine " 577 "broadcast/transpose to contract"; 578 } 579 void runOnFunction() override { 580 RewritePatternSet patterns(&getContext()); 581 populateVectorReductionToContractPatterns(patterns); 582 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 583 } 584 }; 585 586 struct TestVectorTransferDropUnitDimsPatterns 587 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, FunctionPass> { 588 StringRef getArgument() const final { 589 return "test-vector-transfer-drop-unit-dims-patterns"; 590 } 591 void getDependentDialects(DialectRegistry ®istry) const override { 592 registry.insert<memref::MemRefDialect>(); 593 } 594 void runOnFunction() override { 595 RewritePatternSet patterns(&getContext()); 596 populateVectorTransferDropUnitDimsPatterns(patterns); 597 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 598 } 599 }; 600 601 struct TestFlattenVectorTransferPatterns 602 : public PassWrapper<TestFlattenVectorTransferPatterns, FunctionPass> { 603 StringRef getArgument() const final { 604 return "test-vector-transfer-flatten-patterns"; 605 } 606 StringRef getDescription() const final { 607 return "Test patterns to rewrite contiguous row-major N-dimensional " 608 "vector.transfer_{read,write} ops into 1D transfers"; 609 } 610 void getDependentDialects(DialectRegistry ®istry) const override { 611 registry.insert<memref::MemRefDialect>(); 612 } 613 void runOnFunction() override { 614 RewritePatternSet patterns(&getContext()); 615 populateFlattenVectorTransferPatterns(patterns); 616 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 617 } 618 }; 619 620 } // namespace 621 622 namespace mlir { 623 namespace test { 624 void registerTestVectorLowerings() { 625 PassRegistration<TestVectorToVectorLowering>(); 626 627 PassRegistration<TestVectorContractionLowering>(); 628 629 PassRegistration<TestVectorTransposeLowering>(); 630 631 PassRegistration<TestVectorUnrollingPatterns>(); 632 633 PassRegistration<TestVectorTransferUnrollingPatterns>(); 634 635 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 636 637 PassRegistration<TestVectorDistributePatterns>(); 638 639 PassRegistration<TestVectorToLoopPatterns>(); 640 641 PassRegistration<TestVectorTransferOpt>(); 642 643 PassRegistration<TestVectorTransferLoweringPatterns>(); 644 645 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 646 647 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 648 649 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 650 651 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 652 653 PassRegistration<TestFlattenVectorTransferPatterns>(); 654 } 655 } // namespace test 656 } // namespace mlir 657