1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/Vector/IR/VectorOps.h" 23 #include "mlir/Pass/PassManager.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/ADT/SmallVector.h" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 namespace { 33 struct TestLinalgTransforms 34 : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> { 35 TestLinalgTransforms() = default; 36 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 37 38 void getDependentDialects(DialectRegistry ®istry) const override { 39 // clang-format off 40 registry.insert<AffineDialect, 41 memref::MemRefDialect, 42 scf::SCFDialect, 43 linalg::LinalgDialect, 44 vector::VectorDialect, 45 gpu::GPUDialect>(); 46 // clang-format on 47 } 48 StringRef getArgument() const final { 49 return "test-linalg-transform-patterns"; 50 } 51 StringRef getDescription() const final { 52 return "Test Linalg transformation patterns by applying them greedily."; 53 } 54 55 void runOnOperation() override; 56 57 Option<bool> testPatterns{*this, "test-patterns", 58 llvm::cl::desc("Test a mixed set of patterns"), 59 llvm::cl::init(false)}; 60 Option<bool> testMatmulToVectorPatterns1dTiling{ 61 *this, "test-matmul-to-vector-patterns-tile-1d", 62 llvm::cl::desc( 63 "Test a fused pass that applies patterns from matmul to vectors via " 64 "1-d tiling"), 65 llvm::cl::init(false)}; 66 Option<bool> testMatmulToVectorPatterns2dTiling{ 67 *this, "test-matmul-to-vector-patterns-tile-2d", 68 llvm::cl::desc( 69 "Test a fused pass that applies patterns from matmul to vectors via " 70 "2-d tiling"), 71 llvm::cl::init(false)}; 72 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 73 llvm::cl::desc("Test promotion options"), 74 llvm::cl::init(false)}; 75 Option<bool> testTileAndDistributionOptions{ 76 *this, "test-tile-and-distribute-options", 77 llvm::cl::desc("Test tile and distribute options"), 78 llvm::cl::init(false)}; 79 Option<bool> testTileFuseAndDistributionOptions{ 80 *this, "test-tile-fuse-and-distribute-options", 81 llvm::cl::desc("Test tile, fuse and distribute options"), 82 llvm::cl::init(false)}; 83 Option<bool> testVectorTransferForwardingPatterns{ 84 *this, "test-vector-transfer-forwarding-patterns", 85 llvm::cl::desc( 86 "Test a fused pass that forwards memref.copy to vector.transfer"), 87 llvm::cl::init(false)}; 88 Option<bool> testGenericToVectorPattern{ 89 *this, "test-linalg-to-vector-patterns", 90 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 91 "in vector.contract form"), 92 llvm::cl::init(false)}; 93 Option<bool> testTilePattern{*this, "test-tile-pattern", 94 llvm::cl::desc("Test tile pattern"), 95 llvm::cl::init(false)}; 96 Option<bool> testTileScalarizeDynamicDims{ 97 *this, "test-tile-scalarize-dynamic-dims", 98 llvm::cl::desc("Test tiling of dynamic dims by 1"), 99 llvm::cl::init(false)}; 100 Option<bool> testTransformPadTensor{ 101 *this, "test-transform-pad-tensor", 102 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 103 llvm::cl::init(false)}; 104 Option<bool> testGeneralizePadTensor{ 105 *this, "test-generalize-pad-tensor", 106 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 107 llvm::cl::init(false)}; 108 Option<bool> testSwapSubTensorPadTensor{ 109 *this, "test-swap-subtensor-padtensor", 110 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 111 "pad_tensor(subtensor)"), 112 llvm::cl::init(false)}; 113 ListOption<int64_t> peeledLoops{ 114 *this, "peeled-loops", 115 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 116 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 117 ListOption<int64_t> tileSizes{ 118 *this, "tile-sizes", 119 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 120 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 121 Option<bool> skipPartial{ 122 *this, "skip-partial", 123 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 124 llvm::cl::init(false)}; 125 Option<std::string> loopType{ 126 *this, "loop-type", 127 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 128 "tiled_loop"), 129 llvm::cl::init("for")}; 130 }; 131 } // namespace 132 133 static void applyPatterns(FuncOp funcOp) { 134 MLIRContext *ctx = funcOp.getContext(); 135 RewritePatternSet patterns(ctx); 136 137 //===--------------------------------------------------------------------===// 138 // Linalg tiling patterns. 139 //===--------------------------------------------------------------------===// 140 patterns.add<LinalgTilingPattern>( 141 MatmulOp::getOperationName(), ctx, 142 LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 143 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 144 StringAttr::get(ctx, "L3"))); 145 patterns.add<LinalgTilingPattern>( 146 MatmulOp::getOperationName(), ctx, 147 LinalgTilingOptions().setTileSizes({200, 300, 400}), 148 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 149 StringAttr::get(ctx, "L2"))); 150 patterns.add<LinalgTilingPattern>( 151 MatmulOp::getOperationName(), ctx, 152 LinalgTilingOptions().setTileSizes({20, 30, 40}), 153 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 154 StringAttr::get(ctx, "L1"))); 155 patterns.add<LinalgTilingPattern>( 156 MatmulOp::getOperationName(), ctx, 157 LinalgTilingOptions().setTileSizes({2, 3, 4}), 158 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 159 StringAttr::get(ctx, "REG"))); 160 161 patterns.add<LinalgTilingPattern>( 162 MatvecOp::getOperationName(), ctx, 163 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 164 LinalgTilingLoopType::ParallelLoops), 165 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 166 StringAttr::get(ctx, "L1"))); 167 168 patterns.add<LinalgTilingPattern>( 169 DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000), 170 LinalgTransformationFilter( 171 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 172 StringAttr::get(ctx, "L3"), 173 StringAttr::get(ctx, "L2")}, 174 StringAttr::get(ctx, "REG"))); 175 176 //===--------------------------------------------------------------------===// 177 // Linalg tiling and permutation patterns. 178 //===--------------------------------------------------------------------===// 179 patterns.add<LinalgTilingPattern>( 180 MatmulOp::getOperationName(), ctx, 181 LinalgTilingOptions() 182 .setTileSizes({2000, 3000, 4000}) 183 .setInterchange({1, 2, 0}), 184 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 185 StringAttr::get(ctx, "L2__with_perm__"))); 186 patterns.add<LinalgTilingPattern>( 187 MatmulOp::getOperationName(), ctx, 188 LinalgTilingOptions() 189 .setTileSizes({200, 300, 400}) 190 .setInterchange({1, 0, 2}), 191 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 192 StringAttr::get(ctx, "L1__with_perm__"))); 193 patterns.add<LinalgTilingPattern>( 194 MatmulOp::getOperationName(), ctx, 195 LinalgTilingOptions().setTileSizes({20, 30, 40}), 196 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 197 StringAttr::get(ctx, "REG__with_perm__"))); 198 199 patterns.add<LinalgTilingPattern>( 200 MatvecOp::getOperationName(), ctx, 201 LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 202 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 203 StringAttr::get(ctx, "L1__with_perm__"))); 204 205 patterns.add<LinalgTilingPattern>( 206 MatmulOp::getOperationName(), ctx, 207 LinalgTilingOptions() 208 .setTileSizes({16, 8, 4}) 209 .setInterchange({1, 2, 0}) 210 .setLoopType(LinalgTilingLoopType::ParallelLoops), 211 LinalgTransformationFilter( 212 StringAttr::get(ctx, "par__with_perm__"), 213 StringAttr::get(ctx, "after_par__with_perm__"))); 214 215 //===--------------------------------------------------------------------===// 216 // Linalg to loops patterns. 217 //===--------------------------------------------------------------------===// 218 patterns.add<LinalgLoweringPattern<DotOp>>( 219 ctx, 220 /*loweringType=*/LinalgLoweringType::Loops, 221 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 222 223 //===--------------------------------------------------------------------===// 224 // Linalg distribution patterns. 225 //===--------------------------------------------------------------------===// 226 LinalgLoopDistributionOptions distributionOptions; 227 228 //===--------------------------------------------------------------------===// 229 // Linalg to vector contraction patterns. 230 //===--------------------------------------------------------------------===// 231 patterns.add<LinalgVectorizationPattern>( 232 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 233 .addOpFilter<MatmulOp, FillOp, GenericOp>()); 234 patterns.add<CopyVectorizationPattern>(ctx); 235 236 //===--------------------------------------------------------------------===// 237 // Linalg generic interchange pattern. 238 //===--------------------------------------------------------------------===// 239 patterns.add<GenericOpInterchangePattern>( 240 ctx, 241 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 242 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 243 StringAttr::get(ctx, "PERMUTED"))); 244 245 //===--------------------------------------------------------------------===// 246 // Linalg subview operands promotion. 247 //===--------------------------------------------------------------------===// 248 patterns.add<LinalgPromotionPattern<MatmulOp>>( 249 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 250 LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), 251 StringAttr::get(ctx, "_views_promoted_"))); 252 patterns.add<LinalgPromotionPattern<MatmulOp>>( 253 ctx, 254 LinalgPromotionOptions() 255 .setOperandsToPromote({0}) 256 .setUseFullTileBuffersByDefault(true), 257 LinalgTransformationFilter( 258 StringAttr::get(ctx, "_promote_first_view_"), 259 StringAttr::get(ctx, "_first_view_promoted_"))); 260 patterns.add<LinalgPromotionPattern<FillOp>>( 261 ctx, 262 LinalgPromotionOptions() 263 .setOperandsToPromote({1}) 264 .setUseFullTileBuffers({false, true}) 265 .setAlignment(32), 266 LinalgTransformationFilter( 267 StringAttr::get(ctx, "_promote_views_aligned_"), 268 StringAttr::get(ctx, "_views_aligned_promoted_"))); 269 270 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 271 272 // Drop the marker. 273 funcOp.walk([](LinalgOp op) { 274 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 275 }); 276 } 277 278 static void fillL1TilingAndMatmulToVectorPatterns( 279 FuncOp funcOp, StringRef startMarker, 280 SmallVectorImpl<RewritePatternSet> &patternsVector) { 281 MLIRContext *ctx = funcOp.getContext(); 282 patternsVector.emplace_back( 283 ctx, std::make_unique<LinalgTilingPattern>( 284 MatmulOp::getOperationName(), ctx, 285 LinalgTilingOptions() 286 .setTileSizes({8, 12, 16}) 287 .setInterchange({1, 0, 2}), 288 LinalgTransformationFilter(StringAttr::get(ctx, startMarker), 289 StringAttr::get(ctx, "L1")))); 290 291 patternsVector.emplace_back( 292 ctx, 293 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 294 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 295 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 296 StringAttr::get(ctx, "VEC")))); 297 298 patternsVector.emplace_back( 299 ctx, std::make_unique<LinalgVectorizationPattern>( 300 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 301 LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); 302 patternsVector.back().add<LinalgVectorizationPattern>( 303 ctx, LinalgTransformationFilter().addOpFilter<FillOp>()); 304 patternsVector.back().add<CopyVectorizationPattern>(ctx); 305 } 306 307 //===----------------------------------------------------------------------===// 308 // Test promotion callbacks 309 //===----------------------------------------------------------------------===// 310 311 // Allocation call back 312 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 313 ArrayRef<Value> boundingSubViewSize, 314 DataLayout &layout) { 315 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 316 return b 317 .create<memref::AllocOp>( 318 subView.getLoc(), 319 MemRefType::get(shape, subView.getType().getElementType(), 320 /*affineMapComposition =*/{}, 3), 321 boundingSubViewSize) 322 .getResult(); 323 } 324 325 // Deallocation callback 326 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 327 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 328 return success(); 329 } 330 331 // Copy in call back 332 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 333 bool isOutput) { 334 auto floatType = src.getType().cast<MemRefType>().getElementType(); 335 if (!floatType.isa<FloatType>()) 336 return failure(); 337 if (!isOutput) { 338 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 339 FloatAttr::get(floatType, 42.0)); 340 b.create<FillOp>(src.getLoc(), cst, dst); 341 } 342 b.create<memref::CopyOp>(src.getLoc(), src, dst); 343 return success(); 344 } 345 346 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 347 RewritePatternSet &patterns) { 348 patterns.add<LinalgTilingPattern>( 349 MatmulOp::getOperationName(), ctx, 350 LinalgTilingOptions().setTileSizes({16, 16, 16}), 351 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 352 StringAttr::get(ctx, "PROMOTE"))); 353 patterns.add<LinalgPromotionPattern<MatmulOp>>( 354 ctx, 355 LinalgPromotionOptions() 356 .setOperandsToPromote({0, 2}) 357 .setUseFullTileBuffers({false, false}) 358 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 359 .setCopyInOutFns( 360 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 361 return copyCallBackFn(b, src, dst, false); 362 }, 363 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 364 return copyCallBackFn(b, src, dst, true); 365 }), 366 LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); 367 } 368 369 template <typename IdOp, typename NProcsOp> 370 static SmallVector<ProcInfo, 2> 371 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 372 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 373 SmallVector<ProcInfo, 2> procInfo(count); 374 Type indexType = b.getIndexType(); 375 for (unsigned i = 0; i < count; ++i) { 376 gpu::Dimension dim = *gpu::symbolizeDimension(i); 377 procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim), 378 b.create<NProcsOp>(loc, indexType, dim)}; 379 } 380 return procInfo; 381 } 382 383 static void fillTileAndDistributePatterns(MLIRContext *context, 384 RewritePatternSet &patterns) { 385 { 386 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 387 cyclicNprocsEqNiters.distributionMethod.resize( 388 2, DistributionMethod::CyclicNumProcsEqNumIters); 389 cyclicNprocsEqNiters.procInfo = 390 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 391 patterns.add<LinalgTilingPattern>( 392 MatmulOp::getOperationName(), context, 393 LinalgTilingOptions() 394 .setTileSizes({8, 8, 4}) 395 .setLoopType(LinalgTilingLoopType::ParallelLoops) 396 .setDistributionOptions(cyclicNprocsEqNiters), 397 LinalgTransformationFilter( 398 StringAttr::get(context, "distribute1"), 399 StringAttr::get(context, "after_distribute1"))); 400 } 401 402 { 403 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 404 cyclicNprocsGeNiters.distributionMethod.resize( 405 2, DistributionMethod::CyclicNumProcsGeNumIters); 406 cyclicNprocsGeNiters.procInfo = 407 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 408 patterns.add<LinalgTilingPattern>( 409 MatmulOp::getOperationName(), context, 410 LinalgTilingOptions() 411 .setTileSizes({8, 8, 4}) 412 .setLoopType(LinalgTilingLoopType::ParallelLoops) 413 .setDistributionOptions(cyclicNprocsGeNiters), 414 LinalgTransformationFilter( 415 StringAttr::get(context, "distribute2"), 416 StringAttr::get(context, "after_distribute2"))); 417 } 418 419 { 420 LinalgLoopDistributionOptions cyclicNprocsDefault; 421 cyclicNprocsDefault.distributionMethod.resize(2, 422 DistributionMethod::Cyclic); 423 cyclicNprocsDefault.procInfo = 424 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 425 patterns.add<LinalgTilingPattern>( 426 MatmulOp::getOperationName(), context, 427 LinalgTilingOptions() 428 .setTileSizes({8, 8, 4}) 429 .setLoopType(LinalgTilingLoopType::ParallelLoops) 430 .setDistributionOptions(cyclicNprocsDefault), 431 LinalgTransformationFilter( 432 StringAttr::get(context, "distribute3"), 433 StringAttr::get(context, "after_distribute3"))); 434 } 435 436 { 437 LinalgLoopDistributionOptions cyclicNprocsMixed1; 438 cyclicNprocsMixed1.distributionMethod = { 439 DistributionMethod::CyclicNumProcsEqNumIters, 440 DistributionMethod::CyclicNumProcsGeNumIters}; 441 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 442 patterns.add<LinalgTilingPattern>( 443 MatmulOp::getOperationName(), context, 444 LinalgTilingOptions() 445 .setTileSizes({8, 8, 4}) 446 .setLoopType(LinalgTilingLoopType::ParallelLoops) 447 .setDistributionOptions(cyclicNprocsMixed1), 448 LinalgTransformationFilter( 449 StringAttr::get(context, "distribute4"), 450 StringAttr::get(context, "after_distribute4"))); 451 } 452 453 { 454 LinalgLoopDistributionOptions cyclicNprocsMixed2; 455 cyclicNprocsMixed2.distributionMethod = { 456 DistributionMethod::CyclicNumProcsGeNumIters, 457 DistributionMethod::Cyclic}; 458 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 459 patterns.add<LinalgTilingPattern>( 460 MatmulOp::getOperationName(), context, 461 LinalgTilingOptions() 462 .setTileSizes({8, 8, 4}) 463 .setLoopType(LinalgTilingLoopType::ParallelLoops) 464 .setDistributionOptions(cyclicNprocsMixed2), 465 LinalgTransformationFilter( 466 StringAttr::get(context, "distribute5"), 467 StringAttr::get(context, "after_distribute5"))); 468 } 469 470 { 471 LinalgLoopDistributionOptions cyclicNprocsMixed3; 472 cyclicNprocsMixed3.distributionMethod = { 473 DistributionMethod::Cyclic, 474 DistributionMethod::CyclicNumProcsEqNumIters}; 475 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 476 477 patterns.add<LinalgTilingPattern>( 478 MatmulOp::getOperationName(), context, 479 LinalgTilingOptions() 480 .setTileSizes({8, 8, 4}) 481 .setLoopType(LinalgTilingLoopType::ParallelLoops) 482 .setDistributionOptions(cyclicNprocsMixed3), 483 LinalgTransformationFilter( 484 StringAttr::get(context, "distribute6"), 485 StringAttr::get(context, "after_distribute6"))); 486 } 487 488 { 489 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 490 cyclicNprocsEqNiters.distributionMethod.resize(2, 491 DistributionMethod::Cyclic); 492 cyclicNprocsEqNiters.procInfo = 493 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 494 patterns.add<LinalgTilingPattern>( 495 MatmulOp::getOperationName(), context, 496 LinalgTilingOptions() 497 .setTileSizes({8, 8, 4}) 498 .setLoopType(LinalgTilingLoopType::Loops) 499 .setDistributionOptions(cyclicNprocsEqNiters), 500 LinalgTransformationFilter( 501 StringAttr::get(context, "tensors_distribute1"), 502 StringAttr::get(context, "tensors_after_distribute1"))); 503 } 504 } 505 506 static void fillTileFuseAndDistributePatterns(MLIRContext *context, 507 RewritePatternSet &patterns) { 508 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 509 cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); 510 cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 511 patterns.add<LinalgTileAndFuseTensorOpsPattern>( 512 MatmulOp::getOperationName(), context, 513 LinalgTilingAndFusionOptions() 514 .setTileSizes({8, 8, 4}) 515 .setDistributionOptions(cyclicNprocsEqNiters), 516 LinalgTransformationFilter( 517 StringAttr::get(context, "tensors_fuse_distribute1"), 518 StringAttr::get(context, "tensors_after_fuse_distribute1"))); 519 } 520 521 static void 522 applyMatmulToVectorPatterns(FuncOp funcOp, 523 bool testMatmulToVectorPatterns1dTiling, 524 bool testMatmulToVectorPatterns2dTiling) { 525 MLIRContext *ctx = funcOp.getContext(); 526 SmallVector<RewritePatternSet, 4> stage1Patterns; 527 if (testMatmulToVectorPatterns1dTiling) { 528 fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); 529 } else if (testMatmulToVectorPatterns2dTiling) { 530 stage1Patterns.emplace_back( 531 ctx, std::make_unique<LinalgTilingPattern>( 532 MatmulOp::getOperationName(), ctx, 533 LinalgTilingOptions() 534 .setTileSizes({768, 264, 768}) 535 .setInterchange({1, 2, 0}), 536 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 537 StringAttr::get(ctx, "L2")))); 538 fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); 539 } 540 { 541 // Canonicalization patterns 542 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 543 vector::populateVectorTransferPermutationMapLoweringPatterns( 544 canonicalizationPatterns); 545 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 546 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 547 } 548 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 549 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 550 FrozenRewritePatternSet stage2Patterns = 551 getLinalgTilingCanonicalizationPatterns(ctx); 552 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns); 553 } 554 555 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 556 RewritePatternSet forwardPattern(funcOp.getContext()); 557 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 558 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 559 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 560 } 561 562 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 563 RewritePatternSet patterns(funcOp.getContext()); 564 auto *ctx = funcOp.getContext(); 565 patterns.add<LinalgVectorizationPattern>( 566 ctx, LinalgTransformationFilter() 567 .addOpFilter<ContractionOpInterface, FillOp, GenericOp>()); 568 patterns.add<CopyVectorizationPattern>(ctx); 569 populatePadOpVectorizationPatterns(patterns); 570 populateConvolutionVectorizationPatterns(patterns); 571 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 572 } 573 574 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 575 RewritePatternSet patterns(funcOp.getContext()); 576 patterns.add<PadOpTransformationPattern>(funcOp.getContext()); 577 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 578 } 579 580 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 581 RewritePatternSet patterns(funcOp.getContext()); 582 patterns.add<GeneralizePadOpPattern>(funcOp.getContext()); 583 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 584 } 585 586 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 587 RewritePatternSet patterns(funcOp.getContext()); 588 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 589 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 590 } 591 592 static void applyTilePattern(FuncOp funcOp, const std::string &loopType, 593 ArrayRef<int64_t> tileSizes, 594 ArrayRef<int64_t> peeledLoops, 595 bool scalarizeDynamicDims) { 596 MLIRContext *context = funcOp.getContext(); 597 RewritePatternSet tilingPattern(context); 598 LinalgTilingLoopType type = 599 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 600 .Case("for", LinalgTilingLoopType::Loops) 601 .Case("affine", LinalgTilingLoopType::AffineLoops) 602 .Case("parallel", LinalgTilingLoopType::ParallelLoops); 603 auto linalgTilingOptions = linalg::LinalgTilingOptions() 604 .setPeeledLoops(peeledLoops) 605 .setLoopType(type); 606 if (scalarizeDynamicDims) { 607 linalgTilingOptions.scalarizeDynamicDims(); 608 assert(tileSizes.empty() && 609 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 610 } else { 611 linalgTilingOptions.setTileSizes(tileSizes); 612 } 613 linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile")); 614 TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert( 615 tilingPattern, linalgTilingOptions, f); 616 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 617 } 618 619 /// Apply transformations specified as patterns. 620 void TestLinalgTransforms::runOnOperation() { 621 auto lambda = [&](void *) { 622 getOperation().walk([](LinalgOp op) { 623 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 624 }); 625 }; 626 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 627 628 if (testPromotionOptions) { 629 RewritePatternSet patterns(&getContext()); 630 fillPromotionCallBackPatterns(&getContext(), patterns); 631 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 632 return; 633 } 634 if (testTileAndDistributionOptions) { 635 RewritePatternSet patterns(&getContext()); 636 fillTileAndDistributePatterns(&getContext(), patterns); 637 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 638 return; 639 } 640 if (testTileFuseAndDistributionOptions) { 641 RewritePatternSet patterns(&getContext()); 642 fillTileFuseAndDistributePatterns(&getContext(), patterns); 643 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 644 return; 645 } 646 if (testPatterns) 647 return applyPatterns(getOperation()); 648 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 649 return applyMatmulToVectorPatterns(getOperation(), 650 testMatmulToVectorPatterns1dTiling, 651 testMatmulToVectorPatterns2dTiling); 652 if (testVectorTransferForwardingPatterns) 653 return applyVectorTransferForwardingPatterns(getOperation()); 654 if (testGenericToVectorPattern) 655 return applyLinalgToVectorPatterns(getOperation()); 656 if (testTransformPadTensor) 657 return applyPadTensorToGenericPatterns(getOperation()); 658 if (testGeneralizePadTensor) 659 return applyGeneralizePadTensorPatterns(getOperation()); 660 if (testSwapSubTensorPadTensor) 661 return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 662 if (testTilePattern) 663 return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops, 664 /*scalarizeDynamicDims=*/false); 665 if (testTileScalarizeDynamicDims) 666 return applyTilePattern(getOperation(), loopType, tileSizes, 667 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 668 } 669 670 namespace mlir { 671 namespace test { 672 void registerTestLinalgTransforms() { 673 PassRegistration<TestLinalgTransforms>(); 674 } 675 } // namespace test 676 } // namespace mlir 677