1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorTransforms.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace mlir::vector; 23 24 namespace { 25 26 struct TestVectorToVectorConversion 27 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> { 28 TestVectorToVectorConversion() = default; 29 TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {} 30 StringRef getArgument() const final { 31 return "test-vector-to-vector-conversion"; 32 } 33 StringRef getDescription() const final { 34 return "Test conversion patterns between ops in the vector dialect"; 35 } 36 37 void getDependentDialects(DialectRegistry ®istry) const override { 38 registry.insert<AffineDialect>(); 39 } 40 41 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 42 llvm::cl::init(false)}; 43 44 void runOnFunction() override { 45 auto *ctx = &getContext(); 46 RewritePatternSet patterns(ctx); 47 if (unroll) { 48 populateVectorUnrollPatterns( 49 patterns, 50 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 51 filter)); 52 } 53 populateVectorToVectorCanonicalizationPatterns(patterns); 54 populateBubbleVectorBitCastOpPatterns(patterns); 55 populateCastAwayVectorLeadingOneDimPatterns(patterns); 56 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 57 } 58 59 private: 60 // Return the target shape based on op type. 61 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 62 if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op)) 63 return SmallVector<int64_t, 4>(2, 2); 64 if (isa<vector::ContractionOp>(op)) 65 return SmallVector<int64_t, 4>(3, 2); 66 // For transfer ops, just propagate the shape coming from 67 // InsertStridedSlices/ExtractStridedSlices. 68 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 69 VectorType dstVec; 70 for (Operation *users : readOp->getUsers()) { 71 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 72 if (!extract) 73 return llvm::None; 74 auto vecType = extract.getResult().getType().cast<VectorType>(); 75 if (dstVec && dstVec != vecType) 76 return llvm::None; 77 dstVec = vecType; 78 } 79 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 80 dstVec.getShape().end()); 81 } 82 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 83 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>(); 84 if (!insert) 85 return llvm::None; 86 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 87 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 88 } 89 return llvm::None; 90 } 91 92 static LogicalResult filter(Operation *op) { 93 return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp, 94 TransferReadOp, TransferWriteOp>(op)); 95 } 96 }; 97 98 struct TestVectorContractionConversion 99 : public PassWrapper<TestVectorContractionConversion, FunctionPass> { 100 StringRef getArgument() const final { 101 return "test-vector-contraction-conversion"; 102 } 103 StringRef getDescription() const final { 104 return "Test conversion patterns that lower contract ops in the vector " 105 "dialect"; 106 } 107 TestVectorContractionConversion() = default; 108 TestVectorContractionConversion(const TestVectorContractionConversion &pass) { 109 } 110 111 Option<bool> lowerToFlatMatrix{ 112 *this, "vector-lower-matrix-intrinsics", 113 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 114 llvm::cl::init(false)}; 115 Option<bool> lowerToFlatTranspose{ 116 *this, "vector-flat-transpose", 117 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 118 llvm::cl::init(false)}; 119 Option<bool> lowerToOuterProduct{ 120 *this, "vector-outerproduct", 121 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 122 llvm::cl::init(false)}; 123 Option<bool> lowerToFilterOuterProduct{ 124 *this, "vector-filter-outerproduct", 125 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 126 "vectors of size 4."), 127 llvm::cl::init(false)}; 128 129 void runOnFunction() override { 130 RewritePatternSet patterns(&getContext()); 131 132 // Test on one pattern in isolation. 133 if (lowerToOuterProduct) { 134 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 135 VectorTransformsOptions options{lowering}; 136 patterns.add<ContractionOpToOuterProductOpLowering>(options, 137 &getContext()); 138 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 139 return; 140 } 141 142 // Test on one pattern in isolation. 143 if (lowerToFilterOuterProduct) { 144 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 145 VectorTransformsOptions options{lowering}; 146 patterns.add<ContractionOpToOuterProductOpLowering>( 147 options, &getContext(), [](vector::ContractionOp op) { 148 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 149 // where M is not 4. 150 if (op.getRhsType().getShape()[0] == 4) 151 return failure(); 152 return success(); 153 }); 154 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 155 return; 156 } 157 158 // Test on all contract lowering patterns. 159 VectorContractLowering contractLowering = VectorContractLowering::Dot; 160 if (lowerToFlatMatrix) 161 contractLowering = VectorContractLowering::Matmul; 162 VectorMultiReductionLowering vectorMultiReductionLowering = 163 VectorMultiReductionLowering::InnerParallel; 164 VectorTransposeLowering transposeLowering = 165 VectorTransposeLowering::EltWise; 166 if (lowerToFlatTranspose) 167 transposeLowering = VectorTransposeLowering::Flat; 168 VectorTransformsOptions options{ 169 contractLowering, vectorMultiReductionLowering, transposeLowering}; 170 populateVectorBroadcastLoweringPatterns(patterns); 171 populateVectorContractLoweringPatterns(patterns, options); 172 populateVectorMaskOpLoweringPatterns(patterns); 173 populateVectorShapeCastLoweringPatterns(patterns); 174 populateVectorTransposeLoweringPatterns(patterns, options); 175 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 176 } 177 }; 178 179 struct TestVectorUnrollingPatterns 180 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> { 181 StringRef getArgument() const final { 182 return "test-vector-unrolling-patterns"; 183 } 184 StringRef getDescription() const final { 185 return "Test conversion patterns to unroll contract ops in the vector " 186 "dialect"; 187 } 188 TestVectorUnrollingPatterns() = default; 189 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} 190 void runOnFunction() override { 191 MLIRContext *ctx = &getContext(); 192 RewritePatternSet patterns(ctx); 193 populateVectorUnrollPatterns( 194 patterns, UnrollVectorOptions() 195 .setNativeShape(ArrayRef<int64_t>{2, 2}) 196 .setFilterConstraint([](Operation *op) { 197 return success(isa<arith::AddFOp, vector::FMAOp>(op)); 198 })); 199 200 if (unrollBasedOnType) { 201 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 202 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 203 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 204 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 205 if (auto floatType = contractOp.getLhsType() 206 .getElementType() 207 .dyn_cast<FloatType>()) { 208 if (floatType.getWidth() == 16) { 209 nativeShape[2] = 4; 210 } 211 } 212 return nativeShape; 213 }; 214 populateVectorUnrollPatterns(patterns, 215 UnrollVectorOptions() 216 .setNativeShapeFn(nativeShapeFn) 217 .setFilterConstraint([](Operation *op) { 218 return success(isa<ContractionOp>(op)); 219 })); 220 } else { 221 populateVectorUnrollPatterns( 222 patterns, UnrollVectorOptions() 223 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 224 .setFilterConstraint([](Operation *op) { 225 return success(isa<ContractionOp>(op)); 226 })); 227 } 228 populateVectorToVectorCanonicalizationPatterns(patterns); 229 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 230 } 231 232 Option<bool> unrollBasedOnType{ 233 *this, "unroll-based-on-type", 234 llvm::cl::desc("Set the unroll factor based on type of the operation"), 235 llvm::cl::init(false)}; 236 }; 237 238 struct TestVectorDistributePatterns 239 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> { 240 StringRef getArgument() const final { 241 return "test-vector-distribute-patterns"; 242 } 243 StringRef getDescription() const final { 244 return "Test conversion patterns to distribute vector ops in the vector " 245 "dialect"; 246 } 247 TestVectorDistributePatterns() = default; 248 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {} 249 void getDependentDialects(DialectRegistry ®istry) const override { 250 registry.insert<VectorDialect>(); 251 registry.insert<AffineDialect>(); 252 } 253 ListOption<int32_t> multiplicity{ 254 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 255 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 256 257 void runOnFunction() override { 258 MLIRContext *ctx = &getContext(); 259 RewritePatternSet patterns(ctx); 260 FuncOp func = getFunction(); 261 func.walk([&](arith::AddFOp op) { 262 OpBuilder builder(op); 263 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 264 SmallVector<int64_t, 2> mul; 265 SmallVector<AffineExpr, 2> perm; 266 SmallVector<Value, 2> ids; 267 unsigned count = 0; 268 // Remove the multiplicity of 1 and calculate the affine map based on 269 // the multiplicity. 270 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 271 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 272 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 273 mul.push_back(m[i]); 274 ids.push_back(func.getArgument(count++)); 275 perm.push_back(getAffineDimExpr(i, ctx)); 276 } 277 } 278 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 279 perm, ctx); 280 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 281 builder, op.getOperation(), ids, mul, map); 282 if (ops.hasValue()) { 283 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 284 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 285 extractOp); 286 } 287 } 288 }); 289 populatePropagateVectorDistributionPatterns(patterns); 290 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 291 } 292 }; 293 294 struct TestVectorToLoopPatterns 295 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> { 296 StringRef getArgument() const final { return "test-vector-to-forloop"; } 297 StringRef getDescription() const final { 298 return "Test conversion patterns to break up a vector op into a for loop"; 299 } 300 TestVectorToLoopPatterns() = default; 301 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {} 302 void getDependentDialects(DialectRegistry ®istry) const override { 303 registry.insert<VectorDialect>(); 304 registry.insert<AffineDialect>(); 305 } 306 Option<int32_t> multiplicity{ 307 *this, "distribution-multiplicity", 308 llvm::cl::desc("Set the multiplicity used for distributing vector"), 309 llvm::cl::init(32)}; 310 void runOnFunction() override { 311 MLIRContext *ctx = &getContext(); 312 RewritePatternSet patterns(ctx); 313 FuncOp func = getFunction(); 314 func.walk([&](arith::AddFOp op) { 315 // Check that the operation type can be broken down into a loop. 316 VectorType type = op.getType().dyn_cast<VectorType>(); 317 if (!type || type.getRank() != 1 || 318 type.getNumElements() % multiplicity != 0) 319 return mlir::WalkResult::advance(); 320 auto filterAlloc = [](Operation *op) { 321 if (isa<arith::ConstantOp, memref::AllocOp, CallOp>(op)) 322 return false; 323 return true; 324 }; 325 auto dependentOps = getSlice(op, filterAlloc); 326 // Create a loop and move instructions from the Op slice into the loop. 327 OpBuilder builder(op); 328 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 329 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 330 auto numIter = 331 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 332 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 333 for (Operation *it : dependentOps) { 334 it->moveBefore(forOp.getBody()->getTerminator()); 335 } 336 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 337 // break up the original op and let the patterns propagate. 338 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 339 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 340 map); 341 if (ops.hasValue()) { 342 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 343 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 344 } 345 return mlir::WalkResult::interrupt(); 346 }); 347 populatePropagateVectorDistributionPatterns(patterns); 348 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 349 } 350 }; 351 352 struct TestVectorTransferUnrollingPatterns 353 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> { 354 void getDependentDialects(DialectRegistry ®istry) const override { 355 registry.insert<AffineDialect>(); 356 } 357 StringRef getArgument() const final { 358 return "test-vector-transfer-unrolling-patterns"; 359 } 360 StringRef getDescription() const final { 361 return "Test conversion patterns to unroll transfer ops in the vector " 362 "dialect"; 363 } 364 void runOnFunction() override { 365 MLIRContext *ctx = &getContext(); 366 RewritePatternSet patterns(ctx); 367 populateVectorUnrollPatterns( 368 patterns, 369 UnrollVectorOptions() 370 .setNativeShape(ArrayRef<int64_t>{2, 2}) 371 .setFilterConstraint([](Operation *op) { 372 return success( 373 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 374 })); 375 populateVectorToVectorCanonicalizationPatterns(patterns); 376 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 377 } 378 }; 379 380 struct TestVectorTransferFullPartialSplitPatterns 381 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 382 FunctionPass> { 383 StringRef getArgument() const final { 384 return "test-vector-transfer-full-partial-split"; 385 } 386 StringRef getDescription() const final { 387 return "Test conversion patterns to split " 388 "transfer ops via scf.if + linalg ops"; 389 } 390 TestVectorTransferFullPartialSplitPatterns() = default; 391 TestVectorTransferFullPartialSplitPatterns( 392 const TestVectorTransferFullPartialSplitPatterns &pass) {} 393 394 void getDependentDialects(DialectRegistry ®istry) const override { 395 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 396 scf::SCFDialect>(); 397 } 398 399 Option<bool> useLinalgOps{ 400 *this, "use-linalg-copy", 401 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 402 "linalg.copy operations."), 403 llvm::cl::init(false)}; 404 void runOnFunction() override { 405 MLIRContext *ctx = &getContext(); 406 RewritePatternSet patterns(ctx); 407 VectorTransformsOptions options; 408 if (useLinalgOps) 409 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 410 else 411 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 412 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 413 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 414 } 415 }; 416 417 struct TestVectorTransferOpt 418 : public PassWrapper<TestVectorTransferOpt, FunctionPass> { 419 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 420 StringRef getDescription() const final { 421 return "Test optimization transformations for transfer ops"; 422 } 423 void runOnFunction() override { transferOpflowOpt(getFunction()); } 424 }; 425 426 struct TestVectorTransferLoweringPatterns 427 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> { 428 void getDependentDialects(DialectRegistry ®istry) const override { 429 registry.insert<memref::MemRefDialect>(); 430 } 431 StringRef getArgument() const final { 432 return "test-vector-transfer-lowering-patterns"; 433 } 434 StringRef getDescription() const final { 435 return "Test conversion patterns to lower transfer ops to other vector ops"; 436 } 437 void runOnFunction() override { 438 RewritePatternSet patterns(&getContext()); 439 populateVectorTransferLoweringPatterns(patterns); 440 populateVectorTransferPermutationMapLoweringPatterns(patterns); 441 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 442 } 443 }; 444 445 struct TestVectorMultiReductionLoweringPatterns 446 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 447 FunctionPass> { 448 TestVectorMultiReductionLoweringPatterns() = default; 449 TestVectorMultiReductionLoweringPatterns( 450 const TestVectorMultiReductionLoweringPatterns &pass) {} 451 void getDependentDialects(DialectRegistry ®istry) const override { 452 registry.insert<memref::MemRefDialect>(); 453 } 454 StringRef getArgument() const final { 455 return "test-vector-multi-reduction-lowering-patterns"; 456 } 457 StringRef getDescription() const final { 458 return "Test conversion patterns to lower vector.multi_reduction to other " 459 "vector ops"; 460 } 461 Option<bool> useOuterReductions{ 462 *this, "use-outer-reductions", 463 llvm::cl::desc("Move reductions to outer most dimensions"), 464 llvm::cl::init(false)}; 465 void runOnFunction() override { 466 RewritePatternSet patterns(&getContext()); 467 populateVectorMultiReductionLoweringPatterns( 468 patterns, useOuterReductions 469 ? vector::VectorMultiReductionLowering::InnerParallel 470 : vector::VectorMultiReductionLowering::InnerReduction); 471 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 472 } 473 }; 474 475 struct TestVectorTransferCollapseInnerMostContiguousDims 476 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 477 FunctionPass> { 478 TestVectorTransferCollapseInnerMostContiguousDims() = default; 479 TestVectorTransferCollapseInnerMostContiguousDims( 480 const TestVectorTransferCollapseInnerMostContiguousDims &pass) {} 481 482 void getDependentDialects(DialectRegistry ®istry) const override { 483 registry.insert<memref::MemRefDialect, AffineDialect>(); 484 } 485 486 StringRef getArgument() const final { 487 return "test-vector-transfer-collapse-inner-most-dims"; 488 } 489 490 StringRef getDescription() const final { 491 return "Test conversion patterns that reducedes the rank of the vector " 492 "transfer memory and vector operands."; 493 } 494 495 void runOnFunction() override { 496 RewritePatternSet patterns(&getContext()); 497 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 498 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 499 } 500 }; 501 502 struct TestVectorReduceToContractPatternsPatterns 503 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 504 FunctionPass> { 505 StringRef getArgument() const final { 506 return "test-vector-reduction-to-contract-patterns"; 507 } 508 StringRef getDescription() const final { 509 return "Test patterns to convert multireduce op to contract and combine " 510 "broadcast/transpose to contract"; 511 } 512 void runOnFunction() override { 513 RewritePatternSet patterns(&getContext()); 514 populateVectorReductionToContractPatterns(patterns); 515 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 516 } 517 }; 518 519 } // end anonymous namespace 520 521 namespace mlir { 522 namespace test { 523 void registerTestVectorConversions() { 524 PassRegistration<TestVectorToVectorConversion>(); 525 526 PassRegistration<TestVectorContractionConversion>(); 527 528 PassRegistration<TestVectorUnrollingPatterns>(); 529 530 PassRegistration<TestVectorTransferUnrollingPatterns>(); 531 532 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 533 534 PassRegistration<TestVectorDistributePatterns>(); 535 536 PassRegistration<TestVectorToLoopPatterns>(); 537 538 PassRegistration<TestVectorTransferOpt>(); 539 540 PassRegistration<TestVectorTransferLoweringPatterns>(); 541 542 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 543 544 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 545 546 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 547 } 548 } // namespace test 549 } // namespace mlir 550