1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallVector.h" 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 namespace { 34 struct TestLinalgTransforms 35 : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> { 36 TestLinalgTransforms() = default; 37 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 38 39 void getDependentDialects(DialectRegistry ®istry) const override { 40 // clang-format off 41 registry.insert<AffineDialect, 42 memref::MemRefDialect, 43 scf::SCFDialect, 44 StandardOpsDialect, 45 linalg::LinalgDialect, 46 vector::VectorDialect, 47 gpu::GPUDialect>(); 48 // clang-format on 49 } 50 StringRef getArgument() const final { 51 return "test-linalg-transform-patterns"; 52 } 53 StringRef getDescription() const final { 54 return "Test Linalg transformation patterns by applying them greedily."; 55 } 56 57 void runOnOperation() override; 58 59 Option<bool> testPatterns{*this, "test-patterns", 60 llvm::cl::desc("Test a mixed set of patterns"), 61 llvm::cl::init(false)}; 62 Option<bool> testMatmulToVectorPatterns1dTiling{ 63 *this, "test-matmul-to-vector-patterns-tile-1d", 64 llvm::cl::desc( 65 "Test a fused pass that applies patterns from matmul to vectors via " 66 "1-d tiling"), 67 llvm::cl::init(false)}; 68 Option<bool> testMatmulToVectorPatterns2dTiling{ 69 *this, "test-matmul-to-vector-patterns-tile-2d", 70 llvm::cl::desc( 71 "Test a fused pass that applies patterns from matmul to vectors via " 72 "2-d tiling"), 73 llvm::cl::init(false)}; 74 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 75 llvm::cl::desc("Test promotion options"), 76 llvm::cl::init(false)}; 77 Option<bool> testTileAndDistributionOptions{ 78 *this, "test-tile-and-distribute-options", 79 llvm::cl::desc("Test tile and distribute options"), 80 llvm::cl::init(false)}; 81 Option<bool> testVectorTransferForwardingPatterns{ 82 *this, "test-vector-transfer-forwarding-patterns", 83 llvm::cl::desc( 84 "Test a fused pass that forwards linalg.copy to vector.transfer"), 85 llvm::cl::init(false)}; 86 Option<bool> testGenericToVectorPattern{ 87 *this, "test-linalg-to-vector-patterns", 88 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 89 "in vector.contract form"), 90 llvm::cl::init(false)}; 91 Option<bool> testTilePattern{*this, "test-tile-pattern", 92 llvm::cl::desc("Test tile pattern"), 93 llvm::cl::init(false)}; 94 Option<bool> testTileScalarizeDynamicDims{ 95 *this, "test-tile-scalarize-dynamic-dims", 96 llvm::cl::desc("Test tiling of dynamic dims by 1"), 97 llvm::cl::init(false)}; 98 Option<bool> testTransformPadTensor{ 99 *this, "test-transform-pad-tensor", 100 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 101 llvm::cl::init(false)}; 102 Option<bool> testGeneralizePadTensor{ 103 *this, "test-generalize-pad-tensor", 104 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 105 llvm::cl::init(false)}; 106 Option<bool> testSwapSubTensorPadTensor{ 107 *this, "test-swap-subtensor-padtensor", 108 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 109 "pad_tensor(subtensor)"), 110 llvm::cl::init(false)}; 111 ListOption<int64_t> peeledLoops{ 112 *this, "peeled-loops", 113 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 114 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 115 ListOption<int64_t> tileSizes{ 116 *this, "tile-sizes", 117 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 118 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 119 ListOption<unsigned> testTiledLoopPeeling{ 120 *this, "test-tiled-loop-peeling", 121 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 122 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 123 Option<bool> skipPartial{ 124 *this, "skip-partial", 125 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 126 llvm::cl::init(false)}; 127 Option<std::string> loopType{ 128 *this, "loop-type", 129 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 130 "tiled_loop"), 131 llvm::cl::init("for")}; 132 }; 133 } // namespace 134 135 static void applyPatterns(FuncOp funcOp) { 136 MLIRContext *ctx = funcOp.getContext(); 137 RewritePatternSet patterns(ctx); 138 139 //===--------------------------------------------------------------------===// 140 // Linalg tiling patterns. 141 //===--------------------------------------------------------------------===// 142 patterns.add<LinalgTilingPattern>( 143 MatmulOp::getOperationName(), ctx, 144 LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 145 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 146 StringAttr::get(ctx, "L3"))); 147 patterns.add<LinalgTilingPattern>( 148 MatmulOp::getOperationName(), ctx, 149 LinalgTilingOptions().setTileSizes({200, 300, 400}), 150 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 151 StringAttr::get(ctx, "L2"))); 152 patterns.add<LinalgTilingPattern>( 153 MatmulOp::getOperationName(), ctx, 154 LinalgTilingOptions().setTileSizes({20, 30, 40}), 155 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 156 StringAttr::get(ctx, "L1"))); 157 patterns.add<LinalgTilingPattern>( 158 MatmulOp::getOperationName(), ctx, 159 LinalgTilingOptions().setTileSizes({2, 3, 4}), 160 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 161 StringAttr::get(ctx, "REG"))); 162 163 patterns.add<LinalgTilingPattern>( 164 MatvecOp::getOperationName(), ctx, 165 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 166 LinalgTilingLoopType::ParallelLoops), 167 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 168 StringAttr::get(ctx, "L1"))); 169 170 patterns.add<LinalgTilingPattern>( 171 DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000), 172 LinalgTransformationFilter( 173 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 174 StringAttr::get(ctx, "L3"), 175 StringAttr::get(ctx, "L2")}, 176 StringAttr::get(ctx, "REG"))); 177 178 //===--------------------------------------------------------------------===// 179 // Linalg tiling and permutation patterns. 180 //===--------------------------------------------------------------------===// 181 patterns.add<LinalgTilingPattern>( 182 MatmulOp::getOperationName(), ctx, 183 LinalgTilingOptions() 184 .setTileSizes({2000, 3000, 4000}) 185 .setInterchange({1, 2, 0}), 186 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 187 StringAttr::get(ctx, "L2__with_perm__"))); 188 patterns.add<LinalgTilingPattern>( 189 MatmulOp::getOperationName(), ctx, 190 LinalgTilingOptions() 191 .setTileSizes({200, 300, 400}) 192 .setInterchange({1, 0, 2}), 193 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 194 StringAttr::get(ctx, "L1__with_perm__"))); 195 patterns.add<LinalgTilingPattern>( 196 MatmulOp::getOperationName(), ctx, 197 LinalgTilingOptions().setTileSizes({20, 30, 40}), 198 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 199 StringAttr::get(ctx, "REG__with_perm__"))); 200 201 patterns.add<LinalgTilingPattern>( 202 MatvecOp::getOperationName(), ctx, 203 LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 204 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 205 StringAttr::get(ctx, "L1__with_perm__"))); 206 207 patterns.add<LinalgTilingPattern>( 208 MatmulOp::getOperationName(), ctx, 209 LinalgTilingOptions() 210 .setTileSizes({16, 8, 4}) 211 .setInterchange({1, 2, 0}) 212 .setLoopType(LinalgTilingLoopType::ParallelLoops), 213 LinalgTransformationFilter( 214 StringAttr::get(ctx, "par__with_perm__"), 215 StringAttr::get(ctx, "after_par__with_perm__"))); 216 217 //===--------------------------------------------------------------------===// 218 // Linalg to loops patterns. 219 //===--------------------------------------------------------------------===// 220 patterns.add<LinalgLoweringPattern<DotOp>>( 221 ctx, 222 /*loweringType=*/LinalgLoweringType::Loops, 223 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 224 225 //===--------------------------------------------------------------------===// 226 // Linalg distribution patterns. 227 //===--------------------------------------------------------------------===// 228 LinalgLoopDistributionOptions distributionOptions; 229 230 //===--------------------------------------------------------------------===// 231 // Linalg to vector contraction patterns. 232 //===--------------------------------------------------------------------===// 233 patterns.add<LinalgVectorizationPattern>( 234 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 235 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 236 237 //===--------------------------------------------------------------------===// 238 // Linalg generic interchange pattern. 239 //===--------------------------------------------------------------------===// 240 patterns.add<GenericOpInterchangePattern>( 241 ctx, 242 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 243 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 244 StringAttr::get(ctx, "PERMUTED"))); 245 246 //===--------------------------------------------------------------------===// 247 // Linalg subview operands promotion. 248 //===--------------------------------------------------------------------===// 249 patterns.add<LinalgPromotionPattern<MatmulOp>>( 250 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 251 LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), 252 StringAttr::get(ctx, "_views_promoted_"))); 253 patterns.add<LinalgPromotionPattern<MatmulOp>>( 254 ctx, 255 LinalgPromotionOptions() 256 .setOperandsToPromote({0}) 257 .setUseFullTileBuffersByDefault(true), 258 LinalgTransformationFilter( 259 StringAttr::get(ctx, "_promote_first_view_"), 260 StringAttr::get(ctx, "_first_view_promoted_"))); 261 patterns.add<LinalgPromotionPattern<FillOp>>( 262 ctx, 263 LinalgPromotionOptions() 264 .setOperandsToPromote({1}) 265 .setUseFullTileBuffers({false, true}) 266 .setAlignment(32), 267 LinalgTransformationFilter( 268 StringAttr::get(ctx, "_promote_views_aligned_"), 269 StringAttr::get(ctx, "_views_aligned_promoted_"))); 270 271 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 272 273 // Drop the marker. 274 funcOp.walk([](LinalgOp op) { 275 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 276 }); 277 } 278 279 static void fillL1TilingAndMatmulToVectorPatterns( 280 FuncOp funcOp, StringRef startMarker, 281 SmallVectorImpl<RewritePatternSet> &patternsVector) { 282 MLIRContext *ctx = funcOp.getContext(); 283 patternsVector.emplace_back( 284 ctx, std::make_unique<LinalgTilingPattern>( 285 MatmulOp::getOperationName(), ctx, 286 LinalgTilingOptions() 287 .setTileSizes({8, 12, 16}) 288 .setInterchange({1, 0, 2}), 289 LinalgTransformationFilter(StringAttr::get(ctx, startMarker), 290 StringAttr::get(ctx, "L1")))); 291 292 patternsVector.emplace_back( 293 ctx, 294 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 295 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 296 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 297 StringAttr::get(ctx, "VEC")))); 298 299 patternsVector.emplace_back( 300 ctx, std::make_unique<LinalgVectorizationPattern>( 301 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 302 LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); 303 patternsVector.back().add<LinalgVectorizationPattern>( 304 ctx, LinalgTransformationFilter().addOpFilter<FillOp, CopyOp>()); 305 } 306 307 //===----------------------------------------------------------------------===// 308 // Test promotion callbacks 309 //===----------------------------------------------------------------------===// 310 311 // Allocation call back 312 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 313 ArrayRef<Value> boundingSubViewSize, 314 DataLayout &layout) { 315 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 316 return b 317 .create<memref::AllocOp>( 318 subView.getLoc(), 319 MemRefType::get(shape, subView.getType().getElementType(), 320 /*affineMapComposition =*/{}, 3), 321 boundingSubViewSize) 322 .getResult(); 323 } 324 325 // Deallocation callback 326 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 327 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 328 return success(); 329 } 330 331 // Copy in call back 332 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 333 bool isOutput) { 334 auto floatType = src.getType().cast<MemRefType>().getElementType(); 335 if (!floatType.isa<FloatType>()) 336 return failure(); 337 if (!isOutput) { 338 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 339 FloatAttr::get(floatType, 42.0)); 340 b.create<FillOp>(src.getLoc(), cst, dst); 341 } 342 b.create<CopyOp>(src.getLoc(), src, dst); 343 return success(); 344 } 345 346 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 347 RewritePatternSet &patterns) { 348 patterns.add<LinalgTilingPattern>( 349 MatmulOp::getOperationName(), ctx, 350 LinalgTilingOptions().setTileSizes({16, 16, 16}), 351 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 352 StringAttr::get(ctx, "PROMOTE"))); 353 patterns.add<LinalgPromotionPattern<MatmulOp>>( 354 ctx, 355 LinalgPromotionOptions() 356 .setOperandsToPromote({0, 2}) 357 .setUseFullTileBuffers({false, false}) 358 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 359 .setCopyInOutFns( 360 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 361 return copyCallBackFn(b, src, dst, false); 362 }, 363 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 364 return copyCallBackFn(b, src, dst, true); 365 }), 366 LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); 367 } 368 369 template <typename IdOp, typename NProcsOp> 370 static SmallVector<ProcInfo, 2> 371 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 372 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 373 SmallVector<ProcInfo, 2> procInfo(count); 374 Type indexType = b.getIndexType(); 375 for (unsigned i = 0; i < count; ++i) { 376 gpu::Dimension dim = *gpu::symbolizeDimension(i); 377 procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim), 378 b.create<NProcsOp>(loc, indexType, dim)}; 379 } 380 return procInfo; 381 } 382 383 static void fillTileAndDistributePatterns(MLIRContext *context, 384 RewritePatternSet &patterns) { 385 { 386 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 387 cyclicNprocsEqNiters.distributionMethod.resize( 388 2, DistributionMethod::CyclicNumProcsEqNumIters); 389 cyclicNprocsEqNiters.procInfo = 390 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 391 patterns.add<LinalgTilingPattern>( 392 MatmulOp::getOperationName(), context, 393 LinalgTilingOptions() 394 .setTileSizes({8, 8, 4}) 395 .setLoopType(LinalgTilingLoopType::ParallelLoops) 396 .setDistributionOptions(cyclicNprocsEqNiters), 397 LinalgTransformationFilter( 398 StringAttr::get(context, "distribute1"), 399 StringAttr::get(context, "after_distribute1"))); 400 } 401 402 { 403 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 404 cyclicNprocsGeNiters.distributionMethod.resize( 405 2, DistributionMethod::CyclicNumProcsGeNumIters); 406 cyclicNprocsGeNiters.procInfo = 407 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 408 patterns.add<LinalgTilingPattern>( 409 MatmulOp::getOperationName(), context, 410 LinalgTilingOptions() 411 .setTileSizes({8, 8, 4}) 412 .setLoopType(LinalgTilingLoopType::ParallelLoops) 413 .setDistributionOptions(cyclicNprocsGeNiters), 414 LinalgTransformationFilter( 415 StringAttr::get(context, "distribute2"), 416 StringAttr::get(context, "after_distribute2"))); 417 } 418 419 { 420 LinalgLoopDistributionOptions cyclicNprocsDefault; 421 cyclicNprocsDefault.distributionMethod.resize(2, 422 DistributionMethod::Cyclic); 423 cyclicNprocsDefault.procInfo = 424 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 425 patterns.add<LinalgTilingPattern>( 426 MatmulOp::getOperationName(), context, 427 LinalgTilingOptions() 428 .setTileSizes({8, 8, 4}) 429 .setLoopType(LinalgTilingLoopType::ParallelLoops) 430 .setDistributionOptions(cyclicNprocsDefault), 431 LinalgTransformationFilter( 432 StringAttr::get(context, "distribute3"), 433 StringAttr::get(context, "after_distribute3"))); 434 } 435 436 { 437 LinalgLoopDistributionOptions cyclicNprocsMixed1; 438 cyclicNprocsMixed1.distributionMethod = { 439 DistributionMethod::CyclicNumProcsEqNumIters, 440 DistributionMethod::CyclicNumProcsGeNumIters}; 441 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 442 patterns.add<LinalgTilingPattern>( 443 MatmulOp::getOperationName(), context, 444 LinalgTilingOptions() 445 .setTileSizes({8, 8, 4}) 446 .setLoopType(LinalgTilingLoopType::ParallelLoops) 447 .setDistributionOptions(cyclicNprocsMixed1), 448 LinalgTransformationFilter( 449 StringAttr::get(context, "distribute4"), 450 StringAttr::get(context, "after_distribute4"))); 451 } 452 453 { 454 LinalgLoopDistributionOptions cyclicNprocsMixed2; 455 cyclicNprocsMixed2.distributionMethod = { 456 DistributionMethod::CyclicNumProcsGeNumIters, 457 DistributionMethod::Cyclic}; 458 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 459 patterns.add<LinalgTilingPattern>( 460 MatmulOp::getOperationName(), context, 461 LinalgTilingOptions() 462 .setTileSizes({8, 8, 4}) 463 .setLoopType(LinalgTilingLoopType::ParallelLoops) 464 .setDistributionOptions(cyclicNprocsMixed2), 465 LinalgTransformationFilter( 466 StringAttr::get(context, "distribute5"), 467 StringAttr::get(context, "after_distribute5"))); 468 } 469 470 { 471 LinalgLoopDistributionOptions cyclicNprocsMixed3; 472 cyclicNprocsMixed3.distributionMethod = { 473 DistributionMethod::Cyclic, 474 DistributionMethod::CyclicNumProcsEqNumIters}; 475 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 476 477 patterns.add<LinalgTilingPattern>( 478 MatmulOp::getOperationName(), context, 479 LinalgTilingOptions() 480 .setTileSizes({8, 8, 4}) 481 .setLoopType(LinalgTilingLoopType::ParallelLoops) 482 .setDistributionOptions(cyclicNprocsMixed3), 483 LinalgTransformationFilter( 484 StringAttr::get(context, "distribute6"), 485 StringAttr::get(context, "after_distribute6"))); 486 } 487 488 { 489 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 490 cyclicNprocsEqNiters.distributionMethod.resize(2, 491 DistributionMethod::Cyclic); 492 cyclicNprocsEqNiters.procInfo = 493 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 494 patterns.add<LinalgTilingPattern>( 495 MatmulOp::getOperationName(), context, 496 LinalgTilingOptions() 497 .setTileSizes({8, 8, 4}) 498 .setLoopType(LinalgTilingLoopType::Loops) 499 .setDistributionOptions(cyclicNprocsEqNiters), 500 LinalgTransformationFilter( 501 StringAttr::get(context, "tensors_distribute1"), 502 StringAttr::get(context, "tensors_after_distribute1"))); 503 } 504 } 505 506 static void 507 applyMatmulToVectorPatterns(FuncOp funcOp, 508 bool testMatmulToVectorPatterns1dTiling, 509 bool testMatmulToVectorPatterns2dTiling) { 510 MLIRContext *ctx = funcOp.getContext(); 511 SmallVector<RewritePatternSet, 4> stage1Patterns; 512 if (testMatmulToVectorPatterns1dTiling) { 513 fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); 514 } else if (testMatmulToVectorPatterns2dTiling) { 515 stage1Patterns.emplace_back( 516 ctx, std::make_unique<LinalgTilingPattern>( 517 MatmulOp::getOperationName(), ctx, 518 LinalgTilingOptions() 519 .setTileSizes({768, 264, 768}) 520 .setInterchange({1, 2, 0}), 521 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 522 StringAttr::get(ctx, "L2")))); 523 fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); 524 } 525 { 526 // Canonicalization patterns 527 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 528 vector::populateVectorTransferPermutationMapLoweringPatterns( 529 canonicalizationPatterns); 530 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 531 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 532 } 533 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 534 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 535 FrozenRewritePatternSet stage2Patterns = 536 getLinalgTilingCanonicalizationPatterns(ctx); 537 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns); 538 } 539 540 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 541 RewritePatternSet forwardPattern(funcOp.getContext()); 542 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 543 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 544 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 545 } 546 547 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 548 RewritePatternSet patterns(funcOp.getContext()); 549 patterns.add<LinalgVectorizationPattern>( 550 funcOp.getContext(), 551 LinalgTransformationFilter() 552 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 553 populatePadOpVectorizationPatterns(patterns); 554 populateConvolutionVectorizationPatterns(patterns); 555 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 556 } 557 558 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 559 RewritePatternSet patterns(funcOp.getContext()); 560 patterns.add<PadOpTransformationPattern>(funcOp.getContext()); 561 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 562 } 563 564 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 565 RewritePatternSet patterns(funcOp.getContext()); 566 patterns.add<GeneralizePadOpPattern>(funcOp.getContext()); 567 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 568 } 569 570 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 571 RewritePatternSet patterns(funcOp.getContext()); 572 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 573 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 574 } 575 576 static void applyTilePattern(FuncOp funcOp, const std::string &loopType, 577 ArrayRef<int64_t> tileSizes, 578 ArrayRef<int64_t> peeledLoops, 579 bool scalarizeDynamicDims) { 580 MLIRContext *context = funcOp.getContext(); 581 RewritePatternSet tilingPattern(context); 582 LinalgTilingLoopType type = 583 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 584 .Case("for", LinalgTilingLoopType::Loops) 585 .Case("affine", LinalgTilingLoopType::AffineLoops) 586 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 587 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 588 auto linalgTilingOptions = linalg::LinalgTilingOptions() 589 .setPeeledLoops(peeledLoops) 590 .setLoopType(type); 591 if (scalarizeDynamicDims) { 592 linalgTilingOptions.scalarizeDynamicDims(); 593 assert(tileSizes.empty() && 594 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 595 } else { 596 linalgTilingOptions.setTileSizes(tileSizes); 597 } 598 linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile")); 599 TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert( 600 tilingPattern, linalgTilingOptions, f); 601 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 602 } 603 604 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 605 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 606 607 namespace { 608 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 609 /// `idx`-th loop contains only "full" iterations and a second loop for the 610 /// remaining partial iteration (if any). 611 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 612 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 613 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 614 } 615 616 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 617 PatternRewriter &rewriter) const override { 618 SmallVector<int64_t> peeledLoops; 619 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 620 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 621 peeledLoops = 622 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 623 return attr.cast<IntegerAttr>().getInt(); 624 })); 625 // Check if the loop was already peeled. 626 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 627 return failure(); 628 } 629 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 630 // No peeling of loop nests with a partial iteration. 631 return failure(); 632 633 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 634 return failure(); 635 636 // Peel loop and canonicalize. 637 TiledLoopOp result; 638 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 639 result))) 640 return failure(); 641 642 // Apply label, so that the same loop is not rewritten a second time. 643 peeledLoops.push_back(idx); 644 rewriter.updateRootInPlace(loopOp, [&]() { 645 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 646 }); 647 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 648 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 649 650 return success(); 651 } 652 653 /// Index of loop to peel. 654 int64_t idx; 655 656 /// If set to true, do not peel TiledLoopOps with a partial iteration. 657 bool skipPartial; 658 }; 659 } // namespace 660 661 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 662 ArrayRef<unsigned> loops, 663 bool skipPartial) { 664 MLIRContext *ctx = funcOp.getContext(); 665 RewritePatternSet patterns(ctx); 666 for (unsigned idx : loops) 667 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 668 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 669 670 // Drop the markers. 671 funcOp.walk([](TiledLoopOp op) { 672 op->removeAttr(kPeeledLoopsLabel); 673 op->removeAttr(kPartialIterationLabel); 674 }); 675 } 676 677 /// Apply transformations specified as patterns. 678 void TestLinalgTransforms::runOnOperation() { 679 auto lambda = [&](void *) { 680 getOperation().walk([](LinalgOp op) { 681 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 682 }); 683 }; 684 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 685 686 if (testPromotionOptions) { 687 RewritePatternSet patterns(&getContext()); 688 fillPromotionCallBackPatterns(&getContext(), patterns); 689 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 690 return; 691 } 692 if (testTileAndDistributionOptions) { 693 RewritePatternSet patterns(&getContext()); 694 fillTileAndDistributePatterns(&getContext(), patterns); 695 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 696 return; 697 } 698 if (testPatterns) 699 return applyPatterns(getOperation()); 700 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 701 return applyMatmulToVectorPatterns(getOperation(), 702 testMatmulToVectorPatterns1dTiling, 703 testMatmulToVectorPatterns2dTiling); 704 if (testVectorTransferForwardingPatterns) 705 return applyVectorTransferForwardingPatterns(getOperation()); 706 if (testGenericToVectorPattern) 707 return applyLinalgToVectorPatterns(getOperation()); 708 if (testTransformPadTensor) 709 return applyPadTensorToGenericPatterns(getOperation()); 710 if (testGeneralizePadTensor) 711 return applyGeneralizePadTensorPatterns(getOperation()); 712 if (testSwapSubTensorPadTensor) 713 return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 714 if (testTiledLoopPeeling.hasValue()) 715 return applyTiledLoopPeelingPattern(getOperation(), testTiledLoopPeeling, 716 skipPartial); 717 if (testTilePattern) 718 return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops, 719 /*scalarizeDynamicDims=*/false); 720 if (testTileScalarizeDynamicDims) 721 return applyTilePattern(getOperation(), loopType, tileSizes, 722 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 723 } 724 725 namespace mlir { 726 namespace test { 727 void registerTestLinalgTransforms() { 728 PassRegistration<TestLinalgTransforms>(); 729 } 730 } // namespace test 731 } // namespace mlir 732