1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/GPU/GPUDialect.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallVector.h" 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 namespace { 34 struct TestLinalgTransforms 35 : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> { 36 TestLinalgTransforms() = default; 37 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 38 39 void getDependentDialects(DialectRegistry ®istry) const override { 40 // clang-format off 41 registry.insert<AffineDialect, 42 memref::MemRefDialect, 43 scf::SCFDialect, 44 linalg::LinalgDialect, 45 vector::VectorDialect, 46 gpu::GPUDialect>(); 47 // clang-format on 48 } 49 StringRef getArgument() const final { 50 return "test-linalg-transform-patterns"; 51 } 52 StringRef getDescription() const final { 53 return "Test Linalg transformation patterns by applying them greedily."; 54 } 55 56 void runOnOperation() override; 57 58 Option<bool> testPatterns{*this, "test-patterns", 59 llvm::cl::desc("Test a mixed set of patterns"), 60 llvm::cl::init(false)}; 61 Option<bool> testMatmulToVectorPatterns1dTiling{ 62 *this, "test-matmul-to-vector-patterns-tile-1d", 63 llvm::cl::desc( 64 "Test a fused pass that applies patterns from matmul to vectors via " 65 "1-d tiling"), 66 llvm::cl::init(false)}; 67 Option<bool> testMatmulToVectorPatterns2dTiling{ 68 *this, "test-matmul-to-vector-patterns-tile-2d", 69 llvm::cl::desc( 70 "Test a fused pass that applies patterns from matmul to vectors via " 71 "2-d tiling"), 72 llvm::cl::init(false)}; 73 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 74 llvm::cl::desc("Test promotion options"), 75 llvm::cl::init(false)}; 76 Option<bool> testTileAndDistributionOptions{ 77 *this, "test-tile-and-distribute-options", 78 llvm::cl::desc("Test tile and distribute options"), 79 llvm::cl::init(false)}; 80 Option<bool> testTileFuseAndDistributionOptions{ 81 *this, "test-tile-fuse-and-distribute-options", 82 llvm::cl::desc("Test tile, fuse and distribute options"), 83 llvm::cl::init(false)}; 84 Option<bool> testVectorTransferForwardingPatterns{ 85 *this, "test-vector-transfer-forwarding-patterns", 86 llvm::cl::desc( 87 "Test a fused pass that forwards memref.copy to vector.transfer"), 88 llvm::cl::init(false)}; 89 Option<bool> testGenericToVectorPattern{ 90 *this, "test-linalg-to-vector-patterns", 91 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 92 "in vector.contract form"), 93 llvm::cl::init(false)}; 94 Option<bool> testTilePattern{*this, "test-tile-pattern", 95 llvm::cl::desc("Test tile pattern"), 96 llvm::cl::init(false)}; 97 Option<bool> testTileScalarizeDynamicDims{ 98 *this, "test-tile-scalarize-dynamic-dims", 99 llvm::cl::desc("Test tiling of dynamic dims by 1"), 100 llvm::cl::init(false)}; 101 Option<bool> testTransformPadTensor{ 102 *this, "test-transform-pad-tensor", 103 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 104 llvm::cl::init(false)}; 105 Option<bool> testGeneralizePadTensor{ 106 *this, "test-generalize-pad-tensor", 107 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 108 llvm::cl::init(false)}; 109 Option<bool> testSwapSubTensorPadTensor{ 110 *this, "test-swap-subtensor-padtensor", 111 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 112 "pad_tensor(subtensor)"), 113 llvm::cl::init(false)}; 114 Option<bool> testSplitReduction{ 115 *this, "test-split-reduction", 116 llvm::cl::desc("Test split reduction transformation"), 117 llvm::cl::init(false)}; 118 ListOption<int64_t> peeledLoops{ 119 *this, "peeled-loops", 120 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 121 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 122 ListOption<int64_t> tileSizes{ 123 *this, "tile-sizes", 124 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 125 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 126 Option<bool> skipPartial{ 127 *this, "skip-partial", 128 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 129 llvm::cl::init(false)}; 130 Option<std::string> loopType{ 131 *this, "loop-type", 132 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 133 "tiled_loop"), 134 llvm::cl::init("for")}; 135 Option<bool> testBubbleUpExtractSliceOpPattern{ 136 *this, "test-bubble-up-extract-slice-op-pattern", 137 llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " 138 "extract_slice + linalgOp"), 139 llvm::cl::init(false)}; 140 }; 141 } // namespace 142 143 static void applyPatterns(FuncOp funcOp) { 144 MLIRContext *ctx = funcOp.getContext(); 145 RewritePatternSet patterns(ctx); 146 147 //===--------------------------------------------------------------------===// 148 // Linalg tiling patterns. 149 //===--------------------------------------------------------------------===// 150 patterns.add<LinalgTilingPattern>( 151 MatmulOp::getOperationName(), ctx, 152 LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 153 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 154 StringAttr::get(ctx, "L3"))); 155 patterns.add<LinalgTilingPattern>( 156 MatmulOp::getOperationName(), ctx, 157 LinalgTilingOptions().setTileSizes({200, 300, 400}), 158 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 159 StringAttr::get(ctx, "L2"))); 160 patterns.add<LinalgTilingPattern>( 161 MatmulOp::getOperationName(), ctx, 162 LinalgTilingOptions().setTileSizes({20, 30, 40}), 163 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 164 StringAttr::get(ctx, "L1"))); 165 patterns.add<LinalgTilingPattern>( 166 MatmulOp::getOperationName(), ctx, 167 LinalgTilingOptions().setTileSizes({2, 3, 4}), 168 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 169 StringAttr::get(ctx, "REG"))); 170 171 patterns.add<LinalgTilingPattern>( 172 MatvecOp::getOperationName(), ctx, 173 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 174 LinalgTilingLoopType::ParallelLoops), 175 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 176 StringAttr::get(ctx, "L1"))); 177 178 patterns.add<LinalgTilingPattern>( 179 DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000), 180 LinalgTransformationFilter( 181 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 182 StringAttr::get(ctx, "L3"), 183 StringAttr::get(ctx, "L2")}, 184 StringAttr::get(ctx, "REG"))); 185 186 //===--------------------------------------------------------------------===// 187 // Linalg tiling and permutation patterns. 188 //===--------------------------------------------------------------------===// 189 patterns.add<LinalgTilingPattern>( 190 MatmulOp::getOperationName(), ctx, 191 LinalgTilingOptions() 192 .setTileSizes({2000, 3000, 4000}) 193 .setInterchange({1, 2, 0}), 194 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 195 StringAttr::get(ctx, "L2__with_perm__"))); 196 patterns.add<LinalgTilingPattern>( 197 MatmulOp::getOperationName(), ctx, 198 LinalgTilingOptions() 199 .setTileSizes({200, 300, 400}) 200 .setInterchange({1, 0, 2}), 201 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 202 StringAttr::get(ctx, "L1__with_perm__"))); 203 patterns.add<LinalgTilingPattern>( 204 MatmulOp::getOperationName(), ctx, 205 LinalgTilingOptions().setTileSizes({20, 30, 40}), 206 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 207 StringAttr::get(ctx, "REG__with_perm__"))); 208 209 patterns.add<LinalgTilingPattern>( 210 MatvecOp::getOperationName(), ctx, 211 LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 212 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 213 StringAttr::get(ctx, "L1__with_perm__"))); 214 215 patterns.add<LinalgTilingPattern>( 216 MatmulOp::getOperationName(), ctx, 217 LinalgTilingOptions() 218 .setTileSizes({16, 8, 4}) 219 .setInterchange({1, 2, 0}) 220 .setLoopType(LinalgTilingLoopType::ParallelLoops), 221 LinalgTransformationFilter( 222 StringAttr::get(ctx, "par__with_perm__"), 223 StringAttr::get(ctx, "after_par__with_perm__"))); 224 225 //===--------------------------------------------------------------------===// 226 // Linalg to loops patterns. 227 //===--------------------------------------------------------------------===// 228 patterns.add<LinalgLoweringPattern<DotOp>>( 229 ctx, 230 /*loweringType=*/LinalgLoweringType::Loops, 231 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 232 233 //===--------------------------------------------------------------------===// 234 // Linalg distribution patterns. 235 //===--------------------------------------------------------------------===// 236 LinalgLoopDistributionOptions distributionOptions; 237 238 //===--------------------------------------------------------------------===// 239 // Linalg to vector contraction patterns. 240 //===--------------------------------------------------------------------===// 241 patterns.add<LinalgVectorizationPattern>( 242 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 243 .addOpFilter<MatmulOp, FillOp, GenericOp>()); 244 patterns.add<CopyVectorizationPattern>(ctx); 245 246 //===--------------------------------------------------------------------===// 247 // Linalg generic interchange pattern. 248 //===--------------------------------------------------------------------===// 249 patterns.add<GenericOpInterchangePattern>( 250 ctx, 251 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 252 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 253 StringAttr::get(ctx, "PERMUTED"))); 254 255 //===--------------------------------------------------------------------===// 256 // Linalg subview operands promotion. 257 //===--------------------------------------------------------------------===// 258 patterns.add<LinalgPromotionPattern<MatmulOp>>( 259 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 260 LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), 261 StringAttr::get(ctx, "_views_promoted_"))); 262 patterns.add<LinalgPromotionPattern<MatmulOp>>( 263 ctx, 264 LinalgPromotionOptions() 265 .setOperandsToPromote({0}) 266 .setUseFullTileBuffersByDefault(true), 267 LinalgTransformationFilter( 268 StringAttr::get(ctx, "_promote_first_view_"), 269 StringAttr::get(ctx, "_first_view_promoted_"))); 270 patterns.add<LinalgPromotionPattern<FillOp>>( 271 ctx, 272 LinalgPromotionOptions() 273 .setOperandsToPromote({1}) 274 .setUseFullTileBuffers({false, true}) 275 .setAlignment(32), 276 LinalgTransformationFilter( 277 StringAttr::get(ctx, "_promote_views_aligned_"), 278 StringAttr::get(ctx, "_views_aligned_promoted_"))); 279 280 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 281 282 // Drop the marker. 283 funcOp.walk([](LinalgOp op) { 284 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 285 }); 286 } 287 288 static void fillL1TilingAndMatmulToVectorPatterns( 289 FuncOp funcOp, StringRef startMarker, 290 SmallVectorImpl<RewritePatternSet> &patternsVector) { 291 MLIRContext *ctx = funcOp.getContext(); 292 patternsVector.emplace_back( 293 ctx, std::make_unique<LinalgTilingPattern>( 294 MatmulOp::getOperationName(), ctx, 295 LinalgTilingOptions() 296 .setTileSizes({8, 12, 16}) 297 .setInterchange({1, 0, 2}), 298 LinalgTransformationFilter(StringAttr::get(ctx, startMarker), 299 StringAttr::get(ctx, "L1")))); 300 301 patternsVector.emplace_back( 302 ctx, 303 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 304 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 305 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 306 StringAttr::get(ctx, "VEC")))); 307 308 patternsVector.emplace_back( 309 ctx, std::make_unique<LinalgVectorizationPattern>( 310 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 311 LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); 312 patternsVector.back().add<LinalgVectorizationPattern>( 313 ctx, LinalgTransformationFilter().addOpFilter<FillOp>()); 314 patternsVector.back().add<CopyVectorizationPattern>(ctx); 315 } 316 317 //===----------------------------------------------------------------------===// 318 // Test promotion callbacks 319 //===----------------------------------------------------------------------===// 320 321 // Allocation call back 322 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 323 ArrayRef<Value> boundingSubViewSize, 324 DataLayout &layout) { 325 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 326 return b 327 .create<memref::AllocOp>( 328 subView.getLoc(), 329 MemRefType::get(shape, subView.getType().getElementType(), 330 /*affineMapComposition =*/{}, 3), 331 boundingSubViewSize) 332 .getResult(); 333 } 334 335 // Deallocation callback 336 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 337 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 338 return success(); 339 } 340 341 // Copy in call back 342 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 343 bool isOutput) { 344 auto floatType = src.getType().cast<MemRefType>().getElementType(); 345 if (!floatType.isa<FloatType>()) 346 return failure(); 347 if (!isOutput) { 348 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 349 FloatAttr::get(floatType, 42.0)); 350 b.create<FillOp>(src.getLoc(), cst, dst); 351 } 352 b.create<memref::CopyOp>(src.getLoc(), src, dst); 353 return success(); 354 } 355 356 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 357 RewritePatternSet &patterns) { 358 patterns.add<LinalgTilingPattern>( 359 MatmulOp::getOperationName(), ctx, 360 LinalgTilingOptions().setTileSizes({16, 16, 16}), 361 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 362 StringAttr::get(ctx, "PROMOTE"))); 363 patterns.add<LinalgPromotionPattern<MatmulOp>>( 364 ctx, 365 LinalgPromotionOptions() 366 .setOperandsToPromote({0, 2}) 367 .setUseFullTileBuffers({false, false}) 368 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 369 .setCopyInOutFns( 370 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 371 return copyCallBackFn(b, src, dst, false); 372 }, 373 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 374 return copyCallBackFn(b, src, dst, true); 375 }), 376 LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); 377 } 378 379 template <typename IdOp, typename NProcsOp> 380 static SmallVector<ProcInfo, 2> 381 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 382 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 383 SmallVector<ProcInfo, 2> procInfo(count); 384 Type indexType = b.getIndexType(); 385 for (unsigned i = 0; i < count; ++i) { 386 gpu::Dimension dim = *gpu::symbolizeDimension(i); 387 procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim), 388 b.create<NProcsOp>(loc, indexType, dim)}; 389 } 390 return procInfo; 391 } 392 393 static void fillTileAndDistributePatterns(MLIRContext *context, 394 RewritePatternSet &patterns) { 395 { 396 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 397 cyclicNprocsEqNiters.distributionMethod.resize( 398 2, DistributionMethod::CyclicNumProcsEqNumIters); 399 cyclicNprocsEqNiters.procInfo = 400 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 401 patterns.add<LinalgTilingPattern>( 402 MatmulOp::getOperationName(), context, 403 LinalgTilingOptions() 404 .setTileSizes({8, 8, 4}) 405 .setLoopType(LinalgTilingLoopType::ParallelLoops) 406 .setDistributionOptions(cyclicNprocsEqNiters), 407 LinalgTransformationFilter( 408 StringAttr::get(context, "distribute1"), 409 StringAttr::get(context, "after_distribute1"))); 410 } 411 412 { 413 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 414 cyclicNprocsGeNiters.distributionMethod.resize( 415 2, DistributionMethod::CyclicNumProcsGeNumIters); 416 cyclicNprocsGeNiters.procInfo = 417 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 418 patterns.add<LinalgTilingPattern>( 419 MatmulOp::getOperationName(), context, 420 LinalgTilingOptions() 421 .setTileSizes({8, 8, 4}) 422 .setLoopType(LinalgTilingLoopType::ParallelLoops) 423 .setDistributionOptions(cyclicNprocsGeNiters), 424 LinalgTransformationFilter( 425 StringAttr::get(context, "distribute2"), 426 StringAttr::get(context, "after_distribute2"))); 427 } 428 429 { 430 LinalgLoopDistributionOptions cyclicNprocsDefault; 431 cyclicNprocsDefault.distributionMethod.resize(2, 432 DistributionMethod::Cyclic); 433 cyclicNprocsDefault.procInfo = 434 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 435 patterns.add<LinalgTilingPattern>( 436 MatmulOp::getOperationName(), context, 437 LinalgTilingOptions() 438 .setTileSizes({8, 8, 4}) 439 .setLoopType(LinalgTilingLoopType::ParallelLoops) 440 .setDistributionOptions(cyclicNprocsDefault), 441 LinalgTransformationFilter( 442 StringAttr::get(context, "distribute3"), 443 StringAttr::get(context, "after_distribute3"))); 444 } 445 446 { 447 LinalgLoopDistributionOptions cyclicNprocsMixed1; 448 cyclicNprocsMixed1.distributionMethod = { 449 DistributionMethod::CyclicNumProcsEqNumIters, 450 DistributionMethod::CyclicNumProcsGeNumIters}; 451 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 452 patterns.add<LinalgTilingPattern>( 453 MatmulOp::getOperationName(), context, 454 LinalgTilingOptions() 455 .setTileSizes({8, 8, 4}) 456 .setLoopType(LinalgTilingLoopType::ParallelLoops) 457 .setDistributionOptions(cyclicNprocsMixed1), 458 LinalgTransformationFilter( 459 StringAttr::get(context, "distribute4"), 460 StringAttr::get(context, "after_distribute4"))); 461 } 462 463 { 464 LinalgLoopDistributionOptions cyclicNprocsMixed2; 465 cyclicNprocsMixed2.distributionMethod = { 466 DistributionMethod::CyclicNumProcsGeNumIters, 467 DistributionMethod::Cyclic}; 468 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 469 patterns.add<LinalgTilingPattern>( 470 MatmulOp::getOperationName(), context, 471 LinalgTilingOptions() 472 .setTileSizes({8, 8, 4}) 473 .setLoopType(LinalgTilingLoopType::ParallelLoops) 474 .setDistributionOptions(cyclicNprocsMixed2), 475 LinalgTransformationFilter( 476 StringAttr::get(context, "distribute5"), 477 StringAttr::get(context, "after_distribute5"))); 478 } 479 480 { 481 LinalgLoopDistributionOptions cyclicNprocsMixed3; 482 cyclicNprocsMixed3.distributionMethod = { 483 DistributionMethod::Cyclic, 484 DistributionMethod::CyclicNumProcsEqNumIters}; 485 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 486 487 patterns.add<LinalgTilingPattern>( 488 MatmulOp::getOperationName(), context, 489 LinalgTilingOptions() 490 .setTileSizes({8, 8, 4}) 491 .setLoopType(LinalgTilingLoopType::ParallelLoops) 492 .setDistributionOptions(cyclicNprocsMixed3), 493 LinalgTransformationFilter( 494 StringAttr::get(context, "distribute6"), 495 StringAttr::get(context, "after_distribute6"))); 496 } 497 498 { 499 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 500 cyclicNprocsEqNiters.distributionMethod.resize(2, 501 DistributionMethod::Cyclic); 502 cyclicNprocsEqNiters.procInfo = 503 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 504 patterns.add<LinalgTilingPattern>( 505 MatmulOp::getOperationName(), context, 506 LinalgTilingOptions() 507 .setTileSizes({8, 8, 4}) 508 .setLoopType(LinalgTilingLoopType::Loops) 509 .setDistributionOptions(cyclicNprocsEqNiters), 510 LinalgTransformationFilter( 511 StringAttr::get(context, "tensors_distribute1"), 512 StringAttr::get(context, "tensors_after_distribute1"))); 513 } 514 } 515 516 static void fillTileFuseAndDistributePatterns(MLIRContext *context, 517 RewritePatternSet &patterns) { 518 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 519 cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); 520 cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 521 patterns.add<LinalgTileAndFuseTensorOpsPattern>( 522 MatmulOp::getOperationName(), context, 523 LinalgTilingAndFusionOptions() 524 .setTileSizes({8, 8, 4}) 525 .setDistributionOptions(cyclicNprocsEqNiters), 526 LinalgTransformationFilter( 527 StringAttr::get(context, "tensors_fuse_distribute1"), 528 StringAttr::get(context, "tensors_after_fuse_distribute1"))); 529 } 530 531 static void 532 applyMatmulToVectorPatterns(FuncOp funcOp, 533 bool testMatmulToVectorPatterns1dTiling, 534 bool testMatmulToVectorPatterns2dTiling) { 535 MLIRContext *ctx = funcOp.getContext(); 536 SmallVector<RewritePatternSet, 4> stage1Patterns; 537 if (testMatmulToVectorPatterns1dTiling) { 538 fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); 539 } else if (testMatmulToVectorPatterns2dTiling) { 540 stage1Patterns.emplace_back( 541 ctx, std::make_unique<LinalgTilingPattern>( 542 MatmulOp::getOperationName(), ctx, 543 LinalgTilingOptions() 544 .setTileSizes({768, 264, 768}) 545 .setInterchange({1, 2, 0}), 546 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 547 StringAttr::get(ctx, "L2")))); 548 fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); 549 } 550 { 551 // Canonicalization patterns 552 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 553 vector::populateVectorTransferPermutationMapLoweringPatterns( 554 canonicalizationPatterns); 555 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 556 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 557 } 558 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 559 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 560 FrozenRewritePatternSet stage2Patterns = 561 getLinalgTilingCanonicalizationPatterns(ctx); 562 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns); 563 } 564 565 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 566 RewritePatternSet forwardPattern(funcOp.getContext()); 567 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 568 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 569 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 570 } 571 572 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 573 RewritePatternSet patterns(funcOp.getContext()); 574 auto *ctx = funcOp.getContext(); 575 patterns.add<LinalgVectorizationPattern>( 576 ctx, LinalgTransformationFilter() 577 .addOpFilter<ContractionOpInterface, FillOp, GenericOp>()); 578 patterns.add<CopyVectorizationPattern>(ctx); 579 populatePadOpVectorizationPatterns(patterns); 580 populateConvolutionVectorizationPatterns(patterns); 581 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 582 } 583 584 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 585 RewritePatternSet patterns(funcOp.getContext()); 586 patterns.add<PadOpTransformationPattern>(funcOp.getContext()); 587 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 588 } 589 590 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 591 RewritePatternSet patterns(funcOp.getContext()); 592 patterns.add<GeneralizePadOpPattern>(funcOp.getContext()); 593 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 594 } 595 596 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 597 RewritePatternSet patterns(funcOp.getContext()); 598 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 599 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 600 } 601 602 static void applyTilePattern(FuncOp funcOp, const std::string &loopType, 603 ArrayRef<int64_t> tileSizes, 604 ArrayRef<int64_t> peeledLoops, 605 bool scalarizeDynamicDims) { 606 MLIRContext *context = funcOp.getContext(); 607 RewritePatternSet tilingPattern(context); 608 LinalgTilingLoopType type = 609 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 610 .Case("for", LinalgTilingLoopType::Loops) 611 .Case("affine", LinalgTilingLoopType::AffineLoops) 612 .Case("parallel", LinalgTilingLoopType::ParallelLoops); 613 auto linalgTilingOptions = linalg::LinalgTilingOptions() 614 .setPeeledLoops(peeledLoops) 615 .setLoopType(type); 616 if (scalarizeDynamicDims) { 617 linalgTilingOptions.scalarizeDynamicDims(); 618 assert(tileSizes.empty() && 619 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 620 } else { 621 linalgTilingOptions.setTileSizes(tileSizes); 622 } 623 linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile")); 624 TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert( 625 tilingPattern, linalgTilingOptions, f); 626 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 627 } 628 629 static void applySplitReduction(FuncOp funcOp) { 630 RewritePatternSet patterns(funcOp.getContext()); 631 linalg::populateSplitReductionPattern( 632 patterns, 633 [](LinalgOp op) { 634 unsigned insertDimIndex = op.getNumLoops() - 1; 635 return std::make_pair(4, insertDimIndex); 636 }, 637 LinalgTransformationFilter( 638 ArrayRef<StringAttr>{}, 639 StringAttr::get(funcOp.getContext(), "SPLIT"))); 640 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 641 } 642 643 static void applyBubbleUpExtractSliceOpPattern(FuncOp funcOp) { 644 RewritePatternSet patterns(funcOp.getContext()); 645 populateBubbleUpExtractSliceOpPatterns(patterns); 646 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 647 } 648 649 /// Apply transformations specified as patterns. 650 void TestLinalgTransforms::runOnOperation() { 651 auto lambda = [&](void *) { 652 getOperation().walk([](LinalgOp op) { 653 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 654 }); 655 }; 656 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 657 658 if (testPromotionOptions) { 659 RewritePatternSet patterns(&getContext()); 660 fillPromotionCallBackPatterns(&getContext(), patterns); 661 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 662 return; 663 } 664 if (testTileAndDistributionOptions) { 665 RewritePatternSet patterns(&getContext()); 666 fillTileAndDistributePatterns(&getContext(), patterns); 667 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 668 return; 669 } 670 if (testTileFuseAndDistributionOptions) { 671 RewritePatternSet patterns(&getContext()); 672 fillTileFuseAndDistributePatterns(&getContext(), patterns); 673 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 674 return; 675 } 676 if (testPatterns) 677 return applyPatterns(getOperation()); 678 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 679 return applyMatmulToVectorPatterns(getOperation(), 680 testMatmulToVectorPatterns1dTiling, 681 testMatmulToVectorPatterns2dTiling); 682 if (testVectorTransferForwardingPatterns) 683 return applyVectorTransferForwardingPatterns(getOperation()); 684 if (testGenericToVectorPattern) 685 return applyLinalgToVectorPatterns(getOperation()); 686 if (testTransformPadTensor) 687 return applyPadTensorToGenericPatterns(getOperation()); 688 if (testGeneralizePadTensor) 689 return applyGeneralizePadTensorPatterns(getOperation()); 690 if (testSwapSubTensorPadTensor) 691 return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 692 if (testTilePattern) 693 return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops, 694 /*scalarizeDynamicDims=*/false); 695 if (testTileScalarizeDynamicDims) 696 return applyTilePattern(getOperation(), loopType, tileSizes, 697 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 698 if (testSplitReduction) 699 return applySplitReduction(getOperation()); 700 if (testBubbleUpExtractSliceOpPattern) 701 return applyBubbleUpExtractSliceOpPattern(getOperation()); 702 } 703 704 namespace mlir { 705 namespace test { 706 void registerTestLinalgTransforms() { 707 PassRegistration<TestLinalgTransforms>(); 708 } 709 } // namespace test 710 } // namespace mlir 711