1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 21 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 22 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/Vector/IR/VectorOps.h" 25 #include "mlir/Pass/PassManager.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/ADT/SmallVector.h" 30 31 using namespace mlir; 32 using namespace mlir::linalg; 33 34 namespace { 35 struct TestLinalgTransforms 36 : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> { 37 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms) 38 39 TestLinalgTransforms() = default; 40 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 41 42 void getDependentDialects(DialectRegistry ®istry) const override { 43 // clang-format off 44 registry.insert<AffineDialect, 45 bufferization::BufferizationDialect, 46 memref::MemRefDialect, 47 scf::SCFDialect, 48 linalg::LinalgDialect, 49 vector::VectorDialect, 50 gpu::GPUDialect>(); 51 // clang-format on 52 } 53 StringRef getArgument() const final { 54 return "test-linalg-transform-patterns"; 55 } 56 StringRef getDescription() const final { 57 return "Test Linalg transformation patterns by applying them greedily."; 58 } 59 60 void runOnOperation() override; 61 62 Option<bool> testPatterns{*this, "test-patterns", 63 llvm::cl::desc("Test a mixed set of patterns"), 64 llvm::cl::init(false)}; 65 Option<bool> testMatmulToVectorPatterns1dTiling{ 66 *this, "test-matmul-to-vector-patterns-tile-1d", 67 llvm::cl::desc( 68 "Test a fused pass that applies patterns from matmul to vectors via " 69 "1-d tiling"), 70 llvm::cl::init(false)}; 71 Option<bool> testMatmulToVectorPatterns2dTiling{ 72 *this, "test-matmul-to-vector-patterns-tile-2d", 73 llvm::cl::desc( 74 "Test a fused pass that applies patterns from matmul to vectors via " 75 "2-d tiling"), 76 llvm::cl::init(false)}; 77 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 78 llvm::cl::desc("Test promotion options"), 79 llvm::cl::init(false)}; 80 Option<bool> testTileAndDistributionOptions{ 81 *this, "test-tile-and-distribute-options", 82 llvm::cl::desc("Test tile and distribute options"), 83 llvm::cl::init(false)}; 84 Option<bool> testTileFuseAndDistributionOptions{ 85 *this, "test-tile-fuse-and-distribute-options", 86 llvm::cl::desc("Test tile, fuse and distribute options"), 87 llvm::cl::init(false)}; 88 Option<bool> testVectorTransferForwardingPatterns{ 89 *this, "test-vector-transfer-forwarding-patterns", 90 llvm::cl::desc( 91 "Test a fused pass that forwards memref.copy to vector.transfer"), 92 llvm::cl::init(false)}; 93 Option<bool> testGenericToVectorPattern{ 94 *this, "test-linalg-to-vector-patterns", 95 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 96 "in vector.contract form"), 97 llvm::cl::init(false)}; 98 Option<bool> testTilePattern{*this, "test-tile-pattern", 99 llvm::cl::desc("Test tile pattern"), 100 llvm::cl::init(false)}; 101 Option<bool> testTileScalarizeDynamicDims{ 102 *this, "test-tile-scalarize-dynamic-dims", 103 llvm::cl::desc("Test tiling of dynamic dims by 1"), 104 llvm::cl::init(false)}; 105 Option<bool> testTransformPadTensor{ 106 *this, "test-transform-pad-tensor", 107 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 108 llvm::cl::init(false)}; 109 Option<bool> testGeneralizePadTensor{ 110 *this, "test-generalize-pad-tensor", 111 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 112 llvm::cl::init(false)}; 113 Option<bool> testSwapSubTensorPadTensor{ 114 *this, "test-swap-subtensor-padtensor", 115 llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into " 116 "tensor.pad(subtensor)"), 117 llvm::cl::init(false)}; 118 Option<bool> testSplitReduction{ 119 *this, "test-split-reduction", 120 llvm::cl::desc("Test split reduction transformation"), 121 llvm::cl::init(false)}; 122 ListOption<int64_t> peeledLoops{ 123 *this, "peeled-loops", 124 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 125 llvm::cl::ZeroOrMore}; 126 ListOption<int64_t> tileSizes{ 127 *this, "tile-sizes", 128 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 129 llvm::cl::ZeroOrMore}; 130 Option<bool> skipPartial{ 131 *this, "skip-partial", 132 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 133 llvm::cl::init(false)}; 134 Option<std::string> loopType{ 135 *this, "loop-type", 136 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 137 "tiled_loop"), 138 llvm::cl::init("for")}; 139 Option<bool> testBubbleUpExtractSliceOpPattern{ 140 *this, "test-bubble-up-extract-slice-op-pattern", 141 llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " 142 "extract_slice + linalgOp"), 143 llvm::cl::init(false)}; 144 }; 145 } // namespace 146 147 static void applyPatterns(func::FuncOp funcOp) { 148 MLIRContext *ctx = funcOp.getContext(); 149 RewritePatternSet patterns(ctx); 150 151 //===--------------------------------------------------------------------===// 152 // Linalg tiling patterns. 153 //===--------------------------------------------------------------------===// 154 patterns.add<LinalgTilingPattern>( 155 MatmulOp::getOperationName(), ctx, 156 LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 157 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 158 StringAttr::get(ctx, "L3"))); 159 patterns.add<LinalgTilingPattern>( 160 MatmulOp::getOperationName(), ctx, 161 LinalgTilingOptions().setTileSizes({200, 300, 400}), 162 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 163 StringAttr::get(ctx, "L2"))); 164 patterns.add<LinalgTilingPattern>( 165 MatmulOp::getOperationName(), ctx, 166 LinalgTilingOptions().setTileSizes({20, 30, 40}), 167 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 168 StringAttr::get(ctx, "L1"))); 169 patterns.add<LinalgTilingPattern>( 170 MatmulOp::getOperationName(), ctx, 171 LinalgTilingOptions().setTileSizes({2, 3, 4}), 172 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 173 StringAttr::get(ctx, "REG"))); 174 175 patterns.add<LinalgTilingPattern>( 176 MatvecOp::getOperationName(), ctx, 177 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 178 LinalgTilingLoopType::ParallelLoops), 179 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 180 StringAttr::get(ctx, "L1"))); 181 182 patterns.add<LinalgTilingPattern>( 183 DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000), 184 LinalgTransformationFilter( 185 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 186 StringAttr::get(ctx, "L3"), 187 StringAttr::get(ctx, "L2")}, 188 StringAttr::get(ctx, "REG"))); 189 190 //===--------------------------------------------------------------------===// 191 // Linalg tiling and permutation patterns. 192 //===--------------------------------------------------------------------===// 193 patterns.add<LinalgTilingPattern>( 194 MatmulOp::getOperationName(), ctx, 195 LinalgTilingOptions() 196 .setTileSizes({2000, 3000, 4000}) 197 .setInterchange({1, 2, 0}), 198 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 199 StringAttr::get(ctx, "L2__with_perm__"))); 200 patterns.add<LinalgTilingPattern>( 201 MatmulOp::getOperationName(), ctx, 202 LinalgTilingOptions() 203 .setTileSizes({200, 300, 400}) 204 .setInterchange({1, 0, 2}), 205 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 206 StringAttr::get(ctx, "L1__with_perm__"))); 207 patterns.add<LinalgTilingPattern>( 208 MatmulOp::getOperationName(), ctx, 209 LinalgTilingOptions().setTileSizes({20, 30, 40}), 210 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 211 StringAttr::get(ctx, "REG__with_perm__"))); 212 213 patterns.add<LinalgTilingPattern>( 214 MatvecOp::getOperationName(), ctx, 215 LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 216 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 217 StringAttr::get(ctx, "L1__with_perm__"))); 218 219 patterns.add<LinalgTilingPattern>( 220 MatmulOp::getOperationName(), ctx, 221 LinalgTilingOptions() 222 .setTileSizes({16, 8, 4}) 223 .setInterchange({1, 2, 0}) 224 .setLoopType(LinalgTilingLoopType::ParallelLoops), 225 LinalgTransformationFilter( 226 StringAttr::get(ctx, "par__with_perm__"), 227 StringAttr::get(ctx, "after_par__with_perm__"))); 228 229 //===--------------------------------------------------------------------===// 230 // Linalg to loops patterns. 231 //===--------------------------------------------------------------------===// 232 patterns.add<LinalgLoweringPattern<DotOp>>( 233 ctx, 234 /*loweringType=*/LinalgLoweringType::Loops, 235 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 236 237 //===--------------------------------------------------------------------===// 238 // Linalg distribution patterns. 239 //===--------------------------------------------------------------------===// 240 LinalgLoopDistributionOptions distributionOptions; 241 242 //===--------------------------------------------------------------------===// 243 // Linalg to vector contraction patterns. 244 //===--------------------------------------------------------------------===// 245 patterns.add<LinalgVectorizationPattern>( 246 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 247 .addOpFilter<MatmulOp, FillOp, GenericOp>()); 248 patterns.add<CopyVectorizationPattern>(ctx); 249 250 //===--------------------------------------------------------------------===// 251 // Linalg generic interchange pattern. 252 //===--------------------------------------------------------------------===// 253 patterns.add<GenericOpInterchangePattern>( 254 ctx, 255 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 256 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 257 StringAttr::get(ctx, "PERMUTED"))); 258 259 //===--------------------------------------------------------------------===// 260 // Linalg subview operands promotion. 261 //===--------------------------------------------------------------------===// 262 patterns.add<LinalgPromotionPattern<MatmulOp>>( 263 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 264 LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), 265 StringAttr::get(ctx, "_views_promoted_"))); 266 patterns.add<LinalgPromotionPattern<MatmulOp>>( 267 ctx, 268 LinalgPromotionOptions() 269 .setOperandsToPromote({0}) 270 .setUseFullTileBuffersByDefault(true), 271 LinalgTransformationFilter( 272 StringAttr::get(ctx, "_promote_first_view_"), 273 StringAttr::get(ctx, "_first_view_promoted_"))); 274 patterns.add<LinalgPromotionPattern<FillOp>>( 275 ctx, 276 LinalgPromotionOptions() 277 .setOperandsToPromote({1}) 278 .setUseFullTileBuffers({false, true}) 279 .setAlignment(32), 280 LinalgTransformationFilter( 281 StringAttr::get(ctx, "_promote_views_aligned_"), 282 StringAttr::get(ctx, "_views_aligned_promoted_"))); 283 284 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 285 286 // Drop the marker. 287 funcOp.walk([](LinalgOp op) { 288 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 289 }); 290 } 291 292 static void fillL1TilingAndMatmulToVectorPatterns( 293 func::FuncOp funcOp, StringRef startMarker, 294 SmallVectorImpl<RewritePatternSet> &patternsVector) { 295 MLIRContext *ctx = funcOp.getContext(); 296 patternsVector.emplace_back( 297 ctx, std::make_unique<LinalgTilingPattern>( 298 MatmulOp::getOperationName(), ctx, 299 LinalgTilingOptions() 300 .setTileSizes({8, 12, 16}) 301 .setInterchange({1, 0, 2}), 302 LinalgTransformationFilter(StringAttr::get(ctx, startMarker), 303 StringAttr::get(ctx, "L1")))); 304 305 patternsVector.emplace_back( 306 ctx, 307 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 308 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 309 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 310 StringAttr::get(ctx, "VEC")))); 311 312 patternsVector.emplace_back( 313 ctx, std::make_unique<LinalgVectorizationPattern>( 314 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 315 LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); 316 patternsVector.back().add<LinalgVectorizationPattern>( 317 ctx, LinalgTransformationFilter().addOpFilter<FillOp>()); 318 patternsVector.back().add<CopyVectorizationPattern>(ctx); 319 } 320 321 //===----------------------------------------------------------------------===// 322 // Test promotion callbacks 323 //===----------------------------------------------------------------------===// 324 325 // Allocation call back 326 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 327 ArrayRef<Value> boundingSubViewSize, 328 DataLayout &layout) { 329 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 330 return b 331 .create<memref::AllocOp>( 332 subView.getLoc(), 333 MemRefType::get(shape, subView.getType().getElementType(), 334 /*affineMapComposition =*/{}, 3), 335 boundingSubViewSize) 336 .getResult(); 337 } 338 339 // Deallocation callback 340 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 341 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 342 return success(); 343 } 344 345 // Copy in call back 346 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 347 bool isOutput) { 348 auto floatType = src.getType().cast<MemRefType>().getElementType(); 349 if (!floatType.isa<FloatType>()) 350 return failure(); 351 if (!isOutput) { 352 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 353 FloatAttr::get(floatType, 42.0)); 354 b.create<FillOp>(src.getLoc(), cst, dst); 355 } 356 b.create<memref::CopyOp>(src.getLoc(), src, dst); 357 return success(); 358 } 359 360 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 361 RewritePatternSet &patterns) { 362 patterns.add<LinalgTilingPattern>( 363 MatmulOp::getOperationName(), ctx, 364 LinalgTilingOptions().setTileSizes({16, 16, 16}), 365 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 366 StringAttr::get(ctx, "PROMOTE"))); 367 patterns.add<LinalgPromotionPattern<MatmulOp>>( 368 ctx, 369 LinalgPromotionOptions() 370 .setOperandsToPromote({0, 2}) 371 .setUseFullTileBuffers({false, false}) 372 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 373 .setCopyInOutFns( 374 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 375 return copyCallBackFn(b, src, dst, false); 376 }, 377 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 378 return copyCallBackFn(b, src, dst, true); 379 }), 380 LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); 381 } 382 383 template <typename IdOp, typename NProcsOp> 384 static SmallVector<ProcInfo, 2> 385 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 386 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 387 SmallVector<ProcInfo, 2> procInfo(count); 388 Type indexType = b.getIndexType(); 389 for (unsigned i = 0; i < count; ++i) { 390 gpu::Dimension dim = *gpu::symbolizeDimension(i); 391 procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim), 392 b.create<NProcsOp>(loc, indexType, dim)}; 393 } 394 return procInfo; 395 } 396 397 static void fillTileAndDistributePatterns(MLIRContext *context, 398 RewritePatternSet &patterns) { 399 { 400 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 401 cyclicNprocsEqNiters.distributionMethod.resize( 402 2, DistributionMethod::CyclicNumProcsEqNumIters); 403 cyclicNprocsEqNiters.procInfo = 404 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 405 patterns.add<LinalgTilingPattern>( 406 MatmulOp::getOperationName(), context, 407 LinalgTilingOptions() 408 .setTileSizes({8, 8, 4}) 409 .setLoopType(LinalgTilingLoopType::ParallelLoops) 410 .setDistributionOptions(cyclicNprocsEqNiters), 411 LinalgTransformationFilter( 412 StringAttr::get(context, "distribute1"), 413 StringAttr::get(context, "after_distribute1"))); 414 } 415 416 { 417 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 418 cyclicNprocsGeNiters.distributionMethod.resize( 419 2, DistributionMethod::CyclicNumProcsGeNumIters); 420 cyclicNprocsGeNiters.procInfo = 421 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 422 patterns.add<LinalgTilingPattern>( 423 MatmulOp::getOperationName(), context, 424 LinalgTilingOptions() 425 .setTileSizes({8, 8, 4}) 426 .setLoopType(LinalgTilingLoopType::ParallelLoops) 427 .setDistributionOptions(cyclicNprocsGeNiters), 428 LinalgTransformationFilter( 429 StringAttr::get(context, "distribute2"), 430 StringAttr::get(context, "after_distribute2"))); 431 } 432 433 { 434 LinalgLoopDistributionOptions cyclicNprocsDefault; 435 cyclicNprocsDefault.distributionMethod.resize(2, 436 DistributionMethod::Cyclic); 437 cyclicNprocsDefault.procInfo = 438 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 439 patterns.add<LinalgTilingPattern>( 440 MatmulOp::getOperationName(), context, 441 LinalgTilingOptions() 442 .setTileSizes({8, 8, 4}) 443 .setLoopType(LinalgTilingLoopType::ParallelLoops) 444 .setDistributionOptions(cyclicNprocsDefault), 445 LinalgTransformationFilter( 446 StringAttr::get(context, "distribute3"), 447 StringAttr::get(context, "after_distribute3"))); 448 } 449 450 { 451 LinalgLoopDistributionOptions cyclicNprocsMixed1; 452 cyclicNprocsMixed1.distributionMethod = { 453 DistributionMethod::CyclicNumProcsEqNumIters, 454 DistributionMethod::CyclicNumProcsGeNumIters}; 455 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 456 patterns.add<LinalgTilingPattern>( 457 MatmulOp::getOperationName(), context, 458 LinalgTilingOptions() 459 .setTileSizes({8, 8, 4}) 460 .setLoopType(LinalgTilingLoopType::ParallelLoops) 461 .setDistributionOptions(cyclicNprocsMixed1), 462 LinalgTransformationFilter( 463 StringAttr::get(context, "distribute4"), 464 StringAttr::get(context, "after_distribute4"))); 465 } 466 467 { 468 LinalgLoopDistributionOptions cyclicNprocsMixed2; 469 cyclicNprocsMixed2.distributionMethod = { 470 DistributionMethod::CyclicNumProcsGeNumIters, 471 DistributionMethod::Cyclic}; 472 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 473 patterns.add<LinalgTilingPattern>( 474 MatmulOp::getOperationName(), context, 475 LinalgTilingOptions() 476 .setTileSizes({8, 8, 4}) 477 .setLoopType(LinalgTilingLoopType::ParallelLoops) 478 .setDistributionOptions(cyclicNprocsMixed2), 479 LinalgTransformationFilter( 480 StringAttr::get(context, "distribute5"), 481 StringAttr::get(context, "after_distribute5"))); 482 } 483 484 { 485 LinalgLoopDistributionOptions cyclicNprocsMixed3; 486 cyclicNprocsMixed3.distributionMethod = { 487 DistributionMethod::Cyclic, 488 DistributionMethod::CyclicNumProcsEqNumIters}; 489 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 490 491 patterns.add<LinalgTilingPattern>( 492 MatmulOp::getOperationName(), context, 493 LinalgTilingOptions() 494 .setTileSizes({8, 8, 4}) 495 .setLoopType(LinalgTilingLoopType::ParallelLoops) 496 .setDistributionOptions(cyclicNprocsMixed3), 497 LinalgTransformationFilter( 498 StringAttr::get(context, "distribute6"), 499 StringAttr::get(context, "after_distribute6"))); 500 } 501 502 { 503 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 504 cyclicNprocsEqNiters.distributionMethod.resize(2, 505 DistributionMethod::Cyclic); 506 cyclicNprocsEqNiters.procInfo = 507 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 508 patterns.add<LinalgTilingPattern>( 509 MatmulOp::getOperationName(), context, 510 LinalgTilingOptions() 511 .setTileSizes({8, 8, 4}) 512 .setLoopType(LinalgTilingLoopType::Loops) 513 .setDistributionOptions(cyclicNprocsEqNiters), 514 LinalgTransformationFilter( 515 StringAttr::get(context, "tensors_distribute1"), 516 StringAttr::get(context, "tensors_after_distribute1"))); 517 } 518 } 519 520 static void fillTileFuseAndDistributePatterns(MLIRContext *context, 521 RewritePatternSet &patterns) { 522 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 523 cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); 524 cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 525 patterns.add<LinalgTileAndFuseTensorOpsPattern>( 526 MatmulOp::getOperationName(), context, 527 LinalgTilingAndFusionOptions() 528 .setTileSizes({8, 8, 4}) 529 .setDistributionOptions(cyclicNprocsEqNiters), 530 LinalgTransformationFilter( 531 StringAttr::get(context, "tensors_fuse_distribute1"), 532 StringAttr::get(context, "tensors_after_fuse_distribute1"))); 533 } 534 535 static void 536 applyMatmulToVectorPatterns(func::FuncOp funcOp, 537 bool testMatmulToVectorPatterns1dTiling, 538 bool testMatmulToVectorPatterns2dTiling) { 539 MLIRContext *ctx = funcOp.getContext(); 540 SmallVector<RewritePatternSet, 4> stage1Patterns; 541 if (testMatmulToVectorPatterns1dTiling) { 542 fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); 543 } else if (testMatmulToVectorPatterns2dTiling) { 544 stage1Patterns.emplace_back( 545 ctx, std::make_unique<LinalgTilingPattern>( 546 MatmulOp::getOperationName(), ctx, 547 LinalgTilingOptions() 548 .setTileSizes({768, 264, 768}) 549 .setInterchange({1, 2, 0}), 550 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 551 StringAttr::get(ctx, "L2")))); 552 fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); 553 } 554 { 555 // Canonicalization patterns 556 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 557 vector::populateVectorTransferPermutationMapLoweringPatterns( 558 canonicalizationPatterns); 559 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 560 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 561 } 562 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 563 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 564 FrozenRewritePatternSet stage2Patterns = 565 getLinalgTilingCanonicalizationPatterns(ctx); 566 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns); 567 } 568 569 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { 570 RewritePatternSet forwardPattern(funcOp.getContext()); 571 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 572 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 573 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 574 } 575 576 static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { 577 RewritePatternSet patterns(funcOp.getContext()); 578 auto *ctx = funcOp.getContext(); 579 patterns.add<LinalgVectorizationPattern>( 580 ctx, LinalgTransformationFilter() 581 .addOpFilter<ContractionOpInterface, FillOp, GenericOp>()); 582 patterns.add<CopyVectorizationPattern>(ctx); 583 populatePadOpVectorizationPatterns(patterns); 584 populateConvolutionVectorizationPatterns(patterns); 585 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 586 } 587 588 static void applyPadTensorToGenericPatterns(func::FuncOp funcOp) { 589 RewritePatternSet patterns(funcOp.getContext()); 590 patterns.add<PadOpTransformationPattern>(funcOp.getContext()); 591 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 592 } 593 594 static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) { 595 RewritePatternSet patterns(funcOp.getContext()); 596 patterns.add<GeneralizePadOpPattern>(funcOp.getContext()); 597 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 598 } 599 600 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { 601 RewritePatternSet patterns(funcOp.getContext()); 602 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 603 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 604 } 605 606 static void applyTilePattern(func::FuncOp funcOp, const std::string &loopType, 607 ArrayRef<int64_t> tileSizes, 608 ArrayRef<int64_t> peeledLoops, 609 bool scalarizeDynamicDims) { 610 MLIRContext *context = funcOp.getContext(); 611 RewritePatternSet tilingPattern(context); 612 LinalgTilingLoopType type = 613 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 614 .Case("for", LinalgTilingLoopType::Loops) 615 .Case("affine", LinalgTilingLoopType::AffineLoops) 616 .Case("parallel", LinalgTilingLoopType::ParallelLoops); 617 auto linalgTilingOptions = linalg::LinalgTilingOptions() 618 .setPeeledLoops(peeledLoops) 619 .setLoopType(type); 620 if (scalarizeDynamicDims) { 621 linalgTilingOptions.scalarizeDynamicDims(); 622 assert(tileSizes.empty() && 623 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 624 } else { 625 linalgTilingOptions.setTileSizes(tileSizes); 626 } 627 linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile")); 628 TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert( 629 tilingPattern, linalgTilingOptions, f); 630 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 631 } 632 633 static void applySplitReduction(func::FuncOp funcOp) { 634 RewritePatternSet patterns(funcOp.getContext()); 635 linalg::populateSplitReductionPattern( 636 patterns, 637 [](LinalgOp op) { 638 unsigned insertDimIndex = op.getNumLoops() - 1; 639 return std::make_pair(4, insertDimIndex); 640 }, 641 LinalgTransformationFilter( 642 ArrayRef<StringAttr>{}, 643 StringAttr::get(funcOp.getContext(), "SPLIT"))); 644 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 645 } 646 647 static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) { 648 RewritePatternSet patterns(funcOp.getContext()); 649 populateBubbleUpExtractSliceOpPatterns(patterns); 650 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 651 } 652 653 /// Apply transformations specified as patterns. 654 void TestLinalgTransforms::runOnOperation() { 655 auto lambda = [&](void *) { 656 getOperation().walk([](LinalgOp op) { 657 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 658 }); 659 }; 660 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 661 662 if (testPromotionOptions) { 663 RewritePatternSet patterns(&getContext()); 664 fillPromotionCallBackPatterns(&getContext(), patterns); 665 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 666 return; 667 } 668 if (testTileAndDistributionOptions) { 669 RewritePatternSet patterns(&getContext()); 670 fillTileAndDistributePatterns(&getContext(), patterns); 671 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 672 return; 673 } 674 if (testTileFuseAndDistributionOptions) { 675 RewritePatternSet patterns(&getContext()); 676 fillTileFuseAndDistributePatterns(&getContext(), patterns); 677 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 678 return; 679 } 680 if (testPatterns) 681 return applyPatterns(getOperation()); 682 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 683 return applyMatmulToVectorPatterns(getOperation(), 684 testMatmulToVectorPatterns1dTiling, 685 testMatmulToVectorPatterns2dTiling); 686 if (testVectorTransferForwardingPatterns) 687 return applyVectorTransferForwardingPatterns(getOperation()); 688 if (testGenericToVectorPattern) 689 return applyLinalgToVectorPatterns(getOperation()); 690 if (testTransformPadTensor) 691 return applyPadTensorToGenericPatterns(getOperation()); 692 if (testGeneralizePadTensor) 693 return applyGeneralizePadTensorPatterns(getOperation()); 694 if (testSwapSubTensorPadTensor) 695 return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 696 if (testTilePattern) 697 return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops, 698 /*scalarizeDynamicDims=*/false); 699 if (testTileScalarizeDynamicDims) 700 return applyTilePattern(getOperation(), loopType, tileSizes, 701 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 702 if (testSplitReduction) 703 return applySplitReduction(getOperation()); 704 if (testBubbleUpExtractSliceOpPattern) 705 return applyBubbleUpExtractSliceOpPattern(getOperation()); 706 } 707 708 namespace mlir { 709 namespace test { 710 void registerTestLinalgTransforms() { 711 PassRegistration<TestLinalgTransforms>(); 712 } 713 } // namespace test 714 } // namespace mlir 715