1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallVector.h" 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 namespace { 34 struct TestLinalgTransforms 35 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 36 TestLinalgTransforms() = default; 37 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 38 39 void getDependentDialects(DialectRegistry ®istry) const override { 40 // clang-format off 41 registry.insert<AffineDialect, 42 memref::MemRefDialect, 43 scf::SCFDialect, 44 StandardOpsDialect, 45 vector::VectorDialect, 46 gpu::GPUDialect>(); 47 // clang-format on 48 } 49 StringRef getArgument() const final { 50 return "test-linalg-transform-patterns"; 51 } 52 StringRef getDescription() const final { 53 return "Test Linalg transformation patterns by applying them greedily."; 54 } 55 56 void runOnFunction() override; 57 58 Option<bool> testPatterns{*this, "test-patterns", 59 llvm::cl::desc("Test a mixed set of patterns"), 60 llvm::cl::init(false)}; 61 Option<bool> testMatmulToVectorPatterns1dTiling{ 62 *this, "test-matmul-to-vector-patterns-tile-1d", 63 llvm::cl::desc( 64 "Test a fused pass that applies patterns from matmul to vectors via " 65 "1-d tiling"), 66 llvm::cl::init(false)}; 67 Option<bool> testMatmulToVectorPatterns2dTiling{ 68 *this, "test-matmul-to-vector-patterns-tile-2d", 69 llvm::cl::desc( 70 "Test a fused pass that applies patterns from matmul to vectors via " 71 "2-d tiling"), 72 llvm::cl::init(false)}; 73 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 74 llvm::cl::desc("Test promotion options"), 75 llvm::cl::init(false)}; 76 Option<bool> testTileAndDistributionOptions{ 77 *this, "test-tile-and-distribute-options", 78 llvm::cl::desc("Test tile and distribute options"), 79 llvm::cl::init(false)}; 80 Option<bool> testVectorTransferForwardingPatterns{ 81 *this, "test-vector-transfer-forwarding-patterns", 82 llvm::cl::desc( 83 "Test a fused pass that forwards linalg.copy to vector.transfer"), 84 llvm::cl::init(false)}; 85 Option<bool> testGenericToVectorPattern{ 86 *this, "test-linalg-to-vector-patterns", 87 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 88 "in vector.contract form"), 89 llvm::cl::init(false)}; 90 Option<bool> testTilePattern{*this, "test-tile-pattern", 91 llvm::cl::desc("Test tile pattern"), 92 llvm::cl::init(false)}; 93 Option<bool> testTileScalarizeDynamicDims{ 94 *this, "test-tile-scalarize-dynamic-dims", 95 llvm::cl::desc("Test tiling of dynamic dims by 1"), 96 llvm::cl::init(false)}; 97 Option<bool> testTransformPadTensor{ 98 *this, "test-transform-pad-tensor", 99 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 100 llvm::cl::init(false)}; 101 Option<bool> testGeneralizePadTensor{ 102 *this, "test-generalize-pad-tensor", 103 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 104 llvm::cl::init(false)}; 105 Option<bool> testSwapSubTensorPadTensor{ 106 *this, "test-swap-subtensor-padtensor", 107 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 108 "pad_tensor(subtensor)"), 109 llvm::cl::init(false)}; 110 ListOption<int64_t> peeledLoops{ 111 *this, "peeled-loops", 112 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 113 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 114 ListOption<int64_t> tileSizes{ 115 *this, "tile-sizes", 116 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 117 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 118 ListOption<unsigned> testTiledLoopPeeling{ 119 *this, "test-tiled-loop-peeling", 120 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 121 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 122 Option<bool> skipPartial{ 123 *this, "skip-partial", 124 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 125 llvm::cl::init(false)}; 126 Option<std::string> loopType{ 127 *this, "loop-type", 128 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 129 "tiled_loop"), 130 llvm::cl::init("for")}; 131 }; 132 } // end anonymous namespace 133 134 static void applyPatterns(FuncOp funcOp) { 135 MLIRContext *ctx = funcOp.getContext(); 136 RewritePatternSet patterns(ctx); 137 138 //===--------------------------------------------------------------------===// 139 // Linalg tiling patterns. 140 //===--------------------------------------------------------------------===// 141 patterns.add<LinalgTilingPattern<MatmulOp>>( 142 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 143 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 144 StringAttr::get(ctx, "L3"))); 145 patterns.add<LinalgTilingPattern<MatmulOp>>( 146 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 147 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 148 StringAttr::get(ctx, "L2"))); 149 patterns.add<LinalgTilingPattern<MatmulOp>>( 150 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 151 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 152 StringAttr::get(ctx, "L1"))); 153 patterns.add<LinalgTilingPattern<MatmulOp>>( 154 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 155 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 156 StringAttr::get(ctx, "REG"))); 157 158 patterns.add<LinalgTilingPattern<MatvecOp>>( 159 ctx, 160 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 161 LinalgTilingLoopType::ParallelLoops), 162 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 163 StringAttr::get(ctx, "L1"))); 164 165 patterns.add<LinalgTilingPattern<DotOp>>( 166 ctx, LinalgTilingOptions().setTileSizes(8000), 167 LinalgTransformationFilter( 168 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 169 StringAttr::get(ctx, "L3"), 170 StringAttr::get(ctx, "L2")}, 171 StringAttr::get(ctx, "REG"))); 172 173 //===--------------------------------------------------------------------===// 174 // Linalg tiling and permutation patterns. 175 //===--------------------------------------------------------------------===// 176 patterns.add<LinalgTilingPattern<MatmulOp>>( 177 ctx, 178 LinalgTilingOptions() 179 .setTileSizes({2000, 3000, 4000}) 180 .setInterchange({1, 2, 0}), 181 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 182 StringAttr::get(ctx, "L2__with_perm__"))); 183 patterns.add<LinalgTilingPattern<MatmulOp>>( 184 ctx, 185 LinalgTilingOptions() 186 .setTileSizes({200, 300, 400}) 187 .setInterchange({1, 0, 2}), 188 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 189 StringAttr::get(ctx, "L1__with_perm__"))); 190 patterns.add<LinalgTilingPattern<MatmulOp>>( 191 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 192 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 193 StringAttr::get(ctx, "REG__with_perm__"))); 194 195 patterns.add<LinalgTilingPattern<MatvecOp>>( 196 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 197 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 198 StringAttr::get(ctx, "L1__with_perm__"))); 199 200 patterns.add<LinalgTilingPattern<MatmulOp>>( 201 ctx, 202 LinalgTilingOptions() 203 .setTileSizes({16, 8, 4}) 204 .setInterchange({1, 2, 0}) 205 .setLoopType(LinalgTilingLoopType::ParallelLoops), 206 LinalgTransformationFilter( 207 StringAttr::get(ctx, "par__with_perm__"), 208 StringAttr::get(ctx, "after_par__with_perm__"))); 209 210 //===--------------------------------------------------------------------===// 211 // Linalg to loops patterns. 212 //===--------------------------------------------------------------------===// 213 patterns.add<LinalgLoweringPattern<DotOp>>( 214 ctx, 215 /*loweringType=*/LinalgLoweringType::Loops, 216 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 217 218 //===--------------------------------------------------------------------===// 219 // Linalg distribution patterns. 220 //===--------------------------------------------------------------------===// 221 LinalgLoopDistributionOptions distributionOptions; 222 223 //===--------------------------------------------------------------------===// 224 // Linalg to vector contraction patterns. 225 //===--------------------------------------------------------------------===// 226 patterns.add<LinalgVectorizationPattern>( 227 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 228 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 229 230 //===--------------------------------------------------------------------===// 231 // Linalg generic interchange pattern. 232 //===--------------------------------------------------------------------===// 233 patterns.add<GenericOpInterchangePattern>( 234 ctx, 235 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 236 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 237 StringAttr::get(ctx, "PERMUTED"))); 238 239 //===--------------------------------------------------------------------===// 240 // Linalg subview operands promotion. 241 //===--------------------------------------------------------------------===// 242 patterns.add<LinalgPromotionPattern<MatmulOp>>( 243 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 244 LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), 245 StringAttr::get(ctx, "_views_promoted_"))); 246 patterns.add<LinalgPromotionPattern<MatmulOp>>( 247 ctx, 248 LinalgPromotionOptions() 249 .setOperandsToPromote({0}) 250 .setUseFullTileBuffersByDefault(true), 251 LinalgTransformationFilter( 252 StringAttr::get(ctx, "_promote_first_view_"), 253 StringAttr::get(ctx, "_first_view_promoted_"))); 254 patterns.add<LinalgPromotionPattern<FillOp>>( 255 ctx, 256 LinalgPromotionOptions() 257 .setOperandsToPromote({1}) 258 .setUseFullTileBuffers({false, true}) 259 .setAlignment(32), 260 LinalgTransformationFilter( 261 StringAttr::get(ctx, "_promote_views_aligned_"), 262 StringAttr::get(ctx, "_views_aligned_promoted_"))); 263 264 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 265 266 // Drop the marker. 267 funcOp.walk([](LinalgOp op) { 268 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 269 }); 270 } 271 272 static void fillL1TilingAndMatmulToVectorPatterns( 273 FuncOp funcOp, StringRef startMarker, 274 SmallVectorImpl<RewritePatternSet> &patternsVector) { 275 MLIRContext *ctx = funcOp.getContext(); 276 patternsVector.emplace_back( 277 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 278 ctx, 279 LinalgTilingOptions() 280 .setTileSizes({8, 12, 16}) 281 .setInterchange({1, 0, 2}), 282 LinalgTransformationFilter(StringAttr::get(ctx, startMarker), 283 StringAttr::get(ctx, "L1")))); 284 285 patternsVector.emplace_back( 286 ctx, 287 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 288 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 289 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 290 StringAttr::get(ctx, "VEC")))); 291 292 patternsVector.emplace_back( 293 ctx, std::make_unique<LinalgVectorizationPattern>( 294 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 295 LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); 296 patternsVector.back().add<LinalgVectorizationPattern>( 297 ctx, LinalgTransformationFilter().addFilter( 298 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 299 } 300 301 //===----------------------------------------------------------------------===// 302 // Test promotion callbacks 303 //===----------------------------------------------------------------------===// 304 305 // Allocation call back 306 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 307 ArrayRef<Value> boundingSubViewSize, 308 DataLayout &layout) { 309 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 310 return b 311 .create<memref::AllocOp>( 312 subView.getLoc(), 313 MemRefType::get(shape, subView.getType().getElementType(), 314 /*affineMapComposition =*/{}, 3), 315 boundingSubViewSize) 316 .getResult(); 317 } 318 319 // Deallocation callback 320 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 321 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 322 return success(); 323 } 324 325 // Copy in call back 326 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 327 bool isOutput) { 328 auto floatType = src.getType().cast<MemRefType>().getElementType(); 329 if (!floatType.isa<FloatType>()) 330 return failure(); 331 if (!isOutput) { 332 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 333 FloatAttr::get(floatType, 42.0)); 334 b.create<FillOp>(src.getLoc(), cst, dst); 335 } 336 b.create<CopyOp>(src.getLoc(), src, dst); 337 return success(); 338 } 339 340 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 341 RewritePatternSet &patterns) { 342 patterns.add<LinalgTilingPattern<MatmulOp>>( 343 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 344 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 345 StringAttr::get(ctx, "PROMOTE"))); 346 patterns.add<LinalgPromotionPattern<MatmulOp>>( 347 ctx, 348 LinalgPromotionOptions() 349 .setOperandsToPromote({0, 2}) 350 .setUseFullTileBuffers({false, false}) 351 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 352 .setCopyInOutFns( 353 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 354 return copyCallBackFn(b, src, dst, false); 355 }, 356 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 357 return copyCallBackFn(b, src, dst, true); 358 }), 359 LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); 360 } 361 362 template <typename IdOp, typename NProcsOp> 363 static SmallVector<ProcInfo, 2> 364 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 365 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 366 SmallVector<ProcInfo, 2> procInfo(count); 367 const char *xyz[] = {"x", "y", "z"}; 368 Type indexType = b.getIndexType(); 369 for (unsigned i = 0; i < count; ++i) { 370 procInfo[count - 1 - i] = { 371 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 372 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 373 } 374 return procInfo; 375 } 376 377 static void fillTileAndDistributePatterns(MLIRContext *context, 378 RewritePatternSet &patterns) { 379 { 380 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 381 cyclicNprocsEqNiters.distributionMethod.resize( 382 2, DistributionMethod::CyclicNumProcsEqNumIters); 383 cyclicNprocsEqNiters.procInfo = 384 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 385 patterns.add<LinalgTilingPattern<MatmulOp>>( 386 context, 387 LinalgTilingOptions() 388 .setTileSizes({8, 8, 4}) 389 .setLoopType(LinalgTilingLoopType::ParallelLoops) 390 .setDistributionOptions(cyclicNprocsEqNiters), 391 LinalgTransformationFilter( 392 StringAttr::get(context, "distribute1"), 393 StringAttr::get(context, "after_distribute1"))); 394 } 395 396 { 397 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 398 cyclicNprocsGeNiters.distributionMethod.resize( 399 2, DistributionMethod::CyclicNumProcsGeNumIters); 400 cyclicNprocsGeNiters.procInfo = 401 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 402 patterns.add<LinalgTilingPattern<MatmulOp>>( 403 context, 404 LinalgTilingOptions() 405 .setTileSizes({8, 8, 4}) 406 .setLoopType(LinalgTilingLoopType::ParallelLoops) 407 .setDistributionOptions(cyclicNprocsGeNiters), 408 LinalgTransformationFilter( 409 StringAttr::get(context, "distribute2"), 410 StringAttr::get(context, "after_distribute2"))); 411 } 412 413 { 414 LinalgLoopDistributionOptions cyclicNprocsDefault; 415 cyclicNprocsDefault.distributionMethod.resize(2, 416 DistributionMethod::Cyclic); 417 cyclicNprocsDefault.procInfo = 418 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 419 patterns.add<LinalgTilingPattern<MatmulOp>>( 420 context, 421 LinalgTilingOptions() 422 .setTileSizes({8, 8, 4}) 423 .setLoopType(LinalgTilingLoopType::ParallelLoops) 424 .setDistributionOptions(cyclicNprocsDefault), 425 LinalgTransformationFilter( 426 StringAttr::get(context, "distribute3"), 427 StringAttr::get(context, "after_distribute3"))); 428 } 429 430 { 431 LinalgLoopDistributionOptions cyclicNprocsMixed1; 432 cyclicNprocsMixed1.distributionMethod = { 433 DistributionMethod::CyclicNumProcsEqNumIters, 434 DistributionMethod::CyclicNumProcsGeNumIters}; 435 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 436 patterns.add<LinalgTilingPattern<MatmulOp>>( 437 context, 438 LinalgTilingOptions() 439 .setTileSizes({8, 8, 4}) 440 .setLoopType(LinalgTilingLoopType::ParallelLoops) 441 .setDistributionOptions(cyclicNprocsMixed1), 442 LinalgTransformationFilter( 443 StringAttr::get(context, "distribute4"), 444 StringAttr::get(context, "after_distribute4"))); 445 } 446 447 { 448 LinalgLoopDistributionOptions cyclicNprocsMixed2; 449 cyclicNprocsMixed2.distributionMethod = { 450 DistributionMethod::CyclicNumProcsGeNumIters, 451 DistributionMethod::Cyclic}; 452 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 453 patterns.add<LinalgTilingPattern<MatmulOp>>( 454 context, 455 LinalgTilingOptions() 456 .setTileSizes({8, 8, 4}) 457 .setLoopType(LinalgTilingLoopType::ParallelLoops) 458 .setDistributionOptions(cyclicNprocsMixed2), 459 LinalgTransformationFilter( 460 StringAttr::get(context, "distribute5"), 461 StringAttr::get(context, "after_distribute5"))); 462 } 463 464 { 465 LinalgLoopDistributionOptions cyclicNprocsMixed3; 466 cyclicNprocsMixed3.distributionMethod = { 467 DistributionMethod::Cyclic, 468 DistributionMethod::CyclicNumProcsEqNumIters}; 469 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 470 471 patterns.add<LinalgTilingPattern<MatmulOp>>( 472 context, 473 LinalgTilingOptions() 474 .setTileSizes({8, 8, 4}) 475 .setLoopType(LinalgTilingLoopType::ParallelLoops) 476 .setDistributionOptions(cyclicNprocsMixed3), 477 LinalgTransformationFilter( 478 StringAttr::get(context, "distribute6"), 479 StringAttr::get(context, "after_distribute6"))); 480 } 481 482 { 483 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 484 cyclicNprocsEqNiters.distributionMethod.resize(2, 485 DistributionMethod::Cyclic); 486 cyclicNprocsEqNiters.procInfo = 487 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 488 patterns.add<LinalgTilingPattern<MatmulOp>>( 489 context, 490 LinalgTilingOptions() 491 .setTileSizes({8, 8, 4}) 492 .setLoopType(LinalgTilingLoopType::Loops) 493 .setDistributionOptions(cyclicNprocsEqNiters), 494 LinalgTransformationFilter( 495 StringAttr::get(context, "tensors_distribute1"), 496 StringAttr::get(context, "tensors_after_distribute1"))); 497 } 498 } 499 500 static void 501 applyMatmulToVectorPatterns(FuncOp funcOp, 502 bool testMatmulToVectorPatterns1dTiling, 503 bool testMatmulToVectorPatterns2dTiling) { 504 MLIRContext *ctx = funcOp.getContext(); 505 SmallVector<RewritePatternSet, 4> stage1Patterns; 506 if (testMatmulToVectorPatterns1dTiling) { 507 fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); 508 } else if (testMatmulToVectorPatterns2dTiling) { 509 stage1Patterns.emplace_back( 510 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 511 ctx, 512 LinalgTilingOptions() 513 .setTileSizes({768, 264, 768}) 514 .setInterchange({1, 2, 0}), 515 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 516 StringAttr::get(ctx, "L2")))); 517 fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); 518 } 519 { 520 // Canonicalization patterns 521 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 522 vector::populateVectorTransferPermutationMapLoweringPatterns( 523 canonicalizationPatterns); 524 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 525 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 526 } 527 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 528 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 529 FrozenRewritePatternSet stage2Patterns = 530 getLinalgTilingCanonicalizationPatterns(ctx); 531 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 532 std::move(stage2Patterns)); 533 } 534 535 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 536 RewritePatternSet forwardPattern(funcOp.getContext()); 537 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 538 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 539 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 540 } 541 542 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 543 RewritePatternSet patterns(funcOp.getContext()); 544 patterns.add<LinalgVectorizationPattern>( 545 funcOp.getContext(), 546 LinalgTransformationFilter() 547 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 548 populatePadTensorOpVectorizationPatterns(patterns); 549 populateConvolutionVectorizationPatterns(patterns); 550 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 551 } 552 553 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 554 RewritePatternSet patterns(funcOp.getContext()); 555 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 556 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 557 } 558 559 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 560 RewritePatternSet patterns(funcOp.getContext()); 561 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 562 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 563 } 564 565 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 566 RewritePatternSet patterns(funcOp.getContext()); 567 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 568 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 569 } 570 571 static void applyTilePattern(FuncOp funcOp, std::string loopType, 572 ArrayRef<int64_t> tileSizes, 573 ArrayRef<int64_t> peeledLoops, 574 bool scalarizeDynamicDims) { 575 MLIRContext *context = funcOp.getContext(); 576 RewritePatternSet tilingPattern(context); 577 LinalgTilingLoopType type = 578 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 579 .Case("for", LinalgTilingLoopType::Loops) 580 .Case("affine", LinalgTilingLoopType::AffineLoops) 581 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 582 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 583 auto linalgTilingOptions = linalg::LinalgTilingOptions() 584 .setPeeledLoops(peeledLoops) 585 .setLoopType(type); 586 if (scalarizeDynamicDims) { 587 linalgTilingOptions.scalarizeDynamicDims(); 588 assert(tileSizes.empty() && 589 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 590 } else { 591 linalgTilingOptions.setTileSizes(tileSizes); 592 } 593 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, 594 linalg::LinalgTilingPattern<linalg::GenericOp>>( 595 context, linalgTilingOptions, 596 linalg::LinalgTransformationFilter(StringAttr::get(context, "tile"))); 597 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 598 } 599 600 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 601 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 602 603 namespace { 604 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 605 /// `idx`-th loop contains only "full" iterations and a second loop for the 606 /// remaining partial iteration (if any). 607 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 608 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 609 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 610 } 611 612 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 613 PatternRewriter &rewriter) const override { 614 SmallVector<int64_t> peeledLoops; 615 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 616 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 617 peeledLoops = 618 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 619 return attr.cast<IntegerAttr>().getInt(); 620 })); 621 // Check if the loop was already peeled. 622 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 623 return failure(); 624 } 625 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 626 // No peeling of loop nests with a partial iteration. 627 return failure(); 628 629 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 630 return failure(); 631 632 // Peel loop and canonicalize. 633 TiledLoopOp result; 634 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 635 result))) 636 return failure(); 637 638 // Apply label, so that the same loop is not rewritten a second time. 639 peeledLoops.push_back(idx); 640 rewriter.updateRootInPlace(loopOp, [&]() { 641 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 642 }); 643 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 644 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 645 646 return success(); 647 } 648 649 /// Index of loop to peel. 650 int64_t idx; 651 652 /// If set to true, do not peel TiledLoopOps with a partial iteration. 653 bool skipPartial; 654 }; 655 } // namespace 656 657 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 658 ArrayRef<unsigned> loops, 659 bool skipPartial) { 660 MLIRContext *ctx = funcOp.getContext(); 661 RewritePatternSet patterns(ctx); 662 for (unsigned idx : loops) 663 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 664 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 665 666 // Drop the markers. 667 funcOp.walk([](TiledLoopOp op) { 668 op->removeAttr(kPeeledLoopsLabel); 669 op->removeAttr(kPartialIterationLabel); 670 }); 671 } 672 673 /// Apply transformations specified as patterns. 674 void TestLinalgTransforms::runOnFunction() { 675 auto lambda = [&](void *) { 676 getFunction().walk([](LinalgOp op) { 677 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 678 }); 679 }; 680 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 681 682 if (testPromotionOptions) { 683 RewritePatternSet patterns(&getContext()); 684 fillPromotionCallBackPatterns(&getContext(), patterns); 685 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 686 return; 687 } 688 if (testTileAndDistributionOptions) { 689 RewritePatternSet patterns(&getContext()); 690 fillTileAndDistributePatterns(&getContext(), patterns); 691 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 692 return; 693 } 694 if (testPatterns) 695 return applyPatterns(getFunction()); 696 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 697 return applyMatmulToVectorPatterns(getFunction(), 698 testMatmulToVectorPatterns1dTiling, 699 testMatmulToVectorPatterns2dTiling); 700 if (testVectorTransferForwardingPatterns) 701 return applyVectorTransferForwardingPatterns(getFunction()); 702 if (testGenericToVectorPattern) 703 return applyLinalgToVectorPatterns(getFunction()); 704 if (testTransformPadTensor) 705 return applyPadTensorToGenericPatterns(getFunction()); 706 if (testGeneralizePadTensor) 707 return applyGeneralizePadTensorPatterns(getFunction()); 708 if (testSwapSubTensorPadTensor) 709 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 710 if (testTiledLoopPeeling.hasValue()) 711 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, 712 skipPartial); 713 if (testTilePattern) 714 return applyTilePattern(getFunction(), loopType, tileSizes, peeledLoops, 715 /*scalarizeDynamicDims=*/false); 716 if (testTileScalarizeDynamicDims) 717 return applyTilePattern(getFunction(), loopType, tileSizes, 718 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 719 } 720 721 namespace mlir { 722 namespace test { 723 void registerTestLinalgTransforms() { 724 PassRegistration<TestLinalgTransforms>(); 725 } 726 } // namespace test 727 } // namespace mlir 728