1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/Dialect/Vector/VectorTransforms.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 namespace { 25 26 struct TestVectorToVectorConversion 27 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> { 28 TestVectorToVectorConversion() = default; 29 TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {} 30 StringRef getArgument() const final { 31 return "test-vector-to-vector-conversion"; 32 } 33 StringRef getDescription() const final { 34 return "Test conversion patterns between ops in the vector dialect"; 35 } 36 37 void getDependentDialects(DialectRegistry ®istry) const override { 38 registry.insert<AffineDialect>(); 39 } 40 41 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 42 llvm::cl::init(false)}; 43 44 void runOnFunction() override { 45 auto *ctx = &getContext(); 46 RewritePatternSet patterns(ctx); 47 if (unroll) { 48 populateVectorUnrollPatterns( 49 patterns, 50 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 51 filter)); 52 } 53 populateVectorToVectorCanonicalizationPatterns(patterns); 54 populateBubbleVectorBitCastOpPatterns(patterns); 55 populateCastAwayVectorLeadingOneDimPatterns(patterns); 56 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 57 } 58 59 private: 60 // Return the target shape based on op type. 61 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 62 if (isa<AddFOp, SelectOp, CmpFOp>(op)) 63 return SmallVector<int64_t, 4>(2, 2); 64 if (isa<vector::ContractionOp>(op)) 65 return SmallVector<int64_t, 4>(3, 2); 66 // For transfer ops, just propagate the shape coming from 67 // InsertStridedSlices/ExtractStridedSlices. 68 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 69 VectorType dstVec; 70 for (Operation *users : readOp->getUsers()) { 71 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 72 if (!extract) 73 return llvm::None; 74 auto vecType = extract.getResult().getType().cast<VectorType>(); 75 if (dstVec && dstVec != vecType) 76 return llvm::None; 77 dstVec = vecType; 78 } 79 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 80 dstVec.getShape().end()); 81 } 82 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 83 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>(); 84 if (!insert) 85 return llvm::None; 86 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 87 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 88 } 89 return llvm::None; 90 } 91 92 static LogicalResult filter(Operation *op) { 93 return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp, TransferReadOp, 94 TransferWriteOp>(op)); 95 } 96 }; 97 98 struct TestVectorContractionConversion 99 : public PassWrapper<TestVectorContractionConversion, FunctionPass> { 100 StringRef getArgument() const final { 101 return "test-vector-contraction-conversion"; 102 } 103 StringRef getDescription() const final { 104 return "Test conversion patterns that lower contract ops in the vector " 105 "dialect"; 106 } 107 TestVectorContractionConversion() = default; 108 TestVectorContractionConversion(const TestVectorContractionConversion &pass) { 109 } 110 111 Option<bool> lowerToFlatMatrix{ 112 *this, "vector-lower-matrix-intrinsics", 113 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 114 llvm::cl::init(false)}; 115 Option<bool> lowerToFlatTranspose{ 116 *this, "vector-flat-transpose", 117 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 118 llvm::cl::init(false)}; 119 Option<bool> lowerToOuterProduct{ 120 *this, "vector-outerproduct", 121 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 122 llvm::cl::init(false)}; 123 Option<bool> lowerToFilterOuterProduct{ 124 *this, "vector-filter-outerproduct", 125 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 126 "vectors of size 4."), 127 llvm::cl::init(false)}; 128 129 void runOnFunction() override { 130 RewritePatternSet patterns(&getContext()); 131 132 // Test on one pattern in isolation. 133 if (lowerToOuterProduct) { 134 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 135 VectorTransformsOptions options{lowering}; 136 patterns.add<ContractionOpToOuterProductOpLowering>(options, 137 &getContext()); 138 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 139 return; 140 } 141 142 // Test on one pattern in isolation. 143 if (lowerToFilterOuterProduct) { 144 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 145 VectorTransformsOptions options{lowering}; 146 patterns.add<ContractionOpToOuterProductOpLowering>( 147 options, &getContext(), [](vector::ContractionOp op) { 148 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 149 // where M is not 4. 150 if (op.getRhsType().getShape()[0] == 4) 151 return failure(); 152 return success(); 153 }); 154 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 155 return; 156 } 157 158 // Test on all contract lowering patterns. 159 VectorContractLowering contractLowering = VectorContractLowering::Dot; 160 if (lowerToFlatMatrix) 161 contractLowering = VectorContractLowering::Matmul; 162 VectorTransposeLowering transposeLowering = 163 VectorTransposeLowering::EltWise; 164 if (lowerToFlatTranspose) 165 transposeLowering = VectorTransposeLowering::Flat; 166 VectorTransformsOptions options{contractLowering, transposeLowering}; 167 populateVectorBroadcastLoweringPatterns(patterns); 168 populateVectorContractLoweringPatterns(patterns, options); 169 populateVectorMaskOpLoweringPatterns(patterns); 170 populateVectorShapeCastLoweringPatterns(patterns); 171 populateVectorTransposeLoweringPatterns(patterns, options); 172 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 173 } 174 }; 175 176 struct TestVectorUnrollingPatterns 177 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> { 178 StringRef getArgument() const final { 179 return "test-vector-unrolling-patterns"; 180 } 181 StringRef getDescription() const final { 182 return "Test conversion patterns to unroll contract ops in the vector " 183 "dialect"; 184 } 185 TestVectorUnrollingPatterns() = default; 186 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} 187 void runOnFunction() override { 188 MLIRContext *ctx = &getContext(); 189 RewritePatternSet patterns(ctx); 190 populateVectorUnrollPatterns( 191 patterns, UnrollVectorOptions() 192 .setNativeShape(ArrayRef<int64_t>{2, 2}) 193 .setFilterConstraint([](Operation *op) { 194 return success(isa<AddFOp, vector::FMAOp>(op)); 195 })); 196 197 if (unrollBasedOnType) { 198 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 199 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 200 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 201 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 202 if (auto floatType = contractOp.getLhsType() 203 .getElementType() 204 .dyn_cast<FloatType>()) { 205 if (floatType.getWidth() == 16) { 206 nativeShape[2] = 4; 207 } 208 } 209 return nativeShape; 210 }; 211 populateVectorUnrollPatterns(patterns, 212 UnrollVectorOptions() 213 .setNativeShapeFn(nativeShapeFn) 214 .setFilterConstraint([](Operation *op) { 215 return success(isa<ContractionOp>(op)); 216 })); 217 } else { 218 populateVectorUnrollPatterns( 219 patterns, UnrollVectorOptions() 220 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 221 .setFilterConstraint([](Operation *op) { 222 return success(isa<ContractionOp>(op)); 223 })); 224 } 225 populateVectorToVectorCanonicalizationPatterns(patterns); 226 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 227 } 228 229 Option<bool> unrollBasedOnType{ 230 *this, "unroll-based-on-type", 231 llvm::cl::desc("Set the unroll factor based on type of the operation"), 232 llvm::cl::init(false)}; 233 }; 234 235 struct TestVectorDistributePatterns 236 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> { 237 StringRef getArgument() const final { 238 return "test-vector-distribute-patterns"; 239 } 240 StringRef getDescription() const final { 241 return "Test conversion patterns to distribute vector ops in the vector " 242 "dialect"; 243 } 244 TestVectorDistributePatterns() = default; 245 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {} 246 void getDependentDialects(DialectRegistry ®istry) const override { 247 registry.insert<VectorDialect>(); 248 registry.insert<AffineDialect>(); 249 } 250 ListOption<int32_t> multiplicity{ 251 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 252 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 253 254 void runOnFunction() override { 255 MLIRContext *ctx = &getContext(); 256 RewritePatternSet patterns(ctx); 257 FuncOp func = getFunction(); 258 func.walk([&](AddFOp op) { 259 OpBuilder builder(op); 260 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 261 SmallVector<int64_t, 2> mul; 262 SmallVector<AffineExpr, 2> perm; 263 SmallVector<Value, 2> ids; 264 unsigned count = 0; 265 // Remove the multiplicity of 1 and calculate the affine map based on 266 // the multiplicity. 267 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 268 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 269 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 270 mul.push_back(m[i]); 271 ids.push_back(func.getArgument(count++)); 272 perm.push_back(getAffineDimExpr(i, ctx)); 273 } 274 } 275 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 276 perm, ctx); 277 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 278 builder, op.getOperation(), ids, mul, map); 279 if (ops.hasValue()) { 280 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 281 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 282 extractOp); 283 } 284 } 285 }); 286 populatePropagateVectorDistributionPatterns(patterns); 287 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 288 } 289 }; 290 291 struct TestVectorToLoopPatterns 292 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> { 293 StringRef getArgument() const final { return "test-vector-to-forloop"; } 294 StringRef getDescription() const final { 295 return "Test conversion patterns to break up a vector op into a for loop"; 296 } 297 TestVectorToLoopPatterns() = default; 298 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {} 299 void getDependentDialects(DialectRegistry ®istry) const override { 300 registry.insert<VectorDialect>(); 301 registry.insert<AffineDialect>(); 302 } 303 Option<int32_t> multiplicity{ 304 *this, "distribution-multiplicity", 305 llvm::cl::desc("Set the multiplicity used for distributing vector"), 306 llvm::cl::init(32)}; 307 void runOnFunction() override { 308 MLIRContext *ctx = &getContext(); 309 RewritePatternSet patterns(ctx); 310 FuncOp func = getFunction(); 311 func.walk([&](AddFOp op) { 312 // Check that the operation type can be broken down into a loop. 313 VectorType type = op.getType().dyn_cast<VectorType>(); 314 if (!type || type.getRank() != 1 || 315 type.getNumElements() % multiplicity != 0) 316 return mlir::WalkResult::advance(); 317 auto filterAlloc = [](Operation *op) { 318 if (isa<ConstantOp, memref::AllocOp, CallOp>(op)) 319 return false; 320 return true; 321 }; 322 auto dependentOps = getSlice(op, filterAlloc); 323 // Create a loop and move instructions from the Op slice into the loop. 324 OpBuilder builder(op); 325 auto zero = builder.create<ConstantOp>( 326 op.getLoc(), builder.getIndexType(), 327 builder.getIntegerAttr(builder.getIndexType(), 0)); 328 auto one = builder.create<ConstantOp>( 329 op.getLoc(), builder.getIndexType(), 330 builder.getIntegerAttr(builder.getIndexType(), 1)); 331 auto numIter = builder.create<ConstantOp>( 332 op.getLoc(), builder.getIndexType(), 333 builder.getIntegerAttr(builder.getIndexType(), multiplicity)); 334 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 335 for (Operation *it : dependentOps) { 336 it->moveBefore(forOp.getBody()->getTerminator()); 337 } 338 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 339 // break up the original op and let the patterns propagate. 340 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 341 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 342 map); 343 if (ops.hasValue()) { 344 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 345 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 346 } 347 return mlir::WalkResult::interrupt(); 348 }); 349 populatePropagateVectorDistributionPatterns(patterns); 350 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 351 } 352 }; 353 354 struct TestVectorTransferUnrollingPatterns 355 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> { 356 void getDependentDialects(DialectRegistry ®istry) const override { 357 registry.insert<AffineDialect>(); 358 } 359 StringRef getArgument() const final { 360 return "test-vector-transfer-unrolling-patterns"; 361 } 362 StringRef getDescription() const final { 363 return "Test conversion patterns to unroll transfer ops in the vector " 364 "dialect"; 365 } 366 void runOnFunction() override { 367 MLIRContext *ctx = &getContext(); 368 RewritePatternSet patterns(ctx); 369 populateVectorUnrollPatterns( 370 patterns, 371 UnrollVectorOptions() 372 .setNativeShape(ArrayRef<int64_t>{2, 2}) 373 .setFilterConstraint([](Operation *op) { 374 return success( 375 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 376 })); 377 populateVectorToVectorCanonicalizationPatterns(patterns); 378 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 379 } 380 }; 381 382 struct TestVectorTransferFullPartialSplitPatterns 383 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 384 FunctionPass> { 385 StringRef getArgument() const final { 386 return "test-vector-transfer-full-partial-split"; 387 } 388 StringRef getDescription() const final { 389 return "Test conversion patterns to split " 390 "transfer ops via scf.if + linalg ops"; 391 } 392 TestVectorTransferFullPartialSplitPatterns() = default; 393 TestVectorTransferFullPartialSplitPatterns( 394 const TestVectorTransferFullPartialSplitPatterns &pass) {} 395 396 void getDependentDialects(DialectRegistry ®istry) const override { 397 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 398 scf::SCFDialect>(); 399 } 400 401 Option<bool> useLinalgOps{ 402 *this, "use-linalg-copy", 403 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 404 "linalg.copy operations."), 405 llvm::cl::init(false)}; 406 void runOnFunction() override { 407 MLIRContext *ctx = &getContext(); 408 RewritePatternSet patterns(ctx); 409 VectorTransformsOptions options; 410 if (useLinalgOps) 411 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 412 else 413 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 414 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 415 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 416 } 417 }; 418 419 struct TestVectorTransferOpt 420 : public PassWrapper<TestVectorTransferOpt, FunctionPass> { 421 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 422 StringRef getDescription() const final { 423 return "Test optimization transformations for transfer ops"; 424 } 425 void runOnFunction() override { transferOpflowOpt(getFunction()); } 426 }; 427 428 struct TestVectorTransferLoweringPatterns 429 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> { 430 void getDependentDialects(DialectRegistry ®istry) const override { 431 registry.insert<memref::MemRefDialect>(); 432 } 433 StringRef getArgument() const final { 434 return "test-vector-transfer-lowering-patterns"; 435 } 436 StringRef getDescription() const final { 437 return "Test conversion patterns to lower transfer ops to other vector ops"; 438 } 439 void runOnFunction() override { 440 RewritePatternSet patterns(&getContext()); 441 populateVectorTransferLoweringPatterns(patterns); 442 populateVectorTransferPermutationMapLoweringPatterns(patterns); 443 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 444 } 445 }; 446 447 struct TestVectorMultiReductionLoweringPatterns 448 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 449 FunctionPass> { 450 TestVectorMultiReductionLoweringPatterns() = default; 451 TestVectorMultiReductionLoweringPatterns( 452 const TestVectorMultiReductionLoweringPatterns &pass) {} 453 void getDependentDialects(DialectRegistry ®istry) const override { 454 registry.insert<memref::MemRefDialect>(); 455 } 456 StringRef getArgument() const final { 457 return "test-vector-multi-reduction-lowering-patterns"; 458 } 459 StringRef getDescription() const final { 460 return "Test conversion patterns to lower vector.multi_reduction to other " 461 "vector ops"; 462 } 463 Option<bool> useOuterReductions{ 464 *this, "use-outer-reductions", 465 llvm::cl::desc("Move reductions to outer most dimensions"), 466 llvm::cl::init(false)}; 467 void runOnFunction() override { 468 RewritePatternSet patterns(&getContext()); 469 populateVectorMultiReductionLoweringPatterns(patterns, !useOuterReductions); 470 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 471 } 472 }; 473 474 } // end anonymous namespace 475 476 namespace mlir { 477 namespace test { 478 void registerTestVectorConversions() { 479 PassRegistration<TestVectorToVectorConversion>(); 480 481 PassRegistration<TestVectorContractionConversion>(); 482 483 PassRegistration<TestVectorUnrollingPatterns>(); 484 485 PassRegistration<TestVectorTransferUnrollingPatterns>(); 486 487 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 488 489 PassRegistration<TestVectorDistributePatterns>(); 490 491 PassRegistration<TestVectorToLoopPatterns>(); 492 493 PassRegistration<TestVectorTransferOpt>(); 494 495 PassRegistration<TestVectorTransferLoweringPatterns>(); 496 497 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 498 } 499 } // namespace test 500 } // namespace mlir 501