1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/Dialect/Vector/VectorTransforms.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 namespace { 25 26 struct TestVectorToVectorConversion 27 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> { 28 TestVectorToVectorConversion() = default; 29 TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {} 30 StringRef getArgument() const final { 31 return "test-vector-to-vector-conversion"; 32 } 33 StringRef getDescription() const final { 34 return "Test conversion patterns between ops in the vector dialect"; 35 } 36 37 void getDependentDialects(DialectRegistry ®istry) const override { 38 registry.insert<AffineDialect>(); 39 } 40 41 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 42 llvm::cl::init(false)}; 43 44 void runOnFunction() override { 45 auto *ctx = &getContext(); 46 RewritePatternSet patterns(ctx); 47 if (unroll) { 48 populateVectorUnrollPatterns( 49 patterns, 50 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 51 filter)); 52 } 53 populateVectorToVectorCanonicalizationPatterns(patterns); 54 populateBubbleVectorBitCastOpPatterns(patterns); 55 populateCastAwayVectorLeadingOneDimPatterns(patterns); 56 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 57 } 58 59 private: 60 // Return the target shape based on op type. 61 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 62 if (isa<AddFOp, SelectOp, CmpFOp>(op)) 63 return SmallVector<int64_t, 4>(2, 2); 64 if (isa<vector::ContractionOp>(op)) 65 return SmallVector<int64_t, 4>(3, 2); 66 // For transfer ops, just propagate the shape coming from 67 // InsertStridedSlices/ExtractStridedSlices. 68 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 69 VectorType dstVec; 70 for (Operation *users : readOp->getUsers()) { 71 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 72 if (!extract) 73 return llvm::None; 74 auto vecType = extract.getResult().getType().cast<VectorType>(); 75 if (dstVec && dstVec != vecType) 76 return llvm::None; 77 dstVec = vecType; 78 } 79 return SmallVector<int64_t, 4>(dstVec.getShape().begin(), 80 dstVec.getShape().end()); 81 } 82 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 83 auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>(); 84 if (!insert) 85 return llvm::None; 86 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 87 return SmallVector<int64_t, 4>(shape.begin(), shape.end()); 88 } 89 return llvm::None; 90 } 91 92 static LogicalResult filter(Operation *op) { 93 return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp, TransferReadOp, 94 TransferWriteOp>(op)); 95 } 96 }; 97 98 struct TestVectorContractionConversion 99 : public PassWrapper<TestVectorContractionConversion, FunctionPass> { 100 StringRef getArgument() const final { 101 return "test-vector-contraction-conversion"; 102 } 103 StringRef getDescription() const final { 104 return "Test conversion patterns that lower contract ops in the vector " 105 "dialect"; 106 } 107 TestVectorContractionConversion() = default; 108 TestVectorContractionConversion(const TestVectorContractionConversion &pass) { 109 } 110 111 Option<bool> lowerToFlatMatrix{ 112 *this, "vector-lower-matrix-intrinsics", 113 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 114 llvm::cl::init(false)}; 115 Option<bool> lowerToFlatTranspose{ 116 *this, "vector-flat-transpose", 117 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 118 llvm::cl::init(false)}; 119 Option<bool> lowerToOuterProduct{ 120 *this, "vector-outerproduct", 121 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 122 llvm::cl::init(false)}; 123 Option<bool> lowerToFilterOuterProduct{ 124 *this, "vector-filter-outerproduct", 125 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 126 "vectors of size 4."), 127 llvm::cl::init(false)}; 128 129 void runOnFunction() override { 130 RewritePatternSet patterns(&getContext()); 131 132 // Test on one pattern in isolation. 133 if (lowerToOuterProduct) { 134 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 135 VectorTransformsOptions options{lowering}; 136 patterns.add<ContractionOpToOuterProductOpLowering>(options, 137 &getContext()); 138 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 139 return; 140 } 141 142 // Test on one pattern in isolation. 143 if (lowerToFilterOuterProduct) { 144 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 145 VectorTransformsOptions options{lowering}; 146 patterns.add<ContractionOpToOuterProductOpLowering>( 147 options, &getContext(), [](vector::ContractionOp op) { 148 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 149 // where M is not 4. 150 if (op.getRhsType().getShape()[0] == 4) 151 return failure(); 152 return success(); 153 }); 154 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 155 return; 156 } 157 158 // Test on all contract lowering patterns. 159 VectorContractLowering contractLowering = VectorContractLowering::Dot; 160 if (lowerToFlatMatrix) 161 contractLowering = VectorContractLowering::Matmul; 162 VectorTransposeLowering transposeLowering = 163 VectorTransposeLowering::EltWise; 164 if (lowerToFlatTranspose) 165 transposeLowering = VectorTransposeLowering::Flat; 166 VectorTransformsOptions options{contractLowering, transposeLowering}; 167 populateVectorContractLoweringPatterns(patterns, options); 168 populateVectorTransposeLoweringPatterns(patterns, options); 169 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 170 } 171 }; 172 173 struct TestVectorUnrollingPatterns 174 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> { 175 StringRef getArgument() const final { 176 return "test-vector-unrolling-patterns"; 177 } 178 StringRef getDescription() const final { 179 return "Test conversion patterns to unroll contract ops in the vector " 180 "dialect"; 181 } 182 TestVectorUnrollingPatterns() = default; 183 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} 184 void runOnFunction() override { 185 MLIRContext *ctx = &getContext(); 186 RewritePatternSet patterns(ctx); 187 populateVectorUnrollPatterns( 188 patterns, UnrollVectorOptions() 189 .setNativeShape(ArrayRef<int64_t>{2, 2}) 190 .setFilterConstraint([](Operation *op) { 191 return success(isa<AddFOp, vector::FMAOp>(op)); 192 })); 193 194 if (unrollBasedOnType) { 195 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 196 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 197 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 198 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 199 if (auto floatType = contractOp.getLhsType() 200 .getElementType() 201 .dyn_cast<FloatType>()) { 202 if (floatType.getWidth() == 16) { 203 nativeShape[2] = 4; 204 } 205 } 206 return nativeShape; 207 }; 208 populateVectorUnrollPatterns(patterns, 209 UnrollVectorOptions() 210 .setNativeShapeFn(nativeShapeFn) 211 .setFilterConstraint([](Operation *op) { 212 return success(isa<ContractionOp>(op)); 213 })); 214 } else { 215 populateVectorUnrollPatterns( 216 patterns, UnrollVectorOptions() 217 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 218 .setFilterConstraint([](Operation *op) { 219 return success(isa<ContractionOp>(op)); 220 })); 221 } 222 populateVectorToVectorCanonicalizationPatterns(patterns); 223 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 224 } 225 226 Option<bool> unrollBasedOnType{ 227 *this, "unroll-based-on-type", 228 llvm::cl::desc("Set the unroll factor based on type of the operation"), 229 llvm::cl::init(false)}; 230 }; 231 232 struct TestVectorDistributePatterns 233 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> { 234 StringRef getArgument() const final { 235 return "test-vector-distribute-patterns"; 236 } 237 StringRef getDescription() const final { 238 return "Test conversion patterns to distribute vector ops in the vector " 239 "dialect"; 240 } 241 TestVectorDistributePatterns() = default; 242 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {} 243 void getDependentDialects(DialectRegistry ®istry) const override { 244 registry.insert<VectorDialect>(); 245 registry.insert<AffineDialect>(); 246 } 247 ListOption<int32_t> multiplicity{ 248 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 249 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 250 251 void runOnFunction() override { 252 MLIRContext *ctx = &getContext(); 253 RewritePatternSet patterns(ctx); 254 FuncOp func = getFunction(); 255 func.walk([&](AddFOp op) { 256 OpBuilder builder(op); 257 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 258 SmallVector<int64_t, 2> mul; 259 SmallVector<AffineExpr, 2> perm; 260 SmallVector<Value, 2> ids; 261 unsigned count = 0; 262 // Remove the multiplicity of 1 and calculate the affine map based on 263 // the multiplicity. 264 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 265 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 266 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 267 mul.push_back(m[i]); 268 ids.push_back(func.getArgument(count++)); 269 perm.push_back(getAffineDimExpr(i, ctx)); 270 } 271 } 272 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 273 perm, ctx); 274 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 275 builder, op.getOperation(), ids, mul, map); 276 if (ops.hasValue()) { 277 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 278 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 279 extractOp); 280 } 281 } 282 }); 283 populatePropagateVectorDistributionPatterns(patterns); 284 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 285 } 286 }; 287 288 struct TestVectorToLoopPatterns 289 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> { 290 StringRef getArgument() const final { return "test-vector-to-forloop"; } 291 StringRef getDescription() const final { 292 return "Test conversion patterns to break up a vector op into a for loop"; 293 } 294 TestVectorToLoopPatterns() = default; 295 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {} 296 void getDependentDialects(DialectRegistry ®istry) const override { 297 registry.insert<VectorDialect>(); 298 registry.insert<AffineDialect>(); 299 } 300 Option<int32_t> multiplicity{ 301 *this, "distribution-multiplicity", 302 llvm::cl::desc("Set the multiplicity used for distributing vector"), 303 llvm::cl::init(32)}; 304 void runOnFunction() override { 305 MLIRContext *ctx = &getContext(); 306 RewritePatternSet patterns(ctx); 307 FuncOp func = getFunction(); 308 func.walk([&](AddFOp op) { 309 // Check that the operation type can be broken down into a loop. 310 VectorType type = op.getType().dyn_cast<VectorType>(); 311 if (!type || type.getRank() != 1 || 312 type.getNumElements() % multiplicity != 0) 313 return mlir::WalkResult::advance(); 314 auto filterAlloc = [](Operation *op) { 315 if (isa<ConstantOp, memref::AllocOp, CallOp>(op)) 316 return false; 317 return true; 318 }; 319 auto dependentOps = getSlice(op, filterAlloc); 320 // Create a loop and move instructions from the Op slice into the loop. 321 OpBuilder builder(op); 322 auto zero = builder.create<ConstantOp>( 323 op.getLoc(), builder.getIndexType(), 324 builder.getIntegerAttr(builder.getIndexType(), 0)); 325 auto one = builder.create<ConstantOp>( 326 op.getLoc(), builder.getIndexType(), 327 builder.getIntegerAttr(builder.getIndexType(), 1)); 328 auto numIter = builder.create<ConstantOp>( 329 op.getLoc(), builder.getIndexType(), 330 builder.getIntegerAttr(builder.getIndexType(), multiplicity)); 331 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 332 for (Operation *it : dependentOps) { 333 it->moveBefore(forOp.getBody()->getTerminator()); 334 } 335 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 336 // break up the original op and let the patterns propagate. 337 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 338 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 339 map); 340 if (ops.hasValue()) { 341 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 342 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 343 } 344 return mlir::WalkResult::interrupt(); 345 }); 346 populatePropagateVectorDistributionPatterns(patterns); 347 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 348 } 349 }; 350 351 struct TestVectorTransferUnrollingPatterns 352 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> { 353 void getDependentDialects(DialectRegistry ®istry) const override { 354 registry.insert<AffineDialect>(); 355 } 356 StringRef getArgument() const final { 357 return "test-vector-transfer-unrolling-patterns"; 358 } 359 StringRef getDescription() const final { 360 return "Test conversion patterns to unroll transfer ops in the vector " 361 "dialect"; 362 } 363 void runOnFunction() override { 364 MLIRContext *ctx = &getContext(); 365 RewritePatternSet patterns(ctx); 366 populateVectorUnrollPatterns( 367 patterns, 368 UnrollVectorOptions() 369 .setNativeShape(ArrayRef<int64_t>{2, 2}) 370 .setFilterConstraint([](Operation *op) { 371 return success( 372 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 373 })); 374 populateVectorToVectorCanonicalizationPatterns(patterns); 375 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 376 } 377 }; 378 379 struct TestVectorTransferFullPartialSplitPatterns 380 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 381 FunctionPass> { 382 StringRef getArgument() const final { 383 return "test-vector-transfer-full-partial-split"; 384 } 385 StringRef getDescription() const final { 386 return "Test conversion patterns to split " 387 "transfer ops via scf.if + linalg ops"; 388 } 389 TestVectorTransferFullPartialSplitPatterns() = default; 390 TestVectorTransferFullPartialSplitPatterns( 391 const TestVectorTransferFullPartialSplitPatterns &pass) {} 392 393 void getDependentDialects(DialectRegistry ®istry) const override { 394 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 395 scf::SCFDialect>(); 396 } 397 398 Option<bool> useLinalgOps{ 399 *this, "use-linalg-copy", 400 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 401 "linalg.copy operations."), 402 llvm::cl::init(false)}; 403 void runOnFunction() override { 404 MLIRContext *ctx = &getContext(); 405 RewritePatternSet patterns(ctx); 406 VectorTransformsOptions options; 407 if (useLinalgOps) 408 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 409 else 410 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 411 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 412 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 413 } 414 }; 415 416 struct TestVectorTransferOpt 417 : public PassWrapper<TestVectorTransferOpt, FunctionPass> { 418 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 419 StringRef getDescription() const final { 420 return "Test optimization transformations for transfer ops"; 421 } 422 void runOnFunction() override { transferOpflowOpt(getFunction()); } 423 }; 424 425 struct TestVectorTransferLoweringPatterns 426 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> { 427 void getDependentDialects(DialectRegistry ®istry) const override { 428 registry.insert<memref::MemRefDialect>(); 429 } 430 StringRef getArgument() const final { 431 return "test-vector-transfer-lowering-patterns"; 432 } 433 StringRef getDescription() const final { 434 return "Test conversion patterns to lower transfer ops to other vector ops"; 435 } 436 void runOnFunction() override { 437 RewritePatternSet patterns(&getContext()); 438 populateVectorTransferLoweringPatterns(patterns); 439 populateVectorTransferPermutationMapLoweringPatterns(patterns); 440 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 441 } 442 }; 443 444 struct TestVectorMultiReductionLoweringPatterns 445 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 446 FunctionPass> { 447 void getDependentDialects(DialectRegistry ®istry) const override { 448 registry.insert<memref::MemRefDialect>(); 449 } 450 StringRef getArgument() const final { 451 return "test-vector-multi-reduction-lowering-patterns"; 452 } 453 StringRef getDescription() const final { 454 return "Test conversion patterns to lower vector.multi_reduction to other " 455 "vector ops"; 456 } 457 void runOnFunction() override { 458 RewritePatternSet patterns(&getContext()); 459 populateVectorMultiReductionLoweringPatterns(patterns); 460 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 461 } 462 }; 463 464 } // end anonymous namespace 465 466 namespace mlir { 467 namespace test { 468 void registerTestVectorConversions() { 469 PassRegistration<TestVectorToVectorConversion>(); 470 471 PassRegistration<TestVectorContractionConversion>(); 472 473 PassRegistration<TestVectorUnrollingPatterns>(); 474 475 PassRegistration<TestVectorTransferUnrollingPatterns>(); 476 477 PassRegistration<TestVectorTransferFullPartialSplitPatterns>(); 478 479 PassRegistration<TestVectorDistributePatterns>(); 480 481 PassRegistration<TestVectorToLoopPatterns>(); 482 483 PassRegistration<TestVectorTransferOpt>(); 484 485 PassRegistration<TestVectorTransferLoweringPatterns>(); 486 487 PassRegistration<TestVectorMultiReductionLoweringPatterns>(); 488 } 489 } // namespace test 490 } // namespace mlir 491