1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/IR/SCF.h" 21 #include "mlir/Dialect/Vector/IR/VectorOps.h" 22 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Pass/PassManager.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 using namespace mlir::vector; 32 33 namespace { 34 35 struct TestVectorToVectorLowering 36 : public PassWrapper<TestVectorToVectorLowering, 37 OperationPass<func::FuncOp>> { 38 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) 39 40 TestVectorToVectorLowering() = default; 41 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 42 : PassWrapper(pass) {} 43 StringRef getArgument() const final { 44 return "test-vector-to-vector-lowering"; 45 } 46 StringRef getDescription() const final { 47 return "Test lowering patterns between ops in the vector dialect"; 48 } 49 50 void getDependentDialects(DialectRegistry ®istry) const override { 51 registry.insert<AffineDialect>(); 52 } 53 54 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 55 llvm::cl::init(false)}; 56 57 void runOnOperation() override { 58 auto *ctx = &getContext(); 59 RewritePatternSet patterns(ctx); 60 if (unroll) { 61 populateVectorUnrollPatterns( 62 patterns, 63 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 64 filter)); 65 } 66 populateVectorToVectorCanonicalizationPatterns(patterns); 67 populateBubbleVectorBitCastOpPatterns(patterns); 68 populateCastAwayVectorLeadingOneDimPatterns(patterns); 69 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 70 } 71 72 private: 73 // Return the target shape based on op type. 74 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 75 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 76 return SmallVector<int64_t, 4>(2, 2); 77 if (isa<vector::ContractionOp>(op)) 78 return SmallVector<int64_t, 4>(3, 2); 79 // For transfer ops, just propagate the shape coming from 80 // InsertStridedSlices/ExtractStridedSlices. 81 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 82 VectorType dstVec; 83 for (Operation *users : readOp->getUsers()) { 84 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 85 if (!extract) 86 return llvm::None; 87 auto vecType = extract.getResult().getType().cast<VectorType>(); 88 if (dstVec && dstVec != vecType) 89 return llvm::None; 90 dstVec = vecType; 91 } 92 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 93 dstVec.getShape().end()); 94 } 95 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 96 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 97 if (!insert) 98 return llvm::None; 99 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 100 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 101 } 102 return llvm::None; 103 } 104 105 static LogicalResult filter(Operation *op) { 106 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 107 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 108 } 109 }; 110 111 struct TestVectorContractionLowering 112 : public PassWrapper<TestVectorContractionLowering, 113 OperationPass<func::FuncOp>> { 114 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering) 115 116 StringRef getArgument() const final { 117 return "test-vector-contraction-lowering"; 118 } 119 StringRef getDescription() const final { 120 return "Test lowering patterns that lower contract ops in the vector " 121 "dialect"; 122 } 123 TestVectorContractionLowering() = default; 124 TestVectorContractionLowering(const TestVectorContractionLowering &pass) 125 : PassWrapper(pass) {} 126 127 Option<bool> lowerToFlatMatrix{ 128 *this, "vector-lower-matrix-intrinsics", 129 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 130 llvm::cl::init(false)}; 131 Option<bool> lowerToOuterProduct{ 132 *this, "vector-outerproduct", 133 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 134 llvm::cl::init(false)}; 135 Option<bool> lowerToFilterOuterProduct{ 136 *this, "vector-filter-outerproduct", 137 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 138 "vectors of size 4."), 139 llvm::cl::init(false)}; 140 Option<bool> lowerToParallelArith{ 141 *this, "vector-parallel-arith", 142 llvm::cl::desc("Lower vector.contract to elementwise vector ops."), 143 llvm::cl::init(false)}; 144 145 void runOnOperation() override { 146 RewritePatternSet patterns(&getContext()); 147 148 // Test on one pattern in isolation. 149 if (lowerToOuterProduct) { 150 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 151 VectorTransformsOptions options{lowering}; 152 patterns.add<ContractionOpToOuterProductOpLowering>(options, 153 &getContext()); 154 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 155 return; 156 } 157 158 // Test on one pattern in isolation. 159 if (lowerToFilterOuterProduct) { 160 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 161 VectorTransformsOptions options{lowering}; 162 patterns.add<ContractionOpToOuterProductOpLowering>( 163 options, &getContext(), [](vector::ContractionOp op) { 164 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 165 // where M is not 4. 166 if (op.getRhsType().getShape()[0] == 4) 167 return failure(); 168 return success(); 169 }); 170 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 171 return; 172 } 173 174 if (lowerToParallelArith) { 175 vector::populateVectorContractLoweringPatterns( 176 patterns, 177 vector::VectorTransformsOptions().setVectorTransformsOptions( 178 vector::VectorContractLowering::ParallelArith)); 179 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 180 return; 181 } 182 183 // Test on all contract lowering patterns. 184 VectorContractLowering contractLowering = VectorContractLowering::Dot; 185 if (lowerToFlatMatrix) 186 contractLowering = VectorContractLowering::Matmul; 187 VectorMultiReductionLowering vectorMultiReductionLowering = 188 VectorMultiReductionLowering::InnerParallel; 189 VectorTransformsOptions options{contractLowering, 190 vectorMultiReductionLowering, 191 VectorTransposeLowering()}; 192 populateVectorBroadcastLoweringPatterns(patterns); 193 populateVectorContractLoweringPatterns(patterns, options); 194 populateVectorMaskOpLoweringPatterns(patterns); 195 populateVectorShapeCastLoweringPatterns(patterns); 196 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 197 } 198 }; 199 200 struct TestVectorTransposeLowering 201 : public PassWrapper<TestVectorTransposeLowering, 202 OperationPass<func::FuncOp>> { 203 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering) 204 205 StringRef getArgument() const final { 206 return "test-vector-transpose-lowering"; 207 } 208 StringRef getDescription() const final { 209 return "Test lowering patterns that lower contract ops in the vector " 210 "dialect"; 211 } 212 TestVectorTransposeLowering() = default; 213 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) 214 : PassWrapper(pass) {} 215 216 Option<bool> lowerToEltwise{ 217 *this, "eltwise", 218 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), 219 llvm::cl::init(false)}; 220 Option<bool> lowerToFlatTranspose{ 221 *this, "flat", 222 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 223 llvm::cl::init(false)}; 224 Option<bool> lowerToShuffleTranspose{ 225 *this, "shuffle", 226 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 227 llvm::cl::init(false)}; 228 Option<bool> lowerToAvx2{ 229 *this, "avx2", 230 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), 231 llvm::cl::init(false)}; 232 233 void getDependentDialects(DialectRegistry ®istry) const override { 234 registry.insert<LLVM::LLVMDialect>(); 235 } 236 237 void runOnOperation() override { 238 RewritePatternSet patterns(&getContext()); 239 240 // Test on one pattern in isolation. 241 // Explicitly disable shape_cast lowering. 242 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() 243 .enableVectorTransposeLowering() 244 .enableShapeCastLowering(false); 245 if (lowerToEltwise) { 246 options = options.setVectorTransformsOptions( 247 VectorTransformsOptions().setVectorTransposeLowering( 248 VectorTransposeLowering::EltWise)); 249 } 250 if (lowerToFlatTranspose) { 251 options = options.setVectorTransformsOptions( 252 VectorTransformsOptions().setVectorTransposeLowering( 253 VectorTransposeLowering::Flat)); 254 } 255 if (lowerToShuffleTranspose) { 256 options = options.setVectorTransformsOptions( 257 VectorTransformsOptions().setVectorTransposeLowering( 258 VectorTransposeLowering::Shuffle)); 259 } 260 if (lowerToAvx2) { 261 options = options.enableAVX2Lowering().setAVX2LoweringOptions( 262 x86vector::avx2::LoweringOptions().setTransposeOptions( 263 x86vector::avx2::TransposeLoweringOptions() 264 .lower4x8xf32() 265 .lower8x8xf32())); 266 } 267 268 OpPassManager dynamicPM("func.func"); 269 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); 270 if (failed(runPipeline(dynamicPM, getOperation()))) 271 return signalPassFailure(); 272 } 273 }; 274 275 struct TestVectorUnrollingPatterns 276 : public PassWrapper<TestVectorUnrollingPatterns, 277 OperationPass<func::FuncOp>> { 278 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) 279 280 StringRef getArgument() const final { 281 return "test-vector-unrolling-patterns"; 282 } 283 StringRef getDescription() const final { 284 return "Test lowering patterns to unroll contract ops in the vector " 285 "dialect"; 286 } 287 TestVectorUnrollingPatterns() = default; 288 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 289 : PassWrapper(pass) {} 290 void runOnOperation() override { 291 MLIRContext *ctx = &getContext(); 292 RewritePatternSet patterns(ctx); 293 populateVectorUnrollPatterns( 294 patterns, UnrollVectorOptions() 295 .setNativeShape(ArrayRef<int64_t>{2, 2}) 296 .setFilterConstraint([](Operation *op) { 297 return success(isa<arith::AddFOp, vector::FMAOp, 298 vector::MultiDimReductionOp>(op)); 299 })); 300 populateVectorUnrollPatterns( 301 patterns, UnrollVectorOptions() 302 .setNativeShape(ArrayRef<int64_t>{2}) 303 .setFilterConstraint([](Operation *op) { 304 return success(isa<vector::ReductionOp>(op)); 305 })); 306 populateVectorUnrollPatterns( 307 patterns, UnrollVectorOptions() 308 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) 309 .setFilterConstraint([](Operation *op) { 310 return success(isa<vector::TransposeOp>(op)); 311 })); 312 313 if (unrollBasedOnType) { 314 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 315 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 316 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 317 SmallVector<int64_t, 4> nativeShape( 318 contractOp.getIteratorTypes().size(), 4); 319 Type lhsType = contractOp.getLhsType().getElementType(); 320 nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2; 321 return nativeShape; 322 }; 323 324 UnrollVectorOptions opts; 325 opts.setNativeShapeFn(nativeShapeFn) 326 .setFilterConstraint( 327 [](Operation *op) { return success(isa<ContractionOp>(op)); }); 328 329 if (!unrollOrder.empty()) { 330 opts.setUnrollTraversalOrderFn([this](Operation *op) 331 -> Optional<SmallVector<int64_t>> { 332 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 333 if (contractOp.getIteratorTypes().size() == unrollOrder.size()) 334 return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end()); 335 return None; 336 }); 337 } 338 populateVectorUnrollPatterns(patterns, opts); 339 } else { 340 auto nativeShapeFn = 341 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 342 auto contractOp = dyn_cast<ContractionOp>(op); 343 if (!contractOp) 344 return None; 345 return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2); 346 }; 347 populateVectorUnrollPatterns(patterns, 348 UnrollVectorOptions() 349 .setNativeShapeFn(nativeShapeFn) 350 .setFilterConstraint([](Operation *op) { 351 return success(isa<ContractionOp>(op)); 352 })); 353 } 354 populateVectorToVectorCanonicalizationPatterns(patterns); 355 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 356 } 357 358 ListOption<int64_t> unrollOrder{*this, "unroll-order", 359 llvm::cl::desc("set the unroll order"), 360 llvm::cl::ZeroOrMore}; 361 362 Option<bool> unrollBasedOnType{ 363 *this, "unroll-based-on-type", 364 llvm::cl::desc("Set the unroll factor based on type of the operation"), 365 llvm::cl::init(false)}; 366 }; 367 368 struct TestVectorDistributePatterns 369 : public PassWrapper<TestVectorDistributePatterns, 370 OperationPass<func::FuncOp>> { 371 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns) 372 373 StringRef getArgument() const final { 374 return "test-vector-distribute-patterns"; 375 } 376 StringRef getDescription() const final { 377 return "Test lowering patterns to distribute vector ops in the vector " 378 "dialect"; 379 } 380 TestVectorDistributePatterns() = default; 381 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) 382 : PassWrapper(pass) {} 383 void getDependentDialects(DialectRegistry ®istry) const override { 384 registry.insert<VectorDialect>(); 385 registry.insert<AffineDialect>(); 386 } 387 ListOption<int32_t> multiplicity{ 388 *this, "distribution-multiplicity", 389 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 390 391 void runOnOperation() override { 392 MLIRContext *ctx = &getContext(); 393 RewritePatternSet patterns(ctx); 394 func::FuncOp func = getOperation(); 395 func.walk([&](arith::AddFOp op) { 396 OpBuilder builder(op); 397 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 398 SmallVector<int64_t, 2> mul; 399 SmallVector<AffineExpr, 2> perm; 400 SmallVector<Value, 2> ids; 401 unsigned count = 0; 402 // Remove the multiplicity of 1 and calculate the affine map based on 403 // the multiplicity. 404 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 405 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 406 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 407 mul.push_back(m[i]); 408 ids.push_back(func.getArgument(count++)); 409 perm.push_back(getAffineDimExpr(i, ctx)); 410 } 411 } 412 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 413 perm, ctx); 414 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 415 builder, op.getOperation(), ids, mul, map); 416 if (ops.hasValue()) { 417 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 418 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 419 extractOp); 420 } 421 } 422 }); 423 populatePropagateVectorDistributionPatterns(patterns); 424 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 425 } 426 }; 427 428 struct TestVectorToLoopPatterns 429 : public PassWrapper<TestVectorToLoopPatterns, 430 OperationPass<func::FuncOp>> { 431 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns) 432 433 StringRef getArgument() const final { return "test-vector-to-forloop"; } 434 StringRef getDescription() const final { 435 return "Test lowering patterns to break up a vector op into a for loop"; 436 } 437 TestVectorToLoopPatterns() = default; 438 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) 439 : PassWrapper(pass) {} 440 void getDependentDialects(DialectRegistry ®istry) const override { 441 registry.insert<VectorDialect>(); 442 registry.insert<AffineDialect>(); 443 } 444 Option<int32_t> multiplicity{ 445 *this, "distribution-multiplicity", 446 llvm::cl::desc("Set the multiplicity used for distributing vector"), 447 llvm::cl::init(32)}; 448 void runOnOperation() override { 449 MLIRContext *ctx = &getContext(); 450 RewritePatternSet patterns(ctx); 451 func::FuncOp func = getOperation(); 452 func.walk([&](arith::AddFOp op) { 453 // Check that the operation type can be broken down into a loop. 454 VectorType type = op.getType().dyn_cast<VectorType>(); 455 if (!type || type.getRank() != 1 || 456 type.getNumElements() % multiplicity != 0) 457 return mlir::WalkResult::advance(); 458 auto filterAlloc = [](Operation *op) { 459 return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op); 460 }; 461 auto dependentOps = getSlice(op, filterAlloc); 462 // Create a loop and move instructions from the Op slice into the loop. 463 OpBuilder builder(op); 464 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 465 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 466 auto numIter = 467 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 468 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 469 for (Operation *it : dependentOps) { 470 it->moveBefore(forOp.getBody()->getTerminator()); 471 } 472 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 473 // break up the original op and let the patterns propagate. 474 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 475 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 476 map); 477 if (ops.hasValue()) { 478 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 479 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 480 } 481 return mlir::WalkResult::interrupt(); 482 }); 483 populatePropagateVectorDistributionPatterns(patterns); 484 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 485 } 486 }; 487 488 struct TestVectorTransferUnrollingPatterns 489 : public PassWrapper<TestVectorTransferUnrollingPatterns, 490 OperationPass<func::FuncOp>> { 491 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 492 TestVectorTransferUnrollingPatterns) 493 494 TestVectorTransferUnrollingPatterns() = default; 495 TestVectorTransferUnrollingPatterns( 496 const TestVectorTransferUnrollingPatterns &pass) 497 : PassWrapper(pass) {} 498 499 void getDependentDialects(DialectRegistry ®istry) const override { 500 registry.insert<AffineDialect>(); 501 } 502 StringRef getArgument() const final { 503 return "test-vector-transfer-unrolling-patterns"; 504 } 505 StringRef getDescription() const final { 506 return "Test lowering patterns to unroll transfer ops in the vector " 507 "dialect"; 508 } 509 void runOnOperation() override { 510 MLIRContext *ctx = &getContext(); 511 RewritePatternSet patterns(ctx); 512 UnrollVectorOptions opts; 513 opts.setNativeShape(ArrayRef<int64_t>{2, 2}) 514 .setFilterConstraint([](Operation *op) { 515 return success( 516 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 517 }); 518 if (reverseUnrollOrder.getValue()) { 519 opts.setUnrollTraversalOrderFn( 520 [](Operation *op) -> Optional<SmallVector<int64_t>> { 521 int64_t numLoops = 0; 522 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) 523 numLoops = readOp.getVectorType().getRank(); 524 else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) 525 numLoops = writeOp.getVectorType().getRank(); 526 else 527 return None; 528 auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops)); 529 return llvm::to_vector(order); 530 }); 531 } 532 populateVectorUnrollPatterns(patterns, opts); 533 populateVectorToVectorCanonicalizationPatterns(patterns); 534 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 535 } 536 537 Option<bool> reverseUnrollOrder{ 538 *this, "reverse-unroll-order", 539 llvm::cl::desc( 540 "reverse the order of unrolling of vector transfer operations"), 541 llvm::cl::init(false)}; 542 }; 543 544 struct TestVectorTransferFullPartialSplitPatterns 545 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 546 OperationPass<func::FuncOp>> { 547 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 548 TestVectorTransferFullPartialSplitPatterns) 549 550 StringRef getArgument() const final { 551 return "test-vector-transfer-full-partial-split"; 552 } 553 StringRef getDescription() const final { 554 return "Test lowering patterns to split " 555 "transfer ops via scf.if + linalg ops"; 556 } 557 TestVectorTransferFullPartialSplitPatterns() = default; 558 TestVectorTransferFullPartialSplitPatterns( 559 const TestVectorTransferFullPartialSplitPatterns &pass) 560 : PassWrapper(pass) {} 561 562 void getDependentDialects(DialectRegistry ®istry) const override { 563 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 564 scf::SCFDialect>(); 565 } 566 567 Option<bool> useLinalgOps{ 568 *this, "use-memref-copy", 569 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 570 "memref.copy operations."), 571 llvm::cl::init(false)}; 572 void runOnOperation() override { 573 MLIRContext *ctx = &getContext(); 574 RewritePatternSet patterns(ctx); 575 VectorTransformsOptions options; 576 if (useLinalgOps) 577 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 578 else 579 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 580 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 581 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 582 } 583 }; 584 585 struct TestVectorTransferOpt 586 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> { 587 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) 588 589 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 590 StringRef getDescription() const final { 591 return "Test optimization transformations for transfer ops"; 592 } 593 void runOnOperation() override { transferOpflowOpt(getOperation()); } 594 }; 595 596 struct TestVectorTransferLoweringPatterns 597 : public PassWrapper<TestVectorTransferLoweringPatterns, 598 OperationPass<func::FuncOp>> { 599 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 600 TestVectorTransferLoweringPatterns) 601 602 void getDependentDialects(DialectRegistry ®istry) const override { 603 registry.insert<tensor::TensorDialect, memref::MemRefDialect>(); 604 } 605 StringRef getArgument() const final { 606 return "test-vector-transfer-lowering-patterns"; 607 } 608 StringRef getDescription() const final { 609 return "Test lowering patterns to lower transfer ops to other vector ops"; 610 } 611 void runOnOperation() override { 612 RewritePatternSet patterns(&getContext()); 613 populateVectorTransferLoweringPatterns(patterns); 614 populateVectorTransferPermutationMapLoweringPatterns(patterns); 615 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 616 } 617 }; 618 619 struct TestVectorMultiReductionLoweringPatterns 620 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 621 OperationPass<func::FuncOp>> { 622 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 623 TestVectorMultiReductionLoweringPatterns) 624 625 TestVectorMultiReductionLoweringPatterns() = default; 626 TestVectorMultiReductionLoweringPatterns( 627 const TestVectorMultiReductionLoweringPatterns &pass) 628 : PassWrapper(pass) {} 629 void getDependentDialects(DialectRegistry ®istry) const override { 630 registry.insert<memref::MemRefDialect>(); 631 } 632 StringRef getArgument() const final { 633 return "test-vector-multi-reduction-lowering-patterns"; 634 } 635 StringRef getDescription() const final { 636 return "Test lowering patterns to lower vector.multi_reduction to other " 637 "vector ops"; 638 } 639 Option<bool> useOuterReductions{ 640 *this, "use-outer-reductions", 641 llvm::cl::desc("Move reductions to outer most dimensions"), 642 llvm::cl::init(false)}; 643 void runOnOperation() override { 644 RewritePatternSet patterns(&getContext()); 645 populateVectorMultiReductionLoweringPatterns( 646 patterns, useOuterReductions 647 ? vector::VectorMultiReductionLowering::InnerParallel 648 : vector::VectorMultiReductionLowering::InnerReduction); 649 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 650 } 651 }; 652 653 struct TestVectorTransferCollapseInnerMostContiguousDims 654 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 655 OperationPass<func::FuncOp>> { 656 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 657 TestVectorTransferCollapseInnerMostContiguousDims) 658 659 TestVectorTransferCollapseInnerMostContiguousDims() = default; 660 TestVectorTransferCollapseInnerMostContiguousDims( 661 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 662 663 void getDependentDialects(DialectRegistry ®istry) const override { 664 registry.insert<memref::MemRefDialect, AffineDialect>(); 665 } 666 667 StringRef getArgument() const final { 668 return "test-vector-transfer-collapse-inner-most-dims"; 669 } 670 671 StringRef getDescription() const final { 672 return "Test lowering patterns that reducedes the rank of the vector " 673 "transfer memory and vector operands."; 674 } 675 676 void runOnOperation() override { 677 RewritePatternSet patterns(&getContext()); 678 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 679 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 680 } 681 }; 682 683 struct TestVectorReduceToContractPatternsPatterns 684 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 685 OperationPass<func::FuncOp>> { 686 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 687 TestVectorReduceToContractPatternsPatterns) 688 689 StringRef getArgument() const final { 690 return "test-vector-reduction-to-contract-patterns"; 691 } 692 StringRef getDescription() const final { 693 return "Test patterns to convert multireduce op to contract and combine " 694 "broadcast/transpose to contract"; 695 } 696 void runOnOperation() override { 697 RewritePatternSet patterns(&getContext()); 698 populateVectorReductionToContractPatterns(patterns); 699 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 700 } 701 }; 702 703 struct TestVectorTransferDropUnitDimsPatterns 704 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, 705 OperationPass<func::FuncOp>> { 706 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 707 TestVectorTransferDropUnitDimsPatterns) 708 709 StringRef getArgument() const final { 710 return "test-vector-transfer-drop-unit-dims-patterns"; 711 } 712 void getDependentDialects(DialectRegistry ®istry) const override { 713 registry.insert<memref::MemRefDialect>(); 714 } 715 void runOnOperation() override { 716 RewritePatternSet patterns(&getContext()); 717 populateVectorTransferDropUnitDimsPatterns(patterns); 718 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 719 } 720 }; 721 722 struct TestFlattenVectorTransferPatterns 723 : public PassWrapper<TestFlattenVectorTransferPatterns, 724 OperationPass<func::FuncOp>> { 725 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 726 TestFlattenVectorTransferPatterns) 727 728 StringRef getArgument() const final { 729 return "test-vector-transfer-flatten-patterns"; 730 } 731 StringRef getDescription() const final { 732 return "Test patterns to rewrite contiguous row-major N-dimensional " 733 "vector.transfer_{read,write} ops into 1D transfers"; 734 } 735 void getDependentDialects(DialectRegistry ®istry) const override { 736 registry.insert<memref::MemRefDialect>(); 737 } 738 void runOnOperation() override { 739 RewritePatternSet patterns(&getContext()); 740 populateFlattenVectorTransferPatterns(patterns); 741 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 742 } 743 }; 744 745 struct TestVectorScanLowering 746 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> { 747 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) 748 749 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 750 StringRef getDescription() const final { 751 return "Test lowering patterns that lower the scan op in the vector " 752 "dialect"; 753 } 754 void runOnOperation() override { 755 RewritePatternSet patterns(&getContext()); 756 populateVectorScanLoweringPatterns(patterns); 757 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 758 } 759 }; 760 761 /// Allocate shared memory for a single warp to test lowering of 762 /// WarpExecuteOnLane0Op. 763 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, 764 WarpExecuteOnLane0Op warpOp, 765 Type type) { 766 static constexpr int64_t kSharedMemorySpace = 3; 767 // Compute type of shared memory buffer. 768 MemRefType memrefType; 769 if (auto vectorType = type.dyn_cast<VectorType>()) { 770 memrefType = 771 MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, 772 kSharedMemorySpace); 773 } else { 774 memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); 775 } 776 777 // Get symbol table holding all shared memory globals. 778 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>(); 779 SymbolTable symbolTable(moduleOp); 780 781 // Create a pretty name. 782 SmallString<64> buf; 783 llvm::raw_svector_ostream os(buf); 784 interleave(memrefType.getShape(), os, "x"); 785 os << "x" << memrefType.getElementType(); 786 std::string symbolName = (Twine("__shared_") + os.str()).str(); 787 788 auto ip = builder.saveInsertionPoint(); 789 builder.setInsertionPoint(moduleOp); 790 auto global = builder.create<memref::GlobalOp>( 791 loc, 792 /*sym_name=*/symbolName, 793 /*sym_visibility=*/builder.getStringAttr("private"), 794 /*type=*/memrefType, 795 /*initial_value=*/Attribute(), 796 /*constant=*/false, 797 /*alignment=*/IntegerAttr()); 798 symbolTable.insert(global); 799 // The symbol table inserts at the end of the module, but globals are a bit 800 // nicer if they are at the beginning. 801 global->moveBefore(&moduleOp.front()); 802 803 builder.restoreInsertionPoint(ip); 804 return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); 805 } 806 807 static Value warpReduction(Location loc, OpBuilder &builder, Value input, 808 CombiningKind kind, uint32_t size) { 809 Value laneVal = input; 810 // Parallel reduction using butterfly shuffles. 811 for (uint64_t i = 1; i < size; i <<= 1) { 812 Value shuffled = builder 813 .create<gpu::ShuffleOp>(loc, laneVal, i, 814 /*width=*/size, 815 /*mode=*/gpu::ShuffleMode::XOR) 816 .result(); 817 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); 818 } 819 return laneVal; 820 } 821 822 struct TestVectorDistribution 823 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> { 824 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) 825 826 void getDependentDialects(DialectRegistry ®istry) const override { 827 registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect, 828 AffineDialect>(); 829 } 830 831 StringRef getArgument() const final { return "test-vector-warp-distribute"; } 832 StringRef getDescription() const final { 833 return "Test vector warp distribute transformation and lowering patterns"; 834 } 835 TestVectorDistribution() = default; 836 TestVectorDistribution(const TestVectorDistribution &pass) 837 : PassWrapper(pass) {} 838 839 Option<bool> warpOpToSCF{ 840 *this, "rewrite-warp-ops-to-scf-if", 841 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), 842 llvm::cl::init(false)}; 843 844 Option<bool> distributeTransferWriteOps{ 845 *this, "distribute-transfer-write", 846 llvm::cl::desc("Test distribution of transfer write"), 847 llvm::cl::init(false)}; 848 849 Option<bool> hoistUniform{*this, "hoist-uniform", 850 llvm::cl::desc("Test hoist uniform"), 851 llvm::cl::init(false)}; 852 853 Option<bool> propagateDistribution{ 854 *this, "propagate-distribution", 855 llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)}; 856 857 void runOnOperation() override { 858 RewritePatternSet patterns(&getContext()); 859 860 getOperation().walk([&](Operation *op) { 861 if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) { 862 if (hoistUniform) { 863 moveScalarUniformCode(warpOp); 864 } 865 WalkResult::interrupt(); 866 } 867 }); 868 MLIRContext *ctx = &getContext(); 869 if (distributeTransferWriteOps) { 870 auto distributionFn = [](vector::TransferWriteOp writeOp) { 871 // Create a map (d0, d1) -> (d1) to distribute along the inner 872 // dimension. Once we support n-d distribution we can add more 873 // complex cases. 874 int64_t vecRank = writeOp.getVectorType().getRank(); 875 OpBuilder builder(writeOp.getContext()); 876 auto map = 877 AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); 878 return map; 879 }; 880 RewritePatternSet patterns(ctx); 881 populateDistributeTransferWriteOpPatterns(patterns, distributionFn); 882 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 883 } 884 if (propagateDistribution) { 885 RewritePatternSet patterns(ctx); 886 vector::populatePropagateWarpVectorDistributionPatterns(patterns); 887 vector::populateDistributeReduction(patterns, warpReduction); 888 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 889 } 890 WarpExecuteOnLane0LoweringOptions options; 891 options.warpAllocationFn = allocateGlobalSharedMemory; 892 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, 893 WarpExecuteOnLane0Op warpOp) { 894 builder.create<gpu::BarrierOp>(loc); 895 }; 896 // Test on one pattern in isolation. 897 if (warpOpToSCF) { 898 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); 899 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 900 return; 901 } 902 } 903 }; 904 905 } // namespace 906 907 namespace mlir { 908 namespace test { 909 void registerTestVectorLowerings() { 910 PassRegistration<TestVectorToVectorLowering>(); 911 912 PassRegistration<TestVectorContractionLowering>(); 913 914 PassRegistration<TestVectorTransposeLowering>(); 915 916 PassRegistration<TestVectorUnrollingPatterns>(); 917 918 PassRegistration<TestVectorTransferUnrollingPatterns>(); 919 920 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 921 922 PassRegistration<TestVectorDistributePatterns>(); 923 924 PassRegistration<TestVectorToLoopPatterns>(); 925 926 PassRegistration<TestVectorTransferOpt>(); 927 928 PassRegistration<TestVectorTransferLoweringPatterns>(); 929 930 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 931 932 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 933 934 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 935 936 PassRegistration<TestVectorTransferDropUnitDimsPatterns>(); 937 938 PassRegistration<TestFlattenVectorTransferPatterns>(); 939 940 PassRegistration<TestVectorScanLowering>(); 941 942 PassRegistration<TestVectorDistribution>(); 943 } 944 } // namespace test 945 } // namespace mlir 946