1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/Dialect/Vector/VectorTransforms.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 namespace { 25 26 struct TestVectorToVectorConversion 27 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> { 28 TestVectorToVectorConversion() = default; 29 TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {} 30 StringRef getArgument() const final { 31 return "test-vector-to-vector-conversion"; 32 } 33 StringRef getDescription() const final { 34 return "Test conversion patterns between ops in the vector dialect"; 35 } 36 37 void getDependentDialects(DialectRegistry ®istry) const override { 38 registry.insert<AffineDialect>(); 39 } 40 41 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 42 llvm::cl::init(false)}; 43 44 void runOnFunction() override { 45 auto *ctx = &getContext(); 46 RewritePatternSet patterns(ctx); 47 if (unroll) { 48 populateVectorUnrollPatterns( 49 patterns, 50 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 51 filter)); 52 } 53 populateVectorToVectorCanonicalizationPatterns(patterns); 54 populateBubbleVectorBitCastOpPatterns(patterns); 55 populateCastAwayVectorLeadingOneDimPatterns(patterns); 56 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 57 } 58 59 private: 60 // Return the target shape based on op type. 61 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 62 if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op)) 63 return SmallVector<int64_t, 4>(2, 2); 64 if (isa<vector::ContractionOp>(op)) 65 return SmallVector<int64_t, 4>(3, 2); 66 // For transfer ops, just propagate the shape coming from 67 // InsertStridedSlices/ExtractStridedSlices. 68 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 69 VectorType dstVec; 70 for (Operation *users : readOp->getUsers()) { 71 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 72 if (!extract) 73 return llvm::None; 74 auto vecType = extract.getResult().getType().cast<VectorType>(); 75 if (dstVec && dstVec != vecType) 76 return llvm::None; 77 dstVec = vecType; 78 } 79 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 80 dstVec.getShape().end()); 81 } 82 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 83 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>(); 84 if (!insert) 85 return llvm::None; 86 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 87 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 88 } 89 return llvm::None; 90 } 91 92 static LogicalResult filter(Operation *op) { 93 return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp, 94 TransferReadOp, TransferWriteOp>(op)); 95 } 96 }; 97 98 struct TestVectorContractionConversion 99 : public PassWrapper<TestVectorContractionConversion, FunctionPass> { 100 StringRef getArgument() const final { 101 return "test-vector-contraction-conversion"; 102 } 103 StringRef getDescription() const final { 104 return "Test conversion patterns that lower contract ops in the vector " 105 "dialect"; 106 } 107 TestVectorContractionConversion() = default; 108 TestVectorContractionConversion(const TestVectorContractionConversion &pass) { 109 } 110 111 Option<bool> lowerToFlatMatrix{ 112 *this, "vector-lower-matrix-intrinsics", 113 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 114 llvm::cl::init(false)}; 115 Option<bool> lowerToFlatTranspose{ 116 *this, "vector-flat-transpose", 117 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 118 llvm::cl::init(false)}; 119 Option<bool> lowerToOuterProduct{ 120 *this, "vector-outerproduct", 121 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 122 llvm::cl::init(false)}; 123 Option<bool> lowerToFilterOuterProduct{ 124 *this, "vector-filter-outerproduct", 125 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 126 "vectors of size 4."), 127 llvm::cl::init(false)}; 128 129 void runOnFunction() override { 130 RewritePatternSet patterns(&getContext()); 131 132 // Test on one pattern in isolation. 133 if (lowerToOuterProduct) { 134 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 135 VectorTransformsOptions options{lowering}; 136 patterns.add<ContractionOpToOuterProductOpLowering>(options, 137 &getContext()); 138 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 139 return; 140 } 141 142 // Test on one pattern in isolation. 143 if (lowerToFilterOuterProduct) { 144 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 145 VectorTransformsOptions options{lowering}; 146 patterns.add<ContractionOpToOuterProductOpLowering>( 147 options, &getContext(), [](vector::ContractionOp op) { 148 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 149 // where M is not 4. 150 if (op.getRhsType().getShape()[0] == 4) 151 return failure(); 152 return success(); 153 }); 154 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 155 return; 156 } 157 158 // Test on all contract lowering patterns. 159 VectorContractLowering contractLowering = VectorContractLowering::Dot; 160 if (lowerToFlatMatrix) 161 contractLowering = VectorContractLowering::Matmul; 162 VectorTransposeLowering transposeLowering = 163 VectorTransposeLowering::EltWise; 164 if (lowerToFlatTranspose) 165 transposeLowering = VectorTransposeLowering::Flat; 166 VectorTransformsOptions options{contractLowering, transposeLowering}; 167 populateVectorBroadcastLoweringPatterns(patterns); 168 populateVectorContractLoweringPatterns(patterns, options); 169 populateVectorMaskOpLoweringPatterns(patterns); 170 populateVectorShapeCastLoweringPatterns(patterns); 171 populateVectorTransposeLoweringPatterns(patterns, options); 172 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 173 } 174 }; 175 176 struct TestVectorUnrollingPatterns 177 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> { 178 StringRef getArgument() const final { 179 return "test-vector-unrolling-patterns"; 180 } 181 StringRef getDescription() const final { 182 return "Test conversion patterns to unroll contract ops in the vector " 183 "dialect"; 184 } 185 TestVectorUnrollingPatterns() = default; 186 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} 187 void runOnFunction() override { 188 MLIRContext *ctx = &getContext(); 189 RewritePatternSet patterns(ctx); 190 populateVectorUnrollPatterns( 191 patterns, UnrollVectorOptions() 192 .setNativeShape(ArrayRef<int64_t>{2, 2}) 193 .setFilterConstraint([](Operation *op) { 194 return success(isa<arith::AddFOp, vector::FMAOp>(op)); 195 })); 196 197 if (unrollBasedOnType) { 198 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 199 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 200 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 201 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 202 if (auto floatType = contractOp.getLhsType() 203 .getElementType() 204 .dyn_cast<FloatType>()) { 205 if (floatType.getWidth() == 16) { 206 nativeShape[2] = 4; 207 } 208 } 209 return nativeShape; 210 }; 211 populateVectorUnrollPatterns(patterns, 212 UnrollVectorOptions() 213 .setNativeShapeFn(nativeShapeFn) 214 .setFilterConstraint([](Operation *op) { 215 return success(isa<ContractionOp>(op)); 216 })); 217 } else { 218 populateVectorUnrollPatterns( 219 patterns, UnrollVectorOptions() 220 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 221 .setFilterConstraint([](Operation *op) { 222 return success(isa<ContractionOp>(op)); 223 })); 224 } 225 populateVectorToVectorCanonicalizationPatterns(patterns); 226 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 227 } 228 229 Option<bool> unrollBasedOnType{ 230 *this, "unroll-based-on-type", 231 llvm::cl::desc("Set the unroll factor based on type of the operation"), 232 llvm::cl::init(false)}; 233 }; 234 235 struct TestVectorDistributePatterns 236 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> { 237 StringRef getArgument() const final { 238 return "test-vector-distribute-patterns"; 239 } 240 StringRef getDescription() const final { 241 return "Test conversion patterns to distribute vector ops in the vector " 242 "dialect"; 243 } 244 TestVectorDistributePatterns() = default; 245 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {} 246 void getDependentDialects(DialectRegistry ®istry) const override { 247 registry.insert<VectorDialect>(); 248 registry.insert<AffineDialect>(); 249 } 250 ListOption<int32_t> multiplicity{ 251 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 252 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 253 254 void runOnFunction() override { 255 MLIRContext *ctx = &getContext(); 256 RewritePatternSet patterns(ctx); 257 FuncOp func = getFunction(); 258 func.walk([&](arith::AddFOp op) { 259 OpBuilder builder(op); 260 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 261 SmallVector<int64_t, 2> mul; 262 SmallVector<AffineExpr, 2> perm; 263 SmallVector<Value, 2> ids; 264 unsigned count = 0; 265 // Remove the multiplicity of 1 and calculate the affine map based on 266 // the multiplicity. 267 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 268 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 269 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 270 mul.push_back(m[i]); 271 ids.push_back(func.getArgument(count++)); 272 perm.push_back(getAffineDimExpr(i, ctx)); 273 } 274 } 275 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 276 perm, ctx); 277 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 278 builder, op.getOperation(), ids, mul, map); 279 if (ops.hasValue()) { 280 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 281 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 282 extractOp); 283 } 284 } 285 }); 286 populatePropagateVectorDistributionPatterns(patterns); 287 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 288 } 289 }; 290 291 struct TestVectorToLoopPatterns 292 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> { 293 StringRef getArgument() const final { return "test-vector-to-forloop"; } 294 StringRef getDescription() const final { 295 return "Test conversion patterns to break up a vector op into a for loop"; 296 } 297 TestVectorToLoopPatterns() = default; 298 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {} 299 void getDependentDialects(DialectRegistry ®istry) const override { 300 registry.insert<VectorDialect>(); 301 registry.insert<AffineDialect>(); 302 } 303 Option<int32_t> multiplicity{ 304 *this, "distribution-multiplicity", 305 llvm::cl::desc("Set the multiplicity used for distributing vector"), 306 llvm::cl::init(32)}; 307 void runOnFunction() override { 308 MLIRContext *ctx = &getContext(); 309 RewritePatternSet patterns(ctx); 310 FuncOp func = getFunction(); 311 func.walk([&](arith::AddFOp op) { 312 // Check that the operation type can be broken down into a loop. 313 VectorType type = op.getType().dyn_cast<VectorType>(); 314 if (!type || type.getRank() != 1 || 315 type.getNumElements() % multiplicity != 0) 316 return mlir::WalkResult::advance(); 317 auto filterAlloc = [](Operation *op) { 318 if (isa<arith::ConstantOp, memref::AllocOp, CallOp>(op)) 319 return false; 320 return true; 321 }; 322 auto dependentOps = getSlice(op, filterAlloc); 323 // Create a loop and move instructions from the Op slice into the loop. 324 OpBuilder builder(op); 325 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0); 326 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1); 327 auto numIter = 328 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity); 329 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 330 for (Operation *it : dependentOps) { 331 it->moveBefore(forOp.getBody()->getTerminator()); 332 } 333 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 334 // break up the original op and let the patterns propagate. 335 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 336 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 337 map); 338 if (ops.hasValue()) { 339 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 340 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 341 } 342 return mlir::WalkResult::interrupt(); 343 }); 344 populatePropagateVectorDistributionPatterns(patterns); 345 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 346 } 347 }; 348 349 struct TestVectorTransferUnrollingPatterns 350 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> { 351 void getDependentDialects(DialectRegistry ®istry) const override { 352 registry.insert<AffineDialect>(); 353 } 354 StringRef getArgument() const final { 355 return "test-vector-transfer-unrolling-patterns"; 356 } 357 StringRef getDescription() const final { 358 return "Test conversion patterns to unroll transfer ops in the vector " 359 "dialect"; 360 } 361 void runOnFunction() override { 362 MLIRContext *ctx = &getContext(); 363 RewritePatternSet patterns(ctx); 364 populateVectorUnrollPatterns( 365 patterns, 366 UnrollVectorOptions() 367 .setNativeShape(ArrayRef<int64_t>{2, 2}) 368 .setFilterConstraint([](Operation *op) { 369 return success( 370 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 371 })); 372 populateVectorToVectorCanonicalizationPatterns(patterns); 373 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 374 } 375 }; 376 377 struct TestVectorTransferFullPartialSplitPatterns 378 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 379 FunctionPass> { 380 StringRef getArgument() const final { 381 return "test-vector-transfer-full-partial-split"; 382 } 383 StringRef getDescription() const final { 384 return "Test conversion patterns to split " 385 "transfer ops via scf.if + linalg ops"; 386 } 387 TestVectorTransferFullPartialSplitPatterns() = default; 388 TestVectorTransferFullPartialSplitPatterns( 389 const TestVectorTransferFullPartialSplitPatterns &pass) {} 390 391 void getDependentDialects(DialectRegistry ®istry) const override { 392 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 393 scf::SCFDialect>(); 394 } 395 396 Option<bool> useLinalgOps{ 397 *this, "use-linalg-copy", 398 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 399 "linalg.copy operations."), 400 llvm::cl::init(false)}; 401 void runOnFunction() override { 402 MLIRContext *ctx = &getContext(); 403 RewritePatternSet patterns(ctx); 404 VectorTransformsOptions options; 405 if (useLinalgOps) 406 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 407 else 408 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 409 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 410 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 411 } 412 }; 413 414 struct TestVectorTransferOpt 415 : public PassWrapper<TestVectorTransferOpt, FunctionPass> { 416 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 417 StringRef getDescription() const final { 418 return "Test optimization transformations for transfer ops"; 419 } 420 void runOnFunction() override { transferOpflowOpt(getFunction()); } 421 }; 422 423 struct TestVectorTransferLoweringPatterns 424 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> { 425 void getDependentDialects(DialectRegistry ®istry) const override { 426 registry.insert<memref::MemRefDialect>(); 427 } 428 StringRef getArgument() const final { 429 return "test-vector-transfer-lowering-patterns"; 430 } 431 StringRef getDescription() const final { 432 return "Test conversion patterns to lower transfer ops to other vector ops"; 433 } 434 void runOnFunction() override { 435 RewritePatternSet patterns(&getContext()); 436 populateVectorTransferLoweringPatterns(patterns); 437 populateVectorTransferPermutationMapLoweringPatterns(patterns); 438 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 439 } 440 }; 441 442 struct TestVectorMultiReductionLoweringPatterns 443 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 444 FunctionPass> { 445 TestVectorMultiReductionLoweringPatterns() = default; 446 TestVectorMultiReductionLoweringPatterns( 447 const TestVectorMultiReductionLoweringPatterns &pass) {} 448 void getDependentDialects(DialectRegistry ®istry) const override { 449 registry.insert<memref::MemRefDialect>(); 450 } 451 StringRef getArgument() const final { 452 return "test-vector-multi-reduction-lowering-patterns"; 453 } 454 StringRef getDescription() const final { 455 return "Test conversion patterns to lower vector.multi_reduction to other " 456 "vector ops"; 457 } 458 Option<bool> useOuterReductions{ 459 *this, "use-outer-reductions", 460 llvm::cl::desc("Move reductions to outer most dimensions"), 461 llvm::cl::init(false)}; 462 void runOnFunction() override { 463 RewritePatternSet patterns(&getContext()); 464 populateVectorMultiReductionLoweringPatterns(patterns, !useOuterReductions); 465 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 466 } 467 }; 468 469 } // end anonymous namespace 470 471 namespace mlir { 472 namespace test { 473 void registerTestVectorConversions() { 474 PassRegistration<TestVectorToVectorConversion>(); 475 476 PassRegistration<TestVectorContractionConversion>(); 477 478 PassRegistration<TestVectorUnrollingPatterns>(); 479 480 PassRegistration<TestVectorTransferUnrollingPatterns>(); 481 482 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 483 484 PassRegistration<TestVectorDistributePatterns>(); 485 486 PassRegistration<TestVectorToLoopPatterns>(); 487 488 PassRegistration<TestVectorTransferOpt>(); 489 490 PassRegistration<TestVectorTransferLoweringPatterns>(); 491 492 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 493 } 494 } // namespace test 495 } // namespace mlir 496