1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 18 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/ADT/SmallVector.h" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 namespace { 33 struct TestLinalgTransforms 34 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 35 TestLinalgTransforms() = default; 36 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 37 38 void getDependentDialects(DialectRegistry ®istry) const override { 39 // clang-format off 40 registry.insert<AffineDialect, 41 memref::MemRefDialect, 42 scf::SCFDialect, 43 StandardOpsDialect, 44 vector::VectorDialect, 45 gpu::GPUDialect>(); 46 // clang-format on 47 } 48 StringRef getArgument() const final { 49 return "test-linalg-transform-patterns"; 50 } 51 StringRef getDescription() const final { 52 return "Test Linalg transformation patterns by applying them greedily."; 53 } 54 55 void runOnFunction() override; 56 57 Option<bool> testPatterns{*this, "test-patterns", 58 llvm::cl::desc("Test a mixed set of patterns"), 59 llvm::cl::init(false)}; 60 Option<bool> testMatmulToVectorPatterns1dTiling{ 61 *this, "test-matmul-to-vector-patterns-tile-1d", 62 llvm::cl::desc( 63 "Test a fused pass that applies patterns from matmul to vectors via " 64 "1-d tiling"), 65 llvm::cl::init(false)}; 66 Option<bool> testMatmulToVectorPatterns2dTiling{ 67 *this, "test-matmul-to-vector-patterns-tile-2d", 68 llvm::cl::desc( 69 "Test a fused pass that applies patterns from matmul to vectors via " 70 "2-d tiling"), 71 llvm::cl::init(false)}; 72 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 73 llvm::cl::desc("Test promotion options"), 74 llvm::cl::init(false)}; 75 Option<bool> testTileAndDistributionOptions{ 76 *this, "test-tile-and-distribute-options", 77 llvm::cl::desc("Test tile and distribute options"), 78 llvm::cl::init(false)}; 79 Option<bool> testVectorTransferForwardingPatterns{ 80 *this, "test-vector-transfer-forwarding-patterns", 81 llvm::cl::desc( 82 "Test a fused pass that forwards linalg.copy to vector.transfer"), 83 llvm::cl::init(false)}; 84 Option<bool> testGenericToVectorPattern{ 85 *this, "test-linalg-to-vector-patterns", 86 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 87 "in vector.contract form"), 88 llvm::cl::init(false)}; 89 Option<bool> testTilePattern{*this, "test-tile-pattern", 90 llvm::cl::desc("Test tile pattern"), 91 llvm::cl::init(false)}; 92 Option<bool> testTileScalarizeDynamicDims{ 93 *this, "test-tile-scalarize-dynamic-dims", 94 llvm::cl::desc("Test tiling of dynamic dims by 1"), 95 llvm::cl::init(false)}; 96 Option<int> testHoistPadding{*this, "test-hoist-padding", 97 llvm::cl::desc("Test hoist padding"), 98 llvm::cl::init(0)}; 99 Option<bool> testTransformPadTensor{ 100 *this, "test-transform-pad-tensor", 101 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 102 llvm::cl::init(false)}; 103 Option<bool> testGeneralizePadTensor{ 104 *this, "test-generalize-pad-tensor", 105 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 106 llvm::cl::init(false)}; 107 Option<bool> testSwapSubTensorPadTensor{ 108 *this, "test-swap-subtensor-padtensor", 109 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 110 "pad_tensor(subtensor)"), 111 llvm::cl::init(false)}; 112 ListOption<int64_t> paddedOperands{ 113 *this, "padded-operands", 114 llvm::cl::desc("Operands to pad when test-tile-pattern"), 115 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 116 ListOption<int64_t> nofoldOperands{ 117 *this, "nofold-operands", 118 llvm::cl::desc("Operands to set nofold when test-tile-pattern"), 119 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 120 ListOption<int64_t> peeledLoops{ 121 *this, "peeled-loops", 122 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 123 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 124 ListOption<int64_t> tileSizes{ 125 *this, "tile-sizes", 126 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 127 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 128 ListOption<unsigned> testInterchangePattern{ 129 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 130 llvm::cl::desc("Test the interchange pattern.")}; 131 ListOption<unsigned> testTiledLoopPeeling{ 132 *this, "test-tiled-loop-peeling", 133 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 134 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 135 Option<bool> skipPartial{ 136 *this, "skip-partial", 137 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 138 llvm::cl::init(false)}; 139 Option<std::string> loopType{ 140 *this, "loop-type", 141 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 142 "tiled_loop"), 143 llvm::cl::init("for")}; 144 }; 145 } // end anonymous namespace 146 147 static void applyPatterns(FuncOp funcOp) { 148 MLIRContext *ctx = funcOp.getContext(); 149 RewritePatternSet patterns(ctx); 150 151 //===--------------------------------------------------------------------===// 152 // Linalg tiling patterns. 153 //===--------------------------------------------------------------------===// 154 patterns.add<LinalgTilingPattern<MatmulOp>>( 155 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 156 LinalgTransformationFilter(Identifier::get("MEM", ctx), 157 Identifier::get("L3", ctx))); 158 patterns.add<LinalgTilingPattern<MatmulOp>>( 159 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 160 LinalgTransformationFilter(Identifier::get("L3", ctx), 161 Identifier::get("L2", ctx))); 162 patterns.add<LinalgTilingPattern<MatmulOp>>( 163 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 164 LinalgTransformationFilter(Identifier::get("L2", ctx), 165 Identifier::get("L1", ctx))); 166 patterns.add<LinalgTilingPattern<MatmulOp>>( 167 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 168 LinalgTransformationFilter(Identifier::get("L1", ctx), 169 Identifier::get("REG", ctx))); 170 171 patterns.add<LinalgTilingPattern<MatvecOp>>( 172 ctx, 173 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 174 LinalgTilingLoopType::ParallelLoops), 175 LinalgTransformationFilter(ArrayRef<Identifier>{}, 176 Identifier::get("L1", ctx))); 177 178 patterns.add<LinalgTilingPattern<DotOp>>( 179 ctx, LinalgTilingOptions().setTileSizes(8000), 180 LinalgTransformationFilter( 181 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 182 Identifier::get("L3", ctx), 183 Identifier::get("L2", ctx)}, 184 Identifier::get("REG", ctx))); 185 186 //===--------------------------------------------------------------------===// 187 // Linalg tiling and permutation patterns. 188 //===--------------------------------------------------------------------===// 189 patterns.add<LinalgTilingPattern<MatmulOp>>( 190 ctx, 191 LinalgTilingOptions() 192 .setTileSizes({2000, 3000, 4000}) 193 .setInterchange({1, 2, 0}), 194 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 195 Identifier::get("L2__with_perm__", ctx))); 196 patterns.add<LinalgTilingPattern<MatmulOp>>( 197 ctx, 198 LinalgTilingOptions() 199 .setTileSizes({200, 300, 400}) 200 .setInterchange({1, 0, 2}), 201 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 202 Identifier::get("L1__with_perm__", ctx))); 203 patterns.add<LinalgTilingPattern<MatmulOp>>( 204 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 205 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 206 Identifier::get("REG__with_perm__", ctx))); 207 208 patterns.add<LinalgTilingPattern<MatvecOp>>( 209 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 210 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 211 Identifier::get("L1__with_perm__", ctx))); 212 213 patterns.add<LinalgTilingPattern<MatmulOp>>( 214 ctx, 215 LinalgTilingOptions() 216 .setTileSizes({16, 8, 4}) 217 .setInterchange({1, 2, 0}) 218 .setLoopType(LinalgTilingLoopType::ParallelLoops), 219 LinalgTransformationFilter( 220 Identifier::get("par__with_perm__", ctx), 221 Identifier::get("after_par__with_perm__", ctx))); 222 223 //===--------------------------------------------------------------------===// 224 // Linalg to loops patterns. 225 //===--------------------------------------------------------------------===// 226 patterns.add<LinalgLoweringPattern<DotOp>>( 227 ctx, 228 /*loweringType=*/LinalgLoweringType::Loops, 229 LinalgTransformationFilter(Identifier::get("REG", ctx))); 230 231 //===--------------------------------------------------------------------===// 232 // Linalg distribution patterns. 233 //===--------------------------------------------------------------------===// 234 LinalgLoopDistributionOptions distributionOptions; 235 236 //===--------------------------------------------------------------------===// 237 // Linalg to vector contraction patterns. 238 //===--------------------------------------------------------------------===// 239 patterns.add<LinalgVectorizationPattern>( 240 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 241 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 242 243 //===--------------------------------------------------------------------===// 244 // Linalg generic interchange pattern. 245 //===--------------------------------------------------------------------===// 246 patterns.add<GenericOpInterchangePattern>( 247 ctx, 248 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 249 LinalgTransformationFilter(ArrayRef<Identifier>{}, 250 Identifier::get("PERMUTED", ctx))); 251 252 //===--------------------------------------------------------------------===// 253 // Linalg subview operands promotion. 254 //===--------------------------------------------------------------------===// 255 patterns.add<LinalgPromotionPattern<MatmulOp>>( 256 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 257 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 258 Identifier::get("_views_promoted_", ctx))); 259 patterns.add<LinalgPromotionPattern<MatmulOp>>( 260 ctx, 261 LinalgPromotionOptions() 262 .setOperandsToPromote({0}) 263 .setUseFullTileBuffersByDefault(true), 264 LinalgTransformationFilter( 265 Identifier::get("_promote_first_view_", ctx), 266 Identifier::get("_first_view_promoted_", ctx))); 267 patterns.add<LinalgPromotionPattern<FillOp>>( 268 ctx, 269 LinalgPromotionOptions() 270 .setOperandsToPromote({1}) 271 .setUseFullTileBuffers({false, true}) 272 .setAlignment(32), 273 LinalgTransformationFilter( 274 Identifier::get("_promote_views_aligned_", ctx), 275 Identifier::get("_views_aligned_promoted_", ctx))); 276 277 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 278 279 // Drop the marker. 280 funcOp.walk([](LinalgOp op) { 281 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 282 }); 283 } 284 285 static void fillL1TilingAndMatmulToVectorPatterns( 286 FuncOp funcOp, StringRef startMarker, 287 SmallVectorImpl<RewritePatternSet> &patternsVector) { 288 MLIRContext *ctx = funcOp.getContext(); 289 patternsVector.emplace_back( 290 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 291 ctx, 292 LinalgTilingOptions() 293 .setTileSizes({8, 12, 16}) 294 .setInterchange({1, 0, 2}), 295 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 296 Identifier::get("L1", ctx)))); 297 298 patternsVector.emplace_back( 299 ctx, 300 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 301 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 302 LinalgTransformationFilter(Identifier::get("L1", ctx), 303 Identifier::get("VEC", ctx)))); 304 305 patternsVector.emplace_back( 306 ctx, std::make_unique<LinalgVectorizationPattern>( 307 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 308 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 309 patternsVector.back().add<LinalgVectorizationPattern>( 310 ctx, LinalgTransformationFilter().addFilter( 311 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // Test promotion callbacks 316 //===----------------------------------------------------------------------===// 317 318 // Allocation call back 319 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 320 ArrayRef<Value> boundingSubViewSize, 321 DataLayout &layout) { 322 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 323 return b 324 .create<memref::AllocOp>( 325 subView.getLoc(), 326 MemRefType::get(shape, subView.getType().getElementType(), 327 /*affineMapComposition =*/{}, 3), 328 boundingSubViewSize) 329 .getResult(); 330 } 331 332 // Deallocation callback 333 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 334 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 335 return success(); 336 } 337 338 // Copy in call back 339 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 340 bool isOutput) { 341 auto floatType = src.getType().cast<MemRefType>().getElementType(); 342 if (!floatType.isa<FloatType>()) 343 return failure(); 344 if (!isOutput) { 345 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 346 FloatAttr::get(floatType, 42.0)); 347 b.create<FillOp>(src.getLoc(), cst, dst); 348 } 349 b.create<CopyOp>(src.getLoc(), src, dst); 350 return success(); 351 } 352 353 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 354 RewritePatternSet &patterns) { 355 patterns.add<LinalgTilingPattern<MatmulOp>>( 356 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 357 LinalgTransformationFilter(Identifier::get("START", ctx), 358 Identifier::get("PROMOTE", ctx))); 359 patterns.add<LinalgPromotionPattern<MatmulOp>>( 360 ctx, 361 LinalgPromotionOptions() 362 .setOperandsToPromote({0, 2}) 363 .setUseFullTileBuffers({false, false}) 364 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 365 .setCopyInOutFns( 366 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 367 return copyCallBackFn(b, src, dst, false); 368 }, 369 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 370 return copyCallBackFn(b, src, dst, true); 371 }), 372 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 373 } 374 375 template <typename IdOp, typename NProcsOp> 376 static SmallVector<ProcInfo, 2> 377 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 378 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 379 SmallVector<ProcInfo, 2> procInfo(count); 380 const char *xyz[] = {"x", "y", "z"}; 381 Type indexType = b.getIndexType(); 382 for (unsigned i = 0; i < count; ++i) { 383 procInfo[count - 1 - i] = { 384 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 385 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 386 } 387 return procInfo; 388 } 389 390 static void fillTileAndDistributePatterns(MLIRContext *context, 391 RewritePatternSet &patterns) { 392 { 393 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 394 cyclicNprocsEqNiters.distributionMethod.resize( 395 2, DistributionMethod::CyclicNumProcsEqNumIters); 396 cyclicNprocsEqNiters.procInfo = 397 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 398 patterns.add<LinalgTilingPattern<MatmulOp>>( 399 context, 400 LinalgTilingOptions() 401 .setTileSizes({8, 8, 4}) 402 .setLoopType(LinalgTilingLoopType::ParallelLoops) 403 .setDistributionOptions(cyclicNprocsEqNiters), 404 LinalgTransformationFilter( 405 Identifier::get("distribute1", context), 406 Identifier::get("after_distribute1", context))); 407 } 408 409 { 410 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 411 cyclicNprocsGeNiters.distributionMethod.resize( 412 2, DistributionMethod::CyclicNumProcsGeNumIters); 413 cyclicNprocsGeNiters.procInfo = 414 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 415 patterns.add<LinalgTilingPattern<MatmulOp>>( 416 context, 417 LinalgTilingOptions() 418 .setTileSizes({8, 8, 4}) 419 .setLoopType(LinalgTilingLoopType::ParallelLoops) 420 .setDistributionOptions(cyclicNprocsGeNiters), 421 LinalgTransformationFilter( 422 Identifier::get("distribute2", context), 423 Identifier::get("after_distribute2", context))); 424 } 425 426 { 427 LinalgLoopDistributionOptions cyclicNprocsDefault; 428 cyclicNprocsDefault.distributionMethod.resize(2, 429 DistributionMethod::Cyclic); 430 cyclicNprocsDefault.procInfo = 431 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 432 patterns.add<LinalgTilingPattern<MatmulOp>>( 433 context, 434 LinalgTilingOptions() 435 .setTileSizes({8, 8, 4}) 436 .setLoopType(LinalgTilingLoopType::ParallelLoops) 437 .setDistributionOptions(cyclicNprocsDefault), 438 LinalgTransformationFilter( 439 Identifier::get("distribute3", context), 440 Identifier::get("after_distribute3", context))); 441 } 442 443 { 444 LinalgLoopDistributionOptions cyclicNprocsMixed1; 445 cyclicNprocsMixed1.distributionMethod = { 446 DistributionMethod::CyclicNumProcsEqNumIters, 447 DistributionMethod::CyclicNumProcsGeNumIters}; 448 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 449 patterns.add<LinalgTilingPattern<MatmulOp>>( 450 context, 451 LinalgTilingOptions() 452 .setTileSizes({8, 8, 4}) 453 .setLoopType(LinalgTilingLoopType::ParallelLoops) 454 .setDistributionOptions(cyclicNprocsMixed1), 455 LinalgTransformationFilter( 456 Identifier::get("distribute4", context), 457 Identifier::get("after_distribute4", context))); 458 } 459 460 { 461 LinalgLoopDistributionOptions cyclicNprocsMixed2; 462 cyclicNprocsMixed2.distributionMethod = { 463 DistributionMethod::CyclicNumProcsGeNumIters, 464 DistributionMethod::Cyclic}; 465 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 466 patterns.add<LinalgTilingPattern<MatmulOp>>( 467 context, 468 LinalgTilingOptions() 469 .setTileSizes({8, 8, 4}) 470 .setLoopType(LinalgTilingLoopType::ParallelLoops) 471 .setDistributionOptions(cyclicNprocsMixed2), 472 LinalgTransformationFilter( 473 Identifier::get("distribute5", context), 474 Identifier::get("after_distribute5", context))); 475 } 476 477 { 478 LinalgLoopDistributionOptions cyclicNprocsMixed3; 479 cyclicNprocsMixed3.distributionMethod = { 480 DistributionMethod::Cyclic, 481 DistributionMethod::CyclicNumProcsEqNumIters}; 482 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 483 484 patterns.add<LinalgTilingPattern<MatmulOp>>( 485 context, 486 LinalgTilingOptions() 487 .setTileSizes({8, 8, 4}) 488 .setLoopType(LinalgTilingLoopType::ParallelLoops) 489 .setDistributionOptions(cyclicNprocsMixed3), 490 LinalgTransformationFilter( 491 Identifier::get("distribute6", context), 492 Identifier::get("after_distribute6", context))); 493 } 494 495 { 496 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 497 cyclicNprocsEqNiters.distributionMethod.resize(2, 498 DistributionMethod::Cyclic); 499 cyclicNprocsEqNiters.procInfo = 500 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 501 patterns.add<LinalgTilingPattern<MatmulOp>>( 502 context, 503 LinalgTilingOptions() 504 .setTileSizes({8, 8, 4}) 505 .setLoopType(LinalgTilingLoopType::Loops) 506 .setDistributionOptions(cyclicNprocsEqNiters), 507 LinalgTransformationFilter( 508 Identifier::get("tensors_distribute1", context), 509 Identifier::get("tensors_after_distribute1", context))); 510 } 511 } 512 513 static void 514 applyMatmulToVectorPatterns(FuncOp funcOp, 515 bool testMatmulToVectorPatterns1dTiling, 516 bool testMatmulToVectorPatterns2dTiling) { 517 MLIRContext *ctx = funcOp.getContext(); 518 SmallVector<RewritePatternSet, 4> stage1Patterns; 519 if (testMatmulToVectorPatterns1dTiling) { 520 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 521 stage1Patterns); 522 } else if (testMatmulToVectorPatterns2dTiling) { 523 stage1Patterns.emplace_back( 524 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 525 ctx, 526 LinalgTilingOptions() 527 .setTileSizes({768, 264, 768}) 528 .setInterchange({1, 2, 0}), 529 LinalgTransformationFilter(Identifier::get("START", ctx), 530 Identifier::get("L2", ctx)))); 531 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 532 stage1Patterns); 533 } 534 { 535 // Canonicalization patterns 536 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 537 vector::populateVectorTransferPermutationMapLoweringPatterns( 538 canonicalizationPatterns); 539 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 540 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 541 } 542 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 543 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 544 FrozenRewritePatternSet stage2Patterns = 545 getLinalgTilingCanonicalizationPatterns(ctx); 546 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 547 std::move(stage2Patterns)); 548 } 549 550 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 551 RewritePatternSet forwardPattern(funcOp.getContext()); 552 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 553 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 554 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 555 } 556 557 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 558 RewritePatternSet patterns(funcOp.getContext()); 559 patterns.add<LinalgVectorizationPattern>( 560 funcOp.getContext(), 561 LinalgTransformationFilter() 562 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 563 populatePadTensorOpVectorizationPatterns(patterns); 564 populateConvolutionVectorizationPatterns(patterns); 565 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 566 } 567 568 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 569 RewritePatternSet patterns(funcOp.getContext()); 570 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 571 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 572 } 573 574 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 575 RewritePatternSet patterns(funcOp.getContext()); 576 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 577 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 578 } 579 580 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 581 RewritePatternSet patterns(funcOp.getContext()); 582 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 583 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 584 } 585 586 // For now, just assume it is the zero of type. 587 // In the future, it should be the zero of type + op. 588 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 589 auto t = getElementTypeOrSelf(op.get()); 590 return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t, 591 b.getZeroAttr(t)); 592 } 593 594 static void applyTilePattern(FuncOp funcOp, std::string loopType, 595 ArrayRef<int64_t> tileSizes, 596 ArrayRef<int64_t> paddedOperands, 597 ArrayRef<int64_t> nofoldOperands, 598 ArrayRef<int64_t> peeledLoops, 599 bool scalarizeDynamicDims) { 600 MLIRContext *context = funcOp.getContext(); 601 RewritePatternSet tilingPattern(context); 602 LinalgTilingLoopType type = 603 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 604 .Case("for", LinalgTilingLoopType::Loops) 605 .Case("affine", LinalgTilingLoopType::AffineLoops) 606 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 607 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 608 auto linalgTilingOptions = linalg::LinalgTilingOptions() 609 .setPeeledLoops(peeledLoops) 610 .setLoopType(type); 611 if (scalarizeDynamicDims) { 612 linalgTilingOptions.scalarizeDynamicDims(); 613 assert(tileSizes.empty() && 614 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 615 } else { 616 linalgTilingOptions.setTileSizes(tileSizes); 617 } 618 if (!paddedOperands.empty()) { 619 auto paddingFunc = [&](OpBuilder &b, 620 OpOperand &opOperand) -> FailureOr<Value> { 621 if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0) 622 return failure(); 623 return getNeutralOfLinalgOp(b, opOperand); 624 }; 625 auto nofoldFunc = [&](OpOperand &opOperand) { 626 if (llvm::count(nofoldOperands, opOperand.getOperandNumber()) != 0) 627 return true; 628 return false; 629 }; 630 linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc); 631 linalgTilingOptions.setPaddingNoFoldComputationFunction(nofoldFunc); 632 } 633 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, 634 linalg::LinalgTilingPattern<linalg::GenericOp>>( 635 context, linalgTilingOptions, 636 linalg::LinalgTransformationFilter(Identifier::get("tile", context))); 637 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 638 } 639 640 static void applyInterchangePattern(FuncOp funcOp, 641 ArrayRef<unsigned> interchangeVector) { 642 MLIRContext *context = funcOp.getContext(); 643 RewritePatternSet interchangePattern(context); 644 interchangePattern.add<GenericOpInterchangePattern>( 645 context, interchangeVector, 646 LinalgTransformationFilter(ArrayRef<Identifier>{}, 647 Identifier::get("interchange", context))); 648 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 649 } 650 651 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 652 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 653 654 namespace { 655 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 656 /// `idx`-th loop contains only "full" iterations and a second loop for the 657 /// remaining partial iteration (if any). 658 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 659 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 660 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 661 } 662 663 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 664 PatternRewriter &rewriter) const override { 665 SmallVector<int64_t> peeledLoops; 666 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 667 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 668 peeledLoops = 669 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 670 return attr.cast<IntegerAttr>().getInt(); 671 })); 672 // Check if the loop was already peeled. 673 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 674 return failure(); 675 } 676 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 677 // No peeling of loop nests with a partial iteration. 678 return failure(); 679 680 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 681 return failure(); 682 683 // Peel loop and canonicalize. 684 TiledLoopOp result; 685 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 686 result))) 687 return failure(); 688 689 // Apply label, so that the same loop is not rewritten a second time. 690 peeledLoops.push_back(idx); 691 rewriter.updateRootInPlace(loopOp, [&]() { 692 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 693 }); 694 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 695 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 696 697 return success(); 698 } 699 700 /// Index of loop to peel. 701 int64_t idx; 702 703 /// If set to true, do not peel TiledLoopOps with a partial iteration. 704 bool skipPartial; 705 }; 706 } // namespace 707 708 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 709 ArrayRef<unsigned> loops, 710 bool skipPartial) { 711 MLIRContext *ctx = funcOp.getContext(); 712 RewritePatternSet patterns(ctx); 713 for (unsigned idx : loops) 714 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 715 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 716 717 // Drop the markers. 718 funcOp.walk([](TiledLoopOp op) { 719 op->removeAttr(kPeeledLoopsLabel); 720 op->removeAttr(kPartialIterationLabel); 721 }); 722 } 723 724 /// Apply transformations specified as patterns. 725 void TestLinalgTransforms::runOnFunction() { 726 auto lambda = [&](void *) { 727 getFunction().walk([](LinalgOp op) { 728 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 729 }); 730 }; 731 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 732 733 if (testPromotionOptions) { 734 RewritePatternSet patterns(&getContext()); 735 fillPromotionCallBackPatterns(&getContext(), patterns); 736 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 737 return; 738 } 739 if (testTileAndDistributionOptions) { 740 RewritePatternSet patterns(&getContext()); 741 fillTileAndDistributePatterns(&getContext(), patterns); 742 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 743 return; 744 } 745 if (testPatterns) 746 return applyPatterns(getFunction()); 747 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 748 return applyMatmulToVectorPatterns(getFunction(), 749 testMatmulToVectorPatterns1dTiling, 750 testMatmulToVectorPatterns2dTiling); 751 if (testVectorTransferForwardingPatterns) 752 return applyVectorTransferForwardingPatterns(getFunction()); 753 if (testGenericToVectorPattern) 754 return applyLinalgToVectorPatterns(getFunction()); 755 if (testTransformPadTensor) 756 return applyPadTensorToGenericPatterns(getFunction()); 757 if (testGeneralizePadTensor) 758 return applyGeneralizePadTensorPatterns(getFunction()); 759 if (testSwapSubTensorPadTensor) 760 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 761 if (testTiledLoopPeeling.hasValue()) 762 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, 763 skipPartial); 764 if (testTilePattern) 765 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 766 nofoldOperands, peeledLoops, 767 /*scalarizeDynamicDims=*/false); 768 if (testTileScalarizeDynamicDims) 769 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 770 nofoldOperands, 771 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 772 if (testHoistPadding) { 773 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 774 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 775 }); 776 } 777 if (testInterchangePattern.hasValue()) 778 return applyInterchangePattern(getFunction(), testInterchangePattern); 779 } 780 781 namespace mlir { 782 namespace test { 783 void registerTestLinalgTransforms() { 784 PassRegistration<TestLinalgTransforms>(); 785 } 786 } // namespace test 787 } // namespace mlir 788