1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 21 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 22 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/Vector/IR/VectorOps.h" 25 #include "mlir/Pass/PassManager.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/ADT/SmallVector.h" 30 31 using namespace mlir; 32 using namespace mlir::linalg; 33 34 namespace { 35 struct TestLinalgTransforms 36 : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> { 37 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms) 38 39 TestLinalgTransforms() = default; 40 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 41 42 void getDependentDialects(DialectRegistry ®istry) const override { 43 // clang-format off 44 registry.insert<AffineDialect, 45 bufferization::BufferizationDialect, 46 memref::MemRefDialect, 47 scf::SCFDialect, 48 linalg::LinalgDialect, 49 vector::VectorDialect, 50 gpu::GPUDialect>(); 51 // clang-format on 52 } 53 StringRef getArgument() const final { 54 return "test-linalg-transform-patterns"; 55 } 56 StringRef getDescription() const final { 57 return "Test Linalg transformation patterns by applying them greedily."; 58 } 59 60 void runOnOperation() override; 61 62 Option<bool> testPatterns{*this, "test-patterns", 63 llvm::cl::desc("Test a mixed set of patterns"), 64 llvm::cl::init(false)}; 65 Option<bool> testTileAndDistributionOptions{ 66 *this, "test-tile-and-distribute-options", 67 llvm::cl::desc("Test tile and distribute options"), 68 llvm::cl::init(false)}; 69 Option<bool> testTileFuseAndDistributionOptions{ 70 *this, "test-tile-fuse-and-distribute-options", 71 llvm::cl::desc("Test tile, fuse and distribute options"), 72 llvm::cl::init(false)}; 73 Option<bool> testVectorTransferForwardingPatterns{ 74 *this, "test-vector-transfer-forwarding-patterns", 75 llvm::cl::desc( 76 "Test a fused pass that forwards memref.copy to vector.transfer"), 77 llvm::cl::init(false)}; 78 Option<bool> testGenericToVectorPattern{ 79 *this, "test-linalg-to-vector-patterns", 80 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 81 "in vector.contract form"), 82 llvm::cl::init(false)}; 83 Option<bool> testTilePattern{*this, "test-tile-pattern", 84 llvm::cl::desc("Test tile pattern"), 85 llvm::cl::init(false)}; 86 Option<bool> testTileScalarizeDynamicDims{ 87 *this, "test-tile-scalarize-dynamic-dims", 88 llvm::cl::desc("Test tiling of dynamic dims by 1"), 89 llvm::cl::init(false)}; 90 Option<bool> testTransformPadTensor{ 91 *this, "test-transform-pad-tensor", 92 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 93 llvm::cl::init(false)}; 94 Option<bool> testGeneralizePadTensor{ 95 *this, "test-generalize-pad-tensor", 96 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 97 llvm::cl::init(false)}; 98 Option<bool> testSwapSubTensorPadTensor{ 99 *this, "test-swap-subtensor-padtensor", 100 llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into " 101 "tensor.pad(subtensor)"), 102 llvm::cl::init(false)}; 103 Option<bool> testSplitReduction{ 104 *this, "test-split-reduction", 105 llvm::cl::desc("Test split reduction transformation"), 106 llvm::cl::init(false)}; 107 ListOption<int64_t> peeledLoops{ 108 *this, "peeled-loops", 109 llvm::cl::desc("Loops to be peeled when test-tile-pattern")}; 110 ListOption<int64_t> tileSizes{ 111 *this, "tile-sizes", 112 llvm::cl::desc("Linalg tile sizes for test-tile-pattern")}; 113 Option<bool> skipPartial{ 114 *this, "skip-partial", 115 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 116 llvm::cl::init(false)}; 117 Option<std::string> loopType{ 118 *this, "loop-type", 119 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 120 "tiled_loop"), 121 llvm::cl::init("for")}; 122 Option<bool> testBubbleUpExtractSliceOpPattern{ 123 *this, "test-bubble-up-extract-slice-op-pattern", 124 llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " 125 "extract_slice + linalgOp"), 126 llvm::cl::init(false)}; 127 }; 128 } // namespace 129 130 static void applyPatterns(func::FuncOp funcOp) { 131 MLIRContext *ctx = funcOp.getContext(); 132 RewritePatternSet patterns(ctx); 133 134 //===--------------------------------------------------------------------===// 135 // Linalg tiling patterns. 136 //===--------------------------------------------------------------------===// 137 patterns.add<LinalgTilingPattern>( 138 MatmulOp::getOperationName(), ctx, 139 LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 140 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 141 StringAttr::get(ctx, "L3"))); 142 patterns.add<LinalgTilingPattern>( 143 MatmulOp::getOperationName(), ctx, 144 LinalgTilingOptions().setTileSizes({200, 300, 400}), 145 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 146 StringAttr::get(ctx, "L2"))); 147 patterns.add<LinalgTilingPattern>( 148 MatmulOp::getOperationName(), ctx, 149 LinalgTilingOptions().setTileSizes({20, 30, 40}), 150 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 151 StringAttr::get(ctx, "L1"))); 152 patterns.add<LinalgTilingPattern>( 153 MatmulOp::getOperationName(), ctx, 154 LinalgTilingOptions().setTileSizes({2, 3, 4}), 155 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 156 StringAttr::get(ctx, "REG"))); 157 158 patterns.add<LinalgTilingPattern>( 159 MatvecOp::getOperationName(), ctx, 160 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 161 LinalgTilingLoopType::ParallelLoops), 162 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 163 StringAttr::get(ctx, "L1"))); 164 165 patterns.add<LinalgTilingPattern>( 166 DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000), 167 LinalgTransformationFilter( 168 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 169 StringAttr::get(ctx, "L3"), 170 StringAttr::get(ctx, "L2")}, 171 StringAttr::get(ctx, "REG"))); 172 173 //===--------------------------------------------------------------------===// 174 // Linalg tiling and permutation patterns. 175 //===--------------------------------------------------------------------===// 176 patterns.add<LinalgTilingPattern>( 177 MatmulOp::getOperationName(), ctx, 178 LinalgTilingOptions() 179 .setTileSizes({2000, 3000, 4000}) 180 .setInterchange({1, 2, 0}), 181 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 182 StringAttr::get(ctx, "L2__with_perm__"))); 183 patterns.add<LinalgTilingPattern>( 184 MatmulOp::getOperationName(), ctx, 185 LinalgTilingOptions() 186 .setTileSizes({200, 300, 400}) 187 .setInterchange({1, 0, 2}), 188 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 189 StringAttr::get(ctx, "L1__with_perm__"))); 190 patterns.add<LinalgTilingPattern>( 191 MatmulOp::getOperationName(), ctx, 192 LinalgTilingOptions().setTileSizes({20, 30, 40}), 193 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 194 StringAttr::get(ctx, "REG__with_perm__"))); 195 196 patterns.add<LinalgTilingPattern>( 197 MatvecOp::getOperationName(), ctx, 198 LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 199 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 200 StringAttr::get(ctx, "L1__with_perm__"))); 201 202 patterns.add<LinalgTilingPattern>( 203 MatmulOp::getOperationName(), ctx, 204 LinalgTilingOptions() 205 .setTileSizes({16, 8, 4}) 206 .setInterchange({1, 2, 0}) 207 .setLoopType(LinalgTilingLoopType::ParallelLoops), 208 LinalgTransformationFilter( 209 StringAttr::get(ctx, "par__with_perm__"), 210 StringAttr::get(ctx, "after_par__with_perm__"))); 211 212 //===--------------------------------------------------------------------===// 213 // Linalg to loops patterns. 214 //===--------------------------------------------------------------------===// 215 patterns.add<LinalgLoweringPattern<DotOp>>( 216 ctx, 217 /*loweringType=*/LinalgLoweringType::Loops, 218 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 219 220 //===--------------------------------------------------------------------===// 221 // Linalg distribution patterns. 222 //===--------------------------------------------------------------------===// 223 LinalgLoopDistributionOptions distributionOptions; 224 225 //===--------------------------------------------------------------------===// 226 // Linalg to vector contraction patterns. 227 //===--------------------------------------------------------------------===// 228 patterns.add<LinalgVectorizationPattern>( 229 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 230 .addOpFilter<MatmulOp, FillOp, GenericOp>()); 231 patterns.add<CopyVectorizationPattern>(ctx); 232 233 //===--------------------------------------------------------------------===// 234 // Linalg generic interchange pattern. 235 //===--------------------------------------------------------------------===// 236 patterns.add<GenericOpInterchangePattern>( 237 ctx, 238 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 239 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 240 StringAttr::get(ctx, "PERMUTED"))); 241 242 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 243 244 // Drop the marker. 245 funcOp.walk([](LinalgOp op) { 246 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 247 }); 248 } 249 250 template <typename IdOp, typename NProcsOp> 251 static SmallVector<ProcInfo, 2> 252 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 253 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 254 SmallVector<ProcInfo, 2> procInfo(count); 255 Type indexType = b.getIndexType(); 256 for (unsigned i = 0; i < count; ++i) { 257 gpu::Dimension dim = *gpu::symbolizeDimension(i); 258 procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim), 259 b.create<NProcsOp>(loc, indexType, dim)}; 260 } 261 return procInfo; 262 } 263 264 static void fillTileAndDistributePatterns(MLIRContext *context, 265 RewritePatternSet &patterns) { 266 { 267 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 268 cyclicNprocsEqNiters.distributionMethod.resize( 269 2, DistributionMethod::CyclicNumProcsEqNumIters); 270 cyclicNprocsEqNiters.procInfo = 271 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 272 patterns.add<LinalgTilingPattern>( 273 MatmulOp::getOperationName(), context, 274 LinalgTilingOptions() 275 .setTileSizes({8, 8, 4}) 276 .setLoopType(LinalgTilingLoopType::ParallelLoops) 277 .setDistributionOptions(cyclicNprocsEqNiters), 278 LinalgTransformationFilter( 279 StringAttr::get(context, "distribute1"), 280 StringAttr::get(context, "after_distribute1"))); 281 } 282 283 { 284 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 285 cyclicNprocsGeNiters.distributionMethod.resize( 286 2, DistributionMethod::CyclicNumProcsGeNumIters); 287 cyclicNprocsGeNiters.procInfo = 288 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 289 patterns.add<LinalgTilingPattern>( 290 MatmulOp::getOperationName(), context, 291 LinalgTilingOptions() 292 .setTileSizes({8, 8, 4}) 293 .setLoopType(LinalgTilingLoopType::ParallelLoops) 294 .setDistributionOptions(cyclicNprocsGeNiters), 295 LinalgTransformationFilter( 296 StringAttr::get(context, "distribute2"), 297 StringAttr::get(context, "after_distribute2"))); 298 } 299 300 { 301 LinalgLoopDistributionOptions cyclicNprocsDefault; 302 cyclicNprocsDefault.distributionMethod.resize(2, 303 DistributionMethod::Cyclic); 304 cyclicNprocsDefault.procInfo = 305 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 306 patterns.add<LinalgTilingPattern>( 307 MatmulOp::getOperationName(), context, 308 LinalgTilingOptions() 309 .setTileSizes({8, 8, 4}) 310 .setLoopType(LinalgTilingLoopType::ParallelLoops) 311 .setDistributionOptions(cyclicNprocsDefault), 312 LinalgTransformationFilter( 313 StringAttr::get(context, "distribute3"), 314 StringAttr::get(context, "after_distribute3"))); 315 } 316 317 { 318 LinalgLoopDistributionOptions cyclicNprocsMixed1; 319 cyclicNprocsMixed1.distributionMethod = { 320 DistributionMethod::CyclicNumProcsEqNumIters, 321 DistributionMethod::CyclicNumProcsGeNumIters}; 322 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 323 patterns.add<LinalgTilingPattern>( 324 MatmulOp::getOperationName(), context, 325 LinalgTilingOptions() 326 .setTileSizes({8, 8, 4}) 327 .setLoopType(LinalgTilingLoopType::ParallelLoops) 328 .setDistributionOptions(cyclicNprocsMixed1), 329 LinalgTransformationFilter( 330 StringAttr::get(context, "distribute4"), 331 StringAttr::get(context, "after_distribute4"))); 332 } 333 334 { 335 LinalgLoopDistributionOptions cyclicNprocsMixed2; 336 cyclicNprocsMixed2.distributionMethod = { 337 DistributionMethod::CyclicNumProcsGeNumIters, 338 DistributionMethod::Cyclic}; 339 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 340 patterns.add<LinalgTilingPattern>( 341 MatmulOp::getOperationName(), context, 342 LinalgTilingOptions() 343 .setTileSizes({8, 8, 4}) 344 .setLoopType(LinalgTilingLoopType::ParallelLoops) 345 .setDistributionOptions(cyclicNprocsMixed2), 346 LinalgTransformationFilter( 347 StringAttr::get(context, "distribute5"), 348 StringAttr::get(context, "after_distribute5"))); 349 } 350 351 { 352 LinalgLoopDistributionOptions cyclicNprocsMixed3; 353 cyclicNprocsMixed3.distributionMethod = { 354 DistributionMethod::Cyclic, 355 DistributionMethod::CyclicNumProcsEqNumIters}; 356 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 357 358 patterns.add<LinalgTilingPattern>( 359 MatmulOp::getOperationName(), context, 360 LinalgTilingOptions() 361 .setTileSizes({8, 8, 4}) 362 .setLoopType(LinalgTilingLoopType::ParallelLoops) 363 .setDistributionOptions(cyclicNprocsMixed3), 364 LinalgTransformationFilter( 365 StringAttr::get(context, "distribute6"), 366 StringAttr::get(context, "after_distribute6"))); 367 } 368 369 { 370 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 371 cyclicNprocsEqNiters.distributionMethod.resize(2, 372 DistributionMethod::Cyclic); 373 cyclicNprocsEqNiters.procInfo = 374 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 375 patterns.add<LinalgTilingPattern>( 376 MatmulOp::getOperationName(), context, 377 LinalgTilingOptions() 378 .setTileSizes({8, 8, 4}) 379 .setLoopType(LinalgTilingLoopType::Loops) 380 .setDistributionOptions(cyclicNprocsEqNiters), 381 LinalgTransformationFilter( 382 StringAttr::get(context, "tensors_distribute1"), 383 StringAttr::get(context, "tensors_after_distribute1"))); 384 } 385 } 386 387 static void fillTileFuseAndDistributePatterns(MLIRContext *context, 388 RewritePatternSet &patterns) { 389 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 390 cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); 391 cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 392 patterns.add<LinalgTileAndFuseTensorOpsPattern>( 393 MatmulOp::getOperationName(), context, 394 LinalgTilingAndFusionOptions() 395 .setTileSizes({8, 8, 4}) 396 .setDistributionOptions(cyclicNprocsEqNiters), 397 LinalgTransformationFilter( 398 StringAttr::get(context, "tensors_fuse_distribute1"), 399 StringAttr::get(context, "tensors_after_fuse_distribute1"))); 400 } 401 402 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { 403 RewritePatternSet forwardPattern(funcOp.getContext()); 404 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 405 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 406 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 407 } 408 409 static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { 410 RewritePatternSet patterns(funcOp.getContext()); 411 auto *ctx = funcOp.getContext(); 412 patterns.add<LinalgVectorizationPattern>( 413 ctx, LinalgTransformationFilter() 414 .addOpFilter<ContractionOpInterface, FillOp, GenericOp>()); 415 patterns.add<CopyVectorizationPattern>(ctx); 416 populatePadOpVectorizationPatterns(patterns); 417 populateConvolutionVectorizationPatterns(patterns); 418 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 419 } 420 421 static void applyPadTensorToGenericPatterns(func::FuncOp funcOp) { 422 RewritePatternSet patterns(funcOp.getContext()); 423 patterns.add<PadOpTransformationPattern>(funcOp.getContext()); 424 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 425 } 426 427 static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) { 428 RewritePatternSet patterns(funcOp.getContext()); 429 patterns.add<GeneralizePadOpPattern>(funcOp.getContext()); 430 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 431 } 432 433 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { 434 RewritePatternSet patterns(funcOp.getContext()); 435 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 436 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 437 } 438 439 static void applyTilePattern(func::FuncOp funcOp, const std::string &loopType, 440 ArrayRef<int64_t> tileSizes, 441 ArrayRef<int64_t> peeledLoops, 442 bool scalarizeDynamicDims) { 443 MLIRContext *context = funcOp.getContext(); 444 RewritePatternSet tilingPattern(context); 445 LinalgTilingLoopType type = 446 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 447 .Case("for", LinalgTilingLoopType::Loops) 448 .Case("affine", LinalgTilingLoopType::AffineLoops) 449 .Case("parallel", LinalgTilingLoopType::ParallelLoops); 450 auto linalgTilingOptions = linalg::LinalgTilingOptions() 451 .setPeeledLoops(peeledLoops) 452 .setLoopType(type); 453 if (scalarizeDynamicDims) { 454 linalgTilingOptions.scalarizeDynamicDims(); 455 assert(tileSizes.empty() && 456 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 457 } else { 458 linalgTilingOptions.setTileSizes(tileSizes); 459 } 460 linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile")); 461 TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert( 462 tilingPattern, linalgTilingOptions, f); 463 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 464 } 465 466 static void applySplitReduction(func::FuncOp funcOp) { 467 RewritePatternSet patterns(funcOp.getContext()); 468 linalg::populateSplitReductionPattern( 469 patterns, 470 [](LinalgOp op) { 471 unsigned insertDimIndex = op.getNumLoops() - 1; 472 return std::make_pair(4, insertDimIndex); 473 }, 474 LinalgTransformationFilter( 475 ArrayRef<StringAttr>{}, 476 StringAttr::get(funcOp.getContext(), "SPLIT"))); 477 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 478 } 479 480 static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) { 481 RewritePatternSet patterns(funcOp.getContext()); 482 populateBubbleUpExtractSliceOpPatterns(patterns); 483 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 484 } 485 486 /// Apply transformations specified as patterns. 487 void TestLinalgTransforms::runOnOperation() { 488 auto lambda = [&](void *) { 489 getOperation().walk([](LinalgOp op) { 490 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 491 }); 492 }; 493 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 494 495 if (testTileAndDistributionOptions) { 496 RewritePatternSet patterns(&getContext()); 497 fillTileAndDistributePatterns(&getContext(), patterns); 498 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 499 return; 500 } 501 if (testTileFuseAndDistributionOptions) { 502 RewritePatternSet patterns(&getContext()); 503 fillTileFuseAndDistributePatterns(&getContext(), patterns); 504 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 505 return; 506 } 507 if (testPatterns) 508 return applyPatterns(getOperation()); 509 if (testVectorTransferForwardingPatterns) 510 return applyVectorTransferForwardingPatterns(getOperation()); 511 if (testGenericToVectorPattern) 512 return applyLinalgToVectorPatterns(getOperation()); 513 if (testTransformPadTensor) 514 return applyPadTensorToGenericPatterns(getOperation()); 515 if (testGeneralizePadTensor) 516 return applyGeneralizePadTensorPatterns(getOperation()); 517 if (testSwapSubTensorPadTensor) 518 return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 519 if (testTilePattern) 520 return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops, 521 /*scalarizeDynamicDims=*/false); 522 if (testTileScalarizeDynamicDims) 523 return applyTilePattern(getOperation(), loopType, tileSizes, 524 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 525 if (testSplitReduction) 526 return applySplitReduction(getOperation()); 527 if (testBubbleUpExtractSliceOpPattern) 528 return applyBubbleUpExtractSliceOpPattern(getOperation()); 529 } 530 531 namespace mlir { 532 namespace test { 533 void registerTestLinalgTransforms() { 534 PassRegistration<TestLinalgTransforms>(); 535 } 536 } // namespace test 537 } // namespace mlir 538