1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 #include "llvm/ADT/SetVector.h" 25 #include "llvm/ADT/SmallVector.h" 26 27 using namespace mlir; 28 using namespace mlir::linalg; 29 30 namespace { 31 struct TestLinalgTransforms 32 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 33 TestLinalgTransforms() = default; 34 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 35 36 void getDependentDialects(DialectRegistry ®istry) const override { 37 // clang-format off 38 registry.insert<AffineDialect, 39 memref::MemRefDialect, 40 scf::SCFDialect, 41 StandardOpsDialect, 42 vector::VectorDialect, 43 gpu::GPUDialect>(); 44 // clang-format on 45 } 46 StringRef getArgument() const final { 47 return "test-linalg-transform-patterns"; 48 } 49 StringRef getDescription() const final { 50 return "Test Linalg transformation patterns by applying them greedily."; 51 } 52 53 void runOnFunction() override; 54 55 Option<bool> testPatterns{*this, "test-patterns", 56 llvm::cl::desc("Test a mixed set of patterns"), 57 llvm::cl::init(false)}; 58 Option<bool> testMatmulToVectorPatterns1dTiling{ 59 *this, "test-matmul-to-vector-patterns-tile-1d", 60 llvm::cl::desc( 61 "Test a fused pass that applies patterns from matmul to vectors via " 62 "1-d tiling"), 63 llvm::cl::init(false)}; 64 Option<bool> testMatmulToVectorPatterns2dTiling{ 65 *this, "test-matmul-to-vector-patterns-tile-2d", 66 llvm::cl::desc( 67 "Test a fused pass that applies patterns from matmul to vectors via " 68 "2-d tiling"), 69 llvm::cl::init(false)}; 70 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 71 llvm::cl::desc("Test promotion options"), 72 llvm::cl::init(false)}; 73 Option<bool> testTileAndDistributionOptions{ 74 *this, "test-tile-and-distribute-options", 75 llvm::cl::desc("Test tile and distribute options"), 76 llvm::cl::init(false)}; 77 Option<bool> testVectorTransferForwardingPatterns{ 78 *this, "test-vector-transfer-forwarding-patterns", 79 llvm::cl::desc( 80 "Test a fused pass that forwards linalg.copy to vector.transfer"), 81 llvm::cl::init(false)}; 82 Option<bool> testGenericToVectorPattern{ 83 *this, "test-linalg-to-vector-patterns", 84 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 85 "in vector.contract form"), 86 llvm::cl::init(false)}; 87 Option<bool> testTilePattern{*this, "test-tile-pattern", 88 llvm::cl::desc("Test tile pattern"), 89 llvm::cl::init(false)}; 90 Option<bool> testTileScalarizeDynamicDims{ 91 *this, "test-tile-scalarize-dynamic-dims", 92 llvm::cl::desc("Test tiling of dynamic dims by 1"), 93 llvm::cl::init(false)}; 94 Option<int> testHoistPadding{*this, "test-hoist-padding", 95 llvm::cl::desc("Test hoist padding"), 96 llvm::cl::init(0)}; 97 Option<bool> testTransformPadTensor{ 98 *this, "test-transform-pad-tensor", 99 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 100 llvm::cl::init(false)}; 101 Option<bool> testGeneralizePadTensor{ 102 *this, "test-generalize-pad-tensor", 103 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 104 llvm::cl::init(false)}; 105 Option<bool> testSwapSubTensorPadTensor{ 106 *this, "test-swap-subtensor-padtensor", 107 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 108 "pad_tensor(subtensor)"), 109 llvm::cl::init(false)}; 110 Option<bool> padTiles{*this, "pad-tiles", 111 llvm::cl::desc("Pad tiles when test-tile-pattern"), 112 llvm::cl::init(false)}; 113 ListOption<int64_t> peeledLoops{ 114 *this, "peeled-loops", 115 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 116 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 117 ListOption<int64_t> tileSizes{ 118 *this, "tile-sizes", 119 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 120 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 121 ListOption<unsigned> testInterchangePattern{ 122 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 123 llvm::cl::desc("Test the interchange pattern.")}; 124 ListOption<unsigned> testTiledLoopPeeling{ 125 *this, "test-tiled-loop-peeling", 126 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 127 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 128 Option<bool> skipPartial{ 129 *this, "skip-partial", 130 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 131 llvm::cl::init(false)}; 132 }; 133 } // end anonymous namespace 134 135 static void applyPatterns(FuncOp funcOp) { 136 MLIRContext *ctx = funcOp.getContext(); 137 RewritePatternSet patterns(ctx); 138 139 //===--------------------------------------------------------------------===// 140 // Linalg tiling patterns. 141 //===--------------------------------------------------------------------===// 142 patterns.add<LinalgTilingPattern<MatmulOp>>( 143 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 144 LinalgTransformationFilter(Identifier::get("MEM", ctx), 145 Identifier::get("L3", ctx))); 146 patterns.add<LinalgTilingPattern<MatmulOp>>( 147 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 148 LinalgTransformationFilter(Identifier::get("L3", ctx), 149 Identifier::get("L2", ctx))); 150 patterns.add<LinalgTilingPattern<MatmulOp>>( 151 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 152 LinalgTransformationFilter(Identifier::get("L2", ctx), 153 Identifier::get("L1", ctx))); 154 patterns.add<LinalgTilingPattern<MatmulOp>>( 155 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 156 LinalgTransformationFilter(Identifier::get("L1", ctx), 157 Identifier::get("REG", ctx))); 158 159 patterns.add<LinalgTilingPattern<MatvecOp>>( 160 ctx, 161 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 162 LinalgTilingLoopType::ParallelLoops), 163 LinalgTransformationFilter(ArrayRef<Identifier>{}, 164 Identifier::get("L1", ctx))); 165 166 patterns.add<LinalgTilingPattern<DotOp>>( 167 ctx, LinalgTilingOptions().setTileSizes(8000), 168 LinalgTransformationFilter( 169 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 170 Identifier::get("L3", ctx), 171 Identifier::get("L2", ctx)}, 172 Identifier::get("REG", ctx))); 173 174 //===--------------------------------------------------------------------===// 175 // Linalg tiling and permutation patterns. 176 //===--------------------------------------------------------------------===// 177 patterns.add<LinalgTilingPattern<MatmulOp>>( 178 ctx, 179 LinalgTilingOptions() 180 .setTileSizes({2000, 3000, 4000}) 181 .setInterchange({1, 2, 0}), 182 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 183 Identifier::get("L2__with_perm__", ctx))); 184 patterns.add<LinalgTilingPattern<MatmulOp>>( 185 ctx, 186 LinalgTilingOptions() 187 .setTileSizes({200, 300, 400}) 188 .setInterchange({1, 0, 2}), 189 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 190 Identifier::get("L1__with_perm__", ctx))); 191 patterns.add<LinalgTilingPattern<MatmulOp>>( 192 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 193 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 194 Identifier::get("REG__with_perm__", ctx))); 195 196 patterns.add<LinalgTilingPattern<MatvecOp>>( 197 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 198 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 199 Identifier::get("L1__with_perm__", ctx))); 200 201 patterns.add<LinalgTilingPattern<MatmulOp>>( 202 ctx, 203 LinalgTilingOptions() 204 .setTileSizes({16, 8, 4}) 205 .setInterchange({1, 2, 0}) 206 .setLoopType(LinalgTilingLoopType::ParallelLoops), 207 LinalgTransformationFilter( 208 Identifier::get("par__with_perm__", ctx), 209 Identifier::get("after_par__with_perm__", ctx))); 210 211 //===--------------------------------------------------------------------===// 212 // Linalg to loops patterns. 213 //===--------------------------------------------------------------------===// 214 patterns.add<LinalgLoweringPattern<DotOp>>( 215 ctx, 216 /*loweringType=*/LinalgLoweringType::Loops, 217 LinalgTransformationFilter(Identifier::get("REG", ctx))); 218 219 //===--------------------------------------------------------------------===// 220 // Linalg distribution patterns. 221 //===--------------------------------------------------------------------===// 222 LinalgLoopDistributionOptions distributionOptions; 223 224 //===--------------------------------------------------------------------===// 225 // Linalg to vector contraction patterns. 226 //===--------------------------------------------------------------------===// 227 patterns.add<LinalgVectorizationPattern>( 228 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 229 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 230 231 //===--------------------------------------------------------------------===// 232 // Linalg generic interchange pattern. 233 //===--------------------------------------------------------------------===// 234 patterns.add<GenericOpInterchangePattern>( 235 ctx, 236 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 237 LinalgTransformationFilter(ArrayRef<Identifier>{}, 238 Identifier::get("PERMUTED", ctx))); 239 240 //===--------------------------------------------------------------------===// 241 // Linalg subview operands promotion. 242 //===--------------------------------------------------------------------===// 243 patterns.add<LinalgPromotionPattern<MatmulOp>>( 244 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 245 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 246 Identifier::get("_views_promoted_", ctx))); 247 patterns.add<LinalgPromotionPattern<MatmulOp>>( 248 ctx, 249 LinalgPromotionOptions() 250 .setOperandsToPromote({0}) 251 .setUseFullTileBuffersByDefault(true), 252 LinalgTransformationFilter( 253 Identifier::get("_promote_first_view_", ctx), 254 Identifier::get("_first_view_promoted_", ctx))); 255 patterns.add<LinalgPromotionPattern<FillOp>>( 256 ctx, 257 LinalgPromotionOptions() 258 .setOperandsToPromote({1}) 259 .setUseFullTileBuffers({false, true}) 260 .setAlignment(32), 261 LinalgTransformationFilter( 262 Identifier::get("_promote_views_aligned_", ctx), 263 Identifier::get("_views_aligned_promoted_", ctx))); 264 265 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 266 267 // Drop the marker. 268 funcOp.walk([](LinalgOp op) { 269 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 270 }); 271 } 272 273 static void fillL1TilingAndMatmulToVectorPatterns( 274 FuncOp funcOp, StringRef startMarker, 275 SmallVectorImpl<RewritePatternSet> &patternsVector) { 276 MLIRContext *ctx = funcOp.getContext(); 277 patternsVector.emplace_back( 278 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 279 ctx, 280 LinalgTilingOptions() 281 .setTileSizes({8, 12, 16}) 282 .setInterchange({1, 0, 2}), 283 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 284 Identifier::get("L1", ctx)))); 285 286 patternsVector.emplace_back( 287 ctx, 288 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 289 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 290 LinalgTransformationFilter(Identifier::get("L1", ctx), 291 Identifier::get("VEC", ctx)))); 292 293 patternsVector.emplace_back( 294 ctx, std::make_unique<LinalgVectorizationPattern>( 295 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 296 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 297 patternsVector.back().add<LinalgVectorizationPattern>( 298 ctx, LinalgTransformationFilter().addFilter( 299 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 300 } 301 302 //===----------------------------------------------------------------------===// 303 // Test promotion callbacks 304 //===----------------------------------------------------------------------===// 305 306 // Allocation call back 307 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 308 ArrayRef<Value> boundingSubViewSize, 309 DataLayout &layout) { 310 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 311 return b 312 .create<memref::AllocOp>( 313 subView.getLoc(), 314 MemRefType::get(shape, subView.getType().getElementType(), 315 /*affineMapComposition =*/{}, 3), 316 boundingSubViewSize) 317 .getResult(); 318 } 319 320 // Deallocation callback 321 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 322 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 323 return success(); 324 } 325 326 // Copy in call back 327 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 328 bool isOutput) { 329 auto floatType = src.getType().cast<MemRefType>().getElementType(); 330 if (!floatType.isa<FloatType>()) 331 return failure(); 332 if (!isOutput) { 333 Value cst = 334 b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)); 335 b.create<FillOp>(src.getLoc(), cst, dst); 336 } 337 b.create<CopyOp>(src.getLoc(), src, dst); 338 return success(); 339 } 340 341 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 342 RewritePatternSet &patterns) { 343 patterns.add<LinalgTilingPattern<MatmulOp>>( 344 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 345 LinalgTransformationFilter(Identifier::get("START", ctx), 346 Identifier::get("PROMOTE", ctx))); 347 patterns.add<LinalgPromotionPattern<MatmulOp>>( 348 ctx, 349 LinalgPromotionOptions() 350 .setOperandsToPromote({0, 2}) 351 .setUseFullTileBuffers({false, false}) 352 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 353 .setCopyInOutFns( 354 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 355 return copyCallBackFn(b, src, dst, false); 356 }, 357 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 358 return copyCallBackFn(b, src, dst, true); 359 }), 360 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 361 } 362 363 template <typename IdOp, typename NProcsOp> 364 static SmallVector<ProcInfo, 2> 365 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 366 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 367 SmallVector<ProcInfo, 2> procInfo(count); 368 const char *xyz[] = {"x", "y", "z"}; 369 Type indexType = b.getIndexType(); 370 for (unsigned i = 0; i < count; ++i) { 371 procInfo[count - 1 - i] = { 372 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 373 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 374 } 375 return procInfo; 376 } 377 378 static void fillTileAndDistributePatterns(MLIRContext *context, 379 RewritePatternSet &patterns) { 380 { 381 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 382 cyclicNprocsEqNiters.distributionMethod.resize( 383 2, DistributionMethod::CyclicNumProcsEqNumIters); 384 cyclicNprocsEqNiters.procInfo = 385 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 386 patterns.add<LinalgTilingPattern<MatmulOp>>( 387 context, 388 LinalgTilingOptions() 389 .setTileSizes({8, 8, 4}) 390 .setLoopType(LinalgTilingLoopType::ParallelLoops) 391 .setDistributionOptions(cyclicNprocsEqNiters), 392 LinalgTransformationFilter( 393 Identifier::get("distribute1", context), 394 Identifier::get("after_distribute1", context))); 395 } 396 397 { 398 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 399 cyclicNprocsGeNiters.distributionMethod.resize( 400 2, DistributionMethod::CyclicNumProcsGeNumIters); 401 cyclicNprocsGeNiters.procInfo = 402 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 403 patterns.add<LinalgTilingPattern<MatmulOp>>( 404 context, 405 LinalgTilingOptions() 406 .setTileSizes({8, 8, 4}) 407 .setLoopType(LinalgTilingLoopType::ParallelLoops) 408 .setDistributionOptions(cyclicNprocsGeNiters), 409 LinalgTransformationFilter( 410 Identifier::get("distribute2", context), 411 Identifier::get("after_distribute2", context))); 412 } 413 414 { 415 LinalgLoopDistributionOptions cyclicNprocsDefault; 416 cyclicNprocsDefault.distributionMethod.resize(2, 417 DistributionMethod::Cyclic); 418 cyclicNprocsDefault.procInfo = 419 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 420 patterns.add<LinalgTilingPattern<MatmulOp>>( 421 context, 422 LinalgTilingOptions() 423 .setTileSizes({8, 8, 4}) 424 .setLoopType(LinalgTilingLoopType::ParallelLoops) 425 .setDistributionOptions(cyclicNprocsDefault), 426 LinalgTransformationFilter( 427 Identifier::get("distribute3", context), 428 Identifier::get("after_distribute3", context))); 429 } 430 431 { 432 LinalgLoopDistributionOptions cyclicNprocsMixed1; 433 cyclicNprocsMixed1.distributionMethod = { 434 DistributionMethod::CyclicNumProcsEqNumIters, 435 DistributionMethod::CyclicNumProcsGeNumIters}; 436 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 437 patterns.add<LinalgTilingPattern<MatmulOp>>( 438 context, 439 LinalgTilingOptions() 440 .setTileSizes({8, 8, 4}) 441 .setLoopType(LinalgTilingLoopType::ParallelLoops) 442 .setDistributionOptions(cyclicNprocsMixed1), 443 LinalgTransformationFilter( 444 Identifier::get("distribute4", context), 445 Identifier::get("after_distribute4", context))); 446 } 447 448 { 449 LinalgLoopDistributionOptions cyclicNprocsMixed2; 450 cyclicNprocsMixed2.distributionMethod = { 451 DistributionMethod::CyclicNumProcsGeNumIters, 452 DistributionMethod::Cyclic}; 453 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 454 patterns.add<LinalgTilingPattern<MatmulOp>>( 455 context, 456 LinalgTilingOptions() 457 .setTileSizes({8, 8, 4}) 458 .setLoopType(LinalgTilingLoopType::ParallelLoops) 459 .setDistributionOptions(cyclicNprocsMixed2), 460 LinalgTransformationFilter( 461 Identifier::get("distribute5", context), 462 Identifier::get("after_distribute5", context))); 463 } 464 465 { 466 LinalgLoopDistributionOptions cyclicNprocsMixed3; 467 cyclicNprocsMixed3.distributionMethod = { 468 DistributionMethod::Cyclic, 469 DistributionMethod::CyclicNumProcsEqNumIters}; 470 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 471 472 patterns.add<LinalgTilingPattern<MatmulOp>>( 473 context, 474 LinalgTilingOptions() 475 .setTileSizes({8, 8, 4}) 476 .setLoopType(LinalgTilingLoopType::ParallelLoops) 477 .setDistributionOptions(cyclicNprocsMixed3), 478 LinalgTransformationFilter( 479 Identifier::get("distribute6", context), 480 Identifier::get("after_distribute6", context))); 481 } 482 483 { 484 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 485 cyclicNprocsEqNiters.distributionMethod.resize(2, 486 DistributionMethod::Cyclic); 487 cyclicNprocsEqNiters.procInfo = 488 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 489 patterns.add<LinalgTilingPattern<MatmulOp>>( 490 context, 491 LinalgTilingOptions() 492 .setTileSizes({8, 8, 4}) 493 .setLoopType(LinalgTilingLoopType::Loops) 494 .setDistributionOptions(cyclicNprocsEqNiters), 495 LinalgTransformationFilter( 496 Identifier::get("tensors_distribute1", context), 497 Identifier::get("tensors_after_distribute1", context))); 498 } 499 } 500 501 static void 502 applyMatmulToVectorPatterns(FuncOp funcOp, 503 bool testMatmulToVectorPatterns1dTiling, 504 bool testMatmulToVectorPatterns2dTiling) { 505 MLIRContext *ctx = funcOp.getContext(); 506 SmallVector<RewritePatternSet, 4> stage1Patterns; 507 if (testMatmulToVectorPatterns1dTiling) { 508 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 509 stage1Patterns); 510 } else if (testMatmulToVectorPatterns2dTiling) { 511 stage1Patterns.emplace_back( 512 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 513 ctx, 514 LinalgTilingOptions() 515 .setTileSizes({768, 264, 768}) 516 .setInterchange({1, 2, 0}), 517 LinalgTransformationFilter(Identifier::get("START", ctx), 518 Identifier::get("L2", ctx)))); 519 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 520 stage1Patterns); 521 } 522 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 523 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 524 FrozenRewritePatternSet stage2Patterns = 525 getLinalgTilingCanonicalizationPatterns(ctx); 526 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 527 std::move(stage2Patterns)); 528 } 529 530 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 531 RewritePatternSet forwardPattern(funcOp.getContext()); 532 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 533 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 534 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 535 } 536 537 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 538 RewritePatternSet patterns(funcOp.getContext()); 539 patterns.add<LinalgVectorizationPattern>( 540 funcOp.getContext(), 541 LinalgTransformationFilter() 542 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 543 populatePadTensorOpVectorizationPatterns(patterns); 544 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 545 } 546 547 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 548 RewritePatternSet patterns(funcOp.getContext()); 549 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 550 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 551 } 552 553 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 554 RewritePatternSet patterns(funcOp.getContext()); 555 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 556 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 557 } 558 559 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 560 RewritePatternSet patterns(funcOp.getContext()); 561 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 562 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 563 } 564 565 // For now, just assume it is the zero of type. 566 // In the future, it should be the zero of type + op. 567 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 568 auto t = getElementTypeOrSelf(op.get()); 569 return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); 570 } 571 572 static void applyTilePattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes, 573 bool padTiles, ArrayRef<int64_t> peeledLoops, 574 bool scalarizeDynamicDims) { 575 MLIRContext *context = funcOp.getContext(); 576 RewritePatternSet tilingPattern(context); 577 auto linalgTilingOptions = 578 linalg::LinalgTilingOptions().setPeeledLoops(peeledLoops); 579 if (scalarizeDynamicDims) { 580 linalgTilingOptions.scalarizeDynamicDims(); 581 assert(tileSizes.empty() && 582 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 583 } else { 584 linalgTilingOptions.setTileSizes(tileSizes); 585 } 586 if (padTiles) 587 linalgTilingOptions.setPaddingValueComputationFunction( 588 getNeutralOfLinalgOp); 589 590 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, 591 linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>, 592 linalg::LinalgTilingPattern<linalg::GenericOp>>( 593 context, linalgTilingOptions, 594 linalg::LinalgTransformationFilter(Identifier::get("tile", context))); 595 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 596 } 597 598 static void applyInterchangePattern(FuncOp funcOp, 599 ArrayRef<unsigned> interchangeVector) { 600 MLIRContext *context = funcOp.getContext(); 601 RewritePatternSet interchangePattern(context); 602 interchangePattern.add<GenericOpInterchangePattern>( 603 context, interchangeVector, 604 LinalgTransformationFilter(ArrayRef<Identifier>{}, 605 Identifier::get("interchange", context))); 606 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 607 } 608 609 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 610 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 611 612 namespace { 613 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 614 /// `idx`-th loop contains only "full" iterations and a second loop for the 615 /// remaining partial iteration (if any). 616 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 617 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 618 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 619 } 620 621 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 622 PatternRewriter &rewriter) const override { 623 SmallVector<int64_t> peeledLoops; 624 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 625 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 626 peeledLoops = 627 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 628 return attr.cast<IntegerAttr>().getInt(); 629 })); 630 // Check if the loop was already peeled. 631 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 632 return failure(); 633 } 634 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 635 // No peeling of loop nests with a partial iteration. 636 return failure(); 637 638 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 639 return failure(); 640 641 // Peel loop and canonicalize. 642 TiledLoopOp result; 643 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 644 result))) 645 return failure(); 646 647 // Apply label, so that the same loop is not rewritten a second time. 648 peeledLoops.push_back(idx); 649 rewriter.updateRootInPlace(loopOp, [&]() { 650 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 651 }); 652 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 653 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 654 655 return success(); 656 } 657 658 /// Index of loop to peel. 659 int64_t idx; 660 661 /// If set to true, do not peel TiledLoopOps with a partial iteration. 662 bool skipPartial; 663 }; 664 } // namespace 665 666 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 667 ArrayRef<unsigned> loops, 668 bool skipPartial) { 669 MLIRContext *ctx = funcOp.getContext(); 670 RewritePatternSet patterns(ctx); 671 for (unsigned idx : loops) 672 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 673 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 674 675 // Drop the markers. 676 funcOp.walk([](TiledLoopOp op) { 677 op->removeAttr(kPeeledLoopsLabel); 678 op->removeAttr(kPartialIterationLabel); 679 }); 680 } 681 682 /// Apply transformations specified as patterns. 683 void TestLinalgTransforms::runOnFunction() { 684 auto lambda = [&](void *) { 685 getFunction().walk([](LinalgOp op) { 686 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 687 }); 688 }; 689 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 690 691 if (testPromotionOptions) { 692 RewritePatternSet patterns(&getContext()); 693 fillPromotionCallBackPatterns(&getContext(), patterns); 694 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 695 return; 696 } 697 if (testTileAndDistributionOptions) { 698 RewritePatternSet patterns(&getContext()); 699 fillTileAndDistributePatterns(&getContext(), patterns); 700 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 701 return; 702 } 703 if (testPatterns) 704 return applyPatterns(getFunction()); 705 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 706 return applyMatmulToVectorPatterns(getFunction(), 707 testMatmulToVectorPatterns1dTiling, 708 testMatmulToVectorPatterns2dTiling); 709 if (testVectorTransferForwardingPatterns) 710 return applyVectorTransferForwardingPatterns(getFunction()); 711 if (testGenericToVectorPattern) 712 return applyLinalgToVectorPatterns(getFunction()); 713 if (testTransformPadTensor) 714 return applyPadTensorToGenericPatterns(getFunction()); 715 if (testGeneralizePadTensor) 716 return applyGeneralizePadTensorPatterns(getFunction()); 717 if (testSwapSubTensorPadTensor) 718 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 719 if (testTiledLoopPeeling.hasValue()) 720 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, 721 skipPartial); 722 if (testTilePattern) 723 return applyTilePattern(getFunction(), tileSizes, padTiles, peeledLoops, 724 /*scalarizeDynamicDims=*/false); 725 if (testTileScalarizeDynamicDims) 726 return applyTilePattern(getFunction(), tileSizes, padTiles, 727 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 728 if (testHoistPadding) { 729 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 730 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 731 }); 732 } 733 if (testInterchangePattern.hasValue()) 734 return applyInterchangePattern(getFunction(), testInterchangePattern); 735 } 736 737 namespace mlir { 738 namespace test { 739 void registerTestLinalgTransforms() { 740 PassRegistration<TestLinalgTransforms>(); 741 } 742 } // namespace test 743 } // namespace mlir 744