1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 using namespace mlir; 28 using namespace mlir::linalg; 29 using namespace mlir::vector; 30 31 namespace { 32 33 struct TestVectorToVectorLowering 34 : public PassWrapper<TestVectorToVectorLowering, 35 OperationPass<func::FuncOp>> { 36 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) 37 38 TestVectorToVectorLowering() = default; 39 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 40 : PassWrapper(pass) {} 41 StringRef getArgument() const final { 42 return "test-vector-to-vector-lowering"; 43 } 44 StringRef getDescription() const final { 45 return "Test lowering patterns between ops in the vector dialect"; 46 } 47 48 void getDependentDialects(DialectRegistry ®istry) const override { 49 registry.insert<AffineDialect>(); 50 } 51 52 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 53 llvm::cl::init(false)}; 54 55 void runOnOperation() override { 56 auto *ctx = &getContext(); 57 RewritePatternSet patterns(ctx); 58 if (unroll) { 59 populateVectorUnrollPatterns( 60 patterns, 61 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 62 filter)); 63 } 64 populateVectorToVectorCanonicalizationPatterns(patterns); 65 populateBubbleVectorBitCastOpPatterns(patterns); 66 populateCastAwayVectorLeadingOneDimPatterns(patterns); 67 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 68 } 69 70 private: 71 // Return the target shape based on op type. 72 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 73 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 74 return SmallVector<int64_t, 4>(2, 2); 75 if (isa<vector::ContractionOp>(op)) 76 return SmallVector<int64_t, 4>(3, 2); 77 // For transfer ops, just propagate the shape coming from 78 // InsertStridedSlices/ExtractStridedSlices. 79 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 80 VectorType dstVec; 81 for (Operation *users : readOp->getUsers()) { 82 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 83 if (!extract) 84 return llvm::None; 85 auto vecType = extract.getResult().getType().cast<VectorType>(); 86 if (dstVec && dstVec != vecType) 87 return llvm::None; 88 dstVec = vecType; 89 } 90 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 91 dstVec.getShape().end()); 92 } 93 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 94 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 95 if (!insert) 96 return llvm::None; 97 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 98 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 99 } 100 return llvm::None; 101 } 102 103 static LogicalResult filter(Operation *op) { 104 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 105 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 106 } 107 }; 108 109 struct TestVectorContractionLowering 110 : public PassWrapper<TestVectorContractionLowering, 111 OperationPass<func::FuncOp>> { 112 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering) 113 114 StringRef getArgument() const final { 115 return "test-vector-contraction-lowering"; 116 } 117 StringRef getDescription() const final { 118 return "Test lowering patterns that lower contract ops in the vector " 119 "dialect"; 120 } 121 TestVectorContractionLowering() = default; 122 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 123 : PassWrapper(pass) {} 124 125 Option<bool> lowerToFlatMatrix{ 126 *this, "vector-lower-matrix-intrinsics", 127 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 128 llvm::cl::init(false)}; 129 Option<bool> lowerToOuterProduct{ 130 *this, "vector-outerproduct", 131 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 132 llvm::cl::init(false)}; 133 Option<bool> lowerToFilterOuterProduct{ 134 *this, "vector-filter-outerproduct", 135 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 136 "vectors of size 4."), 137 llvm::cl::init(false)}; 138 139 void runOnOperation() override { 140 RewritePatternSet patterns(&getContext()); 141 142 // Test on one pattern in isolation. 143 if (lowerToOuterProduct) { 144 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 145 VectorTransformsOptions options{lowering}; 146 patterns.add<ContractionOpToOuterProductOpLowering>(options, 147 &getContext()); 148 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 149 return; 150 } 151 152 // Test on one pattern in isolation. 153 if (lowerToFilterOuterProduct) { 154 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 155 VectorTransformsOptions options{lowering}; 156 patterns.add<ContractionOpToOuterProductOpLowering>( 157 options, &getContext(), [](vector::ContractionOp op) { 158 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 159 // where M is not 4. 160 if (op.getRhsType().getShape()[0] == 4) 161 return failure(); 162 return success(); 163 }); 164 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 165 return; 166 } 167 168 // Test on all contract lowering patterns. 169 VectorContractLowering contractLowering = VectorContractLowering::Dot; 170 if (lowerToFlatMatrix) 171 contractLowering = VectorContractLowering::Matmul; 172 VectorMultiReductionLowering vectorMultiReductionLowering = 173 VectorMultiReductionLowering::InnerParallel; 174 VectorTransformsOptions options{contractLowering, 175 vectorMultiReductionLowering, 176 VectorTransposeLowering()}; 177 populateVectorBroadcastLoweringPatterns(patterns); 178 populateVectorContractLoweringPatterns(patterns, options); 179 populateVectorMaskOpLoweringPatterns(patterns); 180 populateVectorShapeCastLoweringPatterns(patterns); 181 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 182 } 183 }; 184 185 struct TestVectorTransposeLowering 186 : public PassWrapper<TestVectorTransposeLowering, 187 OperationPass<func::FuncOp>> { 188 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering) 189 190 StringRef getArgument() const final { 191 return "test-vector-transpose-lowering"; 192 } 193 StringRef getDescription() const final { 194 return "Test lowering patterns that lower contract ops in the vector " 195 "dialect"; 196 } 197 TestVectorTransposeLowering() = default; 198 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 199 : PassWrapper(pass) {} 200 201 Option<bool> lowerToEltwise{ 202 *this, "eltwise", 203 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 204 llvm::cl::init(false)}; 205 Option<bool> lowerToFlatTranspose{ 206 *this, "flat", 207 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 208 llvm::cl::init(false)}; 209 Option<bool> lowerToShuffleTranspose{ 210 *this, "shuffle", 211 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 212 llvm::cl::init(false)}; 213 Option<bool> lowerToAvx2{ 214 *this, "avx2", 215 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 216 llvm::cl::init(false)}; 217 218 void getDependentDialects(DialectRegistry ®istry) const override { 219 registry.insert<LLVM::LLVMDialect>(); 220 } 221 222 void runOnOperation() override { 223 RewritePatternSet patterns(&getContext()); 224 225 // Test on one pattern in isolation. 226 // Explicitly disable shape_cast lowering. 227 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 228 .enableVectorTransposeLowering() 229 .enableShapeCastLowering(false); 230 if (lowerToEltwise) { 231 options = options.setVectorTransformsOptions( 232 VectorTransformsOptions().setVectorTransposeLowering( 233 VectorTransposeLowering::EltWise)); 234 } 235 if (lowerToFlatTranspose) { 236 options = options.setVectorTransformsOptions( 237 VectorTransformsOptions().setVectorTransposeLowering( 238 VectorTransposeLowering::Flat)); 239 } 240 if (lowerToShuffleTranspose) { 241 options = options.setVectorTransformsOptions( 242 VectorTransformsOptions().setVectorTransposeLowering( 243 VectorTransposeLowering::Shuffle)); 244 } 245 if (lowerToAvx2) { 246 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 247 x86vector::avx2::LoweringOptions().setTransposeOptions( 248 x86vector::avx2::TransposeLoweringOptions() 249 .lower4x8xf32() 250 .lower8x8xf32())); 251 } 252 253 OpPassManager dynamicPM("func.func"); 254 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 255 if (failed(runPipeline(dynamicPM, getOperation()))) 256 return signalPassFailure(); 257 } 258 }; 259 260 struct TestVectorUnrollingPatterns 261 : public PassWrapper<TestVectorUnrollingPatterns, 262 OperationPass<func::FuncOp>> { 263 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) 264 265 StringRef getArgument() const final { 266 return "test-vector-unrolling-patterns"; 267 } 268 StringRef getDescription() const final { 269 return "Test lowering patterns to unroll contract ops in the vector " 270 "dialect"; 271 } 272 TestVectorUnrollingPatterns() = default; 273 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 274 : PassWrapper(pass) {} 275 void runOnOperation() override { 276 MLIRContext *ctx = &getContext(); 277 RewritePatternSet patterns(ctx); 278 populateVectorUnrollPatterns( 279 patterns, UnrollVectorOptions() 280 .setNativeShape(ArrayRef<int64_t>{2, 2}) 281 .setFilterConstraint([](Operation *op) { 282 return success(isa<arith::AddFOp, vector::FMAOp, 283 vector::MultiDimReductionOp>(op)); 284 })); 285 populateVectorUnrollPatterns( 286 patterns, UnrollVectorOptions() 287 .setNativeShape(ArrayRef<int64_t>{2}) 288 .setFilterConstraint([](Operation *op) { 289 return success(isa<vector::ReductionOp>(op)); 290 })); 291 populateVectorUnrollPatterns( 292 patterns, UnrollVectorOptions() 293 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) 294 .setFilterConstraint([](Operation *op) { 295 return success(isa<vector::TransposeOp>(op)); 296 })); 297 298 if (unrollBasedOnType) { 299 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 300 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 301 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 302 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 303 if (auto floatType = contractOp.getLhsType() 304 .getElementType() 305 .dyn_cast<FloatType>()) { 306 if (floatType.getWidth() == 16) { 307 nativeShape[2] = 4; 308 } 309 } 310 return nativeShape; 311 }; 312 populateVectorUnrollPatterns(patterns, 313 UnrollVectorOptions() 314 .setNativeShapeFn(nativeShapeFn) 315 .setFilterConstraint([](Operation *op) { 316 return success(isa<ContractionOp>(op)); 317 })); 318 } else { 319 populateVectorUnrollPatterns( 320 patterns, UnrollVectorOptions() 321 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 322 .setFilterConstraint([](Operation *op) { 323 return success(isa<ContractionOp>(op)); 324 })); 325 } 326 populateVectorToVectorCanonicalizationPatterns(patterns); 327 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 328 } 329 330 Option<bool> unrollBasedOnType{ 331 *this, "unroll-based-on-type", 332 llvm::cl::desc("Set the unroll factor based on type of the operation"), 333 llvm::cl::init(false)}; 334 }; 335 336 struct TestVectorDistributePatterns 337 : public PassWrapper<TestVectorDistributePatterns, 338 OperationPass<func::FuncOp>> { 339 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns) 340 341 StringRef getArgument() const final { 342 return "test-vector-distribute-patterns"; 343 } 344 StringRef getDescription() const final { 345 return "Test lowering patterns to distribute vector ops in the vector " 346 "dialect"; 347 } 348 TestVectorDistributePatterns() = default; 349 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 350 : PassWrapper(pass) {} 351 void getDependentDialects(DialectRegistry ®istry) const override { 352 registry.insert<VectorDialect>(); 353 registry.insert<AffineDialect>(); 354 } 355 ListOption<int32_t> multiplicity{ 356 *this, "distribution-multiplicity", 357 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 358 359 void runOnOperation() override { 360 MLIRContext *ctx = &getContext(); 361 RewritePatternSet patterns(ctx); 362 func::FuncOp func = getOperation(); 363 func.walk([&](arith::AddFOp op) { 364 OpBuilder builder(op); 365 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 366 SmallVector<int64_t, 2> mul; 367 SmallVector<AffineExpr, 2> perm; 368 SmallVector<Value, 2> ids; 369 unsigned count = 0; 370 // Remove the multiplicity of 1 and calculate the affine map based on 371 // the multiplicity. 372 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 373 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 374 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 375 mul.push_back(m[i]); 376 ids.push_back(func.getArgument(count++)); 377 perm.push_back(getAffineDimExpr(i, ctx)); 378 } 379 } 380 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 381 perm, ctx); 382 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 383 builder, op.getOperation(), ids, mul, map); 384 if (ops.hasValue()) { 385 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 386 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 387 extractOp); 388 } 389 } 390 }); 391 populatePropagateVectorDistributionPatterns(patterns); 392 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 393 } 394 }; 395 396 struct TestVectorToLoopPatterns 397 : public PassWrapper<TestVectorToLoopPatterns, 398 OperationPass<func::FuncOp>> { 399 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns) 400 401 StringRef getArgument() const final { return "test-vector-to-forloop"; } 402 StringRef getDescription() const final { 403 return "Test lowering patterns to break up a vector op into a for loop"; 404 } 405 TestVectorToLoopPatterns() = default; 406 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 407 : PassWrapper(pass) {} 408 void getDependentDialects(DialectRegistry ®istry) const override { 409 registry.insert<VectorDialect>(); 410 registry.insert<AffineDialect>(); 411 } 412 Option<int32_t> multiplicity{ 413 *this, "distribution-multiplicity", 414 llvm::cl::desc("Set the multiplicity used for distributing vector"), 415 llvm::cl::init(32)}; 416 void runOnOperation() override { 417 MLIRContext *ctx = &getContext(); 418 RewritePatternSet patterns(ctx); 419 func::FuncOp func = getOperation(); 420 func.walk([&](arith::AddFOp op) { 421 // Check that the operation type can be broken down into a loop. 422 VectorType type = op.getType().dyn_cast<VectorType>(); 423 if (!type || type.getRank() != 1 || 424 type.getNumElements() % multiplicity != 0) 425 return mlir::WalkResult::advance(); 426 auto filterAlloc = [](Operation *op) { 427 return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op); 428 }; 429 auto dependentOps = getSlice(op, filterAlloc); 430 // Create a loop and move instructions from the Op slice into the loop. 431 OpBuilder builder(op); 432 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 433 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 434 auto numIter = 435 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 436 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 437 for (Operation *it : dependentOps) { 438 it->moveBefore(forOp.getBody()->getTerminator()); 439 } 440 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 441 // break up the original op and let the patterns propagate. 442 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 443 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 444 map); 445 if (ops.hasValue()) { 446 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 447 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 448 } 449 return mlir::WalkResult::interrupt(); 450 }); 451 populatePropagateVectorDistributionPatterns(patterns); 452 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 453 } 454 }; 455 456 struct TestVectorTransferUnrollingPatterns 457 : public PassWrapper<TestVectorTransferUnrollingPatterns, 458 OperationPass<func::FuncOp>> { 459 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 460 TestVectorTransferUnrollingPatterns) 461 462 void getDependentDialects(DialectRegistry ®istry) const override { 463 registry.insert<AffineDialect>(); 464 } 465 StringRef getArgument() const final { 466 return "test-vector-transfer-unrolling-patterns"; 467 } 468 StringRef getDescription() const final { 469 return "Test lowering patterns to unroll transfer ops in the vector " 470 "dialect"; 471 } 472 void runOnOperation() override { 473 MLIRContext *ctx = &getContext(); 474 RewritePatternSet patterns(ctx); 475 populateVectorUnrollPatterns( 476 patterns, 477 UnrollVectorOptions() 478 .setNativeShape(ArrayRef<int64_t>{2, 2}) 479 .setFilterConstraint([](Operation *op) { 480 return success( 481 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 482 })); 483 populateVectorToVectorCanonicalizationPatterns(patterns); 484 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 485 } 486 }; 487 488 struct TestVectorTransferFullPartialSplitPatterns 489 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 490 OperationPass<func::FuncOp>> { 491 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 492 TestVectorTransferFullPartialSplitPatterns) 493 494 StringRef getArgument() const final { 495 return "test-vector-transfer-full-partial-split"; 496 } 497 StringRef getDescription() const final { 498 return "Test lowering patterns to split " 499 "transfer ops via scf.if + linalg ops"; 500 } 501 TestVectorTransferFullPartialSplitPatterns() = default; 502 TestVectorTransferFullPartialSplitPatterns( 503 const TestVectorTransferFullPartialSplitPatterns &pass) 504 : PassWrapper(pass) {} 505 506 void getDependentDialects(DialectRegistry ®istry) const override { 507 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 508 scf::SCFDialect>(); 509 } 510 511 Option<bool> useLinalgOps{ 512 *this, "use-memref-copy", 513 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 514 "memref.copy operations."), 515 llvm::cl::init(false)}; 516 void runOnOperation() override { 517 MLIRContext *ctx = &getContext(); 518 RewritePatternSet patterns(ctx); 519 VectorTransformsOptions options; 520 if (useLinalgOps) 521 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 522 else 523 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 524 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 525 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 526 } 527 }; 528 529 struct TestVectorTransferOpt 530 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> { 531 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) 532 533 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 534 StringRef getDescription() const final { 535 return "Test optimization transformations for transfer ops"; 536 } 537 void runOnOperation() override { transferOpflowOpt(getOperation()); } 538 }; 539 540 struct TestVectorTransferLoweringPatterns 541 : public PassWrapper<TestVectorTransferLoweringPatterns, 542 OperationPass<func::FuncOp>> { 543 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 544 TestVectorTransferLoweringPatterns) 545 546 void getDependentDialects(DialectRegistry ®istry) const override { 547 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 548 } 549 StringRef getArgument() const final { 550 return "test-vector-transfer-lowering-patterns"; 551 } 552 StringRef getDescription() const final { 553 return "Test lowering patterns to lower transfer ops to other vector ops"; 554 } 555 void runOnOperation() override { 556 RewritePatternSet patterns(&getContext()); 557 populateVectorTransferLoweringPatterns(patterns); 558 populateVectorTransferPermutationMapLoweringPatterns(patterns); 559 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 560 } 561 }; 562 563 struct TestVectorMultiReductionLoweringPatterns 564 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 565 OperationPass<func::FuncOp>> { 566 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 567 TestVectorMultiReductionLoweringPatterns) 568 569 TestVectorMultiReductionLoweringPatterns() = default; 570 TestVectorMultiReductionLoweringPatterns( 571 const TestVectorMultiReductionLoweringPatterns &pass) 572 : PassWrapper(pass) {} 573 void getDependentDialects(DialectRegistry ®istry) const override { 574 registry.insert<memref::MemRefDialect>(); 575 } 576 StringRef getArgument() const final { 577 return "test-vector-multi-reduction-lowering-patterns"; 578 } 579 StringRef getDescription() const final { 580 return "Test lowering patterns to lower vector.multi_reduction to other " 581 "vector ops"; 582 } 583 Option<bool> useOuterReductions{ 584 *this, "use-outer-reductions", 585 llvm::cl::desc("Move reductions to outer most dimensions"), 586 llvm::cl::init(false)}; 587 void runOnOperation() override { 588 RewritePatternSet patterns(&getContext()); 589 populateVectorMultiReductionLoweringPatterns( 590 patterns, useOuterReductions 591 ? vector::VectorMultiReductionLowering::InnerParallel 592 : vector::VectorMultiReductionLowering::InnerReduction); 593 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 594 } 595 }; 596 597 struct TestVectorTransferCollapseInnerMostContiguousDims 598 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 599 OperationPass<func::FuncOp>> { 600 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 601 TestVectorTransferCollapseInnerMostContiguousDims) 602 603 TestVectorTransferCollapseInnerMostContiguousDims() = default; 604 TestVectorTransferCollapseInnerMostContiguousDims( 605 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 606 607 void getDependentDialects(DialectRegistry ®istry) const override { 608 registry.insert<memref::MemRefDialect, AffineDialect>(); 609 } 610 611 StringRef getArgument() const final { 612 return "test-vector-transfer-collapse-inner-most-dims"; 613 } 614 615 StringRef getDescription() const final { 616 return "Test lowering patterns that reducedes the rank of the vector " 617 "transfer memory and vector operands."; 618 } 619 620 void runOnOperation() override { 621 RewritePatternSet patterns(&getContext()); 622 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 623 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 624 } 625 }; 626 627 struct TestVectorReduceToContractPatternsPatterns 628 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 629 OperationPass<func::FuncOp>> { 630 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 631 TestVectorReduceToContractPatternsPatterns) 632 633 StringRef getArgument() const final { 634 return "test-vector-reduction-to-contract-patterns"; 635 } 636 StringRef getDescription() const final { 637 return "Test patterns to convert multireduce op to contract and combine " 638 "broadcast/transpose to contract"; 639 } 640 void runOnOperation() override { 641 RewritePatternSet patterns(&getContext()); 642 populateVectorReductionToContractPatterns(patterns); 643 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 644 } 645 }; 646 647 struct TestVectorTransferDropUnitDimsPatterns 648 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, 649 OperationPass<func::FuncOp>> { 650 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 651 TestVectorTransferDropUnitDimsPatterns) 652 653 StringRef getArgument() const final { 654 return "test-vector-transfer-drop-unit-dims-patterns"; 655 } 656 void getDependentDialects(DialectRegistry ®istry) const override { 657 registry.insert<memref::MemRefDialect>(); 658 } 659 void runOnOperation() override { 660 RewritePatternSet patterns(&getContext()); 661 populateVectorTransferDropUnitDimsPatterns(patterns); 662 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 663 } 664 }; 665 666 struct TestFlattenVectorTransferPatterns 667 : public PassWrapper<TestFlattenVectorTransferPatterns, 668 OperationPass<func::FuncOp>> { 669 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 670 TestFlattenVectorTransferPatterns) 671 672 StringRef getArgument() const final { 673 return "test-vector-transfer-flatten-patterns"; 674 } 675 StringRef getDescription() const final { 676 return "Test patterns to rewrite contiguous row-major N-dimensional " 677 "vector.transfer_{read,write} ops into 1D transfers"; 678 } 679 void getDependentDialects(DialectRegistry ®istry) const override { 680 registry.insert<memref::MemRefDialect>(); 681 } 682 void runOnOperation() override { 683 RewritePatternSet patterns(&getContext()); 684 populateFlattenVectorTransferPatterns(patterns); 685 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 686 } 687 }; 688 689 struct TestVectorScanLowering 690 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> { 691 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) 692 693 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 694 StringRef getDescription() const final { 695 return "Test lowering patterns that lower the scan op in the vector " 696 "dialect"; 697 } 698 void runOnOperation() override { 699 RewritePatternSet patterns(&getContext()); 700 populateVectorScanLoweringPatterns(patterns); 701 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 702 } 703 }; 704 705 /// Allocate shared memory for a single warp to test lowering of 706 /// WarpExecuteOnLane0Op. 707 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, 708 WarpExecuteOnLane0Op warpOp, 709 Type type) { 710 static constexpr int64_t kSharedMemorySpace = 3; 711 // Compute type of shared memory buffer. 712 MemRefType memrefType; 713 if (auto vectorType = type.dyn_cast<VectorType>()) { 714 memrefType = 715 MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, 716 kSharedMemorySpace); 717 } else { 718 memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); 719 } 720 721 // Get symbol table holding all shared memory globals. 722 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>(); 723 SymbolTable symbolTable(moduleOp); 724 725 // Create a pretty name. 726 SmallString<64> buf; 727 llvm::raw_svector_ostream os(buf); 728 interleave(memrefType.getShape(), os, "x"); 729 os << "x" << memrefType.getElementType(); 730 std::string symbolName = (Twine("__shared_") + os.str()).str(); 731 732 auto ip = builder.saveInsertionPoint(); 733 builder.setInsertionPoint(moduleOp); 734 auto global = builder.create<memref::GlobalOp>( 735 loc, 736 /*sym_name=*/symbolName, 737 /*sym_visibility=*/builder.getStringAttr("private"), 738 /*type=*/memrefType, 739 /*initial_value=*/Attribute(), 740 /*constant=*/false, 741 /*alignment=*/IntegerAttr()); 742 symbolTable.insert(global); 743 // The symbol table inserts at the end of the module, but globals are a bit 744 // nicer if they are at the beginning. 745 global->moveBefore(&moduleOp.front()); 746 747 builder.restoreInsertionPoint(ip); 748 return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); 749 } 750 751 struct TestVectorDistribution 752 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> { 753 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) 754 755 void getDependentDialects(DialectRegistry ®istry) const override { 756 registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect>(); 757 } 758 759 StringRef getArgument() const final { return "test-vector-warp-distribute"; } 760 StringRef getDescription() const final { 761 return "Test vector warp distribute transformation and lowering patterns"; 762 } 763 TestVectorDistribution() = default; 764 TestVectorDistribution(const TestVectorDistribution &pass) 765 : PassWrapper(pass) {} 766 767 Option<bool> warpOpToSCF{ 768 *this, "rewrite-warp-ops-to-scf-if", 769 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), 770 llvm::cl::init(false)}; 771 772 void runOnOperation() override { 773 RewritePatternSet patterns(&getContext()); 774 WarpExecuteOnLane0LoweringOptions options; 775 options.warpAllocationFn = allocateGlobalSharedMemory; 776 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, 777 WarpExecuteOnLane0Op warpOp) { 778 builder.create<gpu::BarrierOp>(loc); 779 }; 780 // Test on one pattern in isolation. 781 if (warpOpToSCF) { 782 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); 783 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 784 return; 785 } 786 } 787 }; 788 789 } // namespace 790 791 namespace mlir { 792 namespace test { 793 void registerTestVectorLowerings() { 794 PassRegistration<TestVectorToVectorLowering>(); 795 796 PassRegistration<TestVectorContractionLowering>(); 797 798 PassRegistration<TestVectorTransposeLowering>(); 799 800 PassRegistration<TestVectorUnrollingPatterns>(); 801 802 PassRegistration<TestVectorTransferUnrollingPatterns>(); 803 804 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 805 806 PassRegistration<TestVectorDistributePatterns>(); 807 808 PassRegistration<TestVectorToLoopPatterns>(); 809 810 PassRegistration<TestVectorTransferOpt>(); 811 812 PassRegistration<TestVectorTransferLoweringPatterns>(); 813 814 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 815 816 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 817 818 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 819 820 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 821 822 PassRegistration<TestFlattenVectorTransferPatterns>(); 823 824 PassRegistration<TestVectorScanLowering>(); 825 826 PassRegistration<TestVectorDistribution>(); 827 } 828 } // namespace test 829 } // namespace mlir 830