1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorTransforms.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace mlir::vector; 23 24 namespace { 25 26 struct TestVectorToVectorConversion 27 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> { 28 TestVectorToVectorConversion() = default; 29 TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {} 30 StringRef getArgument() const final { 31 return "test-vector-to-vector-conversion"; 32 } 33 StringRef getDescription() const final { 34 return "Test conversion patterns between ops in the vector dialect"; 35 } 36 37 void getDependentDialects(DialectRegistry ®istry) const override { 38 registry.insert<AffineDialect>(); 39 } 40 41 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 42 llvm::cl::init(false)}; 43 44 void runOnFunction() override { 45 auto *ctx = &getContext(); 46 RewritePatternSet patterns(ctx); 47 if (unroll) { 48 populateVectorUnrollPatterns( 49 patterns, 50 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 51 filter)); 52 } 53 populateVectorToVectorCanonicalizationPatterns(patterns); 54 populateBubbleVectorBitCastOpPatterns(patterns); 55 populateCastAwayVectorLeadingOneDimPatterns(patterns); 56 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 57 } 58 59 private: 60 // Return the target shape based on op type. 61 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 62 if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op)) 63 return SmallVector<int64_t, 4>(2, 2); 64 if (isa<vector::ContractionOp>(op)) 65 return SmallVector<int64_t, 4>(3, 2); 66 // For transfer ops, just propagate the shape coming from 67 // InsertStridedSlices/ExtractStridedSlices. 68 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 69 VectorType dstVec; 70 for (Operation *users : readOp->getUsers()) { 71 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 72 if (!extract) 73 return llvm::None; 74 auto vecType = extract.getResult().getType().cast<VectorType>(); 75 if (dstVec && dstVec != vecType) 76 return llvm::None; 77 dstVec = vecType; 78 } 79 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 80 dstVec.getShape().end()); 81 } 82 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 83 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>(); 84 if (!insert) 85 return llvm::None; 86 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 87 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 88 } 89 return llvm::None; 90 } 91 92 static LogicalResult filter(Operation *op) { 93 return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp, 94 TransferReadOp, TransferWriteOp>(op)); 95 } 96 }; 97 98 struct TestVectorContractionConversion 99 : public PassWrapper<TestVectorContractionConversion, FunctionPass> { 100 StringRef getArgument() const final { 101 return "test-vector-contraction-conversion"; 102 } 103 StringRef getDescription() const final { 104 return "Test conversion patterns that lower contract ops in the vector " 105 "dialect"; 106 } 107 TestVectorContractionConversion() = default; 108 TestVectorContractionConversion(const TestVectorContractionConversion &pass) { 109 } 110 111 Option<bool> lowerToFlatMatrix{ 112 *this, "vector-lower-matrix-intrinsics", 113 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 114 llvm::cl::init(false)}; 115 Option<bool> lowerToFlatTranspose{ 116 *this, "vector-flat-transpose", 117 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 118 llvm::cl::init(false)}; 119 Option<bool> lowerToShuffleTranspose{ 120 *this, "vector-shuffle-transpose", 121 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), 122 llvm::cl::init(false)}; 123 Option<bool> lowerToOuterProduct{ 124 *this, "vector-outerproduct", 125 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 126 llvm::cl::init(false)}; 127 Option<bool> lowerToFilterOuterProduct{ 128 *this, "vector-filter-outerproduct", 129 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 130 "vectors of size 4."), 131 llvm::cl::init(false)}; 132 133 void runOnFunction() override { 134 RewritePatternSet patterns(&getContext()); 135 136 // Test on one pattern in isolation. 137 if (lowerToOuterProduct) { 138 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 139 VectorTransformsOptions options{lowering}; 140 patterns.add<ContractionOpToOuterProductOpLowering>(options, 141 &getContext()); 142 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 143 return; 144 } 145 146 // Test on one pattern in isolation. 147 if (lowerToFilterOuterProduct) { 148 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 149 VectorTransformsOptions options{lowering}; 150 patterns.add<ContractionOpToOuterProductOpLowering>( 151 options, &getContext(), [](vector::ContractionOp op) { 152 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 153 // where M is not 4. 154 if (op.getRhsType().getShape()[0] == 4) 155 return failure(); 156 return success(); 157 }); 158 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 159 return; 160 } 161 162 // Test on all contract lowering patterns. 163 VectorContractLowering contractLowering = VectorContractLowering::Dot; 164 if (lowerToFlatMatrix) 165 contractLowering = VectorContractLowering::Matmul; 166 VectorMultiReductionLowering vectorMultiReductionLowering = 167 VectorMultiReductionLowering::InnerParallel; 168 VectorTransposeLowering transposeLowering = 169 VectorTransposeLowering::EltWise; 170 if (lowerToFlatTranspose) 171 transposeLowering = VectorTransposeLowering::Flat; 172 if (lowerToShuffleTranspose) 173 transposeLowering = VectorTransposeLowering::Shuffle; 174 VectorTransformsOptions options{ 175 contractLowering, vectorMultiReductionLowering, transposeLowering}; 176 populateVectorBroadcastLoweringPatterns(patterns); 177 populateVectorContractLoweringPatterns(patterns, options); 178 populateVectorMaskOpLoweringPatterns(patterns); 179 if (!lowerToShuffleTranspose) 180 populateVectorShapeCastLoweringPatterns(patterns); 181 populateVectorTransposeLoweringPatterns(patterns, options); 182 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 183 } 184 }; 185 186 struct TestVectorUnrollingPatterns 187 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> { 188 StringRef getArgument() const final { 189 return "test-vector-unrolling-patterns"; 190 } 191 StringRef getDescription() const final { 192 return "Test conversion patterns to unroll contract ops in the vector " 193 "dialect"; 194 } 195 TestVectorUnrollingPatterns() = default; 196 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} 197 void runOnFunction() override { 198 MLIRContext *ctx = &getContext(); 199 RewritePatternSet patterns(ctx); 200 populateVectorUnrollPatterns( 201 patterns, UnrollVectorOptions() 202 .setNativeShape(ArrayRef<int64_t>{2, 2}) 203 .setFilterConstraint([](Operation *op) { 204 return success(isa<arith::AddFOp, vector::FMAOp>(op)); 205 })); 206 207 if (unrollBasedOnType) { 208 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 209 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 210 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 211 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 212 if (auto floatType = contractOp.getLhsType() 213 .getElementType() 214 .dyn_cast<FloatType>()) { 215 if (floatType.getWidth() == 16) { 216 nativeShape[2] = 4; 217 } 218 } 219 return nativeShape; 220 }; 221 populateVectorUnrollPatterns(patterns, 222 UnrollVectorOptions() 223 .setNativeShapeFn(nativeShapeFn) 224 .setFilterConstraint([](Operation *op) { 225 return success(isa<ContractionOp>(op)); 226 })); 227 } else { 228 populateVectorUnrollPatterns( 229 patterns, UnrollVectorOptions() 230 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 231 .setFilterConstraint([](Operation *op) { 232 return success(isa<ContractionOp>(op)); 233 })); 234 } 235 populateVectorToVectorCanonicalizationPatterns(patterns); 236 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 237 } 238 239 Option<bool> unrollBasedOnType{ 240 *this, "unroll-based-on-type", 241 llvm::cl::desc("Set the unroll factor based on type of the operation"), 242 llvm::cl::init(false)}; 243 }; 244 245 struct TestVectorDistributePatterns 246 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> { 247 StringRef getArgument() const final { 248 return "test-vector-distribute-patterns"; 249 } 250 StringRef getDescription() const final { 251 return "Test conversion patterns to distribute vector ops in the vector " 252 "dialect"; 253 } 254 TestVectorDistributePatterns() = default; 255 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {} 256 void getDependentDialects(DialectRegistry ®istry) const override { 257 registry.insert<VectorDialect>(); 258 registry.insert<AffineDialect>(); 259 } 260 ListOption<int32_t> multiplicity{ 261 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 262 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 263 264 void runOnFunction() override { 265 MLIRContext *ctx = &getContext(); 266 RewritePatternSet patterns(ctx); 267 FuncOp func = getFunction(); 268 func.walk([&](arith::AddFOp op) { 269 OpBuilder builder(op); 270 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 271 SmallVector<int64_t, 2> mul; 272 SmallVector<AffineExpr, 2> perm; 273 SmallVector<Value, 2> ids; 274 unsigned count = 0; 275 // Remove the multiplicity of 1 and calculate the affine map based on 276 // the multiplicity. 277 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 278 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 279 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 280 mul.push_back(m[i]); 281 ids.push_back(func.getArgument(count++)); 282 perm.push_back(getAffineDimExpr(i, ctx)); 283 } 284 } 285 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 286 perm, ctx); 287 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 288 builder, op.getOperation(), ids, mul, map); 289 if (ops.hasValue()) { 290 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 291 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 292 extractOp); 293 } 294 } 295 }); 296 populatePropagateVectorDistributionPatterns(patterns); 297 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 298 } 299 }; 300 301 struct TestVectorToLoopPatterns 302 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> { 303 StringRef getArgument() const final { return "test-vector-to-forloop"; } 304 StringRef getDescription() const final { 305 return "Test conversion patterns to break up a vector op into a for loop"; 306 } 307 TestVectorToLoopPatterns() = default; 308 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {} 309 void getDependentDialects(DialectRegistry ®istry) const override { 310 registry.insert<VectorDialect>(); 311 registry.insert<AffineDialect>(); 312 } 313 Option<int32_t> multiplicity{ 314 *this, "distribution-multiplicity", 315 llvm::cl::desc("Set the multiplicity used for distributing vector"), 316 llvm::cl::init(32)}; 317 void runOnFunction() override { 318 MLIRContext *ctx = &getContext(); 319 RewritePatternSet patterns(ctx); 320 FuncOp func = getFunction(); 321 func.walk([&](arith::AddFOp op) { 322 // Check that the operation type can be broken down into a loop. 323 VectorType type = op.getType().dyn_cast<VectorType>(); 324 if (!type || type.getRank() != 1 || 325 type.getNumElements() % multiplicity != 0) 326 return mlir::WalkResult::advance(); 327 auto filterAlloc = [](Operation *op) { 328 if (isa<arith::ConstantOp, memref::AllocOp, CallOp>(op)) 329 return false; 330 return true; 331 }; 332 auto dependentOps = getSlice(op, filterAlloc); 333 // Create a loop and move instructions from the Op slice into the loop. 334 OpBuilder builder(op); 335 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 336 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 337 auto numIter = 338 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 339 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 340 for (Operation *it : dependentOps) { 341 it->moveBefore(forOp.getBody()->getTerminator()); 342 } 343 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 344 // break up the original op and let the patterns propagate. 345 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 346 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 347 map); 348 if (ops.hasValue()) { 349 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 350 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 351 } 352 return mlir::WalkResult::interrupt(); 353 }); 354 populatePropagateVectorDistributionPatterns(patterns); 355 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 356 } 357 }; 358 359 struct TestVectorTransferUnrollingPatterns 360 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> { 361 void getDependentDialects(DialectRegistry ®istry) const override { 362 registry.insert<AffineDialect>(); 363 } 364 StringRef getArgument() const final { 365 return "test-vector-transfer-unrolling-patterns"; 366 } 367 StringRef getDescription() const final { 368 return "Test conversion patterns to unroll transfer ops in the vector " 369 "dialect"; 370 } 371 void runOnFunction() override { 372 MLIRContext *ctx = &getContext(); 373 RewritePatternSet patterns(ctx); 374 populateVectorUnrollPatterns( 375 patterns, 376 UnrollVectorOptions() 377 .setNativeShape(ArrayRef<int64_t>{2, 2}) 378 .setFilterConstraint([](Operation *op) { 379 return success( 380 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 381 })); 382 populateVectorToVectorCanonicalizationPatterns(patterns); 383 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 384 } 385 }; 386 387 struct TestVectorTransferFullPartialSplitPatterns 388 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 389 FunctionPass> { 390 StringRef getArgument() const final { 391 return "test-vector-transfer-full-partial-split"; 392 } 393 StringRef getDescription() const final { 394 return "Test conversion patterns to split " 395 "transfer ops via scf.if + linalg ops"; 396 } 397 TestVectorTransferFullPartialSplitPatterns() = default; 398 TestVectorTransferFullPartialSplitPatterns( 399 const TestVectorTransferFullPartialSplitPatterns &pass) {} 400 401 void getDependentDialects(DialectRegistry ®istry) const override { 402 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 403 scf::SCFDialect>(); 404 } 405 406 Option<bool> useLinalgOps{ 407 *this, "use-linalg-copy", 408 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 409 "linalg.copy operations."), 410 llvm::cl::init(false)}; 411 void runOnFunction() override { 412 MLIRContext *ctx = &getContext(); 413 RewritePatternSet patterns(ctx); 414 VectorTransformsOptions options; 415 if (useLinalgOps) 416 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 417 else 418 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 419 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 420 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 421 } 422 }; 423 424 struct TestVectorTransferOpt 425 : public PassWrapper<TestVectorTransferOpt, FunctionPass> { 426 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 427 StringRef getDescription() const final { 428 return "Test optimization transformations for transfer ops"; 429 } 430 void runOnFunction() override { transferOpflowOpt(getFunction()); } 431 }; 432 433 struct TestVectorTransferLoweringPatterns 434 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> { 435 void getDependentDialects(DialectRegistry ®istry) const override { 436 registry.insert<memref::MemRefDialect>(); 437 } 438 StringRef getArgument() const final { 439 return "test-vector-transfer-lowering-patterns"; 440 } 441 StringRef getDescription() const final { 442 return "Test conversion patterns to lower transfer ops to other vector ops"; 443 } 444 void runOnFunction() override { 445 RewritePatternSet patterns(&getContext()); 446 populateVectorTransferLoweringPatterns(patterns); 447 populateVectorTransferPermutationMapLoweringPatterns(patterns); 448 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 449 } 450 }; 451 452 struct TestVectorMultiReductionLoweringPatterns 453 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 454 FunctionPass> { 455 TestVectorMultiReductionLoweringPatterns() = default; 456 TestVectorMultiReductionLoweringPatterns( 457 const TestVectorMultiReductionLoweringPatterns &pass) {} 458 void getDependentDialects(DialectRegistry ®istry) const override { 459 registry.insert<memref::MemRefDialect>(); 460 } 461 StringRef getArgument() const final { 462 return "test-vector-multi-reduction-lowering-patterns"; 463 } 464 StringRef getDescription() const final { 465 return "Test conversion patterns to lower vector.multi_reduction to other " 466 "vector ops"; 467 } 468 Option<bool> useOuterReductions{ 469 *this, "use-outer-reductions", 470 llvm::cl::desc("Move reductions to outer most dimensions"), 471 llvm::cl::init(false)}; 472 void runOnFunction() override { 473 RewritePatternSet patterns(&getContext()); 474 populateVectorMultiReductionLoweringPatterns( 475 patterns, useOuterReductions 476 ? vector::VectorMultiReductionLowering::InnerParallel 477 : vector::VectorMultiReductionLowering::InnerReduction); 478 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 479 } 480 }; 481 482 struct TestVectorTransferCollapseInnerMostContiguousDims 483 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 484 FunctionPass> { 485 TestVectorTransferCollapseInnerMostContiguousDims() = default; 486 TestVectorTransferCollapseInnerMostContiguousDims( 487 const TestVectorTransferCollapseInnerMostContiguousDims &pass) {} 488 489 void getDependentDialects(DialectRegistry ®istry) const override { 490 registry.insert<memref::MemRefDialect, AffineDialect>(); 491 } 492 493 StringRef getArgument() const final { 494 return "test-vector-transfer-collapse-inner-most-dims"; 495 } 496 497 StringRef getDescription() const final { 498 return "Test conversion patterns that reducedes the rank of the vector " 499 "transfer memory and vector operands."; 500 } 501 502 void runOnFunction() override { 503 RewritePatternSet patterns(&getContext()); 504 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 505 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 506 } 507 }; 508 509 struct TestVectorReduceToContractPatternsPatterns 510 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 511 FunctionPass> { 512 StringRef getArgument() const final { 513 return "test-vector-reduction-to-contract-patterns"; 514 } 515 StringRef getDescription() const final { 516 return "Test patterns to convert multireduce op to contract and combine " 517 "broadcast/transpose to contract"; 518 } 519 void runOnFunction() override { 520 RewritePatternSet patterns(&getContext()); 521 populateVectorReductionToContractPatterns(patterns); 522 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 523 } 524 }; 525 526 } // end anonymous namespace 527 528 namespace mlir { 529 namespace test { 530 void registerTestVectorConversions() { 531 PassRegistration<TestVectorToVectorConversion>(); 532 533 PassRegistration<TestVectorContractionConversion>(); 534 535 PassRegistration<TestVectorUnrollingPatterns>(); 536 537 PassRegistration<TestVectorTransferUnrollingPatterns>(); 538 539 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 540 541 PassRegistration<TestVectorDistributePatterns>(); 542 543 PassRegistration<TestVectorToLoopPatterns>(); 544 545 PassRegistration<TestVectorTransferOpt>(); 546 547 PassRegistration<TestVectorTransferLoweringPatterns>(); 548 549 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 550 551 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 552 553 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 554 } 555 } // namespace test 556 } // namespace mlir 557