1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 18 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/ADT/SmallVector.h" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 namespace { 33 struct TestLinalgTransforms 34 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 35 TestLinalgTransforms() = default; 36 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 37 38 void getDependentDialects(DialectRegistry ®istry) const override { 39 // clang-format off 40 registry.insert<AffineDialect, 41 memref::MemRefDialect, 42 scf::SCFDialect, 43 StandardOpsDialect, 44 vector::VectorDialect, 45 gpu::GPUDialect>(); 46 // clang-format on 47 } 48 StringRef getArgument() const final { 49 return "test-linalg-transform-patterns"; 50 } 51 StringRef getDescription() const final { 52 return "Test Linalg transformation patterns by applying them greedily."; 53 } 54 55 void runOnFunction() override; 56 57 Option<bool> testPatterns{*this, "test-patterns", 58 llvm::cl::desc("Test a mixed set of patterns"), 59 llvm::cl::init(false)}; 60 Option<bool> testMatmulToVectorPatterns1dTiling{ 61 *this, "test-matmul-to-vector-patterns-tile-1d", 62 llvm::cl::desc( 63 "Test a fused pass that applies patterns from matmul to vectors via " 64 "1-d tiling"), 65 llvm::cl::init(false)}; 66 Option<bool> testMatmulToVectorPatterns2dTiling{ 67 *this, "test-matmul-to-vector-patterns-tile-2d", 68 llvm::cl::desc( 69 "Test a fused pass that applies patterns from matmul to vectors via " 70 "2-d tiling"), 71 llvm::cl::init(false)}; 72 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 73 llvm::cl::desc("Test promotion options"), 74 llvm::cl::init(false)}; 75 Option<bool> testTileAndDistributionOptions{ 76 *this, "test-tile-and-distribute-options", 77 llvm::cl::desc("Test tile and distribute options"), 78 llvm::cl::init(false)}; 79 Option<bool> testVectorTransferForwardingPatterns{ 80 *this, "test-vector-transfer-forwarding-patterns", 81 llvm::cl::desc( 82 "Test a fused pass that forwards linalg.copy to vector.transfer"), 83 llvm::cl::init(false)}; 84 Option<bool> testGenericToVectorPattern{ 85 *this, "test-linalg-to-vector-patterns", 86 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 87 "in vector.contract form"), 88 llvm::cl::init(false)}; 89 Option<bool> testTilePattern{*this, "test-tile-pattern", 90 llvm::cl::desc("Test tile pattern"), 91 llvm::cl::init(false)}; 92 Option<bool> testTileScalarizeDynamicDims{ 93 *this, "test-tile-scalarize-dynamic-dims", 94 llvm::cl::desc("Test tiling of dynamic dims by 1"), 95 llvm::cl::init(false)}; 96 Option<bool> testTransformPadTensor{ 97 *this, "test-transform-pad-tensor", 98 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 99 llvm::cl::init(false)}; 100 Option<bool> testGeneralizePadTensor{ 101 *this, "test-generalize-pad-tensor", 102 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 103 llvm::cl::init(false)}; 104 Option<bool> testSwapSubTensorPadTensor{ 105 *this, "test-swap-subtensor-padtensor", 106 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 107 "pad_tensor(subtensor)"), 108 llvm::cl::init(false)}; 109 ListOption<int64_t> peeledLoops{ 110 *this, "peeled-loops", 111 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 112 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 113 ListOption<int64_t> tileSizes{ 114 *this, "tile-sizes", 115 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 116 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 117 ListOption<unsigned> testTiledLoopPeeling{ 118 *this, "test-tiled-loop-peeling", 119 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 120 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 121 Option<bool> skipPartial{ 122 *this, "skip-partial", 123 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 124 llvm::cl::init(false)}; 125 Option<std::string> loopType{ 126 *this, "loop-type", 127 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 128 "tiled_loop"), 129 llvm::cl::init("for")}; 130 Option<bool> testDecomposeConvolutionPattern{ 131 *this, "test-decompose-convolution-patterns", 132 llvm::cl::desc("Test a set of patterns to rewrite high-D convolution ops " 133 "into low-D ones"), 134 llvm::cl::init(false)}; 135 }; 136 } // end anonymous namespace 137 138 static void applyPatterns(FuncOp funcOp) { 139 MLIRContext *ctx = funcOp.getContext(); 140 RewritePatternSet patterns(ctx); 141 142 //===--------------------------------------------------------------------===// 143 // Linalg tiling patterns. 144 //===--------------------------------------------------------------------===// 145 patterns.add<LinalgTilingPattern<MatmulOp>>( 146 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 147 LinalgTransformationFilter(Identifier::get("MEM", ctx), 148 Identifier::get("L3", ctx))); 149 patterns.add<LinalgTilingPattern<MatmulOp>>( 150 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 151 LinalgTransformationFilter(Identifier::get("L3", ctx), 152 Identifier::get("L2", ctx))); 153 patterns.add<LinalgTilingPattern<MatmulOp>>( 154 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 155 LinalgTransformationFilter(Identifier::get("L2", ctx), 156 Identifier::get("L1", ctx))); 157 patterns.add<LinalgTilingPattern<MatmulOp>>( 158 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 159 LinalgTransformationFilter(Identifier::get("L1", ctx), 160 Identifier::get("REG", ctx))); 161 162 patterns.add<LinalgTilingPattern<MatvecOp>>( 163 ctx, 164 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 165 LinalgTilingLoopType::ParallelLoops), 166 LinalgTransformationFilter(ArrayRef<Identifier>{}, 167 Identifier::get("L1", ctx))); 168 169 patterns.add<LinalgTilingPattern<DotOp>>( 170 ctx, LinalgTilingOptions().setTileSizes(8000), 171 LinalgTransformationFilter( 172 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 173 Identifier::get("L3", ctx), 174 Identifier::get("L2", ctx)}, 175 Identifier::get("REG", ctx))); 176 177 //===--------------------------------------------------------------------===// 178 // Linalg tiling and permutation patterns. 179 //===--------------------------------------------------------------------===// 180 patterns.add<LinalgTilingPattern<MatmulOp>>( 181 ctx, 182 LinalgTilingOptions() 183 .setTileSizes({2000, 3000, 4000}) 184 .setInterchange({1, 2, 0}), 185 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 186 Identifier::get("L2__with_perm__", ctx))); 187 patterns.add<LinalgTilingPattern<MatmulOp>>( 188 ctx, 189 LinalgTilingOptions() 190 .setTileSizes({200, 300, 400}) 191 .setInterchange({1, 0, 2}), 192 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 193 Identifier::get("L1__with_perm__", ctx))); 194 patterns.add<LinalgTilingPattern<MatmulOp>>( 195 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 196 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 197 Identifier::get("REG__with_perm__", ctx))); 198 199 patterns.add<LinalgTilingPattern<MatvecOp>>( 200 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 201 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 202 Identifier::get("L1__with_perm__", ctx))); 203 204 patterns.add<LinalgTilingPattern<MatmulOp>>( 205 ctx, 206 LinalgTilingOptions() 207 .setTileSizes({16, 8, 4}) 208 .setInterchange({1, 2, 0}) 209 .setLoopType(LinalgTilingLoopType::ParallelLoops), 210 LinalgTransformationFilter( 211 Identifier::get("par__with_perm__", ctx), 212 Identifier::get("after_par__with_perm__", ctx))); 213 214 //===--------------------------------------------------------------------===// 215 // Linalg to loops patterns. 216 //===--------------------------------------------------------------------===// 217 patterns.add<LinalgLoweringPattern<DotOp>>( 218 ctx, 219 /*loweringType=*/LinalgLoweringType::Loops, 220 LinalgTransformationFilter(Identifier::get("REG", ctx))); 221 222 //===--------------------------------------------------------------------===// 223 // Linalg distribution patterns. 224 //===--------------------------------------------------------------------===// 225 LinalgLoopDistributionOptions distributionOptions; 226 227 //===--------------------------------------------------------------------===// 228 // Linalg to vector contraction patterns. 229 //===--------------------------------------------------------------------===// 230 patterns.add<LinalgVectorizationPattern>( 231 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 232 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 233 234 //===--------------------------------------------------------------------===// 235 // Linalg generic interchange pattern. 236 //===--------------------------------------------------------------------===// 237 patterns.add<GenericOpInterchangePattern>( 238 ctx, 239 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 240 LinalgTransformationFilter(ArrayRef<Identifier>{}, 241 Identifier::get("PERMUTED", ctx))); 242 243 //===--------------------------------------------------------------------===// 244 // Linalg subview operands promotion. 245 //===--------------------------------------------------------------------===// 246 patterns.add<LinalgPromotionPattern<MatmulOp>>( 247 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 248 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 249 Identifier::get("_views_promoted_", ctx))); 250 patterns.add<LinalgPromotionPattern<MatmulOp>>( 251 ctx, 252 LinalgPromotionOptions() 253 .setOperandsToPromote({0}) 254 .setUseFullTileBuffersByDefault(true), 255 LinalgTransformationFilter( 256 Identifier::get("_promote_first_view_", ctx), 257 Identifier::get("_first_view_promoted_", ctx))); 258 patterns.add<LinalgPromotionPattern<FillOp>>( 259 ctx, 260 LinalgPromotionOptions() 261 .setOperandsToPromote({1}) 262 .setUseFullTileBuffers({false, true}) 263 .setAlignment(32), 264 LinalgTransformationFilter( 265 Identifier::get("_promote_views_aligned_", ctx), 266 Identifier::get("_views_aligned_promoted_", ctx))); 267 268 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 269 270 // Drop the marker. 271 funcOp.walk([](LinalgOp op) { 272 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 273 }); 274 } 275 276 static void fillL1TilingAndMatmulToVectorPatterns( 277 FuncOp funcOp, StringRef startMarker, 278 SmallVectorImpl<RewritePatternSet> &patternsVector) { 279 MLIRContext *ctx = funcOp.getContext(); 280 patternsVector.emplace_back( 281 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 282 ctx, 283 LinalgTilingOptions() 284 .setTileSizes({8, 12, 16}) 285 .setInterchange({1, 0, 2}), 286 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 287 Identifier::get("L1", ctx)))); 288 289 patternsVector.emplace_back( 290 ctx, 291 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 292 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 293 LinalgTransformationFilter(Identifier::get("L1", ctx), 294 Identifier::get("VEC", ctx)))); 295 296 patternsVector.emplace_back( 297 ctx, std::make_unique<LinalgVectorizationPattern>( 298 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 299 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 300 patternsVector.back().add<LinalgVectorizationPattern>( 301 ctx, LinalgTransformationFilter().addFilter( 302 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 303 } 304 305 //===----------------------------------------------------------------------===// 306 // Test promotion callbacks 307 //===----------------------------------------------------------------------===// 308 309 // Allocation call back 310 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 311 ArrayRef<Value> boundingSubViewSize, 312 DataLayout &layout) { 313 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 314 return b 315 .create<memref::AllocOp>( 316 subView.getLoc(), 317 MemRefType::get(shape, subView.getType().getElementType(), 318 /*affineMapComposition =*/{}, 3), 319 boundingSubViewSize) 320 .getResult(); 321 } 322 323 // Deallocation callback 324 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 325 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 326 return success(); 327 } 328 329 // Copy in call back 330 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 331 bool isOutput) { 332 auto floatType = src.getType().cast<MemRefType>().getElementType(); 333 if (!floatType.isa<FloatType>()) 334 return failure(); 335 if (!isOutput) { 336 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 337 FloatAttr::get(floatType, 42.0)); 338 b.create<FillOp>(src.getLoc(), cst, dst); 339 } 340 b.create<CopyOp>(src.getLoc(), src, dst); 341 return success(); 342 } 343 344 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 345 RewritePatternSet &patterns) { 346 patterns.add<LinalgTilingPattern<MatmulOp>>( 347 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 348 LinalgTransformationFilter(Identifier::get("START", ctx), 349 Identifier::get("PROMOTE", ctx))); 350 patterns.add<LinalgPromotionPattern<MatmulOp>>( 351 ctx, 352 LinalgPromotionOptions() 353 .setOperandsToPromote({0, 2}) 354 .setUseFullTileBuffers({false, false}) 355 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 356 .setCopyInOutFns( 357 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 358 return copyCallBackFn(b, src, dst, false); 359 }, 360 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 361 return copyCallBackFn(b, src, dst, true); 362 }), 363 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 364 } 365 366 template <typename IdOp, typename NProcsOp> 367 static SmallVector<ProcInfo, 2> 368 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 369 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 370 SmallVector<ProcInfo, 2> procInfo(count); 371 const char *xyz[] = {"x", "y", "z"}; 372 Type indexType = b.getIndexType(); 373 for (unsigned i = 0; i < count; ++i) { 374 procInfo[count - 1 - i] = { 375 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 376 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 377 } 378 return procInfo; 379 } 380 381 static void fillTileAndDistributePatterns(MLIRContext *context, 382 RewritePatternSet &patterns) { 383 { 384 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 385 cyclicNprocsEqNiters.distributionMethod.resize( 386 2, DistributionMethod::CyclicNumProcsEqNumIters); 387 cyclicNprocsEqNiters.procInfo = 388 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 389 patterns.add<LinalgTilingPattern<MatmulOp>>( 390 context, 391 LinalgTilingOptions() 392 .setTileSizes({8, 8, 4}) 393 .setLoopType(LinalgTilingLoopType::ParallelLoops) 394 .setDistributionOptions(cyclicNprocsEqNiters), 395 LinalgTransformationFilter( 396 Identifier::get("distribute1", context), 397 Identifier::get("after_distribute1", context))); 398 } 399 400 { 401 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 402 cyclicNprocsGeNiters.distributionMethod.resize( 403 2, DistributionMethod::CyclicNumProcsGeNumIters); 404 cyclicNprocsGeNiters.procInfo = 405 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 406 patterns.add<LinalgTilingPattern<MatmulOp>>( 407 context, 408 LinalgTilingOptions() 409 .setTileSizes({8, 8, 4}) 410 .setLoopType(LinalgTilingLoopType::ParallelLoops) 411 .setDistributionOptions(cyclicNprocsGeNiters), 412 LinalgTransformationFilter( 413 Identifier::get("distribute2", context), 414 Identifier::get("after_distribute2", context))); 415 } 416 417 { 418 LinalgLoopDistributionOptions cyclicNprocsDefault; 419 cyclicNprocsDefault.distributionMethod.resize(2, 420 DistributionMethod::Cyclic); 421 cyclicNprocsDefault.procInfo = 422 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 423 patterns.add<LinalgTilingPattern<MatmulOp>>( 424 context, 425 LinalgTilingOptions() 426 .setTileSizes({8, 8, 4}) 427 .setLoopType(LinalgTilingLoopType::ParallelLoops) 428 .setDistributionOptions(cyclicNprocsDefault), 429 LinalgTransformationFilter( 430 Identifier::get("distribute3", context), 431 Identifier::get("after_distribute3", context))); 432 } 433 434 { 435 LinalgLoopDistributionOptions cyclicNprocsMixed1; 436 cyclicNprocsMixed1.distributionMethod = { 437 DistributionMethod::CyclicNumProcsEqNumIters, 438 DistributionMethod::CyclicNumProcsGeNumIters}; 439 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 440 patterns.add<LinalgTilingPattern<MatmulOp>>( 441 context, 442 LinalgTilingOptions() 443 .setTileSizes({8, 8, 4}) 444 .setLoopType(LinalgTilingLoopType::ParallelLoops) 445 .setDistributionOptions(cyclicNprocsMixed1), 446 LinalgTransformationFilter( 447 Identifier::get("distribute4", context), 448 Identifier::get("after_distribute4", context))); 449 } 450 451 { 452 LinalgLoopDistributionOptions cyclicNprocsMixed2; 453 cyclicNprocsMixed2.distributionMethod = { 454 DistributionMethod::CyclicNumProcsGeNumIters, 455 DistributionMethod::Cyclic}; 456 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 457 patterns.add<LinalgTilingPattern<MatmulOp>>( 458 context, 459 LinalgTilingOptions() 460 .setTileSizes({8, 8, 4}) 461 .setLoopType(LinalgTilingLoopType::ParallelLoops) 462 .setDistributionOptions(cyclicNprocsMixed2), 463 LinalgTransformationFilter( 464 Identifier::get("distribute5", context), 465 Identifier::get("after_distribute5", context))); 466 } 467 468 { 469 LinalgLoopDistributionOptions cyclicNprocsMixed3; 470 cyclicNprocsMixed3.distributionMethod = { 471 DistributionMethod::Cyclic, 472 DistributionMethod::CyclicNumProcsEqNumIters}; 473 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 474 475 patterns.add<LinalgTilingPattern<MatmulOp>>( 476 context, 477 LinalgTilingOptions() 478 .setTileSizes({8, 8, 4}) 479 .setLoopType(LinalgTilingLoopType::ParallelLoops) 480 .setDistributionOptions(cyclicNprocsMixed3), 481 LinalgTransformationFilter( 482 Identifier::get("distribute6", context), 483 Identifier::get("after_distribute6", context))); 484 } 485 486 { 487 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 488 cyclicNprocsEqNiters.distributionMethod.resize(2, 489 DistributionMethod::Cyclic); 490 cyclicNprocsEqNiters.procInfo = 491 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 492 patterns.add<LinalgTilingPattern<MatmulOp>>( 493 context, 494 LinalgTilingOptions() 495 .setTileSizes({8, 8, 4}) 496 .setLoopType(LinalgTilingLoopType::Loops) 497 .setDistributionOptions(cyclicNprocsEqNiters), 498 LinalgTransformationFilter( 499 Identifier::get("tensors_distribute1", context), 500 Identifier::get("tensors_after_distribute1", context))); 501 } 502 } 503 504 static void 505 applyMatmulToVectorPatterns(FuncOp funcOp, 506 bool testMatmulToVectorPatterns1dTiling, 507 bool testMatmulToVectorPatterns2dTiling) { 508 MLIRContext *ctx = funcOp.getContext(); 509 SmallVector<RewritePatternSet, 4> stage1Patterns; 510 if (testMatmulToVectorPatterns1dTiling) { 511 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 512 stage1Patterns); 513 } else if (testMatmulToVectorPatterns2dTiling) { 514 stage1Patterns.emplace_back( 515 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 516 ctx, 517 LinalgTilingOptions() 518 .setTileSizes({768, 264, 768}) 519 .setInterchange({1, 2, 0}), 520 LinalgTransformationFilter(Identifier::get("START", ctx), 521 Identifier::get("L2", ctx)))); 522 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 523 stage1Patterns); 524 } 525 { 526 // Canonicalization patterns 527 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 528 vector::populateVectorTransferPermutationMapLoweringPatterns( 529 canonicalizationPatterns); 530 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 531 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 532 } 533 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 534 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 535 FrozenRewritePatternSet stage2Patterns = 536 getLinalgTilingCanonicalizationPatterns(ctx); 537 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 538 std::move(stage2Patterns)); 539 } 540 541 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 542 RewritePatternSet forwardPattern(funcOp.getContext()); 543 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 544 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 545 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 546 } 547 548 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 549 RewritePatternSet patterns(funcOp.getContext()); 550 patterns.add<LinalgVectorizationPattern>( 551 funcOp.getContext(), 552 LinalgTransformationFilter() 553 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 554 populatePadTensorOpVectorizationPatterns(patterns); 555 populateConvolutionVectorizationPatterns(patterns); 556 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 557 } 558 559 static void applyDecomposeConvolutionPatterns(FuncOp funcOp) { 560 RewritePatternSet patterns(funcOp.getContext()); 561 populateDecomposeConvolutionPatterns(patterns); 562 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 563 } 564 565 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 566 RewritePatternSet patterns(funcOp.getContext()); 567 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 568 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 569 } 570 571 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 572 RewritePatternSet patterns(funcOp.getContext()); 573 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 574 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 575 } 576 577 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 578 RewritePatternSet patterns(funcOp.getContext()); 579 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 580 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 581 } 582 583 static void applyTilePattern(FuncOp funcOp, std::string loopType, 584 ArrayRef<int64_t> tileSizes, 585 ArrayRef<int64_t> peeledLoops, 586 bool scalarizeDynamicDims) { 587 MLIRContext *context = funcOp.getContext(); 588 RewritePatternSet tilingPattern(context); 589 LinalgTilingLoopType type = 590 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 591 .Case("for", LinalgTilingLoopType::Loops) 592 .Case("affine", LinalgTilingLoopType::AffineLoops) 593 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 594 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 595 auto linalgTilingOptions = linalg::LinalgTilingOptions() 596 .setPeeledLoops(peeledLoops) 597 .setLoopType(type); 598 if (scalarizeDynamicDims) { 599 linalgTilingOptions.scalarizeDynamicDims(); 600 assert(tileSizes.empty() && 601 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 602 } else { 603 linalgTilingOptions.setTileSizes(tileSizes); 604 } 605 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, 606 linalg::LinalgTilingPattern<linalg::GenericOp>>( 607 context, linalgTilingOptions, 608 linalg::LinalgTransformationFilter(Identifier::get("tile", context))); 609 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 610 } 611 612 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 613 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 614 615 namespace { 616 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 617 /// `idx`-th loop contains only "full" iterations and a second loop for the 618 /// remaining partial iteration (if any). 619 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 620 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 621 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 622 } 623 624 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 625 PatternRewriter &rewriter) const override { 626 SmallVector<int64_t> peeledLoops; 627 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 628 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 629 peeledLoops = 630 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 631 return attr.cast<IntegerAttr>().getInt(); 632 })); 633 // Check if the loop was already peeled. 634 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 635 return failure(); 636 } 637 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 638 // No peeling of loop nests with a partial iteration. 639 return failure(); 640 641 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 642 return failure(); 643 644 // Peel loop and canonicalize. 645 TiledLoopOp result; 646 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 647 result))) 648 return failure(); 649 650 // Apply label, so that the same loop is not rewritten a second time. 651 peeledLoops.push_back(idx); 652 rewriter.updateRootInPlace(loopOp, [&]() { 653 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 654 }); 655 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 656 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 657 658 return success(); 659 } 660 661 /// Index of loop to peel. 662 int64_t idx; 663 664 /// If set to true, do not peel TiledLoopOps with a partial iteration. 665 bool skipPartial; 666 }; 667 } // namespace 668 669 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 670 ArrayRef<unsigned> loops, 671 bool skipPartial) { 672 MLIRContext *ctx = funcOp.getContext(); 673 RewritePatternSet patterns(ctx); 674 for (unsigned idx : loops) 675 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 676 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 677 678 // Drop the markers. 679 funcOp.walk([](TiledLoopOp op) { 680 op->removeAttr(kPeeledLoopsLabel); 681 op->removeAttr(kPartialIterationLabel); 682 }); 683 } 684 685 /// Apply transformations specified as patterns. 686 void TestLinalgTransforms::runOnFunction() { 687 auto lambda = [&](void *) { 688 getFunction().walk([](LinalgOp op) { 689 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 690 }); 691 }; 692 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 693 694 if (testPromotionOptions) { 695 RewritePatternSet patterns(&getContext()); 696 fillPromotionCallBackPatterns(&getContext(), patterns); 697 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 698 return; 699 } 700 if (testTileAndDistributionOptions) { 701 RewritePatternSet patterns(&getContext()); 702 fillTileAndDistributePatterns(&getContext(), patterns); 703 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 704 return; 705 } 706 if (testPatterns) 707 return applyPatterns(getFunction()); 708 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 709 return applyMatmulToVectorPatterns(getFunction(), 710 testMatmulToVectorPatterns1dTiling, 711 testMatmulToVectorPatterns2dTiling); 712 if (testVectorTransferForwardingPatterns) 713 return applyVectorTransferForwardingPatterns(getFunction()); 714 if (testGenericToVectorPattern) 715 return applyLinalgToVectorPatterns(getFunction()); 716 if (testTransformPadTensor) 717 return applyPadTensorToGenericPatterns(getFunction()); 718 if (testGeneralizePadTensor) 719 return applyGeneralizePadTensorPatterns(getFunction()); 720 if (testSwapSubTensorPadTensor) 721 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 722 if (testTiledLoopPeeling.hasValue()) 723 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, 724 skipPartial); 725 if (testTilePattern) 726 return applyTilePattern(getFunction(), loopType, tileSizes, peeledLoops, 727 /*scalarizeDynamicDims=*/false); 728 if (testTileScalarizeDynamicDims) 729 return applyTilePattern(getFunction(), loopType, tileSizes, 730 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 731 if (testDecomposeConvolutionPattern) 732 return applyDecomposeConvolutionPatterns(getFunction()); 733 } 734 735 namespace mlir { 736 namespace test { 737 void registerTestLinalgTransforms() { 738 PassRegistration<TestLinalgTransforms>(); 739 } 740 } // namespace test 741 } // namespace mlir 742