1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/IR/SCF.h" 21 #include "mlir/Dialect/Vector/IR/VectorOps.h" 22 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Pass/PassManager.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 using namespace mlir::vector; 32 33 namespace { 34 35 struct TestVectorToVectorLowering 36 : public PassWrapper<TestVectorToVectorLowering, 37 OperationPass<func::FuncOp>> { 38 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) 39 40 TestVectorToVectorLowering() = default; 41 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 42 : PassWrapper(pass) {} 43 StringRef getArgument() const final { 44 return "test-vector-to-vector-lowering"; 45 } 46 StringRef getDescription() const final { 47 return "Test lowering patterns between ops in the vector dialect"; 48 } 49 50 void getDependentDialects(DialectRegistry ®istry) const override { 51 registry.insert<AffineDialect>(); 52 } 53 54 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 55 llvm::cl::init(false)}; 56 57 void runOnOperation() override { 58 auto *ctx = &getContext(); 59 RewritePatternSet patterns(ctx); 60 if (unroll) { 61 populateVectorUnrollPatterns( 62 patterns, 63 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 64 filter)); 65 } 66 populateVectorToVectorCanonicalizationPatterns(patterns); 67 populateBubbleVectorBitCastOpPatterns(patterns); 68 populateCastAwayVectorLeadingOneDimPatterns(patterns); 69 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 70 } 71 72 private: 73 // Return the target shape based on op type. 74 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 75 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 76 return SmallVector<int64_t, 4>(2, 2); 77 if (isa<vector::ContractionOp>(op)) 78 return SmallVector<int64_t, 4>(3, 2); 79 // For transfer ops, just propagate the shape coming from 80 // InsertStridedSlices/ExtractStridedSlices. 81 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 82 VectorType dstVec; 83 for (Operation *users : readOp->getUsers()) { 84 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 85 if (!extract) 86 return llvm::None; 87 auto vecType = extract.getResult().getType().cast<VectorType>(); 88 if (dstVec && dstVec != vecType) 89 return llvm::None; 90 dstVec = vecType; 91 } 92 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 93 dstVec.getShape().end()); 94 } 95 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 96 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 97 if (!insert) 98 return llvm::None; 99 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 100 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 101 } 102 return llvm::None; 103 } 104 105 static LogicalResult filter(Operation *op) { 106 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 107 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 108 } 109 }; 110 111 struct TestVectorContractionLowering 112 : public PassWrapper<TestVectorContractionLowering, 113 OperationPass<func::FuncOp>> { 114 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering) 115 116 StringRef getArgument() const final { 117 return "test-vector-contraction-lowering"; 118 } 119 StringRef getDescription() const final { 120 return "Test lowering patterns that lower contract ops in the vector " 121 "dialect"; 122 } 123 TestVectorContractionLowering() = default; 124 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 125 : PassWrapper(pass) {} 126 127 Option<bool> lowerToFlatMatrix{ 128 *this, "vector-lower-matrix-intrinsics", 129 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 130 llvm::cl::init(false)}; 131 Option<bool> lowerToOuterProduct{ 132 *this, "vector-outerproduct", 133 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 134 llvm::cl::init(false)}; 135 Option<bool> lowerToFilterOuterProduct{ 136 *this, "vector-filter-outerproduct", 137 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 138 "vectors of size 4."), 139 llvm::cl::init(false)}; 140 Option<bool> lowerToParallelArith{ 141 *this, "vector-parallel-arith", 142 llvm::cl::desc("Lower vector.contract to elementwise vector ops."), 143 llvm::cl::init(false)}; 144 145 void runOnOperation() override { 146 RewritePatternSet patterns(&getContext()); 147 148 // Test on one pattern in isolation. 149 if (lowerToOuterProduct) { 150 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 151 VectorTransformsOptions options{lowering}; 152 patterns.add<ContractionOpToOuterProductOpLowering>(options, 153 &getContext()); 154 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 155 return; 156 } 157 158 // Test on one pattern in isolation. 159 if (lowerToFilterOuterProduct) { 160 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 161 VectorTransformsOptions options{lowering}; 162 patterns.add<ContractionOpToOuterProductOpLowering>( 163 options, &getContext(), [](vector::ContractionOp op) { 164 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 165 // where M is not 4. 166 if (op.getRhsType().getShape()[0] == 4) 167 return failure(); 168 return success(); 169 }); 170 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 171 return; 172 } 173 174 if (lowerToParallelArith) { 175 vector::populateVectorContractLoweringPatterns( 176 patterns, 177 vector::VectorTransformsOptions().setVectorTransformsOptions( 178 vector::VectorContractLowering::ParallelArith)); 179 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 180 return; 181 } 182 183 // Test on all contract lowering patterns. 184 VectorContractLowering contractLowering = VectorContractLowering::Dot; 185 if (lowerToFlatMatrix) 186 contractLowering = VectorContractLowering::Matmul; 187 VectorMultiReductionLowering vectorMultiReductionLowering = 188 VectorMultiReductionLowering::InnerParallel; 189 VectorTransformsOptions options{contractLowering, 190 vectorMultiReductionLowering, 191 VectorTransposeLowering()}; 192 populateVectorBroadcastLoweringPatterns(patterns); 193 populateVectorContractLoweringPatterns(patterns, options); 194 populateVectorMaskOpLoweringPatterns(patterns); 195 populateVectorShapeCastLoweringPatterns(patterns); 196 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 197 } 198 }; 199 200 struct TestVectorTransposeLowering 201 : public PassWrapper<TestVectorTransposeLowering, 202 OperationPass<func::FuncOp>> { 203 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering) 204 205 StringRef getArgument() const final { 206 return "test-vector-transpose-lowering"; 207 } 208 StringRef getDescription() const final { 209 return "Test lowering patterns that lower contract ops in the vector " 210 "dialect"; 211 } 212 TestVectorTransposeLowering() = default; 213 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 214 : PassWrapper(pass) {} 215 216 Option<bool> lowerToEltwise{ 217 *this, "eltwise", 218 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 219 llvm::cl::init(false)}; 220 Option<bool> lowerToFlatTranspose{ 221 *this, "flat", 222 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 223 llvm::cl::init(false)}; 224 Option<bool> lowerToShuffleTranspose{ 225 *this, "shuffle", 226 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 227 llvm::cl::init(false)}; 228 Option<bool> lowerToAvx2{ 229 *this, "avx2", 230 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 231 llvm::cl::init(false)}; 232 233 void getDependentDialects(DialectRegistry ®istry) const override { 234 registry.insert<LLVM::LLVMDialect>(); 235 } 236 237 void runOnOperation() override { 238 RewritePatternSet patterns(&getContext()); 239 240 // Test on one pattern in isolation. 241 // Explicitly disable shape_cast lowering. 242 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 243 .enableVectorTransposeLowering() 244 .enableShapeCastLowering(false); 245 if (lowerToEltwise) { 246 options = options.setVectorTransformsOptions( 247 VectorTransformsOptions().setVectorTransposeLowering( 248 VectorTransposeLowering::EltWise)); 249 } 250 if (lowerToFlatTranspose) { 251 options = options.setVectorTransformsOptions( 252 VectorTransformsOptions().setVectorTransposeLowering( 253 VectorTransposeLowering::Flat)); 254 } 255 if (lowerToShuffleTranspose) { 256 options = options.setVectorTransformsOptions( 257 VectorTransformsOptions().setVectorTransposeLowering( 258 VectorTransposeLowering::Shuffle)); 259 } 260 if (lowerToAvx2) { 261 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 262 x86vector::avx2::LoweringOptions().setTransposeOptions( 263 x86vector::avx2::TransposeLoweringOptions() 264 .lower4x8xf32() 265 .lower8x8xf32())); 266 } 267 268 OpPassManager dynamicPM("func.func"); 269 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 270 if (failed(runPipeline(dynamicPM, getOperation()))) 271 return signalPassFailure(); 272 } 273 }; 274 275 struct TestVectorUnrollingPatterns 276 : public PassWrapper<TestVectorUnrollingPatterns, 277 OperationPass<func::FuncOp>> { 278 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) 279 280 StringRef getArgument() const final { 281 return "test-vector-unrolling-patterns"; 282 } 283 StringRef getDescription() const final { 284 return "Test lowering patterns to unroll contract ops in the vector " 285 "dialect"; 286 } 287 TestVectorUnrollingPatterns() = default; 288 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 289 : PassWrapper(pass) {} 290 void runOnOperation() override { 291 MLIRContext *ctx = &getContext(); 292 RewritePatternSet patterns(ctx); 293 populateVectorUnrollPatterns( 294 patterns, UnrollVectorOptions() 295 .setNativeShape(ArrayRef<int64_t>{2, 2}) 296 .setFilterConstraint([](Operation *op) { 297 return success(isa<arith::AddFOp, vector::FMAOp, 298 vector::MultiDimReductionOp>(op)); 299 })); 300 populateVectorUnrollPatterns( 301 patterns, UnrollVectorOptions() 302 .setNativeShape(ArrayRef<int64_t>{2}) 303 .setFilterConstraint([](Operation *op) { 304 return success(isa<vector::ReductionOp>(op)); 305 })); 306 populateVectorUnrollPatterns( 307 patterns, UnrollVectorOptions() 308 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) 309 .setFilterConstraint([](Operation *op) { 310 return success(isa<vector::TransposeOp>(op)); 311 })); 312 313 if (unrollBasedOnType) { 314 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 315 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 316 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 317 SmallVector<int64_t, 4> nativeShape( 318 contractOp.getIteratorTypes().size(), 4); 319 Type lhsType = contractOp.getLhsType().getElementType(); 320 nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2; 321 return nativeShape; 322 }; 323 324 UnrollVectorOptions opts; 325 opts.setNativeShapeFn(nativeShapeFn) 326 .setFilterConstraint( 327 [](Operation *op) { return success(isa<ContractionOp>(op)); }); 328 329 if (!unrollOrder.empty()) { 330 opts.setUnrollTraversalOrderFn([this](Operation *op) 331 -> Optional<SmallVector<int64_t>> { 332 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 333 if (contractOp.getIteratorTypes().size() == unrollOrder.size()) 334 return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end()); 335 return None; 336 }); 337 } 338 populateVectorUnrollPatterns(patterns, opts); 339 } else { 340 auto nativeShapeFn = 341 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 342 auto contractOp = dyn_cast<ContractionOp>(op); 343 if (!contractOp) 344 return None; 345 return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2); 346 }; 347 populateVectorUnrollPatterns(patterns, 348 UnrollVectorOptions() 349 .setNativeShapeFn(nativeShapeFn) 350 .setFilterConstraint([](Operation *op) { 351 return success(isa<ContractionOp>(op)); 352 })); 353 } 354 populateVectorToVectorCanonicalizationPatterns(patterns); 355 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 356 } 357 358 ListOption<int64_t> unrollOrder{*this, "unroll-order", 359 llvm::cl::desc("set the unroll order")}; 360 361 Option<bool> unrollBasedOnType{ 362 *this, "unroll-based-on-type", 363 llvm::cl::desc("Set the unroll factor based on type of the operation"), 364 llvm::cl::init(false)}; 365 }; 366 367 struct TestVectorDistributePatterns 368 : public PassWrapper<TestVectorDistributePatterns, 369 OperationPass<func::FuncOp>> { 370 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns) 371 372 StringRef getArgument() const final { 373 return "test-vector-distribute-patterns"; 374 } 375 StringRef getDescription() const final { 376 return "Test lowering patterns to distribute vector ops in the vector " 377 "dialect"; 378 } 379 TestVectorDistributePatterns() = default; 380 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 381 : PassWrapper(pass) {} 382 void getDependentDialects(DialectRegistry ®istry) const override { 383 registry.insert<VectorDialect>(); 384 registry.insert<AffineDialect>(); 385 } 386 ListOption<int32_t> multiplicity{ 387 *this, "distribution-multiplicity", 388 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 389 390 void runOnOperation() override { 391 MLIRContext *ctx = &getContext(); 392 RewritePatternSet patterns(ctx); 393 func::FuncOp func = getOperation(); 394 func.walk([&](arith::AddFOp op) { 395 OpBuilder builder(op); 396 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 397 SmallVector<int64_t, 2> mul; 398 SmallVector<AffineExpr, 2> perm; 399 SmallVector<Value, 2> ids; 400 unsigned count = 0; 401 // Remove the multiplicity of 1 and calculate the affine map based on 402 // the multiplicity. 403 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 404 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 405 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 406 mul.push_back(m[i]); 407 ids.push_back(func.getArgument(count++)); 408 perm.push_back(getAffineDimExpr(i, ctx)); 409 } 410 } 411 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 412 perm, ctx); 413 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 414 builder, op.getOperation(), ids, mul, map); 415 if (ops) { 416 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 417 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 418 extractOp); 419 } 420 } 421 }); 422 populatePropagateVectorDistributionPatterns(patterns); 423 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 424 } 425 }; 426 427 struct TestVectorToLoopPatterns 428 : public PassWrapper<TestVectorToLoopPatterns, 429 OperationPass<func::FuncOp>> { 430 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns) 431 432 StringRef getArgument() const final { return "test-vector-to-forloop"; } 433 StringRef getDescription() const final { 434 return "Test lowering patterns to break up a vector op into a for loop"; 435 } 436 TestVectorToLoopPatterns() = default; 437 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 438 : PassWrapper(pass) {} 439 void getDependentDialects(DialectRegistry ®istry) const override { 440 registry.insert<VectorDialect>(); 441 registry.insert<AffineDialect>(); 442 } 443 Option<int32_t> multiplicity{ 444 *this, "distribution-multiplicity", 445 llvm::cl::desc("Set the multiplicity used for distributing vector"), 446 llvm::cl::init(32)}; 447 void runOnOperation() override { 448 MLIRContext *ctx = &getContext(); 449 RewritePatternSet patterns(ctx); 450 func::FuncOp func = getOperation(); 451 func.walk([&](arith::AddFOp op) { 452 // Check that the operation type can be broken down into a loop. 453 VectorType type = op.getType().dyn_cast<VectorType>(); 454 if (!type || type.getRank() != 1 || 455 type.getNumElements() % multiplicity != 0) 456 return mlir::WalkResult::advance(); 457 auto filterAlloc = [](Operation *op) { 458 return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op); 459 }; 460 auto dependentOps = getSlice(op, filterAlloc); 461 // Create a loop and move instructions from the Op slice into the loop. 462 OpBuilder builder(op); 463 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 464 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 465 auto numIter = 466 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 467 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 468 for (Operation *it : dependentOps) { 469 it->moveBefore(forOp.getBody()->getTerminator()); 470 } 471 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 472 // break up the original op and let the patterns propagate. 473 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 474 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 475 map); 476 if (ops) { 477 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 478 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 479 } 480 return mlir::WalkResult::interrupt(); 481 }); 482 populatePropagateVectorDistributionPatterns(patterns); 483 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 484 } 485 }; 486 487 struct TestVectorTransferUnrollingPatterns 488 : public PassWrapper<TestVectorTransferUnrollingPatterns, 489 OperationPass<func::FuncOp>> { 490 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 491 TestVectorTransferUnrollingPatterns) 492 493 TestVectorTransferUnrollingPatterns() = default; 494 TestVectorTransferUnrollingPatterns( 495 const TestVectorTransferUnrollingPatterns &pass) 496 : PassWrapper(pass) {} 497 498 void getDependentDialects(DialectRegistry ®istry) const override { 499 registry.insert<AffineDialect>(); 500 } 501 StringRef getArgument() const final { 502 return "test-vector-transfer-unrolling-patterns"; 503 } 504 StringRef getDescription() const final { 505 return "Test lowering patterns to unroll transfer ops in the vector " 506 "dialect"; 507 } 508 void runOnOperation() override { 509 MLIRContext *ctx = &getContext(); 510 RewritePatternSet patterns(ctx); 511 UnrollVectorOptions opts; 512 opts.setNativeShape(ArrayRef<int64_t>{2, 2}) 513 .setFilterConstraint([](Operation *op) { 514 return success( 515 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 516 }); 517 if (reverseUnrollOrder.getValue()) { 518 opts.setUnrollTraversalOrderFn( 519 [](Operation *op) -> Optional<SmallVector<int64_t>> { 520 int64_t numLoops = 0; 521 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) 522 numLoops = readOp.getVectorType().getRank(); 523 else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) 524 numLoops = writeOp.getVectorType().getRank(); 525 else 526 return None; 527 auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops)); 528 return llvm::to_vector(order); 529 }); 530 } 531 populateVectorUnrollPatterns(patterns, opts); 532 populateVectorToVectorCanonicalizationPatterns(patterns); 533 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 534 } 535 536 Option<bool> reverseUnrollOrder{ 537 *this, "reverse-unroll-order", 538 llvm::cl::desc( 539 "reverse the order of unrolling of vector transfer operations"), 540 llvm::cl::init(false)}; 541 }; 542 543 struct TestVectorTransferFullPartialSplitPatterns 544 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 545 OperationPass<func::FuncOp>> { 546 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 547 TestVectorTransferFullPartialSplitPatterns) 548 549 StringRef getArgument() const final { 550 return "test-vector-transfer-full-partial-split"; 551 } 552 StringRef getDescription() const final { 553 return "Test lowering patterns to split " 554 "transfer ops via scf.if + linalg ops"; 555 } 556 TestVectorTransferFullPartialSplitPatterns() = default; 557 TestVectorTransferFullPartialSplitPatterns( 558 const TestVectorTransferFullPartialSplitPatterns &pass) 559 : PassWrapper(pass) {} 560 561 void getDependentDialects(DialectRegistry ®istry) const override { 562 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 563 scf::SCFDialect>(); 564 } 565 566 Option<bool> useLinalgOps{ 567 *this, "use-memref-copy", 568 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 569 "memref.copy operations."), 570 llvm::cl::init(false)}; 571 void runOnOperation() override { 572 MLIRContext *ctx = &getContext(); 573 RewritePatternSet patterns(ctx); 574 VectorTransformsOptions options; 575 if (useLinalgOps) 576 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 577 else 578 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 579 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 580 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 581 } 582 }; 583 584 struct TestVectorTransferOpt 585 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> { 586 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) 587 588 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 589 StringRef getDescription() const final { 590 return "Test optimization transformations for transfer ops"; 591 } 592 void runOnOperation() override { transferOpflowOpt(getOperation()); } 593 }; 594 595 struct TestVectorTransferLoweringPatterns 596 : public PassWrapper<TestVectorTransferLoweringPatterns, 597 OperationPass<func::FuncOp>> { 598 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 599 TestVectorTransferLoweringPatterns) 600 601 void getDependentDialects(DialectRegistry ®istry) const override { 602 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 603 } 604 StringRef getArgument() const final { 605 return "test-vector-transfer-lowering-patterns"; 606 } 607 StringRef getDescription() const final { 608 return "Test lowering patterns to lower transfer ops to other vector ops"; 609 } 610 void runOnOperation() override { 611 RewritePatternSet patterns(&getContext()); 612 populateVectorTransferLoweringPatterns(patterns); 613 populateVectorTransferPermutationMapLoweringPatterns(patterns); 614 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 615 } 616 }; 617 618 struct TestVectorMultiReductionLoweringPatterns 619 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 620 OperationPass<func::FuncOp>> { 621 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 622 TestVectorMultiReductionLoweringPatterns) 623 624 TestVectorMultiReductionLoweringPatterns() = default; 625 TestVectorMultiReductionLoweringPatterns( 626 const TestVectorMultiReductionLoweringPatterns &pass) 627 : PassWrapper(pass) {} 628 void getDependentDialects(DialectRegistry ®istry) const override { 629 registry.insert<memref::MemRefDialect>(); 630 } 631 StringRef getArgument() const final { 632 return "test-vector-multi-reduction-lowering-patterns"; 633 } 634 StringRef getDescription() const final { 635 return "Test lowering patterns to lower vector.multi_reduction to other " 636 "vector ops"; 637 } 638 Option<bool> useOuterReductions{ 639 *this, "use-outer-reductions", 640 llvm::cl::desc("Move reductions to outer most dimensions"), 641 llvm::cl::init(false)}; 642 void runOnOperation() override { 643 RewritePatternSet patterns(&getContext()); 644 populateVectorMultiReductionLoweringPatterns( 645 patterns, useOuterReductions 646 ? vector::VectorMultiReductionLowering::InnerParallel 647 : vector::VectorMultiReductionLowering::InnerReduction); 648 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 649 } 650 }; 651 652 struct TestVectorTransferCollapseInnerMostContiguousDims 653 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 654 OperationPass<func::FuncOp>> { 655 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 656 TestVectorTransferCollapseInnerMostContiguousDims) 657 658 TestVectorTransferCollapseInnerMostContiguousDims() = default; 659 TestVectorTransferCollapseInnerMostContiguousDims( 660 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 661 662 void getDependentDialects(DialectRegistry ®istry) const override { 663 registry.insert<memref::MemRefDialect, AffineDialect>(); 664 } 665 666 StringRef getArgument() const final { 667 return "test-vector-transfer-collapse-inner-most-dims"; 668 } 669 670 StringRef getDescription() const final { 671 return "Test lowering patterns that reducedes the rank of the vector " 672 "transfer memory and vector operands."; 673 } 674 675 void runOnOperation() override { 676 RewritePatternSet patterns(&getContext()); 677 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 678 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 679 } 680 }; 681 682 struct TestVectorReduceToContractPatternsPatterns 683 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 684 OperationPass<func::FuncOp>> { 685 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 686 TestVectorReduceToContractPatternsPatterns) 687 688 StringRef getArgument() const final { 689 return "test-vector-reduction-to-contract-patterns"; 690 } 691 StringRef getDescription() const final { 692 return "Test patterns to convert multireduce op to contract and combine " 693 "broadcast/transpose to contract"; 694 } 695 void runOnOperation() override { 696 RewritePatternSet patterns(&getContext()); 697 populateVectorReductionToContractPatterns(patterns); 698 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 699 } 700 }; 701 702 struct TestVectorTransferDropUnitDimsPatterns 703 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, 704 OperationPass<func::FuncOp>> { 705 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 706 TestVectorTransferDropUnitDimsPatterns) 707 708 StringRef getArgument() const final { 709 return "test-vector-transfer-drop-unit-dims-patterns"; 710 } 711 void getDependentDialects(DialectRegistry ®istry) const override { 712 registry.insert<memref::MemRefDialect>(); 713 } 714 void runOnOperation() override { 715 RewritePatternSet patterns(&getContext()); 716 populateVectorTransferDropUnitDimsPatterns(patterns); 717 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 718 } 719 }; 720 721 struct TestFlattenVectorTransferPatterns 722 : public PassWrapper<TestFlattenVectorTransferPatterns, 723 OperationPass<func::FuncOp>> { 724 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 725 TestFlattenVectorTransferPatterns) 726 727 StringRef getArgument() const final { 728 return "test-vector-transfer-flatten-patterns"; 729 } 730 StringRef getDescription() const final { 731 return "Test patterns to rewrite contiguous row-major N-dimensional " 732 "vector.transfer_{read,write} ops into 1D transfers"; 733 } 734 void getDependentDialects(DialectRegistry ®istry) const override { 735 registry.insert<memref::MemRefDialect>(); 736 } 737 void runOnOperation() override { 738 RewritePatternSet patterns(&getContext()); 739 populateFlattenVectorTransferPatterns(patterns); 740 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 741 } 742 }; 743 744 struct TestVectorScanLowering 745 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> { 746 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) 747 748 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 749 StringRef getDescription() const final { 750 return "Test lowering patterns that lower the scan op in the vector " 751 "dialect"; 752 } 753 void runOnOperation() override { 754 RewritePatternSet patterns(&getContext()); 755 populateVectorScanLoweringPatterns(patterns); 756 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 757 } 758 }; 759 760 /// Allocate shared memory for a single warp to test lowering of 761 /// WarpExecuteOnLane0Op. 762 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, 763 WarpExecuteOnLane0Op warpOp, 764 Type type) { 765 static constexpr int64_t kSharedMemorySpace = 3; 766 // Compute type of shared memory buffer. 767 MemRefType memrefType; 768 if (auto vectorType = type.dyn_cast<VectorType>()) { 769 memrefType = 770 MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, 771 kSharedMemorySpace); 772 } else { 773 memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); 774 } 775 776 // Get symbol table holding all shared memory globals. 777 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>(); 778 SymbolTable symbolTable(moduleOp); 779 780 // Create a pretty name. 781 SmallString<64> buf; 782 llvm::raw_svector_ostream os(buf); 783 interleave(memrefType.getShape(), os, "x"); 784 os << "x" << memrefType.getElementType(); 785 std::string symbolName = (Twine("__shared_") + os.str()).str(); 786 787 auto ip = builder.saveInsertionPoint(); 788 builder.setInsertionPoint(moduleOp); 789 auto global = builder.create<memref::GlobalOp>( 790 loc, 791 /*sym_name=*/symbolName, 792 /*sym_visibility=*/builder.getStringAttr("private"), 793 /*type=*/memrefType, 794 /*initial_value=*/Attribute(), 795 /*constant=*/false, 796 /*alignment=*/IntegerAttr()); 797 symbolTable.insert(global); 798 // The symbol table inserts at the end of the module, but globals are a bit 799 // nicer if they are at the beginning. 800 global->moveBefore(&moduleOp.front()); 801 802 builder.restoreInsertionPoint(ip); 803 return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); 804 } 805 806 static Value warpReduction(Location loc, OpBuilder &builder, Value input, 807 CombiningKind kind, uint32_t size) { 808 Value laneVal = input; 809 // Parallel reduction using butterfly shuffles. 810 for (uint64_t i = 1; i < size; i <<= 1) { 811 Value shuffled = builder 812 .create<gpu::ShuffleOp>(loc, laneVal, i, 813 /*width=*/size, 814 /*mode=*/gpu::ShuffleMode::XOR) 815 .result(); 816 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); 817 } 818 return laneVal; 819 } 820 821 struct TestVectorDistribution 822 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> { 823 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) 824 825 void getDependentDialects(DialectRegistry ®istry) const override { 826 registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect, 827 AffineDialect>(); 828 } 829 830 StringRef getArgument() const final { return "test-vector-warp-distribute"; } 831 StringRef getDescription() const final { 832 return "Test vector warp distribute transformation and lowering patterns"; 833 } 834 TestVectorDistribution() = default; 835 TestVectorDistribution(const TestVectorDistribution &pass) 836 : PassWrapper(pass) {} 837 838 Option<bool> warpOpToSCF{ 839 *this, "rewrite-warp-ops-to-scf-if", 840 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), 841 llvm::cl::init(false)}; 842 843 Option<bool> distributeTransferWriteOps{ 844 *this, "distribute-transfer-write", 845 llvm::cl::desc("Test distribution of transfer write"), 846 llvm::cl::init(false)}; 847 848 Option<bool> hoistUniform{*this, "hoist-uniform", 849 llvm::cl::desc("Test hoist uniform"), 850 llvm::cl::init(false)}; 851 852 Option<bool> propagateDistribution{ 853 *this, "propagate-distribution", 854 llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)}; 855 856 void runOnOperation() override { 857 RewritePatternSet patterns(&getContext()); 858 859 getOperation().walk([&](Operation *op) { 860 if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) { 861 if (hoistUniform) { 862 moveScalarUniformCode(warpOp); 863 } 864 WalkResult::interrupt(); 865 } 866 }); 867 MLIRContext *ctx = &getContext(); 868 if (distributeTransferWriteOps) { 869 auto distributionFn = [](vector::TransferWriteOp writeOp) { 870 // Create a map (d0, d1) -> (d1) to distribute along the inner 871 // dimension. Once we support n-d distribution we can add more 872 // complex cases. 873 int64_t vecRank = writeOp.getVectorType().getRank(); 874 OpBuilder builder(writeOp.getContext()); 875 auto map = 876 AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); 877 return map; 878 }; 879 RewritePatternSet patterns(ctx); 880 populateDistributeTransferWriteOpPatterns(patterns, distributionFn); 881 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 882 } 883 if (propagateDistribution) { 884 RewritePatternSet patterns(ctx); 885 vector::populatePropagateWarpVectorDistributionPatterns(patterns); 886 vector::populateDistributeReduction(patterns, warpReduction); 887 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 888 } 889 WarpExecuteOnLane0LoweringOptions options; 890 options.warpAllocationFn = allocateGlobalSharedMemory; 891 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, 892 WarpExecuteOnLane0Op warpOp) { 893 builder.create<gpu::BarrierOp>(loc); 894 }; 895 // Test on one pattern in isolation. 896 if (warpOpToSCF) { 897 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); 898 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 899 return; 900 } 901 } 902 }; 903 904 } // namespace 905 906 namespace mlir { 907 namespace test { 908 void registerTestVectorLowerings() { 909 PassRegistration<TestVectorToVectorLowering>(); 910 911 PassRegistration<TestVectorContractionLowering>(); 912 913 PassRegistration<TestVectorTransposeLowering>(); 914 915 PassRegistration<TestVectorUnrollingPatterns>(); 916 917 PassRegistration<TestVectorTransferUnrollingPatterns>(); 918 919 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 920 921 PassRegistration<TestVectorDistributePatterns>(); 922 923 PassRegistration<TestVectorToLoopPatterns>(); 924 925 PassRegistration<TestVectorTransferOpt>(); 926 927 PassRegistration<TestVectorTransferLoweringPatterns>(); 928 929 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 930 931 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 932 933 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 934 935 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 936 937 PassRegistration<TestFlattenVectorTransferPatterns>(); 938 939 PassRegistration<TestVectorScanLowering>(); 940 941 PassRegistration<TestVectorDistribution>(); 942 } 943 } // namespace test 944 } // namespace mlir 945