1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 18 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/ADT/SmallVector.h" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 namespace { 33 struct TestLinalgTransforms 34 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 35 TestLinalgTransforms() = default; 36 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 37 38 void getDependentDialects(DialectRegistry ®istry) const override { 39 // clang-format off 40 registry.insert<AffineDialect, 41 memref::MemRefDialect, 42 scf::SCFDialect, 43 StandardOpsDialect, 44 vector::VectorDialect, 45 gpu::GPUDialect>(); 46 // clang-format on 47 } 48 StringRef getArgument() const final { 49 return "test-linalg-transform-patterns"; 50 } 51 StringRef getDescription() const final { 52 return "Test Linalg transformation patterns by applying them greedily."; 53 } 54 55 void runOnFunction() override; 56 57 Option<bool> testPatterns{*this, "test-patterns", 58 llvm::cl::desc("Test a mixed set of patterns"), 59 llvm::cl::init(false)}; 60 Option<bool> testMatmulToVectorPatterns1dTiling{ 61 *this, "test-matmul-to-vector-patterns-tile-1d", 62 llvm::cl::desc( 63 "Test a fused pass that applies patterns from matmul to vectors via " 64 "1-d tiling"), 65 llvm::cl::init(false)}; 66 Option<bool> testMatmulToVectorPatterns2dTiling{ 67 *this, "test-matmul-to-vector-patterns-tile-2d", 68 llvm::cl::desc( 69 "Test a fused pass that applies patterns from matmul to vectors via " 70 "2-d tiling"), 71 llvm::cl::init(false)}; 72 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 73 llvm::cl::desc("Test promotion options"), 74 llvm::cl::init(false)}; 75 Option<bool> testTileAndDistributionOptions{ 76 *this, "test-tile-and-distribute-options", 77 llvm::cl::desc("Test tile and distribute options"), 78 llvm::cl::init(false)}; 79 Option<bool> testVectorTransferForwardingPatterns{ 80 *this, "test-vector-transfer-forwarding-patterns", 81 llvm::cl::desc( 82 "Test a fused pass that forwards linalg.copy to vector.transfer"), 83 llvm::cl::init(false)}; 84 Option<bool> testGenericToVectorPattern{ 85 *this, "test-linalg-to-vector-patterns", 86 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 87 "in vector.contract form"), 88 llvm::cl::init(false)}; 89 Option<bool> testTilePattern{*this, "test-tile-pattern", 90 llvm::cl::desc("Test tile pattern"), 91 llvm::cl::init(false)}; 92 Option<bool> testTileScalarizeDynamicDims{ 93 *this, "test-tile-scalarize-dynamic-dims", 94 llvm::cl::desc("Test tiling of dynamic dims by 1"), 95 llvm::cl::init(false)}; 96 Option<int> testHoistPadding{*this, "test-hoist-padding", 97 llvm::cl::desc("Test hoist padding"), 98 llvm::cl::init(0)}; 99 Option<bool> testTransformPadTensor{ 100 *this, "test-transform-pad-tensor", 101 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 102 llvm::cl::init(false)}; 103 Option<bool> testGeneralizePadTensor{ 104 *this, "test-generalize-pad-tensor", 105 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 106 llvm::cl::init(false)}; 107 Option<bool> testSwapSubTensorPadTensor{ 108 *this, "test-swap-subtensor-padtensor", 109 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 110 "pad_tensor(subtensor)"), 111 llvm::cl::init(false)}; 112 ListOption<int64_t> paddedOperands{ 113 *this, "padded-operands", 114 llvm::cl::desc("Operands to pad when test-tile-pattern"), 115 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 116 ListOption<int64_t> peeledLoops{ 117 *this, "peeled-loops", 118 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 119 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 120 ListOption<int64_t> tileSizes{ 121 *this, "tile-sizes", 122 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 123 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 124 ListOption<unsigned> testInterchangePattern{ 125 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 126 llvm::cl::desc("Test the interchange pattern.")}; 127 ListOption<unsigned> testTiledLoopPeeling{ 128 *this, "test-tiled-loop-peeling", 129 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 130 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 131 Option<bool> skipPartial{ 132 *this, "skip-partial", 133 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 134 llvm::cl::init(false)}; 135 Option<std::string> loopType{ 136 *this, "loop-type", 137 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 138 "tiled_loop"), 139 llvm::cl::init("for")}; 140 }; 141 } // end anonymous namespace 142 143 static void applyPatterns(FuncOp funcOp) { 144 MLIRContext *ctx = funcOp.getContext(); 145 RewritePatternSet patterns(ctx); 146 147 //===--------------------------------------------------------------------===// 148 // Linalg tiling patterns. 149 //===--------------------------------------------------------------------===// 150 patterns.add<LinalgTilingPattern<MatmulOp>>( 151 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 152 LinalgTransformationFilter(Identifier::get("MEM", ctx), 153 Identifier::get("L3", ctx))); 154 patterns.add<LinalgTilingPattern<MatmulOp>>( 155 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 156 LinalgTransformationFilter(Identifier::get("L3", ctx), 157 Identifier::get("L2", ctx))); 158 patterns.add<LinalgTilingPattern<MatmulOp>>( 159 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 160 LinalgTransformationFilter(Identifier::get("L2", ctx), 161 Identifier::get("L1", ctx))); 162 patterns.add<LinalgTilingPattern<MatmulOp>>( 163 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 164 LinalgTransformationFilter(Identifier::get("L1", ctx), 165 Identifier::get("REG", ctx))); 166 167 patterns.add<LinalgTilingPattern<MatvecOp>>( 168 ctx, 169 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 170 LinalgTilingLoopType::ParallelLoops), 171 LinalgTransformationFilter(ArrayRef<Identifier>{}, 172 Identifier::get("L1", ctx))); 173 174 patterns.add<LinalgTilingPattern<DotOp>>( 175 ctx, LinalgTilingOptions().setTileSizes(8000), 176 LinalgTransformationFilter( 177 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 178 Identifier::get("L3", ctx), 179 Identifier::get("L2", ctx)}, 180 Identifier::get("REG", ctx))); 181 182 //===--------------------------------------------------------------------===// 183 // Linalg tiling and permutation patterns. 184 //===--------------------------------------------------------------------===// 185 patterns.add<LinalgTilingPattern<MatmulOp>>( 186 ctx, 187 LinalgTilingOptions() 188 .setTileSizes({2000, 3000, 4000}) 189 .setInterchange({1, 2, 0}), 190 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 191 Identifier::get("L2__with_perm__", ctx))); 192 patterns.add<LinalgTilingPattern<MatmulOp>>( 193 ctx, 194 LinalgTilingOptions() 195 .setTileSizes({200, 300, 400}) 196 .setInterchange({1, 0, 2}), 197 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 198 Identifier::get("L1__with_perm__", ctx))); 199 patterns.add<LinalgTilingPattern<MatmulOp>>( 200 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 201 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 202 Identifier::get("REG__with_perm__", ctx))); 203 204 patterns.add<LinalgTilingPattern<MatvecOp>>( 205 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 206 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 207 Identifier::get("L1__with_perm__", ctx))); 208 209 patterns.add<LinalgTilingPattern<MatmulOp>>( 210 ctx, 211 LinalgTilingOptions() 212 .setTileSizes({16, 8, 4}) 213 .setInterchange({1, 2, 0}) 214 .setLoopType(LinalgTilingLoopType::ParallelLoops), 215 LinalgTransformationFilter( 216 Identifier::get("par__with_perm__", ctx), 217 Identifier::get("after_par__with_perm__", ctx))); 218 219 //===--------------------------------------------------------------------===// 220 // Linalg to loops patterns. 221 //===--------------------------------------------------------------------===// 222 patterns.add<LinalgLoweringPattern<DotOp>>( 223 ctx, 224 /*loweringType=*/LinalgLoweringType::Loops, 225 LinalgTransformationFilter(Identifier::get("REG", ctx))); 226 227 //===--------------------------------------------------------------------===// 228 // Linalg distribution patterns. 229 //===--------------------------------------------------------------------===// 230 LinalgLoopDistributionOptions distributionOptions; 231 232 //===--------------------------------------------------------------------===// 233 // Linalg to vector contraction patterns. 234 //===--------------------------------------------------------------------===// 235 patterns.add<LinalgVectorizationPattern>( 236 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 237 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 238 239 //===--------------------------------------------------------------------===// 240 // Linalg generic interchange pattern. 241 //===--------------------------------------------------------------------===// 242 patterns.add<GenericOpInterchangePattern>( 243 ctx, 244 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 245 LinalgTransformationFilter(ArrayRef<Identifier>{}, 246 Identifier::get("PERMUTED", ctx))); 247 248 //===--------------------------------------------------------------------===// 249 // Linalg subview operands promotion. 250 //===--------------------------------------------------------------------===// 251 patterns.add<LinalgPromotionPattern<MatmulOp>>( 252 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 253 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 254 Identifier::get("_views_promoted_", ctx))); 255 patterns.add<LinalgPromotionPattern<MatmulOp>>( 256 ctx, 257 LinalgPromotionOptions() 258 .setOperandsToPromote({0}) 259 .setUseFullTileBuffersByDefault(true), 260 LinalgTransformationFilter( 261 Identifier::get("_promote_first_view_", ctx), 262 Identifier::get("_first_view_promoted_", ctx))); 263 patterns.add<LinalgPromotionPattern<FillOp>>( 264 ctx, 265 LinalgPromotionOptions() 266 .setOperandsToPromote({1}) 267 .setUseFullTileBuffers({false, true}) 268 .setAlignment(32), 269 LinalgTransformationFilter( 270 Identifier::get("_promote_views_aligned_", ctx), 271 Identifier::get("_views_aligned_promoted_", ctx))); 272 273 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 274 275 // Drop the marker. 276 funcOp.walk([](LinalgOp op) { 277 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 278 }); 279 } 280 281 static void fillL1TilingAndMatmulToVectorPatterns( 282 FuncOp funcOp, StringRef startMarker, 283 SmallVectorImpl<RewritePatternSet> &patternsVector) { 284 MLIRContext *ctx = funcOp.getContext(); 285 patternsVector.emplace_back( 286 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 287 ctx, 288 LinalgTilingOptions() 289 .setTileSizes({8, 12, 16}) 290 .setInterchange({1, 0, 2}), 291 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 292 Identifier::get("L1", ctx)))); 293 294 patternsVector.emplace_back( 295 ctx, 296 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 297 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 298 LinalgTransformationFilter(Identifier::get("L1", ctx), 299 Identifier::get("VEC", ctx)))); 300 301 patternsVector.emplace_back( 302 ctx, std::make_unique<LinalgVectorizationPattern>( 303 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 304 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 305 patternsVector.back().add<LinalgVectorizationPattern>( 306 ctx, LinalgTransformationFilter().addFilter( 307 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 308 } 309 310 //===----------------------------------------------------------------------===// 311 // Test promotion callbacks 312 //===----------------------------------------------------------------------===// 313 314 // Allocation call back 315 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 316 ArrayRef<Value> boundingSubViewSize, 317 DataLayout &layout) { 318 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 319 return b 320 .create<memref::AllocOp>( 321 subView.getLoc(), 322 MemRefType::get(shape, subView.getType().getElementType(), 323 /*affineMapComposition =*/{}, 3), 324 boundingSubViewSize) 325 .getResult(); 326 } 327 328 // Deallocation callback 329 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 330 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 331 return success(); 332 } 333 334 // Copy in call back 335 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 336 bool isOutput) { 337 auto floatType = src.getType().cast<MemRefType>().getElementType(); 338 if (!floatType.isa<FloatType>()) 339 return failure(); 340 if (!isOutput) { 341 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 342 FloatAttr::get(floatType, 42.0)); 343 b.create<FillOp>(src.getLoc(), cst, dst); 344 } 345 b.create<CopyOp>(src.getLoc(), src, dst); 346 return success(); 347 } 348 349 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 350 RewritePatternSet &patterns) { 351 patterns.add<LinalgTilingPattern<MatmulOp>>( 352 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 353 LinalgTransformationFilter(Identifier::get("START", ctx), 354 Identifier::get("PROMOTE", ctx))); 355 patterns.add<LinalgPromotionPattern<MatmulOp>>( 356 ctx, 357 LinalgPromotionOptions() 358 .setOperandsToPromote({0, 2}) 359 .setUseFullTileBuffers({false, false}) 360 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 361 .setCopyInOutFns( 362 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 363 return copyCallBackFn(b, src, dst, false); 364 }, 365 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 366 return copyCallBackFn(b, src, dst, true); 367 }), 368 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 369 } 370 371 template <typename IdOp, typename NProcsOp> 372 static SmallVector<ProcInfo, 2> 373 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 374 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 375 SmallVector<ProcInfo, 2> procInfo(count); 376 const char *xyz[] = {"x", "y", "z"}; 377 Type indexType = b.getIndexType(); 378 for (unsigned i = 0; i < count; ++i) { 379 procInfo[count - 1 - i] = { 380 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 381 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 382 } 383 return procInfo; 384 } 385 386 static void fillTileAndDistributePatterns(MLIRContext *context, 387 RewritePatternSet &patterns) { 388 { 389 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 390 cyclicNprocsEqNiters.distributionMethod.resize( 391 2, DistributionMethod::CyclicNumProcsEqNumIters); 392 cyclicNprocsEqNiters.procInfo = 393 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 394 patterns.add<LinalgTilingPattern<MatmulOp>>( 395 context, 396 LinalgTilingOptions() 397 .setTileSizes({8, 8, 4}) 398 .setLoopType(LinalgTilingLoopType::ParallelLoops) 399 .setDistributionOptions(cyclicNprocsEqNiters), 400 LinalgTransformationFilter( 401 Identifier::get("distribute1", context), 402 Identifier::get("after_distribute1", context))); 403 } 404 405 { 406 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 407 cyclicNprocsGeNiters.distributionMethod.resize( 408 2, DistributionMethod::CyclicNumProcsGeNumIters); 409 cyclicNprocsGeNiters.procInfo = 410 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 411 patterns.add<LinalgTilingPattern<MatmulOp>>( 412 context, 413 LinalgTilingOptions() 414 .setTileSizes({8, 8, 4}) 415 .setLoopType(LinalgTilingLoopType::ParallelLoops) 416 .setDistributionOptions(cyclicNprocsGeNiters), 417 LinalgTransformationFilter( 418 Identifier::get("distribute2", context), 419 Identifier::get("after_distribute2", context))); 420 } 421 422 { 423 LinalgLoopDistributionOptions cyclicNprocsDefault; 424 cyclicNprocsDefault.distributionMethod.resize(2, 425 DistributionMethod::Cyclic); 426 cyclicNprocsDefault.procInfo = 427 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 428 patterns.add<LinalgTilingPattern<MatmulOp>>( 429 context, 430 LinalgTilingOptions() 431 .setTileSizes({8, 8, 4}) 432 .setLoopType(LinalgTilingLoopType::ParallelLoops) 433 .setDistributionOptions(cyclicNprocsDefault), 434 LinalgTransformationFilter( 435 Identifier::get("distribute3", context), 436 Identifier::get("after_distribute3", context))); 437 } 438 439 { 440 LinalgLoopDistributionOptions cyclicNprocsMixed1; 441 cyclicNprocsMixed1.distributionMethod = { 442 DistributionMethod::CyclicNumProcsEqNumIters, 443 DistributionMethod::CyclicNumProcsGeNumIters}; 444 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 445 patterns.add<LinalgTilingPattern<MatmulOp>>( 446 context, 447 LinalgTilingOptions() 448 .setTileSizes({8, 8, 4}) 449 .setLoopType(LinalgTilingLoopType::ParallelLoops) 450 .setDistributionOptions(cyclicNprocsMixed1), 451 LinalgTransformationFilter( 452 Identifier::get("distribute4", context), 453 Identifier::get("after_distribute4", context))); 454 } 455 456 { 457 LinalgLoopDistributionOptions cyclicNprocsMixed2; 458 cyclicNprocsMixed2.distributionMethod = { 459 DistributionMethod::CyclicNumProcsGeNumIters, 460 DistributionMethod::Cyclic}; 461 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 462 patterns.add<LinalgTilingPattern<MatmulOp>>( 463 context, 464 LinalgTilingOptions() 465 .setTileSizes({8, 8, 4}) 466 .setLoopType(LinalgTilingLoopType::ParallelLoops) 467 .setDistributionOptions(cyclicNprocsMixed2), 468 LinalgTransformationFilter( 469 Identifier::get("distribute5", context), 470 Identifier::get("after_distribute5", context))); 471 } 472 473 { 474 LinalgLoopDistributionOptions cyclicNprocsMixed3; 475 cyclicNprocsMixed3.distributionMethod = { 476 DistributionMethod::Cyclic, 477 DistributionMethod::CyclicNumProcsEqNumIters}; 478 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 479 480 patterns.add<LinalgTilingPattern<MatmulOp>>( 481 context, 482 LinalgTilingOptions() 483 .setTileSizes({8, 8, 4}) 484 .setLoopType(LinalgTilingLoopType::ParallelLoops) 485 .setDistributionOptions(cyclicNprocsMixed3), 486 LinalgTransformationFilter( 487 Identifier::get("distribute6", context), 488 Identifier::get("after_distribute6", context))); 489 } 490 491 { 492 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 493 cyclicNprocsEqNiters.distributionMethod.resize(2, 494 DistributionMethod::Cyclic); 495 cyclicNprocsEqNiters.procInfo = 496 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 497 patterns.add<LinalgTilingPattern<MatmulOp>>( 498 context, 499 LinalgTilingOptions() 500 .setTileSizes({8, 8, 4}) 501 .setLoopType(LinalgTilingLoopType::Loops) 502 .setDistributionOptions(cyclicNprocsEqNiters), 503 LinalgTransformationFilter( 504 Identifier::get("tensors_distribute1", context), 505 Identifier::get("tensors_after_distribute1", context))); 506 } 507 } 508 509 static void 510 applyMatmulToVectorPatterns(FuncOp funcOp, 511 bool testMatmulToVectorPatterns1dTiling, 512 bool testMatmulToVectorPatterns2dTiling) { 513 MLIRContext *ctx = funcOp.getContext(); 514 SmallVector<RewritePatternSet, 4> stage1Patterns; 515 if (testMatmulToVectorPatterns1dTiling) { 516 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 517 stage1Patterns); 518 } else if (testMatmulToVectorPatterns2dTiling) { 519 stage1Patterns.emplace_back( 520 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 521 ctx, 522 LinalgTilingOptions() 523 .setTileSizes({768, 264, 768}) 524 .setInterchange({1, 2, 0}), 525 LinalgTransformationFilter(Identifier::get("START", ctx), 526 Identifier::get("L2", ctx)))); 527 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 528 stage1Patterns); 529 } 530 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 531 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 532 FrozenRewritePatternSet stage2Patterns = 533 getLinalgTilingCanonicalizationPatterns(ctx); 534 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 535 std::move(stage2Patterns)); 536 } 537 538 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 539 RewritePatternSet forwardPattern(funcOp.getContext()); 540 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 541 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 542 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 543 } 544 545 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 546 RewritePatternSet patterns(funcOp.getContext()); 547 patterns.add<LinalgVectorizationPattern>( 548 funcOp.getContext(), 549 LinalgTransformationFilter() 550 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 551 populatePadTensorOpVectorizationPatterns(patterns); 552 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 553 } 554 555 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 556 RewritePatternSet patterns(funcOp.getContext()); 557 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 558 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 559 } 560 561 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 562 RewritePatternSet patterns(funcOp.getContext()); 563 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 564 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 565 } 566 567 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 568 RewritePatternSet patterns(funcOp.getContext()); 569 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 570 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 571 } 572 573 // For now, just assume it is the zero of type. 574 // In the future, it should be the zero of type + op. 575 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 576 auto t = getElementTypeOrSelf(op.get()); 577 return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t, 578 b.getZeroAttr(t)); 579 } 580 581 static void applyTilePattern(FuncOp funcOp, std::string loopType, 582 ArrayRef<int64_t> tileSizes, 583 ArrayRef<int64_t> paddedOperands, 584 ArrayRef<int64_t> peeledLoops, 585 bool scalarizeDynamicDims) { 586 MLIRContext *context = funcOp.getContext(); 587 RewritePatternSet tilingPattern(context); 588 LinalgTilingLoopType type = 589 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 590 .Case("for", LinalgTilingLoopType::Loops) 591 .Case("affine", LinalgTilingLoopType::AffineLoops) 592 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 593 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 594 auto linalgTilingOptions = linalg::LinalgTilingOptions() 595 .setPeeledLoops(peeledLoops) 596 .setLoopType(type); 597 if (scalarizeDynamicDims) { 598 linalgTilingOptions.scalarizeDynamicDims(); 599 assert(tileSizes.empty() && 600 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 601 } else { 602 linalgTilingOptions.setTileSizes(tileSizes); 603 } 604 if (!paddedOperands.empty()) { 605 auto paddingFunc = [&](OpBuilder &b, 606 OpOperand &opOperand) -> FailureOr<Value> { 607 if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0) 608 return failure(); 609 return getNeutralOfLinalgOp(b, opOperand); 610 }; 611 linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc); 612 } 613 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, 614 linalg::LinalgTilingPattern<linalg::GenericOp>>( 615 context, linalgTilingOptions, 616 linalg::LinalgTransformationFilter(Identifier::get("tile", context))); 617 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 618 } 619 620 static void applyInterchangePattern(FuncOp funcOp, 621 ArrayRef<unsigned> interchangeVector) { 622 MLIRContext *context = funcOp.getContext(); 623 RewritePatternSet interchangePattern(context); 624 interchangePattern.add<GenericOpInterchangePattern>( 625 context, interchangeVector, 626 LinalgTransformationFilter(ArrayRef<Identifier>{}, 627 Identifier::get("interchange", context))); 628 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 629 } 630 631 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 632 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 633 634 namespace { 635 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 636 /// `idx`-th loop contains only "full" iterations and a second loop for the 637 /// remaining partial iteration (if any). 638 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 639 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 640 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 641 } 642 643 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 644 PatternRewriter &rewriter) const override { 645 SmallVector<int64_t> peeledLoops; 646 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 647 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 648 peeledLoops = 649 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 650 return attr.cast<IntegerAttr>().getInt(); 651 })); 652 // Check if the loop was already peeled. 653 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 654 return failure(); 655 } 656 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 657 // No peeling of loop nests with a partial iteration. 658 return failure(); 659 660 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 661 return failure(); 662 663 // Peel loop and canonicalize. 664 TiledLoopOp result; 665 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 666 result))) 667 return failure(); 668 669 // Apply label, so that the same loop is not rewritten a second time. 670 peeledLoops.push_back(idx); 671 rewriter.updateRootInPlace(loopOp, [&]() { 672 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 673 }); 674 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 675 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 676 677 return success(); 678 } 679 680 /// Index of loop to peel. 681 int64_t idx; 682 683 /// If set to true, do not peel TiledLoopOps with a partial iteration. 684 bool skipPartial; 685 }; 686 } // namespace 687 688 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 689 ArrayRef<unsigned> loops, 690 bool skipPartial) { 691 MLIRContext *ctx = funcOp.getContext(); 692 RewritePatternSet patterns(ctx); 693 for (unsigned idx : loops) 694 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 695 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 696 697 // Drop the markers. 698 funcOp.walk([](TiledLoopOp op) { 699 op->removeAttr(kPeeledLoopsLabel); 700 op->removeAttr(kPartialIterationLabel); 701 }); 702 } 703 704 /// Apply transformations specified as patterns. 705 void TestLinalgTransforms::runOnFunction() { 706 auto lambda = [&](void *) { 707 getFunction().walk([](LinalgOp op) { 708 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 709 }); 710 }; 711 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 712 713 if (testPromotionOptions) { 714 RewritePatternSet patterns(&getContext()); 715 fillPromotionCallBackPatterns(&getContext(), patterns); 716 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 717 return; 718 } 719 if (testTileAndDistributionOptions) { 720 RewritePatternSet patterns(&getContext()); 721 fillTileAndDistributePatterns(&getContext(), patterns); 722 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 723 return; 724 } 725 if (testPatterns) 726 return applyPatterns(getFunction()); 727 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 728 return applyMatmulToVectorPatterns(getFunction(), 729 testMatmulToVectorPatterns1dTiling, 730 testMatmulToVectorPatterns2dTiling); 731 if (testVectorTransferForwardingPatterns) 732 return applyVectorTransferForwardingPatterns(getFunction()); 733 if (testGenericToVectorPattern) 734 return applyLinalgToVectorPatterns(getFunction()); 735 if (testTransformPadTensor) 736 return applyPadTensorToGenericPatterns(getFunction()); 737 if (testGeneralizePadTensor) 738 return applyGeneralizePadTensorPatterns(getFunction()); 739 if (testSwapSubTensorPadTensor) 740 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 741 if (testTiledLoopPeeling.hasValue()) 742 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, 743 skipPartial); 744 if (testTilePattern) 745 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 746 peeledLoops, /*scalarizeDynamicDims=*/false); 747 if (testTileScalarizeDynamicDims) 748 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 749 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 750 if (testHoistPadding) { 751 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 752 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 753 }); 754 } 755 if (testInterchangePattern.hasValue()) 756 return applyInterchangePattern(getFunction(), testInterchangePattern); 757 } 758 759 namespace mlir { 760 namespace test { 761 void registerTestLinalgTransforms() { 762 PassRegistration<TestLinalgTransforms>(); 763 } 764 } // namespace test 765 } // namespace mlir 766