1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 18 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/ADT/SmallVector.h" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 namespace { 33 struct TestLinalgTransforms 34 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 35 TestLinalgTransforms() = default; 36 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 37 38 void getDependentDialects(DialectRegistry ®istry) const override { 39 // clang-format off 40 registry.insert<AffineDialect, 41 memref::MemRefDialect, 42 scf::SCFDialect, 43 StandardOpsDialect, 44 vector::VectorDialect, 45 gpu::GPUDialect>(); 46 // clang-format on 47 } 48 StringRef getArgument() const final { 49 return "test-linalg-transform-patterns"; 50 } 51 StringRef getDescription() const final { 52 return "Test Linalg transformation patterns by applying them greedily."; 53 } 54 55 void runOnFunction() override; 56 57 Option<bool> testPatterns{*this, "test-patterns", 58 llvm::cl::desc("Test a mixed set of patterns"), 59 llvm::cl::init(false)}; 60 Option<bool> testMatmulToVectorPatterns1dTiling{ 61 *this, "test-matmul-to-vector-patterns-tile-1d", 62 llvm::cl::desc( 63 "Test a fused pass that applies patterns from matmul to vectors via " 64 "1-d tiling"), 65 llvm::cl::init(false)}; 66 Option<bool> testMatmulToVectorPatterns2dTiling{ 67 *this, "test-matmul-to-vector-patterns-tile-2d", 68 llvm::cl::desc( 69 "Test a fused pass that applies patterns from matmul to vectors via " 70 "2-d tiling"), 71 llvm::cl::init(false)}; 72 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 73 llvm::cl::desc("Test promotion options"), 74 llvm::cl::init(false)}; 75 Option<bool> testTileAndDistributionOptions{ 76 *this, "test-tile-and-distribute-options", 77 llvm::cl::desc("Test tile and distribute options"), 78 llvm::cl::init(false)}; 79 Option<bool> testVectorTransferForwardingPatterns{ 80 *this, "test-vector-transfer-forwarding-patterns", 81 llvm::cl::desc( 82 "Test a fused pass that forwards linalg.copy to vector.transfer"), 83 llvm::cl::init(false)}; 84 Option<bool> testGenericToVectorPattern{ 85 *this, "test-linalg-to-vector-patterns", 86 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 87 "in vector.contract form"), 88 llvm::cl::init(false)}; 89 Option<bool> testTilePattern{*this, "test-tile-pattern", 90 llvm::cl::desc("Test tile pattern"), 91 llvm::cl::init(false)}; 92 Option<bool> testTileScalarizeDynamicDims{ 93 *this, "test-tile-scalarize-dynamic-dims", 94 llvm::cl::desc("Test tiling of dynamic dims by 1"), 95 llvm::cl::init(false)}; 96 Option<int> testHoistPadding{*this, "test-hoist-padding", 97 llvm::cl::desc("Test hoist padding"), 98 llvm::cl::init(0)}; 99 Option<bool> testTransformPadTensor{ 100 *this, "test-transform-pad-tensor", 101 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 102 llvm::cl::init(false)}; 103 Option<bool> testGeneralizePadTensor{ 104 *this, "test-generalize-pad-tensor", 105 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 106 llvm::cl::init(false)}; 107 Option<bool> testSwapSubTensorPadTensor{ 108 *this, "test-swap-subtensor-padtensor", 109 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 110 "pad_tensor(subtensor)"), 111 llvm::cl::init(false)}; 112 ListOption<int64_t> paddedOperands{ 113 *this, "padded-operands", 114 llvm::cl::desc("Operands to pad when test-tile-pattern"), 115 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 116 ListOption<int64_t> nofoldOperands{ 117 *this, "nofold-operands", 118 llvm::cl::desc("Operands to set nofold when test-tile-pattern"), 119 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 120 ListOption<int64_t> peeledLoops{ 121 *this, "peeled-loops", 122 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 123 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 124 ListOption<int64_t> tileSizes{ 125 *this, "tile-sizes", 126 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 127 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 128 ListOption<unsigned> testInterchangePattern{ 129 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 130 llvm::cl::desc("Test the interchange pattern.")}; 131 ListOption<unsigned> testTiledLoopPeeling{ 132 *this, "test-tiled-loop-peeling", 133 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 134 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 135 Option<bool> skipPartial{ 136 *this, "skip-partial", 137 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 138 llvm::cl::init(false)}; 139 Option<std::string> loopType{ 140 *this, "loop-type", 141 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 142 "tiled_loop"), 143 llvm::cl::init("for")}; 144 }; 145 } // end anonymous namespace 146 147 static void applyPatterns(FuncOp funcOp) { 148 MLIRContext *ctx = funcOp.getContext(); 149 RewritePatternSet patterns(ctx); 150 151 //===--------------------------------------------------------------------===// 152 // Linalg tiling patterns. 153 //===--------------------------------------------------------------------===// 154 patterns.add<LinalgTilingPattern<MatmulOp>>( 155 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 156 LinalgTransformationFilter(Identifier::get("MEM", ctx), 157 Identifier::get("L3", ctx))); 158 patterns.add<LinalgTilingPattern<MatmulOp>>( 159 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 160 LinalgTransformationFilter(Identifier::get("L3", ctx), 161 Identifier::get("L2", ctx))); 162 patterns.add<LinalgTilingPattern<MatmulOp>>( 163 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 164 LinalgTransformationFilter(Identifier::get("L2", ctx), 165 Identifier::get("L1", ctx))); 166 patterns.add<LinalgTilingPattern<MatmulOp>>( 167 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 168 LinalgTransformationFilter(Identifier::get("L1", ctx), 169 Identifier::get("REG", ctx))); 170 171 patterns.add<LinalgTilingPattern<MatvecOp>>( 172 ctx, 173 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 174 LinalgTilingLoopType::ParallelLoops), 175 LinalgTransformationFilter(ArrayRef<Identifier>{}, 176 Identifier::get("L1", ctx))); 177 178 patterns.add<LinalgTilingPattern<DotOp>>( 179 ctx, LinalgTilingOptions().setTileSizes(8000), 180 LinalgTransformationFilter( 181 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 182 Identifier::get("L3", ctx), 183 Identifier::get("L2", ctx)}, 184 Identifier::get("REG", ctx))); 185 186 //===--------------------------------------------------------------------===// 187 // Linalg tiling and permutation patterns. 188 //===--------------------------------------------------------------------===// 189 patterns.add<LinalgTilingPattern<MatmulOp>>( 190 ctx, 191 LinalgTilingOptions() 192 .setTileSizes({2000, 3000, 4000}) 193 .setInterchange({1, 2, 0}), 194 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 195 Identifier::get("L2__with_perm__", ctx))); 196 patterns.add<LinalgTilingPattern<MatmulOp>>( 197 ctx, 198 LinalgTilingOptions() 199 .setTileSizes({200, 300, 400}) 200 .setInterchange({1, 0, 2}), 201 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 202 Identifier::get("L1__with_perm__", ctx))); 203 patterns.add<LinalgTilingPattern<MatmulOp>>( 204 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 205 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 206 Identifier::get("REG__with_perm__", ctx))); 207 208 patterns.add<LinalgTilingPattern<MatvecOp>>( 209 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 210 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 211 Identifier::get("L1__with_perm__", ctx))); 212 213 patterns.add<LinalgTilingPattern<MatmulOp>>( 214 ctx, 215 LinalgTilingOptions() 216 .setTileSizes({16, 8, 4}) 217 .setInterchange({1, 2, 0}) 218 .setLoopType(LinalgTilingLoopType::ParallelLoops), 219 LinalgTransformationFilter( 220 Identifier::get("par__with_perm__", ctx), 221 Identifier::get("after_par__with_perm__", ctx))); 222 223 //===--------------------------------------------------------------------===// 224 // Linalg to loops patterns. 225 //===--------------------------------------------------------------------===// 226 patterns.add<LinalgLoweringPattern<DotOp>>( 227 ctx, 228 /*loweringType=*/LinalgLoweringType::Loops, 229 LinalgTransformationFilter(Identifier::get("REG", ctx))); 230 231 //===--------------------------------------------------------------------===// 232 // Linalg distribution patterns. 233 //===--------------------------------------------------------------------===// 234 LinalgLoopDistributionOptions distributionOptions; 235 236 //===--------------------------------------------------------------------===// 237 // Linalg to vector contraction patterns. 238 //===--------------------------------------------------------------------===// 239 patterns.add<LinalgVectorizationPattern>( 240 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 241 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 242 243 //===--------------------------------------------------------------------===// 244 // Linalg generic interchange pattern. 245 //===--------------------------------------------------------------------===// 246 patterns.add<GenericOpInterchangePattern>( 247 ctx, 248 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 249 LinalgTransformationFilter(ArrayRef<Identifier>{}, 250 Identifier::get("PERMUTED", ctx))); 251 252 //===--------------------------------------------------------------------===// 253 // Linalg subview operands promotion. 254 //===--------------------------------------------------------------------===// 255 patterns.add<LinalgPromotionPattern<MatmulOp>>( 256 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 257 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 258 Identifier::get("_views_promoted_", ctx))); 259 patterns.add<LinalgPromotionPattern<MatmulOp>>( 260 ctx, 261 LinalgPromotionOptions() 262 .setOperandsToPromote({0}) 263 .setUseFullTileBuffersByDefault(true), 264 LinalgTransformationFilter( 265 Identifier::get("_promote_first_view_", ctx), 266 Identifier::get("_first_view_promoted_", ctx))); 267 patterns.add<LinalgPromotionPattern<FillOp>>( 268 ctx, 269 LinalgPromotionOptions() 270 .setOperandsToPromote({1}) 271 .setUseFullTileBuffers({false, true}) 272 .setAlignment(32), 273 LinalgTransformationFilter( 274 Identifier::get("_promote_views_aligned_", ctx), 275 Identifier::get("_views_aligned_promoted_", ctx))); 276 277 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 278 279 // Drop the marker. 280 funcOp.walk([](LinalgOp op) { 281 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 282 }); 283 } 284 285 static void fillL1TilingAndMatmulToVectorPatterns( 286 FuncOp funcOp, StringRef startMarker, 287 SmallVectorImpl<RewritePatternSet> &patternsVector) { 288 MLIRContext *ctx = funcOp.getContext(); 289 patternsVector.emplace_back( 290 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 291 ctx, 292 LinalgTilingOptions() 293 .setTileSizes({8, 12, 16}) 294 .setInterchange({1, 0, 2}), 295 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 296 Identifier::get("L1", ctx)))); 297 298 patternsVector.emplace_back( 299 ctx, 300 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 301 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 302 LinalgTransformationFilter(Identifier::get("L1", ctx), 303 Identifier::get("VEC", ctx)))); 304 305 patternsVector.emplace_back( 306 ctx, std::make_unique<LinalgVectorizationPattern>( 307 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 308 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 309 patternsVector.back().add<LinalgVectorizationPattern>( 310 ctx, LinalgTransformationFilter().addFilter( 311 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // Test promotion callbacks 316 //===----------------------------------------------------------------------===// 317 318 // Allocation call back 319 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 320 ArrayRef<Value> boundingSubViewSize, 321 DataLayout &layout) { 322 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 323 return b 324 .create<memref::AllocOp>( 325 subView.getLoc(), 326 MemRefType::get(shape, subView.getType().getElementType(), 327 /*affineMapComposition =*/{}, 3), 328 boundingSubViewSize) 329 .getResult(); 330 } 331 332 // Deallocation callback 333 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 334 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 335 return success(); 336 } 337 338 // Copy in call back 339 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 340 bool isOutput) { 341 auto floatType = src.getType().cast<MemRefType>().getElementType(); 342 if (!floatType.isa<FloatType>()) 343 return failure(); 344 if (!isOutput) { 345 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 346 FloatAttr::get(floatType, 42.0)); 347 b.create<FillOp>(src.getLoc(), cst, dst); 348 } 349 b.create<CopyOp>(src.getLoc(), src, dst); 350 return success(); 351 } 352 353 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 354 RewritePatternSet &patterns) { 355 patterns.add<LinalgTilingPattern<MatmulOp>>( 356 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 357 LinalgTransformationFilter(Identifier::get("START", ctx), 358 Identifier::get("PROMOTE", ctx))); 359 patterns.add<LinalgPromotionPattern<MatmulOp>>( 360 ctx, 361 LinalgPromotionOptions() 362 .setOperandsToPromote({0, 2}) 363 .setUseFullTileBuffers({false, false}) 364 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 365 .setCopyInOutFns( 366 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 367 return copyCallBackFn(b, src, dst, false); 368 }, 369 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 370 return copyCallBackFn(b, src, dst, true); 371 }), 372 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 373 } 374 375 template <typename IdOp, typename NProcsOp> 376 static SmallVector<ProcInfo, 2> 377 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 378 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 379 SmallVector<ProcInfo, 2> procInfo(count); 380 const char *xyz[] = {"x", "y", "z"}; 381 Type indexType = b.getIndexType(); 382 for (unsigned i = 0; i < count; ++i) { 383 procInfo[count - 1 - i] = { 384 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 385 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 386 } 387 return procInfo; 388 } 389 390 static void fillTileAndDistributePatterns(MLIRContext *context, 391 RewritePatternSet &patterns) { 392 { 393 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 394 cyclicNprocsEqNiters.distributionMethod.resize( 395 2, DistributionMethod::CyclicNumProcsEqNumIters); 396 cyclicNprocsEqNiters.procInfo = 397 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 398 patterns.add<LinalgTilingPattern<MatmulOp>>( 399 context, 400 LinalgTilingOptions() 401 .setTileSizes({8, 8, 4}) 402 .setLoopType(LinalgTilingLoopType::ParallelLoops) 403 .setDistributionOptions(cyclicNprocsEqNiters), 404 LinalgTransformationFilter( 405 Identifier::get("distribute1", context), 406 Identifier::get("after_distribute1", context))); 407 } 408 409 { 410 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 411 cyclicNprocsGeNiters.distributionMethod.resize( 412 2, DistributionMethod::CyclicNumProcsGeNumIters); 413 cyclicNprocsGeNiters.procInfo = 414 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 415 patterns.add<LinalgTilingPattern<MatmulOp>>( 416 context, 417 LinalgTilingOptions() 418 .setTileSizes({8, 8, 4}) 419 .setLoopType(LinalgTilingLoopType::ParallelLoops) 420 .setDistributionOptions(cyclicNprocsGeNiters), 421 LinalgTransformationFilter( 422 Identifier::get("distribute2", context), 423 Identifier::get("after_distribute2", context))); 424 } 425 426 { 427 LinalgLoopDistributionOptions cyclicNprocsDefault; 428 cyclicNprocsDefault.distributionMethod.resize(2, 429 DistributionMethod::Cyclic); 430 cyclicNprocsDefault.procInfo = 431 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 432 patterns.add<LinalgTilingPattern<MatmulOp>>( 433 context, 434 LinalgTilingOptions() 435 .setTileSizes({8, 8, 4}) 436 .setLoopType(LinalgTilingLoopType::ParallelLoops) 437 .setDistributionOptions(cyclicNprocsDefault), 438 LinalgTransformationFilter( 439 Identifier::get("distribute3", context), 440 Identifier::get("after_distribute3", context))); 441 } 442 443 { 444 LinalgLoopDistributionOptions cyclicNprocsMixed1; 445 cyclicNprocsMixed1.distributionMethod = { 446 DistributionMethod::CyclicNumProcsEqNumIters, 447 DistributionMethod::CyclicNumProcsGeNumIters}; 448 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 449 patterns.add<LinalgTilingPattern<MatmulOp>>( 450 context, 451 LinalgTilingOptions() 452 .setTileSizes({8, 8, 4}) 453 .setLoopType(LinalgTilingLoopType::ParallelLoops) 454 .setDistributionOptions(cyclicNprocsMixed1), 455 LinalgTransformationFilter( 456 Identifier::get("distribute4", context), 457 Identifier::get("after_distribute4", context))); 458 } 459 460 { 461 LinalgLoopDistributionOptions cyclicNprocsMixed2; 462 cyclicNprocsMixed2.distributionMethod = { 463 DistributionMethod::CyclicNumProcsGeNumIters, 464 DistributionMethod::Cyclic}; 465 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 466 patterns.add<LinalgTilingPattern<MatmulOp>>( 467 context, 468 LinalgTilingOptions() 469 .setTileSizes({8, 8, 4}) 470 .setLoopType(LinalgTilingLoopType::ParallelLoops) 471 .setDistributionOptions(cyclicNprocsMixed2), 472 LinalgTransformationFilter( 473 Identifier::get("distribute5", context), 474 Identifier::get("after_distribute5", context))); 475 } 476 477 { 478 LinalgLoopDistributionOptions cyclicNprocsMixed3; 479 cyclicNprocsMixed3.distributionMethod = { 480 DistributionMethod::Cyclic, 481 DistributionMethod::CyclicNumProcsEqNumIters}; 482 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 483 484 patterns.add<LinalgTilingPattern<MatmulOp>>( 485 context, 486 LinalgTilingOptions() 487 .setTileSizes({8, 8, 4}) 488 .setLoopType(LinalgTilingLoopType::ParallelLoops) 489 .setDistributionOptions(cyclicNprocsMixed3), 490 LinalgTransformationFilter( 491 Identifier::get("distribute6", context), 492 Identifier::get("after_distribute6", context))); 493 } 494 495 { 496 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 497 cyclicNprocsEqNiters.distributionMethod.resize(2, 498 DistributionMethod::Cyclic); 499 cyclicNprocsEqNiters.procInfo = 500 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 501 patterns.add<LinalgTilingPattern<MatmulOp>>( 502 context, 503 LinalgTilingOptions() 504 .setTileSizes({8, 8, 4}) 505 .setLoopType(LinalgTilingLoopType::Loops) 506 .setDistributionOptions(cyclicNprocsEqNiters), 507 LinalgTransformationFilter( 508 Identifier::get("tensors_distribute1", context), 509 Identifier::get("tensors_after_distribute1", context))); 510 } 511 } 512 513 static void 514 applyMatmulToVectorPatterns(FuncOp funcOp, 515 bool testMatmulToVectorPatterns1dTiling, 516 bool testMatmulToVectorPatterns2dTiling) { 517 MLIRContext *ctx = funcOp.getContext(); 518 SmallVector<RewritePatternSet, 4> stage1Patterns; 519 if (testMatmulToVectorPatterns1dTiling) { 520 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 521 stage1Patterns); 522 } else if (testMatmulToVectorPatterns2dTiling) { 523 stage1Patterns.emplace_back( 524 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 525 ctx, 526 LinalgTilingOptions() 527 .setTileSizes({768, 264, 768}) 528 .setInterchange({1, 2, 0}), 529 LinalgTransformationFilter(Identifier::get("START", ctx), 530 Identifier::get("L2", ctx)))); 531 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 532 stage1Patterns); 533 } 534 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 535 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 536 FrozenRewritePatternSet stage2Patterns = 537 getLinalgTilingCanonicalizationPatterns(ctx); 538 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 539 std::move(stage2Patterns)); 540 } 541 542 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 543 RewritePatternSet forwardPattern(funcOp.getContext()); 544 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 545 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 546 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 547 } 548 549 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 550 RewritePatternSet patterns(funcOp.getContext()); 551 patterns.add<LinalgVectorizationPattern>( 552 funcOp.getContext(), 553 LinalgTransformationFilter() 554 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 555 populatePadTensorOpVectorizationPatterns(patterns); 556 populateConvolutionVectorizationPatterns(patterns); 557 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 558 } 559 560 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 561 RewritePatternSet patterns(funcOp.getContext()); 562 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 563 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 564 } 565 566 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 567 RewritePatternSet patterns(funcOp.getContext()); 568 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 569 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 570 } 571 572 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 573 RewritePatternSet patterns(funcOp.getContext()); 574 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 575 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 576 } 577 578 // For now, just assume it is the zero of type. 579 // In the future, it should be the zero of type + op. 580 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 581 auto t = getElementTypeOrSelf(op.get()); 582 return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t, 583 b.getZeroAttr(t)); 584 } 585 586 static void applyTilePattern(FuncOp funcOp, std::string loopType, 587 ArrayRef<int64_t> tileSizes, 588 ArrayRef<int64_t> paddedOperands, 589 ArrayRef<int64_t> nofoldOperands, 590 ArrayRef<int64_t> peeledLoops, 591 bool scalarizeDynamicDims) { 592 MLIRContext *context = funcOp.getContext(); 593 RewritePatternSet tilingPattern(context); 594 LinalgTilingLoopType type = 595 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 596 .Case("for", LinalgTilingLoopType::Loops) 597 .Case("affine", LinalgTilingLoopType::AffineLoops) 598 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 599 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 600 auto linalgTilingOptions = linalg::LinalgTilingOptions() 601 .setPeeledLoops(peeledLoops) 602 .setLoopType(type); 603 if (scalarizeDynamicDims) { 604 linalgTilingOptions.scalarizeDynamicDims(); 605 assert(tileSizes.empty() && 606 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 607 } else { 608 linalgTilingOptions.setTileSizes(tileSizes); 609 } 610 if (!paddedOperands.empty()) { 611 auto paddingFunc = [&](OpBuilder &b, 612 OpOperand &opOperand) -> FailureOr<Value> { 613 if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0) 614 return failure(); 615 return getNeutralOfLinalgOp(b, opOperand); 616 }; 617 auto nofoldFunc = [&](OpOperand &opOperand) { 618 if (llvm::count(nofoldOperands, opOperand.getOperandNumber()) != 0) 619 return true; 620 return false; 621 }; 622 linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc); 623 linalgTilingOptions.setPaddingNoFoldComputationFunction(nofoldFunc); 624 } 625 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, 626 linalg::LinalgTilingPattern<linalg::GenericOp>>( 627 context, linalgTilingOptions, 628 linalg::LinalgTransformationFilter(Identifier::get("tile", context))); 629 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 630 } 631 632 static void applyInterchangePattern(FuncOp funcOp, 633 ArrayRef<unsigned> interchangeVector) { 634 MLIRContext *context = funcOp.getContext(); 635 RewritePatternSet interchangePattern(context); 636 interchangePattern.add<GenericOpInterchangePattern>( 637 context, interchangeVector, 638 LinalgTransformationFilter(ArrayRef<Identifier>{}, 639 Identifier::get("interchange", context))); 640 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 641 } 642 643 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 644 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 645 646 namespace { 647 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 648 /// `idx`-th loop contains only "full" iterations and a second loop for the 649 /// remaining partial iteration (if any). 650 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 651 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 652 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 653 } 654 655 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 656 PatternRewriter &rewriter) const override { 657 SmallVector<int64_t> peeledLoops; 658 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 659 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 660 peeledLoops = 661 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 662 return attr.cast<IntegerAttr>().getInt(); 663 })); 664 // Check if the loop was already peeled. 665 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 666 return failure(); 667 } 668 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 669 // No peeling of loop nests with a partial iteration. 670 return failure(); 671 672 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 673 return failure(); 674 675 // Peel loop and canonicalize. 676 TiledLoopOp result; 677 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 678 result))) 679 return failure(); 680 681 // Apply label, so that the same loop is not rewritten a second time. 682 peeledLoops.push_back(idx); 683 rewriter.updateRootInPlace(loopOp, [&]() { 684 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 685 }); 686 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 687 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 688 689 return success(); 690 } 691 692 /// Index of loop to peel. 693 int64_t idx; 694 695 /// If set to true, do not peel TiledLoopOps with a partial iteration. 696 bool skipPartial; 697 }; 698 } // namespace 699 700 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 701 ArrayRef<unsigned> loops, 702 bool skipPartial) { 703 MLIRContext *ctx = funcOp.getContext(); 704 RewritePatternSet patterns(ctx); 705 for (unsigned idx : loops) 706 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 707 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 708 709 // Drop the markers. 710 funcOp.walk([](TiledLoopOp op) { 711 op->removeAttr(kPeeledLoopsLabel); 712 op->removeAttr(kPartialIterationLabel); 713 }); 714 } 715 716 /// Apply transformations specified as patterns. 717 void TestLinalgTransforms::runOnFunction() { 718 auto lambda = [&](void *) { 719 getFunction().walk([](LinalgOp op) { 720 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 721 }); 722 }; 723 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 724 725 if (testPromotionOptions) { 726 RewritePatternSet patterns(&getContext()); 727 fillPromotionCallBackPatterns(&getContext(), patterns); 728 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 729 return; 730 } 731 if (testTileAndDistributionOptions) { 732 RewritePatternSet patterns(&getContext()); 733 fillTileAndDistributePatterns(&getContext(), patterns); 734 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 735 return; 736 } 737 if (testPatterns) 738 return applyPatterns(getFunction()); 739 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 740 return applyMatmulToVectorPatterns(getFunction(), 741 testMatmulToVectorPatterns1dTiling, 742 testMatmulToVectorPatterns2dTiling); 743 if (testVectorTransferForwardingPatterns) 744 return applyVectorTransferForwardingPatterns(getFunction()); 745 if (testGenericToVectorPattern) 746 return applyLinalgToVectorPatterns(getFunction()); 747 if (testTransformPadTensor) 748 return applyPadTensorToGenericPatterns(getFunction()); 749 if (testGeneralizePadTensor) 750 return applyGeneralizePadTensorPatterns(getFunction()); 751 if (testSwapSubTensorPadTensor) 752 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 753 if (testTiledLoopPeeling.hasValue()) 754 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, 755 skipPartial); 756 if (testTilePattern) 757 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 758 nofoldOperands, peeledLoops, 759 /*scalarizeDynamicDims=*/false); 760 if (testTileScalarizeDynamicDims) 761 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 762 nofoldOperands, 763 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 764 if (testHoistPadding) { 765 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 766 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 767 }); 768 } 769 if (testInterchangePattern.hasValue()) 770 return applyInterchangePattern(getFunction(), testInterchangePattern); 771 } 772 773 namespace mlir { 774 namespace test { 775 void registerTestLinalgTransforms() { 776 PassRegistration<TestLinalgTransforms>(); 777 } 778 } // namespace test 779 } // namespace mlir 780