1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallVector.h" 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 namespace { 34 struct TestLinalgTransforms 35 : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> { 36 TestLinalgTransforms() = default; 37 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 38 39 void getDependentDialects(DialectRegistry ®istry) const override { 40 // clang-format off 41 registry.insert<AffineDialect, 42 memref::MemRefDialect, 43 scf::SCFDialect, 44 StandardOpsDialect, 45 linalg::LinalgDialect, 46 vector::VectorDialect, 47 gpu::GPUDialect>(); 48 // clang-format on 49 } 50 StringRef getArgument() const final { 51 return "test-linalg-transform-patterns"; 52 } 53 StringRef getDescription() const final { 54 return "Test Linalg transformation patterns by applying them greedily."; 55 } 56 57 void runOnOperation() override; 58 59 Option<bool> testPatterns{*this, "test-patterns", 60 llvm::cl::desc("Test a mixed set of patterns"), 61 llvm::cl::init(false)}; 62 Option<bool> testMatmulToVectorPatterns1dTiling{ 63 *this, "test-matmul-to-vector-patterns-tile-1d", 64 llvm::cl::desc( 65 "Test a fused pass that applies patterns from matmul to vectors via " 66 "1-d tiling"), 67 llvm::cl::init(false)}; 68 Option<bool> testMatmulToVectorPatterns2dTiling{ 69 *this, "test-matmul-to-vector-patterns-tile-2d", 70 llvm::cl::desc( 71 "Test a fused pass that applies patterns from matmul to vectors via " 72 "2-d tiling"), 73 llvm::cl::init(false)}; 74 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 75 llvm::cl::desc("Test promotion options"), 76 llvm::cl::init(false)}; 77 Option<bool> testTileAndDistributionOptions{ 78 *this, "test-tile-and-distribute-options", 79 llvm::cl::desc("Test tile and distribute options"), 80 llvm::cl::init(false)}; 81 Option<bool> testTileFuseAndDistributionOptions{ 82 *this, "test-tile-fuse-and-distribute-options", 83 llvm::cl::desc("Test tile, fuse and distribute options"), 84 llvm::cl::init(false)}; 85 Option<bool> testVectorTransferForwardingPatterns{ 86 *this, "test-vector-transfer-forwarding-patterns", 87 llvm::cl::desc( 88 "Test a fused pass that forwards memref.copy to vector.transfer"), 89 llvm::cl::init(false)}; 90 Option<bool> testGenericToVectorPattern{ 91 *this, "test-linalg-to-vector-patterns", 92 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 93 "in vector.contract form"), 94 llvm::cl::init(false)}; 95 Option<bool> testTilePattern{*this, "test-tile-pattern", 96 llvm::cl::desc("Test tile pattern"), 97 llvm::cl::init(false)}; 98 Option<bool> testTileScalarizeDynamicDims{ 99 *this, "test-tile-scalarize-dynamic-dims", 100 llvm::cl::desc("Test tiling of dynamic dims by 1"), 101 llvm::cl::init(false)}; 102 Option<bool> testTransformPadTensor{ 103 *this, "test-transform-pad-tensor", 104 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 105 llvm::cl::init(false)}; 106 Option<bool> testGeneralizePadTensor{ 107 *this, "test-generalize-pad-tensor", 108 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 109 llvm::cl::init(false)}; 110 Option<bool> testSwapSubTensorPadTensor{ 111 *this, "test-swap-subtensor-padtensor", 112 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 113 "pad_tensor(subtensor)"), 114 llvm::cl::init(false)}; 115 ListOption<int64_t> peeledLoops{ 116 *this, "peeled-loops", 117 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 118 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 119 ListOption<int64_t> tileSizes{ 120 *this, "tile-sizes", 121 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 122 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 123 Option<bool> skipPartial{ 124 *this, "skip-partial", 125 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 126 llvm::cl::init(false)}; 127 Option<std::string> loopType{ 128 *this, "loop-type", 129 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 130 "tiled_loop"), 131 llvm::cl::init("for")}; 132 }; 133 } // namespace 134 135 static void applyPatterns(FuncOp funcOp) { 136 MLIRContext *ctx = funcOp.getContext(); 137 RewritePatternSet patterns(ctx); 138 139 //===--------------------------------------------------------------------===// 140 // Linalg tiling patterns. 141 //===--------------------------------------------------------------------===// 142 patterns.add<LinalgTilingPattern>( 143 MatmulOp::getOperationName(), ctx, 144 LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 145 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 146 StringAttr::get(ctx, "L3"))); 147 patterns.add<LinalgTilingPattern>( 148 MatmulOp::getOperationName(), ctx, 149 LinalgTilingOptions().setTileSizes({200, 300, 400}), 150 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 151 StringAttr::get(ctx, "L2"))); 152 patterns.add<LinalgTilingPattern>( 153 MatmulOp::getOperationName(), ctx, 154 LinalgTilingOptions().setTileSizes({20, 30, 40}), 155 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 156 StringAttr::get(ctx, "L1"))); 157 patterns.add<LinalgTilingPattern>( 158 MatmulOp::getOperationName(), ctx, 159 LinalgTilingOptions().setTileSizes({2, 3, 4}), 160 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 161 StringAttr::get(ctx, "REG"))); 162 163 patterns.add<LinalgTilingPattern>( 164 MatvecOp::getOperationName(), ctx, 165 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 166 LinalgTilingLoopType::ParallelLoops), 167 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 168 StringAttr::get(ctx, "L1"))); 169 170 patterns.add<LinalgTilingPattern>( 171 DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000), 172 LinalgTransformationFilter( 173 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 174 StringAttr::get(ctx, "L3"), 175 StringAttr::get(ctx, "L2")}, 176 StringAttr::get(ctx, "REG"))); 177 178 //===--------------------------------------------------------------------===// 179 // Linalg tiling and permutation patterns. 180 //===--------------------------------------------------------------------===// 181 patterns.add<LinalgTilingPattern>( 182 MatmulOp::getOperationName(), ctx, 183 LinalgTilingOptions() 184 .setTileSizes({2000, 3000, 4000}) 185 .setInterchange({1, 2, 0}), 186 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 187 StringAttr::get(ctx, "L2__with_perm__"))); 188 patterns.add<LinalgTilingPattern>( 189 MatmulOp::getOperationName(), ctx, 190 LinalgTilingOptions() 191 .setTileSizes({200, 300, 400}) 192 .setInterchange({1, 0, 2}), 193 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 194 StringAttr::get(ctx, "L1__with_perm__"))); 195 patterns.add<LinalgTilingPattern>( 196 MatmulOp::getOperationName(), ctx, 197 LinalgTilingOptions().setTileSizes({20, 30, 40}), 198 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 199 StringAttr::get(ctx, "REG__with_perm__"))); 200 201 patterns.add<LinalgTilingPattern>( 202 MatvecOp::getOperationName(), ctx, 203 LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 204 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 205 StringAttr::get(ctx, "L1__with_perm__"))); 206 207 patterns.add<LinalgTilingPattern>( 208 MatmulOp::getOperationName(), ctx, 209 LinalgTilingOptions() 210 .setTileSizes({16, 8, 4}) 211 .setInterchange({1, 2, 0}) 212 .setLoopType(LinalgTilingLoopType::ParallelLoops), 213 LinalgTransformationFilter( 214 StringAttr::get(ctx, "par__with_perm__"), 215 StringAttr::get(ctx, "after_par__with_perm__"))); 216 217 //===--------------------------------------------------------------------===// 218 // Linalg to loops patterns. 219 //===--------------------------------------------------------------------===// 220 patterns.add<LinalgLoweringPattern<DotOp>>( 221 ctx, 222 /*loweringType=*/LinalgLoweringType::Loops, 223 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 224 225 //===--------------------------------------------------------------------===// 226 // Linalg distribution patterns. 227 //===--------------------------------------------------------------------===// 228 LinalgLoopDistributionOptions distributionOptions; 229 230 //===--------------------------------------------------------------------===// 231 // Linalg to vector contraction patterns. 232 //===--------------------------------------------------------------------===// 233 patterns.add<LinalgVectorizationPattern>( 234 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 235 .addOpFilter<MatmulOp, FillOp, GenericOp>()); 236 patterns.add<CopyVectorizationPattern>(ctx); 237 238 //===--------------------------------------------------------------------===// 239 // Linalg generic interchange pattern. 240 //===--------------------------------------------------------------------===// 241 patterns.add<GenericOpInterchangePattern>( 242 ctx, 243 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 244 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 245 StringAttr::get(ctx, "PERMUTED"))); 246 247 //===--------------------------------------------------------------------===// 248 // Linalg subview operands promotion. 249 //===--------------------------------------------------------------------===// 250 patterns.add<LinalgPromotionPattern<MatmulOp>>( 251 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 252 LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), 253 StringAttr::get(ctx, "_views_promoted_"))); 254 patterns.add<LinalgPromotionPattern<MatmulOp>>( 255 ctx, 256 LinalgPromotionOptions() 257 .setOperandsToPromote({0}) 258 .setUseFullTileBuffersByDefault(true), 259 LinalgTransformationFilter( 260 StringAttr::get(ctx, "_promote_first_view_"), 261 StringAttr::get(ctx, "_first_view_promoted_"))); 262 patterns.add<LinalgPromotionPattern<FillOp>>( 263 ctx, 264 LinalgPromotionOptions() 265 .setOperandsToPromote({1}) 266 .setUseFullTileBuffers({false, true}) 267 .setAlignment(32), 268 LinalgTransformationFilter( 269 StringAttr::get(ctx, "_promote_views_aligned_"), 270 StringAttr::get(ctx, "_views_aligned_promoted_"))); 271 272 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 273 274 // Drop the marker. 275 funcOp.walk([](LinalgOp op) { 276 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 277 }); 278 } 279 280 static void fillL1TilingAndMatmulToVectorPatterns( 281 FuncOp funcOp, StringRef startMarker, 282 SmallVectorImpl<RewritePatternSet> &patternsVector) { 283 MLIRContext *ctx = funcOp.getContext(); 284 patternsVector.emplace_back( 285 ctx, std::make_unique<LinalgTilingPattern>( 286 MatmulOp::getOperationName(), ctx, 287 LinalgTilingOptions() 288 .setTileSizes({8, 12, 16}) 289 .setInterchange({1, 0, 2}), 290 LinalgTransformationFilter(StringAttr::get(ctx, startMarker), 291 StringAttr::get(ctx, "L1")))); 292 293 patternsVector.emplace_back( 294 ctx, 295 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 296 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 297 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 298 StringAttr::get(ctx, "VEC")))); 299 300 patternsVector.emplace_back( 301 ctx, std::make_unique<LinalgVectorizationPattern>( 302 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 303 LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); 304 patternsVector.back().add<LinalgVectorizationPattern>( 305 ctx, LinalgTransformationFilter().addOpFilter<FillOp>()); 306 patternsVector.back().add<CopyVectorizationPattern>(ctx); 307 } 308 309 //===----------------------------------------------------------------------===// 310 // Test promotion callbacks 311 //===----------------------------------------------------------------------===// 312 313 // Allocation call back 314 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 315 ArrayRef<Value> boundingSubViewSize, 316 DataLayout &layout) { 317 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 318 return b 319 .create<memref::AllocOp>( 320 subView.getLoc(), 321 MemRefType::get(shape, subView.getType().getElementType(), 322 /*affineMapComposition =*/{}, 3), 323 boundingSubViewSize) 324 .getResult(); 325 } 326 327 // Deallocation callback 328 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 329 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 330 return success(); 331 } 332 333 // Copy in call back 334 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 335 bool isOutput) { 336 auto floatType = src.getType().cast<MemRefType>().getElementType(); 337 if (!floatType.isa<FloatType>()) 338 return failure(); 339 if (!isOutput) { 340 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 341 FloatAttr::get(floatType, 42.0)); 342 b.create<FillOp>(src.getLoc(), cst, dst); 343 } 344 b.create<memref::CopyOp>(src.getLoc(), src, dst); 345 return success(); 346 } 347 348 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 349 RewritePatternSet &patterns) { 350 patterns.add<LinalgTilingPattern>( 351 MatmulOp::getOperationName(), ctx, 352 LinalgTilingOptions().setTileSizes({16, 16, 16}), 353 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 354 StringAttr::get(ctx, "PROMOTE"))); 355 patterns.add<LinalgPromotionPattern<MatmulOp>>( 356 ctx, 357 LinalgPromotionOptions() 358 .setOperandsToPromote({0, 2}) 359 .setUseFullTileBuffers({false, false}) 360 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 361 .setCopyInOutFns( 362 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 363 return copyCallBackFn(b, src, dst, false); 364 }, 365 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 366 return copyCallBackFn(b, src, dst, true); 367 }), 368 LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); 369 } 370 371 template <typename IdOp, typename NProcsOp> 372 static SmallVector<ProcInfo, 2> 373 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 374 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 375 SmallVector<ProcInfo, 2> procInfo(count); 376 Type indexType = b.getIndexType(); 377 for (unsigned i = 0; i < count; ++i) { 378 gpu::Dimension dim = *gpu::symbolizeDimension(i); 379 procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim), 380 b.create<NProcsOp>(loc, indexType, dim)}; 381 } 382 return procInfo; 383 } 384 385 static void fillTileAndDistributePatterns(MLIRContext *context, 386 RewritePatternSet &patterns) { 387 { 388 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 389 cyclicNprocsEqNiters.distributionMethod.resize( 390 2, DistributionMethod::CyclicNumProcsEqNumIters); 391 cyclicNprocsEqNiters.procInfo = 392 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 393 patterns.add<LinalgTilingPattern>( 394 MatmulOp::getOperationName(), context, 395 LinalgTilingOptions() 396 .setTileSizes({8, 8, 4}) 397 .setLoopType(LinalgTilingLoopType::ParallelLoops) 398 .setDistributionOptions(cyclicNprocsEqNiters), 399 LinalgTransformationFilter( 400 StringAttr::get(context, "distribute1"), 401 StringAttr::get(context, "after_distribute1"))); 402 } 403 404 { 405 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 406 cyclicNprocsGeNiters.distributionMethod.resize( 407 2, DistributionMethod::CyclicNumProcsGeNumIters); 408 cyclicNprocsGeNiters.procInfo = 409 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 410 patterns.add<LinalgTilingPattern>( 411 MatmulOp::getOperationName(), context, 412 LinalgTilingOptions() 413 .setTileSizes({8, 8, 4}) 414 .setLoopType(LinalgTilingLoopType::ParallelLoops) 415 .setDistributionOptions(cyclicNprocsGeNiters), 416 LinalgTransformationFilter( 417 StringAttr::get(context, "distribute2"), 418 StringAttr::get(context, "after_distribute2"))); 419 } 420 421 { 422 LinalgLoopDistributionOptions cyclicNprocsDefault; 423 cyclicNprocsDefault.distributionMethod.resize(2, 424 DistributionMethod::Cyclic); 425 cyclicNprocsDefault.procInfo = 426 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 427 patterns.add<LinalgTilingPattern>( 428 MatmulOp::getOperationName(), context, 429 LinalgTilingOptions() 430 .setTileSizes({8, 8, 4}) 431 .setLoopType(LinalgTilingLoopType::ParallelLoops) 432 .setDistributionOptions(cyclicNprocsDefault), 433 LinalgTransformationFilter( 434 StringAttr::get(context, "distribute3"), 435 StringAttr::get(context, "after_distribute3"))); 436 } 437 438 { 439 LinalgLoopDistributionOptions cyclicNprocsMixed1; 440 cyclicNprocsMixed1.distributionMethod = { 441 DistributionMethod::CyclicNumProcsEqNumIters, 442 DistributionMethod::CyclicNumProcsGeNumIters}; 443 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 444 patterns.add<LinalgTilingPattern>( 445 MatmulOp::getOperationName(), context, 446 LinalgTilingOptions() 447 .setTileSizes({8, 8, 4}) 448 .setLoopType(LinalgTilingLoopType::ParallelLoops) 449 .setDistributionOptions(cyclicNprocsMixed1), 450 LinalgTransformationFilter( 451 StringAttr::get(context, "distribute4"), 452 StringAttr::get(context, "after_distribute4"))); 453 } 454 455 { 456 LinalgLoopDistributionOptions cyclicNprocsMixed2; 457 cyclicNprocsMixed2.distributionMethod = { 458 DistributionMethod::CyclicNumProcsGeNumIters, 459 DistributionMethod::Cyclic}; 460 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 461 patterns.add<LinalgTilingPattern>( 462 MatmulOp::getOperationName(), context, 463 LinalgTilingOptions() 464 .setTileSizes({8, 8, 4}) 465 .setLoopType(LinalgTilingLoopType::ParallelLoops) 466 .setDistributionOptions(cyclicNprocsMixed2), 467 LinalgTransformationFilter( 468 StringAttr::get(context, "distribute5"), 469 StringAttr::get(context, "after_distribute5"))); 470 } 471 472 { 473 LinalgLoopDistributionOptions cyclicNprocsMixed3; 474 cyclicNprocsMixed3.distributionMethod = { 475 DistributionMethod::Cyclic, 476 DistributionMethod::CyclicNumProcsEqNumIters}; 477 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 478 479 patterns.add<LinalgTilingPattern>( 480 MatmulOp::getOperationName(), context, 481 LinalgTilingOptions() 482 .setTileSizes({8, 8, 4}) 483 .setLoopType(LinalgTilingLoopType::ParallelLoops) 484 .setDistributionOptions(cyclicNprocsMixed3), 485 LinalgTransformationFilter( 486 StringAttr::get(context, "distribute6"), 487 StringAttr::get(context, "after_distribute6"))); 488 } 489 490 { 491 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 492 cyclicNprocsEqNiters.distributionMethod.resize(2, 493 DistributionMethod::Cyclic); 494 cyclicNprocsEqNiters.procInfo = 495 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 496 patterns.add<LinalgTilingPattern>( 497 MatmulOp::getOperationName(), context, 498 LinalgTilingOptions() 499 .setTileSizes({8, 8, 4}) 500 .setLoopType(LinalgTilingLoopType::Loops) 501 .setDistributionOptions(cyclicNprocsEqNiters), 502 LinalgTransformationFilter( 503 StringAttr::get(context, "tensors_distribute1"), 504 StringAttr::get(context, "tensors_after_distribute1"))); 505 } 506 } 507 508 static void fillTileFuseAndDistributePatterns(MLIRContext *context, 509 RewritePatternSet &patterns) { 510 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 511 cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); 512 cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 513 patterns.add<LinalgTileAndFuseTensorOpsPattern>( 514 MatmulOp::getOperationName(), context, 515 LinalgTilingAndFusionOptions() 516 .setTileSizes({8, 8, 4}) 517 .setDistributionOptions(cyclicNprocsEqNiters), 518 LinalgTransformationFilter( 519 StringAttr::get(context, "tensors_fuse_distribute1"), 520 StringAttr::get(context, "tensors_after_fuse_distribute1"))); 521 } 522 523 static void 524 applyMatmulToVectorPatterns(FuncOp funcOp, 525 bool testMatmulToVectorPatterns1dTiling, 526 bool testMatmulToVectorPatterns2dTiling) { 527 MLIRContext *ctx = funcOp.getContext(); 528 SmallVector<RewritePatternSet, 4> stage1Patterns; 529 if (testMatmulToVectorPatterns1dTiling) { 530 fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); 531 } else if (testMatmulToVectorPatterns2dTiling) { 532 stage1Patterns.emplace_back( 533 ctx, std::make_unique<LinalgTilingPattern>( 534 MatmulOp::getOperationName(), ctx, 535 LinalgTilingOptions() 536 .setTileSizes({768, 264, 768}) 537 .setInterchange({1, 2, 0}), 538 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 539 StringAttr::get(ctx, "L2")))); 540 fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); 541 } 542 { 543 // Canonicalization patterns 544 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 545 vector::populateVectorTransferPermutationMapLoweringPatterns( 546 canonicalizationPatterns); 547 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 548 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 549 } 550 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 551 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 552 FrozenRewritePatternSet stage2Patterns = 553 getLinalgTilingCanonicalizationPatterns(ctx); 554 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns); 555 } 556 557 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 558 RewritePatternSet forwardPattern(funcOp.getContext()); 559 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 560 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 561 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 562 } 563 564 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 565 RewritePatternSet patterns(funcOp.getContext()); 566 auto *ctx = funcOp.getContext(); 567 patterns.add<LinalgVectorizationPattern>( 568 ctx, LinalgTransformationFilter() 569 .addOpFilter<ContractionOpInterface, FillOp, GenericOp>()); 570 patterns.add<CopyVectorizationPattern>(ctx); 571 populatePadOpVectorizationPatterns(patterns); 572 populateConvolutionVectorizationPatterns(patterns); 573 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 574 } 575 576 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 577 RewritePatternSet patterns(funcOp.getContext()); 578 patterns.add<PadOpTransformationPattern>(funcOp.getContext()); 579 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 580 } 581 582 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 583 RewritePatternSet patterns(funcOp.getContext()); 584 patterns.add<GeneralizePadOpPattern>(funcOp.getContext()); 585 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 586 } 587 588 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 589 RewritePatternSet patterns(funcOp.getContext()); 590 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 591 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 592 } 593 594 static void applyTilePattern(FuncOp funcOp, const std::string &loopType, 595 ArrayRef<int64_t> tileSizes, 596 ArrayRef<int64_t> peeledLoops, 597 bool scalarizeDynamicDims) { 598 MLIRContext *context = funcOp.getContext(); 599 RewritePatternSet tilingPattern(context); 600 LinalgTilingLoopType type = 601 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 602 .Case("for", LinalgTilingLoopType::Loops) 603 .Case("affine", LinalgTilingLoopType::AffineLoops) 604 .Case("parallel", LinalgTilingLoopType::ParallelLoops); 605 auto linalgTilingOptions = linalg::LinalgTilingOptions() 606 .setPeeledLoops(peeledLoops) 607 .setLoopType(type); 608 if (scalarizeDynamicDims) { 609 linalgTilingOptions.scalarizeDynamicDims(); 610 assert(tileSizes.empty() && 611 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 612 } else { 613 linalgTilingOptions.setTileSizes(tileSizes); 614 } 615 linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile")); 616 TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert( 617 tilingPattern, linalgTilingOptions, f); 618 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 619 } 620 621 /// Apply transformations specified as patterns. 622 void TestLinalgTransforms::runOnOperation() { 623 auto lambda = [&](void *) { 624 getOperation().walk([](LinalgOp op) { 625 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 626 }); 627 }; 628 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 629 630 if (testPromotionOptions) { 631 RewritePatternSet patterns(&getContext()); 632 fillPromotionCallBackPatterns(&getContext(), patterns); 633 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 634 return; 635 } 636 if (testTileAndDistributionOptions) { 637 RewritePatternSet patterns(&getContext()); 638 fillTileAndDistributePatterns(&getContext(), patterns); 639 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 640 return; 641 } 642 if (testTileFuseAndDistributionOptions) { 643 RewritePatternSet patterns(&getContext()); 644 fillTileFuseAndDistributePatterns(&getContext(), patterns); 645 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 646 return; 647 } 648 if (testPatterns) 649 return applyPatterns(getOperation()); 650 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 651 return applyMatmulToVectorPatterns(getOperation(), 652 testMatmulToVectorPatterns1dTiling, 653 testMatmulToVectorPatterns2dTiling); 654 if (testVectorTransferForwardingPatterns) 655 return applyVectorTransferForwardingPatterns(getOperation()); 656 if (testGenericToVectorPattern) 657 return applyLinalgToVectorPatterns(getOperation()); 658 if (testTransformPadTensor) 659 return applyPadTensorToGenericPatterns(getOperation()); 660 if (testGeneralizePadTensor) 661 return applyGeneralizePadTensorPatterns(getOperation()); 662 if (testSwapSubTensorPadTensor) 663 return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 664 if (testTilePattern) 665 return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops, 666 /*scalarizeDynamicDims=*/false); 667 if (testTileScalarizeDynamicDims) 668 return applyTilePattern(getOperation(), loopType, tileSizes, 669 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 670 } 671 672 namespace mlir { 673 namespace test { 674 void registerTestLinalgTransforms() { 675 PassRegistration<TestLinalgTransforms>(); 676 } 677 } // namespace test 678 } // namespace mlir 679