1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/GPU/GPUDialect.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallVector.h" 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 namespace { 34 struct TestLinalgTransforms 35 : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> { 36 TestLinalgTransforms() = default; 37 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 38 39 void getDependentDialects(DialectRegistry ®istry) const override { 40 // clang-format off 41 registry.insert<AffineDialect, 42 memref::MemRefDialect, 43 scf::SCFDialect, 44 linalg::LinalgDialect, 45 vector::VectorDialect, 46 gpu::GPUDialect>(); 47 // clang-format on 48 } 49 StringRef getArgument() const final { 50 return "test-linalg-transform-patterns"; 51 } 52 StringRef getDescription() const final { 53 return "Test Linalg transformation patterns by applying them greedily."; 54 } 55 56 void runOnOperation() override; 57 58 Option<bool> testPatterns{*this, "test-patterns", 59 llvm::cl::desc("Test a mixed set of patterns"), 60 llvm::cl::init(false)}; 61 Option<bool> testMatmulToVectorPatterns1dTiling{ 62 *this, "test-matmul-to-vector-patterns-tile-1d", 63 llvm::cl::desc( 64 "Test a fused pass that applies patterns from matmul to vectors via " 65 "1-d tiling"), 66 llvm::cl::init(false)}; 67 Option<bool> testMatmulToVectorPatterns2dTiling{ 68 *this, "test-matmul-to-vector-patterns-tile-2d", 69 llvm::cl::desc( 70 "Test a fused pass that applies patterns from matmul to vectors via " 71 "2-d tiling"), 72 llvm::cl::init(false)}; 73 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 74 llvm::cl::desc("Test promotion options"), 75 llvm::cl::init(false)}; 76 Option<bool> testTileAndDistributionOptions{ 77 *this, "test-tile-and-distribute-options", 78 llvm::cl::desc("Test tile and distribute options"), 79 llvm::cl::init(false)}; 80 Option<bool> testTileFuseAndDistributionOptions{ 81 *this, "test-tile-fuse-and-distribute-options", 82 llvm::cl::desc("Test tile, fuse and distribute options"), 83 llvm::cl::init(false)}; 84 Option<bool> testVectorTransferForwardingPatterns{ 85 *this, "test-vector-transfer-forwarding-patterns", 86 llvm::cl::desc( 87 "Test a fused pass that forwards memref.copy to vector.transfer"), 88 llvm::cl::init(false)}; 89 Option<bool> testGenericToVectorPattern{ 90 *this, "test-linalg-to-vector-patterns", 91 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 92 "in vector.contract form"), 93 llvm::cl::init(false)}; 94 Option<bool> testTilePattern{*this, "test-tile-pattern", 95 llvm::cl::desc("Test tile pattern"), 96 llvm::cl::init(false)}; 97 Option<bool> testTileScalarizeDynamicDims{ 98 *this, "test-tile-scalarize-dynamic-dims", 99 llvm::cl::desc("Test tiling of dynamic dims by 1"), 100 llvm::cl::init(false)}; 101 Option<bool> testTransformPadTensor{ 102 *this, "test-transform-pad-tensor", 103 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 104 llvm::cl::init(false)}; 105 Option<bool> testGeneralizePadTensor{ 106 *this, "test-generalize-pad-tensor", 107 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 108 llvm::cl::init(false)}; 109 Option<bool> testSwapSubTensorPadTensor{ 110 *this, "test-swap-subtensor-padtensor", 111 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 112 "pad_tensor(subtensor)"), 113 llvm::cl::init(false)}; 114 Option<bool> testSplitReduction{ 115 *this, "test-split-reduction", 116 llvm::cl::desc("Test split reduction transformation"), 117 llvm::cl::init(false)}; 118 ListOption<int64_t> peeledLoops{ 119 *this, "peeled-loops", 120 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 121 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 122 ListOption<int64_t> tileSizes{ 123 *this, "tile-sizes", 124 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 125 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 126 Option<bool> skipPartial{ 127 *this, "skip-partial", 128 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 129 llvm::cl::init(false)}; 130 Option<std::string> loopType{ 131 *this, "loop-type", 132 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 133 "tiled_loop"), 134 llvm::cl::init("for")}; 135 }; 136 } // namespace 137 138 static void applyPatterns(FuncOp funcOp) { 139 MLIRContext *ctx = funcOp.getContext(); 140 RewritePatternSet patterns(ctx); 141 142 //===--------------------------------------------------------------------===// 143 // Linalg tiling patterns. 144 //===--------------------------------------------------------------------===// 145 patterns.add<LinalgTilingPattern>( 146 MatmulOp::getOperationName(), ctx, 147 LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 148 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 149 StringAttr::get(ctx, "L3"))); 150 patterns.add<LinalgTilingPattern>( 151 MatmulOp::getOperationName(), ctx, 152 LinalgTilingOptions().setTileSizes({200, 300, 400}), 153 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 154 StringAttr::get(ctx, "L2"))); 155 patterns.add<LinalgTilingPattern>( 156 MatmulOp::getOperationName(), ctx, 157 LinalgTilingOptions().setTileSizes({20, 30, 40}), 158 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 159 StringAttr::get(ctx, "L1"))); 160 patterns.add<LinalgTilingPattern>( 161 MatmulOp::getOperationName(), ctx, 162 LinalgTilingOptions().setTileSizes({2, 3, 4}), 163 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 164 StringAttr::get(ctx, "REG"))); 165 166 patterns.add<LinalgTilingPattern>( 167 MatvecOp::getOperationName(), ctx, 168 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 169 LinalgTilingLoopType::ParallelLoops), 170 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 171 StringAttr::get(ctx, "L1"))); 172 173 patterns.add<LinalgTilingPattern>( 174 DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000), 175 LinalgTransformationFilter( 176 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 177 StringAttr::get(ctx, "L3"), 178 StringAttr::get(ctx, "L2")}, 179 StringAttr::get(ctx, "REG"))); 180 181 //===--------------------------------------------------------------------===// 182 // Linalg tiling and permutation patterns. 183 //===--------------------------------------------------------------------===// 184 patterns.add<LinalgTilingPattern>( 185 MatmulOp::getOperationName(), ctx, 186 LinalgTilingOptions() 187 .setTileSizes({2000, 3000, 4000}) 188 .setInterchange({1, 2, 0}), 189 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 190 StringAttr::get(ctx, "L2__with_perm__"))); 191 patterns.add<LinalgTilingPattern>( 192 MatmulOp::getOperationName(), ctx, 193 LinalgTilingOptions() 194 .setTileSizes({200, 300, 400}) 195 .setInterchange({1, 0, 2}), 196 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 197 StringAttr::get(ctx, "L1__with_perm__"))); 198 patterns.add<LinalgTilingPattern>( 199 MatmulOp::getOperationName(), ctx, 200 LinalgTilingOptions().setTileSizes({20, 30, 40}), 201 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 202 StringAttr::get(ctx, "REG__with_perm__"))); 203 204 patterns.add<LinalgTilingPattern>( 205 MatvecOp::getOperationName(), ctx, 206 LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 207 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 208 StringAttr::get(ctx, "L1__with_perm__"))); 209 210 patterns.add<LinalgTilingPattern>( 211 MatmulOp::getOperationName(), ctx, 212 LinalgTilingOptions() 213 .setTileSizes({16, 8, 4}) 214 .setInterchange({1, 2, 0}) 215 .setLoopType(LinalgTilingLoopType::ParallelLoops), 216 LinalgTransformationFilter( 217 StringAttr::get(ctx, "par__with_perm__"), 218 StringAttr::get(ctx, "after_par__with_perm__"))); 219 220 //===--------------------------------------------------------------------===// 221 // Linalg to loops patterns. 222 //===--------------------------------------------------------------------===// 223 patterns.add<LinalgLoweringPattern<DotOp>>( 224 ctx, 225 /*loweringType=*/LinalgLoweringType::Loops, 226 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 227 228 //===--------------------------------------------------------------------===// 229 // Linalg distribution patterns. 230 //===--------------------------------------------------------------------===// 231 LinalgLoopDistributionOptions distributionOptions; 232 233 //===--------------------------------------------------------------------===// 234 // Linalg to vector contraction patterns. 235 //===--------------------------------------------------------------------===// 236 patterns.add<LinalgVectorizationPattern>( 237 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 238 .addOpFilter<MatmulOp, FillOp, GenericOp>()); 239 patterns.add<CopyVectorizationPattern>(ctx); 240 241 //===--------------------------------------------------------------------===// 242 // Linalg generic interchange pattern. 243 //===--------------------------------------------------------------------===// 244 patterns.add<GenericOpInterchangePattern>( 245 ctx, 246 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 247 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 248 StringAttr::get(ctx, "PERMUTED"))); 249 250 //===--------------------------------------------------------------------===// 251 // Linalg subview operands promotion. 252 //===--------------------------------------------------------------------===// 253 patterns.add<LinalgPromotionPattern<MatmulOp>>( 254 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 255 LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), 256 StringAttr::get(ctx, "_views_promoted_"))); 257 patterns.add<LinalgPromotionPattern<MatmulOp>>( 258 ctx, 259 LinalgPromotionOptions() 260 .setOperandsToPromote({0}) 261 .setUseFullTileBuffersByDefault(true), 262 LinalgTransformationFilter( 263 StringAttr::get(ctx, "_promote_first_view_"), 264 StringAttr::get(ctx, "_first_view_promoted_"))); 265 patterns.add<LinalgPromotionPattern<FillOp>>( 266 ctx, 267 LinalgPromotionOptions() 268 .setOperandsToPromote({1}) 269 .setUseFullTileBuffers({false, true}) 270 .setAlignment(32), 271 LinalgTransformationFilter( 272 StringAttr::get(ctx, "_promote_views_aligned_"), 273 StringAttr::get(ctx, "_views_aligned_promoted_"))); 274 275 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 276 277 // Drop the marker. 278 funcOp.walk([](LinalgOp op) { 279 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 280 }); 281 } 282 283 static void fillL1TilingAndMatmulToVectorPatterns( 284 FuncOp funcOp, StringRef startMarker, 285 SmallVectorImpl<RewritePatternSet> &patternsVector) { 286 MLIRContext *ctx = funcOp.getContext(); 287 patternsVector.emplace_back( 288 ctx, std::make_unique<LinalgTilingPattern>( 289 MatmulOp::getOperationName(), ctx, 290 LinalgTilingOptions() 291 .setTileSizes({8, 12, 16}) 292 .setInterchange({1, 0, 2}), 293 LinalgTransformationFilter(StringAttr::get(ctx, startMarker), 294 StringAttr::get(ctx, "L1")))); 295 296 patternsVector.emplace_back( 297 ctx, 298 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 299 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 300 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 301 StringAttr::get(ctx, "VEC")))); 302 303 patternsVector.emplace_back( 304 ctx, std::make_unique<LinalgVectorizationPattern>( 305 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 306 LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); 307 patternsVector.back().add<LinalgVectorizationPattern>( 308 ctx, LinalgTransformationFilter().addOpFilter<FillOp>()); 309 patternsVector.back().add<CopyVectorizationPattern>(ctx); 310 } 311 312 //===----------------------------------------------------------------------===// 313 // Test promotion callbacks 314 //===----------------------------------------------------------------------===// 315 316 // Allocation call back 317 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 318 ArrayRef<Value> boundingSubViewSize, 319 DataLayout &layout) { 320 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 321 return b 322 .create<memref::AllocOp>( 323 subView.getLoc(), 324 MemRefType::get(shape, subView.getType().getElementType(), 325 /*affineMapComposition =*/{}, 3), 326 boundingSubViewSize) 327 .getResult(); 328 } 329 330 // Deallocation callback 331 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 332 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 333 return success(); 334 } 335 336 // Copy in call back 337 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 338 bool isOutput) { 339 auto floatType = src.getType().cast<MemRefType>().getElementType(); 340 if (!floatType.isa<FloatType>()) 341 return failure(); 342 if (!isOutput) { 343 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 344 FloatAttr::get(floatType, 42.0)); 345 b.create<FillOp>(src.getLoc(), cst, dst); 346 } 347 b.create<memref::CopyOp>(src.getLoc(), src, dst); 348 return success(); 349 } 350 351 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 352 RewritePatternSet &patterns) { 353 patterns.add<LinalgTilingPattern>( 354 MatmulOp::getOperationName(), ctx, 355 LinalgTilingOptions().setTileSizes({16, 16, 16}), 356 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 357 StringAttr::get(ctx, "PROMOTE"))); 358 patterns.add<LinalgPromotionPattern<MatmulOp>>( 359 ctx, 360 LinalgPromotionOptions() 361 .setOperandsToPromote({0, 2}) 362 .setUseFullTileBuffers({false, false}) 363 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 364 .setCopyInOutFns( 365 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 366 return copyCallBackFn(b, src, dst, false); 367 }, 368 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 369 return copyCallBackFn(b, src, dst, true); 370 }), 371 LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); 372 } 373 374 template <typename IdOp, typename NProcsOp> 375 static SmallVector<ProcInfo, 2> 376 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 377 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 378 SmallVector<ProcInfo, 2> procInfo(count); 379 Type indexType = b.getIndexType(); 380 for (unsigned i = 0; i < count; ++i) { 381 gpu::Dimension dim = *gpu::symbolizeDimension(i); 382 procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim), 383 b.create<NProcsOp>(loc, indexType, dim)}; 384 } 385 return procInfo; 386 } 387 388 static void fillTileAndDistributePatterns(MLIRContext *context, 389 RewritePatternSet &patterns) { 390 { 391 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 392 cyclicNprocsEqNiters.distributionMethod.resize( 393 2, DistributionMethod::CyclicNumProcsEqNumIters); 394 cyclicNprocsEqNiters.procInfo = 395 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 396 patterns.add<LinalgTilingPattern>( 397 MatmulOp::getOperationName(), context, 398 LinalgTilingOptions() 399 .setTileSizes({8, 8, 4}) 400 .setLoopType(LinalgTilingLoopType::ParallelLoops) 401 .setDistributionOptions(cyclicNprocsEqNiters), 402 LinalgTransformationFilter( 403 StringAttr::get(context, "distribute1"), 404 StringAttr::get(context, "after_distribute1"))); 405 } 406 407 { 408 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 409 cyclicNprocsGeNiters.distributionMethod.resize( 410 2, DistributionMethod::CyclicNumProcsGeNumIters); 411 cyclicNprocsGeNiters.procInfo = 412 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 413 patterns.add<LinalgTilingPattern>( 414 MatmulOp::getOperationName(), context, 415 LinalgTilingOptions() 416 .setTileSizes({8, 8, 4}) 417 .setLoopType(LinalgTilingLoopType::ParallelLoops) 418 .setDistributionOptions(cyclicNprocsGeNiters), 419 LinalgTransformationFilter( 420 StringAttr::get(context, "distribute2"), 421 StringAttr::get(context, "after_distribute2"))); 422 } 423 424 { 425 LinalgLoopDistributionOptions cyclicNprocsDefault; 426 cyclicNprocsDefault.distributionMethod.resize(2, 427 DistributionMethod::Cyclic); 428 cyclicNprocsDefault.procInfo = 429 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 430 patterns.add<LinalgTilingPattern>( 431 MatmulOp::getOperationName(), context, 432 LinalgTilingOptions() 433 .setTileSizes({8, 8, 4}) 434 .setLoopType(LinalgTilingLoopType::ParallelLoops) 435 .setDistributionOptions(cyclicNprocsDefault), 436 LinalgTransformationFilter( 437 StringAttr::get(context, "distribute3"), 438 StringAttr::get(context, "after_distribute3"))); 439 } 440 441 { 442 LinalgLoopDistributionOptions cyclicNprocsMixed1; 443 cyclicNprocsMixed1.distributionMethod = { 444 DistributionMethod::CyclicNumProcsEqNumIters, 445 DistributionMethod::CyclicNumProcsGeNumIters}; 446 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 447 patterns.add<LinalgTilingPattern>( 448 MatmulOp::getOperationName(), context, 449 LinalgTilingOptions() 450 .setTileSizes({8, 8, 4}) 451 .setLoopType(LinalgTilingLoopType::ParallelLoops) 452 .setDistributionOptions(cyclicNprocsMixed1), 453 LinalgTransformationFilter( 454 StringAttr::get(context, "distribute4"), 455 StringAttr::get(context, "after_distribute4"))); 456 } 457 458 { 459 LinalgLoopDistributionOptions cyclicNprocsMixed2; 460 cyclicNprocsMixed2.distributionMethod = { 461 DistributionMethod::CyclicNumProcsGeNumIters, 462 DistributionMethod::Cyclic}; 463 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 464 patterns.add<LinalgTilingPattern>( 465 MatmulOp::getOperationName(), context, 466 LinalgTilingOptions() 467 .setTileSizes({8, 8, 4}) 468 .setLoopType(LinalgTilingLoopType::ParallelLoops) 469 .setDistributionOptions(cyclicNprocsMixed2), 470 LinalgTransformationFilter( 471 StringAttr::get(context, "distribute5"), 472 StringAttr::get(context, "after_distribute5"))); 473 } 474 475 { 476 LinalgLoopDistributionOptions cyclicNprocsMixed3; 477 cyclicNprocsMixed3.distributionMethod = { 478 DistributionMethod::Cyclic, 479 DistributionMethod::CyclicNumProcsEqNumIters}; 480 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 481 482 patterns.add<LinalgTilingPattern>( 483 MatmulOp::getOperationName(), context, 484 LinalgTilingOptions() 485 .setTileSizes({8, 8, 4}) 486 .setLoopType(LinalgTilingLoopType::ParallelLoops) 487 .setDistributionOptions(cyclicNprocsMixed3), 488 LinalgTransformationFilter( 489 StringAttr::get(context, "distribute6"), 490 StringAttr::get(context, "after_distribute6"))); 491 } 492 493 { 494 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 495 cyclicNprocsEqNiters.distributionMethod.resize(2, 496 DistributionMethod::Cyclic); 497 cyclicNprocsEqNiters.procInfo = 498 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 499 patterns.add<LinalgTilingPattern>( 500 MatmulOp::getOperationName(), context, 501 LinalgTilingOptions() 502 .setTileSizes({8, 8, 4}) 503 .setLoopType(LinalgTilingLoopType::Loops) 504 .setDistributionOptions(cyclicNprocsEqNiters), 505 LinalgTransformationFilter( 506 StringAttr::get(context, "tensors_distribute1"), 507 StringAttr::get(context, "tensors_after_distribute1"))); 508 } 509 } 510 511 static void fillTileFuseAndDistributePatterns(MLIRContext *context, 512 RewritePatternSet &patterns) { 513 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 514 cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); 515 cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 516 patterns.add<LinalgTileAndFuseTensorOpsPattern>( 517 MatmulOp::getOperationName(), context, 518 LinalgTilingAndFusionOptions() 519 .setTileSizes({8, 8, 4}) 520 .setDistributionOptions(cyclicNprocsEqNiters), 521 LinalgTransformationFilter( 522 StringAttr::get(context, "tensors_fuse_distribute1"), 523 StringAttr::get(context, "tensors_after_fuse_distribute1"))); 524 } 525 526 static void 527 applyMatmulToVectorPatterns(FuncOp funcOp, 528 bool testMatmulToVectorPatterns1dTiling, 529 bool testMatmulToVectorPatterns2dTiling) { 530 MLIRContext *ctx = funcOp.getContext(); 531 SmallVector<RewritePatternSet, 4> stage1Patterns; 532 if (testMatmulToVectorPatterns1dTiling) { 533 fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); 534 } else if (testMatmulToVectorPatterns2dTiling) { 535 stage1Patterns.emplace_back( 536 ctx, std::make_unique<LinalgTilingPattern>( 537 MatmulOp::getOperationName(), ctx, 538 LinalgTilingOptions() 539 .setTileSizes({768, 264, 768}) 540 .setInterchange({1, 2, 0}), 541 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 542 StringAttr::get(ctx, "L2")))); 543 fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); 544 } 545 { 546 // Canonicalization patterns 547 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 548 vector::populateVectorTransferPermutationMapLoweringPatterns( 549 canonicalizationPatterns); 550 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 551 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 552 } 553 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 554 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 555 FrozenRewritePatternSet stage2Patterns = 556 getLinalgTilingCanonicalizationPatterns(ctx); 557 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns); 558 } 559 560 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 561 RewritePatternSet forwardPattern(funcOp.getContext()); 562 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 563 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 564 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 565 } 566 567 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 568 RewritePatternSet patterns(funcOp.getContext()); 569 auto *ctx = funcOp.getContext(); 570 patterns.add<LinalgVectorizationPattern>( 571 ctx, LinalgTransformationFilter() 572 .addOpFilter<ContractionOpInterface, FillOp, GenericOp>()); 573 patterns.add<CopyVectorizationPattern>(ctx); 574 populatePadOpVectorizationPatterns(patterns); 575 populateConvolutionVectorizationPatterns(patterns); 576 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 577 } 578 579 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 580 RewritePatternSet patterns(funcOp.getContext()); 581 patterns.add<PadOpTransformationPattern>(funcOp.getContext()); 582 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 583 } 584 585 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 586 RewritePatternSet patterns(funcOp.getContext()); 587 patterns.add<GeneralizePadOpPattern>(funcOp.getContext()); 588 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 589 } 590 591 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 592 RewritePatternSet patterns(funcOp.getContext()); 593 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 594 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 595 } 596 597 static void applyTilePattern(FuncOp funcOp, const std::string &loopType, 598 ArrayRef<int64_t> tileSizes, 599 ArrayRef<int64_t> peeledLoops, 600 bool scalarizeDynamicDims) { 601 MLIRContext *context = funcOp.getContext(); 602 RewritePatternSet tilingPattern(context); 603 LinalgTilingLoopType type = 604 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 605 .Case("for", LinalgTilingLoopType::Loops) 606 .Case("affine", LinalgTilingLoopType::AffineLoops) 607 .Case("parallel", LinalgTilingLoopType::ParallelLoops); 608 auto linalgTilingOptions = linalg::LinalgTilingOptions() 609 .setPeeledLoops(peeledLoops) 610 .setLoopType(type); 611 if (scalarizeDynamicDims) { 612 linalgTilingOptions.scalarizeDynamicDims(); 613 assert(tileSizes.empty() && 614 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 615 } else { 616 linalgTilingOptions.setTileSizes(tileSizes); 617 } 618 linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile")); 619 TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert( 620 tilingPattern, linalgTilingOptions, f); 621 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 622 } 623 624 static void applySplitReduction(FuncOp funcOp) { 625 RewritePatternSet patterns(funcOp.getContext()); 626 linalg::populateSplitReductionPattern( 627 patterns, 628 [](LinalgOp op) { 629 unsigned insertDimIndex = op.getNumLoops() - 1; 630 return std::make_pair(4, insertDimIndex); 631 }, 632 LinalgTransformationFilter( 633 ArrayRef<StringAttr>{}, 634 StringAttr::get(funcOp.getContext(), "SPLIT"))); 635 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 636 } 637 638 /// Apply transformations specified as patterns. 639 void TestLinalgTransforms::runOnOperation() { 640 auto lambda = [&](void *) { 641 getOperation().walk([](LinalgOp op) { 642 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 643 }); 644 }; 645 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 646 647 if (testPromotionOptions) { 648 RewritePatternSet patterns(&getContext()); 649 fillPromotionCallBackPatterns(&getContext(), patterns); 650 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 651 return; 652 } 653 if (testTileAndDistributionOptions) { 654 RewritePatternSet patterns(&getContext()); 655 fillTileAndDistributePatterns(&getContext(), patterns); 656 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 657 return; 658 } 659 if (testTileFuseAndDistributionOptions) { 660 RewritePatternSet patterns(&getContext()); 661 fillTileFuseAndDistributePatterns(&getContext(), patterns); 662 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 663 return; 664 } 665 if (testPatterns) 666 return applyPatterns(getOperation()); 667 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 668 return applyMatmulToVectorPatterns(getOperation(), 669 testMatmulToVectorPatterns1dTiling, 670 testMatmulToVectorPatterns2dTiling); 671 if (testVectorTransferForwardingPatterns) 672 return applyVectorTransferForwardingPatterns(getOperation()); 673 if (testGenericToVectorPattern) 674 return applyLinalgToVectorPatterns(getOperation()); 675 if (testTransformPadTensor) 676 return applyPadTensorToGenericPatterns(getOperation()); 677 if (testGeneralizePadTensor) 678 return applyGeneralizePadTensorPatterns(getOperation()); 679 if (testSwapSubTensorPadTensor) 680 return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 681 if (testTilePattern) 682 return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops, 683 /*scalarizeDynamicDims=*/false); 684 if (testTileScalarizeDynamicDims) 685 return applyTilePattern(getOperation(), loopType, tileSizes, 686 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 687 if (testSplitReduction) 688 return applySplitReduction(getOperation()); 689 } 690 691 namespace mlir { 692 namespace test { 693 void registerTestLinalgTransforms() { 694 PassRegistration<TestLinalgTransforms>(); 695 } 696 } // namespace test 697 } // namespace mlir 698