1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 #include "llvm/ADT/SetVector.h" 25 #include "llvm/ADT/SmallVector.h" 26 27 using namespace mlir; 28 using namespace mlir::linalg; 29 30 namespace { 31 struct TestLinalgTransforms 32 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 33 TestLinalgTransforms() = default; 34 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 35 36 void getDependentDialects(DialectRegistry ®istry) const override { 37 // clang-format off 38 registry.insert<AffineDialect, 39 memref::MemRefDialect, 40 scf::SCFDialect, 41 StandardOpsDialect, 42 vector::VectorDialect, 43 gpu::GPUDialect>(); 44 // clang-format on 45 } 46 StringRef getArgument() const final { 47 return "test-linalg-transform-patterns"; 48 } 49 StringRef getDescription() const final { 50 return "Test Linalg transformation patterns by applying them greedily."; 51 } 52 53 void runOnFunction() override; 54 55 Option<bool> testPatterns{*this, "test-patterns", 56 llvm::cl::desc("Test a mixed set of patterns"), 57 llvm::cl::init(false)}; 58 Option<bool> testMatmulToVectorPatterns1dTiling{ 59 *this, "test-matmul-to-vector-patterns-tile-1d", 60 llvm::cl::desc( 61 "Test a fused pass that applies patterns from matmul to vectors via " 62 "1-d tiling"), 63 llvm::cl::init(false)}; 64 Option<bool> testMatmulToVectorPatterns2dTiling{ 65 *this, "test-matmul-to-vector-patterns-tile-2d", 66 llvm::cl::desc( 67 "Test a fused pass that applies patterns from matmul to vectors via " 68 "2-d tiling"), 69 llvm::cl::init(false)}; 70 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 71 llvm::cl::desc("Test promotion options"), 72 llvm::cl::init(false)}; 73 Option<bool> testTileAndDistributionOptions{ 74 *this, "test-tile-and-distribute-options", 75 llvm::cl::desc("Test tile and distribute options"), 76 llvm::cl::init(false)}; 77 Option<bool> testVectorTransferForwardingPatterns{ 78 *this, "test-vector-transfer-forwarding-patterns", 79 llvm::cl::desc( 80 "Test a fused pass that forwards linalg.copy to vector.transfer"), 81 llvm::cl::init(false)}; 82 Option<bool> testGenericToVectorPattern{ 83 *this, "test-linalg-to-vector-patterns", 84 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 85 "in vector.contract form"), 86 llvm::cl::init(false)}; 87 Option<bool> testTileAndPadPattern{ 88 *this, "test-tile-and-pad-pattern", 89 llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)}; 90 Option<int> testHoistPadding{*this, "test-hoist-padding", 91 llvm::cl::desc("Test hoist padding"), 92 llvm::cl::init(0)}; 93 Option<bool> testTransformPadTensor{ 94 *this, "test-transform-pad-tensor", 95 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 96 llvm::cl::init(false)}; 97 Option<bool> testGeneralizePadTensor{ 98 *this, "test-generalize-pad-tensor", 99 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 100 llvm::cl::init(false)}; 101 Option<bool> testSwapSubTensorPadTensor{ 102 *this, "test-swap-subtensor-padtensor", 103 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 104 "pad_tensor(subtensor)"), 105 llvm::cl::init(false)}; 106 ListOption<int64_t> tileSizesForPadding{ 107 *this, "tile-sizes-for-padding", 108 llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, 109 llvm::cl::MiscFlags::CommaSeparated}; 110 ListOption<unsigned> testInterchangePattern{ 111 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 112 llvm::cl::desc("Test the interchange pattern.")}; 113 ListOption<unsigned> testTiledLoopPeeling{ 114 *this, "test-tiled-loop-peeling", 115 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 116 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 117 }; 118 } // end anonymous namespace 119 120 static void applyPatterns(FuncOp funcOp) { 121 MLIRContext *ctx = funcOp.getContext(); 122 RewritePatternSet patterns(ctx); 123 124 //===--------------------------------------------------------------------===// 125 // Linalg tiling patterns. 126 //===--------------------------------------------------------------------===// 127 patterns.add<LinalgTilingPattern<MatmulOp>>( 128 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 129 LinalgTransformationFilter(Identifier::get("MEM", ctx), 130 Identifier::get("L3", ctx))); 131 patterns.add<LinalgTilingPattern<MatmulOp>>( 132 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 133 LinalgTransformationFilter(Identifier::get("L3", ctx), 134 Identifier::get("L2", ctx))); 135 patterns.add<LinalgTilingPattern<MatmulOp>>( 136 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 137 LinalgTransformationFilter(Identifier::get("L2", ctx), 138 Identifier::get("L1", ctx))); 139 patterns.add<LinalgTilingPattern<MatmulOp>>( 140 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 141 LinalgTransformationFilter(Identifier::get("L1", ctx), 142 Identifier::get("REG", ctx))); 143 144 patterns.add<LinalgTilingPattern<MatvecOp>>( 145 ctx, 146 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 147 LinalgTilingLoopType::ParallelLoops), 148 LinalgTransformationFilter(ArrayRef<Identifier>{}, 149 Identifier::get("L1", ctx))); 150 151 patterns.add<LinalgTilingPattern<DotOp>>( 152 ctx, LinalgTilingOptions().setTileSizes(8000), 153 LinalgTransformationFilter( 154 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 155 Identifier::get("L3", ctx), 156 Identifier::get("L2", ctx)}, 157 Identifier::get("REG", ctx))); 158 159 //===--------------------------------------------------------------------===// 160 // Linalg tiling and permutation patterns. 161 //===--------------------------------------------------------------------===// 162 patterns.add<LinalgTilingPattern<MatmulOp>>( 163 ctx, 164 LinalgTilingOptions() 165 .setTileSizes({2000, 3000, 4000}) 166 .setInterchange({1, 2, 0}), 167 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 168 Identifier::get("L2__with_perm__", ctx))); 169 patterns.add<LinalgTilingPattern<MatmulOp>>( 170 ctx, 171 LinalgTilingOptions() 172 .setTileSizes({200, 300, 400}) 173 .setInterchange({1, 0, 2}), 174 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 175 Identifier::get("L1__with_perm__", ctx))); 176 patterns.add<LinalgTilingPattern<MatmulOp>>( 177 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 178 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 179 Identifier::get("REG__with_perm__", ctx))); 180 181 patterns.add<LinalgTilingPattern<MatvecOp>>( 182 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 183 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 184 Identifier::get("L1__with_perm__", ctx))); 185 186 patterns.add<LinalgTilingPattern<MatmulOp>>( 187 ctx, 188 LinalgTilingOptions() 189 .setTileSizes({16, 8, 4}) 190 .setInterchange({1, 2, 0}) 191 .setLoopType(LinalgTilingLoopType::ParallelLoops), 192 LinalgTransformationFilter( 193 Identifier::get("par__with_perm__", ctx), 194 Identifier::get("after_par__with_perm__", ctx))); 195 196 //===--------------------------------------------------------------------===// 197 // Linalg to loops patterns. 198 //===--------------------------------------------------------------------===// 199 patterns.add<LinalgLoweringPattern<DotOp>>( 200 ctx, 201 /*loweringType=*/LinalgLoweringType::Loops, 202 LinalgTransformationFilter(Identifier::get("REG", ctx))); 203 204 //===--------------------------------------------------------------------===// 205 // Linalg distribution patterns. 206 //===--------------------------------------------------------------------===// 207 LinalgLoopDistributionOptions distributionOptions; 208 209 //===--------------------------------------------------------------------===// 210 // Linalg to vector contraction patterns. 211 //===--------------------------------------------------------------------===// 212 patterns.add<LinalgVectorizationPattern>( 213 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 214 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 215 216 //===--------------------------------------------------------------------===// 217 // Linalg generic interchange pattern. 218 //===--------------------------------------------------------------------===// 219 patterns.add<GenericOpInterchangePattern>( 220 ctx, 221 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 222 LinalgTransformationFilter(ArrayRef<Identifier>{}, 223 Identifier::get("PERMUTED", ctx))); 224 225 //===--------------------------------------------------------------------===// 226 // Linalg subview operands promotion. 227 //===--------------------------------------------------------------------===// 228 patterns.add<LinalgPromotionPattern<MatmulOp>>( 229 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 230 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 231 Identifier::get("_views_promoted_", ctx))); 232 patterns.add<LinalgPromotionPattern<MatmulOp>>( 233 ctx, 234 LinalgPromotionOptions() 235 .setOperandsToPromote({0}) 236 .setUseFullTileBuffersByDefault(true), 237 LinalgTransformationFilter( 238 Identifier::get("_promote_first_view_", ctx), 239 Identifier::get("_first_view_promoted_", ctx))); 240 patterns.add<LinalgPromotionPattern<FillOp>>( 241 ctx, 242 LinalgPromotionOptions() 243 .setOperandsToPromote({1}) 244 .setUseFullTileBuffers({false, true}) 245 .setAlignment(32), 246 LinalgTransformationFilter( 247 Identifier::get("_promote_views_aligned_", ctx), 248 Identifier::get("_views_aligned_promoted_", ctx))); 249 250 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 251 252 // Drop the marker. 253 funcOp.walk([](LinalgOp op) { 254 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 255 }); 256 } 257 258 static void fillL1TilingAndMatmulToVectorPatterns( 259 FuncOp funcOp, StringRef startMarker, 260 SmallVectorImpl<RewritePatternSet> &patternsVector) { 261 MLIRContext *ctx = funcOp.getContext(); 262 patternsVector.emplace_back( 263 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 264 ctx, 265 LinalgTilingOptions() 266 .setTileSizes({8, 12, 16}) 267 .setInterchange({1, 0, 2}), 268 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 269 Identifier::get("L1", ctx)))); 270 271 patternsVector.emplace_back( 272 ctx, 273 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 274 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 275 LinalgTransformationFilter(Identifier::get("L1", ctx), 276 Identifier::get("VEC", ctx)))); 277 278 patternsVector.emplace_back( 279 ctx, std::make_unique<LinalgVectorizationPattern>( 280 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 281 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 282 patternsVector.back().add<LinalgVectorizationPattern>( 283 ctx, LinalgTransformationFilter().addFilter( 284 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 285 } 286 287 //===----------------------------------------------------------------------===// 288 // Test promotion callbacks 289 //===----------------------------------------------------------------------===// 290 291 // Allocation call back 292 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 293 ArrayRef<Value> boundingSubViewSize, 294 DataLayout &layout) { 295 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 296 return b 297 .create<memref::AllocOp>( 298 subView.getLoc(), 299 MemRefType::get(shape, subView.getType().getElementType(), 300 /*affineMapComposition =*/{}, 3), 301 boundingSubViewSize) 302 .getResult(); 303 } 304 305 // Deallocation callback 306 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 307 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 308 return success(); 309 } 310 311 // Copy in call back 312 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 313 bool isOutput) { 314 auto floatType = src.getType().cast<MemRefType>().getElementType(); 315 if (!floatType.isa<FloatType>()) 316 return failure(); 317 if (!isOutput) { 318 Value cst = 319 b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)); 320 b.create<FillOp>(src.getLoc(), cst, dst); 321 } 322 b.create<CopyOp>(src.getLoc(), src, dst); 323 return success(); 324 } 325 326 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 327 RewritePatternSet &patterns) { 328 patterns.add<LinalgTilingPattern<MatmulOp>>( 329 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 330 LinalgTransformationFilter(Identifier::get("START", ctx), 331 Identifier::get("PROMOTE", ctx))); 332 patterns.add<LinalgPromotionPattern<MatmulOp>>( 333 ctx, 334 LinalgPromotionOptions() 335 .setOperandsToPromote({0, 2}) 336 .setUseFullTileBuffers({false, false}) 337 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 338 .setCopyInOutFns( 339 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 340 return copyCallBackFn(b, src, dst, false); 341 }, 342 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 343 return copyCallBackFn(b, src, dst, true); 344 }), 345 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 346 } 347 348 template <typename IdOp, typename NProcsOp> 349 static SmallVector<ProcInfo, 2> 350 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 351 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 352 SmallVector<ProcInfo, 2> procInfo(count); 353 const char *xyz[] = {"x", "y", "z"}; 354 Type indexType = b.getIndexType(); 355 for (unsigned i = 0; i < count; ++i) { 356 procInfo[count - 1 - i] = { 357 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 358 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 359 } 360 return procInfo; 361 } 362 363 static void fillTileAndDistributePatterns(MLIRContext *context, 364 RewritePatternSet &patterns) { 365 { 366 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 367 cyclicNprocsEqNiters.distributionMethod.resize( 368 2, DistributionMethod::CyclicNumProcsEqNumIters); 369 cyclicNprocsEqNiters.procInfo = 370 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 371 patterns.add<LinalgTilingPattern<MatmulOp>>( 372 context, 373 LinalgTilingOptions() 374 .setTileSizes({8, 8, 4}) 375 .setLoopType(LinalgTilingLoopType::ParallelLoops) 376 .setDistributionOptions(cyclicNprocsEqNiters), 377 LinalgTransformationFilter( 378 Identifier::get("distribute1", context), 379 Identifier::get("after_distribute1", context))); 380 } 381 382 { 383 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 384 cyclicNprocsGeNiters.distributionMethod.resize( 385 2, DistributionMethod::CyclicNumProcsGeNumIters); 386 cyclicNprocsGeNiters.procInfo = 387 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 388 patterns.add<LinalgTilingPattern<MatmulOp>>( 389 context, 390 LinalgTilingOptions() 391 .setTileSizes({8, 8, 4}) 392 .setLoopType(LinalgTilingLoopType::ParallelLoops) 393 .setDistributionOptions(cyclicNprocsGeNiters), 394 LinalgTransformationFilter( 395 Identifier::get("distribute2", context), 396 Identifier::get("after_distribute2", context))); 397 } 398 399 { 400 LinalgLoopDistributionOptions cyclicNprocsDefault; 401 cyclicNprocsDefault.distributionMethod.resize(2, 402 DistributionMethod::Cyclic); 403 cyclicNprocsDefault.procInfo = 404 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 405 patterns.add<LinalgTilingPattern<MatmulOp>>( 406 context, 407 LinalgTilingOptions() 408 .setTileSizes({8, 8, 4}) 409 .setLoopType(LinalgTilingLoopType::ParallelLoops) 410 .setDistributionOptions(cyclicNprocsDefault), 411 LinalgTransformationFilter( 412 Identifier::get("distribute3", context), 413 Identifier::get("after_distribute3", context))); 414 } 415 416 { 417 LinalgLoopDistributionOptions cyclicNprocsMixed1; 418 cyclicNprocsMixed1.distributionMethod = { 419 DistributionMethod::CyclicNumProcsEqNumIters, 420 DistributionMethod::CyclicNumProcsGeNumIters}; 421 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 422 patterns.add<LinalgTilingPattern<MatmulOp>>( 423 context, 424 LinalgTilingOptions() 425 .setTileSizes({8, 8, 4}) 426 .setLoopType(LinalgTilingLoopType::ParallelLoops) 427 .setDistributionOptions(cyclicNprocsMixed1), 428 LinalgTransformationFilter( 429 Identifier::get("distribute4", context), 430 Identifier::get("after_distribute4", context))); 431 } 432 433 { 434 LinalgLoopDistributionOptions cyclicNprocsMixed2; 435 cyclicNprocsMixed2.distributionMethod = { 436 DistributionMethod::CyclicNumProcsGeNumIters, 437 DistributionMethod::Cyclic}; 438 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 439 patterns.add<LinalgTilingPattern<MatmulOp>>( 440 context, 441 LinalgTilingOptions() 442 .setTileSizes({8, 8, 4}) 443 .setLoopType(LinalgTilingLoopType::ParallelLoops) 444 .setDistributionOptions(cyclicNprocsMixed2), 445 LinalgTransformationFilter( 446 Identifier::get("distribute5", context), 447 Identifier::get("after_distribute5", context))); 448 } 449 450 { 451 LinalgLoopDistributionOptions cyclicNprocsMixed3; 452 cyclicNprocsMixed3.distributionMethod = { 453 DistributionMethod::Cyclic, 454 DistributionMethod::CyclicNumProcsEqNumIters}; 455 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 456 457 patterns.add<LinalgTilingPattern<MatmulOp>>( 458 context, 459 LinalgTilingOptions() 460 .setTileSizes({8, 8, 4}) 461 .setLoopType(LinalgTilingLoopType::ParallelLoops) 462 .setDistributionOptions(cyclicNprocsMixed3), 463 LinalgTransformationFilter( 464 Identifier::get("distribute6", context), 465 Identifier::get("after_distribute6", context))); 466 } 467 468 { 469 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 470 cyclicNprocsEqNiters.distributionMethod.resize(2, 471 DistributionMethod::Cyclic); 472 cyclicNprocsEqNiters.procInfo = 473 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 474 patterns.add<LinalgTilingPattern<MatmulOp>>( 475 context, 476 LinalgTilingOptions() 477 .setTileSizes({8, 8, 4}) 478 .setLoopType(LinalgTilingLoopType::Loops) 479 .setDistributionOptions(cyclicNprocsEqNiters), 480 LinalgTransformationFilter( 481 Identifier::get("tensors_distribute1", context), 482 Identifier::get("tensors_after_distribute1", context))); 483 } 484 } 485 486 static void 487 applyMatmulToVectorPatterns(FuncOp funcOp, 488 bool testMatmulToVectorPatterns1dTiling, 489 bool testMatmulToVectorPatterns2dTiling) { 490 MLIRContext *ctx = funcOp.getContext(); 491 SmallVector<RewritePatternSet, 4> stage1Patterns; 492 if (testMatmulToVectorPatterns1dTiling) { 493 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 494 stage1Patterns); 495 } else if (testMatmulToVectorPatterns2dTiling) { 496 stage1Patterns.emplace_back( 497 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 498 ctx, 499 LinalgTilingOptions() 500 .setTileSizes({768, 264, 768}) 501 .setInterchange({1, 2, 0}), 502 LinalgTransformationFilter(Identifier::get("START", ctx), 503 Identifier::get("L2", ctx)))); 504 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 505 stage1Patterns); 506 } 507 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 508 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 509 FrozenRewritePatternSet stage2Patterns = 510 getLinalgTilingCanonicalizationPatterns(ctx); 511 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 512 std::move(stage2Patterns)); 513 } 514 515 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 516 RewritePatternSet forwardPattern(funcOp.getContext()); 517 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 518 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 519 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 520 } 521 522 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 523 RewritePatternSet patterns(funcOp.getContext()); 524 patterns.add<LinalgVectorizationPattern>( 525 funcOp.getContext(), 526 LinalgTransformationFilter() 527 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 528 populatePadTensorOpVectorizationPatterns(patterns); 529 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 530 } 531 532 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 533 RewritePatternSet patterns(funcOp.getContext()); 534 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 535 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 536 } 537 538 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 539 RewritePatternSet patterns(funcOp.getContext()); 540 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 541 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 542 } 543 544 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 545 RewritePatternSet patterns(funcOp.getContext()); 546 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 547 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 548 } 549 550 // For now, just assume it is the zero of type. 551 // In the future, it should be the zero of type + op. 552 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 553 auto t = getElementTypeOrSelf(op.get()); 554 return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); 555 } 556 557 static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) { 558 MLIRContext *context = funcOp.getContext(); 559 RewritePatternSet tilingPattern(context); 560 auto linalgTilingOptions = 561 linalg::LinalgTilingOptions() 562 .setTileSizes(tileSizes) 563 .setPaddingValueComputationFunction(getNeutralOfLinalgOp); 564 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>, 565 linalg::LinalgTilingPattern<linalg::GenericOp>>( 566 context, linalgTilingOptions, 567 linalg::LinalgTransformationFilter( 568 Identifier::get("tile-and-pad", context))); 569 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 570 } 571 572 static void applyInterchangePattern(FuncOp funcOp, 573 ArrayRef<unsigned> interchangeVector) { 574 MLIRContext *context = funcOp.getContext(); 575 RewritePatternSet interchangePattern(context); 576 interchangePattern.add<GenericOpInterchangePattern>( 577 context, interchangeVector, 578 LinalgTransformationFilter(ArrayRef<Identifier>{}, 579 Identifier::get("interchange", context))); 580 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 581 } 582 583 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 584 585 namespace { 586 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 587 /// `idx`-th loop contains only "full" iterations and a second loop for the 588 /// remaining partial iteration (if any). 589 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 590 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx) 591 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx) {} 592 593 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 594 PatternRewriter &rewriter) const override { 595 SmallVector<int64_t> peeledLoops; 596 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 597 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 598 peeledLoops = 599 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 600 return attr.cast<IntegerAttr>().getInt(); 601 })); 602 // Check if the loop was already peeled. 603 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 604 return failure(); 605 } 606 607 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 608 return failure(); 609 610 // Peel loop and canonicalize. 611 TiledLoopOp result; 612 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 613 result))) 614 return failure(); 615 peeledLoops.push_back(idx); 616 617 // Apply label, so that the same loop is not rewritten a second time. 618 rewriter.updateRootInPlace(loopOp, [&]() { 619 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 620 }); 621 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 622 return success(); 623 } 624 625 /// Index of loop to peel. 626 int64_t idx; 627 }; 628 } // namespace 629 630 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 631 ArrayRef<unsigned> loops) { 632 MLIRContext *ctx = funcOp.getContext(); 633 RewritePatternSet patterns(ctx); 634 for (unsigned idx : loops) 635 patterns.add<TiledLoopPeelingPattern>(ctx, idx); 636 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 637 638 // Drop the marker. 639 funcOp.walk([](TiledLoopOp op) { op->removeAttr(kPeeledLoopsLabel); }); 640 } 641 642 /// Apply transformations specified as patterns. 643 void TestLinalgTransforms::runOnFunction() { 644 auto lambda = [&](void *) { 645 getFunction().walk([](LinalgOp op) { 646 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 647 }); 648 }; 649 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 650 651 if (testPromotionOptions) { 652 RewritePatternSet patterns(&getContext()); 653 fillPromotionCallBackPatterns(&getContext(), patterns); 654 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 655 return; 656 } 657 if (testTileAndDistributionOptions) { 658 RewritePatternSet patterns(&getContext()); 659 fillTileAndDistributePatterns(&getContext(), patterns); 660 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 661 return; 662 } 663 if (testPatterns) 664 return applyPatterns(getFunction()); 665 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 666 return applyMatmulToVectorPatterns(getFunction(), 667 testMatmulToVectorPatterns1dTiling, 668 testMatmulToVectorPatterns2dTiling); 669 if (testVectorTransferForwardingPatterns) 670 return applyVectorTransferForwardingPatterns(getFunction()); 671 if (testGenericToVectorPattern) 672 return applyLinalgToVectorPatterns(getFunction()); 673 if (testTransformPadTensor) 674 return applyPadTensorToGenericPatterns(getFunction()); 675 if (testGeneralizePadTensor) 676 return applyGeneralizePadTensorPatterns(getFunction()); 677 if (testSwapSubTensorPadTensor) 678 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 679 if (testTiledLoopPeeling.hasValue()) 680 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling); 681 if (testTileAndPadPattern) 682 return applyTileAndPadPattern(getFunction(), tileSizesForPadding); 683 if (testHoistPadding) { 684 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 685 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 686 }); 687 } 688 if (testInterchangePattern.hasValue()) 689 return applyInterchangePattern(getFunction(), testInterchangePattern); 690 } 691 692 namespace mlir { 693 namespace test { 694 void registerTestLinalgTransforms() { 695 PassRegistration<TestLinalgTransforms>(); 696 } 697 } // namespace test 698 } // namespace mlir 699