1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/Dialect/Vector/VectorTransforms.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 namespace { 25 26 struct TestVectorToVectorConversion 27 : public PassWrapper<TestVectorToVectorConversion, FunctionPass> { 28 TestVectorToVectorConversion() = default; 29 TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {} 30 31 void getDependentDialects(DialectRegistry ®istry) const override { 32 registry.insert<AffineDialect>(); 33 } 34 35 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 36 llvm::cl::init(false)}; 37 38 void runOnFunction() override { 39 auto *ctx = &getContext(); 40 RewritePatternSet patterns(ctx); 41 if (unroll) { 42 patterns.add<UnrollVectorPattern>( 43 ctx, 44 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 45 filter)); 46 } 47 populateVectorToVectorCanonicalizationPatterns(patterns); 48 populateVectorToVectorTransformationPatterns(patterns); 49 populateBubbleVectorBitCastOpPatterns(patterns); 50 populateCastAwayVectorLeadingOneDimPatterns(patterns); 51 populateSplitVectorTransferPatterns(patterns); 52 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 53 } 54 55 private: 56 // Return the target shape based on op type. 57 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { 58 if (isa<AddFOp, SelectOp, CmpFOp>(op)) 59 return SmallVector<int64_t, 4>(2, 2); 60 if (isa<vector::ContractionOp>(op)) 61 return SmallVector<int64_t, 4>(3, 2); 62 return llvm::None; 63 } 64 65 static LogicalResult filter(Operation *op) { 66 return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp>(op)); 67 } 68 }; 69 70 struct TestVectorSlicesConversion 71 : public PassWrapper<TestVectorSlicesConversion, FunctionPass> { 72 void runOnFunction() override { 73 RewritePatternSet patterns(&getContext()); 74 populateVectorSlicesLoweringPatterns(patterns); 75 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 76 } 77 }; 78 79 struct TestVectorContractionConversion 80 : public PassWrapper<TestVectorContractionConversion, FunctionPass> { 81 TestVectorContractionConversion() = default; 82 TestVectorContractionConversion(const TestVectorContractionConversion &pass) { 83 } 84 85 Option<bool> lowerToFlatMatrix{ 86 *this, "vector-lower-matrix-intrinsics", 87 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), 88 llvm::cl::init(false)}; 89 Option<bool> lowerToFlatTranspose{ 90 *this, "vector-flat-transpose", 91 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), 92 llvm::cl::init(false)}; 93 Option<bool> lowerToOuterProduct{ 94 *this, "vector-outerproduct", 95 llvm::cl::desc("Lower vector.contract to vector.outerproduct"), 96 llvm::cl::init(false)}; 97 Option<bool> lowerToFilterOuterProduct{ 98 *this, "vector-filter-outerproduct", 99 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " 100 "vectors of size 4."), 101 llvm::cl::init(false)}; 102 103 void runOnFunction() override { 104 RewritePatternSet patterns(&getContext()); 105 106 // Test on one pattern in isolation. 107 if (lowerToOuterProduct) { 108 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 109 VectorTransformsOptions options{lowering}; 110 patterns.add<ContractionOpToOuterProductOpLowering>(options, 111 &getContext()); 112 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 113 return; 114 } 115 116 // Test on one pattern in isolation. 117 if (lowerToFilterOuterProduct) { 118 VectorContractLowering lowering = VectorContractLowering::OuterProduct; 119 VectorTransformsOptions options{lowering}; 120 patterns.add<ContractionOpToOuterProductOpLowering>( 121 options, &getContext(), [](vector::ContractionOp op) { 122 // Only lowers vector.contract where the lhs as a type vector<MxNx?> 123 // where M is not 4. 124 if (op.getRhsType().getShape()[0] == 4) 125 return failure(); 126 return success(); 127 }); 128 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 129 return; 130 } 131 132 // Test on all contract lowering patterns. 133 VectorContractLowering contractLowering = VectorContractLowering::Dot; 134 if (lowerToFlatMatrix) 135 contractLowering = VectorContractLowering::Matmul; 136 VectorTransposeLowering transposeLowering = 137 VectorTransposeLowering::EltWise; 138 if (lowerToFlatTranspose) 139 transposeLowering = VectorTransposeLowering::Flat; 140 VectorTransformsOptions options{contractLowering, transposeLowering}; 141 populateVectorContractLoweringPatterns(patterns, options); 142 populateVectorTransposeLoweringPatterns(patterns, options); 143 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 144 } 145 }; 146 147 struct TestVectorUnrollingPatterns 148 : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> { 149 TestVectorUnrollingPatterns() = default; 150 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} 151 void runOnFunction() override { 152 MLIRContext *ctx = &getContext(); 153 RewritePatternSet patterns(ctx); 154 patterns.add<UnrollVectorPattern>( 155 ctx, UnrollVectorOptions() 156 .setNativeShape(ArrayRef<int64_t>{2, 2}) 157 .setFilterConstraint([](Operation *op) { 158 return success(isa<AddFOp, vector::FMAOp>(op)); 159 })); 160 161 if (unrollBasedOnType) { 162 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 163 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> { 164 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 165 SmallVector<int64_t, 4> nativeShape = {4, 4, 2}; 166 if (auto floatType = contractOp.getLhsType() 167 .getElementType() 168 .dyn_cast<FloatType>()) { 169 if (floatType.getWidth() == 16) { 170 nativeShape[2] = 4; 171 } 172 } 173 return nativeShape; 174 }; 175 patterns.add<UnrollVectorPattern>( 176 ctx, UnrollVectorOptions() 177 .setNativeShapeFn(nativeShapeFn) 178 .setFilterConstraint([](Operation *op) { 179 return success(isa<ContractionOp>(op)); 180 })); 181 } else { 182 patterns.add<UnrollVectorPattern>( 183 ctx, UnrollVectorOptions() 184 .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) 185 .setFilterConstraint([](Operation *op) { 186 return success(isa<ContractionOp>(op)); 187 })); 188 } 189 populateVectorToVectorCanonicalizationPatterns(patterns); 190 populateVectorToVectorTransformationPatterns(patterns); 191 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 192 } 193 194 Option<bool> unrollBasedOnType{ 195 *this, "unroll-based-on-type", 196 llvm::cl::desc("Set the unroll factor based on type of the operation"), 197 llvm::cl::init(false)}; 198 }; 199 200 struct TestVectorDistributePatterns 201 : public PassWrapper<TestVectorDistributePatterns, FunctionPass> { 202 TestVectorDistributePatterns() = default; 203 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {} 204 void getDependentDialects(DialectRegistry ®istry) const override { 205 registry.insert<VectorDialect>(); 206 registry.insert<AffineDialect>(); 207 } 208 ListOption<int32_t> multiplicity{ 209 *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, 210 llvm::cl::desc("Set the multiplicity used for distributing vector")}; 211 212 void runOnFunction() override { 213 MLIRContext *ctx = &getContext(); 214 RewritePatternSet patterns(ctx); 215 FuncOp func = getFunction(); 216 func.walk([&](AddFOp op) { 217 OpBuilder builder(op); 218 if (auto vecType = op.getType().dyn_cast<VectorType>()) { 219 SmallVector<int64_t, 2> mul; 220 SmallVector<AffineExpr, 2> perm; 221 SmallVector<Value, 2> ids; 222 unsigned count = 0; 223 // Remove the multiplicity of 1 and calculate the affine map based on 224 // the multiplicity. 225 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end()); 226 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { 227 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { 228 mul.push_back(m[i]); 229 ids.push_back(func.getArgument(count++)); 230 perm.push_back(getAffineDimExpr(i, ctx)); 231 } 232 } 233 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0, 234 perm, ctx); 235 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 236 builder, op.getOperation(), ids, mul, map); 237 if (ops.hasValue()) { 238 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 239 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), 240 extractOp); 241 } 242 } 243 }); 244 patterns.add<PointwiseExtractPattern>(ctx); 245 populateVectorToVectorTransformationPatterns(patterns); 246 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 247 } 248 }; 249 250 struct TestVectorToLoopPatterns 251 : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> { 252 TestVectorToLoopPatterns() = default; 253 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {} 254 void getDependentDialects(DialectRegistry ®istry) const override { 255 registry.insert<VectorDialect>(); 256 registry.insert<AffineDialect>(); 257 } 258 Option<int32_t> multiplicity{ 259 *this, "distribution-multiplicity", 260 llvm::cl::desc("Set the multiplicity used for distributing vector"), 261 llvm::cl::init(32)}; 262 void runOnFunction() override { 263 MLIRContext *ctx = &getContext(); 264 RewritePatternSet patterns(ctx); 265 FuncOp func = getFunction(); 266 func.walk([&](AddFOp op) { 267 // Check that the operation type can be broken down into a loop. 268 VectorType type = op.getType().dyn_cast<VectorType>(); 269 if (!type || type.getRank() != 1 || 270 type.getNumElements() % multiplicity != 0) 271 return mlir::WalkResult::advance(); 272 auto filterAlloc = [](Operation *op) { 273 if (isa<ConstantOp, memref::AllocOp, CallOp>(op)) 274 return false; 275 return true; 276 }; 277 auto dependentOps = getSlice(op, filterAlloc); 278 // Create a loop and move instructions from the Op slice into the loop. 279 OpBuilder builder(op); 280 auto zero = builder.create<ConstantOp>( 281 op.getLoc(), builder.getIndexType(), 282 builder.getIntegerAttr(builder.getIndexType(), 0)); 283 auto one = builder.create<ConstantOp>( 284 op.getLoc(), builder.getIndexType(), 285 builder.getIntegerAttr(builder.getIndexType(), 1)); 286 auto numIter = builder.create<ConstantOp>( 287 op.getLoc(), builder.getIndexType(), 288 builder.getIntegerAttr(builder.getIndexType(), multiplicity)); 289 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one); 290 for (Operation *it : dependentOps) { 291 it->moveBefore(forOp.getBody()->getTerminator()); 292 } 293 auto map = AffineMap::getMultiDimIdentityMap(1, ctx); 294 // break up the original op and let the patterns propagate. 295 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp( 296 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, 297 map); 298 if (ops.hasValue()) { 299 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert}); 300 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); 301 } 302 return mlir::WalkResult::interrupt(); 303 }); 304 patterns.add<PointwiseExtractPattern>(ctx); 305 populateVectorToVectorTransformationPatterns(patterns); 306 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 307 } 308 }; 309 310 struct TestVectorTransferUnrollingPatterns 311 : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> { 312 void getDependentDialects(DialectRegistry ®istry) const override { 313 registry.insert<AffineDialect>(); 314 } 315 void runOnFunction() override { 316 MLIRContext *ctx = &getContext(); 317 RewritePatternSet patterns(ctx); 318 patterns.add<UnrollVectorPattern>( 319 ctx, 320 UnrollVectorOptions() 321 .setNativeShape(ArrayRef<int64_t>{2, 2}) 322 .setFilterConstraint([](Operation *op) { 323 return success( 324 isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); 325 })); 326 populateVectorToVectorCanonicalizationPatterns(patterns); 327 populateVectorToVectorTransformationPatterns(patterns); 328 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 329 } 330 }; 331 332 struct TestVectorTransferFullPartialSplitPatterns 333 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns, 334 FunctionPass> { 335 TestVectorTransferFullPartialSplitPatterns() = default; 336 TestVectorTransferFullPartialSplitPatterns( 337 const TestVectorTransferFullPartialSplitPatterns &pass) {} 338 339 void getDependentDialects(DialectRegistry ®istry) const override { 340 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 341 scf::SCFDialect>(); 342 } 343 344 Option<bool> useLinalgOps{ 345 *this, "use-linalg-copy", 346 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " 347 "linalg.copy operations."), 348 llvm::cl::init(false)}; 349 void runOnFunction() override { 350 MLIRContext *ctx = &getContext(); 351 RewritePatternSet patterns(ctx); 352 VectorTransformsOptions options; 353 if (useLinalgOps) 354 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); 355 else 356 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); 357 patterns.add<VectorTransferFullPartialRewriter>(ctx, options); 358 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 359 } 360 }; 361 362 struct TestVectorTransferOpt 363 : public PassWrapper<TestVectorTransferOpt, FunctionPass> { 364 void runOnFunction() override { transferOpflowOpt(getFunction()); } 365 }; 366 367 struct TestVectorTransferLoweringPatterns 368 : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> { 369 void getDependentDialects(DialectRegistry ®istry) const override { 370 registry.insert<memref::MemRefDialect>(); 371 } 372 void runOnFunction() override { 373 RewritePatternSet patterns(&getContext()); 374 populateVectorTransferLoweringPatterns(patterns); 375 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 376 } 377 }; 378 379 struct TestVectorMultiReductionLoweringPatterns 380 : public PassWrapper<TestVectorMultiReductionLoweringPatterns, 381 FunctionPass> { 382 void getDependentDialects(DialectRegistry ®istry) const override { 383 registry.insert<memref::MemRefDialect>(); 384 } 385 void runOnFunction() override { 386 RewritePatternSet patterns(&getContext()); 387 populateVectorMultiReductionLoweringPatterns(patterns); 388 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 389 } 390 }; 391 392 } // end anonymous namespace 393 394 namespace mlir { 395 namespace test { 396 void registerTestVectorConversions() { 397 PassRegistration<TestVectorToVectorConversion> vectorToVectorPass( 398 "test-vector-to-vector-conversion", 399 "Test conversion patterns between ops in the vector dialect"); 400 401 PassRegistration<TestVectorSlicesConversion> slicesPass( 402 "test-vector-slices-conversion", 403 "Test conversion patterns that lower slices ops in the vector dialect"); 404 405 PassRegistration<TestVectorContractionConversion> contractionPass( 406 "test-vector-contraction-conversion", 407 "Test conversion patterns that lower contract ops in the vector dialect"); 408 409 PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass( 410 "test-vector-unrolling-patterns", 411 "Test conversion patterns to unroll contract ops in the vector dialect"); 412 413 PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass( 414 "test-vector-transfer-unrolling-patterns", 415 "Test conversion patterns to unroll transfer ops in the vector dialect"); 416 417 PassRegistration<TestVectorTransferFullPartialSplitPatterns> 418 vectorTransformFullPartialPass("test-vector-transfer-full-partial-split", 419 "Test conversion patterns to split " 420 "transfer ops via scf.if + linalg ops"); 421 422 PassRegistration<TestVectorDistributePatterns> distributePass( 423 "test-vector-distribute-patterns", 424 "Test conversion patterns to distribute vector ops in the vector " 425 "dialect"); 426 427 PassRegistration<TestVectorToLoopPatterns> vectorToForLoop( 428 "test-vector-to-forloop", 429 "Test conversion patterns to break up a vector op into a for loop"); 430 431 PassRegistration<TestVectorTransferOpt> transferOpOpt( 432 "test-vector-transferop-opt", 433 "Test optimization transformations for transfer ops"); 434 435 PassRegistration<TestVectorTransferLoweringPatterns> transferOpLoweringPass( 436 "test-vector-transfer-lowering-patterns", 437 "Test conversion patterns to lower transfer ops to other vector ops"); 438 439 PassRegistration<TestVectorMultiReductionLoweringPatterns> 440 multiDimReductionOpLoweringPass( 441 "test-vector-multi-reduction-lowering-patterns", 442 "Test conversion patterns to lower vector.multi_reduction to other " 443 "vector ops"); 444 } 445 } // namespace test 446 } // namespace mlir 447