1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallVector.h" 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 namespace { 34 struct TestLinalgTransforms 35 : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> { 36 TestLinalgTransforms() = default; 37 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 38 39 void getDependentDialects(DialectRegistry ®istry) const override { 40 // clang-format off 41 registry.insert<AffineDialect, 42 memref::MemRefDialect, 43 scf::SCFDialect, 44 StandardOpsDialect, 45 linalg::LinalgDialect, 46 vector::VectorDialect, 47 gpu::GPUDialect>(); 48 // clang-format on 49 } 50 StringRef getArgument() const final { 51 return "test-linalg-transform-patterns"; 52 } 53 StringRef getDescription() const final { 54 return "Test Linalg transformation patterns by applying them greedily."; 55 } 56 57 void runOnOperation() override; 58 59 Option<bool> testPatterns{*this, "test-patterns", 60 llvm::cl::desc("Test a mixed set of patterns"), 61 llvm::cl::init(false)}; 62 Option<bool> testMatmulToVectorPatterns1dTiling{ 63 *this, "test-matmul-to-vector-patterns-tile-1d", 64 llvm::cl::desc( 65 "Test a fused pass that applies patterns from matmul to vectors via " 66 "1-d tiling"), 67 llvm::cl::init(false)}; 68 Option<bool> testMatmulToVectorPatterns2dTiling{ 69 *this, "test-matmul-to-vector-patterns-tile-2d", 70 llvm::cl::desc( 71 "Test a fused pass that applies patterns from matmul to vectors via " 72 "2-d tiling"), 73 llvm::cl::init(false)}; 74 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 75 llvm::cl::desc("Test promotion options"), 76 llvm::cl::init(false)}; 77 Option<bool> testTileAndDistributionOptions{ 78 *this, "test-tile-and-distribute-options", 79 llvm::cl::desc("Test tile and distribute options"), 80 llvm::cl::init(false)}; 81 Option<bool> testTileFuseAndDistributionOptions{ 82 *this, "test-tile-fuse-and-distribute-options", 83 llvm::cl::desc("Test tile, fuse and distribute options"), 84 llvm::cl::init(false)}; 85 Option<bool> testVectorTransferForwardingPatterns{ 86 *this, "test-vector-transfer-forwarding-patterns", 87 llvm::cl::desc( 88 "Test a fused pass that forwards memref.copy to vector.transfer"), 89 llvm::cl::init(false)}; 90 Option<bool> testGenericToVectorPattern{ 91 *this, "test-linalg-to-vector-patterns", 92 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 93 "in vector.contract form"), 94 llvm::cl::init(false)}; 95 Option<bool> testTilePattern{*this, "test-tile-pattern", 96 llvm::cl::desc("Test tile pattern"), 97 llvm::cl::init(false)}; 98 Option<bool> testTileScalarizeDynamicDims{ 99 *this, "test-tile-scalarize-dynamic-dims", 100 llvm::cl::desc("Test tiling of dynamic dims by 1"), 101 llvm::cl::init(false)}; 102 Option<bool> testTransformPadTensor{ 103 *this, "test-transform-pad-tensor", 104 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 105 llvm::cl::init(false)}; 106 Option<bool> testGeneralizePadTensor{ 107 *this, "test-generalize-pad-tensor", 108 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 109 llvm::cl::init(false)}; 110 Option<bool> testSwapSubTensorPadTensor{ 111 *this, "test-swap-subtensor-padtensor", 112 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 113 "pad_tensor(subtensor)"), 114 llvm::cl::init(false)}; 115 ListOption<int64_t> peeledLoops{ 116 *this, "peeled-loops", 117 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 118 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 119 ListOption<int64_t> tileSizes{ 120 *this, "tile-sizes", 121 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 122 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 123 ListOption<unsigned> testTiledLoopPeeling{ 124 *this, "test-tiled-loop-peeling", 125 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 126 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 127 Option<bool> skipPartial{ 128 *this, "skip-partial", 129 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 130 llvm::cl::init(false)}; 131 Option<std::string> loopType{ 132 *this, "loop-type", 133 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 134 "tiled_loop"), 135 llvm::cl::init("for")}; 136 }; 137 } // namespace 138 139 static void applyPatterns(FuncOp funcOp) { 140 MLIRContext *ctx = funcOp.getContext(); 141 RewritePatternSet patterns(ctx); 142 143 //===--------------------------------------------------------------------===// 144 // Linalg tiling patterns. 145 //===--------------------------------------------------------------------===// 146 patterns.add<LinalgTilingPattern>( 147 MatmulOp::getOperationName(), ctx, 148 LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 149 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 150 StringAttr::get(ctx, "L3"))); 151 patterns.add<LinalgTilingPattern>( 152 MatmulOp::getOperationName(), ctx, 153 LinalgTilingOptions().setTileSizes({200, 300, 400}), 154 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 155 StringAttr::get(ctx, "L2"))); 156 patterns.add<LinalgTilingPattern>( 157 MatmulOp::getOperationName(), ctx, 158 LinalgTilingOptions().setTileSizes({20, 30, 40}), 159 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 160 StringAttr::get(ctx, "L1"))); 161 patterns.add<LinalgTilingPattern>( 162 MatmulOp::getOperationName(), ctx, 163 LinalgTilingOptions().setTileSizes({2, 3, 4}), 164 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 165 StringAttr::get(ctx, "REG"))); 166 167 patterns.add<LinalgTilingPattern>( 168 MatvecOp::getOperationName(), ctx, 169 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 170 LinalgTilingLoopType::ParallelLoops), 171 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 172 StringAttr::get(ctx, "L1"))); 173 174 patterns.add<LinalgTilingPattern>( 175 DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000), 176 LinalgTransformationFilter( 177 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 178 StringAttr::get(ctx, "L3"), 179 StringAttr::get(ctx, "L2")}, 180 StringAttr::get(ctx, "REG"))); 181 182 //===--------------------------------------------------------------------===// 183 // Linalg tiling and permutation patterns. 184 //===--------------------------------------------------------------------===// 185 patterns.add<LinalgTilingPattern>( 186 MatmulOp::getOperationName(), ctx, 187 LinalgTilingOptions() 188 .setTileSizes({2000, 3000, 4000}) 189 .setInterchange({1, 2, 0}), 190 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 191 StringAttr::get(ctx, "L2__with_perm__"))); 192 patterns.add<LinalgTilingPattern>( 193 MatmulOp::getOperationName(), ctx, 194 LinalgTilingOptions() 195 .setTileSizes({200, 300, 400}) 196 .setInterchange({1, 0, 2}), 197 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 198 StringAttr::get(ctx, "L1__with_perm__"))); 199 patterns.add<LinalgTilingPattern>( 200 MatmulOp::getOperationName(), ctx, 201 LinalgTilingOptions().setTileSizes({20, 30, 40}), 202 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 203 StringAttr::get(ctx, "REG__with_perm__"))); 204 205 patterns.add<LinalgTilingPattern>( 206 MatvecOp::getOperationName(), ctx, 207 LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 208 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 209 StringAttr::get(ctx, "L1__with_perm__"))); 210 211 patterns.add<LinalgTilingPattern>( 212 MatmulOp::getOperationName(), ctx, 213 LinalgTilingOptions() 214 .setTileSizes({16, 8, 4}) 215 .setInterchange({1, 2, 0}) 216 .setLoopType(LinalgTilingLoopType::ParallelLoops), 217 LinalgTransformationFilter( 218 StringAttr::get(ctx, "par__with_perm__"), 219 StringAttr::get(ctx, "after_par__with_perm__"))); 220 221 //===--------------------------------------------------------------------===// 222 // Linalg to loops patterns. 223 //===--------------------------------------------------------------------===// 224 patterns.add<LinalgLoweringPattern<DotOp>>( 225 ctx, 226 /*loweringType=*/LinalgLoweringType::Loops, 227 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 228 229 //===--------------------------------------------------------------------===// 230 // Linalg distribution patterns. 231 //===--------------------------------------------------------------------===// 232 LinalgLoopDistributionOptions distributionOptions; 233 234 //===--------------------------------------------------------------------===// 235 // Linalg to vector contraction patterns. 236 //===--------------------------------------------------------------------===// 237 patterns.add<LinalgVectorizationPattern>( 238 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 239 .addOpFilter<MatmulOp, FillOp, GenericOp>()); 240 patterns.add<CopyVectorizationPattern>(ctx); 241 242 //===--------------------------------------------------------------------===// 243 // Linalg generic interchange pattern. 244 //===--------------------------------------------------------------------===// 245 patterns.add<GenericOpInterchangePattern>( 246 ctx, 247 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 248 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 249 StringAttr::get(ctx, "PERMUTED"))); 250 251 //===--------------------------------------------------------------------===// 252 // Linalg subview operands promotion. 253 //===--------------------------------------------------------------------===// 254 patterns.add<LinalgPromotionPattern<MatmulOp>>( 255 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 256 LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), 257 StringAttr::get(ctx, "_views_promoted_"))); 258 patterns.add<LinalgPromotionPattern<MatmulOp>>( 259 ctx, 260 LinalgPromotionOptions() 261 .setOperandsToPromote({0}) 262 .setUseFullTileBuffersByDefault(true), 263 LinalgTransformationFilter( 264 StringAttr::get(ctx, "_promote_first_view_"), 265 StringAttr::get(ctx, "_first_view_promoted_"))); 266 patterns.add<LinalgPromotionPattern<FillOp>>( 267 ctx, 268 LinalgPromotionOptions() 269 .setOperandsToPromote({1}) 270 .setUseFullTileBuffers({false, true}) 271 .setAlignment(32), 272 LinalgTransformationFilter( 273 StringAttr::get(ctx, "_promote_views_aligned_"), 274 StringAttr::get(ctx, "_views_aligned_promoted_"))); 275 276 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 277 278 // Drop the marker. 279 funcOp.walk([](LinalgOp op) { 280 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 281 }); 282 } 283 284 static void fillL1TilingAndMatmulToVectorPatterns( 285 FuncOp funcOp, StringRef startMarker, 286 SmallVectorImpl<RewritePatternSet> &patternsVector) { 287 MLIRContext *ctx = funcOp.getContext(); 288 patternsVector.emplace_back( 289 ctx, std::make_unique<LinalgTilingPattern>( 290 MatmulOp::getOperationName(), ctx, 291 LinalgTilingOptions() 292 .setTileSizes({8, 12, 16}) 293 .setInterchange({1, 0, 2}), 294 LinalgTransformationFilter(StringAttr::get(ctx, startMarker), 295 StringAttr::get(ctx, "L1")))); 296 297 patternsVector.emplace_back( 298 ctx, 299 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 300 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 301 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 302 StringAttr::get(ctx, "VEC")))); 303 304 patternsVector.emplace_back( 305 ctx, std::make_unique<LinalgVectorizationPattern>( 306 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 307 LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); 308 patternsVector.back().add<LinalgVectorizationPattern>( 309 ctx, LinalgTransformationFilter().addOpFilter<FillOp>()); 310 patternsVector.back().add<CopyVectorizationPattern>(ctx); 311 } 312 313 //===----------------------------------------------------------------------===// 314 // Test promotion callbacks 315 //===----------------------------------------------------------------------===// 316 317 // Allocation call back 318 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 319 ArrayRef<Value> boundingSubViewSize, 320 DataLayout &layout) { 321 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 322 return b 323 .create<memref::AllocOp>( 324 subView.getLoc(), 325 MemRefType::get(shape, subView.getType().getElementType(), 326 /*affineMapComposition =*/{}, 3), 327 boundingSubViewSize) 328 .getResult(); 329 } 330 331 // Deallocation callback 332 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 333 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 334 return success(); 335 } 336 337 // Copy in call back 338 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 339 bool isOutput) { 340 auto floatType = src.getType().cast<MemRefType>().getElementType(); 341 if (!floatType.isa<FloatType>()) 342 return failure(); 343 if (!isOutput) { 344 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 345 FloatAttr::get(floatType, 42.0)); 346 b.create<FillOp>(src.getLoc(), cst, dst); 347 } 348 b.create<memref::CopyOp>(src.getLoc(), src, dst); 349 return success(); 350 } 351 352 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 353 RewritePatternSet &patterns) { 354 patterns.add<LinalgTilingPattern>( 355 MatmulOp::getOperationName(), ctx, 356 LinalgTilingOptions().setTileSizes({16, 16, 16}), 357 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 358 StringAttr::get(ctx, "PROMOTE"))); 359 patterns.add<LinalgPromotionPattern<MatmulOp>>( 360 ctx, 361 LinalgPromotionOptions() 362 .setOperandsToPromote({0, 2}) 363 .setUseFullTileBuffers({false, false}) 364 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 365 .setCopyInOutFns( 366 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 367 return copyCallBackFn(b, src, dst, false); 368 }, 369 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 370 return copyCallBackFn(b, src, dst, true); 371 }), 372 LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); 373 } 374 375 template <typename IdOp, typename NProcsOp> 376 static SmallVector<ProcInfo, 2> 377 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 378 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 379 SmallVector<ProcInfo, 2> procInfo(count); 380 Type indexType = b.getIndexType(); 381 for (unsigned i = 0; i < count; ++i) { 382 gpu::Dimension dim = *gpu::symbolizeDimension(i); 383 procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim), 384 b.create<NProcsOp>(loc, indexType, dim)}; 385 } 386 return procInfo; 387 } 388 389 static void fillTileAndDistributePatterns(MLIRContext *context, 390 RewritePatternSet &patterns) { 391 { 392 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 393 cyclicNprocsEqNiters.distributionMethod.resize( 394 2, DistributionMethod::CyclicNumProcsEqNumIters); 395 cyclicNprocsEqNiters.procInfo = 396 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 397 patterns.add<LinalgTilingPattern>( 398 MatmulOp::getOperationName(), context, 399 LinalgTilingOptions() 400 .setTileSizes({8, 8, 4}) 401 .setLoopType(LinalgTilingLoopType::ParallelLoops) 402 .setDistributionOptions(cyclicNprocsEqNiters), 403 LinalgTransformationFilter( 404 StringAttr::get(context, "distribute1"), 405 StringAttr::get(context, "after_distribute1"))); 406 } 407 408 { 409 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 410 cyclicNprocsGeNiters.distributionMethod.resize( 411 2, DistributionMethod::CyclicNumProcsGeNumIters); 412 cyclicNprocsGeNiters.procInfo = 413 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 414 patterns.add<LinalgTilingPattern>( 415 MatmulOp::getOperationName(), context, 416 LinalgTilingOptions() 417 .setTileSizes({8, 8, 4}) 418 .setLoopType(LinalgTilingLoopType::ParallelLoops) 419 .setDistributionOptions(cyclicNprocsGeNiters), 420 LinalgTransformationFilter( 421 StringAttr::get(context, "distribute2"), 422 StringAttr::get(context, "after_distribute2"))); 423 } 424 425 { 426 LinalgLoopDistributionOptions cyclicNprocsDefault; 427 cyclicNprocsDefault.distributionMethod.resize(2, 428 DistributionMethod::Cyclic); 429 cyclicNprocsDefault.procInfo = 430 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 431 patterns.add<LinalgTilingPattern>( 432 MatmulOp::getOperationName(), context, 433 LinalgTilingOptions() 434 .setTileSizes({8, 8, 4}) 435 .setLoopType(LinalgTilingLoopType::ParallelLoops) 436 .setDistributionOptions(cyclicNprocsDefault), 437 LinalgTransformationFilter( 438 StringAttr::get(context, "distribute3"), 439 StringAttr::get(context, "after_distribute3"))); 440 } 441 442 { 443 LinalgLoopDistributionOptions cyclicNprocsMixed1; 444 cyclicNprocsMixed1.distributionMethod = { 445 DistributionMethod::CyclicNumProcsEqNumIters, 446 DistributionMethod::CyclicNumProcsGeNumIters}; 447 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 448 patterns.add<LinalgTilingPattern>( 449 MatmulOp::getOperationName(), context, 450 LinalgTilingOptions() 451 .setTileSizes({8, 8, 4}) 452 .setLoopType(LinalgTilingLoopType::ParallelLoops) 453 .setDistributionOptions(cyclicNprocsMixed1), 454 LinalgTransformationFilter( 455 StringAttr::get(context, "distribute4"), 456 StringAttr::get(context, "after_distribute4"))); 457 } 458 459 { 460 LinalgLoopDistributionOptions cyclicNprocsMixed2; 461 cyclicNprocsMixed2.distributionMethod = { 462 DistributionMethod::CyclicNumProcsGeNumIters, 463 DistributionMethod::Cyclic}; 464 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 465 patterns.add<LinalgTilingPattern>( 466 MatmulOp::getOperationName(), context, 467 LinalgTilingOptions() 468 .setTileSizes({8, 8, 4}) 469 .setLoopType(LinalgTilingLoopType::ParallelLoops) 470 .setDistributionOptions(cyclicNprocsMixed2), 471 LinalgTransformationFilter( 472 StringAttr::get(context, "distribute5"), 473 StringAttr::get(context, "after_distribute5"))); 474 } 475 476 { 477 LinalgLoopDistributionOptions cyclicNprocsMixed3; 478 cyclicNprocsMixed3.distributionMethod = { 479 DistributionMethod::Cyclic, 480 DistributionMethod::CyclicNumProcsEqNumIters}; 481 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 482 483 patterns.add<LinalgTilingPattern>( 484 MatmulOp::getOperationName(), context, 485 LinalgTilingOptions() 486 .setTileSizes({8, 8, 4}) 487 .setLoopType(LinalgTilingLoopType::ParallelLoops) 488 .setDistributionOptions(cyclicNprocsMixed3), 489 LinalgTransformationFilter( 490 StringAttr::get(context, "distribute6"), 491 StringAttr::get(context, "after_distribute6"))); 492 } 493 494 { 495 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 496 cyclicNprocsEqNiters.distributionMethod.resize(2, 497 DistributionMethod::Cyclic); 498 cyclicNprocsEqNiters.procInfo = 499 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 500 patterns.add<LinalgTilingPattern>( 501 MatmulOp::getOperationName(), context, 502 LinalgTilingOptions() 503 .setTileSizes({8, 8, 4}) 504 .setLoopType(LinalgTilingLoopType::Loops) 505 .setDistributionOptions(cyclicNprocsEqNiters), 506 LinalgTransformationFilter( 507 StringAttr::get(context, "tensors_distribute1"), 508 StringAttr::get(context, "tensors_after_distribute1"))); 509 } 510 } 511 512 static void fillTileFuseAndDistributePatterns(MLIRContext *context, 513 RewritePatternSet &patterns) { 514 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 515 cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); 516 cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 517 patterns.add<LinalgTileAndFuseTensorOpsPattern>( 518 MatmulOp::getOperationName(), context, 519 LinalgTilingAndFusionOptions() 520 .setTileSizes({8, 8, 4}) 521 .setDistributionOptions(cyclicNprocsEqNiters), 522 LinalgTransformationFilter( 523 StringAttr::get(context, "tensors_fuse_distribute1"), 524 StringAttr::get(context, "tensors_after_fuse_distribute1"))); 525 } 526 527 static void 528 applyMatmulToVectorPatterns(FuncOp funcOp, 529 bool testMatmulToVectorPatterns1dTiling, 530 bool testMatmulToVectorPatterns2dTiling) { 531 MLIRContext *ctx = funcOp.getContext(); 532 SmallVector<RewritePatternSet, 4> stage1Patterns; 533 if (testMatmulToVectorPatterns1dTiling) { 534 fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); 535 } else if (testMatmulToVectorPatterns2dTiling) { 536 stage1Patterns.emplace_back( 537 ctx, std::make_unique<LinalgTilingPattern>( 538 MatmulOp::getOperationName(), ctx, 539 LinalgTilingOptions() 540 .setTileSizes({768, 264, 768}) 541 .setInterchange({1, 2, 0}), 542 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 543 StringAttr::get(ctx, "L2")))); 544 fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); 545 } 546 { 547 // Canonicalization patterns 548 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 549 vector::populateVectorTransferPermutationMapLoweringPatterns( 550 canonicalizationPatterns); 551 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 552 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 553 } 554 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 555 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 556 FrozenRewritePatternSet stage2Patterns = 557 getLinalgTilingCanonicalizationPatterns(ctx); 558 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns); 559 } 560 561 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 562 RewritePatternSet forwardPattern(funcOp.getContext()); 563 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 564 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 565 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 566 } 567 568 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 569 RewritePatternSet patterns(funcOp.getContext()); 570 auto *ctx = funcOp.getContext(); 571 patterns.add<LinalgVectorizationPattern>( 572 ctx, LinalgTransformationFilter() 573 .addOpFilter<ContractionOpInterface, FillOp, GenericOp>()); 574 patterns.add<CopyVectorizationPattern>(ctx); 575 populatePadOpVectorizationPatterns(patterns); 576 populateConvolutionVectorizationPatterns(patterns); 577 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 578 } 579 580 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 581 RewritePatternSet patterns(funcOp.getContext()); 582 patterns.add<PadOpTransformationPattern>(funcOp.getContext()); 583 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 584 } 585 586 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 587 RewritePatternSet patterns(funcOp.getContext()); 588 patterns.add<GeneralizePadOpPattern>(funcOp.getContext()); 589 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 590 } 591 592 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 593 RewritePatternSet patterns(funcOp.getContext()); 594 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 595 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 596 } 597 598 static void applyTilePattern(FuncOp funcOp, const std::string &loopType, 599 ArrayRef<int64_t> tileSizes, 600 ArrayRef<int64_t> peeledLoops, 601 bool scalarizeDynamicDims) { 602 MLIRContext *context = funcOp.getContext(); 603 RewritePatternSet tilingPattern(context); 604 LinalgTilingLoopType type = 605 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 606 .Case("for", LinalgTilingLoopType::Loops) 607 .Case("affine", LinalgTilingLoopType::AffineLoops) 608 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 609 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 610 auto linalgTilingOptions = linalg::LinalgTilingOptions() 611 .setPeeledLoops(peeledLoops) 612 .setLoopType(type); 613 if (scalarizeDynamicDims) { 614 linalgTilingOptions.scalarizeDynamicDims(); 615 assert(tileSizes.empty() && 616 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 617 } else { 618 linalgTilingOptions.setTileSizes(tileSizes); 619 } 620 linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile")); 621 TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert( 622 tilingPattern, linalgTilingOptions, f); 623 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 624 } 625 626 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 627 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 628 629 namespace { 630 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 631 /// `idx`-th loop contains only "full" iterations and a second loop for the 632 /// remaining partial iteration (if any). 633 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 634 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 635 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 636 } 637 638 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 639 PatternRewriter &rewriter) const override { 640 SmallVector<int64_t> peeledLoops; 641 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 642 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 643 peeledLoops = 644 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 645 return attr.cast<IntegerAttr>().getInt(); 646 })); 647 // Check if the loop was already peeled. 648 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 649 return failure(); 650 } 651 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 652 // No peeling of loop nests with a partial iteration. 653 return failure(); 654 655 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 656 return failure(); 657 658 // Peel loop and canonicalize. 659 TiledLoopOp result; 660 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 661 result))) 662 return failure(); 663 664 // Apply label, so that the same loop is not rewritten a second time. 665 peeledLoops.push_back(idx); 666 rewriter.updateRootInPlace(loopOp, [&]() { 667 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 668 }); 669 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 670 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 671 672 return success(); 673 } 674 675 /// Index of loop to peel. 676 int64_t idx; 677 678 /// If set to true, do not peel TiledLoopOps with a partial iteration. 679 bool skipPartial; 680 }; 681 } // namespace 682 683 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 684 ArrayRef<unsigned> loops, 685 bool skipPartial) { 686 MLIRContext *ctx = funcOp.getContext(); 687 RewritePatternSet patterns(ctx); 688 for (unsigned idx : loops) 689 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 690 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 691 692 // Drop the markers. 693 funcOp.walk([](TiledLoopOp op) { 694 op->removeAttr(kPeeledLoopsLabel); 695 op->removeAttr(kPartialIterationLabel); 696 }); 697 } 698 699 /// Apply transformations specified as patterns. 700 void TestLinalgTransforms::runOnOperation() { 701 auto lambda = [&](void *) { 702 getOperation().walk([](LinalgOp op) { 703 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 704 }); 705 }; 706 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 707 708 if (testPromotionOptions) { 709 RewritePatternSet patterns(&getContext()); 710 fillPromotionCallBackPatterns(&getContext(), patterns); 711 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 712 return; 713 } 714 if (testTileAndDistributionOptions) { 715 RewritePatternSet patterns(&getContext()); 716 fillTileAndDistributePatterns(&getContext(), patterns); 717 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 718 return; 719 } 720 if (testTileFuseAndDistributionOptions) { 721 RewritePatternSet patterns(&getContext()); 722 fillTileFuseAndDistributePatterns(&getContext(), patterns); 723 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 724 return; 725 } 726 if (testPatterns) 727 return applyPatterns(getOperation()); 728 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 729 return applyMatmulToVectorPatterns(getOperation(), 730 testMatmulToVectorPatterns1dTiling, 731 testMatmulToVectorPatterns2dTiling); 732 if (testVectorTransferForwardingPatterns) 733 return applyVectorTransferForwardingPatterns(getOperation()); 734 if (testGenericToVectorPattern) 735 return applyLinalgToVectorPatterns(getOperation()); 736 if (testTransformPadTensor) 737 return applyPadTensorToGenericPatterns(getOperation()); 738 if (testGeneralizePadTensor) 739 return applyGeneralizePadTensorPatterns(getOperation()); 740 if (testSwapSubTensorPadTensor) 741 return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 742 if (testTiledLoopPeeling.hasValue()) 743 return applyTiledLoopPeelingPattern(getOperation(), testTiledLoopPeeling, 744 skipPartial); 745 if (testTilePattern) 746 return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops, 747 /*scalarizeDynamicDims=*/false); 748 if (testTileScalarizeDynamicDims) 749 return applyTilePattern(getOperation(), loopType, tileSizes, 750 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 751 } 752 753 namespace mlir { 754 namespace test { 755 void registerTestLinalgTransforms() { 756 PassRegistration<TestLinalgTransforms>(); 757 } 758 } // namespace test 759 } // namespace mlir 760