1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallVector.h" 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 namespace { 34 struct TestLinalgTransforms 35 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 36 TestLinalgTransforms() = default; 37 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 38 39 void getDependentDialects(DialectRegistry ®istry) const override { 40 // clang-format off 41 registry.insert<AffineDialect, 42 memref::MemRefDialect, 43 scf::SCFDialect, 44 StandardOpsDialect, 45 vector::VectorDialect, 46 gpu::GPUDialect>(); 47 // clang-format on 48 } 49 StringRef getArgument() const final { 50 return "test-linalg-transform-patterns"; 51 } 52 StringRef getDescription() const final { 53 return "Test Linalg transformation patterns by applying them greedily."; 54 } 55 56 void runOnFunction() override; 57 58 Option<bool> testPatterns{*this, "test-patterns", 59 llvm::cl::desc("Test a mixed set of patterns"), 60 llvm::cl::init(false)}; 61 Option<bool> testMatmulToVectorPatterns1dTiling{ 62 *this, "test-matmul-to-vector-patterns-tile-1d", 63 llvm::cl::desc( 64 "Test a fused pass that applies patterns from matmul to vectors via " 65 "1-d tiling"), 66 llvm::cl::init(false)}; 67 Option<bool> testMatmulToVectorPatterns2dTiling{ 68 *this, "test-matmul-to-vector-patterns-tile-2d", 69 llvm::cl::desc( 70 "Test a fused pass that applies patterns from matmul to vectors via " 71 "2-d tiling"), 72 llvm::cl::init(false)}; 73 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 74 llvm::cl::desc("Test promotion options"), 75 llvm::cl::init(false)}; 76 Option<bool> testTileAndDistributionOptions{ 77 *this, "test-tile-and-distribute-options", 78 llvm::cl::desc("Test tile and distribute options"), 79 llvm::cl::init(false)}; 80 Option<bool> testVectorTransferForwardingPatterns{ 81 *this, "test-vector-transfer-forwarding-patterns", 82 llvm::cl::desc( 83 "Test a fused pass that forwards linalg.copy to vector.transfer"), 84 llvm::cl::init(false)}; 85 Option<bool> testGenericToVectorPattern{ 86 *this, "test-linalg-to-vector-patterns", 87 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 88 "in vector.contract form"), 89 llvm::cl::init(false)}; 90 Option<bool> testTilePattern{*this, "test-tile-pattern", 91 llvm::cl::desc("Test tile pattern"), 92 llvm::cl::init(false)}; 93 Option<bool> testTileScalarizeDynamicDims{ 94 *this, "test-tile-scalarize-dynamic-dims", 95 llvm::cl::desc("Test tiling of dynamic dims by 1"), 96 llvm::cl::init(false)}; 97 Option<bool> testTransformPadTensor{ 98 *this, "test-transform-pad-tensor", 99 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 100 llvm::cl::init(false)}; 101 Option<bool> testGeneralizePadTensor{ 102 *this, "test-generalize-pad-tensor", 103 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 104 llvm::cl::init(false)}; 105 Option<bool> testSwapSubTensorPadTensor{ 106 *this, "test-swap-subtensor-padtensor", 107 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 108 "pad_tensor(subtensor)"), 109 llvm::cl::init(false)}; 110 ListOption<int64_t> peeledLoops{ 111 *this, "peeled-loops", 112 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 113 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 114 ListOption<int64_t> tileSizes{ 115 *this, "tile-sizes", 116 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 117 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 118 ListOption<unsigned> testTiledLoopPeeling{ 119 *this, "test-tiled-loop-peeling", 120 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 121 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 122 Option<bool> skipPartial{ 123 *this, "skip-partial", 124 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 125 llvm::cl::init(false)}; 126 Option<std::string> loopType{ 127 *this, "loop-type", 128 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 129 "tiled_loop"), 130 llvm::cl::init("for")}; 131 Option<bool> testDecomposeConvolutionPattern{ 132 *this, "test-decompose-convolution-patterns", 133 llvm::cl::desc("Test a set of patterns to rewrite high-D convolution ops " 134 "into low-D ones"), 135 llvm::cl::init(false)}; 136 }; 137 } // end anonymous namespace 138 139 static void applyPatterns(FuncOp funcOp) { 140 MLIRContext *ctx = funcOp.getContext(); 141 RewritePatternSet patterns(ctx); 142 143 //===--------------------------------------------------------------------===// 144 // Linalg tiling patterns. 145 //===--------------------------------------------------------------------===// 146 patterns.add<LinalgTilingPattern<MatmulOp>>( 147 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 148 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 149 StringAttr::get(ctx, "L3"))); 150 patterns.add<LinalgTilingPattern<MatmulOp>>( 151 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 152 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 153 StringAttr::get(ctx, "L2"))); 154 patterns.add<LinalgTilingPattern<MatmulOp>>( 155 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 156 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 157 StringAttr::get(ctx, "L1"))); 158 patterns.add<LinalgTilingPattern<MatmulOp>>( 159 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 160 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 161 StringAttr::get(ctx, "REG"))); 162 163 patterns.add<LinalgTilingPattern<MatvecOp>>( 164 ctx, 165 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 166 LinalgTilingLoopType::ParallelLoops), 167 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 168 StringAttr::get(ctx, "L1"))); 169 170 patterns.add<LinalgTilingPattern<DotOp>>( 171 ctx, LinalgTilingOptions().setTileSizes(8000), 172 LinalgTransformationFilter( 173 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 174 StringAttr::get(ctx, "L3"), 175 StringAttr::get(ctx, "L2")}, 176 StringAttr::get(ctx, "REG"))); 177 178 //===--------------------------------------------------------------------===// 179 // Linalg tiling and permutation patterns. 180 //===--------------------------------------------------------------------===// 181 patterns.add<LinalgTilingPattern<MatmulOp>>( 182 ctx, 183 LinalgTilingOptions() 184 .setTileSizes({2000, 3000, 4000}) 185 .setInterchange({1, 2, 0}), 186 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 187 StringAttr::get(ctx, "L2__with_perm__"))); 188 patterns.add<LinalgTilingPattern<MatmulOp>>( 189 ctx, 190 LinalgTilingOptions() 191 .setTileSizes({200, 300, 400}) 192 .setInterchange({1, 0, 2}), 193 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 194 StringAttr::get(ctx, "L1__with_perm__"))); 195 patterns.add<LinalgTilingPattern<MatmulOp>>( 196 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 197 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 198 StringAttr::get(ctx, "REG__with_perm__"))); 199 200 patterns.add<LinalgTilingPattern<MatvecOp>>( 201 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 202 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 203 StringAttr::get(ctx, "L1__with_perm__"))); 204 205 patterns.add<LinalgTilingPattern<MatmulOp>>( 206 ctx, 207 LinalgTilingOptions() 208 .setTileSizes({16, 8, 4}) 209 .setInterchange({1, 2, 0}) 210 .setLoopType(LinalgTilingLoopType::ParallelLoops), 211 LinalgTransformationFilter( 212 StringAttr::get(ctx, "par__with_perm__"), 213 StringAttr::get(ctx, "after_par__with_perm__"))); 214 215 //===--------------------------------------------------------------------===// 216 // Linalg to loops patterns. 217 //===--------------------------------------------------------------------===// 218 patterns.add<LinalgLoweringPattern<DotOp>>( 219 ctx, 220 /*loweringType=*/LinalgLoweringType::Loops, 221 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 222 223 //===--------------------------------------------------------------------===// 224 // Linalg distribution patterns. 225 //===--------------------------------------------------------------------===// 226 LinalgLoopDistributionOptions distributionOptions; 227 228 //===--------------------------------------------------------------------===// 229 // Linalg to vector contraction patterns. 230 //===--------------------------------------------------------------------===// 231 patterns.add<LinalgVectorizationPattern>( 232 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 233 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 234 235 //===--------------------------------------------------------------------===// 236 // Linalg generic interchange pattern. 237 //===--------------------------------------------------------------------===// 238 patterns.add<GenericOpInterchangePattern>( 239 ctx, 240 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 241 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 242 StringAttr::get(ctx, "PERMUTED"))); 243 244 //===--------------------------------------------------------------------===// 245 // Linalg subview operands promotion. 246 //===--------------------------------------------------------------------===// 247 patterns.add<LinalgPromotionPattern<MatmulOp>>( 248 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 249 LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), 250 StringAttr::get(ctx, "_views_promoted_"))); 251 patterns.add<LinalgPromotionPattern<MatmulOp>>( 252 ctx, 253 LinalgPromotionOptions() 254 .setOperandsToPromote({0}) 255 .setUseFullTileBuffersByDefault(true), 256 LinalgTransformationFilter( 257 StringAttr::get(ctx, "_promote_first_view_"), 258 StringAttr::get(ctx, "_first_view_promoted_"))); 259 patterns.add<LinalgPromotionPattern<FillOp>>( 260 ctx, 261 LinalgPromotionOptions() 262 .setOperandsToPromote({1}) 263 .setUseFullTileBuffers({false, true}) 264 .setAlignment(32), 265 LinalgTransformationFilter( 266 StringAttr::get(ctx, "_promote_views_aligned_"), 267 StringAttr::get(ctx, "_views_aligned_promoted_"))); 268 269 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 270 271 // Drop the marker. 272 funcOp.walk([](LinalgOp op) { 273 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 274 }); 275 } 276 277 static void fillL1TilingAndMatmulToVectorPatterns( 278 FuncOp funcOp, StringRef startMarker, 279 SmallVectorImpl<RewritePatternSet> &patternsVector) { 280 MLIRContext *ctx = funcOp.getContext(); 281 patternsVector.emplace_back( 282 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 283 ctx, 284 LinalgTilingOptions() 285 .setTileSizes({8, 12, 16}) 286 .setInterchange({1, 0, 2}), 287 LinalgTransformationFilter(StringAttr::get(ctx, startMarker), 288 StringAttr::get(ctx, "L1")))); 289 290 patternsVector.emplace_back( 291 ctx, 292 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 293 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 294 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 295 StringAttr::get(ctx, "VEC")))); 296 297 patternsVector.emplace_back( 298 ctx, std::make_unique<LinalgVectorizationPattern>( 299 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 300 LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); 301 patternsVector.back().add<LinalgVectorizationPattern>( 302 ctx, LinalgTransformationFilter().addFilter( 303 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 304 } 305 306 //===----------------------------------------------------------------------===// 307 // Test promotion callbacks 308 //===----------------------------------------------------------------------===// 309 310 // Allocation call back 311 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 312 ArrayRef<Value> boundingSubViewSize, 313 DataLayout &layout) { 314 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 315 return b 316 .create<memref::AllocOp>( 317 subView.getLoc(), 318 MemRefType::get(shape, subView.getType().getElementType(), 319 /*affineMapComposition =*/{}, 3), 320 boundingSubViewSize) 321 .getResult(); 322 } 323 324 // Deallocation callback 325 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 326 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 327 return success(); 328 } 329 330 // Copy in call back 331 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 332 bool isOutput) { 333 auto floatType = src.getType().cast<MemRefType>().getElementType(); 334 if (!floatType.isa<FloatType>()) 335 return failure(); 336 if (!isOutput) { 337 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 338 FloatAttr::get(floatType, 42.0)); 339 b.create<FillOp>(src.getLoc(), cst, dst); 340 } 341 b.create<CopyOp>(src.getLoc(), src, dst); 342 return success(); 343 } 344 345 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 346 RewritePatternSet &patterns) { 347 patterns.add<LinalgTilingPattern<MatmulOp>>( 348 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 349 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 350 StringAttr::get(ctx, "PROMOTE"))); 351 patterns.add<LinalgPromotionPattern<MatmulOp>>( 352 ctx, 353 LinalgPromotionOptions() 354 .setOperandsToPromote({0, 2}) 355 .setUseFullTileBuffers({false, false}) 356 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 357 .setCopyInOutFns( 358 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 359 return copyCallBackFn(b, src, dst, false); 360 }, 361 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 362 return copyCallBackFn(b, src, dst, true); 363 }), 364 LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); 365 } 366 367 template <typename IdOp, typename NProcsOp> 368 static SmallVector<ProcInfo, 2> 369 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 370 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 371 SmallVector<ProcInfo, 2> procInfo(count); 372 const char *xyz[] = {"x", "y", "z"}; 373 Type indexType = b.getIndexType(); 374 for (unsigned i = 0; i < count; ++i) { 375 procInfo[count - 1 - i] = { 376 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 377 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 378 } 379 return procInfo; 380 } 381 382 static void fillTileAndDistributePatterns(MLIRContext *context, 383 RewritePatternSet &patterns) { 384 { 385 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 386 cyclicNprocsEqNiters.distributionMethod.resize( 387 2, DistributionMethod::CyclicNumProcsEqNumIters); 388 cyclicNprocsEqNiters.procInfo = 389 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 390 patterns.add<LinalgTilingPattern<MatmulOp>>( 391 context, 392 LinalgTilingOptions() 393 .setTileSizes({8, 8, 4}) 394 .setLoopType(LinalgTilingLoopType::ParallelLoops) 395 .setDistributionOptions(cyclicNprocsEqNiters), 396 LinalgTransformationFilter( 397 StringAttr::get(context, "distribute1"), 398 StringAttr::get(context, "after_distribute1"))); 399 } 400 401 { 402 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 403 cyclicNprocsGeNiters.distributionMethod.resize( 404 2, DistributionMethod::CyclicNumProcsGeNumIters); 405 cyclicNprocsGeNiters.procInfo = 406 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 407 patterns.add<LinalgTilingPattern<MatmulOp>>( 408 context, 409 LinalgTilingOptions() 410 .setTileSizes({8, 8, 4}) 411 .setLoopType(LinalgTilingLoopType::ParallelLoops) 412 .setDistributionOptions(cyclicNprocsGeNiters), 413 LinalgTransformationFilter( 414 StringAttr::get(context, "distribute2"), 415 StringAttr::get(context, "after_distribute2"))); 416 } 417 418 { 419 LinalgLoopDistributionOptions cyclicNprocsDefault; 420 cyclicNprocsDefault.distributionMethod.resize(2, 421 DistributionMethod::Cyclic); 422 cyclicNprocsDefault.procInfo = 423 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 424 patterns.add<LinalgTilingPattern<MatmulOp>>( 425 context, 426 LinalgTilingOptions() 427 .setTileSizes({8, 8, 4}) 428 .setLoopType(LinalgTilingLoopType::ParallelLoops) 429 .setDistributionOptions(cyclicNprocsDefault), 430 LinalgTransformationFilter( 431 StringAttr::get(context, "distribute3"), 432 StringAttr::get(context, "after_distribute3"))); 433 } 434 435 { 436 LinalgLoopDistributionOptions cyclicNprocsMixed1; 437 cyclicNprocsMixed1.distributionMethod = { 438 DistributionMethod::CyclicNumProcsEqNumIters, 439 DistributionMethod::CyclicNumProcsGeNumIters}; 440 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 441 patterns.add<LinalgTilingPattern<MatmulOp>>( 442 context, 443 LinalgTilingOptions() 444 .setTileSizes({8, 8, 4}) 445 .setLoopType(LinalgTilingLoopType::ParallelLoops) 446 .setDistributionOptions(cyclicNprocsMixed1), 447 LinalgTransformationFilter( 448 StringAttr::get(context, "distribute4"), 449 StringAttr::get(context, "after_distribute4"))); 450 } 451 452 { 453 LinalgLoopDistributionOptions cyclicNprocsMixed2; 454 cyclicNprocsMixed2.distributionMethod = { 455 DistributionMethod::CyclicNumProcsGeNumIters, 456 DistributionMethod::Cyclic}; 457 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 458 patterns.add<LinalgTilingPattern<MatmulOp>>( 459 context, 460 LinalgTilingOptions() 461 .setTileSizes({8, 8, 4}) 462 .setLoopType(LinalgTilingLoopType::ParallelLoops) 463 .setDistributionOptions(cyclicNprocsMixed2), 464 LinalgTransformationFilter( 465 StringAttr::get(context, "distribute5"), 466 StringAttr::get(context, "after_distribute5"))); 467 } 468 469 { 470 LinalgLoopDistributionOptions cyclicNprocsMixed3; 471 cyclicNprocsMixed3.distributionMethod = { 472 DistributionMethod::Cyclic, 473 DistributionMethod::CyclicNumProcsEqNumIters}; 474 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 475 476 patterns.add<LinalgTilingPattern<MatmulOp>>( 477 context, 478 LinalgTilingOptions() 479 .setTileSizes({8, 8, 4}) 480 .setLoopType(LinalgTilingLoopType::ParallelLoops) 481 .setDistributionOptions(cyclicNprocsMixed3), 482 LinalgTransformationFilter( 483 StringAttr::get(context, "distribute6"), 484 StringAttr::get(context, "after_distribute6"))); 485 } 486 487 { 488 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 489 cyclicNprocsEqNiters.distributionMethod.resize(2, 490 DistributionMethod::Cyclic); 491 cyclicNprocsEqNiters.procInfo = 492 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 493 patterns.add<LinalgTilingPattern<MatmulOp>>( 494 context, 495 LinalgTilingOptions() 496 .setTileSizes({8, 8, 4}) 497 .setLoopType(LinalgTilingLoopType::Loops) 498 .setDistributionOptions(cyclicNprocsEqNiters), 499 LinalgTransformationFilter( 500 StringAttr::get(context, "tensors_distribute1"), 501 StringAttr::get(context, "tensors_after_distribute1"))); 502 } 503 } 504 505 static void 506 applyMatmulToVectorPatterns(FuncOp funcOp, 507 bool testMatmulToVectorPatterns1dTiling, 508 bool testMatmulToVectorPatterns2dTiling) { 509 MLIRContext *ctx = funcOp.getContext(); 510 SmallVector<RewritePatternSet, 4> stage1Patterns; 511 if (testMatmulToVectorPatterns1dTiling) { 512 fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); 513 } else if (testMatmulToVectorPatterns2dTiling) { 514 stage1Patterns.emplace_back( 515 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 516 ctx, 517 LinalgTilingOptions() 518 .setTileSizes({768, 264, 768}) 519 .setInterchange({1, 2, 0}), 520 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 521 StringAttr::get(ctx, "L2")))); 522 fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); 523 } 524 { 525 // Canonicalization patterns 526 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 527 vector::populateVectorTransferPermutationMapLoweringPatterns( 528 canonicalizationPatterns); 529 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 530 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 531 } 532 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 533 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 534 FrozenRewritePatternSet stage2Patterns = 535 getLinalgTilingCanonicalizationPatterns(ctx); 536 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 537 std::move(stage2Patterns)); 538 } 539 540 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 541 RewritePatternSet forwardPattern(funcOp.getContext()); 542 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 543 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 544 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 545 } 546 547 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 548 RewritePatternSet patterns(funcOp.getContext()); 549 patterns.add<LinalgVectorizationPattern>( 550 funcOp.getContext(), 551 LinalgTransformationFilter() 552 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 553 populatePadTensorOpVectorizationPatterns(patterns); 554 populateConvolutionVectorizationPatterns(patterns); 555 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 556 } 557 558 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 559 RewritePatternSet patterns(funcOp.getContext()); 560 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 561 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 562 } 563 564 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 565 RewritePatternSet patterns(funcOp.getContext()); 566 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 567 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 568 } 569 570 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 571 RewritePatternSet patterns(funcOp.getContext()); 572 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 573 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 574 } 575 576 static void applyTilePattern(FuncOp funcOp, std::string loopType, 577 ArrayRef<int64_t> tileSizes, 578 ArrayRef<int64_t> peeledLoops, 579 bool scalarizeDynamicDims) { 580 MLIRContext *context = funcOp.getContext(); 581 RewritePatternSet tilingPattern(context); 582 LinalgTilingLoopType type = 583 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 584 .Case("for", LinalgTilingLoopType::Loops) 585 .Case("affine", LinalgTilingLoopType::AffineLoops) 586 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 587 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 588 auto linalgTilingOptions = linalg::LinalgTilingOptions() 589 .setPeeledLoops(peeledLoops) 590 .setLoopType(type); 591 if (scalarizeDynamicDims) { 592 linalgTilingOptions.scalarizeDynamicDims(); 593 assert(tileSizes.empty() && 594 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 595 } else { 596 linalgTilingOptions.setTileSizes(tileSizes); 597 } 598 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, 599 linalg::LinalgTilingPattern<linalg::GenericOp>>( 600 context, linalgTilingOptions, 601 linalg::LinalgTransformationFilter(StringAttr::get(context, "tile"))); 602 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 603 } 604 605 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 606 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 607 608 namespace { 609 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 610 /// `idx`-th loop contains only "full" iterations and a second loop for the 611 /// remaining partial iteration (if any). 612 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 613 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 614 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 615 } 616 617 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 618 PatternRewriter &rewriter) const override { 619 SmallVector<int64_t> peeledLoops; 620 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 621 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 622 peeledLoops = 623 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 624 return attr.cast<IntegerAttr>().getInt(); 625 })); 626 // Check if the loop was already peeled. 627 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 628 return failure(); 629 } 630 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 631 // No peeling of loop nests with a partial iteration. 632 return failure(); 633 634 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 635 return failure(); 636 637 // Peel loop and canonicalize. 638 TiledLoopOp result; 639 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 640 result))) 641 return failure(); 642 643 // Apply label, so that the same loop is not rewritten a second time. 644 peeledLoops.push_back(idx); 645 rewriter.updateRootInPlace(loopOp, [&]() { 646 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 647 }); 648 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 649 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 650 651 return success(); 652 } 653 654 /// Index of loop to peel. 655 int64_t idx; 656 657 /// If set to true, do not peel TiledLoopOps with a partial iteration. 658 bool skipPartial; 659 }; 660 } // namespace 661 662 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 663 ArrayRef<unsigned> loops, 664 bool skipPartial) { 665 MLIRContext *ctx = funcOp.getContext(); 666 RewritePatternSet patterns(ctx); 667 for (unsigned idx : loops) 668 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 669 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 670 671 // Drop the markers. 672 funcOp.walk([](TiledLoopOp op) { 673 op->removeAttr(kPeeledLoopsLabel); 674 op->removeAttr(kPartialIterationLabel); 675 }); 676 } 677 678 /// Apply transformations specified as patterns. 679 void TestLinalgTransforms::runOnFunction() { 680 auto lambda = [&](void *) { 681 getFunction().walk([](LinalgOp op) { 682 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 683 }); 684 }; 685 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 686 687 if (testPromotionOptions) { 688 RewritePatternSet patterns(&getContext()); 689 fillPromotionCallBackPatterns(&getContext(), patterns); 690 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 691 return; 692 } 693 if (testTileAndDistributionOptions) { 694 RewritePatternSet patterns(&getContext()); 695 fillTileAndDistributePatterns(&getContext(), patterns); 696 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 697 return; 698 } 699 if (testPatterns) 700 return applyPatterns(getFunction()); 701 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 702 return applyMatmulToVectorPatterns(getFunction(), 703 testMatmulToVectorPatterns1dTiling, 704 testMatmulToVectorPatterns2dTiling); 705 if (testVectorTransferForwardingPatterns) 706 return applyVectorTransferForwardingPatterns(getFunction()); 707 if (testGenericToVectorPattern) 708 return applyLinalgToVectorPatterns(getFunction()); 709 if (testTransformPadTensor) 710 return applyPadTensorToGenericPatterns(getFunction()); 711 if (testGeneralizePadTensor) 712 return applyGeneralizePadTensorPatterns(getFunction()); 713 if (testSwapSubTensorPadTensor) 714 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 715 if (testTiledLoopPeeling.hasValue()) 716 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, 717 skipPartial); 718 if (testTilePattern) 719 return applyTilePattern(getFunction(), loopType, tileSizes, peeledLoops, 720 /*scalarizeDynamicDims=*/false); 721 if (testTileScalarizeDynamicDims) 722 return applyTilePattern(getFunction(), loopType, tileSizes, 723 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 724 if (testDecomposeConvolutionPattern) { 725 // TODO: thread all tests through LinalgStrategy passes. 726 OpPassManager dynamicPM("builtin.func"); 727 dynamicPM.addPass(createLinalgStrategyDecomposePass()); 728 if (failed(runPipeline(dynamicPM, getFunction()))) 729 return signalPassFailure(); 730 } 731 } 732 733 namespace mlir { 734 namespace test { 735 void registerTestLinalgTransforms() { 736 PassRegistration<TestLinalgTransforms>(); 737 } 738 } // namespace test 739 } // namespace mlir 740