1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 17 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/ADT/SmallVector.h" 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 31 namespace { 32 struct TestLinalgTransforms 33 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 34 TestLinalgTransforms() = default; 35 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 36 37 void getDependentDialects(DialectRegistry ®istry) const override { 38 // clang-format off 39 registry.insert<AffineDialect, 40 memref::MemRefDialect, 41 scf::SCFDialect, 42 StandardOpsDialect, 43 vector::VectorDialect, 44 gpu::GPUDialect>(); 45 // clang-format on 46 } 47 StringRef getArgument() const final { 48 return "test-linalg-transform-patterns"; 49 } 50 StringRef getDescription() const final { 51 return "Test Linalg transformation patterns by applying them greedily."; 52 } 53 54 void runOnFunction() override; 55 56 Option<bool> testPatterns{*this, "test-patterns", 57 llvm::cl::desc("Test a mixed set of patterns"), 58 llvm::cl::init(false)}; 59 Option<bool> testMatmulToVectorPatterns1dTiling{ 60 *this, "test-matmul-to-vector-patterns-tile-1d", 61 llvm::cl::desc( 62 "Test a fused pass that applies patterns from matmul to vectors via " 63 "1-d tiling"), 64 llvm::cl::init(false)}; 65 Option<bool> testMatmulToVectorPatterns2dTiling{ 66 *this, "test-matmul-to-vector-patterns-tile-2d", 67 llvm::cl::desc( 68 "Test a fused pass that applies patterns from matmul to vectors via " 69 "2-d tiling"), 70 llvm::cl::init(false)}; 71 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 72 llvm::cl::desc("Test promotion options"), 73 llvm::cl::init(false)}; 74 Option<bool> testTileAndDistributionOptions{ 75 *this, "test-tile-and-distribute-options", 76 llvm::cl::desc("Test tile and distribute options"), 77 llvm::cl::init(false)}; 78 Option<bool> testVectorTransferForwardingPatterns{ 79 *this, "test-vector-transfer-forwarding-patterns", 80 llvm::cl::desc( 81 "Test a fused pass that forwards linalg.copy to vector.transfer"), 82 llvm::cl::init(false)}; 83 Option<bool> testGenericToVectorPattern{ 84 *this, "test-linalg-to-vector-patterns", 85 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 86 "in vector.contract form"), 87 llvm::cl::init(false)}; 88 Option<bool> testTilePattern{*this, "test-tile-pattern", 89 llvm::cl::desc("Test tile pattern"), 90 llvm::cl::init(false)}; 91 Option<bool> testTileScalarizeDynamicDims{ 92 *this, "test-tile-scalarize-dynamic-dims", 93 llvm::cl::desc("Test tiling of dynamic dims by 1"), 94 llvm::cl::init(false)}; 95 Option<int> testHoistPadding{*this, "test-hoist-padding", 96 llvm::cl::desc("Test hoist padding"), 97 llvm::cl::init(0)}; 98 Option<bool> testTransformPadTensor{ 99 *this, "test-transform-pad-tensor", 100 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 101 llvm::cl::init(false)}; 102 Option<bool> testGeneralizePadTensor{ 103 *this, "test-generalize-pad-tensor", 104 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 105 llvm::cl::init(false)}; 106 Option<bool> testSwapSubTensorPadTensor{ 107 *this, "test-swap-subtensor-padtensor", 108 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 109 "pad_tensor(subtensor)"), 110 llvm::cl::init(false)}; 111 ListOption<int64_t> paddedOperands{ 112 *this, "padded-operands", 113 llvm::cl::desc("Operands to pad when test-tile-pattern"), 114 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 115 ListOption<int64_t> peeledLoops{ 116 *this, "peeled-loops", 117 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 118 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 119 ListOption<int64_t> tileSizes{ 120 *this, "tile-sizes", 121 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 122 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 123 ListOption<unsigned> testInterchangePattern{ 124 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 125 llvm::cl::desc("Test the interchange pattern.")}; 126 ListOption<unsigned> testTiledLoopPeeling{ 127 *this, "test-tiled-loop-peeling", 128 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 129 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 130 Option<bool> skipPartial{ 131 *this, "skip-partial", 132 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 133 llvm::cl::init(false)}; 134 Option<std::string> loopType{ 135 *this, "loop-type", 136 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 137 "tiled_loop"), 138 llvm::cl::init("for")}; 139 }; 140 } // end anonymous namespace 141 142 static void applyPatterns(FuncOp funcOp) { 143 MLIRContext *ctx = funcOp.getContext(); 144 RewritePatternSet patterns(ctx); 145 146 //===--------------------------------------------------------------------===// 147 // Linalg tiling patterns. 148 //===--------------------------------------------------------------------===// 149 patterns.add<LinalgTilingPattern<MatmulOp>>( 150 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 151 LinalgTransformationFilter(Identifier::get("MEM", ctx), 152 Identifier::get("L3", ctx))); 153 patterns.add<LinalgTilingPattern<MatmulOp>>( 154 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 155 LinalgTransformationFilter(Identifier::get("L3", ctx), 156 Identifier::get("L2", ctx))); 157 patterns.add<LinalgTilingPattern<MatmulOp>>( 158 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 159 LinalgTransformationFilter(Identifier::get("L2", ctx), 160 Identifier::get("L1", ctx))); 161 patterns.add<LinalgTilingPattern<MatmulOp>>( 162 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 163 LinalgTransformationFilter(Identifier::get("L1", ctx), 164 Identifier::get("REG", ctx))); 165 166 patterns.add<LinalgTilingPattern<MatvecOp>>( 167 ctx, 168 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 169 LinalgTilingLoopType::ParallelLoops), 170 LinalgTransformationFilter(ArrayRef<Identifier>{}, 171 Identifier::get("L1", ctx))); 172 173 patterns.add<LinalgTilingPattern<DotOp>>( 174 ctx, LinalgTilingOptions().setTileSizes(8000), 175 LinalgTransformationFilter( 176 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 177 Identifier::get("L3", ctx), 178 Identifier::get("L2", ctx)}, 179 Identifier::get("REG", ctx))); 180 181 //===--------------------------------------------------------------------===// 182 // Linalg tiling and permutation patterns. 183 //===--------------------------------------------------------------------===// 184 patterns.add<LinalgTilingPattern<MatmulOp>>( 185 ctx, 186 LinalgTilingOptions() 187 .setTileSizes({2000, 3000, 4000}) 188 .setInterchange({1, 2, 0}), 189 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 190 Identifier::get("L2__with_perm__", ctx))); 191 patterns.add<LinalgTilingPattern<MatmulOp>>( 192 ctx, 193 LinalgTilingOptions() 194 .setTileSizes({200, 300, 400}) 195 .setInterchange({1, 0, 2}), 196 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 197 Identifier::get("L1__with_perm__", ctx))); 198 patterns.add<LinalgTilingPattern<MatmulOp>>( 199 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 200 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 201 Identifier::get("REG__with_perm__", ctx))); 202 203 patterns.add<LinalgTilingPattern<MatvecOp>>( 204 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 205 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 206 Identifier::get("L1__with_perm__", ctx))); 207 208 patterns.add<LinalgTilingPattern<MatmulOp>>( 209 ctx, 210 LinalgTilingOptions() 211 .setTileSizes({16, 8, 4}) 212 .setInterchange({1, 2, 0}) 213 .setLoopType(LinalgTilingLoopType::ParallelLoops), 214 LinalgTransformationFilter( 215 Identifier::get("par__with_perm__", ctx), 216 Identifier::get("after_par__with_perm__", ctx))); 217 218 //===--------------------------------------------------------------------===// 219 // Linalg to loops patterns. 220 //===--------------------------------------------------------------------===// 221 patterns.add<LinalgLoweringPattern<DotOp>>( 222 ctx, 223 /*loweringType=*/LinalgLoweringType::Loops, 224 LinalgTransformationFilter(Identifier::get("REG", ctx))); 225 226 //===--------------------------------------------------------------------===// 227 // Linalg distribution patterns. 228 //===--------------------------------------------------------------------===// 229 LinalgLoopDistributionOptions distributionOptions; 230 231 //===--------------------------------------------------------------------===// 232 // Linalg to vector contraction patterns. 233 //===--------------------------------------------------------------------===// 234 patterns.add<LinalgVectorizationPattern>( 235 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 236 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 237 238 //===--------------------------------------------------------------------===// 239 // Linalg generic interchange pattern. 240 //===--------------------------------------------------------------------===// 241 patterns.add<GenericOpInterchangePattern>( 242 ctx, 243 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 244 LinalgTransformationFilter(ArrayRef<Identifier>{}, 245 Identifier::get("PERMUTED", ctx))); 246 247 //===--------------------------------------------------------------------===// 248 // Linalg subview operands promotion. 249 //===--------------------------------------------------------------------===// 250 patterns.add<LinalgPromotionPattern<MatmulOp>>( 251 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 252 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 253 Identifier::get("_views_promoted_", ctx))); 254 patterns.add<LinalgPromotionPattern<MatmulOp>>( 255 ctx, 256 LinalgPromotionOptions() 257 .setOperandsToPromote({0}) 258 .setUseFullTileBuffersByDefault(true), 259 LinalgTransformationFilter( 260 Identifier::get("_promote_first_view_", ctx), 261 Identifier::get("_first_view_promoted_", ctx))); 262 patterns.add<LinalgPromotionPattern<FillOp>>( 263 ctx, 264 LinalgPromotionOptions() 265 .setOperandsToPromote({1}) 266 .setUseFullTileBuffers({false, true}) 267 .setAlignment(32), 268 LinalgTransformationFilter( 269 Identifier::get("_promote_views_aligned_", ctx), 270 Identifier::get("_views_aligned_promoted_", ctx))); 271 272 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 273 274 // Drop the marker. 275 funcOp.walk([](LinalgOp op) { 276 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 277 }); 278 } 279 280 static void fillL1TilingAndMatmulToVectorPatterns( 281 FuncOp funcOp, StringRef startMarker, 282 SmallVectorImpl<RewritePatternSet> &patternsVector) { 283 MLIRContext *ctx = funcOp.getContext(); 284 patternsVector.emplace_back( 285 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 286 ctx, 287 LinalgTilingOptions() 288 .setTileSizes({8, 12, 16}) 289 .setInterchange({1, 0, 2}), 290 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 291 Identifier::get("L1", ctx)))); 292 293 patternsVector.emplace_back( 294 ctx, 295 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 296 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 297 LinalgTransformationFilter(Identifier::get("L1", ctx), 298 Identifier::get("VEC", ctx)))); 299 300 patternsVector.emplace_back( 301 ctx, std::make_unique<LinalgVectorizationPattern>( 302 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 303 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 304 patternsVector.back().add<LinalgVectorizationPattern>( 305 ctx, LinalgTransformationFilter().addFilter( 306 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 307 } 308 309 //===----------------------------------------------------------------------===// 310 // Test promotion callbacks 311 //===----------------------------------------------------------------------===// 312 313 // Allocation call back 314 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 315 ArrayRef<Value> boundingSubViewSize, 316 DataLayout &layout) { 317 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 318 return b 319 .create<memref::AllocOp>( 320 subView.getLoc(), 321 MemRefType::get(shape, subView.getType().getElementType(), 322 /*affineMapComposition =*/{}, 3), 323 boundingSubViewSize) 324 .getResult(); 325 } 326 327 // Deallocation callback 328 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 329 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 330 return success(); 331 } 332 333 // Copy in call back 334 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 335 bool isOutput) { 336 auto floatType = src.getType().cast<MemRefType>().getElementType(); 337 if (!floatType.isa<FloatType>()) 338 return failure(); 339 if (!isOutput) { 340 Value cst = 341 b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)); 342 b.create<FillOp>(src.getLoc(), cst, dst); 343 } 344 b.create<CopyOp>(src.getLoc(), src, dst); 345 return success(); 346 } 347 348 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 349 RewritePatternSet &patterns) { 350 patterns.add<LinalgTilingPattern<MatmulOp>>( 351 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 352 LinalgTransformationFilter(Identifier::get("START", ctx), 353 Identifier::get("PROMOTE", ctx))); 354 patterns.add<LinalgPromotionPattern<MatmulOp>>( 355 ctx, 356 LinalgPromotionOptions() 357 .setOperandsToPromote({0, 2}) 358 .setUseFullTileBuffers({false, false}) 359 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 360 .setCopyInOutFns( 361 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 362 return copyCallBackFn(b, src, dst, false); 363 }, 364 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 365 return copyCallBackFn(b, src, dst, true); 366 }), 367 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 368 } 369 370 template <typename IdOp, typename NProcsOp> 371 static SmallVector<ProcInfo, 2> 372 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 373 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 374 SmallVector<ProcInfo, 2> procInfo(count); 375 const char *xyz[] = {"x", "y", "z"}; 376 Type indexType = b.getIndexType(); 377 for (unsigned i = 0; i < count; ++i) { 378 procInfo[count - 1 - i] = { 379 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 380 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 381 } 382 return procInfo; 383 } 384 385 static void fillTileAndDistributePatterns(MLIRContext *context, 386 RewritePatternSet &patterns) { 387 { 388 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 389 cyclicNprocsEqNiters.distributionMethod.resize( 390 2, DistributionMethod::CyclicNumProcsEqNumIters); 391 cyclicNprocsEqNiters.procInfo = 392 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 393 patterns.add<LinalgTilingPattern<MatmulOp>>( 394 context, 395 LinalgTilingOptions() 396 .setTileSizes({8, 8, 4}) 397 .setLoopType(LinalgTilingLoopType::ParallelLoops) 398 .setDistributionOptions(cyclicNprocsEqNiters), 399 LinalgTransformationFilter( 400 Identifier::get("distribute1", context), 401 Identifier::get("after_distribute1", context))); 402 } 403 404 { 405 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 406 cyclicNprocsGeNiters.distributionMethod.resize( 407 2, DistributionMethod::CyclicNumProcsGeNumIters); 408 cyclicNprocsGeNiters.procInfo = 409 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 410 patterns.add<LinalgTilingPattern<MatmulOp>>( 411 context, 412 LinalgTilingOptions() 413 .setTileSizes({8, 8, 4}) 414 .setLoopType(LinalgTilingLoopType::ParallelLoops) 415 .setDistributionOptions(cyclicNprocsGeNiters), 416 LinalgTransformationFilter( 417 Identifier::get("distribute2", context), 418 Identifier::get("after_distribute2", context))); 419 } 420 421 { 422 LinalgLoopDistributionOptions cyclicNprocsDefault; 423 cyclicNprocsDefault.distributionMethod.resize(2, 424 DistributionMethod::Cyclic); 425 cyclicNprocsDefault.procInfo = 426 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 427 patterns.add<LinalgTilingPattern<MatmulOp>>( 428 context, 429 LinalgTilingOptions() 430 .setTileSizes({8, 8, 4}) 431 .setLoopType(LinalgTilingLoopType::ParallelLoops) 432 .setDistributionOptions(cyclicNprocsDefault), 433 LinalgTransformationFilter( 434 Identifier::get("distribute3", context), 435 Identifier::get("after_distribute3", context))); 436 } 437 438 { 439 LinalgLoopDistributionOptions cyclicNprocsMixed1; 440 cyclicNprocsMixed1.distributionMethod = { 441 DistributionMethod::CyclicNumProcsEqNumIters, 442 DistributionMethod::CyclicNumProcsGeNumIters}; 443 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 444 patterns.add<LinalgTilingPattern<MatmulOp>>( 445 context, 446 LinalgTilingOptions() 447 .setTileSizes({8, 8, 4}) 448 .setLoopType(LinalgTilingLoopType::ParallelLoops) 449 .setDistributionOptions(cyclicNprocsMixed1), 450 LinalgTransformationFilter( 451 Identifier::get("distribute4", context), 452 Identifier::get("after_distribute4", context))); 453 } 454 455 { 456 LinalgLoopDistributionOptions cyclicNprocsMixed2; 457 cyclicNprocsMixed2.distributionMethod = { 458 DistributionMethod::CyclicNumProcsGeNumIters, 459 DistributionMethod::Cyclic}; 460 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 461 patterns.add<LinalgTilingPattern<MatmulOp>>( 462 context, 463 LinalgTilingOptions() 464 .setTileSizes({8, 8, 4}) 465 .setLoopType(LinalgTilingLoopType::ParallelLoops) 466 .setDistributionOptions(cyclicNprocsMixed2), 467 LinalgTransformationFilter( 468 Identifier::get("distribute5", context), 469 Identifier::get("after_distribute5", context))); 470 } 471 472 { 473 LinalgLoopDistributionOptions cyclicNprocsMixed3; 474 cyclicNprocsMixed3.distributionMethod = { 475 DistributionMethod::Cyclic, 476 DistributionMethod::CyclicNumProcsEqNumIters}; 477 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 478 479 patterns.add<LinalgTilingPattern<MatmulOp>>( 480 context, 481 LinalgTilingOptions() 482 .setTileSizes({8, 8, 4}) 483 .setLoopType(LinalgTilingLoopType::ParallelLoops) 484 .setDistributionOptions(cyclicNprocsMixed3), 485 LinalgTransformationFilter( 486 Identifier::get("distribute6", context), 487 Identifier::get("after_distribute6", context))); 488 } 489 490 { 491 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 492 cyclicNprocsEqNiters.distributionMethod.resize(2, 493 DistributionMethod::Cyclic); 494 cyclicNprocsEqNiters.procInfo = 495 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 496 patterns.add<LinalgTilingPattern<MatmulOp>>( 497 context, 498 LinalgTilingOptions() 499 .setTileSizes({8, 8, 4}) 500 .setLoopType(LinalgTilingLoopType::Loops) 501 .setDistributionOptions(cyclicNprocsEqNiters), 502 LinalgTransformationFilter( 503 Identifier::get("tensors_distribute1", context), 504 Identifier::get("tensors_after_distribute1", context))); 505 } 506 } 507 508 static void 509 applyMatmulToVectorPatterns(FuncOp funcOp, 510 bool testMatmulToVectorPatterns1dTiling, 511 bool testMatmulToVectorPatterns2dTiling) { 512 MLIRContext *ctx = funcOp.getContext(); 513 SmallVector<RewritePatternSet, 4> stage1Patterns; 514 if (testMatmulToVectorPatterns1dTiling) { 515 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 516 stage1Patterns); 517 } else if (testMatmulToVectorPatterns2dTiling) { 518 stage1Patterns.emplace_back( 519 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 520 ctx, 521 LinalgTilingOptions() 522 .setTileSizes({768, 264, 768}) 523 .setInterchange({1, 2, 0}), 524 LinalgTransformationFilter(Identifier::get("START", ctx), 525 Identifier::get("L2", ctx)))); 526 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 527 stage1Patterns); 528 } 529 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 530 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 531 FrozenRewritePatternSet stage2Patterns = 532 getLinalgTilingCanonicalizationPatterns(ctx); 533 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 534 std::move(stage2Patterns)); 535 } 536 537 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 538 RewritePatternSet forwardPattern(funcOp.getContext()); 539 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 540 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 541 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 542 } 543 544 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 545 RewritePatternSet patterns(funcOp.getContext()); 546 patterns.add<LinalgVectorizationPattern>( 547 funcOp.getContext(), 548 LinalgTransformationFilter() 549 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 550 populatePadTensorOpVectorizationPatterns(patterns); 551 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 552 } 553 554 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 555 RewritePatternSet patterns(funcOp.getContext()); 556 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 557 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 558 } 559 560 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 561 RewritePatternSet patterns(funcOp.getContext()); 562 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 563 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 564 } 565 566 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 567 RewritePatternSet patterns(funcOp.getContext()); 568 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 569 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 570 } 571 572 // For now, just assume it is the zero of type. 573 // In the future, it should be the zero of type + op. 574 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 575 auto t = getElementTypeOrSelf(op.get()); 576 return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); 577 } 578 579 static void applyTilePattern(FuncOp funcOp, std::string loopType, 580 ArrayRef<int64_t> tileSizes, 581 ArrayRef<int64_t> paddedOperands, 582 ArrayRef<int64_t> peeledLoops, 583 bool scalarizeDynamicDims) { 584 MLIRContext *context = funcOp.getContext(); 585 RewritePatternSet tilingPattern(context); 586 LinalgTilingLoopType type = 587 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 588 .Case("for", LinalgTilingLoopType::Loops) 589 .Case("affine", LinalgTilingLoopType::AffineLoops) 590 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 591 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 592 auto linalgTilingOptions = linalg::LinalgTilingOptions() 593 .setPeeledLoops(peeledLoops) 594 .setLoopType(type); 595 if (scalarizeDynamicDims) { 596 linalgTilingOptions.scalarizeDynamicDims(); 597 assert(tileSizes.empty() && 598 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 599 } else { 600 linalgTilingOptions.setTileSizes(tileSizes); 601 } 602 if (!paddedOperands.empty()) { 603 auto paddingFunc = [&](OpBuilder &b, 604 OpOperand &opOperand) -> FailureOr<Value> { 605 if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0) 606 return failure(); 607 return getNeutralOfLinalgOp(b, opOperand); 608 }; 609 linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc); 610 } 611 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, 612 linalg::LinalgTilingPattern<linalg::GenericOp>>( 613 context, linalgTilingOptions, 614 linalg::LinalgTransformationFilter(Identifier::get("tile", context))); 615 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 616 } 617 618 static void applyInterchangePattern(FuncOp funcOp, 619 ArrayRef<unsigned> interchangeVector) { 620 MLIRContext *context = funcOp.getContext(); 621 RewritePatternSet interchangePattern(context); 622 interchangePattern.add<GenericOpInterchangePattern>( 623 context, interchangeVector, 624 LinalgTransformationFilter(ArrayRef<Identifier>{}, 625 Identifier::get("interchange", context))); 626 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 627 } 628 629 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 630 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 631 632 namespace { 633 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 634 /// `idx`-th loop contains only "full" iterations and a second loop for the 635 /// remaining partial iteration (if any). 636 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 637 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 638 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 639 } 640 641 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 642 PatternRewriter &rewriter) const override { 643 SmallVector<int64_t> peeledLoops; 644 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 645 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 646 peeledLoops = 647 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 648 return attr.cast<IntegerAttr>().getInt(); 649 })); 650 // Check if the loop was already peeled. 651 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 652 return failure(); 653 } 654 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 655 // No peeling of loop nests with a partial iteration. 656 return failure(); 657 658 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 659 return failure(); 660 661 // Peel loop and canonicalize. 662 TiledLoopOp result; 663 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 664 result))) 665 return failure(); 666 667 // Apply label, so that the same loop is not rewritten a second time. 668 peeledLoops.push_back(idx); 669 rewriter.updateRootInPlace(loopOp, [&]() { 670 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 671 }); 672 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 673 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 674 675 return success(); 676 } 677 678 /// Index of loop to peel. 679 int64_t idx; 680 681 /// If set to true, do not peel TiledLoopOps with a partial iteration. 682 bool skipPartial; 683 }; 684 } // namespace 685 686 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 687 ArrayRef<unsigned> loops, 688 bool skipPartial) { 689 MLIRContext *ctx = funcOp.getContext(); 690 RewritePatternSet patterns(ctx); 691 for (unsigned idx : loops) 692 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 693 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 694 695 // Drop the markers. 696 funcOp.walk([](TiledLoopOp op) { 697 op->removeAttr(kPeeledLoopsLabel); 698 op->removeAttr(kPartialIterationLabel); 699 }); 700 } 701 702 /// Apply transformations specified as patterns. 703 void TestLinalgTransforms::runOnFunction() { 704 auto lambda = [&](void *) { 705 getFunction().walk([](LinalgOp op) { 706 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 707 }); 708 }; 709 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 710 711 if (testPromotionOptions) { 712 RewritePatternSet patterns(&getContext()); 713 fillPromotionCallBackPatterns(&getContext(), patterns); 714 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 715 return; 716 } 717 if (testTileAndDistributionOptions) { 718 RewritePatternSet patterns(&getContext()); 719 fillTileAndDistributePatterns(&getContext(), patterns); 720 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 721 return; 722 } 723 if (testPatterns) 724 return applyPatterns(getFunction()); 725 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 726 return applyMatmulToVectorPatterns(getFunction(), 727 testMatmulToVectorPatterns1dTiling, 728 testMatmulToVectorPatterns2dTiling); 729 if (testVectorTransferForwardingPatterns) 730 return applyVectorTransferForwardingPatterns(getFunction()); 731 if (testGenericToVectorPattern) 732 return applyLinalgToVectorPatterns(getFunction()); 733 if (testTransformPadTensor) 734 return applyPadTensorToGenericPatterns(getFunction()); 735 if (testGeneralizePadTensor) 736 return applyGeneralizePadTensorPatterns(getFunction()); 737 if (testSwapSubTensorPadTensor) 738 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 739 if (testTiledLoopPeeling.hasValue()) 740 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, 741 skipPartial); 742 if (testTilePattern) 743 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 744 peeledLoops, /*scalarizeDynamicDims=*/false); 745 if (testTileScalarizeDynamicDims) 746 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 747 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 748 if (testHoistPadding) { 749 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 750 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 751 }); 752 } 753 if (testInterchangePattern.hasValue()) 754 return applyInterchangePattern(getFunction(), testInterchangePattern); 755 } 756 757 namespace mlir { 758 namespace test { 759 void registerTestLinalgTransforms() { 760 PassRegistration<TestLinalgTransforms>(); 761 } 762 } // namespace test 763 } // namespace mlir 764