1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/Dialect/Vector/VectorTransforms.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 namespace { 25 26 struct TestVectorToVectorConversion 27 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> { 28 TestVectorToVectorConversion() = default; 29 TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {} 30 StringRef getArgument() const final { 31 return "test-vector-to-vector-conversion"; 32 } 33 StringRef getDescription() const final { 34 return "Test conversion patterns between ops in the vector dialect"; 35 } 36 37 void getDependentDialects(DialectRegistry ®istry) const override { 38 registry.insert<AffineDialect>(); 39 } 40 41 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 42 llvm::cl::init(false)}; 43 44 void runOnFunction() override { 45 auto *ctx = &getContext(); 46 RewritePatternSet patterns(ctx); 47 if (unroll) { 48 patterns.add<UnrollVectorPattern>( 49 ctx, 50 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 51 filter)); 52 } 53 populateVectorToVectorCanonicalizationPatterns(patterns); 54 populateVectorToVectorTransformationPatterns(patterns); 55 populateBubbleVectorBitCastOpPatterns(patterns); 56 populateCastAwayVectorLeadingOneDimPatterns(patterns); 57 populateSplitVectorTransferPatterns(patterns); 58 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 59 } 60 61 private: 62 // Return the target shape based on op type. 63 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 64 if (isa<AddFOp, SelectOp, CmpFOp>(op)) 65 return SmallVector<int64_t, 4>(2, 2); 66 if (isa<vector::ContractionOp>(op)) 67 return SmallVector<int64_t, 4>(3, 2); 68 return llvm::None; 69 } 70 71 static LogicalResult filter(Operation *op) { 72 return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp>(op)); 73 } 74 }; 75 76 struct TestVectorSlicesConversion 77 : public PassWrapper<TestVectorSlicesConversion, FunctionPass> { 78 StringRef getArgument() const final { 79 return "test-vector-slices-conversion"; 80 } 81 StringRef getDescription() const final { 82 return "Test conversion patterns that lower slices ops in the vector " 83 "dialect"; 84 } 85 void runOnFunction() override { 86 RewritePatternSet patterns(&getContext()); 87 populateVectorSlicesLoweringPatterns(patterns); 88 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 89 } 90 }; 91 92 struct TestVectorContractionConversion 93 : public PassWrapper<TestVectorContractionConversion, FunctionPass> { 94 StringRef getArgument() const final { 95 return "test-vector-contraction-conversion"; 96 } 97 StringRef getDescription() const final { 98 return "Test conversion patterns that lower contract ops in the vector " 99 "dialect"; 100 } 101 TestVectorContractionConversion() = default; 102 TestVectorContractionConversion(const TestVectorContractionConversion &pass) { 103 } 104 105 Option<bool> lowerToFlatMatrix{ 106 *this, "vector-lower-matrix-intrinsics", 107 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 108 llvm::cl::init(false)}; 109 Option<bool> lowerToFlatTranspose{ 110 *this, "vector-flat-transpose", 111 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 112 llvm::cl::init(false)}; 113 Option<bool> lowerToOuterProduct{ 114 *this, "vector-outerproduct", 115 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 116 llvm::cl::init(false)}; 117 Option<bool> lowerToFilterOuterProduct{ 118 *this, "vector-filter-outerproduct", 119 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 120 "vectors of size 4."), 121 llvm::cl::init(false)}; 122 123 void runOnFunction() override { 124 RewritePatternSet patterns(&getContext()); 125 126 // Test on one pattern in isolation. 127 if (lowerToOuterProduct) { 128 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 129 VectorTransformsOptions options{lowering}; 130 patterns.add<ContractionOpToOuterProductOpLowering>(options, 131 &getContext()); 132 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 133 return; 134 } 135 136 // Test on one pattern in isolation. 137 if (lowerToFilterOuterProduct) { 138 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 139 VectorTransformsOptions options{lowering}; 140 patterns.add<ContractionOpToOuterProductOpLowering>( 141 options, &getContext(), [](vector::ContractionOp op) { 142 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 143 // where M is not 4. 144 if (op.getRhsType().getShape()[0] == 4) 145 return failure(); 146 return success(); 147 }); 148 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 149 return; 150 } 151 152 // Test on all contract lowering patterns. 153 VectorContractLowering contractLowering = VectorContractLowering::Dot; 154 if (lowerToFlatMatrix) 155 contractLowering = VectorContractLowering::Matmul; 156 VectorTransposeLowering transposeLowering = 157 VectorTransposeLowering::EltWise; 158 if (lowerToFlatTranspose) 159 transposeLowering = VectorTransposeLowering::Flat; 160 VectorTransformsOptions options{contractLowering, transposeLowering}; 161 populateVectorContractLoweringPatterns(patterns, options); 162 populateVectorTransposeLoweringPatterns(patterns, options); 163 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 164 } 165 }; 166 167 struct TestVectorUnrollingPatterns 168 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> { 169 StringRef getArgument() const final { 170 return "test-vector-unrolling-patterns"; 171 } 172 StringRef getDescription() const final { 173 return "Test conversion patterns to unroll contract ops in the vector " 174 "dialect"; 175 } 176 TestVectorUnrollingPatterns() = default; 177 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} 178 void runOnFunction() override { 179 MLIRContext *ctx = &getContext(); 180 RewritePatternSet patterns(ctx); 181 patterns.add<UnrollVectorPattern>( 182 ctx, UnrollVectorOptions() 183 .setNativeShape(ArrayRef<int64_t>{2, 2}) 184 .setFilterConstraint([](Operation *op) { 185 return success(isa<AddFOp, vector::FMAOp>(op)); 186 })); 187 188 if (unrollBasedOnType) { 189 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 190 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 191 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 192 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 193 if (auto floatType = contractOp.getLhsType() 194 .getElementType() 195 .dyn_cast<FloatType>()) { 196 if (floatType.getWidth() == 16) { 197 nativeShape[2] = 4; 198 } 199 } 200 return nativeShape; 201 }; 202 patterns.add<UnrollVectorPattern>( 203 ctx, UnrollVectorOptions() 204 .setNativeShapeFn(nativeShapeFn) 205 .setFilterConstraint([](Operation *op) { 206 return success(isa<ContractionOp>(op)); 207 })); 208 } else { 209 patterns.add<UnrollVectorPattern>( 210 ctx, UnrollVectorOptions() 211 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 212 .setFilterConstraint([](Operation *op) { 213 return success(isa<ContractionOp>(op)); 214 })); 215 } 216 populateVectorToVectorCanonicalizationPatterns(patterns); 217 populateVectorToVectorTransformationPatterns(patterns); 218 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 219 } 220 221 Option<bool> unrollBasedOnType{ 222 *this, "unroll-based-on-type", 223 llvm::cl::desc("Set the unroll factor based on type of the operation"), 224 llvm::cl::init(false)}; 225 }; 226 227 struct TestVectorDistributePatterns 228 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> { 229 StringRef getArgument() const final { 230 return "test-vector-distribute-patterns"; 231 } 232 StringRef getDescription() const final { 233 return "Test conversion patterns to distribute vector ops in the vector " 234 "dialect"; 235 } 236 TestVectorDistributePatterns() = default; 237 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {} 238 void getDependentDialects(DialectRegistry ®istry) const override { 239 registry.insert<VectorDialect>(); 240 registry.insert<AffineDialect>(); 241 } 242 ListOption<int32_t> multiplicity{ 243 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 244 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 245 246 void runOnFunction() override { 247 MLIRContext *ctx = &getContext(); 248 RewritePatternSet patterns(ctx); 249 FuncOp func = getFunction(); 250 func.walk([&](AddFOp op) { 251 OpBuilder builder(op); 252 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 253 SmallVector<int64_t, 2> mul; 254 SmallVector<AffineExpr, 2> perm; 255 SmallVector<Value, 2> ids; 256 unsigned count = 0; 257 // Remove the multiplicity of 1 and calculate the affine map based on 258 // the multiplicity. 259 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 260 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 261 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 262 mul.push_back(m[i]); 263 ids.push_back(func.getArgument(count++)); 264 perm.push_back(getAffineDimExpr(i, ctx)); 265 } 266 } 267 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 268 perm, ctx); 269 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 270 builder, op.getOperation(), ids, mul, map); 271 if (ops.hasValue()) { 272 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 273 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 274 extractOp); 275 } 276 } 277 }); 278 populatePropagateVectorDistributionPatterns(patterns); 279 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 280 } 281 }; 282 283 struct TestVectorToLoopPatterns 284 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> { 285 StringRef getArgument() const final { return "test-vector-to-forloop"; } 286 StringRef getDescription() const final { 287 return "Test conversion patterns to break up a vector op into a for loop"; 288 } 289 TestVectorToLoopPatterns() = default; 290 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {} 291 void getDependentDialects(DialectRegistry ®istry) const override { 292 registry.insert<VectorDialect>(); 293 registry.insert<AffineDialect>(); 294 } 295 Option<int32_t> multiplicity{ 296 *this, "distribution-multiplicity", 297 llvm::cl::desc("Set the multiplicity used for distributing vector"), 298 llvm::cl::init(32)}; 299 void runOnFunction() override { 300 MLIRContext *ctx = &getContext(); 301 RewritePatternSet patterns(ctx); 302 FuncOp func = getFunction(); 303 func.walk([&](AddFOp op) { 304 // Check that the operation type can be broken down into a loop. 305 VectorType type = op.getType().dyn_cast<VectorType>(); 306 if (!type || type.getRank() != 1 || 307 type.getNumElements() % multiplicity != 0) 308 return mlir::WalkResult::advance(); 309 auto filterAlloc = [](Operation *op) { 310 if (isa<ConstantOp, memref::AllocOp, CallOp>(op)) 311 return false; 312 return true; 313 }; 314 auto dependentOps = getSlice(op, filterAlloc); 315 // Create a loop and move instructions from the Op slice into the loop. 316 OpBuilder builder(op); 317 auto zero = builder.create<ConstantOp>( 318 op.getLoc(), builder.getIndexType(), 319 builder.getIntegerAttr(builder.getIndexType(), 0)); 320 auto one = builder.create<ConstantOp>( 321 op.getLoc(), builder.getIndexType(), 322 builder.getIntegerAttr(builder.getIndexType(), 1)); 323 auto numIter = builder.create<ConstantOp>( 324 op.getLoc(), builder.getIndexType(), 325 builder.getIntegerAttr(builder.getIndexType(), multiplicity)); 326 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 327 for (Operation *it : dependentOps) { 328 it->moveBefore(forOp.getBody()->getTerminator()); 329 } 330 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 331 // break up the original op and let the patterns propagate. 332 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 333 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 334 map); 335 if (ops.hasValue()) { 336 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 337 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 338 } 339 return mlir::WalkResult::interrupt(); 340 }); 341 populatePropagateVectorDistributionPatterns(patterns); 342 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 343 } 344 }; 345 346 struct TestVectorTransferUnrollingPatterns 347 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> { 348 void getDependentDialects(DialectRegistry ®istry) const override { 349 registry.insert<AffineDialect>(); 350 } 351 StringRef getArgument() const final { 352 return "test-vector-transfer-unrolling-patterns"; 353 } 354 StringRef getDescription() const final { 355 return "Test conversion patterns to unroll transfer ops in the vector " 356 "dialect"; 357 } 358 void runOnFunction() override { 359 MLIRContext *ctx = &getContext(); 360 RewritePatternSet patterns(ctx); 361 patterns.add<UnrollVectorPattern>( 362 ctx, 363 UnrollVectorOptions() 364 .setNativeShape(ArrayRef<int64_t>{2, 2}) 365 .setFilterConstraint([](Operation *op) { 366 return success( 367 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 368 })); 369 populateVectorToVectorCanonicalizationPatterns(patterns); 370 populateVectorToVectorTransformationPatterns(patterns); 371 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 372 } 373 }; 374 375 struct TestVectorTransferFullPartialSplitPatterns 376 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 377 FunctionPass> { 378 StringRef getArgument() const final { 379 return "test-vector-transfer-full-partial-split"; 380 } 381 StringRef getDescription() const final { 382 return "Test conversion patterns to split " 383 "transfer ops via scf.if + linalg ops"; 384 } 385 TestVectorTransferFullPartialSplitPatterns() = default; 386 TestVectorTransferFullPartialSplitPatterns( 387 const TestVectorTransferFullPartialSplitPatterns &pass) {} 388 389 void getDependentDialects(DialectRegistry ®istry) const override { 390 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 391 scf::SCFDialect>(); 392 } 393 394 Option<bool> useLinalgOps{ 395 *this, "use-linalg-copy", 396 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 397 "linalg.copy operations."), 398 llvm::cl::init(false)}; 399 void runOnFunction() override { 400 MLIRContext *ctx = &getContext(); 401 RewritePatternSet patterns(ctx); 402 VectorTransformsOptions options; 403 if (useLinalgOps) 404 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 405 else 406 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 407 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 408 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 409 } 410 }; 411 412 struct TestVectorTransferOpt 413 : public PassWrapper<TestVectorTransferOpt, FunctionPass> { 414 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 415 StringRef getDescription() const final { 416 return "Test optimization transformations for transfer ops"; 417 } 418 void runOnFunction() override { transferOpflowOpt(getFunction()); } 419 }; 420 421 struct TestVectorTransferLoweringPatterns 422 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> { 423 void getDependentDialects(DialectRegistry ®istry) const override { 424 registry.insert<memref::MemRefDialect>(); 425 } 426 StringRef getArgument() const final { 427 return "test-vector-transfer-lowering-patterns"; 428 } 429 StringRef getDescription() const final { 430 return "Test conversion patterns to lower transfer ops to other vector ops"; 431 } 432 void runOnFunction() override { 433 RewritePatternSet patterns(&getContext()); 434 populateVectorTransferLoweringPatterns(patterns); 435 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 436 } 437 }; 438 439 struct TestVectorMultiReductionLoweringPatterns 440 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 441 FunctionPass> { 442 void getDependentDialects(DialectRegistry ®istry) const override { 443 registry.insert<memref::MemRefDialect>(); 444 } 445 StringRef getArgument() const final { 446 return "test-vector-multi-reduction-lowering-patterns"; 447 } 448 StringRef getDescription() const final { 449 return "Test conversion patterns to lower vector.multi_reduction to other " 450 "vector ops"; 451 } 452 void runOnFunction() override { 453 RewritePatternSet patterns(&getContext()); 454 populateVectorMultiReductionLoweringPatterns(patterns); 455 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 456 } 457 }; 458 459 } // end anonymous namespace 460 461 namespace mlir { 462 namespace test { 463 void registerTestVectorConversions() { 464 PassRegistration<TestVectorToVectorConversion>(); 465 466 PassRegistration<TestVectorSlicesConversion>(); 467 468 PassRegistration<TestVectorContractionConversion>(); 469 470 PassRegistration<TestVectorUnrollingPatterns>(); 471 472 PassRegistration<TestVectorTransferUnrollingPatterns>(); 473 474 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 475 476 PassRegistration<TestVectorDistributePatterns>(); 477 478 PassRegistration<TestVectorToLoopPatterns>(); 479 480 PassRegistration<TestVectorTransferOpt>(); 481 482 PassRegistration<TestVectorTransferLoweringPatterns>(); 483 484 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 485 } 486 } // namespace test 487 } // namespace mlir 488