1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 #include "llvm/ADT/SetVector.h" 25 #include "llvm/ADT/SmallVector.h" 26 27 using namespace mlir; 28 using namespace mlir::linalg; 29 30 namespace { 31 struct TestLinalgTransforms 32 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 33 TestLinalgTransforms() = default; 34 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 35 36 void getDependentDialects(DialectRegistry ®istry) const override { 37 // clang-format off 38 registry.insert<AffineDialect, 39 memref::MemRefDialect, 40 scf::SCFDialect, 41 StandardOpsDialect, 42 vector::VectorDialect, 43 gpu::GPUDialect>(); 44 // clang-format on 45 } 46 StringRef getArgument() const final { 47 return "test-linalg-transform-patterns"; 48 } 49 StringRef getDescription() const final { 50 return "Test Linalg transformation patterns by applying them greedily."; 51 } 52 53 void runOnFunction() override; 54 55 Option<bool> testPatterns{*this, "test-patterns", 56 llvm::cl::desc("Test a mixed set of patterns"), 57 llvm::cl::init(false)}; 58 Option<bool> testMatmulToVectorPatterns1dTiling{ 59 *this, "test-matmul-to-vector-patterns-tile-1d", 60 llvm::cl::desc( 61 "Test a fused pass that applies patterns from matmul to vectors via " 62 "1-d tiling"), 63 llvm::cl::init(false)}; 64 Option<bool> testMatmulToVectorPatterns2dTiling{ 65 *this, "test-matmul-to-vector-patterns-tile-2d", 66 llvm::cl::desc( 67 "Test a fused pass that applies patterns from matmul to vectors via " 68 "2-d tiling"), 69 llvm::cl::init(false)}; 70 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 71 llvm::cl::desc("Test promotion options"), 72 llvm::cl::init(false)}; 73 Option<bool> testTileAndDistributionOptions{ 74 *this, "test-tile-and-distribute-options", 75 llvm::cl::desc("Test tile and distribute options"), 76 llvm::cl::init(false)}; 77 Option<bool> testVectorTransferForwardingPatterns{ 78 *this, "test-vector-transfer-forwarding-patterns", 79 llvm::cl::desc( 80 "Test a fused pass that forwards linalg.copy to vector.transfer"), 81 llvm::cl::init(false)}; 82 Option<bool> testGenericToVectorPattern{ 83 *this, "test-linalg-to-vector-patterns", 84 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 85 "in vector.contract form"), 86 llvm::cl::init(false)}; 87 Option<bool> testTilePattern{*this, "test-tile-pattern", 88 llvm::cl::desc("Test tile pattern"), 89 llvm::cl::init(false)}; 90 Option<bool> testTileScalarizeDynamicDims{ 91 *this, "test-tile-scalarize-dynamic-dims", 92 llvm::cl::desc("Test tiling of dynamic dims by 1"), 93 llvm::cl::init(false)}; 94 Option<int> testHoistPadding{*this, "test-hoist-padding", 95 llvm::cl::desc("Test hoist padding"), 96 llvm::cl::init(0)}; 97 Option<bool> testTransformPadTensor{ 98 *this, "test-transform-pad-tensor", 99 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 100 llvm::cl::init(false)}; 101 Option<bool> testGeneralizePadTensor{ 102 *this, "test-generalize-pad-tensor", 103 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 104 llvm::cl::init(false)}; 105 Option<bool> testSwapSubTensorPadTensor{ 106 *this, "test-swap-subtensor-padtensor", 107 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 108 "pad_tensor(subtensor)"), 109 llvm::cl::init(false)}; 110 Option<bool> padTiles{*this, "pad-tiles", 111 llvm::cl::desc("Pad tiles when test-tile-pattern"), 112 llvm::cl::init(false)}; 113 ListOption<int64_t> peeledLoops{ 114 *this, "peeled-loops", 115 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 116 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 117 ListOption<int64_t> tileSizes{ 118 *this, "tile-sizes", 119 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 120 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 121 ListOption<unsigned> testInterchangePattern{ 122 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 123 llvm::cl::desc("Test the interchange pattern.")}; 124 ListOption<unsigned> testTiledLoopPeeling{ 125 *this, "test-tiled-loop-peeling", 126 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 127 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 128 Option<bool> skipPartial{ 129 *this, "skip-partial", 130 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 131 llvm::cl::init(false)}; 132 Option<std::string> loopType{ 133 *this, "loop-type", 134 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 135 "tiled_loop"), 136 llvm::cl::init("for")}; 137 }; 138 } // end anonymous namespace 139 140 static void applyPatterns(FuncOp funcOp) { 141 MLIRContext *ctx = funcOp.getContext(); 142 RewritePatternSet patterns(ctx); 143 144 //===--------------------------------------------------------------------===// 145 // Linalg tiling patterns. 146 //===--------------------------------------------------------------------===// 147 patterns.add<LinalgTilingPattern<MatmulOp>>( 148 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 149 LinalgTransformationFilter(Identifier::get("MEM", ctx), 150 Identifier::get("L3", ctx))); 151 patterns.add<LinalgTilingPattern<MatmulOp>>( 152 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 153 LinalgTransformationFilter(Identifier::get("L3", ctx), 154 Identifier::get("L2", ctx))); 155 patterns.add<LinalgTilingPattern<MatmulOp>>( 156 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 157 LinalgTransformationFilter(Identifier::get("L2", ctx), 158 Identifier::get("L1", ctx))); 159 patterns.add<LinalgTilingPattern<MatmulOp>>( 160 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 161 LinalgTransformationFilter(Identifier::get("L1", ctx), 162 Identifier::get("REG", ctx))); 163 164 patterns.add<LinalgTilingPattern<MatvecOp>>( 165 ctx, 166 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 167 LinalgTilingLoopType::ParallelLoops), 168 LinalgTransformationFilter(ArrayRef<Identifier>{}, 169 Identifier::get("L1", ctx))); 170 171 patterns.add<LinalgTilingPattern<DotOp>>( 172 ctx, LinalgTilingOptions().setTileSizes(8000), 173 LinalgTransformationFilter( 174 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 175 Identifier::get("L3", ctx), 176 Identifier::get("L2", ctx)}, 177 Identifier::get("REG", ctx))); 178 179 //===--------------------------------------------------------------------===// 180 // Linalg tiling and permutation patterns. 181 //===--------------------------------------------------------------------===// 182 patterns.add<LinalgTilingPattern<MatmulOp>>( 183 ctx, 184 LinalgTilingOptions() 185 .setTileSizes({2000, 3000, 4000}) 186 .setInterchange({1, 2, 0}), 187 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 188 Identifier::get("L2__with_perm__", ctx))); 189 patterns.add<LinalgTilingPattern<MatmulOp>>( 190 ctx, 191 LinalgTilingOptions() 192 .setTileSizes({200, 300, 400}) 193 .setInterchange({1, 0, 2}), 194 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 195 Identifier::get("L1__with_perm__", ctx))); 196 patterns.add<LinalgTilingPattern<MatmulOp>>( 197 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 198 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 199 Identifier::get("REG__with_perm__", ctx))); 200 201 patterns.add<LinalgTilingPattern<MatvecOp>>( 202 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 203 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 204 Identifier::get("L1__with_perm__", ctx))); 205 206 patterns.add<LinalgTilingPattern<MatmulOp>>( 207 ctx, 208 LinalgTilingOptions() 209 .setTileSizes({16, 8, 4}) 210 .setInterchange({1, 2, 0}) 211 .setLoopType(LinalgTilingLoopType::ParallelLoops), 212 LinalgTransformationFilter( 213 Identifier::get("par__with_perm__", ctx), 214 Identifier::get("after_par__with_perm__", ctx))); 215 216 //===--------------------------------------------------------------------===// 217 // Linalg to loops patterns. 218 //===--------------------------------------------------------------------===// 219 patterns.add<LinalgLoweringPattern<DotOp>>( 220 ctx, 221 /*loweringType=*/LinalgLoweringType::Loops, 222 LinalgTransformationFilter(Identifier::get("REG", ctx))); 223 224 //===--------------------------------------------------------------------===// 225 // Linalg distribution patterns. 226 //===--------------------------------------------------------------------===// 227 LinalgLoopDistributionOptions distributionOptions; 228 229 //===--------------------------------------------------------------------===// 230 // Linalg to vector contraction patterns. 231 //===--------------------------------------------------------------------===// 232 patterns.add<LinalgVectorizationPattern>( 233 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 234 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 235 236 //===--------------------------------------------------------------------===// 237 // Linalg generic interchange pattern. 238 //===--------------------------------------------------------------------===// 239 patterns.add<GenericOpInterchangePattern>( 240 ctx, 241 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 242 LinalgTransformationFilter(ArrayRef<Identifier>{}, 243 Identifier::get("PERMUTED", ctx))); 244 245 //===--------------------------------------------------------------------===// 246 // Linalg subview operands promotion. 247 //===--------------------------------------------------------------------===// 248 patterns.add<LinalgPromotionPattern<MatmulOp>>( 249 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 250 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 251 Identifier::get("_views_promoted_", ctx))); 252 patterns.add<LinalgPromotionPattern<MatmulOp>>( 253 ctx, 254 LinalgPromotionOptions() 255 .setOperandsToPromote({0}) 256 .setUseFullTileBuffersByDefault(true), 257 LinalgTransformationFilter( 258 Identifier::get("_promote_first_view_", ctx), 259 Identifier::get("_first_view_promoted_", ctx))); 260 patterns.add<LinalgPromotionPattern<FillOp>>( 261 ctx, 262 LinalgPromotionOptions() 263 .setOperandsToPromote({1}) 264 .setUseFullTileBuffers({false, true}) 265 .setAlignment(32), 266 LinalgTransformationFilter( 267 Identifier::get("_promote_views_aligned_", ctx), 268 Identifier::get("_views_aligned_promoted_", ctx))); 269 270 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 271 272 // Drop the marker. 273 funcOp.walk([](LinalgOp op) { 274 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 275 }); 276 } 277 278 static void fillL1TilingAndMatmulToVectorPatterns( 279 FuncOp funcOp, StringRef startMarker, 280 SmallVectorImpl<RewritePatternSet> &patternsVector) { 281 MLIRContext *ctx = funcOp.getContext(); 282 patternsVector.emplace_back( 283 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 284 ctx, 285 LinalgTilingOptions() 286 .setTileSizes({8, 12, 16}) 287 .setInterchange({1, 0, 2}), 288 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 289 Identifier::get("L1", ctx)))); 290 291 patternsVector.emplace_back( 292 ctx, 293 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 294 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 295 LinalgTransformationFilter(Identifier::get("L1", ctx), 296 Identifier::get("VEC", ctx)))); 297 298 patternsVector.emplace_back( 299 ctx, std::make_unique<LinalgVectorizationPattern>( 300 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 301 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 302 patternsVector.back().add<LinalgVectorizationPattern>( 303 ctx, LinalgTransformationFilter().addFilter( 304 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 305 } 306 307 //===----------------------------------------------------------------------===// 308 // Test promotion callbacks 309 //===----------------------------------------------------------------------===// 310 311 // Allocation call back 312 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 313 ArrayRef<Value> boundingSubViewSize, 314 DataLayout &layout) { 315 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 316 return b 317 .create<memref::AllocOp>( 318 subView.getLoc(), 319 MemRefType::get(shape, subView.getType().getElementType(), 320 /*affineMapComposition =*/{}, 3), 321 boundingSubViewSize) 322 .getResult(); 323 } 324 325 // Deallocation callback 326 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 327 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 328 return success(); 329 } 330 331 // Copy in call back 332 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 333 bool isOutput) { 334 auto floatType = src.getType().cast<MemRefType>().getElementType(); 335 if (!floatType.isa<FloatType>()) 336 return failure(); 337 if (!isOutput) { 338 Value cst = 339 b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)); 340 b.create<FillOp>(src.getLoc(), cst, dst); 341 } 342 b.create<CopyOp>(src.getLoc(), src, dst); 343 return success(); 344 } 345 346 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 347 RewritePatternSet &patterns) { 348 patterns.add<LinalgTilingPattern<MatmulOp>>( 349 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 350 LinalgTransformationFilter(Identifier::get("START", ctx), 351 Identifier::get("PROMOTE", ctx))); 352 patterns.add<LinalgPromotionPattern<MatmulOp>>( 353 ctx, 354 LinalgPromotionOptions() 355 .setOperandsToPromote({0, 2}) 356 .setUseFullTileBuffers({false, false}) 357 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 358 .setCopyInOutFns( 359 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 360 return copyCallBackFn(b, src, dst, false); 361 }, 362 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 363 return copyCallBackFn(b, src, dst, true); 364 }), 365 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 366 } 367 368 template <typename IdOp, typename NProcsOp> 369 static SmallVector<ProcInfo, 2> 370 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 371 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 372 SmallVector<ProcInfo, 2> procInfo(count); 373 const char *xyz[] = {"x", "y", "z"}; 374 Type indexType = b.getIndexType(); 375 for (unsigned i = 0; i < count; ++i) { 376 procInfo[count - 1 - i] = { 377 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 378 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 379 } 380 return procInfo; 381 } 382 383 static void fillTileAndDistributePatterns(MLIRContext *context, 384 RewritePatternSet &patterns) { 385 { 386 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 387 cyclicNprocsEqNiters.distributionMethod.resize( 388 2, DistributionMethod::CyclicNumProcsEqNumIters); 389 cyclicNprocsEqNiters.procInfo = 390 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 391 patterns.add<LinalgTilingPattern<MatmulOp>>( 392 context, 393 LinalgTilingOptions() 394 .setTileSizes({8, 8, 4}) 395 .setLoopType(LinalgTilingLoopType::ParallelLoops) 396 .setDistributionOptions(cyclicNprocsEqNiters), 397 LinalgTransformationFilter( 398 Identifier::get("distribute1", context), 399 Identifier::get("after_distribute1", context))); 400 } 401 402 { 403 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 404 cyclicNprocsGeNiters.distributionMethod.resize( 405 2, DistributionMethod::CyclicNumProcsGeNumIters); 406 cyclicNprocsGeNiters.procInfo = 407 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 408 patterns.add<LinalgTilingPattern<MatmulOp>>( 409 context, 410 LinalgTilingOptions() 411 .setTileSizes({8, 8, 4}) 412 .setLoopType(LinalgTilingLoopType::ParallelLoops) 413 .setDistributionOptions(cyclicNprocsGeNiters), 414 LinalgTransformationFilter( 415 Identifier::get("distribute2", context), 416 Identifier::get("after_distribute2", context))); 417 } 418 419 { 420 LinalgLoopDistributionOptions cyclicNprocsDefault; 421 cyclicNprocsDefault.distributionMethod.resize(2, 422 DistributionMethod::Cyclic); 423 cyclicNprocsDefault.procInfo = 424 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 425 patterns.add<LinalgTilingPattern<MatmulOp>>( 426 context, 427 LinalgTilingOptions() 428 .setTileSizes({8, 8, 4}) 429 .setLoopType(LinalgTilingLoopType::ParallelLoops) 430 .setDistributionOptions(cyclicNprocsDefault), 431 LinalgTransformationFilter( 432 Identifier::get("distribute3", context), 433 Identifier::get("after_distribute3", context))); 434 } 435 436 { 437 LinalgLoopDistributionOptions cyclicNprocsMixed1; 438 cyclicNprocsMixed1.distributionMethod = { 439 DistributionMethod::CyclicNumProcsEqNumIters, 440 DistributionMethod::CyclicNumProcsGeNumIters}; 441 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 442 patterns.add<LinalgTilingPattern<MatmulOp>>( 443 context, 444 LinalgTilingOptions() 445 .setTileSizes({8, 8, 4}) 446 .setLoopType(LinalgTilingLoopType::ParallelLoops) 447 .setDistributionOptions(cyclicNprocsMixed1), 448 LinalgTransformationFilter( 449 Identifier::get("distribute4", context), 450 Identifier::get("after_distribute4", context))); 451 } 452 453 { 454 LinalgLoopDistributionOptions cyclicNprocsMixed2; 455 cyclicNprocsMixed2.distributionMethod = { 456 DistributionMethod::CyclicNumProcsGeNumIters, 457 DistributionMethod::Cyclic}; 458 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 459 patterns.add<LinalgTilingPattern<MatmulOp>>( 460 context, 461 LinalgTilingOptions() 462 .setTileSizes({8, 8, 4}) 463 .setLoopType(LinalgTilingLoopType::ParallelLoops) 464 .setDistributionOptions(cyclicNprocsMixed2), 465 LinalgTransformationFilter( 466 Identifier::get("distribute5", context), 467 Identifier::get("after_distribute5", context))); 468 } 469 470 { 471 LinalgLoopDistributionOptions cyclicNprocsMixed3; 472 cyclicNprocsMixed3.distributionMethod = { 473 DistributionMethod::Cyclic, 474 DistributionMethod::CyclicNumProcsEqNumIters}; 475 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 476 477 patterns.add<LinalgTilingPattern<MatmulOp>>( 478 context, 479 LinalgTilingOptions() 480 .setTileSizes({8, 8, 4}) 481 .setLoopType(LinalgTilingLoopType::ParallelLoops) 482 .setDistributionOptions(cyclicNprocsMixed3), 483 LinalgTransformationFilter( 484 Identifier::get("distribute6", context), 485 Identifier::get("after_distribute6", context))); 486 } 487 488 { 489 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 490 cyclicNprocsEqNiters.distributionMethod.resize(2, 491 DistributionMethod::Cyclic); 492 cyclicNprocsEqNiters.procInfo = 493 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 494 patterns.add<LinalgTilingPattern<MatmulOp>>( 495 context, 496 LinalgTilingOptions() 497 .setTileSizes({8, 8, 4}) 498 .setLoopType(LinalgTilingLoopType::Loops) 499 .setDistributionOptions(cyclicNprocsEqNiters), 500 LinalgTransformationFilter( 501 Identifier::get("tensors_distribute1", context), 502 Identifier::get("tensors_after_distribute1", context))); 503 } 504 } 505 506 static void 507 applyMatmulToVectorPatterns(FuncOp funcOp, 508 bool testMatmulToVectorPatterns1dTiling, 509 bool testMatmulToVectorPatterns2dTiling) { 510 MLIRContext *ctx = funcOp.getContext(); 511 SmallVector<RewritePatternSet, 4> stage1Patterns; 512 if (testMatmulToVectorPatterns1dTiling) { 513 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 514 stage1Patterns); 515 } else if (testMatmulToVectorPatterns2dTiling) { 516 stage1Patterns.emplace_back( 517 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 518 ctx, 519 LinalgTilingOptions() 520 .setTileSizes({768, 264, 768}) 521 .setInterchange({1, 2, 0}), 522 LinalgTransformationFilter(Identifier::get("START", ctx), 523 Identifier::get("L2", ctx)))); 524 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 525 stage1Patterns); 526 } 527 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 528 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 529 FrozenRewritePatternSet stage2Patterns = 530 getLinalgTilingCanonicalizationPatterns(ctx); 531 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 532 std::move(stage2Patterns)); 533 } 534 535 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 536 RewritePatternSet forwardPattern(funcOp.getContext()); 537 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 538 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 539 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 540 } 541 542 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 543 RewritePatternSet patterns(funcOp.getContext()); 544 patterns.add<LinalgVectorizationPattern>( 545 funcOp.getContext(), 546 LinalgTransformationFilter() 547 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 548 populatePadTensorOpVectorizationPatterns(patterns); 549 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 550 } 551 552 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 553 RewritePatternSet patterns(funcOp.getContext()); 554 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 555 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 556 } 557 558 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 559 RewritePatternSet patterns(funcOp.getContext()); 560 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 561 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 562 } 563 564 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 565 RewritePatternSet patterns(funcOp.getContext()); 566 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 567 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 568 } 569 570 // For now, just assume it is the zero of type. 571 // In the future, it should be the zero of type + op. 572 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 573 auto t = getElementTypeOrSelf(op.get()); 574 return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); 575 } 576 577 static void applyTilePattern(FuncOp funcOp, std::string loopType, 578 ArrayRef<int64_t> tileSizes, bool padTiles, 579 ArrayRef<int64_t> peeledLoops, 580 bool scalarizeDynamicDims) { 581 MLIRContext *context = funcOp.getContext(); 582 RewritePatternSet tilingPattern(context); 583 LinalgTilingLoopType type = 584 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 585 .Case("for", LinalgTilingLoopType::Loops) 586 .Case("affine", LinalgTilingLoopType::AffineLoops) 587 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 588 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 589 auto linalgTilingOptions = linalg::LinalgTilingOptions() 590 .setPeeledLoops(peeledLoops) 591 .setLoopType(type); 592 if (scalarizeDynamicDims) { 593 linalgTilingOptions.scalarizeDynamicDims(); 594 assert(tileSizes.empty() && 595 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 596 } else { 597 linalgTilingOptions.setTileSizes(tileSizes); 598 } 599 if (padTiles) 600 linalgTilingOptions.setPaddingValueComputationFunction( 601 getNeutralOfLinalgOp); 602 603 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, 604 linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>, 605 linalg::LinalgTilingPattern<linalg::GenericOp>>( 606 context, linalgTilingOptions, 607 linalg::LinalgTransformationFilter(Identifier::get("tile", context))); 608 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 609 } 610 611 static void applyInterchangePattern(FuncOp funcOp, 612 ArrayRef<unsigned> interchangeVector) { 613 MLIRContext *context = funcOp.getContext(); 614 RewritePatternSet interchangePattern(context); 615 interchangePattern.add<GenericOpInterchangePattern>( 616 context, interchangeVector, 617 LinalgTransformationFilter(ArrayRef<Identifier>{}, 618 Identifier::get("interchange", context))); 619 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 620 } 621 622 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 623 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 624 625 namespace { 626 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 627 /// `idx`-th loop contains only "full" iterations and a second loop for the 628 /// remaining partial iteration (if any). 629 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 630 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 631 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 632 } 633 634 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 635 PatternRewriter &rewriter) const override { 636 SmallVector<int64_t> peeledLoops; 637 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 638 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 639 peeledLoops = 640 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 641 return attr.cast<IntegerAttr>().getInt(); 642 })); 643 // Check if the loop was already peeled. 644 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 645 return failure(); 646 } 647 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 648 // No peeling of loop nests with a partial iteration. 649 return failure(); 650 651 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 652 return failure(); 653 654 // Peel loop and canonicalize. 655 TiledLoopOp result; 656 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 657 result))) 658 return failure(); 659 660 // Apply label, so that the same loop is not rewritten a second time. 661 peeledLoops.push_back(idx); 662 rewriter.updateRootInPlace(loopOp, [&]() { 663 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 664 }); 665 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 666 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 667 668 return success(); 669 } 670 671 /// Index of loop to peel. 672 int64_t idx; 673 674 /// If set to true, do not peel TiledLoopOps with a partial iteration. 675 bool skipPartial; 676 }; 677 } // namespace 678 679 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 680 ArrayRef<unsigned> loops, 681 bool skipPartial) { 682 MLIRContext *ctx = funcOp.getContext(); 683 RewritePatternSet patterns(ctx); 684 for (unsigned idx : loops) 685 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 686 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 687 688 // Drop the markers. 689 funcOp.walk([](TiledLoopOp op) { 690 op->removeAttr(kPeeledLoopsLabel); 691 op->removeAttr(kPartialIterationLabel); 692 }); 693 } 694 695 /// Apply transformations specified as patterns. 696 void TestLinalgTransforms::runOnFunction() { 697 auto lambda = [&](void *) { 698 getFunction().walk([](LinalgOp op) { 699 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 700 }); 701 }; 702 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 703 704 if (testPromotionOptions) { 705 RewritePatternSet patterns(&getContext()); 706 fillPromotionCallBackPatterns(&getContext(), patterns); 707 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 708 return; 709 } 710 if (testTileAndDistributionOptions) { 711 RewritePatternSet patterns(&getContext()); 712 fillTileAndDistributePatterns(&getContext(), patterns); 713 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 714 return; 715 } 716 if (testPatterns) 717 return applyPatterns(getFunction()); 718 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 719 return applyMatmulToVectorPatterns(getFunction(), 720 testMatmulToVectorPatterns1dTiling, 721 testMatmulToVectorPatterns2dTiling); 722 if (testVectorTransferForwardingPatterns) 723 return applyVectorTransferForwardingPatterns(getFunction()); 724 if (testGenericToVectorPattern) 725 return applyLinalgToVectorPatterns(getFunction()); 726 if (testTransformPadTensor) 727 return applyPadTensorToGenericPatterns(getFunction()); 728 if (testGeneralizePadTensor) 729 return applyGeneralizePadTensorPatterns(getFunction()); 730 if (testSwapSubTensorPadTensor) 731 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 732 if (testTiledLoopPeeling.hasValue()) 733 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, 734 skipPartial); 735 if (testTilePattern) 736 return applyTilePattern(getFunction(), loopType, tileSizes, padTiles, 737 peeledLoops, /*scalarizeDynamicDims=*/false); 738 if (testTileScalarizeDynamicDims) 739 return applyTilePattern(getFunction(), loopType, tileSizes, padTiles, 740 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 741 if (testHoistPadding) { 742 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 743 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 744 }); 745 } 746 if (testInterchangePattern.hasValue()) 747 return applyInterchangePattern(getFunction(), testInterchangePattern); 748 } 749 750 namespace mlir { 751 namespace test { 752 void registerTestLinalgTransforms() { 753 PassRegistration<TestLinalgTransforms>(); 754 } 755 } // namespace test 756 } // namespace mlir 757