1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 #include "llvm/ADT/SetVector.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 namespace { 30 struct TestLinalgTransforms 31 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 32 TestLinalgTransforms() = default; 33 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 34 35 void getDependentDialects(DialectRegistry ®istry) const override { 36 // clang-format off 37 registry.insert<AffineDialect, 38 memref::MemRefDialect, 39 scf::SCFDialect, 40 StandardOpsDialect, 41 vector::VectorDialect, 42 gpu::GPUDialect>(); 43 // clang-format on 44 } 45 StringRef getArgument() const final { 46 return "test-linalg-transform-patterns"; 47 } 48 StringRef getDescription() const final { 49 return "Test Linalg transformation patterns by applying them greedily."; 50 } 51 52 void runOnFunction() override; 53 54 Option<bool> testPatterns{*this, "test-patterns", 55 llvm::cl::desc("Test a mixed set of patterns"), 56 llvm::cl::init(false)}; 57 Option<bool> testMatmulToVectorPatterns1dTiling{ 58 *this, "test-matmul-to-vector-patterns-tile-1d", 59 llvm::cl::desc( 60 "Test a fused pass that applies patterns from matmul to vectors via " 61 "1-d tiling"), 62 llvm::cl::init(false)}; 63 Option<bool> testMatmulToVectorPatterns2dTiling{ 64 *this, "test-matmul-to-vector-patterns-tile-2d", 65 llvm::cl::desc( 66 "Test a fused pass that applies patterns from matmul to vectors via " 67 "2-d tiling"), 68 llvm::cl::init(false)}; 69 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 70 llvm::cl::desc("Test promotion options"), 71 llvm::cl::init(false)}; 72 Option<bool> testTileAndDistributionOptions{ 73 *this, "test-tile-and-distribute-options", 74 llvm::cl::desc("Test tile and distribute options"), 75 llvm::cl::init(false)}; 76 Option<bool> testVectorTransferForwardingPatterns{ 77 *this, "test-vector-transfer-forwarding-patterns", 78 llvm::cl::desc( 79 "Test a fused pass that forwards linalg.copy to vector.transfer"), 80 llvm::cl::init(false)}; 81 Option<bool> testGenericToVectorPattern{ 82 *this, "test-linalg-to-vector-patterns", 83 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 84 "in vector.contract form"), 85 llvm::cl::init(false)}; 86 Option<bool> testAffineMinSCFCanonicalizationPatterns{ 87 *this, "test-affine-min-scf-canonicalization-patterns", 88 llvm::cl::desc("Test affine-min + scf canonicalization patterns."), 89 llvm::cl::init(false)}; 90 Option<bool> testTileAndPadPattern{ 91 *this, "test-tile-and-pad-pattern", 92 llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)}; 93 Option<int> testHoistPadding{*this, "test-hoist-padding", 94 llvm::cl::desc("Test hoist padding"), 95 llvm::cl::init(0)}; 96 Option<bool> testTransformPadTensor{ 97 *this, "test-transform-pad-tensor", 98 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 99 llvm::cl::init(false)}; 100 Option<bool> testGeneralizePadTensor{ 101 *this, "test-generalize-pad-tensor", 102 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 103 llvm::cl::init(false)}; 104 Option<bool> testSwapSubTensorPadTensor{ 105 *this, "test-swap-subtensor-padtensor", 106 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 107 "pad_tensor(subtensor)"), 108 llvm::cl::init(false)}; 109 ListOption<int64_t> tileSizesForPadding{ 110 *this, "tile-sizes-for-padding", 111 llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, 112 llvm::cl::MiscFlags::CommaSeparated}; 113 ListOption<unsigned> testInterchangePattern{ 114 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 115 llvm::cl::desc("Test the interchange pattern.")}; 116 }; 117 } // end anonymous namespace 118 119 static void applyPatterns(FuncOp funcOp) { 120 MLIRContext *ctx = funcOp.getContext(); 121 RewritePatternSet patterns(ctx); 122 123 //===--------------------------------------------------------------------===// 124 // Linalg tiling patterns. 125 //===--------------------------------------------------------------------===// 126 patterns.add<LinalgTilingPattern<MatmulOp>>( 127 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 128 LinalgTransformationFilter(Identifier::get("MEM", ctx), 129 Identifier::get("L3", ctx))); 130 patterns.add<LinalgTilingPattern<MatmulOp>>( 131 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 132 LinalgTransformationFilter(Identifier::get("L3", ctx), 133 Identifier::get("L2", ctx))); 134 patterns.add<LinalgTilingPattern<MatmulOp>>( 135 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 136 LinalgTransformationFilter(Identifier::get("L2", ctx), 137 Identifier::get("L1", ctx))); 138 patterns.add<LinalgTilingPattern<MatmulOp>>( 139 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 140 LinalgTransformationFilter(Identifier::get("L1", ctx), 141 Identifier::get("REG", ctx))); 142 143 patterns.add<LinalgTilingPattern<MatvecOp>>( 144 ctx, 145 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 146 LinalgTilingLoopType::ParallelLoops), 147 LinalgTransformationFilter(ArrayRef<Identifier>{}, 148 Identifier::get("L1", ctx))); 149 150 patterns.add<LinalgTilingPattern<DotOp>>( 151 ctx, LinalgTilingOptions().setTileSizes(8000), 152 LinalgTransformationFilter( 153 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 154 Identifier::get("L3", ctx), 155 Identifier::get("L2", ctx)}, 156 Identifier::get("REG", ctx))); 157 158 //===--------------------------------------------------------------------===// 159 // Linalg tiling and permutation patterns. 160 //===--------------------------------------------------------------------===// 161 patterns.add<LinalgTilingPattern<MatmulOp>>( 162 ctx, 163 LinalgTilingOptions() 164 .setTileSizes({2000, 3000, 4000}) 165 .setInterchange({1, 2, 0}), 166 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 167 Identifier::get("L2__with_perm__", ctx))); 168 patterns.add<LinalgTilingPattern<MatmulOp>>( 169 ctx, 170 LinalgTilingOptions() 171 .setTileSizes({200, 300, 400}) 172 .setInterchange({1, 0, 2}), 173 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 174 Identifier::get("L1__with_perm__", ctx))); 175 patterns.add<LinalgTilingPattern<MatmulOp>>( 176 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 177 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 178 Identifier::get("REG__with_perm__", ctx))); 179 180 patterns.add<LinalgTilingPattern<MatvecOp>>( 181 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 182 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 183 Identifier::get("L1__with_perm__", ctx))); 184 185 patterns.add<LinalgTilingPattern<MatmulOp>>( 186 ctx, 187 LinalgTilingOptions() 188 .setTileSizes({16, 8, 4}) 189 .setInterchange({1, 2, 0}) 190 .setLoopType(LinalgTilingLoopType::ParallelLoops), 191 LinalgTransformationFilter( 192 Identifier::get("par__with_perm__", ctx), 193 Identifier::get("after_par__with_perm__", ctx))); 194 195 //===--------------------------------------------------------------------===// 196 // Linalg to loops patterns. 197 //===--------------------------------------------------------------------===// 198 patterns.add<LinalgLoweringPattern<DotOp>>( 199 ctx, 200 /*loweringType=*/LinalgLoweringType::Loops, 201 LinalgTransformationFilter(Identifier::get("REG", ctx))); 202 203 //===--------------------------------------------------------------------===// 204 // Linalg distribution patterns. 205 //===--------------------------------------------------------------------===// 206 LinalgLoopDistributionOptions distributionOptions; 207 208 //===--------------------------------------------------------------------===// 209 // Linalg to vector contraction patterns. 210 //===--------------------------------------------------------------------===// 211 patterns.add<LinalgVectorizationPattern>( 212 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 213 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 214 215 //===--------------------------------------------------------------------===// 216 // Linalg generic interchange pattern. 217 //===--------------------------------------------------------------------===// 218 patterns.add<GenericOpInterchangePattern>( 219 ctx, 220 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 221 LinalgTransformationFilter(ArrayRef<Identifier>{}, 222 Identifier::get("PERMUTED", ctx))); 223 224 //===--------------------------------------------------------------------===// 225 // Linalg subview operands promotion. 226 //===--------------------------------------------------------------------===// 227 patterns.add<LinalgPromotionPattern<MatmulOp>>( 228 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 229 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 230 Identifier::get("_views_promoted_", ctx))); 231 patterns.add<LinalgPromotionPattern<MatmulOp>>( 232 ctx, 233 LinalgPromotionOptions() 234 .setOperandsToPromote({0}) 235 .setUseFullTileBuffersByDefault(true), 236 LinalgTransformationFilter( 237 Identifier::get("_promote_first_view_", ctx), 238 Identifier::get("_first_view_promoted_", ctx))); 239 patterns.add<LinalgPromotionPattern<FillOp>>( 240 ctx, 241 LinalgPromotionOptions() 242 .setOperandsToPromote({1}) 243 .setUseFullTileBuffers({false, true}) 244 .setAlignment(32), 245 LinalgTransformationFilter( 246 Identifier::get("_promote_views_aligned_", ctx), 247 Identifier::get("_views_aligned_promoted_", ctx))); 248 249 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 250 251 // Drop the marker. 252 funcOp.walk([](LinalgOp op) { 253 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 254 }); 255 } 256 257 static void fillL1TilingAndMatmulToVectorPatterns( 258 FuncOp funcOp, StringRef startMarker, 259 SmallVectorImpl<RewritePatternSet> &patternsVector) { 260 MLIRContext *ctx = funcOp.getContext(); 261 patternsVector.emplace_back( 262 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 263 ctx, 264 LinalgTilingOptions() 265 .setTileSizes({8, 12, 16}) 266 .setInterchange({1, 0, 2}), 267 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 268 Identifier::get("L1", ctx)))); 269 270 patternsVector.emplace_back( 271 ctx, 272 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 273 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 274 LinalgTransformationFilter(Identifier::get("L1", ctx), 275 Identifier::get("VEC", ctx)))); 276 277 patternsVector.emplace_back( 278 ctx, std::make_unique<LinalgVectorizationPattern>( 279 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 280 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 281 patternsVector.back().add<LinalgVectorizationPattern>( 282 ctx, LinalgTransformationFilter().addFilter( 283 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 284 } 285 286 //===----------------------------------------------------------------------===// 287 // Test promotion callbacks 288 //===----------------------------------------------------------------------===// 289 290 // Allocation call back 291 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 292 ArrayRef<Value> boundingSubViewSize, 293 DataLayout &layout) { 294 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 295 return b 296 .create<memref::AllocOp>( 297 subView.getLoc(), 298 MemRefType::get(shape, subView.getType().getElementType(), 299 /*affineMapComposition =*/{}, 3), 300 boundingSubViewSize) 301 .getResult(); 302 } 303 304 // Deallocation callback 305 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 306 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 307 return success(); 308 } 309 310 // Copy in call back 311 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 312 bool isOutput) { 313 auto floatType = src.getType().cast<MemRefType>().getElementType(); 314 if (!floatType.isa<FloatType>()) 315 return failure(); 316 if (!isOutput) { 317 Value cst = 318 b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)); 319 b.create<FillOp>(src.getLoc(), cst, dst); 320 } 321 b.create<CopyOp>(src.getLoc(), src, dst); 322 return success(); 323 } 324 325 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 326 RewritePatternSet &patterns) { 327 patterns.add<LinalgTilingPattern<MatmulOp>>( 328 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 329 LinalgTransformationFilter(Identifier::get("START", ctx), 330 Identifier::get("PROMOTE", ctx))); 331 patterns.add<LinalgPromotionPattern<MatmulOp>>( 332 ctx, 333 LinalgPromotionOptions() 334 .setOperandsToPromote({0, 2}) 335 .setUseFullTileBuffers({false, false}) 336 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 337 .setCopyInOutFns( 338 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 339 return copyCallBackFn(b, src, dst, false); 340 }, 341 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 342 return copyCallBackFn(b, src, dst, true); 343 }), 344 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 345 } 346 347 template <typename IdOp, typename NProcsOp> 348 static SmallVector<ProcInfo, 2> 349 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 350 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 351 SmallVector<ProcInfo, 2> procInfo(count); 352 const char *xyz[] = {"x", "y", "z"}; 353 Type indexType = b.getIndexType(); 354 for (unsigned i = 0; i < count; ++i) { 355 procInfo[count - 1 - i] = { 356 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 357 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 358 } 359 return procInfo; 360 } 361 362 static void fillTileAndDistributePatterns(MLIRContext *context, 363 RewritePatternSet &patterns) { 364 { 365 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 366 cyclicNprocsEqNiters.distributionMethod.resize( 367 2, DistributionMethod::CyclicNumProcsEqNumIters); 368 cyclicNprocsEqNiters.procInfo = 369 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 370 patterns.add<LinalgTilingPattern<MatmulOp>>( 371 context, 372 LinalgTilingOptions() 373 .setTileSizes({8, 8, 4}) 374 .setLoopType(LinalgTilingLoopType::ParallelLoops) 375 .setDistributionOptions(cyclicNprocsEqNiters), 376 LinalgTransformationFilter( 377 Identifier::get("distribute1", context), 378 Identifier::get("after_distribute1", context))); 379 } 380 381 { 382 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 383 cyclicNprocsGeNiters.distributionMethod.resize( 384 2, DistributionMethod::CyclicNumProcsGeNumIters); 385 cyclicNprocsGeNiters.procInfo = 386 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 387 patterns.add<LinalgTilingPattern<MatmulOp>>( 388 context, 389 LinalgTilingOptions() 390 .setTileSizes({8, 8, 4}) 391 .setLoopType(LinalgTilingLoopType::ParallelLoops) 392 .setDistributionOptions(cyclicNprocsGeNiters), 393 LinalgTransformationFilter( 394 Identifier::get("distribute2", context), 395 Identifier::get("after_distribute2", context))); 396 } 397 398 { 399 LinalgLoopDistributionOptions cyclicNprocsDefault; 400 cyclicNprocsDefault.distributionMethod.resize(2, 401 DistributionMethod::Cyclic); 402 cyclicNprocsDefault.procInfo = 403 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 404 patterns.add<LinalgTilingPattern<MatmulOp>>( 405 context, 406 LinalgTilingOptions() 407 .setTileSizes({8, 8, 4}) 408 .setLoopType(LinalgTilingLoopType::ParallelLoops) 409 .setDistributionOptions(cyclicNprocsDefault), 410 LinalgTransformationFilter( 411 Identifier::get("distribute3", context), 412 Identifier::get("after_distribute3", context))); 413 } 414 415 { 416 LinalgLoopDistributionOptions cyclicNprocsMixed1; 417 cyclicNprocsMixed1.distributionMethod = { 418 DistributionMethod::CyclicNumProcsEqNumIters, 419 DistributionMethod::CyclicNumProcsGeNumIters}; 420 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 421 patterns.add<LinalgTilingPattern<MatmulOp>>( 422 context, 423 LinalgTilingOptions() 424 .setTileSizes({8, 8, 4}) 425 .setLoopType(LinalgTilingLoopType::ParallelLoops) 426 .setDistributionOptions(cyclicNprocsMixed1), 427 LinalgTransformationFilter( 428 Identifier::get("distribute4", context), 429 Identifier::get("after_distribute4", context))); 430 } 431 432 { 433 LinalgLoopDistributionOptions cyclicNprocsMixed2; 434 cyclicNprocsMixed2.distributionMethod = { 435 DistributionMethod::CyclicNumProcsGeNumIters, 436 DistributionMethod::Cyclic}; 437 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 438 patterns.add<LinalgTilingPattern<MatmulOp>>( 439 context, 440 LinalgTilingOptions() 441 .setTileSizes({8, 8, 4}) 442 .setLoopType(LinalgTilingLoopType::ParallelLoops) 443 .setDistributionOptions(cyclicNprocsMixed2), 444 LinalgTransformationFilter( 445 Identifier::get("distribute5", context), 446 Identifier::get("after_distribute5", context))); 447 } 448 449 { 450 LinalgLoopDistributionOptions cyclicNprocsMixed3; 451 cyclicNprocsMixed3.distributionMethod = { 452 DistributionMethod::Cyclic, 453 DistributionMethod::CyclicNumProcsEqNumIters}; 454 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 455 456 patterns.add<LinalgTilingPattern<MatmulOp>>( 457 context, 458 LinalgTilingOptions() 459 .setTileSizes({8, 8, 4}) 460 .setLoopType(LinalgTilingLoopType::ParallelLoops) 461 .setDistributionOptions(cyclicNprocsMixed3), 462 LinalgTransformationFilter( 463 Identifier::get("distribute6", context), 464 Identifier::get("after_distribute6", context))); 465 } 466 467 { 468 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 469 cyclicNprocsEqNiters.distributionMethod.resize(2, 470 DistributionMethod::Cyclic); 471 cyclicNprocsEqNiters.procInfo = 472 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 473 patterns.add<LinalgTilingPattern<MatmulOp>>( 474 context, 475 LinalgTilingOptions() 476 .setTileSizes({8, 8, 4}) 477 .setLoopType(LinalgTilingLoopType::Loops) 478 .setDistributionOptions(cyclicNprocsEqNiters), 479 LinalgTransformationFilter( 480 Identifier::get("tensors_distribute1", context), 481 Identifier::get("tensors_after_distribute1", context))); 482 } 483 } 484 485 static void 486 applyMatmulToVectorPatterns(FuncOp funcOp, 487 bool testMatmulToVectorPatterns1dTiling, 488 bool testMatmulToVectorPatterns2dTiling) { 489 MLIRContext *ctx = funcOp.getContext(); 490 SmallVector<RewritePatternSet, 4> stage1Patterns; 491 if (testMatmulToVectorPatterns1dTiling) { 492 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 493 stage1Patterns); 494 } else if (testMatmulToVectorPatterns2dTiling) { 495 stage1Patterns.emplace_back( 496 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 497 ctx, 498 LinalgTilingOptions() 499 .setTileSizes({768, 264, 768}) 500 .setInterchange({1, 2, 0}), 501 LinalgTransformationFilter(Identifier::get("START", ctx), 502 Identifier::get("L2", ctx)))); 503 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 504 stage1Patterns); 505 } 506 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 507 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 508 FrozenRewritePatternSet stage2Patterns = 509 getLinalgTilingCanonicalizationPatterns(ctx); 510 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 511 std::move(stage2Patterns)); 512 } 513 514 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 515 RewritePatternSet forwardPattern(funcOp.getContext()); 516 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 517 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 518 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 519 } 520 521 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 522 RewritePatternSet patterns(funcOp.getContext()); 523 patterns.add<LinalgVectorizationPattern>( 524 funcOp.getContext(), 525 LinalgTransformationFilter() 526 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 527 populatePadTensorOpVectorizationPatterns(patterns); 528 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 529 } 530 531 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 532 RewritePatternSet patterns(funcOp.getContext()); 533 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 534 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 535 } 536 537 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 538 RewritePatternSet patterns(funcOp.getContext()); 539 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 540 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 541 } 542 543 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 544 RewritePatternSet patterns(funcOp.getContext()); 545 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 546 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 547 } 548 549 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) { 550 RewritePatternSet foldPattern(funcOp.getContext()); 551 foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext()); 552 FrozenRewritePatternSet frozenPatterns(std::move(foldPattern)); 553 554 // Explicitly walk and apply the pattern locally to avoid more general folding 555 // on the rest of the IR. 556 funcOp.walk([&frozenPatterns](AffineMinOp minOp) { 557 (void)applyOpPatternsAndFold(minOp, frozenPatterns); 558 }); 559 } 560 561 // For now, just assume it is the zero of type. 562 // In the future, it should be the zero of type + op. 563 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 564 auto t = getElementTypeOrSelf(op.get()); 565 return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); 566 } 567 568 static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) { 569 MLIRContext *context = funcOp.getContext(); 570 RewritePatternSet tilingPattern(context); 571 auto linalgTilingOptions = 572 linalg::LinalgTilingOptions() 573 .setTileSizes(tileSizes) 574 .setPaddingValueComputationFunction(getNeutralOfLinalgOp); 575 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>, 576 linalg::LinalgTilingPattern<linalg::GenericOp>>( 577 context, linalgTilingOptions, 578 linalg::LinalgTransformationFilter( 579 Identifier::get("tile-and-pad", context))); 580 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 581 } 582 583 static void applyInterchangePattern(FuncOp funcOp, 584 ArrayRef<unsigned> interchangeVector) { 585 MLIRContext *context = funcOp.getContext(); 586 RewritePatternSet interchangePattern(context); 587 interchangePattern.add<GenericOpInterchangePattern>( 588 context, interchangeVector, 589 LinalgTransformationFilter(ArrayRef<Identifier>{}, 590 Identifier::get("interchange", context))); 591 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 592 } 593 594 /// Apply transformations specified as patterns. 595 void TestLinalgTransforms::runOnFunction() { 596 auto lambda = [&](void *) { 597 getFunction().walk([](LinalgOp op) { 598 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 599 }); 600 }; 601 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 602 603 if (testPromotionOptions) { 604 RewritePatternSet patterns(&getContext()); 605 fillPromotionCallBackPatterns(&getContext(), patterns); 606 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 607 return; 608 } 609 if (testTileAndDistributionOptions) { 610 RewritePatternSet patterns(&getContext()); 611 fillTileAndDistributePatterns(&getContext(), patterns); 612 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 613 return; 614 } 615 if (testPatterns) 616 return applyPatterns(getFunction()); 617 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 618 return applyMatmulToVectorPatterns(getFunction(), 619 testMatmulToVectorPatterns1dTiling, 620 testMatmulToVectorPatterns2dTiling); 621 if (testVectorTransferForwardingPatterns) 622 return applyVectorTransferForwardingPatterns(getFunction()); 623 if (testGenericToVectorPattern) 624 return applyLinalgToVectorPatterns(getFunction()); 625 if (testTransformPadTensor) 626 return applyPadTensorToGenericPatterns(getFunction()); 627 if (testGeneralizePadTensor) 628 return applyGeneralizePadTensorPatterns(getFunction()); 629 if (testSwapSubTensorPadTensor) 630 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 631 if (testAffineMinSCFCanonicalizationPatterns) 632 return applyAffineMinSCFCanonicalizationPatterns(getFunction()); 633 if (testTileAndPadPattern) 634 return applyTileAndPadPattern(getFunction(), tileSizesForPadding); 635 if (testHoistPadding) { 636 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 637 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 638 }); 639 } 640 if (testInterchangePattern.hasValue()) 641 return applyInterchangePattern(getFunction(), testInterchangePattern); 642 } 643 644 namespace mlir { 645 namespace test { 646 void registerTestLinalgTransforms() { 647 PassRegistration<TestLinalgTransforms>(); 648 } 649 } // namespace test 650 } // namespace mlir 651