1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Pass/PassManager.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 using namespace mlir::vector; 28 29 namespace { 30 31 struct TestVectorToVectorLowering 32 : public PassWrapper<TestVectorToVectorLowering, OperationPass<FuncOp>> { 33 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) 34 35 TestVectorToVectorLowering() = default; 36 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 37 : PassWrapper(pass) {} 38 StringRef getArgument() const final { 39 return "test-vector-to-vector-lowering"; 40 } 41 StringRef getDescription() const final { 42 return "Test lowering patterns between ops in the vector dialect"; 43 } 44 45 void getDependentDialects(DialectRegistry ®istry) const override { 46 registry.insert<AffineDialect>(); 47 } 48 49 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 50 llvm::cl::init(false)}; 51 52 void runOnOperation() override { 53 auto *ctx = &getContext(); 54 RewritePatternSet patterns(ctx); 55 if (unroll) { 56 populateVectorUnrollPatterns( 57 patterns, 58 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 59 filter)); 60 } 61 populateVectorToVectorCanonicalizationPatterns(patterns); 62 populateBubbleVectorBitCastOpPatterns(patterns); 63 populateCastAwayVectorLeadingOneDimPatterns(patterns); 64 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 65 } 66 67 private: 68 // Return the target shape based on op type. 69 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 70 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 71 return SmallVector<int64_t, 4>(2, 2); 72 if (isa<vector::ContractionOp>(op)) 73 return SmallVector<int64_t, 4>(3, 2); 74 // For transfer ops, just propagate the shape coming from 75 // InsertStridedSlices/ExtractStridedSlices. 76 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 77 VectorType dstVec; 78 for (Operation *users : readOp->getUsers()) { 79 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 80 if (!extract) 81 return llvm::None; 82 auto vecType = extract.getResult().getType().cast<VectorType>(); 83 if (dstVec && dstVec != vecType) 84 return llvm::None; 85 dstVec = vecType; 86 } 87 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 88 dstVec.getShape().end()); 89 } 90 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 91 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 92 if (!insert) 93 return llvm::None; 94 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 95 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 96 } 97 return llvm::None; 98 } 99 100 static LogicalResult filter(Operation *op) { 101 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 102 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 103 } 104 }; 105 106 struct TestVectorContractionLowering 107 : public PassWrapper<TestVectorContractionLowering, OperationPass<FuncOp>> { 108 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering) 109 110 StringRef getArgument() const final { 111 return "test-vector-contraction-lowering"; 112 } 113 StringRef getDescription() const final { 114 return "Test lowering patterns that lower contract ops in the vector " 115 "dialect"; 116 } 117 TestVectorContractionLowering() = default; 118 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 119 : PassWrapper(pass) {} 120 121 Option<bool> lowerToFlatMatrix{ 122 *this, "vector-lower-matrix-intrinsics", 123 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 124 llvm::cl::init(false)}; 125 Option<bool> lowerToOuterProduct{ 126 *this, "vector-outerproduct", 127 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 128 llvm::cl::init(false)}; 129 Option<bool> lowerToFilterOuterProduct{ 130 *this, "vector-filter-outerproduct", 131 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 132 "vectors of size 4."), 133 llvm::cl::init(false)}; 134 135 void runOnOperation() override { 136 RewritePatternSet patterns(&getContext()); 137 138 // Test on one pattern in isolation. 139 if (lowerToOuterProduct) { 140 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 141 VectorTransformsOptions options{lowering}; 142 patterns.add<ContractionOpToOuterProductOpLowering>(options, 143 &getContext()); 144 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 145 return; 146 } 147 148 // Test on one pattern in isolation. 149 if (lowerToFilterOuterProduct) { 150 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 151 VectorTransformsOptions options{lowering}; 152 patterns.add<ContractionOpToOuterProductOpLowering>( 153 options, &getContext(), [](vector::ContractionOp op) { 154 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 155 // where M is not 4. 156 if (op.getRhsType().getShape()[0] == 4) 157 return failure(); 158 return success(); 159 }); 160 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 161 return; 162 } 163 164 // Test on all contract lowering patterns. 165 VectorContractLowering contractLowering = VectorContractLowering::Dot; 166 if (lowerToFlatMatrix) 167 contractLowering = VectorContractLowering::Matmul; 168 VectorMultiReductionLowering vectorMultiReductionLowering = 169 VectorMultiReductionLowering::InnerParallel; 170 VectorTransformsOptions options{contractLowering, 171 vectorMultiReductionLowering, 172 VectorTransposeLowering()}; 173 populateVectorBroadcastLoweringPatterns(patterns); 174 populateVectorContractLoweringPatterns(patterns, options); 175 populateVectorMaskOpLoweringPatterns(patterns); 176 populateVectorShapeCastLoweringPatterns(patterns); 177 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 178 } 179 }; 180 181 struct TestVectorTransposeLowering 182 : public PassWrapper<TestVectorTransposeLowering, OperationPass<FuncOp>> { 183 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering) 184 185 StringRef getArgument() const final { 186 return "test-vector-transpose-lowering"; 187 } 188 StringRef getDescription() const final { 189 return "Test lowering patterns that lower contract ops in the vector " 190 "dialect"; 191 } 192 TestVectorTransposeLowering() = default; 193 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 194 : PassWrapper(pass) {} 195 196 Option<bool> lowerToEltwise{ 197 *this, "eltwise", 198 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 199 llvm::cl::init(false)}; 200 Option<bool> lowerToFlatTranspose{ 201 *this, "flat", 202 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 203 llvm::cl::init(false)}; 204 Option<bool> lowerToShuffleTranspose{ 205 *this, "shuffle", 206 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 207 llvm::cl::init(false)}; 208 Option<bool> lowerToAvx2{ 209 *this, "avx2", 210 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 211 llvm::cl::init(false)}; 212 213 void getDependentDialects(DialectRegistry ®istry) const override { 214 registry.insert<LLVM::LLVMDialect>(); 215 } 216 217 void runOnOperation() override { 218 RewritePatternSet patterns(&getContext()); 219 220 // Test on one pattern in isolation. 221 // Explicitly disable shape_cast lowering. 222 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 223 .enableVectorTransposeLowering() 224 .enableShapeCastLowering(false); 225 if (lowerToEltwise) { 226 options = options.setVectorTransformsOptions( 227 VectorTransformsOptions().setVectorTransposeLowering( 228 VectorTransposeLowering::EltWise)); 229 } 230 if (lowerToFlatTranspose) { 231 options = options.setVectorTransformsOptions( 232 VectorTransformsOptions().setVectorTransposeLowering( 233 VectorTransposeLowering::Flat)); 234 } 235 if (lowerToShuffleTranspose) { 236 options = options.setVectorTransformsOptions( 237 VectorTransformsOptions().setVectorTransposeLowering( 238 VectorTransposeLowering::Shuffle)); 239 } 240 if (lowerToAvx2) { 241 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 242 x86vector::avx2::LoweringOptions().setTransposeOptions( 243 x86vector::avx2::TransposeLoweringOptions() 244 .lower4x8xf32() 245 .lower8x8xf32())); 246 } 247 248 OpPassManager dynamicPM("func.func"); 249 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 250 if (failed(runPipeline(dynamicPM, getOperation()))) 251 return signalPassFailure(); 252 } 253 }; 254 255 struct TestVectorUnrollingPatterns 256 : public PassWrapper<TestVectorUnrollingPatterns, OperationPass<FuncOp>> { 257 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) 258 259 StringRef getArgument() const final { 260 return "test-vector-unrolling-patterns"; 261 } 262 StringRef getDescription() const final { 263 return "Test lowering patterns to unroll contract ops in the vector " 264 "dialect"; 265 } 266 TestVectorUnrollingPatterns() = default; 267 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 268 : PassWrapper(pass) {} 269 void runOnOperation() override { 270 MLIRContext *ctx = &getContext(); 271 RewritePatternSet patterns(ctx); 272 populateVectorUnrollPatterns( 273 patterns, UnrollVectorOptions() 274 .setNativeShape(ArrayRef<int64_t>{2, 2}) 275 .setFilterConstraint([](Operation *op) { 276 return success(isa<arith::AddFOp, vector::FMAOp, 277 vector::MultiDimReductionOp>(op)); 278 })); 279 populateVectorUnrollPatterns( 280 patterns, UnrollVectorOptions() 281 .setNativeShape(ArrayRef<int64_t>{2}) 282 .setFilterConstraint([](Operation *op) { 283 return success(isa<vector::ReductionOp>(op)); 284 })); 285 populateVectorUnrollPatterns( 286 patterns, UnrollVectorOptions() 287 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) 288 .setFilterConstraint([](Operation *op) { 289 return success(isa<vector::TransposeOp>(op)); 290 })); 291 292 if (unrollBasedOnType) { 293 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 294 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 295 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 296 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 297 if (auto floatType = contractOp.getLhsType() 298 .getElementType() 299 .dyn_cast<FloatType>()) { 300 if (floatType.getWidth() == 16) { 301 nativeShape[2] = 4; 302 } 303 } 304 return nativeShape; 305 }; 306 populateVectorUnrollPatterns(patterns, 307 UnrollVectorOptions() 308 .setNativeShapeFn(nativeShapeFn) 309 .setFilterConstraint([](Operation *op) { 310 return success(isa<ContractionOp>(op)); 311 })); 312 } else { 313 populateVectorUnrollPatterns( 314 patterns, UnrollVectorOptions() 315 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 316 .setFilterConstraint([](Operation *op) { 317 return success(isa<ContractionOp>(op)); 318 })); 319 } 320 populateVectorToVectorCanonicalizationPatterns(patterns); 321 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 322 } 323 324 Option<bool> unrollBasedOnType{ 325 *this, "unroll-based-on-type", 326 llvm::cl::desc("Set the unroll factor based on type of the operation"), 327 llvm::cl::init(false)}; 328 }; 329 330 struct TestVectorDistributePatterns 331 : public PassWrapper<TestVectorDistributePatterns, OperationPass<FuncOp>> { 332 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns) 333 334 StringRef getArgument() const final { 335 return "test-vector-distribute-patterns"; 336 } 337 StringRef getDescription() const final { 338 return "Test lowering patterns to distribute vector ops in the vector " 339 "dialect"; 340 } 341 TestVectorDistributePatterns() = default; 342 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 343 : PassWrapper(pass) {} 344 void getDependentDialects(DialectRegistry ®istry) const override { 345 registry.insert<VectorDialect>(); 346 registry.insert<AffineDialect>(); 347 } 348 ListOption<int32_t> multiplicity{ 349 *this, "distribution-multiplicity", 350 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 351 352 void runOnOperation() override { 353 MLIRContext *ctx = &getContext(); 354 RewritePatternSet patterns(ctx); 355 FuncOp func = getOperation(); 356 func.walk([&](arith::AddFOp op) { 357 OpBuilder builder(op); 358 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 359 SmallVector<int64_t, 2> mul; 360 SmallVector<AffineExpr, 2> perm; 361 SmallVector<Value, 2> ids; 362 unsigned count = 0; 363 // Remove the multiplicity of 1 and calculate the affine map based on 364 // the multiplicity. 365 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 366 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 367 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 368 mul.push_back(m[i]); 369 ids.push_back(func.getArgument(count++)); 370 perm.push_back(getAffineDimExpr(i, ctx)); 371 } 372 } 373 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 374 perm, ctx); 375 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 376 builder, op.getOperation(), ids, mul, map); 377 if (ops.hasValue()) { 378 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 379 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 380 extractOp); 381 } 382 } 383 }); 384 populatePropagateVectorDistributionPatterns(patterns); 385 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 386 } 387 }; 388 389 struct TestVectorToLoopPatterns 390 : public PassWrapper<TestVectorToLoopPatterns, OperationPass<FuncOp>> { 391 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns) 392 393 StringRef getArgument() const final { return "test-vector-to-forloop"; } 394 StringRef getDescription() const final { 395 return "Test lowering patterns to break up a vector op into a for loop"; 396 } 397 TestVectorToLoopPatterns() = default; 398 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 399 : PassWrapper(pass) {} 400 void getDependentDialects(DialectRegistry ®istry) const override { 401 registry.insert<VectorDialect>(); 402 registry.insert<AffineDialect>(); 403 } 404 Option<int32_t> multiplicity{ 405 *this, "distribution-multiplicity", 406 llvm::cl::desc("Set the multiplicity used for distributing vector"), 407 llvm::cl::init(32)}; 408 void runOnOperation() override { 409 MLIRContext *ctx = &getContext(); 410 RewritePatternSet patterns(ctx); 411 FuncOp func = getOperation(); 412 func.walk([&](arith::AddFOp op) { 413 // Check that the operation type can be broken down into a loop. 414 VectorType type = op.getType().dyn_cast<VectorType>(); 415 if (!type || type.getRank() != 1 || 416 type.getNumElements() % multiplicity != 0) 417 return mlir::WalkResult::advance(); 418 auto filterAlloc = [](Operation *op) { 419 return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op); 420 }; 421 auto dependentOps = getSlice(op, filterAlloc); 422 // Create a loop and move instructions from the Op slice into the loop. 423 OpBuilder builder(op); 424 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 425 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 426 auto numIter = 427 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 428 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 429 for (Operation *it : dependentOps) { 430 it->moveBefore(forOp.getBody()->getTerminator()); 431 } 432 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 433 // break up the original op and let the patterns propagate. 434 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 435 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 436 map); 437 if (ops.hasValue()) { 438 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 439 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 440 } 441 return mlir::WalkResult::interrupt(); 442 }); 443 populatePropagateVectorDistributionPatterns(patterns); 444 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 445 } 446 }; 447 448 struct TestVectorTransferUnrollingPatterns 449 : public PassWrapper<TestVectorTransferUnrollingPatterns, 450 OperationPass<FuncOp>> { 451 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 452 TestVectorTransferUnrollingPatterns) 453 454 void getDependentDialects(DialectRegistry ®istry) const override { 455 registry.insert<AffineDialect>(); 456 } 457 StringRef getArgument() const final { 458 return "test-vector-transfer-unrolling-patterns"; 459 } 460 StringRef getDescription() const final { 461 return "Test lowering patterns to unroll transfer ops in the vector " 462 "dialect"; 463 } 464 void runOnOperation() override { 465 MLIRContext *ctx = &getContext(); 466 RewritePatternSet patterns(ctx); 467 populateVectorUnrollPatterns( 468 patterns, 469 UnrollVectorOptions() 470 .setNativeShape(ArrayRef<int64_t>{2, 2}) 471 .setFilterConstraint([](Operation *op) { 472 return success( 473 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 474 })); 475 populateVectorToVectorCanonicalizationPatterns(patterns); 476 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 477 } 478 }; 479 480 struct TestVectorTransferFullPartialSplitPatterns 481 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 482 OperationPass<FuncOp>> { 483 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 484 TestVectorTransferFullPartialSplitPatterns) 485 486 StringRef getArgument() const final { 487 return "test-vector-transfer-full-partial-split"; 488 } 489 StringRef getDescription() const final { 490 return "Test lowering patterns to split " 491 "transfer ops via scf.if + linalg ops"; 492 } 493 TestVectorTransferFullPartialSplitPatterns() = default; 494 TestVectorTransferFullPartialSplitPatterns( 495 const TestVectorTransferFullPartialSplitPatterns &pass) 496 : PassWrapper(pass) {} 497 498 void getDependentDialects(DialectRegistry ®istry) const override { 499 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 500 scf::SCFDialect>(); 501 } 502 503 Option<bool> useLinalgOps{ 504 *this, "use-memref-copy", 505 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 506 "memref.copy operations."), 507 llvm::cl::init(false)}; 508 void runOnOperation() override { 509 MLIRContext *ctx = &getContext(); 510 RewritePatternSet patterns(ctx); 511 VectorTransformsOptions options; 512 if (useLinalgOps) 513 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 514 else 515 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 516 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 517 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 518 } 519 }; 520 521 struct TestVectorTransferOpt 522 : public PassWrapper<TestVectorTransferOpt, OperationPass<FuncOp>> { 523 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) 524 525 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 526 StringRef getDescription() const final { 527 return "Test optimization transformations for transfer ops"; 528 } 529 void runOnOperation() override { transferOpflowOpt(getOperation()); } 530 }; 531 532 struct TestVectorTransferLoweringPatterns 533 : public PassWrapper<TestVectorTransferLoweringPatterns, 534 OperationPass<FuncOp>> { 535 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 536 TestVectorTransferLoweringPatterns) 537 538 void getDependentDialects(DialectRegistry ®istry) const override { 539 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 540 } 541 StringRef getArgument() const final { 542 return "test-vector-transfer-lowering-patterns"; 543 } 544 StringRef getDescription() const final { 545 return "Test lowering patterns to lower transfer ops to other vector ops"; 546 } 547 void runOnOperation() override { 548 RewritePatternSet patterns(&getContext()); 549 populateVectorTransferLoweringPatterns(patterns); 550 populateVectorTransferPermutationMapLoweringPatterns(patterns); 551 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 552 } 553 }; 554 555 struct TestVectorMultiReductionLoweringPatterns 556 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 557 OperationPass<FuncOp>> { 558 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 559 TestVectorMultiReductionLoweringPatterns) 560 561 TestVectorMultiReductionLoweringPatterns() = default; 562 TestVectorMultiReductionLoweringPatterns( 563 const TestVectorMultiReductionLoweringPatterns &pass) 564 : PassWrapper(pass) {} 565 void getDependentDialects(DialectRegistry ®istry) const override { 566 registry.insert<memref::MemRefDialect>(); 567 } 568 StringRef getArgument() const final { 569 return "test-vector-multi-reduction-lowering-patterns"; 570 } 571 StringRef getDescription() const final { 572 return "Test lowering patterns to lower vector.multi_reduction to other " 573 "vector ops"; 574 } 575 Option<bool> useOuterReductions{ 576 *this, "use-outer-reductions", 577 llvm::cl::desc("Move reductions to outer most dimensions"), 578 llvm::cl::init(false)}; 579 void runOnOperation() override { 580 RewritePatternSet patterns(&getContext()); 581 populateVectorMultiReductionLoweringPatterns( 582 patterns, useOuterReductions 583 ? vector::VectorMultiReductionLowering::InnerParallel 584 : vector::VectorMultiReductionLowering::InnerReduction); 585 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 586 } 587 }; 588 589 struct TestVectorTransferCollapseInnerMostContiguousDims 590 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 591 OperationPass<FuncOp>> { 592 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 593 TestVectorTransferCollapseInnerMostContiguousDims) 594 595 TestVectorTransferCollapseInnerMostContiguousDims() = default; 596 TestVectorTransferCollapseInnerMostContiguousDims( 597 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 598 599 void getDependentDialects(DialectRegistry ®istry) const override { 600 registry.insert<memref::MemRefDialect, AffineDialect>(); 601 } 602 603 StringRef getArgument() const final { 604 return "test-vector-transfer-collapse-inner-most-dims"; 605 } 606 607 StringRef getDescription() const final { 608 return "Test lowering patterns that reducedes the rank of the vector " 609 "transfer memory and vector operands."; 610 } 611 612 void runOnOperation() override { 613 RewritePatternSet patterns(&getContext()); 614 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 615 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 616 } 617 }; 618 619 struct TestVectorReduceToContractPatternsPatterns 620 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 621 OperationPass<FuncOp>> { 622 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 623 TestVectorReduceToContractPatternsPatterns) 624 625 StringRef getArgument() const final { 626 return "test-vector-reduction-to-contract-patterns"; 627 } 628 StringRef getDescription() const final { 629 return "Test patterns to convert multireduce op to contract and combine " 630 "broadcast/transpose to contract"; 631 } 632 void runOnOperation() override { 633 RewritePatternSet patterns(&getContext()); 634 populateVectorReductionToContractPatterns(patterns); 635 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 636 } 637 }; 638 639 struct TestVectorTransferDropUnitDimsPatterns 640 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, 641 OperationPass<FuncOp>> { 642 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 643 TestVectorTransferDropUnitDimsPatterns) 644 645 StringRef getArgument() const final { 646 return "test-vector-transfer-drop-unit-dims-patterns"; 647 } 648 void getDependentDialects(DialectRegistry ®istry) const override { 649 registry.insert<memref::MemRefDialect>(); 650 } 651 void runOnOperation() override { 652 RewritePatternSet patterns(&getContext()); 653 populateVectorTransferDropUnitDimsPatterns(patterns); 654 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 655 } 656 }; 657 658 struct TestFlattenVectorTransferPatterns 659 : public PassWrapper<TestFlattenVectorTransferPatterns, 660 OperationPass<FuncOp>> { 661 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 662 TestFlattenVectorTransferPatterns) 663 664 StringRef getArgument() const final { 665 return "test-vector-transfer-flatten-patterns"; 666 } 667 StringRef getDescription() const final { 668 return "Test patterns to rewrite contiguous row-major N-dimensional " 669 "vector.transfer_{read,write} ops into 1D transfers"; 670 } 671 void getDependentDialects(DialectRegistry ®istry) const override { 672 registry.insert<memref::MemRefDialect>(); 673 } 674 void runOnOperation() override { 675 RewritePatternSet patterns(&getContext()); 676 populateFlattenVectorTransferPatterns(patterns); 677 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 678 } 679 }; 680 681 struct TestVectorScanLowering 682 : public PassWrapper<TestVectorScanLowering, OperationPass<FuncOp>> { 683 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) 684 685 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 686 StringRef getDescription() const final { 687 return "Test lowering patterns that lower the scan op in the vector " 688 "dialect"; 689 } 690 void runOnOperation() override { 691 RewritePatternSet patterns(&getContext()); 692 populateVectorScanLoweringPatterns(patterns); 693 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 694 } 695 }; 696 697 } // namespace 698 699 namespace mlir { 700 namespace test { 701 void registerTestVectorLowerings() { 702 PassRegistration<TestVectorToVectorLowering>(); 703 704 PassRegistration<TestVectorContractionLowering>(); 705 706 PassRegistration<TestVectorTransposeLowering>(); 707 708 PassRegistration<TestVectorUnrollingPatterns>(); 709 710 PassRegistration<TestVectorTransferUnrollingPatterns>(); 711 712 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 713 714 PassRegistration<TestVectorDistributePatterns>(); 715 716 PassRegistration<TestVectorToLoopPatterns>(); 717 718 PassRegistration<TestVectorTransferOpt>(); 719 720 PassRegistration<TestVectorTransferLoweringPatterns>(); 721 722 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 723 724 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 725 726 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 727 728 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 729 730 PassRegistration<TestFlattenVectorTransferPatterns>(); 731 732 PassRegistration<TestVectorScanLowering>(); 733 } 734 } // namespace test 735 } // namespace mlir 736