1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Pass/PassManager.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 using namespace mlir::vector; 28 29 namespace { 30 31 struct TestVectorToVectorLowering 32 : public PassWrapper<TestVectorToVectorLowering, 33 OperationPass<func::FuncOp>> { 34 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) 35 36 TestVectorToVectorLowering() = default; 37 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 38 : PassWrapper(pass) {} 39 StringRef getArgument() const final { 40 return "test-vector-to-vector-lowering"; 41 } 42 StringRef getDescription() const final { 43 return "Test lowering patterns between ops in the vector dialect"; 44 } 45 46 void getDependentDialects(DialectRegistry ®istry) const override { 47 registry.insert<AffineDialect>(); 48 } 49 50 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 51 llvm::cl::init(false)}; 52 53 void runOnOperation() override { 54 auto *ctx = &getContext(); 55 RewritePatternSet patterns(ctx); 56 if (unroll) { 57 populateVectorUnrollPatterns( 58 patterns, 59 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 60 filter)); 61 } 62 populateVectorToVectorCanonicalizationPatterns(patterns); 63 populateBubbleVectorBitCastOpPatterns(patterns); 64 populateCastAwayVectorLeadingOneDimPatterns(patterns); 65 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 66 } 67 68 private: 69 // Return the target shape based on op type. 70 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 71 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 72 return SmallVector<int64_t, 4>(2, 2); 73 if (isa<vector::ContractionOp>(op)) 74 return SmallVector<int64_t, 4>(3, 2); 75 // For transfer ops, just propagate the shape coming from 76 // InsertStridedSlices/ExtractStridedSlices. 77 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 78 VectorType dstVec; 79 for (Operation *users : readOp->getUsers()) { 80 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 81 if (!extract) 82 return llvm::None; 83 auto vecType = extract.getResult().getType().cast<VectorType>(); 84 if (dstVec && dstVec != vecType) 85 return llvm::None; 86 dstVec = vecType; 87 } 88 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 89 dstVec.getShape().end()); 90 } 91 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 92 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 93 if (!insert) 94 return llvm::None; 95 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 96 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 97 } 98 return llvm::None; 99 } 100 101 static LogicalResult filter(Operation *op) { 102 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 103 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 104 } 105 }; 106 107 struct TestVectorContractionLowering 108 : public PassWrapper<TestVectorContractionLowering, 109 OperationPass<func::FuncOp>> { 110 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering) 111 112 StringRef getArgument() const final { 113 return "test-vector-contraction-lowering"; 114 } 115 StringRef getDescription() const final { 116 return "Test lowering patterns that lower contract ops in the vector " 117 "dialect"; 118 } 119 TestVectorContractionLowering() = default; 120 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 121 : PassWrapper(pass) {} 122 123 Option<bool> lowerToFlatMatrix{ 124 *this, "vector-lower-matrix-intrinsics", 125 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 126 llvm::cl::init(false)}; 127 Option<bool> lowerToOuterProduct{ 128 *this, "vector-outerproduct", 129 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 130 llvm::cl::init(false)}; 131 Option<bool> lowerToFilterOuterProduct{ 132 *this, "vector-filter-outerproduct", 133 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 134 "vectors of size 4."), 135 llvm::cl::init(false)}; 136 137 void runOnOperation() override { 138 RewritePatternSet patterns(&getContext()); 139 140 // Test on one pattern in isolation. 141 if (lowerToOuterProduct) { 142 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 143 VectorTransformsOptions options{lowering}; 144 patterns.add<ContractionOpToOuterProductOpLowering>(options, 145 &getContext()); 146 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 147 return; 148 } 149 150 // Test on one pattern in isolation. 151 if (lowerToFilterOuterProduct) { 152 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 153 VectorTransformsOptions options{lowering}; 154 patterns.add<ContractionOpToOuterProductOpLowering>( 155 options, &getContext(), [](vector::ContractionOp op) { 156 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 157 // where M is not 4. 158 if (op.getRhsType().getShape()[0] == 4) 159 return failure(); 160 return success(); 161 }); 162 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 163 return; 164 } 165 166 // Test on all contract lowering patterns. 167 VectorContractLowering contractLowering = VectorContractLowering::Dot; 168 if (lowerToFlatMatrix) 169 contractLowering = VectorContractLowering::Matmul; 170 VectorMultiReductionLowering vectorMultiReductionLowering = 171 VectorMultiReductionLowering::InnerParallel; 172 VectorTransformsOptions options{contractLowering, 173 vectorMultiReductionLowering, 174 VectorTransposeLowering()}; 175 populateVectorBroadcastLoweringPatterns(patterns); 176 populateVectorContractLoweringPatterns(patterns, options); 177 populateVectorMaskOpLoweringPatterns(patterns); 178 populateVectorShapeCastLoweringPatterns(patterns); 179 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 180 } 181 }; 182 183 struct TestVectorTransposeLowering 184 : public PassWrapper<TestVectorTransposeLowering, 185 OperationPass<func::FuncOp>> { 186 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering) 187 188 StringRef getArgument() const final { 189 return "test-vector-transpose-lowering"; 190 } 191 StringRef getDescription() const final { 192 return "Test lowering patterns that lower contract ops in the vector " 193 "dialect"; 194 } 195 TestVectorTransposeLowering() = default; 196 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 197 : PassWrapper(pass) {} 198 199 Option<bool> lowerToEltwise{ 200 *this, "eltwise", 201 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 202 llvm::cl::init(false)}; 203 Option<bool> lowerToFlatTranspose{ 204 *this, "flat", 205 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 206 llvm::cl::init(false)}; 207 Option<bool> lowerToShuffleTranspose{ 208 *this, "shuffle", 209 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 210 llvm::cl::init(false)}; 211 Option<bool> lowerToAvx2{ 212 *this, "avx2", 213 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 214 llvm::cl::init(false)}; 215 216 void getDependentDialects(DialectRegistry ®istry) const override { 217 registry.insert<LLVM::LLVMDialect>(); 218 } 219 220 void runOnOperation() override { 221 RewritePatternSet patterns(&getContext()); 222 223 // Test on one pattern in isolation. 224 // Explicitly disable shape_cast lowering. 225 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 226 .enableVectorTransposeLowering() 227 .enableShapeCastLowering(false); 228 if (lowerToEltwise) { 229 options = options.setVectorTransformsOptions( 230 VectorTransformsOptions().setVectorTransposeLowering( 231 VectorTransposeLowering::EltWise)); 232 } 233 if (lowerToFlatTranspose) { 234 options = options.setVectorTransformsOptions( 235 VectorTransformsOptions().setVectorTransposeLowering( 236 VectorTransposeLowering::Flat)); 237 } 238 if (lowerToShuffleTranspose) { 239 options = options.setVectorTransformsOptions( 240 VectorTransformsOptions().setVectorTransposeLowering( 241 VectorTransposeLowering::Shuffle)); 242 } 243 if (lowerToAvx2) { 244 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 245 x86vector::avx2::LoweringOptions().setTransposeOptions( 246 x86vector::avx2::TransposeLoweringOptions() 247 .lower4x8xf32() 248 .lower8x8xf32())); 249 } 250 251 OpPassManager dynamicPM("func.func"); 252 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 253 if (failed(runPipeline(dynamicPM, getOperation()))) 254 return signalPassFailure(); 255 } 256 }; 257 258 struct TestVectorUnrollingPatterns 259 : public PassWrapper<TestVectorUnrollingPatterns, 260 OperationPass<func::FuncOp>> { 261 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) 262 263 StringRef getArgument() const final { 264 return "test-vector-unrolling-patterns"; 265 } 266 StringRef getDescription() const final { 267 return "Test lowering patterns to unroll contract ops in the vector " 268 "dialect"; 269 } 270 TestVectorUnrollingPatterns() = default; 271 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 272 : PassWrapper(pass) {} 273 void runOnOperation() override { 274 MLIRContext *ctx = &getContext(); 275 RewritePatternSet patterns(ctx); 276 populateVectorUnrollPatterns( 277 patterns, UnrollVectorOptions() 278 .setNativeShape(ArrayRef<int64_t>{2, 2}) 279 .setFilterConstraint([](Operation *op) { 280 return success(isa<arith::AddFOp, vector::FMAOp, 281 vector::MultiDimReductionOp>(op)); 282 })); 283 populateVectorUnrollPatterns( 284 patterns, UnrollVectorOptions() 285 .setNativeShape(ArrayRef<int64_t>{2}) 286 .setFilterConstraint([](Operation *op) { 287 return success(isa<vector::ReductionOp>(op)); 288 })); 289 populateVectorUnrollPatterns( 290 patterns, UnrollVectorOptions() 291 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) 292 .setFilterConstraint([](Operation *op) { 293 return success(isa<vector::TransposeOp>(op)); 294 })); 295 296 if (unrollBasedOnType) { 297 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 298 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 299 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 300 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 301 if (auto floatType = contractOp.getLhsType() 302 .getElementType() 303 .dyn_cast<FloatType>()) { 304 if (floatType.getWidth() == 16) { 305 nativeShape[2] = 4; 306 } 307 } 308 return nativeShape; 309 }; 310 populateVectorUnrollPatterns(patterns, 311 UnrollVectorOptions() 312 .setNativeShapeFn(nativeShapeFn) 313 .setFilterConstraint([](Operation *op) { 314 return success(isa<ContractionOp>(op)); 315 })); 316 } else { 317 populateVectorUnrollPatterns( 318 patterns, UnrollVectorOptions() 319 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 320 .setFilterConstraint([](Operation *op) { 321 return success(isa<ContractionOp>(op)); 322 })); 323 } 324 populateVectorToVectorCanonicalizationPatterns(patterns); 325 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 326 } 327 328 Option<bool> unrollBasedOnType{ 329 *this, "unroll-based-on-type", 330 llvm::cl::desc("Set the unroll factor based on type of the operation"), 331 llvm::cl::init(false)}; 332 }; 333 334 struct TestVectorDistributePatterns 335 : public PassWrapper<TestVectorDistributePatterns, 336 OperationPass<func::FuncOp>> { 337 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns) 338 339 StringRef getArgument() const final { 340 return "test-vector-distribute-patterns"; 341 } 342 StringRef getDescription() const final { 343 return "Test lowering patterns to distribute vector ops in the vector " 344 "dialect"; 345 } 346 TestVectorDistributePatterns() = default; 347 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 348 : PassWrapper(pass) {} 349 void getDependentDialects(DialectRegistry ®istry) const override { 350 registry.insert<VectorDialect>(); 351 registry.insert<AffineDialect>(); 352 } 353 ListOption<int32_t> multiplicity{ 354 *this, "distribution-multiplicity", 355 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 356 357 void runOnOperation() override { 358 MLIRContext *ctx = &getContext(); 359 RewritePatternSet patterns(ctx); 360 func::FuncOp func = getOperation(); 361 func.walk([&](arith::AddFOp op) { 362 OpBuilder builder(op); 363 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 364 SmallVector<int64_t, 2> mul; 365 SmallVector<AffineExpr, 2> perm; 366 SmallVector<Value, 2> ids; 367 unsigned count = 0; 368 // Remove the multiplicity of 1 and calculate the affine map based on 369 // the multiplicity. 370 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 371 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 372 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 373 mul.push_back(m[i]); 374 ids.push_back(func.getArgument(count++)); 375 perm.push_back(getAffineDimExpr(i, ctx)); 376 } 377 } 378 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 379 perm, ctx); 380 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 381 builder, op.getOperation(), ids, mul, map); 382 if (ops.hasValue()) { 383 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 384 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 385 extractOp); 386 } 387 } 388 }); 389 populatePropagateVectorDistributionPatterns(patterns); 390 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 391 } 392 }; 393 394 struct TestVectorToLoopPatterns 395 : public PassWrapper<TestVectorToLoopPatterns, 396 OperationPass<func::FuncOp>> { 397 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns) 398 399 StringRef getArgument() const final { return "test-vector-to-forloop"; } 400 StringRef getDescription() const final { 401 return "Test lowering patterns to break up a vector op into a for loop"; 402 } 403 TestVectorToLoopPatterns() = default; 404 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 405 : PassWrapper(pass) {} 406 void getDependentDialects(DialectRegistry ®istry) const override { 407 registry.insert<VectorDialect>(); 408 registry.insert<AffineDialect>(); 409 } 410 Option<int32_t> multiplicity{ 411 *this, "distribution-multiplicity", 412 llvm::cl::desc("Set the multiplicity used for distributing vector"), 413 llvm::cl::init(32)}; 414 void runOnOperation() override { 415 MLIRContext *ctx = &getContext(); 416 RewritePatternSet patterns(ctx); 417 func::FuncOp func = getOperation(); 418 func.walk([&](arith::AddFOp op) { 419 // Check that the operation type can be broken down into a loop. 420 VectorType type = op.getType().dyn_cast<VectorType>(); 421 if (!type || type.getRank() != 1 || 422 type.getNumElements() % multiplicity != 0) 423 return mlir::WalkResult::advance(); 424 auto filterAlloc = [](Operation *op) { 425 return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op); 426 }; 427 auto dependentOps = getSlice(op, filterAlloc); 428 // Create a loop and move instructions from the Op slice into the loop. 429 OpBuilder builder(op); 430 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 431 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 432 auto numIter = 433 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 434 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 435 for (Operation *it : dependentOps) { 436 it->moveBefore(forOp.getBody()->getTerminator()); 437 } 438 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 439 // break up the original op and let the patterns propagate. 440 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 441 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 442 map); 443 if (ops.hasValue()) { 444 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 445 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 446 } 447 return mlir::WalkResult::interrupt(); 448 }); 449 populatePropagateVectorDistributionPatterns(patterns); 450 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 451 } 452 }; 453 454 struct TestVectorTransferUnrollingPatterns 455 : public PassWrapper<TestVectorTransferUnrollingPatterns, 456 OperationPass<func::FuncOp>> { 457 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 458 TestVectorTransferUnrollingPatterns) 459 460 void getDependentDialects(DialectRegistry ®istry) const override { 461 registry.insert<AffineDialect>(); 462 } 463 StringRef getArgument() const final { 464 return "test-vector-transfer-unrolling-patterns"; 465 } 466 StringRef getDescription() const final { 467 return "Test lowering patterns to unroll transfer ops in the vector " 468 "dialect"; 469 } 470 void runOnOperation() override { 471 MLIRContext *ctx = &getContext(); 472 RewritePatternSet patterns(ctx); 473 populateVectorUnrollPatterns( 474 patterns, 475 UnrollVectorOptions() 476 .setNativeShape(ArrayRef<int64_t>{2, 2}) 477 .setFilterConstraint([](Operation *op) { 478 return success( 479 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 480 })); 481 populateVectorToVectorCanonicalizationPatterns(patterns); 482 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 483 } 484 }; 485 486 struct TestVectorTransferFullPartialSplitPatterns 487 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 488 OperationPass<func::FuncOp>> { 489 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 490 TestVectorTransferFullPartialSplitPatterns) 491 492 StringRef getArgument() const final { 493 return "test-vector-transfer-full-partial-split"; 494 } 495 StringRef getDescription() const final { 496 return "Test lowering patterns to split " 497 "transfer ops via scf.if + linalg ops"; 498 } 499 TestVectorTransferFullPartialSplitPatterns() = default; 500 TestVectorTransferFullPartialSplitPatterns( 501 const TestVectorTransferFullPartialSplitPatterns &pass) 502 : PassWrapper(pass) {} 503 504 void getDependentDialects(DialectRegistry ®istry) const override { 505 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 506 scf::SCFDialect>(); 507 } 508 509 Option<bool> useLinalgOps{ 510 *this, "use-memref-copy", 511 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 512 "memref.copy operations."), 513 llvm::cl::init(false)}; 514 void runOnOperation() override { 515 MLIRContext *ctx = &getContext(); 516 RewritePatternSet patterns(ctx); 517 VectorTransformsOptions options; 518 if (useLinalgOps) 519 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 520 else 521 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 522 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 523 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 524 } 525 }; 526 527 struct TestVectorTransferOpt 528 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> { 529 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) 530 531 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 532 StringRef getDescription() const final { 533 return "Test optimization transformations for transfer ops"; 534 } 535 void runOnOperation() override { transferOpflowOpt(getOperation()); } 536 }; 537 538 struct TestVectorTransferLoweringPatterns 539 : public PassWrapper<TestVectorTransferLoweringPatterns, 540 OperationPass<func::FuncOp>> { 541 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 542 TestVectorTransferLoweringPatterns) 543 544 void getDependentDialects(DialectRegistry ®istry) const override { 545 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 546 } 547 StringRef getArgument() const final { 548 return "test-vector-transfer-lowering-patterns"; 549 } 550 StringRef getDescription() const final { 551 return "Test lowering patterns to lower transfer ops to other vector ops"; 552 } 553 void runOnOperation() override { 554 RewritePatternSet patterns(&getContext()); 555 populateVectorTransferLoweringPatterns(patterns); 556 populateVectorTransferPermutationMapLoweringPatterns(patterns); 557 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 558 } 559 }; 560 561 struct TestVectorMultiReductionLoweringPatterns 562 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 563 OperationPass<func::FuncOp>> { 564 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 565 TestVectorMultiReductionLoweringPatterns) 566 567 TestVectorMultiReductionLoweringPatterns() = default; 568 TestVectorMultiReductionLoweringPatterns( 569 const TestVectorMultiReductionLoweringPatterns &pass) 570 : PassWrapper(pass) {} 571 void getDependentDialects(DialectRegistry ®istry) const override { 572 registry.insert<memref::MemRefDialect>(); 573 } 574 StringRef getArgument() const final { 575 return "test-vector-multi-reduction-lowering-patterns"; 576 } 577 StringRef getDescription() const final { 578 return "Test lowering patterns to lower vector.multi_reduction to other " 579 "vector ops"; 580 } 581 Option<bool> useOuterReductions{ 582 *this, "use-outer-reductions", 583 llvm::cl::desc("Move reductions to outer most dimensions"), 584 llvm::cl::init(false)}; 585 void runOnOperation() override { 586 RewritePatternSet patterns(&getContext()); 587 populateVectorMultiReductionLoweringPatterns( 588 patterns, useOuterReductions 589 ? vector::VectorMultiReductionLowering::InnerParallel 590 : vector::VectorMultiReductionLowering::InnerReduction); 591 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 592 } 593 }; 594 595 struct TestVectorTransferCollapseInnerMostContiguousDims 596 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 597 OperationPass<func::FuncOp>> { 598 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 599 TestVectorTransferCollapseInnerMostContiguousDims) 600 601 TestVectorTransferCollapseInnerMostContiguousDims() = default; 602 TestVectorTransferCollapseInnerMostContiguousDims( 603 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 604 605 void getDependentDialects(DialectRegistry ®istry) const override { 606 registry.insert<memref::MemRefDialect, AffineDialect>(); 607 } 608 609 StringRef getArgument() const final { 610 return "test-vector-transfer-collapse-inner-most-dims"; 611 } 612 613 StringRef getDescription() const final { 614 return "Test lowering patterns that reducedes the rank of the vector " 615 "transfer memory and vector operands."; 616 } 617 618 void runOnOperation() override { 619 RewritePatternSet patterns(&getContext()); 620 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 621 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 622 } 623 }; 624 625 struct TestVectorReduceToContractPatternsPatterns 626 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 627 OperationPass<func::FuncOp>> { 628 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 629 TestVectorReduceToContractPatternsPatterns) 630 631 StringRef getArgument() const final { 632 return "test-vector-reduction-to-contract-patterns"; 633 } 634 StringRef getDescription() const final { 635 return "Test patterns to convert multireduce op to contract and combine " 636 "broadcast/transpose to contract"; 637 } 638 void runOnOperation() override { 639 RewritePatternSet patterns(&getContext()); 640 populateVectorReductionToContractPatterns(patterns); 641 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 642 } 643 }; 644 645 struct TestVectorTransferDropUnitDimsPatterns 646 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, 647 OperationPass<func::FuncOp>> { 648 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 649 TestVectorTransferDropUnitDimsPatterns) 650 651 StringRef getArgument() const final { 652 return "test-vector-transfer-drop-unit-dims-patterns"; 653 } 654 void getDependentDialects(DialectRegistry ®istry) const override { 655 registry.insert<memref::MemRefDialect>(); 656 } 657 void runOnOperation() override { 658 RewritePatternSet patterns(&getContext()); 659 populateVectorTransferDropUnitDimsPatterns(patterns); 660 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 661 } 662 }; 663 664 struct TestFlattenVectorTransferPatterns 665 : public PassWrapper<TestFlattenVectorTransferPatterns, 666 OperationPass<func::FuncOp>> { 667 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 668 TestFlattenVectorTransferPatterns) 669 670 StringRef getArgument() const final { 671 return "test-vector-transfer-flatten-patterns"; 672 } 673 StringRef getDescription() const final { 674 return "Test patterns to rewrite contiguous row-major N-dimensional " 675 "vector.transfer_{read,write} ops into 1D transfers"; 676 } 677 void getDependentDialects(DialectRegistry ®istry) const override { 678 registry.insert<memref::MemRefDialect>(); 679 } 680 void runOnOperation() override { 681 RewritePatternSet patterns(&getContext()); 682 populateFlattenVectorTransferPatterns(patterns); 683 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 684 } 685 }; 686 687 struct TestVectorScanLowering 688 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> { 689 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) 690 691 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 692 StringRef getDescription() const final { 693 return "Test lowering patterns that lower the scan op in the vector " 694 "dialect"; 695 } 696 void runOnOperation() override { 697 RewritePatternSet patterns(&getContext()); 698 populateVectorScanLoweringPatterns(patterns); 699 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 700 } 701 }; 702 703 } // namespace 704 705 namespace mlir { 706 namespace test { 707 void registerTestVectorLowerings() { 708 PassRegistration<TestVectorToVectorLowering>(); 709 710 PassRegistration<TestVectorContractionLowering>(); 711 712 PassRegistration<TestVectorTransposeLowering>(); 713 714 PassRegistration<TestVectorUnrollingPatterns>(); 715 716 PassRegistration<TestVectorTransferUnrollingPatterns>(); 717 718 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 719 720 PassRegistration<TestVectorDistributePatterns>(); 721 722 PassRegistration<TestVectorToLoopPatterns>(); 723 724 PassRegistration<TestVectorTransferOpt>(); 725 726 PassRegistration<TestVectorTransferLoweringPatterns>(); 727 728 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 729 730 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 731 732 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 733 734 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 735 736 PassRegistration<TestFlattenVectorTransferPatterns>(); 737 738 PassRegistration<TestVectorScanLowering>(); 739 } 740 } // namespace test 741 } // namespace mlir 742