1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/Dialect/Vector/VectorTransforms.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 namespace { 25 26 struct TestVectorToVectorConversion 27 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> { 28 TestVectorToVectorConversion() = default; 29 TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {} 30 StringRef getArgument() const final { 31 return "test-vector-to-vector-conversion"; 32 } 33 StringRef getDescription() const final { 34 return "Test conversion patterns between ops in the vector dialect"; 35 } 36 37 void getDependentDialects(DialectRegistry ®istry) const override { 38 registry.insert<AffineDialect>(); 39 } 40 41 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 42 llvm::cl::init(false)}; 43 44 void runOnFunction() override { 45 auto *ctx = &getContext(); 46 RewritePatternSet patterns(ctx); 47 if (unroll) { 48 patterns.add<UnrollVectorPattern>( 49 ctx, 50 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 51 filter)); 52 } 53 populateVectorToVectorCanonicalizationPatterns(patterns); 54 populateVectorToVectorTransformationPatterns(patterns); 55 populateBubbleVectorBitCastOpPatterns(patterns); 56 populateCastAwayVectorLeadingOneDimPatterns(patterns); 57 populateSplitVectorTransferPatterns(patterns); 58 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 59 } 60 61 private: 62 // Return the target shape based on op type. 63 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 64 if (isa<AddFOp, SelectOp, CmpFOp>(op)) 65 return SmallVector<int64_t, 4>(2, 2); 66 if (isa<vector::ContractionOp>(op)) 67 return SmallVector<int64_t, 4>(3, 2); 68 return llvm::None; 69 } 70 71 static LogicalResult filter(Operation *op) { 72 return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp>(op)); 73 } 74 }; 75 76 struct TestVectorSlicesConversion 77 : public PassWrapper<TestVectorSlicesConversion, FunctionPass> { 78 StringRef getArgument() const final { 79 return "test-vector-slices-conversion"; 80 } 81 StringRef getDescription() const final { 82 return "Test conversion patterns that lower slices ops in the vector " 83 "dialect"; 84 } 85 void runOnFunction() override { 86 RewritePatternSet patterns(&getContext()); 87 populateVectorSlicesLoweringPatterns(patterns); 88 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 89 } 90 }; 91 92 struct TestVectorContractionConversion 93 : public PassWrapper<TestVectorContractionConversion, FunctionPass> { 94 StringRef getArgument() const final { 95 return "test-vector-contraction-conversion"; 96 } 97 StringRef getDescription() const final { 98 return "Test conversion patterns that lower contract ops in the vector " 99 "dialect"; 100 } 101 TestVectorContractionConversion() = default; 102 TestVectorContractionConversion(const TestVectorContractionConversion &pass) { 103 } 104 105 Option<bool> lowerToFlatMatrix{ 106 *this, "vector-lower-matrix-intrinsics", 107 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 108 llvm::cl::init(false)}; 109 Option<bool> lowerToFlatTranspose{ 110 *this, "vector-flat-transpose", 111 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 112 llvm::cl::init(false)}; 113 Option<bool> lowerToOuterProduct{ 114 *this, "vector-outerproduct", 115 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 116 llvm::cl::init(false)}; 117 Option<bool> lowerToFilterOuterProduct{ 118 *this, "vector-filter-outerproduct", 119 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 120 "vectors of size 4."), 121 llvm::cl::init(false)}; 122 123 void runOnFunction() override { 124 RewritePatternSet patterns(&getContext()); 125 126 // Test on one pattern in isolation. 127 if (lowerToOuterProduct) { 128 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 129 VectorTransformsOptions options{lowering}; 130 patterns.add<ContractionOpToOuterProductOpLowering>(options, 131 &getContext()); 132 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 133 return; 134 } 135 136 // Test on one pattern in isolation. 137 if (lowerToFilterOuterProduct) { 138 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 139 VectorTransformsOptions options{lowering}; 140 patterns.add<ContractionOpToOuterProductOpLowering>( 141 options, &getContext(), [](vector::ContractionOp op) { 142 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 143 // where M is not 4. 144 if (op.getRhsType().getShape()[0] == 4) 145 return failure(); 146 return success(); 147 }); 148 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 149 return; 150 } 151 152 // Test on all contract lowering patterns. 153 VectorContractLowering contractLowering = VectorContractLowering::Dot; 154 if (lowerToFlatMatrix) 155 contractLowering = VectorContractLowering::Matmul; 156 VectorTransposeLowering transposeLowering = 157 VectorTransposeLowering::EltWise; 158 if (lowerToFlatTranspose) 159 transposeLowering = VectorTransposeLowering::Flat; 160 VectorTransformsOptions options{contractLowering, transposeLowering}; 161 populateVectorContractLoweringPatterns(patterns, options); 162 populateVectorTransposeLoweringPatterns(patterns, options); 163 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 164 } 165 }; 166 167 struct TestVectorUnrollingPatterns 168 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> { 169 StringRef getArgument() const final { 170 return "test-vector-unrolling-patterns"; 171 } 172 StringRef getDescription() const final { 173 return "Test conversion patterns to unroll contract ops in the vector " 174 "dialect"; 175 } 176 TestVectorUnrollingPatterns() = default; 177 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} 178 void runOnFunction() override { 179 MLIRContext *ctx = &getContext(); 180 RewritePatternSet patterns(ctx); 181 patterns.add<UnrollVectorPattern>( 182 ctx, UnrollVectorOptions() 183 .setNativeShape(ArrayRef<int64_t>{2, 2}) 184 .setFilterConstraint([](Operation *op) { 185 return success(isa<AddFOp, vector::FMAOp>(op)); 186 })); 187 188 if (unrollBasedOnType) { 189 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 190 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 191 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 192 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 193 if (auto floatType = contractOp.getLhsType() 194 .getElementType() 195 .dyn_cast<FloatType>()) { 196 if (floatType.getWidth() == 16) { 197 nativeShape[2] = 4; 198 } 199 } 200 return nativeShape; 201 }; 202 patterns.add<UnrollVectorPattern>( 203 ctx, UnrollVectorOptions() 204 .setNativeShapeFn(nativeShapeFn) 205 .setFilterConstraint([](Operation *op) { 206 return success(isa<ContractionOp>(op)); 207 })); 208 } else { 209 patterns.add<UnrollVectorPattern>( 210 ctx, UnrollVectorOptions() 211 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 212 .setFilterConstraint([](Operation *op) { 213 return success(isa<ContractionOp>(op)); 214 })); 215 } 216 populateVectorToVectorCanonicalizationPatterns(patterns); 217 populateVectorToVectorTransformationPatterns(patterns); 218 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 219 } 220 221 Option<bool> unrollBasedOnType{ 222 *this, "unroll-based-on-type", 223 llvm::cl::desc("Set the unroll factor based on type of the operation"), 224 llvm::cl::init(false)}; 225 }; 226 227 struct TestVectorDistributePatterns 228 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> { 229 StringRef getArgument() const final { 230 return "test-vector-distribute-patterns"; 231 } 232 StringRef getDescription() const final { 233 return "Test conversion patterns to distribute vector ops in the vector " 234 "dialect"; 235 } 236 TestVectorDistributePatterns() = default; 237 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {} 238 void getDependentDialects(DialectRegistry ®istry) const override { 239 registry.insert<VectorDialect>(); 240 registry.insert<AffineDialect>(); 241 } 242 ListOption<int32_t> multiplicity{ 243 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 244 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 245 246 void runOnFunction() override { 247 MLIRContext *ctx = &getContext(); 248 RewritePatternSet patterns(ctx); 249 FuncOp func = getFunction(); 250 func.walk([&](AddFOp op) { 251 OpBuilder builder(op); 252 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 253 SmallVector<int64_t, 2> mul; 254 SmallVector<AffineExpr, 2> perm; 255 SmallVector<Value, 2> ids; 256 unsigned count = 0; 257 // Remove the multiplicity of 1 and calculate the affine map based on 258 // the multiplicity. 259 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 260 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 261 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 262 mul.push_back(m[i]); 263 ids.push_back(func.getArgument(count++)); 264 perm.push_back(getAffineDimExpr(i, ctx)); 265 } 266 } 267 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 268 perm, ctx); 269 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 270 builder, op.getOperation(), ids, mul, map); 271 if (ops.hasValue()) { 272 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 273 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 274 extractOp); 275 } 276 } 277 }); 278 patterns.add<PointwiseExtractPattern>(ctx); 279 populateVectorToVectorTransformationPatterns(patterns); 280 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 281 } 282 }; 283 284 struct TestVectorToLoopPatterns 285 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> { 286 StringRef getArgument() const final { return "test-vector-to-forloop"; } 287 StringRef getDescription() const final { 288 return "Test conversion patterns to break up a vector op into a for loop"; 289 } 290 TestVectorToLoopPatterns() = default; 291 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {} 292 void getDependentDialects(DialectRegistry ®istry) const override { 293 registry.insert<VectorDialect>(); 294 registry.insert<AffineDialect>(); 295 } 296 Option<int32_t> multiplicity{ 297 *this, "distribution-multiplicity", 298 llvm::cl::desc("Set the multiplicity used for distributing vector"), 299 llvm::cl::init(32)}; 300 void runOnFunction() override { 301 MLIRContext *ctx = &getContext(); 302 RewritePatternSet patterns(ctx); 303 FuncOp func = getFunction(); 304 func.walk([&](AddFOp op) { 305 // Check that the operation type can be broken down into a loop. 306 VectorType type = op.getType().dyn_cast<VectorType>(); 307 if (!type || type.getRank() != 1 || 308 type.getNumElements() % multiplicity != 0) 309 return mlir::WalkResult::advance(); 310 auto filterAlloc = [](Operation *op) { 311 if (isa<ConstantOp, memref::AllocOp, CallOp>(op)) 312 return false; 313 return true; 314 }; 315 auto dependentOps = getSlice(op, filterAlloc); 316 // Create a loop and move instructions from the Op slice into the loop. 317 OpBuilder builder(op); 318 auto zero = builder.create<ConstantOp>( 319 op.getLoc(), builder.getIndexType(), 320 builder.getIntegerAttr(builder.getIndexType(), 0)); 321 auto one = builder.create<ConstantOp>( 322 op.getLoc(), builder.getIndexType(), 323 builder.getIntegerAttr(builder.getIndexType(), 1)); 324 auto numIter = builder.create<ConstantOp>( 325 op.getLoc(), builder.getIndexType(), 326 builder.getIntegerAttr(builder.getIndexType(), multiplicity)); 327 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 328 for (Operation *it : dependentOps) { 329 it->moveBefore(forOp.getBody()->getTerminator()); 330 } 331 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 332 // break up the original op and let the patterns propagate. 333 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 334 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 335 map); 336 if (ops.hasValue()) { 337 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 338 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 339 } 340 return mlir::WalkResult::interrupt(); 341 }); 342 patterns.add<PointwiseExtractPattern>(ctx); 343 populateVectorToVectorTransformationPatterns(patterns); 344 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 345 } 346 }; 347 348 struct TestVectorTransferUnrollingPatterns 349 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> { 350 void getDependentDialects(DialectRegistry ®istry) const override { 351 registry.insert<AffineDialect>(); 352 } 353 StringRef getArgument() const final { 354 return "test-vector-transfer-unrolling-patterns"; 355 } 356 StringRef getDescription() const final { 357 return "Test conversion patterns to unroll transfer ops in the vector " 358 "dialect"; 359 } 360 void runOnFunction() override { 361 MLIRContext *ctx = &getContext(); 362 RewritePatternSet patterns(ctx); 363 patterns.add<UnrollVectorPattern>( 364 ctx, 365 UnrollVectorOptions() 366 .setNativeShape(ArrayRef<int64_t>{2, 2}) 367 .setFilterConstraint([](Operation *op) { 368 return success( 369 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 370 })); 371 populateVectorToVectorCanonicalizationPatterns(patterns); 372 populateVectorToVectorTransformationPatterns(patterns); 373 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 374 } 375 }; 376 377 struct TestVectorTransferFullPartialSplitPatterns 378 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 379 FunctionPass> { 380 StringRef getArgument() const final { 381 return "test-vector-transfer-full-partial-split"; 382 } 383 StringRef getDescription() const final { 384 return "Test conversion patterns to split " 385 "transfer ops via scf.if + linalg ops"; 386 } 387 TestVectorTransferFullPartialSplitPatterns() = default; 388 TestVectorTransferFullPartialSplitPatterns( 389 const TestVectorTransferFullPartialSplitPatterns &pass) {} 390 391 void getDependentDialects(DialectRegistry ®istry) const override { 392 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 393 scf::SCFDialect>(); 394 } 395 396 Option<bool> useLinalgOps{ 397 *this, "use-linalg-copy", 398 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 399 "linalg.copy operations."), 400 llvm::cl::init(false)}; 401 void runOnFunction() override { 402 MLIRContext *ctx = &getContext(); 403 RewritePatternSet patterns(ctx); 404 VectorTransformsOptions options; 405 if (useLinalgOps) 406 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 407 else 408 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 409 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 410 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 411 } 412 }; 413 414 struct TestVectorTransferOpt 415 : public PassWrapper<TestVectorTransferOpt, FunctionPass> { 416 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 417 StringRef getDescription() const final { 418 return "Test optimization transformations for transfer ops"; 419 } 420 void runOnFunction() override { transferOpflowOpt(getFunction()); } 421 }; 422 423 struct TestVectorTransferLoweringPatterns 424 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> { 425 void getDependentDialects(DialectRegistry ®istry) const override { 426 registry.insert<memref::MemRefDialect>(); 427 } 428 StringRef getArgument() const final { 429 return "test-vector-transfer-lowering-patterns"; 430 } 431 StringRef getDescription() const final { 432 return "Test conversion patterns to lower transfer ops to other vector ops"; 433 } 434 void runOnFunction() override { 435 RewritePatternSet patterns(&getContext()); 436 populateVectorTransferLoweringPatterns(patterns); 437 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 438 } 439 }; 440 441 struct TestVectorMultiReductionLoweringPatterns 442 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 443 FunctionPass> { 444 void getDependentDialects(DialectRegistry ®istry) const override { 445 registry.insert<memref::MemRefDialect>(); 446 } 447 StringRef getArgument() const final { 448 return "test-vector-multi-reduction-lowering-patterns"; 449 } 450 StringRef getDescription() const final { 451 return "Test conversion patterns to lower vector.multi_reduction to other " 452 "vector ops"; 453 } 454 void runOnFunction() override { 455 RewritePatternSet patterns(&getContext()); 456 populateVectorMultiReductionLoweringPatterns(patterns); 457 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 458 } 459 }; 460 461 } // end anonymous namespace 462 463 namespace mlir { 464 namespace test { 465 void registerTestVectorConversions() { 466 PassRegistration<TestVectorToVectorConversion>(); 467 468 PassRegistration<TestVectorSlicesConversion>(); 469 470 PassRegistration<TestVectorContractionConversion>(); 471 472 PassRegistration<TestVectorUnrollingPatterns>(); 473 474 PassRegistration<TestVectorTransferUnrollingPatterns>(); 475 476 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 477 478 PassRegistration<TestVectorDistributePatterns>(); 479 480 PassRegistration<TestVectorToLoopPatterns>(); 481 482 PassRegistration<TestVectorTransferOpt>(); 483 484 PassRegistration<TestVectorTransferLoweringPatterns>(); 485 486 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 487 } 488 } // namespace test 489 } // namespace mlir 490