1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallVector.h" 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 namespace { 34 struct TestLinalgTransforms 35 : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> { 36 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms) 37 38 TestLinalgTransforms() = default; 39 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 40 41 void getDependentDialects(DialectRegistry ®istry) const override { 42 // clang-format off 43 registry.insert<AffineDialect, 44 memref::MemRefDialect, 45 scf::SCFDialect, 46 linalg::LinalgDialect, 47 vector::VectorDialect, 48 gpu::GPUDialect>(); 49 // clang-format on 50 } 51 StringRef getArgument() const final { 52 return "test-linalg-transform-patterns"; 53 } 54 StringRef getDescription() const final { 55 return "Test Linalg transformation patterns by applying them greedily."; 56 } 57 58 void runOnOperation() override; 59 60 Option<bool> testPatterns{*this, "test-patterns", 61 llvm::cl::desc("Test a mixed set of patterns"), 62 llvm::cl::init(false)}; 63 Option<bool> testMatmulToVectorPatterns1dTiling{ 64 *this, "test-matmul-to-vector-patterns-tile-1d", 65 llvm::cl::desc( 66 "Test a fused pass that applies patterns from matmul to vectors via " 67 "1-d tiling"), 68 llvm::cl::init(false)}; 69 Option<bool> testMatmulToVectorPatterns2dTiling{ 70 *this, "test-matmul-to-vector-patterns-tile-2d", 71 llvm::cl::desc( 72 "Test a fused pass that applies patterns from matmul to vectors via " 73 "2-d tiling"), 74 llvm::cl::init(false)}; 75 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 76 llvm::cl::desc("Test promotion options"), 77 llvm::cl::init(false)}; 78 Option<bool> testTileAndDistributionOptions{ 79 *this, "test-tile-and-distribute-options", 80 llvm::cl::desc("Test tile and distribute options"), 81 llvm::cl::init(false)}; 82 Option<bool> testTileFuseAndDistributionOptions{ 83 *this, "test-tile-fuse-and-distribute-options", 84 llvm::cl::desc("Test tile, fuse and distribute options"), 85 llvm::cl::init(false)}; 86 Option<bool> testVectorTransferForwardingPatterns{ 87 *this, "test-vector-transfer-forwarding-patterns", 88 llvm::cl::desc( 89 "Test a fused pass that forwards memref.copy to vector.transfer"), 90 llvm::cl::init(false)}; 91 Option<bool> testGenericToVectorPattern{ 92 *this, "test-linalg-to-vector-patterns", 93 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 94 "in vector.contract form"), 95 llvm::cl::init(false)}; 96 Option<bool> testTilePattern{*this, "test-tile-pattern", 97 llvm::cl::desc("Test tile pattern"), 98 llvm::cl::init(false)}; 99 Option<bool> testTileScalarizeDynamicDims{ 100 *this, "test-tile-scalarize-dynamic-dims", 101 llvm::cl::desc("Test tiling of dynamic dims by 1"), 102 llvm::cl::init(false)}; 103 Option<bool> testTransformPadTensor{ 104 *this, "test-transform-pad-tensor", 105 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 106 llvm::cl::init(false)}; 107 Option<bool> testGeneralizePadTensor{ 108 *this, "test-generalize-pad-tensor", 109 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 110 llvm::cl::init(false)}; 111 Option<bool> testSwapSubTensorPadTensor{ 112 *this, "test-swap-subtensor-padtensor", 113 llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into " 114 "tensor.pad(subtensor)"), 115 llvm::cl::init(false)}; 116 Option<bool> testSplitReduction{ 117 *this, "test-split-reduction", 118 llvm::cl::desc("Test split reduction transformation"), 119 llvm::cl::init(false)}; 120 ListOption<int64_t> peeledLoops{ 121 *this, "peeled-loops", 122 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 123 llvm::cl::ZeroOrMore}; 124 ListOption<int64_t> tileSizes{ 125 *this, "tile-sizes", 126 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 127 llvm::cl::ZeroOrMore}; 128 Option<bool> skipPartial{ 129 *this, "skip-partial", 130 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 131 llvm::cl::init(false)}; 132 Option<std::string> loopType{ 133 *this, "loop-type", 134 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 135 "tiled_loop"), 136 llvm::cl::init("for")}; 137 Option<bool> testBubbleUpExtractSliceOpPattern{ 138 *this, "test-bubble-up-extract-slice-op-pattern", 139 llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " 140 "extract_slice + linalgOp"), 141 llvm::cl::init(false)}; 142 }; 143 } // namespace 144 145 static void applyPatterns(func::FuncOp funcOp) { 146 MLIRContext *ctx = funcOp.getContext(); 147 RewritePatternSet patterns(ctx); 148 149 //===--------------------------------------------------------------------===// 150 // Linalg tiling patterns. 151 //===--------------------------------------------------------------------===// 152 patterns.add<LinalgTilingPattern>( 153 MatmulOp::getOperationName(), ctx, 154 LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 155 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 156 StringAttr::get(ctx, "L3"))); 157 patterns.add<LinalgTilingPattern>( 158 MatmulOp::getOperationName(), ctx, 159 LinalgTilingOptions().setTileSizes({200, 300, 400}), 160 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 161 StringAttr::get(ctx, "L2"))); 162 patterns.add<LinalgTilingPattern>( 163 MatmulOp::getOperationName(), ctx, 164 LinalgTilingOptions().setTileSizes({20, 30, 40}), 165 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 166 StringAttr::get(ctx, "L1"))); 167 patterns.add<LinalgTilingPattern>( 168 MatmulOp::getOperationName(), ctx, 169 LinalgTilingOptions().setTileSizes({2, 3, 4}), 170 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 171 StringAttr::get(ctx, "REG"))); 172 173 patterns.add<LinalgTilingPattern>( 174 MatvecOp::getOperationName(), ctx, 175 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 176 LinalgTilingLoopType::ParallelLoops), 177 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 178 StringAttr::get(ctx, "L1"))); 179 180 patterns.add<LinalgTilingPattern>( 181 DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000), 182 LinalgTransformationFilter( 183 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 184 StringAttr::get(ctx, "L3"), 185 StringAttr::get(ctx, "L2")}, 186 StringAttr::get(ctx, "REG"))); 187 188 //===--------------------------------------------------------------------===// 189 // Linalg tiling and permutation patterns. 190 //===--------------------------------------------------------------------===// 191 patterns.add<LinalgTilingPattern>( 192 MatmulOp::getOperationName(), ctx, 193 LinalgTilingOptions() 194 .setTileSizes({2000, 3000, 4000}) 195 .setInterchange({1, 2, 0}), 196 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 197 StringAttr::get(ctx, "L2__with_perm__"))); 198 patterns.add<LinalgTilingPattern>( 199 MatmulOp::getOperationName(), ctx, 200 LinalgTilingOptions() 201 .setTileSizes({200, 300, 400}) 202 .setInterchange({1, 0, 2}), 203 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 204 StringAttr::get(ctx, "L1__with_perm__"))); 205 patterns.add<LinalgTilingPattern>( 206 MatmulOp::getOperationName(), ctx, 207 LinalgTilingOptions().setTileSizes({20, 30, 40}), 208 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 209 StringAttr::get(ctx, "REG__with_perm__"))); 210 211 patterns.add<LinalgTilingPattern>( 212 MatvecOp::getOperationName(), ctx, 213 LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 214 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 215 StringAttr::get(ctx, "L1__with_perm__"))); 216 217 patterns.add<LinalgTilingPattern>( 218 MatmulOp::getOperationName(), ctx, 219 LinalgTilingOptions() 220 .setTileSizes({16, 8, 4}) 221 .setInterchange({1, 2, 0}) 222 .setLoopType(LinalgTilingLoopType::ParallelLoops), 223 LinalgTransformationFilter( 224 StringAttr::get(ctx, "par__with_perm__"), 225 StringAttr::get(ctx, "after_par__with_perm__"))); 226 227 //===--------------------------------------------------------------------===// 228 // Linalg to loops patterns. 229 //===--------------------------------------------------------------------===// 230 patterns.add<LinalgLoweringPattern<DotOp>>( 231 ctx, 232 /*loweringType=*/LinalgLoweringType::Loops, 233 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 234 235 //===--------------------------------------------------------------------===// 236 // Linalg distribution patterns. 237 //===--------------------------------------------------------------------===// 238 LinalgLoopDistributionOptions distributionOptions; 239 240 //===--------------------------------------------------------------------===// 241 // Linalg to vector contraction patterns. 242 //===--------------------------------------------------------------------===// 243 patterns.add<LinalgVectorizationPattern>( 244 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 245 .addOpFilter<MatmulOp, FillOp, GenericOp>()); 246 patterns.add<CopyVectorizationPattern>(ctx); 247 248 //===--------------------------------------------------------------------===// 249 // Linalg generic interchange pattern. 250 //===--------------------------------------------------------------------===// 251 patterns.add<GenericOpInterchangePattern>( 252 ctx, 253 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 254 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 255 StringAttr::get(ctx, "PERMUTED"))); 256 257 //===--------------------------------------------------------------------===// 258 // Linalg subview operands promotion. 259 //===--------------------------------------------------------------------===// 260 patterns.add<LinalgPromotionPattern<MatmulOp>>( 261 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 262 LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), 263 StringAttr::get(ctx, "_views_promoted_"))); 264 patterns.add<LinalgPromotionPattern<MatmulOp>>( 265 ctx, 266 LinalgPromotionOptions() 267 .setOperandsToPromote({0}) 268 .setUseFullTileBuffersByDefault(true), 269 LinalgTransformationFilter( 270 StringAttr::get(ctx, "_promote_first_view_"), 271 StringAttr::get(ctx, "_first_view_promoted_"))); 272 patterns.add<LinalgPromotionPattern<FillOp>>( 273 ctx, 274 LinalgPromotionOptions() 275 .setOperandsToPromote({1}) 276 .setUseFullTileBuffers({false, true}) 277 .setAlignment(32), 278 LinalgTransformationFilter( 279 StringAttr::get(ctx, "_promote_views_aligned_"), 280 StringAttr::get(ctx, "_views_aligned_promoted_"))); 281 282 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 283 284 // Drop the marker. 285 funcOp.walk([](LinalgOp op) { 286 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 287 }); 288 } 289 290 static void fillL1TilingAndMatmulToVectorPatterns( 291 func::FuncOp funcOp, StringRef startMarker, 292 SmallVectorImpl<RewritePatternSet> &patternsVector) { 293 MLIRContext *ctx = funcOp.getContext(); 294 patternsVector.emplace_back( 295 ctx, std::make_unique<LinalgTilingPattern>( 296 MatmulOp::getOperationName(), ctx, 297 LinalgTilingOptions() 298 .setTileSizes({8, 12, 16}) 299 .setInterchange({1, 0, 2}), 300 LinalgTransformationFilter(StringAttr::get(ctx, startMarker), 301 StringAttr::get(ctx, "L1")))); 302 303 patternsVector.emplace_back( 304 ctx, 305 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 306 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 307 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 308 StringAttr::get(ctx, "VEC")))); 309 310 patternsVector.emplace_back( 311 ctx, std::make_unique<LinalgVectorizationPattern>( 312 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 313 LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); 314 patternsVector.back().add<LinalgVectorizationPattern>( 315 ctx, LinalgTransformationFilter().addOpFilter<FillOp>()); 316 patternsVector.back().add<CopyVectorizationPattern>(ctx); 317 } 318 319 //===----------------------------------------------------------------------===// 320 // Test promotion callbacks 321 //===----------------------------------------------------------------------===// 322 323 // Allocation call back 324 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 325 ArrayRef<Value> boundingSubViewSize, 326 DataLayout &layout) { 327 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 328 return b 329 .create<memref::AllocOp>( 330 subView.getLoc(), 331 MemRefType::get(shape, subView.getType().getElementType(), 332 /*affineMapComposition =*/{}, 3), 333 boundingSubViewSize) 334 .getResult(); 335 } 336 337 // Deallocation callback 338 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 339 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 340 return success(); 341 } 342 343 // Copy in call back 344 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 345 bool isOutput) { 346 auto floatType = src.getType().cast<MemRefType>().getElementType(); 347 if (!floatType.isa<FloatType>()) 348 return failure(); 349 if (!isOutput) { 350 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 351 FloatAttr::get(floatType, 42.0)); 352 b.create<FillOp>(src.getLoc(), cst, dst); 353 } 354 b.create<memref::CopyOp>(src.getLoc(), src, dst); 355 return success(); 356 } 357 358 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 359 RewritePatternSet &patterns) { 360 patterns.add<LinalgTilingPattern>( 361 MatmulOp::getOperationName(), ctx, 362 LinalgTilingOptions().setTileSizes({16, 16, 16}), 363 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 364 StringAttr::get(ctx, "PROMOTE"))); 365 patterns.add<LinalgPromotionPattern<MatmulOp>>( 366 ctx, 367 LinalgPromotionOptions() 368 .setOperandsToPromote({0, 2}) 369 .setUseFullTileBuffers({false, false}) 370 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 371 .setCopyInOutFns( 372 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 373 return copyCallBackFn(b, src, dst, false); 374 }, 375 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 376 return copyCallBackFn(b, src, dst, true); 377 }), 378 LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); 379 } 380 381 template <typename IdOp, typename NProcsOp> 382 static SmallVector<ProcInfo, 2> 383 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 384 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 385 SmallVector<ProcInfo, 2> procInfo(count); 386 Type indexType = b.getIndexType(); 387 for (unsigned i = 0; i < count; ++i) { 388 gpu::Dimension dim = *gpu::symbolizeDimension(i); 389 procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim), 390 b.create<NProcsOp>(loc, indexType, dim)}; 391 } 392 return procInfo; 393 } 394 395 static void fillTileAndDistributePatterns(MLIRContext *context, 396 RewritePatternSet &patterns) { 397 { 398 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 399 cyclicNprocsEqNiters.distributionMethod.resize( 400 2, DistributionMethod::CyclicNumProcsEqNumIters); 401 cyclicNprocsEqNiters.procInfo = 402 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 403 patterns.add<LinalgTilingPattern>( 404 MatmulOp::getOperationName(), context, 405 LinalgTilingOptions() 406 .setTileSizes({8, 8, 4}) 407 .setLoopType(LinalgTilingLoopType::ParallelLoops) 408 .setDistributionOptions(cyclicNprocsEqNiters), 409 LinalgTransformationFilter( 410 StringAttr::get(context, "distribute1"), 411 StringAttr::get(context, "after_distribute1"))); 412 } 413 414 { 415 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 416 cyclicNprocsGeNiters.distributionMethod.resize( 417 2, DistributionMethod::CyclicNumProcsGeNumIters); 418 cyclicNprocsGeNiters.procInfo = 419 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 420 patterns.add<LinalgTilingPattern>( 421 MatmulOp::getOperationName(), context, 422 LinalgTilingOptions() 423 .setTileSizes({8, 8, 4}) 424 .setLoopType(LinalgTilingLoopType::ParallelLoops) 425 .setDistributionOptions(cyclicNprocsGeNiters), 426 LinalgTransformationFilter( 427 StringAttr::get(context, "distribute2"), 428 StringAttr::get(context, "after_distribute2"))); 429 } 430 431 { 432 LinalgLoopDistributionOptions cyclicNprocsDefault; 433 cyclicNprocsDefault.distributionMethod.resize(2, 434 DistributionMethod::Cyclic); 435 cyclicNprocsDefault.procInfo = 436 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 437 patterns.add<LinalgTilingPattern>( 438 MatmulOp::getOperationName(), context, 439 LinalgTilingOptions() 440 .setTileSizes({8, 8, 4}) 441 .setLoopType(LinalgTilingLoopType::ParallelLoops) 442 .setDistributionOptions(cyclicNprocsDefault), 443 LinalgTransformationFilter( 444 StringAttr::get(context, "distribute3"), 445 StringAttr::get(context, "after_distribute3"))); 446 } 447 448 { 449 LinalgLoopDistributionOptions cyclicNprocsMixed1; 450 cyclicNprocsMixed1.distributionMethod = { 451 DistributionMethod::CyclicNumProcsEqNumIters, 452 DistributionMethod::CyclicNumProcsGeNumIters}; 453 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 454 patterns.add<LinalgTilingPattern>( 455 MatmulOp::getOperationName(), context, 456 LinalgTilingOptions() 457 .setTileSizes({8, 8, 4}) 458 .setLoopType(LinalgTilingLoopType::ParallelLoops) 459 .setDistributionOptions(cyclicNprocsMixed1), 460 LinalgTransformationFilter( 461 StringAttr::get(context, "distribute4"), 462 StringAttr::get(context, "after_distribute4"))); 463 } 464 465 { 466 LinalgLoopDistributionOptions cyclicNprocsMixed2; 467 cyclicNprocsMixed2.distributionMethod = { 468 DistributionMethod::CyclicNumProcsGeNumIters, 469 DistributionMethod::Cyclic}; 470 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 471 patterns.add<LinalgTilingPattern>( 472 MatmulOp::getOperationName(), context, 473 LinalgTilingOptions() 474 .setTileSizes({8, 8, 4}) 475 .setLoopType(LinalgTilingLoopType::ParallelLoops) 476 .setDistributionOptions(cyclicNprocsMixed2), 477 LinalgTransformationFilter( 478 StringAttr::get(context, "distribute5"), 479 StringAttr::get(context, "after_distribute5"))); 480 } 481 482 { 483 LinalgLoopDistributionOptions cyclicNprocsMixed3; 484 cyclicNprocsMixed3.distributionMethod = { 485 DistributionMethod::Cyclic, 486 DistributionMethod::CyclicNumProcsEqNumIters}; 487 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 488 489 patterns.add<LinalgTilingPattern>( 490 MatmulOp::getOperationName(), context, 491 LinalgTilingOptions() 492 .setTileSizes({8, 8, 4}) 493 .setLoopType(LinalgTilingLoopType::ParallelLoops) 494 .setDistributionOptions(cyclicNprocsMixed3), 495 LinalgTransformationFilter( 496 StringAttr::get(context, "distribute6"), 497 StringAttr::get(context, "after_distribute6"))); 498 } 499 500 { 501 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 502 cyclicNprocsEqNiters.distributionMethod.resize(2, 503 DistributionMethod::Cyclic); 504 cyclicNprocsEqNiters.procInfo = 505 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 506 patterns.add<LinalgTilingPattern>( 507 MatmulOp::getOperationName(), context, 508 LinalgTilingOptions() 509 .setTileSizes({8, 8, 4}) 510 .setLoopType(LinalgTilingLoopType::Loops) 511 .setDistributionOptions(cyclicNprocsEqNiters), 512 LinalgTransformationFilter( 513 StringAttr::get(context, "tensors_distribute1"), 514 StringAttr::get(context, "tensors_after_distribute1"))); 515 } 516 } 517 518 static void fillTileFuseAndDistributePatterns(MLIRContext *context, 519 RewritePatternSet &patterns) { 520 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 521 cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); 522 cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 523 patterns.add<LinalgTileAndFuseTensorOpsPattern>( 524 MatmulOp::getOperationName(), context, 525 LinalgTilingAndFusionOptions() 526 .setTileSizes({8, 8, 4}) 527 .setDistributionOptions(cyclicNprocsEqNiters), 528 LinalgTransformationFilter( 529 StringAttr::get(context, "tensors_fuse_distribute1"), 530 StringAttr::get(context, "tensors_after_fuse_distribute1"))); 531 } 532 533 static void 534 applyMatmulToVectorPatterns(func::FuncOp funcOp, 535 bool testMatmulToVectorPatterns1dTiling, 536 bool testMatmulToVectorPatterns2dTiling) { 537 MLIRContext *ctx = funcOp.getContext(); 538 SmallVector<RewritePatternSet, 4> stage1Patterns; 539 if (testMatmulToVectorPatterns1dTiling) { 540 fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); 541 } else if (testMatmulToVectorPatterns2dTiling) { 542 stage1Patterns.emplace_back( 543 ctx, std::make_unique<LinalgTilingPattern>( 544 MatmulOp::getOperationName(), ctx, 545 LinalgTilingOptions() 546 .setTileSizes({768, 264, 768}) 547 .setInterchange({1, 2, 0}), 548 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 549 StringAttr::get(ctx, "L2")))); 550 fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); 551 } 552 { 553 // Canonicalization patterns 554 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 555 vector::populateVectorTransferPermutationMapLoweringPatterns( 556 canonicalizationPatterns); 557 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 558 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 559 } 560 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 561 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 562 FrozenRewritePatternSet stage2Patterns = 563 getLinalgTilingCanonicalizationPatterns(ctx); 564 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns); 565 } 566 567 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { 568 RewritePatternSet forwardPattern(funcOp.getContext()); 569 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 570 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 571 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 572 } 573 574 static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { 575 RewritePatternSet patterns(funcOp.getContext()); 576 auto *ctx = funcOp.getContext(); 577 patterns.add<LinalgVectorizationPattern>( 578 ctx, LinalgTransformationFilter() 579 .addOpFilter<ContractionOpInterface, FillOp, GenericOp>()); 580 patterns.add<CopyVectorizationPattern>(ctx); 581 populatePadOpVectorizationPatterns(patterns); 582 populateConvolutionVectorizationPatterns(patterns); 583 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 584 } 585 586 static void applyPadTensorToGenericPatterns(func::FuncOp funcOp) { 587 RewritePatternSet patterns(funcOp.getContext()); 588 patterns.add<PadOpTransformationPattern>(funcOp.getContext()); 589 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 590 } 591 592 static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) { 593 RewritePatternSet patterns(funcOp.getContext()); 594 patterns.add<GeneralizePadOpPattern>(funcOp.getContext()); 595 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 596 } 597 598 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { 599 RewritePatternSet patterns(funcOp.getContext()); 600 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 601 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 602 } 603 604 static void applyTilePattern(func::FuncOp funcOp, const std::string &loopType, 605 ArrayRef<int64_t> tileSizes, 606 ArrayRef<int64_t> peeledLoops, 607 bool scalarizeDynamicDims) { 608 MLIRContext *context = funcOp.getContext(); 609 RewritePatternSet tilingPattern(context); 610 LinalgTilingLoopType type = 611 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 612 .Case("for", LinalgTilingLoopType::Loops) 613 .Case("affine", LinalgTilingLoopType::AffineLoops) 614 .Case("parallel", LinalgTilingLoopType::ParallelLoops); 615 auto linalgTilingOptions = linalg::LinalgTilingOptions() 616 .setPeeledLoops(peeledLoops) 617 .setLoopType(type); 618 if (scalarizeDynamicDims) { 619 linalgTilingOptions.scalarizeDynamicDims(); 620 assert(tileSizes.empty() && 621 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 622 } else { 623 linalgTilingOptions.setTileSizes(tileSizes); 624 } 625 linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile")); 626 TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert( 627 tilingPattern, linalgTilingOptions, f); 628 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 629 } 630 631 static void applySplitReduction(func::FuncOp funcOp) { 632 RewritePatternSet patterns(funcOp.getContext()); 633 linalg::populateSplitReductionPattern( 634 patterns, 635 [](LinalgOp op) { 636 unsigned insertDimIndex = op.getNumLoops() - 1; 637 return std::make_pair(4, insertDimIndex); 638 }, 639 LinalgTransformationFilter( 640 ArrayRef<StringAttr>{}, 641 StringAttr::get(funcOp.getContext(), "SPLIT"))); 642 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 643 } 644 645 static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) { 646 RewritePatternSet patterns(funcOp.getContext()); 647 populateBubbleUpExtractSliceOpPatterns(patterns); 648 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 649 } 650 651 /// Apply transformations specified as patterns. 652 void TestLinalgTransforms::runOnOperation() { 653 auto lambda = [&](void *) { 654 getOperation().walk([](LinalgOp op) { 655 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 656 }); 657 }; 658 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 659 660 if (testPromotionOptions) { 661 RewritePatternSet patterns(&getContext()); 662 fillPromotionCallBackPatterns(&getContext(), patterns); 663 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 664 return; 665 } 666 if (testTileAndDistributionOptions) { 667 RewritePatternSet patterns(&getContext()); 668 fillTileAndDistributePatterns(&getContext(), patterns); 669 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 670 return; 671 } 672 if (testTileFuseAndDistributionOptions) { 673 RewritePatternSet patterns(&getContext()); 674 fillTileFuseAndDistributePatterns(&getContext(), patterns); 675 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 676 return; 677 } 678 if (testPatterns) 679 return applyPatterns(getOperation()); 680 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 681 return applyMatmulToVectorPatterns(getOperation(), 682 testMatmulToVectorPatterns1dTiling, 683 testMatmulToVectorPatterns2dTiling); 684 if (testVectorTransferForwardingPatterns) 685 return applyVectorTransferForwardingPatterns(getOperation()); 686 if (testGenericToVectorPattern) 687 return applyLinalgToVectorPatterns(getOperation()); 688 if (testTransformPadTensor) 689 return applyPadTensorToGenericPatterns(getOperation()); 690 if (testGeneralizePadTensor) 691 return applyGeneralizePadTensorPatterns(getOperation()); 692 if (testSwapSubTensorPadTensor) 693 return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 694 if (testTilePattern) 695 return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops, 696 /*scalarizeDynamicDims=*/false); 697 if (testTileScalarizeDynamicDims) 698 return applyTilePattern(getOperation(), loopType, tileSizes, 699 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 700 if (testSplitReduction) 701 return applySplitReduction(getOperation()); 702 if (testBubbleUpExtractSliceOpPattern) 703 return applyBubbleUpExtractSliceOpPattern(getOperation()); 704 } 705 706 namespace mlir { 707 namespace test { 708 void registerTestLinalgTransforms() { 709 PassRegistration<TestLinalgTransforms>(); 710 } 711 } // namespace test 712 } // namespace mlir 713