1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/GPU/GPUDialect.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallVector.h" 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 namespace { 34 struct TestLinalgTransforms 35 : public PassWrapper<TestLinalgTransforms, OperationPass<FuncOp>> { 36 TestLinalgTransforms() = default; 37 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 38 39 void getDependentDialects(DialectRegistry ®istry) const override { 40 // clang-format off 41 registry.insert<AffineDialect, 42 memref::MemRefDialect, 43 scf::SCFDialect, 44 linalg::LinalgDialect, 45 vector::VectorDialect, 46 gpu::GPUDialect>(); 47 // clang-format on 48 } 49 StringRef getArgument() const final { 50 return "test-linalg-transform-patterns"; 51 } 52 StringRef getDescription() const final { 53 return "Test Linalg transformation patterns by applying them greedily."; 54 } 55 56 void runOnOperation() override; 57 58 Option<bool> testPatterns{*this, "test-patterns", 59 llvm::cl::desc("Test a mixed set of patterns"), 60 llvm::cl::init(false)}; 61 Option<bool> testMatmulToVectorPatterns1dTiling{ 62 *this, "test-matmul-to-vector-patterns-tile-1d", 63 llvm::cl::desc( 64 "Test a fused pass that applies patterns from matmul to vectors via " 65 "1-d tiling"), 66 llvm::cl::init(false)}; 67 Option<bool> testMatmulToVectorPatterns2dTiling{ 68 *this, "test-matmul-to-vector-patterns-tile-2d", 69 llvm::cl::desc( 70 "Test a fused pass that applies patterns from matmul to vectors via " 71 "2-d tiling"), 72 llvm::cl::init(false)}; 73 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 74 llvm::cl::desc("Test promotion options"), 75 llvm::cl::init(false)}; 76 Option<bool> testTileAndDistributionOptions{ 77 *this, "test-tile-and-distribute-options", 78 llvm::cl::desc("Test tile and distribute options"), 79 llvm::cl::init(false)}; 80 Option<bool> testTileFuseAndDistributionOptions{ 81 *this, "test-tile-fuse-and-distribute-options", 82 llvm::cl::desc("Test tile, fuse and distribute options"), 83 llvm::cl::init(false)}; 84 Option<bool> testVectorTransferForwardingPatterns{ 85 *this, "test-vector-transfer-forwarding-patterns", 86 llvm::cl::desc( 87 "Test a fused pass that forwards memref.copy to vector.transfer"), 88 llvm::cl::init(false)}; 89 Option<bool> testGenericToVectorPattern{ 90 *this, "test-linalg-to-vector-patterns", 91 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 92 "in vector.contract form"), 93 llvm::cl::init(false)}; 94 Option<bool> testTilePattern{*this, "test-tile-pattern", 95 llvm::cl::desc("Test tile pattern"), 96 llvm::cl::init(false)}; 97 Option<bool> testTileScalarizeDynamicDims{ 98 *this, "test-tile-scalarize-dynamic-dims", 99 llvm::cl::desc("Test tiling of dynamic dims by 1"), 100 llvm::cl::init(false)}; 101 Option<bool> testTransformPadTensor{ 102 *this, "test-transform-pad-tensor", 103 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 104 llvm::cl::init(false)}; 105 Option<bool> testGeneralizePadTensor{ 106 *this, "test-generalize-pad-tensor", 107 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 108 llvm::cl::init(false)}; 109 Option<bool> testSwapSubTensorPadTensor{ 110 *this, "test-swap-subtensor-padtensor", 111 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 112 "pad_tensor(subtensor)"), 113 llvm::cl::init(false)}; 114 ListOption<int64_t> peeledLoops{ 115 *this, "peeled-loops", 116 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 117 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 118 ListOption<int64_t> tileSizes{ 119 *this, "tile-sizes", 120 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 121 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 122 Option<bool> skipPartial{ 123 *this, "skip-partial", 124 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 125 llvm::cl::init(false)}; 126 Option<std::string> loopType{ 127 *this, "loop-type", 128 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 129 "tiled_loop"), 130 llvm::cl::init("for")}; 131 }; 132 } // namespace 133 134 static void applyPatterns(FuncOp funcOp) { 135 MLIRContext *ctx = funcOp.getContext(); 136 RewritePatternSet patterns(ctx); 137 138 //===--------------------------------------------------------------------===// 139 // Linalg tiling patterns. 140 //===--------------------------------------------------------------------===// 141 patterns.add<LinalgTilingPattern>( 142 MatmulOp::getOperationName(), ctx, 143 LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 144 LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), 145 StringAttr::get(ctx, "L3"))); 146 patterns.add<LinalgTilingPattern>( 147 MatmulOp::getOperationName(), ctx, 148 LinalgTilingOptions().setTileSizes({200, 300, 400}), 149 LinalgTransformationFilter(StringAttr::get(ctx, "L3"), 150 StringAttr::get(ctx, "L2"))); 151 patterns.add<LinalgTilingPattern>( 152 MatmulOp::getOperationName(), ctx, 153 LinalgTilingOptions().setTileSizes({20, 30, 40}), 154 LinalgTransformationFilter(StringAttr::get(ctx, "L2"), 155 StringAttr::get(ctx, "L1"))); 156 patterns.add<LinalgTilingPattern>( 157 MatmulOp::getOperationName(), ctx, 158 LinalgTilingOptions().setTileSizes({2, 3, 4}), 159 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 160 StringAttr::get(ctx, "REG"))); 161 162 patterns.add<LinalgTilingPattern>( 163 MatvecOp::getOperationName(), ctx, 164 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 165 LinalgTilingLoopType::ParallelLoops), 166 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 167 StringAttr::get(ctx, "L1"))); 168 169 patterns.add<LinalgTilingPattern>( 170 DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000), 171 LinalgTransformationFilter( 172 ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"), 173 StringAttr::get(ctx, "L3"), 174 StringAttr::get(ctx, "L2")}, 175 StringAttr::get(ctx, "REG"))); 176 177 //===--------------------------------------------------------------------===// 178 // Linalg tiling and permutation patterns. 179 //===--------------------------------------------------------------------===// 180 patterns.add<LinalgTilingPattern>( 181 MatmulOp::getOperationName(), ctx, 182 LinalgTilingOptions() 183 .setTileSizes({2000, 3000, 4000}) 184 .setInterchange({1, 2, 0}), 185 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 186 StringAttr::get(ctx, "L2__with_perm__"))); 187 patterns.add<LinalgTilingPattern>( 188 MatmulOp::getOperationName(), ctx, 189 LinalgTilingOptions() 190 .setTileSizes({200, 300, 400}) 191 .setInterchange({1, 0, 2}), 192 LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), 193 StringAttr::get(ctx, "L1__with_perm__"))); 194 patterns.add<LinalgTilingPattern>( 195 MatmulOp::getOperationName(), ctx, 196 LinalgTilingOptions().setTileSizes({20, 30, 40}), 197 LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), 198 StringAttr::get(ctx, "REG__with_perm__"))); 199 200 patterns.add<LinalgTilingPattern>( 201 MatvecOp::getOperationName(), ctx, 202 LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 203 LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), 204 StringAttr::get(ctx, "L1__with_perm__"))); 205 206 patterns.add<LinalgTilingPattern>( 207 MatmulOp::getOperationName(), ctx, 208 LinalgTilingOptions() 209 .setTileSizes({16, 8, 4}) 210 .setInterchange({1, 2, 0}) 211 .setLoopType(LinalgTilingLoopType::ParallelLoops), 212 LinalgTransformationFilter( 213 StringAttr::get(ctx, "par__with_perm__"), 214 StringAttr::get(ctx, "after_par__with_perm__"))); 215 216 //===--------------------------------------------------------------------===// 217 // Linalg to loops patterns. 218 //===--------------------------------------------------------------------===// 219 patterns.add<LinalgLoweringPattern<DotOp>>( 220 ctx, 221 /*loweringType=*/LinalgLoweringType::Loops, 222 LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); 223 224 //===--------------------------------------------------------------------===// 225 // Linalg distribution patterns. 226 //===--------------------------------------------------------------------===// 227 LinalgLoopDistributionOptions distributionOptions; 228 229 //===--------------------------------------------------------------------===// 230 // Linalg to vector contraction patterns. 231 //===--------------------------------------------------------------------===// 232 patterns.add<LinalgVectorizationPattern>( 233 ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) 234 .addOpFilter<MatmulOp, FillOp, GenericOp>()); 235 patterns.add<CopyVectorizationPattern>(ctx); 236 237 //===--------------------------------------------------------------------===// 238 // Linalg generic interchange pattern. 239 //===--------------------------------------------------------------------===// 240 patterns.add<GenericOpInterchangePattern>( 241 ctx, 242 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 243 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 244 StringAttr::get(ctx, "PERMUTED"))); 245 246 //===--------------------------------------------------------------------===// 247 // Linalg subview operands promotion. 248 //===--------------------------------------------------------------------===// 249 patterns.add<LinalgPromotionPattern<MatmulOp>>( 250 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 251 LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), 252 StringAttr::get(ctx, "_views_promoted_"))); 253 patterns.add<LinalgPromotionPattern<MatmulOp>>( 254 ctx, 255 LinalgPromotionOptions() 256 .setOperandsToPromote({0}) 257 .setUseFullTileBuffersByDefault(true), 258 LinalgTransformationFilter( 259 StringAttr::get(ctx, "_promote_first_view_"), 260 StringAttr::get(ctx, "_first_view_promoted_"))); 261 patterns.add<LinalgPromotionPattern<FillOp>>( 262 ctx, 263 LinalgPromotionOptions() 264 .setOperandsToPromote({1}) 265 .setUseFullTileBuffers({false, true}) 266 .setAlignment(32), 267 LinalgTransformationFilter( 268 StringAttr::get(ctx, "_promote_views_aligned_"), 269 StringAttr::get(ctx, "_views_aligned_promoted_"))); 270 271 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 272 273 // Drop the marker. 274 funcOp.walk([](LinalgOp op) { 275 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 276 }); 277 } 278 279 static void fillL1TilingAndMatmulToVectorPatterns( 280 FuncOp funcOp, StringRef startMarker, 281 SmallVectorImpl<RewritePatternSet> &patternsVector) { 282 MLIRContext *ctx = funcOp.getContext(); 283 patternsVector.emplace_back( 284 ctx, std::make_unique<LinalgTilingPattern>( 285 MatmulOp::getOperationName(), ctx, 286 LinalgTilingOptions() 287 .setTileSizes({8, 12, 16}) 288 .setInterchange({1, 0, 2}), 289 LinalgTransformationFilter(StringAttr::get(ctx, startMarker), 290 StringAttr::get(ctx, "L1")))); 291 292 patternsVector.emplace_back( 293 ctx, 294 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 295 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 296 LinalgTransformationFilter(StringAttr::get(ctx, "L1"), 297 StringAttr::get(ctx, "VEC")))); 298 299 patternsVector.emplace_back( 300 ctx, std::make_unique<LinalgVectorizationPattern>( 301 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 302 LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); 303 patternsVector.back().add<LinalgVectorizationPattern>( 304 ctx, LinalgTransformationFilter().addOpFilter<FillOp>()); 305 patternsVector.back().add<CopyVectorizationPattern>(ctx); 306 } 307 308 //===----------------------------------------------------------------------===// 309 // Test promotion callbacks 310 //===----------------------------------------------------------------------===// 311 312 // Allocation call back 313 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 314 ArrayRef<Value> boundingSubViewSize, 315 DataLayout &layout) { 316 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 317 return b 318 .create<memref::AllocOp>( 319 subView.getLoc(), 320 MemRefType::get(shape, subView.getType().getElementType(), 321 /*affineMapComposition =*/{}, 3), 322 boundingSubViewSize) 323 .getResult(); 324 } 325 326 // Deallocation callback 327 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 328 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 329 return success(); 330 } 331 332 // Copy in call back 333 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 334 bool isOutput) { 335 auto floatType = src.getType().cast<MemRefType>().getElementType(); 336 if (!floatType.isa<FloatType>()) 337 return failure(); 338 if (!isOutput) { 339 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 340 FloatAttr::get(floatType, 42.0)); 341 b.create<FillOp>(src.getLoc(), cst, dst); 342 } 343 b.create<memref::CopyOp>(src.getLoc(), src, dst); 344 return success(); 345 } 346 347 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 348 RewritePatternSet &patterns) { 349 patterns.add<LinalgTilingPattern>( 350 MatmulOp::getOperationName(), ctx, 351 LinalgTilingOptions().setTileSizes({16, 16, 16}), 352 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 353 StringAttr::get(ctx, "PROMOTE"))); 354 patterns.add<LinalgPromotionPattern<MatmulOp>>( 355 ctx, 356 LinalgPromotionOptions() 357 .setOperandsToPromote({0, 2}) 358 .setUseFullTileBuffers({false, false}) 359 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 360 .setCopyInOutFns( 361 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 362 return copyCallBackFn(b, src, dst, false); 363 }, 364 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 365 return copyCallBackFn(b, src, dst, true); 366 }), 367 LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); 368 } 369 370 template <typename IdOp, typename NProcsOp> 371 static SmallVector<ProcInfo, 2> 372 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 373 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 374 SmallVector<ProcInfo, 2> procInfo(count); 375 Type indexType = b.getIndexType(); 376 for (unsigned i = 0; i < count; ++i) { 377 gpu::Dimension dim = *gpu::symbolizeDimension(i); 378 procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim), 379 b.create<NProcsOp>(loc, indexType, dim)}; 380 } 381 return procInfo; 382 } 383 384 static void fillTileAndDistributePatterns(MLIRContext *context, 385 RewritePatternSet &patterns) { 386 { 387 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 388 cyclicNprocsEqNiters.distributionMethod.resize( 389 2, DistributionMethod::CyclicNumProcsEqNumIters); 390 cyclicNprocsEqNiters.procInfo = 391 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 392 patterns.add<LinalgTilingPattern>( 393 MatmulOp::getOperationName(), context, 394 LinalgTilingOptions() 395 .setTileSizes({8, 8, 4}) 396 .setLoopType(LinalgTilingLoopType::ParallelLoops) 397 .setDistributionOptions(cyclicNprocsEqNiters), 398 LinalgTransformationFilter( 399 StringAttr::get(context, "distribute1"), 400 StringAttr::get(context, "after_distribute1"))); 401 } 402 403 { 404 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 405 cyclicNprocsGeNiters.distributionMethod.resize( 406 2, DistributionMethod::CyclicNumProcsGeNumIters); 407 cyclicNprocsGeNiters.procInfo = 408 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 409 patterns.add<LinalgTilingPattern>( 410 MatmulOp::getOperationName(), context, 411 LinalgTilingOptions() 412 .setTileSizes({8, 8, 4}) 413 .setLoopType(LinalgTilingLoopType::ParallelLoops) 414 .setDistributionOptions(cyclicNprocsGeNiters), 415 LinalgTransformationFilter( 416 StringAttr::get(context, "distribute2"), 417 StringAttr::get(context, "after_distribute2"))); 418 } 419 420 { 421 LinalgLoopDistributionOptions cyclicNprocsDefault; 422 cyclicNprocsDefault.distributionMethod.resize(2, 423 DistributionMethod::Cyclic); 424 cyclicNprocsDefault.procInfo = 425 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 426 patterns.add<LinalgTilingPattern>( 427 MatmulOp::getOperationName(), context, 428 LinalgTilingOptions() 429 .setTileSizes({8, 8, 4}) 430 .setLoopType(LinalgTilingLoopType::ParallelLoops) 431 .setDistributionOptions(cyclicNprocsDefault), 432 LinalgTransformationFilter( 433 StringAttr::get(context, "distribute3"), 434 StringAttr::get(context, "after_distribute3"))); 435 } 436 437 { 438 LinalgLoopDistributionOptions cyclicNprocsMixed1; 439 cyclicNprocsMixed1.distributionMethod = { 440 DistributionMethod::CyclicNumProcsEqNumIters, 441 DistributionMethod::CyclicNumProcsGeNumIters}; 442 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 443 patterns.add<LinalgTilingPattern>( 444 MatmulOp::getOperationName(), context, 445 LinalgTilingOptions() 446 .setTileSizes({8, 8, 4}) 447 .setLoopType(LinalgTilingLoopType::ParallelLoops) 448 .setDistributionOptions(cyclicNprocsMixed1), 449 LinalgTransformationFilter( 450 StringAttr::get(context, "distribute4"), 451 StringAttr::get(context, "after_distribute4"))); 452 } 453 454 { 455 LinalgLoopDistributionOptions cyclicNprocsMixed2; 456 cyclicNprocsMixed2.distributionMethod = { 457 DistributionMethod::CyclicNumProcsGeNumIters, 458 DistributionMethod::Cyclic}; 459 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 460 patterns.add<LinalgTilingPattern>( 461 MatmulOp::getOperationName(), context, 462 LinalgTilingOptions() 463 .setTileSizes({8, 8, 4}) 464 .setLoopType(LinalgTilingLoopType::ParallelLoops) 465 .setDistributionOptions(cyclicNprocsMixed2), 466 LinalgTransformationFilter( 467 StringAttr::get(context, "distribute5"), 468 StringAttr::get(context, "after_distribute5"))); 469 } 470 471 { 472 LinalgLoopDistributionOptions cyclicNprocsMixed3; 473 cyclicNprocsMixed3.distributionMethod = { 474 DistributionMethod::Cyclic, 475 DistributionMethod::CyclicNumProcsEqNumIters}; 476 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 477 478 patterns.add<LinalgTilingPattern>( 479 MatmulOp::getOperationName(), context, 480 LinalgTilingOptions() 481 .setTileSizes({8, 8, 4}) 482 .setLoopType(LinalgTilingLoopType::ParallelLoops) 483 .setDistributionOptions(cyclicNprocsMixed3), 484 LinalgTransformationFilter( 485 StringAttr::get(context, "distribute6"), 486 StringAttr::get(context, "after_distribute6"))); 487 } 488 489 { 490 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 491 cyclicNprocsEqNiters.distributionMethod.resize(2, 492 DistributionMethod::Cyclic); 493 cyclicNprocsEqNiters.procInfo = 494 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 495 patterns.add<LinalgTilingPattern>( 496 MatmulOp::getOperationName(), context, 497 LinalgTilingOptions() 498 .setTileSizes({8, 8, 4}) 499 .setLoopType(LinalgTilingLoopType::Loops) 500 .setDistributionOptions(cyclicNprocsEqNiters), 501 LinalgTransformationFilter( 502 StringAttr::get(context, "tensors_distribute1"), 503 StringAttr::get(context, "tensors_after_distribute1"))); 504 } 505 } 506 507 static void fillTileFuseAndDistributePatterns(MLIRContext *context, 508 RewritePatternSet &patterns) { 509 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 510 cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); 511 cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 512 patterns.add<LinalgTileAndFuseTensorOpsPattern>( 513 MatmulOp::getOperationName(), context, 514 LinalgTilingAndFusionOptions() 515 .setTileSizes({8, 8, 4}) 516 .setDistributionOptions(cyclicNprocsEqNiters), 517 LinalgTransformationFilter( 518 StringAttr::get(context, "tensors_fuse_distribute1"), 519 StringAttr::get(context, "tensors_after_fuse_distribute1"))); 520 } 521 522 static void 523 applyMatmulToVectorPatterns(FuncOp funcOp, 524 bool testMatmulToVectorPatterns1dTiling, 525 bool testMatmulToVectorPatterns2dTiling) { 526 MLIRContext *ctx = funcOp.getContext(); 527 SmallVector<RewritePatternSet, 4> stage1Patterns; 528 if (testMatmulToVectorPatterns1dTiling) { 529 fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); 530 } else if (testMatmulToVectorPatterns2dTiling) { 531 stage1Patterns.emplace_back( 532 ctx, std::make_unique<LinalgTilingPattern>( 533 MatmulOp::getOperationName(), ctx, 534 LinalgTilingOptions() 535 .setTileSizes({768, 264, 768}) 536 .setInterchange({1, 2, 0}), 537 LinalgTransformationFilter(StringAttr::get(ctx, "START"), 538 StringAttr::get(ctx, "L2")))); 539 fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); 540 } 541 { 542 // Canonicalization patterns 543 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 544 vector::populateVectorTransferPermutationMapLoweringPatterns( 545 canonicalizationPatterns); 546 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 547 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 548 } 549 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 550 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 551 FrozenRewritePatternSet stage2Patterns = 552 getLinalgTilingCanonicalizationPatterns(ctx); 553 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns); 554 } 555 556 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 557 RewritePatternSet forwardPattern(funcOp.getContext()); 558 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 559 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 560 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 561 } 562 563 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 564 RewritePatternSet patterns(funcOp.getContext()); 565 auto *ctx = funcOp.getContext(); 566 patterns.add<LinalgVectorizationPattern>( 567 ctx, LinalgTransformationFilter() 568 .addOpFilter<ContractionOpInterface, FillOp, GenericOp>()); 569 patterns.add<CopyVectorizationPattern>(ctx); 570 populatePadOpVectorizationPatterns(patterns); 571 populateConvolutionVectorizationPatterns(patterns); 572 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 573 } 574 575 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 576 RewritePatternSet patterns(funcOp.getContext()); 577 patterns.add<PadOpTransformationPattern>(funcOp.getContext()); 578 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 579 } 580 581 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 582 RewritePatternSet patterns(funcOp.getContext()); 583 patterns.add<GeneralizePadOpPattern>(funcOp.getContext()); 584 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 585 } 586 587 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 588 RewritePatternSet patterns(funcOp.getContext()); 589 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 590 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 591 } 592 593 static void applyTilePattern(FuncOp funcOp, const std::string &loopType, 594 ArrayRef<int64_t> tileSizes, 595 ArrayRef<int64_t> peeledLoops, 596 bool scalarizeDynamicDims) { 597 MLIRContext *context = funcOp.getContext(); 598 RewritePatternSet tilingPattern(context); 599 LinalgTilingLoopType type = 600 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 601 .Case("for", LinalgTilingLoopType::Loops) 602 .Case("affine", LinalgTilingLoopType::AffineLoops) 603 .Case("parallel", LinalgTilingLoopType::ParallelLoops); 604 auto linalgTilingOptions = linalg::LinalgTilingOptions() 605 .setPeeledLoops(peeledLoops) 606 .setLoopType(type); 607 if (scalarizeDynamicDims) { 608 linalgTilingOptions.scalarizeDynamicDims(); 609 assert(tileSizes.empty() && 610 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 611 } else { 612 linalgTilingOptions.setTileSizes(tileSizes); 613 } 614 linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile")); 615 TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert( 616 tilingPattern, linalgTilingOptions, f); 617 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 618 } 619 620 /// Apply transformations specified as patterns. 621 void TestLinalgTransforms::runOnOperation() { 622 auto lambda = [&](void *) { 623 getOperation().walk([](LinalgOp op) { 624 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 625 }); 626 }; 627 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 628 629 if (testPromotionOptions) { 630 RewritePatternSet patterns(&getContext()); 631 fillPromotionCallBackPatterns(&getContext(), patterns); 632 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 633 return; 634 } 635 if (testTileAndDistributionOptions) { 636 RewritePatternSet patterns(&getContext()); 637 fillTileAndDistributePatterns(&getContext(), patterns); 638 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 639 return; 640 } 641 if (testTileFuseAndDistributionOptions) { 642 RewritePatternSet patterns(&getContext()); 643 fillTileFuseAndDistributePatterns(&getContext(), patterns); 644 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 645 return; 646 } 647 if (testPatterns) 648 return applyPatterns(getOperation()); 649 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 650 return applyMatmulToVectorPatterns(getOperation(), 651 testMatmulToVectorPatterns1dTiling, 652 testMatmulToVectorPatterns2dTiling); 653 if (testVectorTransferForwardingPatterns) 654 return applyVectorTransferForwardingPatterns(getOperation()); 655 if (testGenericToVectorPattern) 656 return applyLinalgToVectorPatterns(getOperation()); 657 if (testTransformPadTensor) 658 return applyPadTensorToGenericPatterns(getOperation()); 659 if (testGeneralizePadTensor) 660 return applyGeneralizePadTensorPatterns(getOperation()); 661 if (testSwapSubTensorPadTensor) 662 return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 663 if (testTilePattern) 664 return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops, 665 /*scalarizeDynamicDims=*/false); 666 if (testTileScalarizeDynamicDims) 667 return applyTilePattern(getOperation(), loopType, tileSizes, 668 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 669 } 670 671 namespace mlir { 672 namespace test { 673 void registerTestLinalgTransforms() { 674 PassRegistration<TestLinalgTransforms>(); 675 } 676 } // namespace test 677 } // namespace mlir 678