1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Pass/PassManager.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 using namespace mlir::vector; 28 29 namespace { 30 31 struct TestVectorToVectorLowering 32 : public PassWrapper<TestVectorToVectorLowering, OperationPass<FuncOp>> { 33 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) 34 35 TestVectorToVectorLowering() = default; 36 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 37 : PassWrapper(pass) {} 38 StringRef getArgument() const final { 39 return "test-vector-to-vector-lowering"; 40 } 41 StringRef getDescription() const final { 42 return "Test lowering patterns between ops in the vector dialect"; 43 } 44 45 void getDependentDialects(DialectRegistry ®istry) const override { 46 registry.insert<AffineDialect>(); 47 } 48 49 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 50 llvm::cl::init(false)}; 51 52 void runOnOperation() override { 53 auto *ctx = &getContext(); 54 RewritePatternSet patterns(ctx); 55 if (unroll) { 56 populateVectorUnrollPatterns( 57 patterns, 58 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 59 filter)); 60 } 61 populateVectorToVectorCanonicalizationPatterns(patterns); 62 populateBubbleVectorBitCastOpPatterns(patterns); 63 populateCastAwayVectorLeadingOneDimPatterns(patterns); 64 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 65 } 66 67 private: 68 // Return the target shape based on op type. 69 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 70 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 71 return SmallVector<int64_t, 4>(2, 2); 72 if (isa<vector::ContractionOp>(op)) 73 return SmallVector<int64_t, 4>(3, 2); 74 // For transfer ops, just propagate the shape coming from 75 // InsertStridedSlices/ExtractStridedSlices. 76 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 77 VectorType dstVec; 78 for (Operation *users : readOp->getUsers()) { 79 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 80 if (!extract) 81 return llvm::None; 82 auto vecType = extract.getResult().getType().cast<VectorType>(); 83 if (dstVec && dstVec != vecType) 84 return llvm::None; 85 dstVec = vecType; 86 } 87 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 88 dstVec.getShape().end()); 89 } 90 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 91 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 92 if (!insert) 93 return llvm::None; 94 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 95 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 96 } 97 return llvm::None; 98 } 99 100 static LogicalResult filter(Operation *op) { 101 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 102 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 103 } 104 }; 105 106 struct TestVectorContractionLowering 107 : public PassWrapper<TestVectorContractionLowering, OperationPass<FuncOp>> { 108 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering) 109 110 StringRef getArgument() const final { 111 return "test-vector-contraction-lowering"; 112 } 113 StringRef getDescription() const final { 114 return "Test lowering patterns that lower contract ops in the vector " 115 "dialect"; 116 } 117 TestVectorContractionLowering() = default; 118 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 119 : PassWrapper(pass) {} 120 121 Option<bool> lowerToFlatMatrix{ 122 *this, "vector-lower-matrix-intrinsics", 123 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 124 llvm::cl::init(false)}; 125 Option<bool> lowerToOuterProduct{ 126 *this, "vector-outerproduct", 127 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 128 llvm::cl::init(false)}; 129 Option<bool> lowerToFilterOuterProduct{ 130 *this, "vector-filter-outerproduct", 131 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 132 "vectors of size 4."), 133 llvm::cl::init(false)}; 134 135 void runOnOperation() override { 136 RewritePatternSet patterns(&getContext()); 137 138 // Test on one pattern in isolation. 139 if (lowerToOuterProduct) { 140 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 141 VectorTransformsOptions options{lowering}; 142 patterns.add<ContractionOpToOuterProductOpLowering>(options, 143 &getContext()); 144 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 145 return; 146 } 147 148 // Test on one pattern in isolation. 149 if (lowerToFilterOuterProduct) { 150 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 151 VectorTransformsOptions options{lowering}; 152 patterns.add<ContractionOpToOuterProductOpLowering>( 153 options, &getContext(), [](vector::ContractionOp op) { 154 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 155 // where M is not 4. 156 if (op.getRhsType().getShape()[0] == 4) 157 return failure(); 158 return success(); 159 }); 160 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 161 return; 162 } 163 164 // Test on all contract lowering patterns. 165 VectorContractLowering contractLowering = VectorContractLowering::Dot; 166 if (lowerToFlatMatrix) 167 contractLowering = VectorContractLowering::Matmul; 168 VectorMultiReductionLowering vectorMultiReductionLowering = 169 VectorMultiReductionLowering::InnerParallel; 170 VectorTransformsOptions options{contractLowering, 171 vectorMultiReductionLowering, 172 VectorTransposeLowering()}; 173 populateVectorBroadcastLoweringPatterns(patterns); 174 populateVectorContractLoweringPatterns(patterns, options); 175 populateVectorMaskOpLoweringPatterns(patterns); 176 populateVectorShapeCastLoweringPatterns(patterns); 177 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 178 } 179 }; 180 181 struct TestVectorTransposeLowering 182 : public PassWrapper<TestVectorTransposeLowering, OperationPass<FuncOp>> { 183 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering) 184 185 StringRef getArgument() const final { 186 return "test-vector-transpose-lowering"; 187 } 188 StringRef getDescription() const final { 189 return "Test lowering patterns that lower contract ops in the vector " 190 "dialect"; 191 } 192 TestVectorTransposeLowering() = default; 193 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 194 : PassWrapper(pass) {} 195 196 Option<bool> lowerToEltwise{ 197 *this, "eltwise", 198 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 199 llvm::cl::init(false)}; 200 Option<bool> lowerToFlatTranspose{ 201 *this, "flat", 202 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 203 llvm::cl::init(false)}; 204 Option<bool> lowerToShuffleTranspose{ 205 *this, "shuffle", 206 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 207 llvm::cl::init(false)}; 208 Option<bool> lowerToAvx2{ 209 *this, "avx2", 210 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 211 llvm::cl::init(false)}; 212 213 void getDependentDialects(DialectRegistry ®istry) const override { 214 registry.insert<LLVM::LLVMDialect>(); 215 } 216 217 void runOnOperation() override { 218 RewritePatternSet patterns(&getContext()); 219 220 // Test on one pattern in isolation. 221 // Explicitly disable shape_cast lowering. 222 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 223 .enableVectorTransposeLowering() 224 .enableShapeCastLowering(false); 225 if (lowerToEltwise) { 226 options = options.setVectorTransformsOptions( 227 VectorTransformsOptions().setVectorTransposeLowering( 228 VectorTransposeLowering::EltWise)); 229 } 230 if (lowerToFlatTranspose) { 231 options = options.setVectorTransformsOptions( 232 VectorTransformsOptions().setVectorTransposeLowering( 233 VectorTransposeLowering::Flat)); 234 } 235 if (lowerToShuffleTranspose) { 236 options = options.setVectorTransformsOptions( 237 VectorTransformsOptions().setVectorTransposeLowering( 238 VectorTransposeLowering::Shuffle)); 239 } 240 if (lowerToAvx2) { 241 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 242 x86vector::avx2::LoweringOptions().setTransposeOptions( 243 x86vector::avx2::TransposeLoweringOptions() 244 .lower4x8xf32() 245 .lower8x8xf32())); 246 } 247 248 OpPassManager dynamicPM("func.func"); 249 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 250 if (failed(runPipeline(dynamicPM, getOperation()))) 251 return signalPassFailure(); 252 } 253 }; 254 255 struct TestVectorUnrollingPatterns 256 : public PassWrapper<TestVectorUnrollingPatterns, OperationPass<FuncOp>> { 257 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) 258 259 StringRef getArgument() const final { 260 return "test-vector-unrolling-patterns"; 261 } 262 StringRef getDescription() const final { 263 return "Test lowering patterns to unroll contract ops in the vector " 264 "dialect"; 265 } 266 TestVectorUnrollingPatterns() = default; 267 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 268 : PassWrapper(pass) {} 269 void runOnOperation() override { 270 MLIRContext *ctx = &getContext(); 271 RewritePatternSet patterns(ctx); 272 populateVectorUnrollPatterns( 273 patterns, UnrollVectorOptions() 274 .setNativeShape(ArrayRef<int64_t>{2, 2}) 275 .setFilterConstraint([](Operation *op) { 276 return success(isa<arith::AddFOp, vector::FMAOp, 277 vector::MultiDimReductionOp>(op)); 278 })); 279 populateVectorUnrollPatterns( 280 patterns, UnrollVectorOptions() 281 .setNativeShape(ArrayRef<int64_t>{2}) 282 .setFilterConstraint([](Operation *op) { 283 return success(isa<vector::ReductionOp>(op)); 284 })); 285 286 if (unrollBasedOnType) { 287 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 288 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 289 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 290 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 291 if (auto floatType = contractOp.getLhsType() 292 .getElementType() 293 .dyn_cast<FloatType>()) { 294 if (floatType.getWidth() == 16) { 295 nativeShape[2] = 4; 296 } 297 } 298 return nativeShape; 299 }; 300 populateVectorUnrollPatterns(patterns, 301 UnrollVectorOptions() 302 .setNativeShapeFn(nativeShapeFn) 303 .setFilterConstraint([](Operation *op) { 304 return success(isa<ContractionOp>(op)); 305 })); 306 } else { 307 populateVectorUnrollPatterns( 308 patterns, UnrollVectorOptions() 309 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 310 .setFilterConstraint([](Operation *op) { 311 return success(isa<ContractionOp>(op)); 312 })); 313 } 314 populateVectorToVectorCanonicalizationPatterns(patterns); 315 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 316 } 317 318 Option<bool> unrollBasedOnType{ 319 *this, "unroll-based-on-type", 320 llvm::cl::desc("Set the unroll factor based on type of the operation"), 321 llvm::cl::init(false)}; 322 }; 323 324 struct TestVectorDistributePatterns 325 : public PassWrapper<TestVectorDistributePatterns, OperationPass<FuncOp>> { 326 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns) 327 328 StringRef getArgument() const final { 329 return "test-vector-distribute-patterns"; 330 } 331 StringRef getDescription() const final { 332 return "Test lowering patterns to distribute vector ops in the vector " 333 "dialect"; 334 } 335 TestVectorDistributePatterns() = default; 336 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 337 : PassWrapper(pass) {} 338 void getDependentDialects(DialectRegistry ®istry) const override { 339 registry.insert<VectorDialect>(); 340 registry.insert<AffineDialect>(); 341 } 342 ListOption<int32_t> multiplicity{ 343 *this, "distribution-multiplicity", 344 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 345 346 void runOnOperation() override { 347 MLIRContext *ctx = &getContext(); 348 RewritePatternSet patterns(ctx); 349 FuncOp func = getOperation(); 350 func.walk([&](arith::AddFOp op) { 351 OpBuilder builder(op); 352 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 353 SmallVector<int64_t, 2> mul; 354 SmallVector<AffineExpr, 2> perm; 355 SmallVector<Value, 2> ids; 356 unsigned count = 0; 357 // Remove the multiplicity of 1 and calculate the affine map based on 358 // the multiplicity. 359 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 360 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 361 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 362 mul.push_back(m[i]); 363 ids.push_back(func.getArgument(count++)); 364 perm.push_back(getAffineDimExpr(i, ctx)); 365 } 366 } 367 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 368 perm, ctx); 369 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 370 builder, op.getOperation(), ids, mul, map); 371 if (ops.hasValue()) { 372 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 373 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 374 extractOp); 375 } 376 } 377 }); 378 populatePropagateVectorDistributionPatterns(patterns); 379 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 380 } 381 }; 382 383 struct TestVectorToLoopPatterns 384 : public PassWrapper<TestVectorToLoopPatterns, OperationPass<FuncOp>> { 385 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns) 386 387 StringRef getArgument() const final { return "test-vector-to-forloop"; } 388 StringRef getDescription() const final { 389 return "Test lowering patterns to break up a vector op into a for loop"; 390 } 391 TestVectorToLoopPatterns() = default; 392 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 393 : PassWrapper(pass) {} 394 void getDependentDialects(DialectRegistry ®istry) const override { 395 registry.insert<VectorDialect>(); 396 registry.insert<AffineDialect>(); 397 } 398 Option<int32_t> multiplicity{ 399 *this, "distribution-multiplicity", 400 llvm::cl::desc("Set the multiplicity used for distributing vector"), 401 llvm::cl::init(32)}; 402 void runOnOperation() override { 403 MLIRContext *ctx = &getContext(); 404 RewritePatternSet patterns(ctx); 405 FuncOp func = getOperation(); 406 func.walk([&](arith::AddFOp op) { 407 // Check that the operation type can be broken down into a loop. 408 VectorType type = op.getType().dyn_cast<VectorType>(); 409 if (!type || type.getRank() != 1 || 410 type.getNumElements() % multiplicity != 0) 411 return mlir::WalkResult::advance(); 412 auto filterAlloc = [](Operation *op) { 413 return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op); 414 }; 415 auto dependentOps = getSlice(op, filterAlloc); 416 // Create a loop and move instructions from the Op slice into the loop. 417 OpBuilder builder(op); 418 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 419 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 420 auto numIter = 421 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 422 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 423 for (Operation *it : dependentOps) { 424 it->moveBefore(forOp.getBody()->getTerminator()); 425 } 426 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 427 // break up the original op and let the patterns propagate. 428 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 429 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 430 map); 431 if (ops.hasValue()) { 432 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 433 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 434 } 435 return mlir::WalkResult::interrupt(); 436 }); 437 populatePropagateVectorDistributionPatterns(patterns); 438 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 439 } 440 }; 441 442 struct TestVectorTransferUnrollingPatterns 443 : public PassWrapper<TestVectorTransferUnrollingPatterns, 444 OperationPass<FuncOp>> { 445 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 446 TestVectorTransferUnrollingPatterns) 447 448 void getDependentDialects(DialectRegistry ®istry) const override { 449 registry.insert<AffineDialect>(); 450 } 451 StringRef getArgument() const final { 452 return "test-vector-transfer-unrolling-patterns"; 453 } 454 StringRef getDescription() const final { 455 return "Test lowering patterns to unroll transfer ops in the vector " 456 "dialect"; 457 } 458 void runOnOperation() override { 459 MLIRContext *ctx = &getContext(); 460 RewritePatternSet patterns(ctx); 461 populateVectorUnrollPatterns( 462 patterns, 463 UnrollVectorOptions() 464 .setNativeShape(ArrayRef<int64_t>{2, 2}) 465 .setFilterConstraint([](Operation *op) { 466 return success( 467 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 468 })); 469 populateVectorToVectorCanonicalizationPatterns(patterns); 470 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 471 } 472 }; 473 474 struct TestVectorTransferFullPartialSplitPatterns 475 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 476 OperationPass<FuncOp>> { 477 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 478 TestVectorTransferFullPartialSplitPatterns) 479 480 StringRef getArgument() const final { 481 return "test-vector-transfer-full-partial-split"; 482 } 483 StringRef getDescription() const final { 484 return "Test lowering patterns to split " 485 "transfer ops via scf.if + linalg ops"; 486 } 487 TestVectorTransferFullPartialSplitPatterns() = default; 488 TestVectorTransferFullPartialSplitPatterns( 489 const TestVectorTransferFullPartialSplitPatterns &pass) 490 : PassWrapper(pass) {} 491 492 void getDependentDialects(DialectRegistry ®istry) const override { 493 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 494 scf::SCFDialect>(); 495 } 496 497 Option<bool> useLinalgOps{ 498 *this, "use-memref-copy", 499 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 500 "memref.copy operations."), 501 llvm::cl::init(false)}; 502 void runOnOperation() override { 503 MLIRContext *ctx = &getContext(); 504 RewritePatternSet patterns(ctx); 505 VectorTransformsOptions options; 506 if (useLinalgOps) 507 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 508 else 509 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 510 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 511 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 512 } 513 }; 514 515 struct TestVectorTransferOpt 516 : public PassWrapper<TestVectorTransferOpt, OperationPass<FuncOp>> { 517 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) 518 519 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 520 StringRef getDescription() const final { 521 return "Test optimization transformations for transfer ops"; 522 } 523 void runOnOperation() override { transferOpflowOpt(getOperation()); } 524 }; 525 526 struct TestVectorTransferLoweringPatterns 527 : public PassWrapper<TestVectorTransferLoweringPatterns, 528 OperationPass<FuncOp>> { 529 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 530 TestVectorTransferLoweringPatterns) 531 532 void getDependentDialects(DialectRegistry ®istry) const override { 533 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 534 } 535 StringRef getArgument() const final { 536 return "test-vector-transfer-lowering-patterns"; 537 } 538 StringRef getDescription() const final { 539 return "Test lowering patterns to lower transfer ops to other vector ops"; 540 } 541 void runOnOperation() override { 542 RewritePatternSet patterns(&getContext()); 543 populateVectorTransferLoweringPatterns(patterns); 544 populateVectorTransferPermutationMapLoweringPatterns(patterns); 545 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 546 } 547 }; 548 549 struct TestVectorMultiReductionLoweringPatterns 550 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 551 OperationPass<FuncOp>> { 552 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 553 TestVectorMultiReductionLoweringPatterns) 554 555 TestVectorMultiReductionLoweringPatterns() = default; 556 TestVectorMultiReductionLoweringPatterns( 557 const TestVectorMultiReductionLoweringPatterns &pass) 558 : PassWrapper(pass) {} 559 void getDependentDialects(DialectRegistry ®istry) const override { 560 registry.insert<memref::MemRefDialect>(); 561 } 562 StringRef getArgument() const final { 563 return "test-vector-multi-reduction-lowering-patterns"; 564 } 565 StringRef getDescription() const final { 566 return "Test lowering patterns to lower vector.multi_reduction to other " 567 "vector ops"; 568 } 569 Option<bool> useOuterReductions{ 570 *this, "use-outer-reductions", 571 llvm::cl::desc("Move reductions to outer most dimensions"), 572 llvm::cl::init(false)}; 573 void runOnOperation() override { 574 RewritePatternSet patterns(&getContext()); 575 populateVectorMultiReductionLoweringPatterns( 576 patterns, useOuterReductions 577 ? vector::VectorMultiReductionLowering::InnerParallel 578 : vector::VectorMultiReductionLowering::InnerReduction); 579 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 580 } 581 }; 582 583 struct TestVectorTransferCollapseInnerMostContiguousDims 584 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 585 OperationPass<FuncOp>> { 586 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 587 TestVectorTransferCollapseInnerMostContiguousDims) 588 589 TestVectorTransferCollapseInnerMostContiguousDims() = default; 590 TestVectorTransferCollapseInnerMostContiguousDims( 591 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 592 593 void getDependentDialects(DialectRegistry ®istry) const override { 594 registry.insert<memref::MemRefDialect, AffineDialect>(); 595 } 596 597 StringRef getArgument() const final { 598 return "test-vector-transfer-collapse-inner-most-dims"; 599 } 600 601 StringRef getDescription() const final { 602 return "Test lowering patterns that reducedes the rank of the vector " 603 "transfer memory and vector operands."; 604 } 605 606 void runOnOperation() override { 607 RewritePatternSet patterns(&getContext()); 608 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 609 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 610 } 611 }; 612 613 struct TestVectorReduceToContractPatternsPatterns 614 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 615 OperationPass<FuncOp>> { 616 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 617 TestVectorReduceToContractPatternsPatterns) 618 619 StringRef getArgument() const final { 620 return "test-vector-reduction-to-contract-patterns"; 621 } 622 StringRef getDescription() const final { 623 return "Test patterns to convert multireduce op to contract and combine " 624 "broadcast/transpose to contract"; 625 } 626 void runOnOperation() override { 627 RewritePatternSet patterns(&getContext()); 628 populateVectorReductionToContractPatterns(patterns); 629 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 630 } 631 }; 632 633 struct TestVectorTransferDropUnitDimsPatterns 634 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, 635 OperationPass<FuncOp>> { 636 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 637 TestVectorTransferDropUnitDimsPatterns) 638 639 StringRef getArgument() const final { 640 return "test-vector-transfer-drop-unit-dims-patterns"; 641 } 642 void getDependentDialects(DialectRegistry ®istry) const override { 643 registry.insert<memref::MemRefDialect>(); 644 } 645 void runOnOperation() override { 646 RewritePatternSet patterns(&getContext()); 647 populateVectorTransferDropUnitDimsPatterns(patterns); 648 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 649 } 650 }; 651 652 struct TestFlattenVectorTransferPatterns 653 : public PassWrapper<TestFlattenVectorTransferPatterns, 654 OperationPass<FuncOp>> { 655 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 656 TestFlattenVectorTransferPatterns) 657 658 StringRef getArgument() const final { 659 return "test-vector-transfer-flatten-patterns"; 660 } 661 StringRef getDescription() const final { 662 return "Test patterns to rewrite contiguous row-major N-dimensional " 663 "vector.transfer_{read,write} ops into 1D transfers"; 664 } 665 void getDependentDialects(DialectRegistry ®istry) const override { 666 registry.insert<memref::MemRefDialect>(); 667 } 668 void runOnOperation() override { 669 RewritePatternSet patterns(&getContext()); 670 populateFlattenVectorTransferPatterns(patterns); 671 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 672 } 673 }; 674 675 struct TestVectorScanLowering 676 : public PassWrapper<TestVectorScanLowering, OperationPass<FuncOp>> { 677 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) 678 679 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 680 StringRef getDescription() const final { 681 return "Test lowering patterns that lower the scan op in the vector " 682 "dialect"; 683 } 684 void runOnOperation() override { 685 RewritePatternSet patterns(&getContext()); 686 populateVectorScanLoweringPatterns(patterns); 687 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 688 } 689 }; 690 691 } // namespace 692 693 namespace mlir { 694 namespace test { 695 void registerTestVectorLowerings() { 696 PassRegistration<TestVectorToVectorLowering>(); 697 698 PassRegistration<TestVectorContractionLowering>(); 699 700 PassRegistration<TestVectorTransposeLowering>(); 701 702 PassRegistration<TestVectorUnrollingPatterns>(); 703 704 PassRegistration<TestVectorTransferUnrollingPatterns>(); 705 706 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 707 708 PassRegistration<TestVectorDistributePatterns>(); 709 710 PassRegistration<TestVectorToLoopPatterns>(); 711 712 PassRegistration<TestVectorTransferOpt>(); 713 714 PassRegistration<TestVectorTransferLoweringPatterns>(); 715 716 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 717 718 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 719 720 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 721 722 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 723 724 PassRegistration<TestFlattenVectorTransferPatterns>(); 725 726 PassRegistration<TestVectorScanLowering>(); 727 } 728 } // namespace test 729 } // namespace mlir 730