1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/Dialect/Vector/VectorTransforms.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Pass/PassManager.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 using namespace mlir; 25 using namespace mlir::linalg; 26 using namespace mlir::vector; 27 28 namespace { 29 30 struct TestVectorToVectorLowering 31 : public PassWrapper<TestVectorToVectorLowering, FunctionPass> { 32 TestVectorToVectorLowering() = default; 33 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) {} 34 StringRef getArgument() const final { 35 return "test-vector-to-vector-lowering"; 36 } 37 StringRef getDescription() const final { 38 return "Test lowering patterns between ops in the vector dialect"; 39 } 40 41 void getDependentDialects(DialectRegistry ®istry) const override { 42 registry.insert<AffineDialect>(); 43 } 44 45 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 46 llvm::cl::init(false)}; 47 48 void runOnFunction() override { 49 auto *ctx = &getContext(); 50 RewritePatternSet patterns(ctx); 51 if (unroll) { 52 populateVectorUnrollPatterns( 53 patterns, 54 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 55 filter)); 56 } 57 populateVectorToVectorCanonicalizationPatterns(patterns); 58 populateBubbleVectorBitCastOpPatterns(patterns); 59 populateCastAwayVectorLeadingOneDimPatterns(patterns); 60 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 61 } 62 63 private: 64 // Return the target shape based on op type. 65 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 66 if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op)) 67 return SmallVector<int64_t, 4>(2, 2); 68 if (isa<vector::ContractionOp>(op)) 69 return SmallVector<int64_t, 4>(3, 2); 70 // For transfer ops, just propagate the shape coming from 71 // InsertStridedSlices/ExtractStridedSlices. 72 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 73 VectorType dstVec; 74 for (Operation *users : readOp->getUsers()) { 75 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 76 if (!extract) 77 return llvm::None; 78 auto vecType = extract.getResult().getType().cast<VectorType>(); 79 if (dstVec && dstVec != vecType) 80 return llvm::None; 81 dstVec = vecType; 82 } 83 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 84 dstVec.getShape().end()); 85 } 86 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 87 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>(); 88 if (!insert) 89 return llvm::None; 90 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 91 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 92 } 93 return llvm::None; 94 } 95 96 static LogicalResult filter(Operation *op) { 97 return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp, 98 TransferReadOp, TransferWriteOp>(op)); 99 } 100 }; 101 102 struct TestVectorContractionLowering 103 : public PassWrapper<TestVectorContractionLowering, FunctionPass> { 104 StringRef getArgument() const final { 105 return "test-vector-contraction-lowering"; 106 } 107 StringRef getDescription() const final { 108 return "Test lowering patterns that lower contract ops in the vector " 109 "dialect"; 110 } 111 TestVectorContractionLowering() = default; 112 TestVectorContractionLowering(const TestVectorContractionLowering &pass) {} 113 114 Option<bool> lowerToFlatMatrix{ 115 *this, "vector-lower-matrix-intrinsics", 116 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 117 llvm::cl::init(false)}; 118 Option<bool> lowerToOuterProduct{ 119 *this, "vector-outerproduct", 120 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 121 llvm::cl::init(false)}; 122 Option<bool> lowerToFilterOuterProduct{ 123 *this, "vector-filter-outerproduct", 124 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 125 "vectors of size 4."), 126 llvm::cl::init(false)}; 127 128 void runOnFunction() override { 129 RewritePatternSet patterns(&getContext()); 130 131 // Test on one pattern in isolation. 132 if (lowerToOuterProduct) { 133 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 134 VectorTransformsOptions options{lowering}; 135 patterns.add<ContractionOpToOuterProductOpLowering>(options, 136 &getContext()); 137 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 138 return; 139 } 140 141 // Test on one pattern in isolation. 142 if (lowerToFilterOuterProduct) { 143 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 144 VectorTransformsOptions options{lowering}; 145 patterns.add<ContractionOpToOuterProductOpLowering>( 146 options, &getContext(), [](vector::ContractionOp op) { 147 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 148 // where M is not 4. 149 if (op.getRhsType().getShape()[0] == 4) 150 return failure(); 151 return success(); 152 }); 153 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 154 return; 155 } 156 157 // Test on all contract lowering patterns. 158 VectorContractLowering contractLowering = VectorContractLowering::Dot; 159 if (lowerToFlatMatrix) 160 contractLowering = VectorContractLowering::Matmul; 161 VectorMultiReductionLowering vectorMultiReductionLowering = 162 VectorMultiReductionLowering::InnerParallel; 163 VectorTransformsOptions options{contractLowering, 164 vectorMultiReductionLowering, 165 VectorTransposeLowering()}; 166 populateVectorBroadcastLoweringPatterns(patterns); 167 populateVectorContractLoweringPatterns(patterns, options); 168 populateVectorMaskOpLoweringPatterns(patterns); 169 populateVectorShapeCastLoweringPatterns(patterns); 170 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 171 } 172 }; 173 174 struct TestVectorTransposeLowering 175 : public PassWrapper<TestVectorTransposeLowering, FunctionPass> { 176 StringRef getArgument() const final { 177 return "test-vector-transpose-lowering"; 178 } 179 StringRef getDescription() const final { 180 return "Test lowering patterns that lower contract ops in the vector " 181 "dialect"; 182 } 183 TestVectorTransposeLowering() = default; 184 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) {} 185 186 Option<bool> lowerToEltwise{ 187 *this, "eltwise", 188 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 189 llvm::cl::init(false)}; 190 Option<bool> lowerToFlatTranspose{ 191 *this, "flat", 192 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 193 llvm::cl::init(false)}; 194 Option<bool> lowerToShuffleTranspose{ 195 *this, "shuffle", 196 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 197 llvm::cl::init(false)}; 198 Option<bool> lowerToAvx2{ 199 *this, "avx2", 200 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 201 llvm::cl::init(false)}; 202 203 void runOnFunction() override { 204 RewritePatternSet patterns(&getContext()); 205 206 // Test on one pattern in isolation. 207 // Explicitly disable shape_cast lowering. 208 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 209 .enableVectorTransposeLowering() 210 .enableShapeCastLowering(false); 211 if (lowerToEltwise) { 212 options = options.setVectorTransformsOptions( 213 VectorTransformsOptions().setVectorTransposeLowering( 214 VectorTransposeLowering::EltWise)); 215 } 216 if (lowerToFlatTranspose) { 217 options = options.setVectorTransformsOptions( 218 VectorTransformsOptions().setVectorTransposeLowering( 219 VectorTransposeLowering::Flat)); 220 } 221 if (lowerToShuffleTranspose) { 222 options = options.setVectorTransformsOptions( 223 VectorTransformsOptions().setVectorTransposeLowering( 224 VectorTransposeLowering::Shuffle)); 225 } 226 if (lowerToAvx2) { 227 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 228 x86vector::avx2::LoweringOptions().setTransposeOptions( 229 x86vector::avx2::TransposeLoweringOptions() 230 .lower4x8xf32() 231 .lower8x8xf32())); 232 } 233 234 OpPassManager dynamicPM("builtin.func"); 235 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 236 if (failed(runPipeline(dynamicPM, getFunction()))) 237 return signalPassFailure(); 238 } 239 }; 240 241 struct TestVectorUnrollingPatterns 242 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> { 243 StringRef getArgument() const final { 244 return "test-vector-unrolling-patterns"; 245 } 246 StringRef getDescription() const final { 247 return "Test lowering patterns to unroll contract ops in the vector " 248 "dialect"; 249 } 250 TestVectorUnrollingPatterns() = default; 251 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} 252 void runOnFunction() override { 253 MLIRContext *ctx = &getContext(); 254 RewritePatternSet patterns(ctx); 255 populateVectorUnrollPatterns( 256 patterns, UnrollVectorOptions() 257 .setNativeShape(ArrayRef<int64_t>{2, 2}) 258 .setFilterConstraint([](Operation *op) { 259 return success(isa<arith::AddFOp, vector::FMAOp>(op)); 260 })); 261 262 if (unrollBasedOnType) { 263 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 264 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 265 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 266 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 267 if (auto floatType = contractOp.getLhsType() 268 .getElementType() 269 .dyn_cast<FloatType>()) { 270 if (floatType.getWidth() == 16) { 271 nativeShape[2] = 4; 272 } 273 } 274 return nativeShape; 275 }; 276 populateVectorUnrollPatterns(patterns, 277 UnrollVectorOptions() 278 .setNativeShapeFn(nativeShapeFn) 279 .setFilterConstraint([](Operation *op) { 280 return success(isa<ContractionOp>(op)); 281 })); 282 } else { 283 populateVectorUnrollPatterns( 284 patterns, UnrollVectorOptions() 285 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 286 .setFilterConstraint([](Operation *op) { 287 return success(isa<ContractionOp>(op)); 288 })); 289 } 290 populateVectorToVectorCanonicalizationPatterns(patterns); 291 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 292 } 293 294 Option<bool> unrollBasedOnType{ 295 *this, "unroll-based-on-type", 296 llvm::cl::desc("Set the unroll factor based on type of the operation"), 297 llvm::cl::init(false)}; 298 }; 299 300 struct TestVectorDistributePatterns 301 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> { 302 StringRef getArgument() const final { 303 return "test-vector-distribute-patterns"; 304 } 305 StringRef getDescription() const final { 306 return "Test lowering patterns to distribute vector ops in the vector " 307 "dialect"; 308 } 309 TestVectorDistributePatterns() = default; 310 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {} 311 void getDependentDialects(DialectRegistry ®istry) const override { 312 registry.insert<VectorDialect>(); 313 registry.insert<AffineDialect>(); 314 } 315 ListOption<int32_t> multiplicity{ 316 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 317 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 318 319 void runOnFunction() override { 320 MLIRContext *ctx = &getContext(); 321 RewritePatternSet patterns(ctx); 322 FuncOp func = getFunction(); 323 func.walk([&](arith::AddFOp op) { 324 OpBuilder builder(op); 325 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 326 SmallVector<int64_t, 2> mul; 327 SmallVector<AffineExpr, 2> perm; 328 SmallVector<Value, 2> ids; 329 unsigned count = 0; 330 // Remove the multiplicity of 1 and calculate the affine map based on 331 // the multiplicity. 332 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 333 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 334 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 335 mul.push_back(m[i]); 336 ids.push_back(func.getArgument(count++)); 337 perm.push_back(getAffineDimExpr(i, ctx)); 338 } 339 } 340 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 341 perm, ctx); 342 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 343 builder, op.getOperation(), ids, mul, map); 344 if (ops.hasValue()) { 345 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 346 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 347 extractOp); 348 } 349 } 350 }); 351 populatePropagateVectorDistributionPatterns(patterns); 352 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 353 } 354 }; 355 356 struct TestVectorToLoopPatterns 357 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> { 358 StringRef getArgument() const final { return "test-vector-to-forloop"; } 359 StringRef getDescription() const final { 360 return "Test lowering patterns to break up a vector op into a for loop"; 361 } 362 TestVectorToLoopPatterns() = default; 363 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {} 364 void getDependentDialects(DialectRegistry ®istry) const override { 365 registry.insert<VectorDialect>(); 366 registry.insert<AffineDialect>(); 367 } 368 Option<int32_t> multiplicity{ 369 *this, "distribution-multiplicity", 370 llvm::cl::desc("Set the multiplicity used for distributing vector"), 371 llvm::cl::init(32)}; 372 void runOnFunction() override { 373 MLIRContext *ctx = &getContext(); 374 RewritePatternSet patterns(ctx); 375 FuncOp func = getFunction(); 376 func.walk([&](arith::AddFOp op) { 377 // Check that the operation type can be broken down into a loop. 378 VectorType type = op.getType().dyn_cast<VectorType>(); 379 if (!type || type.getRank() != 1 || 380 type.getNumElements() % multiplicity != 0) 381 return mlir::WalkResult::advance(); 382 auto filterAlloc = [](Operation *op) { 383 if (isa<arith::ConstantOp, memref::AllocOp, CallOp>(op)) 384 return false; 385 return true; 386 }; 387 auto dependentOps = getSlice(op, filterAlloc); 388 // Create a loop and move instructions from the Op slice into the loop. 389 OpBuilder builder(op); 390 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 391 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 392 auto numIter = 393 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 394 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 395 for (Operation *it : dependentOps) { 396 it->moveBefore(forOp.getBody()->getTerminator()); 397 } 398 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 399 // break up the original op and let the patterns propagate. 400 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 401 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 402 map); 403 if (ops.hasValue()) { 404 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 405 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 406 } 407 return mlir::WalkResult::interrupt(); 408 }); 409 populatePropagateVectorDistributionPatterns(patterns); 410 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 411 } 412 }; 413 414 struct TestVectorTransferUnrollingPatterns 415 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> { 416 void getDependentDialects(DialectRegistry ®istry) const override { 417 registry.insert<AffineDialect>(); 418 } 419 StringRef getArgument() const final { 420 return "test-vector-transfer-unrolling-patterns"; 421 } 422 StringRef getDescription() const final { 423 return "Test lowering patterns to unroll transfer ops in the vector " 424 "dialect"; 425 } 426 void runOnFunction() override { 427 MLIRContext *ctx = &getContext(); 428 RewritePatternSet patterns(ctx); 429 populateVectorUnrollPatterns( 430 patterns, 431 UnrollVectorOptions() 432 .setNativeShape(ArrayRef<int64_t>{2, 2}) 433 .setFilterConstraint([](Operation *op) { 434 return success( 435 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 436 })); 437 populateVectorToVectorCanonicalizationPatterns(patterns); 438 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 439 } 440 }; 441 442 struct TestVectorTransferFullPartialSplitPatterns 443 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 444 FunctionPass> { 445 StringRef getArgument() const final { 446 return "test-vector-transfer-full-partial-split"; 447 } 448 StringRef getDescription() const final { 449 return "Test lowering patterns to split " 450 "transfer ops via scf.if + linalg ops"; 451 } 452 TestVectorTransferFullPartialSplitPatterns() = default; 453 TestVectorTransferFullPartialSplitPatterns( 454 const TestVectorTransferFullPartialSplitPatterns &pass) {} 455 456 void getDependentDialects(DialectRegistry ®istry) const override { 457 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 458 scf::SCFDialect>(); 459 } 460 461 Option<bool> useLinalgOps{ 462 *this, "use-linalg-copy", 463 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 464 "linalg.copy operations."), 465 llvm::cl::init(false)}; 466 void runOnFunction() override { 467 MLIRContext *ctx = &getContext(); 468 RewritePatternSet patterns(ctx); 469 VectorTransformsOptions options; 470 if (useLinalgOps) 471 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 472 else 473 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 474 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 475 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 476 } 477 }; 478 479 struct TestVectorTransferOpt 480 : public PassWrapper<TestVectorTransferOpt, FunctionPass> { 481 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 482 StringRef getDescription() const final { 483 return "Test optimization transformations for transfer ops"; 484 } 485 void runOnFunction() override { transferOpflowOpt(getFunction()); } 486 }; 487 488 struct TestVectorTransferLoweringPatterns 489 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> { 490 void getDependentDialects(DialectRegistry ®istry) const override { 491 registry.insert<memref::MemRefDialect>(); 492 } 493 StringRef getArgument() const final { 494 return "test-vector-transfer-lowering-patterns"; 495 } 496 StringRef getDescription() const final { 497 return "Test lowering patterns to lower transfer ops to other vector ops"; 498 } 499 void runOnFunction() override { 500 RewritePatternSet patterns(&getContext()); 501 populateVectorTransferLoweringPatterns(patterns); 502 populateVectorTransferPermutationMapLoweringPatterns(patterns); 503 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 504 } 505 }; 506 507 struct TestVectorMultiReductionLoweringPatterns 508 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 509 FunctionPass> { 510 TestVectorMultiReductionLoweringPatterns() = default; 511 TestVectorMultiReductionLoweringPatterns( 512 const TestVectorMultiReductionLoweringPatterns &pass) {} 513 void getDependentDialects(DialectRegistry ®istry) const override { 514 registry.insert<memref::MemRefDialect>(); 515 } 516 StringRef getArgument() const final { 517 return "test-vector-multi-reduction-lowering-patterns"; 518 } 519 StringRef getDescription() const final { 520 return "Test lowering patterns to lower vector.multi_reduction to other " 521 "vector ops"; 522 } 523 Option<bool> useOuterReductions{ 524 *this, "use-outer-reductions", 525 llvm::cl::desc("Move reductions to outer most dimensions"), 526 llvm::cl::init(false)}; 527 void runOnFunction() override { 528 RewritePatternSet patterns(&getContext()); 529 populateVectorMultiReductionLoweringPatterns( 530 patterns, useOuterReductions 531 ? vector::VectorMultiReductionLowering::InnerParallel 532 : vector::VectorMultiReductionLowering::InnerReduction); 533 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 534 } 535 }; 536 537 struct TestVectorTransferCollapseInnerMostContiguousDims 538 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 539 FunctionPass> { 540 TestVectorTransferCollapseInnerMostContiguousDims() = default; 541 TestVectorTransferCollapseInnerMostContiguousDims( 542 const TestVectorTransferCollapseInnerMostContiguousDims &pass) {} 543 544 void getDependentDialects(DialectRegistry ®istry) const override { 545 registry.insert<memref::MemRefDialect, AffineDialect>(); 546 } 547 548 StringRef getArgument() const final { 549 return "test-vector-transfer-collapse-inner-most-dims"; 550 } 551 552 StringRef getDescription() const final { 553 return "Test lowering patterns that reducedes the rank of the vector " 554 "transfer memory and vector operands."; 555 } 556 557 void runOnFunction() override { 558 RewritePatternSet patterns(&getContext()); 559 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 560 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 561 } 562 }; 563 564 struct TestVectorReduceToContractPatternsPatterns 565 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 566 FunctionPass> { 567 StringRef getArgument() const final { 568 return "test-vector-reduction-to-contract-patterns"; 569 } 570 StringRef getDescription() const final { 571 return "Test patterns to convert multireduce op to contract and combine " 572 "broadcast/transpose to contract"; 573 } 574 void runOnFunction() override { 575 RewritePatternSet patterns(&getContext()); 576 populateVectorReductionToContractPatterns(patterns); 577 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 578 } 579 }; 580 581 } // end anonymous namespace 582 583 namespace mlir { 584 namespace test { 585 void registerTestVectorLowerings() { 586 PassRegistration<TestVectorToVectorLowering>(); 587 588 PassRegistration<TestVectorContractionLowering>(); 589 590 PassRegistration<TestVectorTransposeLowering>(); 591 592 PassRegistration<TestVectorUnrollingPatterns>(); 593 594 PassRegistration<TestVectorTransferUnrollingPatterns>(); 595 596 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 597 598 PassRegistration<TestVectorDistributePatterns>(); 599 600 PassRegistration<TestVectorToLoopPatterns>(); 601 602 PassRegistration<TestVectorTransferOpt>(); 603 604 PassRegistration<TestVectorTransferLoweringPatterns>(); 605 606 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 607 608 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 609 610 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 611 } 612 } // namespace test 613 } // namespace mlir 614