1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/Vector/IR/VectorOps.h" 22 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Pass/PassManager.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 using namespace mlir::vector; 31 32 namespace { 33 34 struct TestVectorToVectorLowering 35 : public PassWrapper<TestVectorToVectorLowering, 36 OperationPass<func::FuncOp>> { 37 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) 38 39 TestVectorToVectorLowering() = default; 40 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 41 : PassWrapper(pass) {} 42 StringRef getArgument() const final { 43 return "test-vector-to-vector-lowering"; 44 } 45 StringRef getDescription() const final { 46 return "Test lowering patterns between ops in the vector dialect"; 47 } 48 49 void getDependentDialects(DialectRegistry ®istry) const override { 50 registry.insert<AffineDialect>(); 51 } 52 53 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 54 llvm::cl::init(false)}; 55 56 void runOnOperation() override { 57 auto *ctx = &getContext(); 58 RewritePatternSet patterns(ctx); 59 if (unroll) { 60 populateVectorUnrollPatterns( 61 patterns, 62 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 63 filter)); 64 } 65 populateVectorToVectorCanonicalizationPatterns(patterns); 66 populateBubbleVectorBitCastOpPatterns(patterns); 67 populateCastAwayVectorLeadingOneDimPatterns(patterns); 68 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 69 } 70 71 private: 72 // Return the target shape based on op type. 73 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 74 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 75 return SmallVector<int64_t, 4>(2, 2); 76 if (isa<vector::ContractionOp>(op)) 77 return SmallVector<int64_t, 4>(3, 2); 78 // For transfer ops, just propagate the shape coming from 79 // InsertStridedSlices/ExtractStridedSlices. 80 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 81 VectorType dstVec; 82 for (Operation *users : readOp->getUsers()) { 83 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 84 if (!extract) 85 return llvm::None; 86 auto vecType = extract.getResult().getType().cast<VectorType>(); 87 if (dstVec && dstVec != vecType) 88 return llvm::None; 89 dstVec = vecType; 90 } 91 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 92 dstVec.getShape().end()); 93 } 94 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 95 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 96 if (!insert) 97 return llvm::None; 98 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 99 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 100 } 101 return llvm::None; 102 } 103 104 static LogicalResult filter(Operation *op) { 105 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 106 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 107 } 108 }; 109 110 struct TestVectorContractionLowering 111 : public PassWrapper<TestVectorContractionLowering, 112 OperationPass<func::FuncOp>> { 113 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering) 114 115 StringRef getArgument() const final { 116 return "test-vector-contraction-lowering"; 117 } 118 StringRef getDescription() const final { 119 return "Test lowering patterns that lower contract ops in the vector " 120 "dialect"; 121 } 122 TestVectorContractionLowering() = default; 123 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 124 : PassWrapper(pass) {} 125 126 Option<bool> lowerToFlatMatrix{ 127 *this, "vector-lower-matrix-intrinsics", 128 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 129 llvm::cl::init(false)}; 130 Option<bool> lowerToOuterProduct{ 131 *this, "vector-outerproduct", 132 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 133 llvm::cl::init(false)}; 134 Option<bool> lowerToFilterOuterProduct{ 135 *this, "vector-filter-outerproduct", 136 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 137 "vectors of size 4."), 138 llvm::cl::init(false)}; 139 Option<bool> lowerToParallelArith{ 140 *this, "vector-parallel-arith", 141 llvm::cl::desc("Lower vector.contract to elementwise vector ops."), 142 llvm::cl::init(false)}; 143 144 void runOnOperation() override { 145 RewritePatternSet patterns(&getContext()); 146 147 // Test on one pattern in isolation. 148 if (lowerToOuterProduct) { 149 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 150 VectorTransformsOptions options{lowering}; 151 patterns.add<ContractionOpToOuterProductOpLowering>(options, 152 &getContext()); 153 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 154 return; 155 } 156 157 // Test on one pattern in isolation. 158 if (lowerToFilterOuterProduct) { 159 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 160 VectorTransformsOptions options{lowering}; 161 patterns.add<ContractionOpToOuterProductOpLowering>( 162 options, &getContext(), [](vector::ContractionOp op) { 163 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 164 // where M is not 4. 165 if (op.getRhsType().getShape()[0] == 4) 166 return failure(); 167 return success(); 168 }); 169 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 170 return; 171 } 172 173 if (lowerToParallelArith) { 174 vector::populateVectorContractLoweringPatterns( 175 patterns, 176 vector::VectorTransformsOptions().setVectorTransformsOptions( 177 vector::VectorContractLowering::ParallelArith)); 178 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 179 return; 180 } 181 182 // Test on all contract lowering patterns. 183 VectorContractLowering contractLowering = VectorContractLowering::Dot; 184 if (lowerToFlatMatrix) 185 contractLowering = VectorContractLowering::Matmul; 186 VectorMultiReductionLowering vectorMultiReductionLowering = 187 VectorMultiReductionLowering::InnerParallel; 188 VectorTransformsOptions options{contractLowering, 189 vectorMultiReductionLowering, 190 VectorTransposeLowering()}; 191 populateVectorBroadcastLoweringPatterns(patterns); 192 populateVectorContractLoweringPatterns(patterns, options); 193 populateVectorMaskOpLoweringPatterns(patterns); 194 populateVectorShapeCastLoweringPatterns(patterns); 195 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 196 } 197 }; 198 199 struct TestVectorTransposeLowering 200 : public PassWrapper<TestVectorTransposeLowering, 201 OperationPass<func::FuncOp>> { 202 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering) 203 204 StringRef getArgument() const final { 205 return "test-vector-transpose-lowering"; 206 } 207 StringRef getDescription() const final { 208 return "Test lowering patterns that lower contract ops in the vector " 209 "dialect"; 210 } 211 TestVectorTransposeLowering() = default; 212 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 213 : PassWrapper(pass) {} 214 215 Option<bool> lowerToEltwise{ 216 *this, "eltwise", 217 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 218 llvm::cl::init(false)}; 219 Option<bool> lowerToFlatTranspose{ 220 *this, "flat", 221 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 222 llvm::cl::init(false)}; 223 Option<bool> lowerToShuffleTranspose{ 224 *this, "shuffle", 225 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 226 llvm::cl::init(false)}; 227 Option<bool> lowerToAvx2{ 228 *this, "avx2", 229 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 230 llvm::cl::init(false)}; 231 232 void getDependentDialects(DialectRegistry ®istry) const override { 233 registry.insert<LLVM::LLVMDialect>(); 234 } 235 236 void runOnOperation() override { 237 RewritePatternSet patterns(&getContext()); 238 239 // Test on one pattern in isolation. 240 // Explicitly disable shape_cast lowering. 241 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 242 .enableVectorTransposeLowering() 243 .enableShapeCastLowering(false); 244 if (lowerToEltwise) { 245 options = options.setVectorTransformsOptions( 246 VectorTransformsOptions().setVectorTransposeLowering( 247 VectorTransposeLowering::EltWise)); 248 } 249 if (lowerToFlatTranspose) { 250 options = options.setVectorTransformsOptions( 251 VectorTransformsOptions().setVectorTransposeLowering( 252 VectorTransposeLowering::Flat)); 253 } 254 if (lowerToShuffleTranspose) { 255 options = options.setVectorTransformsOptions( 256 VectorTransformsOptions().setVectorTransposeLowering( 257 VectorTransposeLowering::Shuffle)); 258 } 259 if (lowerToAvx2) { 260 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 261 x86vector::avx2::LoweringOptions().setTransposeOptions( 262 x86vector::avx2::TransposeLoweringOptions() 263 .lower4x8xf32() 264 .lower8x8xf32())); 265 } 266 267 OpPassManager dynamicPM("func.func"); 268 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 269 if (failed(runPipeline(dynamicPM, getOperation()))) 270 return signalPassFailure(); 271 } 272 }; 273 274 struct TestVectorUnrollingPatterns 275 : public PassWrapper<TestVectorUnrollingPatterns, 276 OperationPass<func::FuncOp>> { 277 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) 278 279 StringRef getArgument() const final { 280 return "test-vector-unrolling-patterns"; 281 } 282 StringRef getDescription() const final { 283 return "Test lowering patterns to unroll contract ops in the vector " 284 "dialect"; 285 } 286 TestVectorUnrollingPatterns() = default; 287 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 288 : PassWrapper(pass) {} 289 void runOnOperation() override { 290 MLIRContext *ctx = &getContext(); 291 RewritePatternSet patterns(ctx); 292 populateVectorUnrollPatterns( 293 patterns, UnrollVectorOptions() 294 .setNativeShape(ArrayRef<int64_t>{2, 2}) 295 .setFilterConstraint([](Operation *op) { 296 return success(isa<arith::AddFOp, vector::FMAOp, 297 vector::MultiDimReductionOp>(op)); 298 })); 299 populateVectorUnrollPatterns( 300 patterns, UnrollVectorOptions() 301 .setNativeShape(ArrayRef<int64_t>{2}) 302 .setFilterConstraint([](Operation *op) { 303 return success(isa<vector::ReductionOp>(op)); 304 })); 305 populateVectorUnrollPatterns( 306 patterns, UnrollVectorOptions() 307 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) 308 .setFilterConstraint([](Operation *op) { 309 return success(isa<vector::TransposeOp>(op)); 310 })); 311 312 if (unrollBasedOnType) { 313 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 314 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 315 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 316 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 317 if (auto floatType = contractOp.getLhsType() 318 .getElementType() 319 .dyn_cast<FloatType>()) { 320 if (floatType.getWidth() == 16) { 321 nativeShape[2] = 4; 322 } 323 } 324 return nativeShape; 325 }; 326 327 UnrollVectorOptions opts; 328 opts.setNativeShapeFn(nativeShapeFn) 329 .setFilterConstraint( 330 [](Operation *op) { return success(isa<ContractionOp>(op)); }); 331 if (!unrollOrder.empty()) { 332 opts.setUnrollTraversalOrderFn([this](Operation *op) 333 -> Optional<SmallVector<int64_t>> { 334 return SmallVector<int64_t>{unrollOrder.begin(), unrollOrder.end()}; 335 }); 336 } 337 populateVectorUnrollPatterns(patterns, opts); 338 } else { 339 populateVectorUnrollPatterns( 340 patterns, UnrollVectorOptions() 341 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 342 .setFilterConstraint([](Operation *op) { 343 return success(isa<ContractionOp>(op)); 344 })); 345 } 346 populateVectorToVectorCanonicalizationPatterns(patterns); 347 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 348 } 349 350 ListOption<int64_t> unrollOrder{*this, "unroll-order", 351 llvm::cl::desc("set the unroll order"), 352 llvm::cl::ZeroOrMore}; 353 354 Option<bool> unrollBasedOnType{ 355 *this, "unroll-based-on-type", 356 llvm::cl::desc("Set the unroll factor based on type of the operation"), 357 llvm::cl::init(false)}; 358 }; 359 360 struct TestVectorDistributePatterns 361 : public PassWrapper<TestVectorDistributePatterns, 362 OperationPass<func::FuncOp>> { 363 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns) 364 365 StringRef getArgument() const final { 366 return "test-vector-distribute-patterns"; 367 } 368 StringRef getDescription() const final { 369 return "Test lowering patterns to distribute vector ops in the vector " 370 "dialect"; 371 } 372 TestVectorDistributePatterns() = default; 373 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 374 : PassWrapper(pass) {} 375 void getDependentDialects(DialectRegistry ®istry) const override { 376 registry.insert<VectorDialect>(); 377 registry.insert<AffineDialect>(); 378 } 379 ListOption<int32_t> multiplicity{ 380 *this, "distribution-multiplicity", 381 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 382 383 void runOnOperation() override { 384 MLIRContext *ctx = &getContext(); 385 RewritePatternSet patterns(ctx); 386 func::FuncOp func = getOperation(); 387 func.walk([&](arith::AddFOp op) { 388 OpBuilder builder(op); 389 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 390 SmallVector<int64_t, 2> mul; 391 SmallVector<AffineExpr, 2> perm; 392 SmallVector<Value, 2> ids; 393 unsigned count = 0; 394 // Remove the multiplicity of 1 and calculate the affine map based on 395 // the multiplicity. 396 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 397 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 398 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 399 mul.push_back(m[i]); 400 ids.push_back(func.getArgument(count++)); 401 perm.push_back(getAffineDimExpr(i, ctx)); 402 } 403 } 404 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 405 perm, ctx); 406 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 407 builder, op.getOperation(), ids, mul, map); 408 if (ops.hasValue()) { 409 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 410 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 411 extractOp); 412 } 413 } 414 }); 415 populatePropagateVectorDistributionPatterns(patterns); 416 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 417 } 418 }; 419 420 struct TestVectorToLoopPatterns 421 : public PassWrapper<TestVectorToLoopPatterns, 422 OperationPass<func::FuncOp>> { 423 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns) 424 425 StringRef getArgument() const final { return "test-vector-to-forloop"; } 426 StringRef getDescription() const final { 427 return "Test lowering patterns to break up a vector op into a for loop"; 428 } 429 TestVectorToLoopPatterns() = default; 430 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 431 : PassWrapper(pass) {} 432 void getDependentDialects(DialectRegistry ®istry) const override { 433 registry.insert<VectorDialect>(); 434 registry.insert<AffineDialect>(); 435 } 436 Option<int32_t> multiplicity{ 437 *this, "distribution-multiplicity", 438 llvm::cl::desc("Set the multiplicity used for distributing vector"), 439 llvm::cl::init(32)}; 440 void runOnOperation() override { 441 MLIRContext *ctx = &getContext(); 442 RewritePatternSet patterns(ctx); 443 func::FuncOp func = getOperation(); 444 func.walk([&](arith::AddFOp op) { 445 // Check that the operation type can be broken down into a loop. 446 VectorType type = op.getType().dyn_cast<VectorType>(); 447 if (!type || type.getRank() != 1 || 448 type.getNumElements() % multiplicity != 0) 449 return mlir::WalkResult::advance(); 450 auto filterAlloc = [](Operation *op) { 451 return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op); 452 }; 453 auto dependentOps = getSlice(op, filterAlloc); 454 // Create a loop and move instructions from the Op slice into the loop. 455 OpBuilder builder(op); 456 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 457 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 458 auto numIter = 459 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 460 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 461 for (Operation *it : dependentOps) { 462 it->moveBefore(forOp.getBody()->getTerminator()); 463 } 464 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 465 // break up the original op and let the patterns propagate. 466 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 467 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 468 map); 469 if (ops.hasValue()) { 470 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 471 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 472 } 473 return mlir::WalkResult::interrupt(); 474 }); 475 populatePropagateVectorDistributionPatterns(patterns); 476 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 477 } 478 }; 479 480 struct TestVectorTransferUnrollingPatterns 481 : public PassWrapper<TestVectorTransferUnrollingPatterns, 482 OperationPass<func::FuncOp>> { 483 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 484 TestVectorTransferUnrollingPatterns) 485 486 TestVectorTransferUnrollingPatterns() = default; 487 TestVectorTransferUnrollingPatterns( 488 const TestVectorTransferUnrollingPatterns &pass) 489 : PassWrapper(pass) {} 490 491 void getDependentDialects(DialectRegistry ®istry) const override { 492 registry.insert<AffineDialect>(); 493 } 494 StringRef getArgument() const final { 495 return "test-vector-transfer-unrolling-patterns"; 496 } 497 StringRef getDescription() const final { 498 return "Test lowering patterns to unroll transfer ops in the vector " 499 "dialect"; 500 } 501 void runOnOperation() override { 502 MLIRContext *ctx = &getContext(); 503 RewritePatternSet patterns(ctx); 504 UnrollVectorOptions opts; 505 opts.setNativeShape(ArrayRef<int64_t>{2, 2}) 506 .setFilterConstraint([](Operation *op) { 507 return success( 508 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 509 }); 510 if (reverseUnrollOrder.getValue()) { 511 opts.setUnrollTraversalOrderFn( 512 [](Operation *op) -> Optional<SmallVector<int64_t>> { 513 int64_t numLoops = 0; 514 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) 515 numLoops = readOp.getVectorType().getRank(); 516 else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) 517 numLoops = writeOp.getVectorType().getRank(); 518 else 519 return None; 520 auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops)); 521 return llvm::to_vector(order); 522 }); 523 } 524 populateVectorUnrollPatterns(patterns, opts); 525 populateVectorToVectorCanonicalizationPatterns(patterns); 526 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 527 } 528 529 Option<bool> reverseUnrollOrder{ 530 *this, "reverse-unroll-order", 531 llvm::cl::desc( 532 "reverse the order of unrolling of vector transfer operations"), 533 llvm::cl::init(false)}; 534 }; 535 536 struct TestVectorTransferFullPartialSplitPatterns 537 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 538 OperationPass<func::FuncOp>> { 539 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 540 TestVectorTransferFullPartialSplitPatterns) 541 542 StringRef getArgument() const final { 543 return "test-vector-transfer-full-partial-split"; 544 } 545 StringRef getDescription() const final { 546 return "Test lowering patterns to split " 547 "transfer ops via scf.if + linalg ops"; 548 } 549 TestVectorTransferFullPartialSplitPatterns() = default; 550 TestVectorTransferFullPartialSplitPatterns( 551 const TestVectorTransferFullPartialSplitPatterns &pass) 552 : PassWrapper(pass) {} 553 554 void getDependentDialects(DialectRegistry ®istry) const override { 555 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 556 scf::SCFDialect>(); 557 } 558 559 Option<bool> useLinalgOps{ 560 *this, "use-memref-copy", 561 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 562 "memref.copy operations."), 563 llvm::cl::init(false)}; 564 void runOnOperation() override { 565 MLIRContext *ctx = &getContext(); 566 RewritePatternSet patterns(ctx); 567 VectorTransformsOptions options; 568 if (useLinalgOps) 569 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 570 else 571 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 572 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 573 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 574 } 575 }; 576 577 struct TestVectorTransferOpt 578 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> { 579 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) 580 581 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 582 StringRef getDescription() const final { 583 return "Test optimization transformations for transfer ops"; 584 } 585 void runOnOperation() override { transferOpflowOpt(getOperation()); } 586 }; 587 588 struct TestVectorTransferLoweringPatterns 589 : public PassWrapper<TestVectorTransferLoweringPatterns, 590 OperationPass<func::FuncOp>> { 591 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 592 TestVectorTransferLoweringPatterns) 593 594 void getDependentDialects(DialectRegistry ®istry) const override { 595 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 596 } 597 StringRef getArgument() const final { 598 return "test-vector-transfer-lowering-patterns"; 599 } 600 StringRef getDescription() const final { 601 return "Test lowering patterns to lower transfer ops to other vector ops"; 602 } 603 void runOnOperation() override { 604 RewritePatternSet patterns(&getContext()); 605 populateVectorTransferLoweringPatterns(patterns); 606 populateVectorTransferPermutationMapLoweringPatterns(patterns); 607 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 608 } 609 }; 610 611 struct TestVectorMultiReductionLoweringPatterns 612 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 613 OperationPass<func::FuncOp>> { 614 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 615 TestVectorMultiReductionLoweringPatterns) 616 617 TestVectorMultiReductionLoweringPatterns() = default; 618 TestVectorMultiReductionLoweringPatterns( 619 const TestVectorMultiReductionLoweringPatterns &pass) 620 : PassWrapper(pass) {} 621 void getDependentDialects(DialectRegistry ®istry) const override { 622 registry.insert<memref::MemRefDialect>(); 623 } 624 StringRef getArgument() const final { 625 return "test-vector-multi-reduction-lowering-patterns"; 626 } 627 StringRef getDescription() const final { 628 return "Test lowering patterns to lower vector.multi_reduction to other " 629 "vector ops"; 630 } 631 Option<bool> useOuterReductions{ 632 *this, "use-outer-reductions", 633 llvm::cl::desc("Move reductions to outer most dimensions"), 634 llvm::cl::init(false)}; 635 void runOnOperation() override { 636 RewritePatternSet patterns(&getContext()); 637 populateVectorMultiReductionLoweringPatterns( 638 patterns, useOuterReductions 639 ? vector::VectorMultiReductionLowering::InnerParallel 640 : vector::VectorMultiReductionLowering::InnerReduction); 641 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 642 } 643 }; 644 645 struct TestVectorTransferCollapseInnerMostContiguousDims 646 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 647 OperationPass<func::FuncOp>> { 648 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 649 TestVectorTransferCollapseInnerMostContiguousDims) 650 651 TestVectorTransferCollapseInnerMostContiguousDims() = default; 652 TestVectorTransferCollapseInnerMostContiguousDims( 653 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 654 655 void getDependentDialects(DialectRegistry ®istry) const override { 656 registry.insert<memref::MemRefDialect, AffineDialect>(); 657 } 658 659 StringRef getArgument() const final { 660 return "test-vector-transfer-collapse-inner-most-dims"; 661 } 662 663 StringRef getDescription() const final { 664 return "Test lowering patterns that reducedes the rank of the vector " 665 "transfer memory and vector operands."; 666 } 667 668 void runOnOperation() override { 669 RewritePatternSet patterns(&getContext()); 670 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 671 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 672 } 673 }; 674 675 struct TestVectorReduceToContractPatternsPatterns 676 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 677 OperationPass<func::FuncOp>> { 678 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 679 TestVectorReduceToContractPatternsPatterns) 680 681 StringRef getArgument() const final { 682 return "test-vector-reduction-to-contract-patterns"; 683 } 684 StringRef getDescription() const final { 685 return "Test patterns to convert multireduce op to contract and combine " 686 "broadcast/transpose to contract"; 687 } 688 void runOnOperation() override { 689 RewritePatternSet patterns(&getContext()); 690 populateVectorReductionToContractPatterns(patterns); 691 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 692 } 693 }; 694 695 struct TestVectorTransferDropUnitDimsPatterns 696 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, 697 OperationPass<func::FuncOp>> { 698 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 699 TestVectorTransferDropUnitDimsPatterns) 700 701 StringRef getArgument() const final { 702 return "test-vector-transfer-drop-unit-dims-patterns"; 703 } 704 void getDependentDialects(DialectRegistry ®istry) const override { 705 registry.insert<memref::MemRefDialect>(); 706 } 707 void runOnOperation() override { 708 RewritePatternSet patterns(&getContext()); 709 populateVectorTransferDropUnitDimsPatterns(patterns); 710 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 711 } 712 }; 713 714 struct TestFlattenVectorTransferPatterns 715 : public PassWrapper<TestFlattenVectorTransferPatterns, 716 OperationPass<func::FuncOp>> { 717 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 718 TestFlattenVectorTransferPatterns) 719 720 StringRef getArgument() const final { 721 return "test-vector-transfer-flatten-patterns"; 722 } 723 StringRef getDescription() const final { 724 return "Test patterns to rewrite contiguous row-major N-dimensional " 725 "vector.transfer_{read,write} ops into 1D transfers"; 726 } 727 void getDependentDialects(DialectRegistry ®istry) const override { 728 registry.insert<memref::MemRefDialect>(); 729 } 730 void runOnOperation() override { 731 RewritePatternSet patterns(&getContext()); 732 populateFlattenVectorTransferPatterns(patterns); 733 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 734 } 735 }; 736 737 struct TestVectorScanLowering 738 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> { 739 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) 740 741 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 742 StringRef getDescription() const final { 743 return "Test lowering patterns that lower the scan op in the vector " 744 "dialect"; 745 } 746 void runOnOperation() override { 747 RewritePatternSet patterns(&getContext()); 748 populateVectorScanLoweringPatterns(patterns); 749 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 750 } 751 }; 752 753 /// Allocate shared memory for a single warp to test lowering of 754 /// WarpExecuteOnLane0Op. 755 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, 756 WarpExecuteOnLane0Op warpOp, 757 Type type) { 758 static constexpr int64_t kSharedMemorySpace = 3; 759 // Compute type of shared memory buffer. 760 MemRefType memrefType; 761 if (auto vectorType = type.dyn_cast<VectorType>()) { 762 memrefType = 763 MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, 764 kSharedMemorySpace); 765 } else { 766 memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); 767 } 768 769 // Get symbol table holding all shared memory globals. 770 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>(); 771 SymbolTable symbolTable(moduleOp); 772 773 // Create a pretty name. 774 SmallString<64> buf; 775 llvm::raw_svector_ostream os(buf); 776 interleave(memrefType.getShape(), os, "x"); 777 os << "x" << memrefType.getElementType(); 778 std::string symbolName = (Twine("__shared_") + os.str()).str(); 779 780 auto ip = builder.saveInsertionPoint(); 781 builder.setInsertionPoint(moduleOp); 782 auto global = builder.create<memref::GlobalOp>( 783 loc, 784 /*sym_name=*/symbolName, 785 /*sym_visibility=*/builder.getStringAttr("private"), 786 /*type=*/memrefType, 787 /*initial_value=*/Attribute(), 788 /*constant=*/false, 789 /*alignment=*/IntegerAttr()); 790 symbolTable.insert(global); 791 // The symbol table inserts at the end of the module, but globals are a bit 792 // nicer if they are at the beginning. 793 global->moveBefore(&moduleOp.front()); 794 795 builder.restoreInsertionPoint(ip); 796 return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); 797 } 798 799 struct TestVectorDistribution 800 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> { 801 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) 802 803 void getDependentDialects(DialectRegistry ®istry) const override { 804 registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect>(); 805 } 806 807 StringRef getArgument() const final { return "test-vector-warp-distribute"; } 808 StringRef getDescription() const final { 809 return "Test vector warp distribute transformation and lowering patterns"; 810 } 811 TestVectorDistribution() = default; 812 TestVectorDistribution(const TestVectorDistribution &pass) 813 : PassWrapper(pass) {} 814 815 Option<bool> warpOpToSCF{ 816 *this, "rewrite-warp-ops-to-scf-if", 817 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), 818 llvm::cl::init(false)}; 819 820 void runOnOperation() override { 821 RewritePatternSet patterns(&getContext()); 822 WarpExecuteOnLane0LoweringOptions options; 823 options.warpAllocationFn = allocateGlobalSharedMemory; 824 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, 825 WarpExecuteOnLane0Op warpOp) { 826 builder.create<gpu::BarrierOp>(loc); 827 }; 828 // Test on one pattern in isolation. 829 if (warpOpToSCF) { 830 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); 831 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 832 return; 833 } 834 } 835 }; 836 837 } // namespace 838 839 namespace mlir { 840 namespace test { 841 void registerTestVectorLowerings() { 842 PassRegistration<TestVectorToVectorLowering>(); 843 844 PassRegistration<TestVectorContractionLowering>(); 845 846 PassRegistration<TestVectorTransposeLowering>(); 847 848 PassRegistration<TestVectorUnrollingPatterns>(); 849 850 PassRegistration<TestVectorTransferUnrollingPatterns>(); 851 852 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 853 854 PassRegistration<TestVectorDistributePatterns>(); 855 856 PassRegistration<TestVectorToLoopPatterns>(); 857 858 PassRegistration<TestVectorTransferOpt>(); 859 860 PassRegistration<TestVectorTransferLoweringPatterns>(); 861 862 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 863 864 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 865 866 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 867 868 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 869 870 PassRegistration<TestFlattenVectorTransferPatterns>(); 871 872 PassRegistration<TestVectorScanLowering>(); 873 874 PassRegistration<TestVectorDistribution>(); 875 } 876 } // namespace test 877 } // namespace mlir 878