1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 #include "llvm/ADT/SetVector.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 namespace { 30 struct TestLinalgTransforms 31 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 32 TestLinalgTransforms() = default; 33 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 34 35 void getDependentDialects(DialectRegistry ®istry) const override { 36 // clang-format off 37 registry.insert<AffineDialect, 38 memref::MemRefDialect, 39 scf::SCFDialect, 40 StandardOpsDialect, 41 vector::VectorDialect, 42 gpu::GPUDialect>(); 43 // clang-format on 44 } 45 StringRef getArgument() const final { 46 return "test-linalg-transform-patterns"; 47 } 48 StringRef getDescription() const final { 49 return "Test Linalg transformation patterns by applying them greedily."; 50 } 51 52 void runOnFunction() override; 53 54 Option<bool> testPatterns{*this, "test-patterns", 55 llvm::cl::desc("Test a mixed set of patterns"), 56 llvm::cl::init(false)}; 57 Option<bool> testMatmulToVectorPatterns1dTiling{ 58 *this, "test-matmul-to-vector-patterns-tile-1d", 59 llvm::cl::desc( 60 "Test a fused pass that applies patterns from matmul to vectors via " 61 "1-d tiling"), 62 llvm::cl::init(false)}; 63 Option<bool> testMatmulToVectorPatterns2dTiling{ 64 *this, "test-matmul-to-vector-patterns-tile-2d", 65 llvm::cl::desc( 66 "Test a fused pass that applies patterns from matmul to vectors via " 67 "2-d tiling"), 68 llvm::cl::init(false)}; 69 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 70 llvm::cl::desc("Test promotion options"), 71 llvm::cl::init(false)}; 72 Option<bool> testTileAndDistributionOptions{ 73 *this, "test-tile-and-distribute-options", 74 llvm::cl::desc("Test tile and distribute options"), 75 llvm::cl::init(false)}; 76 Option<bool> testVectorTransferForwardingPatterns{ 77 *this, "test-vector-transfer-forwarding-patterns", 78 llvm::cl::desc( 79 "Test a fused pass that forwards linalg.copy to vector.transfer"), 80 llvm::cl::init(false)}; 81 Option<bool> testGenericToVectorPattern{ 82 *this, "test-linalg-to-vector-patterns", 83 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 84 "in vector.contract form"), 85 llvm::cl::init(false)}; 86 Option<bool> testTileAndPadPattern{ 87 *this, "test-tile-and-pad-pattern", 88 llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)}; 89 Option<int> testHoistPadding{*this, "test-hoist-padding", 90 llvm::cl::desc("Test hoist padding"), 91 llvm::cl::init(0)}; 92 Option<bool> testTransformPadTensor{ 93 *this, "test-transform-pad-tensor", 94 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 95 llvm::cl::init(false)}; 96 Option<bool> testGeneralizePadTensor{ 97 *this, "test-generalize-pad-tensor", 98 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 99 llvm::cl::init(false)}; 100 Option<bool> testSwapSubTensorPadTensor{ 101 *this, "test-swap-subtensor-padtensor", 102 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 103 "pad_tensor(subtensor)"), 104 llvm::cl::init(false)}; 105 ListOption<int64_t> tileSizesForPadding{ 106 *this, "tile-sizes-for-padding", 107 llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, 108 llvm::cl::MiscFlags::CommaSeparated}; 109 ListOption<unsigned> testInterchangePattern{ 110 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 111 llvm::cl::desc("Test the interchange pattern.")}; 112 }; 113 } // end anonymous namespace 114 115 static void applyPatterns(FuncOp funcOp) { 116 MLIRContext *ctx = funcOp.getContext(); 117 RewritePatternSet patterns(ctx); 118 119 //===--------------------------------------------------------------------===// 120 // Linalg tiling patterns. 121 //===--------------------------------------------------------------------===// 122 patterns.add<LinalgTilingPattern<MatmulOp>>( 123 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 124 LinalgTransformationFilter(Identifier::get("MEM", ctx), 125 Identifier::get("L3", ctx))); 126 patterns.add<LinalgTilingPattern<MatmulOp>>( 127 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 128 LinalgTransformationFilter(Identifier::get("L3", ctx), 129 Identifier::get("L2", ctx))); 130 patterns.add<LinalgTilingPattern<MatmulOp>>( 131 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 132 LinalgTransformationFilter(Identifier::get("L2", ctx), 133 Identifier::get("L1", ctx))); 134 patterns.add<LinalgTilingPattern<MatmulOp>>( 135 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 136 LinalgTransformationFilter(Identifier::get("L1", ctx), 137 Identifier::get("REG", ctx))); 138 139 patterns.add<LinalgTilingPattern<MatvecOp>>( 140 ctx, 141 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 142 LinalgTilingLoopType::ParallelLoops), 143 LinalgTransformationFilter(ArrayRef<Identifier>{}, 144 Identifier::get("L1", ctx))); 145 146 patterns.add<LinalgTilingPattern<DotOp>>( 147 ctx, LinalgTilingOptions().setTileSizes(8000), 148 LinalgTransformationFilter( 149 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 150 Identifier::get("L3", ctx), 151 Identifier::get("L2", ctx)}, 152 Identifier::get("REG", ctx))); 153 154 //===--------------------------------------------------------------------===// 155 // Linalg tiling and permutation patterns. 156 //===--------------------------------------------------------------------===// 157 patterns.add<LinalgTilingPattern<MatmulOp>>( 158 ctx, 159 LinalgTilingOptions() 160 .setTileSizes({2000, 3000, 4000}) 161 .setInterchange({1, 2, 0}), 162 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 163 Identifier::get("L2__with_perm__", ctx))); 164 patterns.add<LinalgTilingPattern<MatmulOp>>( 165 ctx, 166 LinalgTilingOptions() 167 .setTileSizes({200, 300, 400}) 168 .setInterchange({1, 0, 2}), 169 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 170 Identifier::get("L1__with_perm__", ctx))); 171 patterns.add<LinalgTilingPattern<MatmulOp>>( 172 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 173 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 174 Identifier::get("REG__with_perm__", ctx))); 175 176 patterns.add<LinalgTilingPattern<MatvecOp>>( 177 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 178 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 179 Identifier::get("L1__with_perm__", ctx))); 180 181 patterns.add<LinalgTilingPattern<MatmulOp>>( 182 ctx, 183 LinalgTilingOptions() 184 .setTileSizes({16, 8, 4}) 185 .setInterchange({1, 2, 0}) 186 .setLoopType(LinalgTilingLoopType::ParallelLoops), 187 LinalgTransformationFilter( 188 Identifier::get("par__with_perm__", ctx), 189 Identifier::get("after_par__with_perm__", ctx))); 190 191 //===--------------------------------------------------------------------===// 192 // Linalg to loops patterns. 193 //===--------------------------------------------------------------------===// 194 patterns.add<LinalgLoweringPattern<DotOp>>( 195 ctx, 196 /*loweringType=*/LinalgLoweringType::Loops, 197 LinalgTransformationFilter(Identifier::get("REG", ctx))); 198 199 //===--------------------------------------------------------------------===// 200 // Linalg distribution patterns. 201 //===--------------------------------------------------------------------===// 202 LinalgLoopDistributionOptions distributionOptions; 203 204 //===--------------------------------------------------------------------===// 205 // Linalg to vector contraction patterns. 206 //===--------------------------------------------------------------------===// 207 patterns.add<LinalgVectorizationPattern>( 208 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 209 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 210 211 //===--------------------------------------------------------------------===// 212 // Linalg generic interchange pattern. 213 //===--------------------------------------------------------------------===// 214 patterns.add<GenericOpInterchangePattern>( 215 ctx, 216 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 217 LinalgTransformationFilter(ArrayRef<Identifier>{}, 218 Identifier::get("PERMUTED", ctx))); 219 220 //===--------------------------------------------------------------------===// 221 // Linalg subview operands promotion. 222 //===--------------------------------------------------------------------===// 223 patterns.add<LinalgPromotionPattern<MatmulOp>>( 224 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 225 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 226 Identifier::get("_views_promoted_", ctx))); 227 patterns.add<LinalgPromotionPattern<MatmulOp>>( 228 ctx, 229 LinalgPromotionOptions() 230 .setOperandsToPromote({0}) 231 .setUseFullTileBuffersByDefault(true), 232 LinalgTransformationFilter( 233 Identifier::get("_promote_first_view_", ctx), 234 Identifier::get("_first_view_promoted_", ctx))); 235 patterns.add<LinalgPromotionPattern<FillOp>>( 236 ctx, 237 LinalgPromotionOptions() 238 .setOperandsToPromote({1}) 239 .setUseFullTileBuffers({false, true}) 240 .setAlignment(32), 241 LinalgTransformationFilter( 242 Identifier::get("_promote_views_aligned_", ctx), 243 Identifier::get("_views_aligned_promoted_", ctx))); 244 245 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 246 247 // Drop the marker. 248 funcOp.walk([](LinalgOp op) { 249 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 250 }); 251 } 252 253 static void fillL1TilingAndMatmulToVectorPatterns( 254 FuncOp funcOp, StringRef startMarker, 255 SmallVectorImpl<RewritePatternSet> &patternsVector) { 256 MLIRContext *ctx = funcOp.getContext(); 257 patternsVector.emplace_back( 258 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 259 ctx, 260 LinalgTilingOptions() 261 .setTileSizes({8, 12, 16}) 262 .setInterchange({1, 0, 2}), 263 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 264 Identifier::get("L1", ctx)))); 265 266 patternsVector.emplace_back( 267 ctx, 268 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 269 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 270 LinalgTransformationFilter(Identifier::get("L1", ctx), 271 Identifier::get("VEC", ctx)))); 272 273 patternsVector.emplace_back( 274 ctx, std::make_unique<LinalgVectorizationPattern>( 275 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 276 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 277 patternsVector.back().add<LinalgVectorizationPattern>( 278 ctx, LinalgTransformationFilter().addFilter( 279 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 280 } 281 282 //===----------------------------------------------------------------------===// 283 // Test promotion callbacks 284 //===----------------------------------------------------------------------===// 285 286 // Allocation call back 287 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 288 ArrayRef<Value> boundingSubViewSize, 289 DataLayout &layout) { 290 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 291 return b 292 .create<memref::AllocOp>( 293 subView.getLoc(), 294 MemRefType::get(shape, subView.getType().getElementType(), 295 /*affineMapComposition =*/{}, 3), 296 boundingSubViewSize) 297 .getResult(); 298 } 299 300 // Deallocation callback 301 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 302 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 303 return success(); 304 } 305 306 // Copy in call back 307 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 308 bool isOutput) { 309 auto floatType = src.getType().cast<MemRefType>().getElementType(); 310 if (!floatType.isa<FloatType>()) 311 return failure(); 312 if (!isOutput) { 313 Value cst = 314 b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)); 315 b.create<FillOp>(src.getLoc(), cst, dst); 316 } 317 b.create<CopyOp>(src.getLoc(), src, dst); 318 return success(); 319 } 320 321 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 322 RewritePatternSet &patterns) { 323 patterns.add<LinalgTilingPattern<MatmulOp>>( 324 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 325 LinalgTransformationFilter(Identifier::get("START", ctx), 326 Identifier::get("PROMOTE", ctx))); 327 patterns.add<LinalgPromotionPattern<MatmulOp>>( 328 ctx, 329 LinalgPromotionOptions() 330 .setOperandsToPromote({0, 2}) 331 .setUseFullTileBuffers({false, false}) 332 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 333 .setCopyInOutFns( 334 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 335 return copyCallBackFn(b, src, dst, false); 336 }, 337 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 338 return copyCallBackFn(b, src, dst, true); 339 }), 340 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 341 } 342 343 template <typename IdOp, typename NProcsOp> 344 static SmallVector<ProcInfo, 2> 345 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 346 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 347 SmallVector<ProcInfo, 2> procInfo(count); 348 const char *xyz[] = {"x", "y", "z"}; 349 Type indexType = b.getIndexType(); 350 for (unsigned i = 0; i < count; ++i) { 351 procInfo[count - 1 - i] = { 352 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 353 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 354 } 355 return procInfo; 356 } 357 358 static void fillTileAndDistributePatterns(MLIRContext *context, 359 RewritePatternSet &patterns) { 360 { 361 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 362 cyclicNprocsEqNiters.distributionMethod.resize( 363 2, DistributionMethod::CyclicNumProcsEqNumIters); 364 cyclicNprocsEqNiters.procInfo = 365 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 366 patterns.add<LinalgTilingPattern<MatmulOp>>( 367 context, 368 LinalgTilingOptions() 369 .setTileSizes({8, 8, 4}) 370 .setLoopType(LinalgTilingLoopType::ParallelLoops) 371 .setDistributionOptions(cyclicNprocsEqNiters), 372 LinalgTransformationFilter( 373 Identifier::get("distribute1", context), 374 Identifier::get("after_distribute1", context))); 375 } 376 377 { 378 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 379 cyclicNprocsGeNiters.distributionMethod.resize( 380 2, DistributionMethod::CyclicNumProcsGeNumIters); 381 cyclicNprocsGeNiters.procInfo = 382 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 383 patterns.add<LinalgTilingPattern<MatmulOp>>( 384 context, 385 LinalgTilingOptions() 386 .setTileSizes({8, 8, 4}) 387 .setLoopType(LinalgTilingLoopType::ParallelLoops) 388 .setDistributionOptions(cyclicNprocsGeNiters), 389 LinalgTransformationFilter( 390 Identifier::get("distribute2", context), 391 Identifier::get("after_distribute2", context))); 392 } 393 394 { 395 LinalgLoopDistributionOptions cyclicNprocsDefault; 396 cyclicNprocsDefault.distributionMethod.resize(2, 397 DistributionMethod::Cyclic); 398 cyclicNprocsDefault.procInfo = 399 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 400 patterns.add<LinalgTilingPattern<MatmulOp>>( 401 context, 402 LinalgTilingOptions() 403 .setTileSizes({8, 8, 4}) 404 .setLoopType(LinalgTilingLoopType::ParallelLoops) 405 .setDistributionOptions(cyclicNprocsDefault), 406 LinalgTransformationFilter( 407 Identifier::get("distribute3", context), 408 Identifier::get("after_distribute3", context))); 409 } 410 411 { 412 LinalgLoopDistributionOptions cyclicNprocsMixed1; 413 cyclicNprocsMixed1.distributionMethod = { 414 DistributionMethod::CyclicNumProcsEqNumIters, 415 DistributionMethod::CyclicNumProcsGeNumIters}; 416 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 417 patterns.add<LinalgTilingPattern<MatmulOp>>( 418 context, 419 LinalgTilingOptions() 420 .setTileSizes({8, 8, 4}) 421 .setLoopType(LinalgTilingLoopType::ParallelLoops) 422 .setDistributionOptions(cyclicNprocsMixed1), 423 LinalgTransformationFilter( 424 Identifier::get("distribute4", context), 425 Identifier::get("after_distribute4", context))); 426 } 427 428 { 429 LinalgLoopDistributionOptions cyclicNprocsMixed2; 430 cyclicNprocsMixed2.distributionMethod = { 431 DistributionMethod::CyclicNumProcsGeNumIters, 432 DistributionMethod::Cyclic}; 433 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 434 patterns.add<LinalgTilingPattern<MatmulOp>>( 435 context, 436 LinalgTilingOptions() 437 .setTileSizes({8, 8, 4}) 438 .setLoopType(LinalgTilingLoopType::ParallelLoops) 439 .setDistributionOptions(cyclicNprocsMixed2), 440 LinalgTransformationFilter( 441 Identifier::get("distribute5", context), 442 Identifier::get("after_distribute5", context))); 443 } 444 445 { 446 LinalgLoopDistributionOptions cyclicNprocsMixed3; 447 cyclicNprocsMixed3.distributionMethod = { 448 DistributionMethod::Cyclic, 449 DistributionMethod::CyclicNumProcsEqNumIters}; 450 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 451 452 patterns.add<LinalgTilingPattern<MatmulOp>>( 453 context, 454 LinalgTilingOptions() 455 .setTileSizes({8, 8, 4}) 456 .setLoopType(LinalgTilingLoopType::ParallelLoops) 457 .setDistributionOptions(cyclicNprocsMixed3), 458 LinalgTransformationFilter( 459 Identifier::get("distribute6", context), 460 Identifier::get("after_distribute6", context))); 461 } 462 463 { 464 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 465 cyclicNprocsEqNiters.distributionMethod.resize(2, 466 DistributionMethod::Cyclic); 467 cyclicNprocsEqNiters.procInfo = 468 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 469 patterns.add<LinalgTilingPattern<MatmulOp>>( 470 context, 471 LinalgTilingOptions() 472 .setTileSizes({8, 8, 4}) 473 .setLoopType(LinalgTilingLoopType::Loops) 474 .setDistributionOptions(cyclicNprocsEqNiters), 475 LinalgTransformationFilter( 476 Identifier::get("tensors_distribute1", context), 477 Identifier::get("tensors_after_distribute1", context))); 478 } 479 } 480 481 static void 482 applyMatmulToVectorPatterns(FuncOp funcOp, 483 bool testMatmulToVectorPatterns1dTiling, 484 bool testMatmulToVectorPatterns2dTiling) { 485 MLIRContext *ctx = funcOp.getContext(); 486 SmallVector<RewritePatternSet, 4> stage1Patterns; 487 if (testMatmulToVectorPatterns1dTiling) { 488 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 489 stage1Patterns); 490 } else if (testMatmulToVectorPatterns2dTiling) { 491 stage1Patterns.emplace_back( 492 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 493 ctx, 494 LinalgTilingOptions() 495 .setTileSizes({768, 264, 768}) 496 .setInterchange({1, 2, 0}), 497 LinalgTransformationFilter(Identifier::get("START", ctx), 498 Identifier::get("L2", ctx)))); 499 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 500 stage1Patterns); 501 } 502 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 503 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 504 FrozenRewritePatternSet stage2Patterns = 505 getLinalgTilingCanonicalizationPatterns(ctx); 506 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 507 std::move(stage2Patterns)); 508 } 509 510 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 511 RewritePatternSet forwardPattern(funcOp.getContext()); 512 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 513 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 514 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 515 } 516 517 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 518 RewritePatternSet patterns(funcOp.getContext()); 519 patterns.add<LinalgVectorizationPattern>( 520 funcOp.getContext(), 521 LinalgTransformationFilter() 522 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 523 populatePadTensorOpVectorizationPatterns(patterns); 524 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 525 } 526 527 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 528 RewritePatternSet patterns(funcOp.getContext()); 529 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 530 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 531 } 532 533 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 534 RewritePatternSet patterns(funcOp.getContext()); 535 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 536 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 537 } 538 539 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 540 RewritePatternSet patterns(funcOp.getContext()); 541 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 542 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 543 } 544 545 // For now, just assume it is the zero of type. 546 // In the future, it should be the zero of type + op. 547 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 548 auto t = getElementTypeOrSelf(op.get()); 549 return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); 550 } 551 552 static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) { 553 MLIRContext *context = funcOp.getContext(); 554 RewritePatternSet tilingPattern(context); 555 auto linalgTilingOptions = 556 linalg::LinalgTilingOptions() 557 .setTileSizes(tileSizes) 558 .setPaddingValueComputationFunction(getNeutralOfLinalgOp); 559 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>, 560 linalg::LinalgTilingPattern<linalg::GenericOp>>( 561 context, linalgTilingOptions, 562 linalg::LinalgTransformationFilter( 563 Identifier::get("tile-and-pad", context))); 564 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 565 } 566 567 static void applyInterchangePattern(FuncOp funcOp, 568 ArrayRef<unsigned> interchangeVector) { 569 MLIRContext *context = funcOp.getContext(); 570 RewritePatternSet interchangePattern(context); 571 interchangePattern.add<GenericOpInterchangePattern>( 572 context, interchangeVector, 573 LinalgTransformationFilter(ArrayRef<Identifier>{}, 574 Identifier::get("interchange", context))); 575 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 576 } 577 578 /// Apply transformations specified as patterns. 579 void TestLinalgTransforms::runOnFunction() { 580 auto lambda = [&](void *) { 581 getFunction().walk([](LinalgOp op) { 582 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 583 }); 584 }; 585 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 586 587 if (testPromotionOptions) { 588 RewritePatternSet patterns(&getContext()); 589 fillPromotionCallBackPatterns(&getContext(), patterns); 590 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 591 return; 592 } 593 if (testTileAndDistributionOptions) { 594 RewritePatternSet patterns(&getContext()); 595 fillTileAndDistributePatterns(&getContext(), patterns); 596 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 597 return; 598 } 599 if (testPatterns) 600 return applyPatterns(getFunction()); 601 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 602 return applyMatmulToVectorPatterns(getFunction(), 603 testMatmulToVectorPatterns1dTiling, 604 testMatmulToVectorPatterns2dTiling); 605 if (testVectorTransferForwardingPatterns) 606 return applyVectorTransferForwardingPatterns(getFunction()); 607 if (testGenericToVectorPattern) 608 return applyLinalgToVectorPatterns(getFunction()); 609 if (testTransformPadTensor) 610 return applyPadTensorToGenericPatterns(getFunction()); 611 if (testGeneralizePadTensor) 612 return applyGeneralizePadTensorPatterns(getFunction()); 613 if (testSwapSubTensorPadTensor) 614 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 615 if (testTileAndPadPattern) 616 return applyTileAndPadPattern(getFunction(), tileSizesForPadding); 617 if (testHoistPadding) { 618 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 619 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 620 }); 621 } 622 if (testInterchangePattern.hasValue()) 623 return applyInterchangePattern(getFunction(), testInterchangePattern); 624 } 625 626 namespace mlir { 627 namespace test { 628 void registerTestLinalgTransforms() { 629 PassRegistration<TestLinalgTransforms>(); 630 } 631 } // namespace test 632 } // namespace mlir 633