1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 18 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/ADT/SmallVector.h" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 namespace { 33 struct TestLinalgTransforms 34 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 35 TestLinalgTransforms() = default; 36 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 37 38 void getDependentDialects(DialectRegistry ®istry) const override { 39 // clang-format off 40 registry.insert<AffineDialect, 41 memref::MemRefDialect, 42 scf::SCFDialect, 43 StandardOpsDialect, 44 vector::VectorDialect, 45 gpu::GPUDialect>(); 46 // clang-format on 47 } 48 StringRef getArgument() const final { 49 return "test-linalg-transform-patterns"; 50 } 51 StringRef getDescription() const final { 52 return "Test Linalg transformation patterns by applying them greedily."; 53 } 54 55 void runOnFunction() override; 56 57 Option<bool> testPatterns{*this, "test-patterns", 58 llvm::cl::desc("Test a mixed set of patterns"), 59 llvm::cl::init(false)}; 60 Option<bool> testMatmulToVectorPatterns1dTiling{ 61 *this, "test-matmul-to-vector-patterns-tile-1d", 62 llvm::cl::desc( 63 "Test a fused pass that applies patterns from matmul to vectors via " 64 "1-d tiling"), 65 llvm::cl::init(false)}; 66 Option<bool> testMatmulToVectorPatterns2dTiling{ 67 *this, "test-matmul-to-vector-patterns-tile-2d", 68 llvm::cl::desc( 69 "Test a fused pass that applies patterns from matmul to vectors via " 70 "2-d tiling"), 71 llvm::cl::init(false)}; 72 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 73 llvm::cl::desc("Test promotion options"), 74 llvm::cl::init(false)}; 75 Option<bool> testTileAndDistributionOptions{ 76 *this, "test-tile-and-distribute-options", 77 llvm::cl::desc("Test tile and distribute options"), 78 llvm::cl::init(false)}; 79 Option<bool> testVectorTransferForwardingPatterns{ 80 *this, "test-vector-transfer-forwarding-patterns", 81 llvm::cl::desc( 82 "Test a fused pass that forwards linalg.copy to vector.transfer"), 83 llvm::cl::init(false)}; 84 Option<bool> testGenericToVectorPattern{ 85 *this, "test-linalg-to-vector-patterns", 86 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 87 "in vector.contract form"), 88 llvm::cl::init(false)}; 89 Option<bool> testTilePattern{*this, "test-tile-pattern", 90 llvm::cl::desc("Test tile pattern"), 91 llvm::cl::init(false)}; 92 Option<bool> testTileScalarizeDynamicDims{ 93 *this, "test-tile-scalarize-dynamic-dims", 94 llvm::cl::desc("Test tiling of dynamic dims by 1"), 95 llvm::cl::init(false)}; 96 Option<int> testHoistPadding{*this, "test-hoist-padding", 97 llvm::cl::desc("Test hoist padding"), 98 llvm::cl::init(0)}; 99 Option<bool> testTransformPadTensor{ 100 *this, "test-transform-pad-tensor", 101 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 102 llvm::cl::init(false)}; 103 Option<bool> testGeneralizePadTensor{ 104 *this, "test-generalize-pad-tensor", 105 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 106 llvm::cl::init(false)}; 107 Option<bool> testSwapSubTensorPadTensor{ 108 *this, "test-swap-subtensor-padtensor", 109 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 110 "pad_tensor(subtensor)"), 111 llvm::cl::init(false)}; 112 ListOption<int64_t> paddedOperands{ 113 *this, "padded-operands", 114 llvm::cl::desc("Operands to pad when test-tile-pattern"), 115 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 116 ListOption<int64_t> nofoldOperands{ 117 *this, "nofold-operands", 118 llvm::cl::desc("Operands to set nofold when test-tile-pattern"), 119 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 120 ListOption<int64_t> peeledLoops{ 121 *this, "peeled-loops", 122 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 123 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 124 ListOption<int64_t> tileSizes{ 125 *this, "tile-sizes", 126 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 127 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 128 ListOption<unsigned> testInterchangePattern{ 129 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 130 llvm::cl::desc("Test the interchange pattern.")}; 131 ListOption<unsigned> testTiledLoopPeeling{ 132 *this, "test-tiled-loop-peeling", 133 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 134 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 135 Option<bool> skipPartial{ 136 *this, "skip-partial", 137 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 138 llvm::cl::init(false)}; 139 Option<std::string> loopType{ 140 *this, "loop-type", 141 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 142 "tiled_loop"), 143 llvm::cl::init("for")}; 144 }; 145 } // end anonymous namespace 146 147 static void applyPatterns(FuncOp funcOp) { 148 MLIRContext *ctx = funcOp.getContext(); 149 RewritePatternSet patterns(ctx); 150 151 //===--------------------------------------------------------------------===// 152 // Linalg tiling patterns. 153 //===--------------------------------------------------------------------===// 154 patterns.add<LinalgTilingPattern<MatmulOp>>( 155 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 156 LinalgTransformationFilter(Identifier::get("MEM", ctx), 157 Identifier::get("L3", ctx))); 158 patterns.add<LinalgTilingPattern<MatmulOp>>( 159 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 160 LinalgTransformationFilter(Identifier::get("L3", ctx), 161 Identifier::get("L2", ctx))); 162 patterns.add<LinalgTilingPattern<MatmulOp>>( 163 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 164 LinalgTransformationFilter(Identifier::get("L2", ctx), 165 Identifier::get("L1", ctx))); 166 patterns.add<LinalgTilingPattern<MatmulOp>>( 167 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 168 LinalgTransformationFilter(Identifier::get("L1", ctx), 169 Identifier::get("REG", ctx))); 170 171 patterns.add<LinalgTilingPattern<MatvecOp>>( 172 ctx, 173 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 174 LinalgTilingLoopType::ParallelLoops), 175 LinalgTransformationFilter(ArrayRef<Identifier>{}, 176 Identifier::get("L1", ctx))); 177 178 patterns.add<LinalgTilingPattern<DotOp>>( 179 ctx, LinalgTilingOptions().setTileSizes(8000), 180 LinalgTransformationFilter( 181 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 182 Identifier::get("L3", ctx), 183 Identifier::get("L2", ctx)}, 184 Identifier::get("REG", ctx))); 185 186 //===--------------------------------------------------------------------===// 187 // Linalg tiling and permutation patterns. 188 //===--------------------------------------------------------------------===// 189 patterns.add<LinalgTilingPattern<MatmulOp>>( 190 ctx, 191 LinalgTilingOptions() 192 .setTileSizes({2000, 3000, 4000}) 193 .setInterchange({1, 2, 0}), 194 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 195 Identifier::get("L2__with_perm__", ctx))); 196 patterns.add<LinalgTilingPattern<MatmulOp>>( 197 ctx, 198 LinalgTilingOptions() 199 .setTileSizes({200, 300, 400}) 200 .setInterchange({1, 0, 2}), 201 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 202 Identifier::get("L1__with_perm__", ctx))); 203 patterns.add<LinalgTilingPattern<MatmulOp>>( 204 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 205 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 206 Identifier::get("REG__with_perm__", ctx))); 207 208 patterns.add<LinalgTilingPattern<MatvecOp>>( 209 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 210 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 211 Identifier::get("L1__with_perm__", ctx))); 212 213 patterns.add<LinalgTilingPattern<MatmulOp>>( 214 ctx, 215 LinalgTilingOptions() 216 .setTileSizes({16, 8, 4}) 217 .setInterchange({1, 2, 0}) 218 .setLoopType(LinalgTilingLoopType::ParallelLoops), 219 LinalgTransformationFilter( 220 Identifier::get("par__with_perm__", ctx), 221 Identifier::get("after_par__with_perm__", ctx))); 222 223 //===--------------------------------------------------------------------===// 224 // Linalg to loops patterns. 225 //===--------------------------------------------------------------------===// 226 patterns.add<LinalgLoweringPattern<DotOp>>( 227 ctx, 228 /*loweringType=*/LinalgLoweringType::Loops, 229 LinalgTransformationFilter(Identifier::get("REG", ctx))); 230 231 //===--------------------------------------------------------------------===// 232 // Linalg distribution patterns. 233 //===--------------------------------------------------------------------===// 234 LinalgLoopDistributionOptions distributionOptions; 235 236 //===--------------------------------------------------------------------===// 237 // Linalg to vector contraction patterns. 238 //===--------------------------------------------------------------------===// 239 patterns.add<LinalgVectorizationPattern>( 240 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 241 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 242 243 //===--------------------------------------------------------------------===// 244 // Linalg generic interchange pattern. 245 //===--------------------------------------------------------------------===// 246 patterns.add<GenericOpInterchangePattern>( 247 ctx, 248 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 249 LinalgTransformationFilter(ArrayRef<Identifier>{}, 250 Identifier::get("PERMUTED", ctx))); 251 252 //===--------------------------------------------------------------------===// 253 // Linalg subview operands promotion. 254 //===--------------------------------------------------------------------===// 255 patterns.add<LinalgPromotionPattern<MatmulOp>>( 256 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 257 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 258 Identifier::get("_views_promoted_", ctx))); 259 patterns.add<LinalgPromotionPattern<MatmulOp>>( 260 ctx, 261 LinalgPromotionOptions() 262 .setOperandsToPromote({0}) 263 .setUseFullTileBuffersByDefault(true), 264 LinalgTransformationFilter( 265 Identifier::get("_promote_first_view_", ctx), 266 Identifier::get("_first_view_promoted_", ctx))); 267 patterns.add<LinalgPromotionPattern<FillOp>>( 268 ctx, 269 LinalgPromotionOptions() 270 .setOperandsToPromote({1}) 271 .setUseFullTileBuffers({false, true}) 272 .setAlignment(32), 273 LinalgTransformationFilter( 274 Identifier::get("_promote_views_aligned_", ctx), 275 Identifier::get("_views_aligned_promoted_", ctx))); 276 277 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 278 279 // Drop the marker. 280 funcOp.walk([](LinalgOp op) { 281 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 282 }); 283 } 284 285 static void fillL1TilingAndMatmulToVectorPatterns( 286 FuncOp funcOp, StringRef startMarker, 287 SmallVectorImpl<RewritePatternSet> &patternsVector) { 288 MLIRContext *ctx = funcOp.getContext(); 289 patternsVector.emplace_back( 290 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 291 ctx, 292 LinalgTilingOptions() 293 .setTileSizes({8, 12, 16}) 294 .setInterchange({1, 0, 2}), 295 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 296 Identifier::get("L1", ctx)))); 297 298 patternsVector.emplace_back( 299 ctx, 300 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 301 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 302 LinalgTransformationFilter(Identifier::get("L1", ctx), 303 Identifier::get("VEC", ctx)))); 304 305 patternsVector.emplace_back( 306 ctx, std::make_unique<LinalgVectorizationPattern>( 307 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 308 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 309 patternsVector.back().add<LinalgVectorizationPattern>( 310 ctx, LinalgTransformationFilter().addFilter( 311 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // Test promotion callbacks 316 //===----------------------------------------------------------------------===// 317 318 // Allocation call back 319 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 320 ArrayRef<Value> boundingSubViewSize, 321 DataLayout &layout) { 322 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 323 return b 324 .create<memref::AllocOp>( 325 subView.getLoc(), 326 MemRefType::get(shape, subView.getType().getElementType(), 327 /*affineMapComposition =*/{}, 3), 328 boundingSubViewSize) 329 .getResult(); 330 } 331 332 // Deallocation callback 333 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 334 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 335 return success(); 336 } 337 338 // Copy in call back 339 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 340 bool isOutput) { 341 auto floatType = src.getType().cast<MemRefType>().getElementType(); 342 if (!floatType.isa<FloatType>()) 343 return failure(); 344 if (!isOutput) { 345 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 346 FloatAttr::get(floatType, 42.0)); 347 b.create<FillOp>(src.getLoc(), cst, dst); 348 } 349 b.create<CopyOp>(src.getLoc(), src, dst); 350 return success(); 351 } 352 353 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 354 RewritePatternSet &patterns) { 355 patterns.add<LinalgTilingPattern<MatmulOp>>( 356 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 357 LinalgTransformationFilter(Identifier::get("START", ctx), 358 Identifier::get("PROMOTE", ctx))); 359 patterns.add<LinalgPromotionPattern<MatmulOp>>( 360 ctx, 361 LinalgPromotionOptions() 362 .setOperandsToPromote({0, 2}) 363 .setUseFullTileBuffers({false, false}) 364 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 365 .setCopyInOutFns( 366 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 367 return copyCallBackFn(b, src, dst, false); 368 }, 369 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 370 return copyCallBackFn(b, src, dst, true); 371 }), 372 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 373 } 374 375 template <typename IdOp, typename NProcsOp> 376 static SmallVector<ProcInfo, 2> 377 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 378 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 379 SmallVector<ProcInfo, 2> procInfo(count); 380 const char *xyz[] = {"x", "y", "z"}; 381 Type indexType = b.getIndexType(); 382 for (unsigned i = 0; i < count; ++i) { 383 procInfo[count - 1 - i] = { 384 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 385 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 386 } 387 return procInfo; 388 } 389 390 static void fillTileAndDistributePatterns(MLIRContext *context, 391 RewritePatternSet &patterns) { 392 { 393 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 394 cyclicNprocsEqNiters.distributionMethod.resize( 395 2, DistributionMethod::CyclicNumProcsEqNumIters); 396 cyclicNprocsEqNiters.procInfo = 397 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 398 patterns.add<LinalgTilingPattern<MatmulOp>>( 399 context, 400 LinalgTilingOptions() 401 .setTileSizes({8, 8, 4}) 402 .setLoopType(LinalgTilingLoopType::ParallelLoops) 403 .setDistributionOptions(cyclicNprocsEqNiters), 404 LinalgTransformationFilter( 405 Identifier::get("distribute1", context), 406 Identifier::get("after_distribute1", context))); 407 } 408 409 { 410 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 411 cyclicNprocsGeNiters.distributionMethod.resize( 412 2, DistributionMethod::CyclicNumProcsGeNumIters); 413 cyclicNprocsGeNiters.procInfo = 414 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 415 patterns.add<LinalgTilingPattern<MatmulOp>>( 416 context, 417 LinalgTilingOptions() 418 .setTileSizes({8, 8, 4}) 419 .setLoopType(LinalgTilingLoopType::ParallelLoops) 420 .setDistributionOptions(cyclicNprocsGeNiters), 421 LinalgTransformationFilter( 422 Identifier::get("distribute2", context), 423 Identifier::get("after_distribute2", context))); 424 } 425 426 { 427 LinalgLoopDistributionOptions cyclicNprocsDefault; 428 cyclicNprocsDefault.distributionMethod.resize(2, 429 DistributionMethod::Cyclic); 430 cyclicNprocsDefault.procInfo = 431 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 432 patterns.add<LinalgTilingPattern<MatmulOp>>( 433 context, 434 LinalgTilingOptions() 435 .setTileSizes({8, 8, 4}) 436 .setLoopType(LinalgTilingLoopType::ParallelLoops) 437 .setDistributionOptions(cyclicNprocsDefault), 438 LinalgTransformationFilter( 439 Identifier::get("distribute3", context), 440 Identifier::get("after_distribute3", context))); 441 } 442 443 { 444 LinalgLoopDistributionOptions cyclicNprocsMixed1; 445 cyclicNprocsMixed1.distributionMethod = { 446 DistributionMethod::CyclicNumProcsEqNumIters, 447 DistributionMethod::CyclicNumProcsGeNumIters}; 448 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 449 patterns.add<LinalgTilingPattern<MatmulOp>>( 450 context, 451 LinalgTilingOptions() 452 .setTileSizes({8, 8, 4}) 453 .setLoopType(LinalgTilingLoopType::ParallelLoops) 454 .setDistributionOptions(cyclicNprocsMixed1), 455 LinalgTransformationFilter( 456 Identifier::get("distribute4", context), 457 Identifier::get("after_distribute4", context))); 458 } 459 460 { 461 LinalgLoopDistributionOptions cyclicNprocsMixed2; 462 cyclicNprocsMixed2.distributionMethod = { 463 DistributionMethod::CyclicNumProcsGeNumIters, 464 DistributionMethod::Cyclic}; 465 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 466 patterns.add<LinalgTilingPattern<MatmulOp>>( 467 context, 468 LinalgTilingOptions() 469 .setTileSizes({8, 8, 4}) 470 .setLoopType(LinalgTilingLoopType::ParallelLoops) 471 .setDistributionOptions(cyclicNprocsMixed2), 472 LinalgTransformationFilter( 473 Identifier::get("distribute5", context), 474 Identifier::get("after_distribute5", context))); 475 } 476 477 { 478 LinalgLoopDistributionOptions cyclicNprocsMixed3; 479 cyclicNprocsMixed3.distributionMethod = { 480 DistributionMethod::Cyclic, 481 DistributionMethod::CyclicNumProcsEqNumIters}; 482 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 483 484 patterns.add<LinalgTilingPattern<MatmulOp>>( 485 context, 486 LinalgTilingOptions() 487 .setTileSizes({8, 8, 4}) 488 .setLoopType(LinalgTilingLoopType::ParallelLoops) 489 .setDistributionOptions(cyclicNprocsMixed3), 490 LinalgTransformationFilter( 491 Identifier::get("distribute6", context), 492 Identifier::get("after_distribute6", context))); 493 } 494 495 { 496 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 497 cyclicNprocsEqNiters.distributionMethod.resize(2, 498 DistributionMethod::Cyclic); 499 cyclicNprocsEqNiters.procInfo = 500 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 501 patterns.add<LinalgTilingPattern<MatmulOp>>( 502 context, 503 LinalgTilingOptions() 504 .setTileSizes({8, 8, 4}) 505 .setLoopType(LinalgTilingLoopType::Loops) 506 .setDistributionOptions(cyclicNprocsEqNiters), 507 LinalgTransformationFilter( 508 Identifier::get("tensors_distribute1", context), 509 Identifier::get("tensors_after_distribute1", context))); 510 } 511 } 512 513 static void 514 applyMatmulToVectorPatterns(FuncOp funcOp, 515 bool testMatmulToVectorPatterns1dTiling, 516 bool testMatmulToVectorPatterns2dTiling) { 517 MLIRContext *ctx = funcOp.getContext(); 518 SmallVector<RewritePatternSet, 4> stage1Patterns; 519 if (testMatmulToVectorPatterns1dTiling) { 520 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 521 stage1Patterns); 522 } else if (testMatmulToVectorPatterns2dTiling) { 523 stage1Patterns.emplace_back( 524 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 525 ctx, 526 LinalgTilingOptions() 527 .setTileSizes({768, 264, 768}) 528 .setInterchange({1, 2, 0}), 529 LinalgTransformationFilter(Identifier::get("START", ctx), 530 Identifier::get("L2", ctx)))); 531 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 532 stage1Patterns); 533 } 534 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 535 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 536 FrozenRewritePatternSet stage2Patterns = 537 getLinalgTilingCanonicalizationPatterns(ctx); 538 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 539 std::move(stage2Patterns)); 540 } 541 542 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 543 RewritePatternSet forwardPattern(funcOp.getContext()); 544 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 545 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 546 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 547 } 548 549 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 550 RewritePatternSet patterns(funcOp.getContext()); 551 patterns.add<LinalgVectorizationPattern>( 552 funcOp.getContext(), 553 LinalgTransformationFilter() 554 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 555 populatePadTensorOpVectorizationPatterns(patterns); 556 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 557 } 558 559 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 560 RewritePatternSet patterns(funcOp.getContext()); 561 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 562 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 563 } 564 565 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 566 RewritePatternSet patterns(funcOp.getContext()); 567 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 568 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 569 } 570 571 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 572 RewritePatternSet patterns(funcOp.getContext()); 573 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 574 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 575 } 576 577 // For now, just assume it is the zero of type. 578 // In the future, it should be the zero of type + op. 579 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 580 auto t = getElementTypeOrSelf(op.get()); 581 return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t, 582 b.getZeroAttr(t)); 583 } 584 585 static void applyTilePattern(FuncOp funcOp, std::string loopType, 586 ArrayRef<int64_t> tileSizes, 587 ArrayRef<int64_t> paddedOperands, 588 ArrayRef<int64_t> nofoldOperands, 589 ArrayRef<int64_t> peeledLoops, 590 bool scalarizeDynamicDims) { 591 MLIRContext *context = funcOp.getContext(); 592 RewritePatternSet tilingPattern(context); 593 LinalgTilingLoopType type = 594 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 595 .Case("for", LinalgTilingLoopType::Loops) 596 .Case("affine", LinalgTilingLoopType::AffineLoops) 597 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 598 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 599 auto linalgTilingOptions = linalg::LinalgTilingOptions() 600 .setPeeledLoops(peeledLoops) 601 .setLoopType(type); 602 if (scalarizeDynamicDims) { 603 linalgTilingOptions.scalarizeDynamicDims(); 604 assert(tileSizes.empty() && 605 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 606 } else { 607 linalgTilingOptions.setTileSizes(tileSizes); 608 } 609 if (!paddedOperands.empty()) { 610 auto paddingFunc = [&](OpBuilder &b, 611 OpOperand &opOperand) -> FailureOr<Value> { 612 if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0) 613 return failure(); 614 return getNeutralOfLinalgOp(b, opOperand); 615 }; 616 auto nofoldFunc = [&](OpOperand &opOperand) { 617 if (llvm::count(nofoldOperands, opOperand.getOperandNumber()) != 0) 618 return true; 619 return false; 620 }; 621 linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc); 622 linalgTilingOptions.setPaddingNoFoldComputationFunction(nofoldFunc); 623 } 624 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, 625 linalg::LinalgTilingPattern<linalg::GenericOp>>( 626 context, linalgTilingOptions, 627 linalg::LinalgTransformationFilter(Identifier::get("tile", context))); 628 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 629 } 630 631 static void applyInterchangePattern(FuncOp funcOp, 632 ArrayRef<unsigned> interchangeVector) { 633 MLIRContext *context = funcOp.getContext(); 634 RewritePatternSet interchangePattern(context); 635 interchangePattern.add<GenericOpInterchangePattern>( 636 context, interchangeVector, 637 LinalgTransformationFilter(ArrayRef<Identifier>{}, 638 Identifier::get("interchange", context))); 639 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 640 } 641 642 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 643 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 644 645 namespace { 646 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 647 /// `idx`-th loop contains only "full" iterations and a second loop for the 648 /// remaining partial iteration (if any). 649 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 650 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 651 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 652 } 653 654 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 655 PatternRewriter &rewriter) const override { 656 SmallVector<int64_t> peeledLoops; 657 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 658 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 659 peeledLoops = 660 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 661 return attr.cast<IntegerAttr>().getInt(); 662 })); 663 // Check if the loop was already peeled. 664 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 665 return failure(); 666 } 667 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 668 // No peeling of loop nests with a partial iteration. 669 return failure(); 670 671 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 672 return failure(); 673 674 // Peel loop and canonicalize. 675 TiledLoopOp result; 676 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 677 result))) 678 return failure(); 679 680 // Apply label, so that the same loop is not rewritten a second time. 681 peeledLoops.push_back(idx); 682 rewriter.updateRootInPlace(loopOp, [&]() { 683 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 684 }); 685 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 686 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 687 688 return success(); 689 } 690 691 /// Index of loop to peel. 692 int64_t idx; 693 694 /// If set to true, do not peel TiledLoopOps with a partial iteration. 695 bool skipPartial; 696 }; 697 } // namespace 698 699 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 700 ArrayRef<unsigned> loops, 701 bool skipPartial) { 702 MLIRContext *ctx = funcOp.getContext(); 703 RewritePatternSet patterns(ctx); 704 for (unsigned idx : loops) 705 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 706 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 707 708 // Drop the markers. 709 funcOp.walk([](TiledLoopOp op) { 710 op->removeAttr(kPeeledLoopsLabel); 711 op->removeAttr(kPartialIterationLabel); 712 }); 713 } 714 715 /// Apply transformations specified as patterns. 716 void TestLinalgTransforms::runOnFunction() { 717 auto lambda = [&](void *) { 718 getFunction().walk([](LinalgOp op) { 719 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 720 }); 721 }; 722 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 723 724 if (testPromotionOptions) { 725 RewritePatternSet patterns(&getContext()); 726 fillPromotionCallBackPatterns(&getContext(), patterns); 727 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 728 return; 729 } 730 if (testTileAndDistributionOptions) { 731 RewritePatternSet patterns(&getContext()); 732 fillTileAndDistributePatterns(&getContext(), patterns); 733 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 734 return; 735 } 736 if (testPatterns) 737 return applyPatterns(getFunction()); 738 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 739 return applyMatmulToVectorPatterns(getFunction(), 740 testMatmulToVectorPatterns1dTiling, 741 testMatmulToVectorPatterns2dTiling); 742 if (testVectorTransferForwardingPatterns) 743 return applyVectorTransferForwardingPatterns(getFunction()); 744 if (testGenericToVectorPattern) 745 return applyLinalgToVectorPatterns(getFunction()); 746 if (testTransformPadTensor) 747 return applyPadTensorToGenericPatterns(getFunction()); 748 if (testGeneralizePadTensor) 749 return applyGeneralizePadTensorPatterns(getFunction()); 750 if (testSwapSubTensorPadTensor) 751 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 752 if (testTiledLoopPeeling.hasValue()) 753 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, 754 skipPartial); 755 if (testTilePattern) 756 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 757 nofoldOperands, peeledLoops, 758 /*scalarizeDynamicDims=*/false); 759 if (testTileScalarizeDynamicDims) 760 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 761 nofoldOperands, 762 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 763 if (testHoistPadding) { 764 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 765 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 766 }); 767 } 768 if (testInterchangePattern.hasValue()) 769 return applyInterchangePattern(getFunction(), testInterchangePattern); 770 } 771 772 namespace mlir { 773 namespace test { 774 void registerTestLinalgTransforms() { 775 PassRegistration<TestLinalgTransforms>(); 776 } 777 } // namespace test 778 } // namespace mlir 779