1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 using namespace mlir; 28 using namespace mlir::linalg; 29 using namespace mlir::vector; 30 31 namespace { 32 33 struct TestVectorToVectorLowering 34 : public PassWrapper<TestVectorToVectorLowering, 35 OperationPass<func::FuncOp>> { 36 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) 37 38 TestVectorToVectorLowering() = default; 39 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 40 : PassWrapper(pass) {} 41 StringRef getArgument() const final { 42 return "test-vector-to-vector-lowering"; 43 } 44 StringRef getDescription() const final { 45 return "Test lowering patterns between ops in the vector dialect"; 46 } 47 48 void getDependentDialects(DialectRegistry ®istry) const override { 49 registry.insert<AffineDialect>(); 50 } 51 52 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 53 llvm::cl::init(false)}; 54 55 void runOnOperation() override { 56 auto *ctx = &getContext(); 57 RewritePatternSet patterns(ctx); 58 if (unroll) { 59 populateVectorUnrollPatterns( 60 patterns, 61 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 62 filter)); 63 } 64 populateVectorToVectorCanonicalizationPatterns(patterns); 65 populateBubbleVectorBitCastOpPatterns(patterns); 66 populateCastAwayVectorLeadingOneDimPatterns(patterns); 67 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 68 } 69 70 private: 71 // Return the target shape based on op type. 72 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 73 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 74 return SmallVector<int64_t, 4>(2, 2); 75 if (isa<vector::ContractionOp>(op)) 76 return SmallVector<int64_t, 4>(3, 2); 77 // For transfer ops, just propagate the shape coming from 78 // InsertStridedSlices/ExtractStridedSlices. 79 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 80 VectorType dstVec; 81 for (Operation *users : readOp->getUsers()) { 82 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 83 if (!extract) 84 return llvm::None; 85 auto vecType = extract.getResult().getType().cast<VectorType>(); 86 if (dstVec && dstVec != vecType) 87 return llvm::None; 88 dstVec = vecType; 89 } 90 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 91 dstVec.getShape().end()); 92 } 93 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 94 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 95 if (!insert) 96 return llvm::None; 97 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 98 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 99 } 100 return llvm::None; 101 } 102 103 static LogicalResult filter(Operation *op) { 104 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 105 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 106 } 107 }; 108 109 struct TestVectorContractionLowering 110 : public PassWrapper<TestVectorContractionLowering, 111 OperationPass<func::FuncOp>> { 112 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering) 113 114 StringRef getArgument() const final { 115 return "test-vector-contraction-lowering"; 116 } 117 StringRef getDescription() const final { 118 return "Test lowering patterns that lower contract ops in the vector " 119 "dialect"; 120 } 121 TestVectorContractionLowering() = default; 122 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 123 : PassWrapper(pass) {} 124 125 Option<bool> lowerToFlatMatrix{ 126 *this, "vector-lower-matrix-intrinsics", 127 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 128 llvm::cl::init(false)}; 129 Option<bool> lowerToOuterProduct{ 130 *this, "vector-outerproduct", 131 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 132 llvm::cl::init(false)}; 133 Option<bool> lowerToFilterOuterProduct{ 134 *this, "vector-filter-outerproduct", 135 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 136 "vectors of size 4."), 137 llvm::cl::init(false)}; 138 Option<bool> lowerToParallelArith{ 139 *this, "vector-parallel-arith", 140 llvm::cl::desc("Lower vector.contract to elementwise vector ops."), 141 llvm::cl::init(false)}; 142 143 void runOnOperation() override { 144 RewritePatternSet patterns(&getContext()); 145 146 // Test on one pattern in isolation. 147 if (lowerToOuterProduct) { 148 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 149 VectorTransformsOptions options{lowering}; 150 patterns.add<ContractionOpToOuterProductOpLowering>(options, 151 &getContext()); 152 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 153 return; 154 } 155 156 // Test on one pattern in isolation. 157 if (lowerToFilterOuterProduct) { 158 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 159 VectorTransformsOptions options{lowering}; 160 patterns.add<ContractionOpToOuterProductOpLowering>( 161 options, &getContext(), [](vector::ContractionOp op) { 162 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 163 // where M is not 4. 164 if (op.getRhsType().getShape()[0] == 4) 165 return failure(); 166 return success(); 167 }); 168 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 169 return; 170 } 171 172 if (lowerToParallelArith) { 173 vector::populateVectorContractLoweringPatterns( 174 patterns, 175 vector::VectorTransformsOptions().setVectorTransformsOptions( 176 vector::VectorContractLowering::ParallelArith)); 177 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 178 return; 179 } 180 181 // Test on all contract lowering patterns. 182 VectorContractLowering contractLowering = VectorContractLowering::Dot; 183 if (lowerToFlatMatrix) 184 contractLowering = VectorContractLowering::Matmul; 185 VectorMultiReductionLowering vectorMultiReductionLowering = 186 VectorMultiReductionLowering::InnerParallel; 187 VectorTransformsOptions options{contractLowering, 188 vectorMultiReductionLowering, 189 VectorTransposeLowering()}; 190 populateVectorBroadcastLoweringPatterns(patterns); 191 populateVectorContractLoweringPatterns(patterns, options); 192 populateVectorMaskOpLoweringPatterns(patterns); 193 populateVectorShapeCastLoweringPatterns(patterns); 194 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 195 } 196 }; 197 198 struct TestVectorTransposeLowering 199 : public PassWrapper<TestVectorTransposeLowering, 200 OperationPass<func::FuncOp>> { 201 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering) 202 203 StringRef getArgument() const final { 204 return "test-vector-transpose-lowering"; 205 } 206 StringRef getDescription() const final { 207 return "Test lowering patterns that lower contract ops in the vector " 208 "dialect"; 209 } 210 TestVectorTransposeLowering() = default; 211 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 212 : PassWrapper(pass) {} 213 214 Option<bool> lowerToEltwise{ 215 *this, "eltwise", 216 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 217 llvm::cl::init(false)}; 218 Option<bool> lowerToFlatTranspose{ 219 *this, "flat", 220 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 221 llvm::cl::init(false)}; 222 Option<bool> lowerToShuffleTranspose{ 223 *this, "shuffle", 224 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 225 llvm::cl::init(false)}; 226 Option<bool> lowerToAvx2{ 227 *this, "avx2", 228 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 229 llvm::cl::init(false)}; 230 231 void getDependentDialects(DialectRegistry ®istry) const override { 232 registry.insert<LLVM::LLVMDialect>(); 233 } 234 235 void runOnOperation() override { 236 RewritePatternSet patterns(&getContext()); 237 238 // Test on one pattern in isolation. 239 // Explicitly disable shape_cast lowering. 240 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 241 .enableVectorTransposeLowering() 242 .enableShapeCastLowering(false); 243 if (lowerToEltwise) { 244 options = options.setVectorTransformsOptions( 245 VectorTransformsOptions().setVectorTransposeLowering( 246 VectorTransposeLowering::EltWise)); 247 } 248 if (lowerToFlatTranspose) { 249 options = options.setVectorTransformsOptions( 250 VectorTransformsOptions().setVectorTransposeLowering( 251 VectorTransposeLowering::Flat)); 252 } 253 if (lowerToShuffleTranspose) { 254 options = options.setVectorTransformsOptions( 255 VectorTransformsOptions().setVectorTransposeLowering( 256 VectorTransposeLowering::Shuffle)); 257 } 258 if (lowerToAvx2) { 259 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 260 x86vector::avx2::LoweringOptions().setTransposeOptions( 261 x86vector::avx2::TransposeLoweringOptions() 262 .lower4x8xf32() 263 .lower8x8xf32())); 264 } 265 266 OpPassManager dynamicPM("func.func"); 267 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 268 if (failed(runPipeline(dynamicPM, getOperation()))) 269 return signalPassFailure(); 270 } 271 }; 272 273 struct TestVectorUnrollingPatterns 274 : public PassWrapper<TestVectorUnrollingPatterns, 275 OperationPass<func::FuncOp>> { 276 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) 277 278 StringRef getArgument() const final { 279 return "test-vector-unrolling-patterns"; 280 } 281 StringRef getDescription() const final { 282 return "Test lowering patterns to unroll contract ops in the vector " 283 "dialect"; 284 } 285 TestVectorUnrollingPatterns() = default; 286 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 287 : PassWrapper(pass) {} 288 void runOnOperation() override { 289 MLIRContext *ctx = &getContext(); 290 RewritePatternSet patterns(ctx); 291 populateVectorUnrollPatterns( 292 patterns, UnrollVectorOptions() 293 .setNativeShape(ArrayRef<int64_t>{2, 2}) 294 .setFilterConstraint([](Operation *op) { 295 return success(isa<arith::AddFOp, vector::FMAOp, 296 vector::MultiDimReductionOp>(op)); 297 })); 298 populateVectorUnrollPatterns( 299 patterns, UnrollVectorOptions() 300 .setNativeShape(ArrayRef<int64_t>{2}) 301 .setFilterConstraint([](Operation *op) { 302 return success(isa<vector::ReductionOp>(op)); 303 })); 304 populateVectorUnrollPatterns( 305 patterns, UnrollVectorOptions() 306 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) 307 .setFilterConstraint([](Operation *op) { 308 return success(isa<vector::TransposeOp>(op)); 309 })); 310 311 if (unrollBasedOnType) { 312 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 313 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 314 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 315 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 316 if (auto floatType = contractOp.getLhsType() 317 .getElementType() 318 .dyn_cast<FloatType>()) { 319 if (floatType.getWidth() == 16) { 320 nativeShape[2] = 4; 321 } 322 } 323 return nativeShape; 324 }; 325 populateVectorUnrollPatterns(patterns, 326 UnrollVectorOptions() 327 .setNativeShapeFn(nativeShapeFn) 328 .setFilterConstraint([](Operation *op) { 329 return success(isa<ContractionOp>(op)); 330 })); 331 } else { 332 populateVectorUnrollPatterns( 333 patterns, UnrollVectorOptions() 334 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 335 .setFilterConstraint([](Operation *op) { 336 return success(isa<ContractionOp>(op)); 337 })); 338 } 339 populateVectorToVectorCanonicalizationPatterns(patterns); 340 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 341 } 342 343 Option<bool> unrollBasedOnType{ 344 *this, "unroll-based-on-type", 345 llvm::cl::desc("Set the unroll factor based on type of the operation"), 346 llvm::cl::init(false)}; 347 }; 348 349 struct TestVectorDistributePatterns 350 : public PassWrapper<TestVectorDistributePatterns, 351 OperationPass<func::FuncOp>> { 352 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns) 353 354 StringRef getArgument() const final { 355 return "test-vector-distribute-patterns"; 356 } 357 StringRef getDescription() const final { 358 return "Test lowering patterns to distribute vector ops in the vector " 359 "dialect"; 360 } 361 TestVectorDistributePatterns() = default; 362 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 363 : PassWrapper(pass) {} 364 void getDependentDialects(DialectRegistry ®istry) const override { 365 registry.insert<VectorDialect>(); 366 registry.insert<AffineDialect>(); 367 } 368 ListOption<int32_t> multiplicity{ 369 *this, "distribution-multiplicity", 370 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 371 372 void runOnOperation() override { 373 MLIRContext *ctx = &getContext(); 374 RewritePatternSet patterns(ctx); 375 func::FuncOp func = getOperation(); 376 func.walk([&](arith::AddFOp op) { 377 OpBuilder builder(op); 378 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 379 SmallVector<int64_t, 2> mul; 380 SmallVector<AffineExpr, 2> perm; 381 SmallVector<Value, 2> ids; 382 unsigned count = 0; 383 // Remove the multiplicity of 1 and calculate the affine map based on 384 // the multiplicity. 385 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 386 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 387 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 388 mul.push_back(m[i]); 389 ids.push_back(func.getArgument(count++)); 390 perm.push_back(getAffineDimExpr(i, ctx)); 391 } 392 } 393 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 394 perm, ctx); 395 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 396 builder, op.getOperation(), ids, mul, map); 397 if (ops.hasValue()) { 398 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 399 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 400 extractOp); 401 } 402 } 403 }); 404 populatePropagateVectorDistributionPatterns(patterns); 405 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 406 } 407 }; 408 409 struct TestVectorToLoopPatterns 410 : public PassWrapper<TestVectorToLoopPatterns, 411 OperationPass<func::FuncOp>> { 412 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns) 413 414 StringRef getArgument() const final { return "test-vector-to-forloop"; } 415 StringRef getDescription() const final { 416 return "Test lowering patterns to break up a vector op into a for loop"; 417 } 418 TestVectorToLoopPatterns() = default; 419 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 420 : PassWrapper(pass) {} 421 void getDependentDialects(DialectRegistry ®istry) const override { 422 registry.insert<VectorDialect>(); 423 registry.insert<AffineDialect>(); 424 } 425 Option<int32_t> multiplicity{ 426 *this, "distribution-multiplicity", 427 llvm::cl::desc("Set the multiplicity used for distributing vector"), 428 llvm::cl::init(32)}; 429 void runOnOperation() override { 430 MLIRContext *ctx = &getContext(); 431 RewritePatternSet patterns(ctx); 432 func::FuncOp func = getOperation(); 433 func.walk([&](arith::AddFOp op) { 434 // Check that the operation type can be broken down into a loop. 435 VectorType type = op.getType().dyn_cast<VectorType>(); 436 if (!type || type.getRank() != 1 || 437 type.getNumElements() % multiplicity != 0) 438 return mlir::WalkResult::advance(); 439 auto filterAlloc = [](Operation *op) { 440 return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op); 441 }; 442 auto dependentOps = getSlice(op, filterAlloc); 443 // Create a loop and move instructions from the Op slice into the loop. 444 OpBuilder builder(op); 445 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 446 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 447 auto numIter = 448 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 449 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 450 for (Operation *it : dependentOps) { 451 it->moveBefore(forOp.getBody()->getTerminator()); 452 } 453 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 454 // break up the original op and let the patterns propagate. 455 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 456 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 457 map); 458 if (ops.hasValue()) { 459 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 460 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 461 } 462 return mlir::WalkResult::interrupt(); 463 }); 464 populatePropagateVectorDistributionPatterns(patterns); 465 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 466 } 467 }; 468 469 struct TestVectorTransferUnrollingPatterns 470 : public PassWrapper<TestVectorTransferUnrollingPatterns, 471 OperationPass<func::FuncOp>> { 472 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 473 TestVectorTransferUnrollingPatterns) 474 475 void getDependentDialects(DialectRegistry ®istry) const override { 476 registry.insert<AffineDialect>(); 477 } 478 StringRef getArgument() const final { 479 return "test-vector-transfer-unrolling-patterns"; 480 } 481 StringRef getDescription() const final { 482 return "Test lowering patterns to unroll transfer ops in the vector " 483 "dialect"; 484 } 485 void runOnOperation() override { 486 MLIRContext *ctx = &getContext(); 487 RewritePatternSet patterns(ctx); 488 populateVectorUnrollPatterns( 489 patterns, 490 UnrollVectorOptions() 491 .setNativeShape(ArrayRef<int64_t>{2, 2}) 492 .setFilterConstraint([](Operation *op) { 493 return success( 494 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 495 })); 496 populateVectorToVectorCanonicalizationPatterns(patterns); 497 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 498 } 499 }; 500 501 struct TestVectorTransferFullPartialSplitPatterns 502 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 503 OperationPass<func::FuncOp>> { 504 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 505 TestVectorTransferFullPartialSplitPatterns) 506 507 StringRef getArgument() const final { 508 return "test-vector-transfer-full-partial-split"; 509 } 510 StringRef getDescription() const final { 511 return "Test lowering patterns to split " 512 "transfer ops via scf.if + linalg ops"; 513 } 514 TestVectorTransferFullPartialSplitPatterns() = default; 515 TestVectorTransferFullPartialSplitPatterns( 516 const TestVectorTransferFullPartialSplitPatterns &pass) 517 : PassWrapper(pass) {} 518 519 void getDependentDialects(DialectRegistry ®istry) const override { 520 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 521 scf::SCFDialect>(); 522 } 523 524 Option<bool> useLinalgOps{ 525 *this, "use-memref-copy", 526 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 527 "memref.copy operations."), 528 llvm::cl::init(false)}; 529 void runOnOperation() override { 530 MLIRContext *ctx = &getContext(); 531 RewritePatternSet patterns(ctx); 532 VectorTransformsOptions options; 533 if (useLinalgOps) 534 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 535 else 536 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 537 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 538 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 539 } 540 }; 541 542 struct TestVectorTransferOpt 543 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> { 544 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) 545 546 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 547 StringRef getDescription() const final { 548 return "Test optimization transformations for transfer ops"; 549 } 550 void runOnOperation() override { transferOpflowOpt(getOperation()); } 551 }; 552 553 struct TestVectorTransferLoweringPatterns 554 : public PassWrapper<TestVectorTransferLoweringPatterns, 555 OperationPass<func::FuncOp>> { 556 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 557 TestVectorTransferLoweringPatterns) 558 559 void getDependentDialects(DialectRegistry ®istry) const override { 560 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 561 } 562 StringRef getArgument() const final { 563 return "test-vector-transfer-lowering-patterns"; 564 } 565 StringRef getDescription() const final { 566 return "Test lowering patterns to lower transfer ops to other vector ops"; 567 } 568 void runOnOperation() override { 569 RewritePatternSet patterns(&getContext()); 570 populateVectorTransferLoweringPatterns(patterns); 571 populateVectorTransferPermutationMapLoweringPatterns(patterns); 572 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 573 } 574 }; 575 576 struct TestVectorMultiReductionLoweringPatterns 577 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 578 OperationPass<func::FuncOp>> { 579 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 580 TestVectorMultiReductionLoweringPatterns) 581 582 TestVectorMultiReductionLoweringPatterns() = default; 583 TestVectorMultiReductionLoweringPatterns( 584 const TestVectorMultiReductionLoweringPatterns &pass) 585 : PassWrapper(pass) {} 586 void getDependentDialects(DialectRegistry ®istry) const override { 587 registry.insert<memref::MemRefDialect>(); 588 } 589 StringRef getArgument() const final { 590 return "test-vector-multi-reduction-lowering-patterns"; 591 } 592 StringRef getDescription() const final { 593 return "Test lowering patterns to lower vector.multi_reduction to other " 594 "vector ops"; 595 } 596 Option<bool> useOuterReductions{ 597 *this, "use-outer-reductions", 598 llvm::cl::desc("Move reductions to outer most dimensions"), 599 llvm::cl::init(false)}; 600 void runOnOperation() override { 601 RewritePatternSet patterns(&getContext()); 602 populateVectorMultiReductionLoweringPatterns( 603 patterns, useOuterReductions 604 ? vector::VectorMultiReductionLowering::InnerParallel 605 : vector::VectorMultiReductionLowering::InnerReduction); 606 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 607 } 608 }; 609 610 struct TestVectorTransferCollapseInnerMostContiguousDims 611 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 612 OperationPass<func::FuncOp>> { 613 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 614 TestVectorTransferCollapseInnerMostContiguousDims) 615 616 TestVectorTransferCollapseInnerMostContiguousDims() = default; 617 TestVectorTransferCollapseInnerMostContiguousDims( 618 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 619 620 void getDependentDialects(DialectRegistry ®istry) const override { 621 registry.insert<memref::MemRefDialect, AffineDialect>(); 622 } 623 624 StringRef getArgument() const final { 625 return "test-vector-transfer-collapse-inner-most-dims"; 626 } 627 628 StringRef getDescription() const final { 629 return "Test lowering patterns that reducedes the rank of the vector " 630 "transfer memory and vector operands."; 631 } 632 633 void runOnOperation() override { 634 RewritePatternSet patterns(&getContext()); 635 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 636 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 637 } 638 }; 639 640 struct TestVectorReduceToContractPatternsPatterns 641 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 642 OperationPass<func::FuncOp>> { 643 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 644 TestVectorReduceToContractPatternsPatterns) 645 646 StringRef getArgument() const final { 647 return "test-vector-reduction-to-contract-patterns"; 648 } 649 StringRef getDescription() const final { 650 return "Test patterns to convert multireduce op to contract and combine " 651 "broadcast/transpose to contract"; 652 } 653 void runOnOperation() override { 654 RewritePatternSet patterns(&getContext()); 655 populateVectorReductionToContractPatterns(patterns); 656 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 657 } 658 }; 659 660 struct TestVectorTransferDropUnitDimsPatterns 661 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, 662 OperationPass<func::FuncOp>> { 663 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 664 TestVectorTransferDropUnitDimsPatterns) 665 666 StringRef getArgument() const final { 667 return "test-vector-transfer-drop-unit-dims-patterns"; 668 } 669 void getDependentDialects(DialectRegistry ®istry) const override { 670 registry.insert<memref::MemRefDialect>(); 671 } 672 void runOnOperation() override { 673 RewritePatternSet patterns(&getContext()); 674 populateVectorTransferDropUnitDimsPatterns(patterns); 675 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 676 } 677 }; 678 679 struct TestFlattenVectorTransferPatterns 680 : public PassWrapper<TestFlattenVectorTransferPatterns, 681 OperationPass<func::FuncOp>> { 682 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 683 TestFlattenVectorTransferPatterns) 684 685 StringRef getArgument() const final { 686 return "test-vector-transfer-flatten-patterns"; 687 } 688 StringRef getDescription() const final { 689 return "Test patterns to rewrite contiguous row-major N-dimensional " 690 "vector.transfer_{read,write} ops into 1D transfers"; 691 } 692 void getDependentDialects(DialectRegistry ®istry) const override { 693 registry.insert<memref::MemRefDialect>(); 694 } 695 void runOnOperation() override { 696 RewritePatternSet patterns(&getContext()); 697 populateFlattenVectorTransferPatterns(patterns); 698 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 699 } 700 }; 701 702 struct TestVectorScanLowering 703 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> { 704 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) 705 706 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 707 StringRef getDescription() const final { 708 return "Test lowering patterns that lower the scan op in the vector " 709 "dialect"; 710 } 711 void runOnOperation() override { 712 RewritePatternSet patterns(&getContext()); 713 populateVectorScanLoweringPatterns(patterns); 714 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 715 } 716 }; 717 718 /// Allocate shared memory for a single warp to test lowering of 719 /// WarpExecuteOnLane0Op. 720 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, 721 WarpExecuteOnLane0Op warpOp, 722 Type type) { 723 static constexpr int64_t kSharedMemorySpace = 3; 724 // Compute type of shared memory buffer. 725 MemRefType memrefType; 726 if (auto vectorType = type.dyn_cast<VectorType>()) { 727 memrefType = 728 MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, 729 kSharedMemorySpace); 730 } else { 731 memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); 732 } 733 734 // Get symbol table holding all shared memory globals. 735 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>(); 736 SymbolTable symbolTable(moduleOp); 737 738 // Create a pretty name. 739 SmallString<64> buf; 740 llvm::raw_svector_ostream os(buf); 741 interleave(memrefType.getShape(), os, "x"); 742 os << "x" << memrefType.getElementType(); 743 std::string symbolName = (Twine("__shared_") + os.str()).str(); 744 745 auto ip = builder.saveInsertionPoint(); 746 builder.setInsertionPoint(moduleOp); 747 auto global = builder.create<memref::GlobalOp>( 748 loc, 749 /*sym_name=*/symbolName, 750 /*sym_visibility=*/builder.getStringAttr("private"), 751 /*type=*/memrefType, 752 /*initial_value=*/Attribute(), 753 /*constant=*/false, 754 /*alignment=*/IntegerAttr()); 755 symbolTable.insert(global); 756 // The symbol table inserts at the end of the module, but globals are a bit 757 // nicer if they are at the beginning. 758 global->moveBefore(&moduleOp.front()); 759 760 builder.restoreInsertionPoint(ip); 761 return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); 762 } 763 764 struct TestVectorDistribution 765 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> { 766 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) 767 768 void getDependentDialects(DialectRegistry ®istry) const override { 769 registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect>(); 770 } 771 772 StringRef getArgument() const final { return "test-vector-warp-distribute"; } 773 StringRef getDescription() const final { 774 return "Test vector warp distribute transformation and lowering patterns"; 775 } 776 TestVectorDistribution() = default; 777 TestVectorDistribution(const TestVectorDistribution &pass) 778 : PassWrapper(pass) {} 779 780 Option<bool> warpOpToSCF{ 781 *this, "rewrite-warp-ops-to-scf-if", 782 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), 783 llvm::cl::init(false)}; 784 785 void runOnOperation() override { 786 RewritePatternSet patterns(&getContext()); 787 WarpExecuteOnLane0LoweringOptions options; 788 options.warpAllocationFn = allocateGlobalSharedMemory; 789 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, 790 WarpExecuteOnLane0Op warpOp) { 791 builder.create<gpu::BarrierOp>(loc); 792 }; 793 // Test on one pattern in isolation. 794 if (warpOpToSCF) { 795 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); 796 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 797 return; 798 } 799 } 800 }; 801 802 } // namespace 803 804 namespace mlir { 805 namespace test { 806 void registerTestVectorLowerings() { 807 PassRegistration<TestVectorToVectorLowering>(); 808 809 PassRegistration<TestVectorContractionLowering>(); 810 811 PassRegistration<TestVectorTransposeLowering>(); 812 813 PassRegistration<TestVectorUnrollingPatterns>(); 814 815 PassRegistration<TestVectorTransferUnrollingPatterns>(); 816 817 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 818 819 PassRegistration<TestVectorDistributePatterns>(); 820 821 PassRegistration<TestVectorToLoopPatterns>(); 822 823 PassRegistration<TestVectorTransferOpt>(); 824 825 PassRegistration<TestVectorTransferLoweringPatterns>(); 826 827 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 828 829 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 830 831 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 832 833 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 834 835 PassRegistration<TestFlattenVectorTransferPatterns>(); 836 837 PassRegistration<TestVectorScanLowering>(); 838 839 PassRegistration<TestVectorDistribution>(); 840 } 841 } // namespace test 842 } // namespace mlir 843