1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 #include "llvm/ADT/SetVector.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 namespace { 30 struct TestLinalgTransforms 31 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 32 TestLinalgTransforms() = default; 33 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 34 35 void getDependentDialects(DialectRegistry ®istry) const override { 36 // clang-format off 37 registry.insert<AffineDialect, 38 memref::MemRefDialect, 39 scf::SCFDialect, 40 StandardOpsDialect, 41 vector::VectorDialect, 42 gpu::GPUDialect>(); 43 // clang-format on 44 } 45 46 void runOnFunction() override; 47 48 Option<bool> testPatterns{*this, "test-patterns", 49 llvm::cl::desc("Test a mixed set of patterns"), 50 llvm::cl::init(false)}; 51 Option<bool> testMatmulToVectorPatterns1dTiling{ 52 *this, "test-matmul-to-vector-patterns-tile-1d", 53 llvm::cl::desc( 54 "Test a fused pass that applies patterns from matmul to vectors via " 55 "1-d tiling"), 56 llvm::cl::init(false)}; 57 Option<bool> testMatmulToVectorPatterns2dTiling{ 58 *this, "test-matmul-to-vector-patterns-tile-2d", 59 llvm::cl::desc( 60 "Test a fused pass that applies patterns from matmul to vectors via " 61 "2-d tiling"), 62 llvm::cl::init(false)}; 63 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 64 llvm::cl::desc("Test promotion options"), 65 llvm::cl::init(false)}; 66 Option<bool> testTileAndDistributionOptions{ 67 *this, "test-tile-and-distribute-options", 68 llvm::cl::desc("Test tile and distribute options"), 69 llvm::cl::init(false)}; 70 Option<bool> testVectorTransferForwardingPatterns{ 71 *this, "test-vector-transfer-forwarding-patterns", 72 llvm::cl::desc( 73 "Test a fused pass that forwards linalg.copy to vector.transfer"), 74 llvm::cl::init(false)}; 75 Option<bool> testGenericToVectorPattern{ 76 *this, "test-linalg-to-vector-patterns", 77 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 78 "in vector.contract form"), 79 llvm::cl::init(false)}; 80 Option<bool> testAffineMinSCFCanonicalizationPatterns{ 81 *this, "test-affine-min-scf-canonicalization-patterns", 82 llvm::cl::desc("Test affine-min + scf canonicalization patterns."), 83 llvm::cl::init(false)}; 84 Option<bool> testTileAndPadPattern{ 85 *this, "test-tile-and-pad-pattern", 86 llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)}; 87 Option<int> testHoistPadding{*this, "test-hoist-padding", 88 llvm::cl::desc("Test hoist padding"), 89 llvm::cl::init(0)}; 90 ListOption<int64_t> tileSizesForPadding{ 91 *this, "tile-sizes-for-padding", 92 llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, 93 llvm::cl::MiscFlags::CommaSeparated}; 94 ListOption<unsigned> testInterchangePattern{ 95 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 96 llvm::cl::desc("Test the interchange pattern.")}; 97 }; 98 } // end anonymous namespace 99 100 static void applyPatterns(FuncOp funcOp) { 101 MLIRContext *ctx = funcOp.getContext(); 102 RewritePatternSet patterns(ctx); 103 104 //===--------------------------------------------------------------------===// 105 // Linalg tiling patterns. 106 //===--------------------------------------------------------------------===// 107 patterns.add<LinalgTilingPattern<MatmulOp>>( 108 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 109 LinalgTransformationFilter(Identifier::get("MEM", ctx), 110 Identifier::get("L3", ctx))); 111 patterns.add<LinalgTilingPattern<MatmulOp>>( 112 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 113 LinalgTransformationFilter(Identifier::get("L3", ctx), 114 Identifier::get("L2", ctx))); 115 patterns.add<LinalgTilingPattern<MatmulOp>>( 116 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 117 LinalgTransformationFilter(Identifier::get("L2", ctx), 118 Identifier::get("L1", ctx))); 119 patterns.add<LinalgTilingPattern<MatmulOp>>( 120 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 121 LinalgTransformationFilter(Identifier::get("L1", ctx), 122 Identifier::get("REG", ctx))); 123 124 patterns.add<LinalgTilingPattern<MatvecOp>>( 125 ctx, 126 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 127 LinalgTilingLoopType::ParallelLoops), 128 LinalgTransformationFilter(ArrayRef<Identifier>{}, 129 Identifier::get("L1", ctx))); 130 131 patterns.add<LinalgTilingPattern<DotOp>>( 132 ctx, LinalgTilingOptions().setTileSizes(8000), 133 LinalgTransformationFilter( 134 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 135 Identifier::get("L3", ctx), 136 Identifier::get("L2", ctx)}, 137 Identifier::get("REG", ctx))); 138 139 //===--------------------------------------------------------------------===// 140 // Linalg tiling and permutation patterns. 141 //===--------------------------------------------------------------------===// 142 patterns.add<LinalgTilingPattern<MatmulOp>>( 143 ctx, 144 LinalgTilingOptions() 145 .setTileSizes({2000, 3000, 4000}) 146 .setInterchange({1, 2, 0}), 147 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 148 Identifier::get("L2__with_perm__", ctx))); 149 patterns.add<LinalgTilingPattern<MatmulOp>>( 150 ctx, 151 LinalgTilingOptions() 152 .setTileSizes({200, 300, 400}) 153 .setInterchange({1, 0, 2}), 154 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 155 Identifier::get("L1__with_perm__", ctx))); 156 patterns.add<LinalgTilingPattern<MatmulOp>>( 157 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 158 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 159 Identifier::get("REG__with_perm__", ctx))); 160 161 patterns.add<LinalgTilingPattern<MatvecOp>>( 162 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 163 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 164 Identifier::get("L1__with_perm__", ctx))); 165 166 patterns.add<LinalgTilingPattern<MatmulOp>>( 167 ctx, 168 LinalgTilingOptions() 169 .setTileSizes({16, 8, 4}) 170 .setInterchange({1, 2, 0}) 171 .setLoopType(LinalgTilingLoopType::ParallelLoops), 172 LinalgTransformationFilter( 173 Identifier::get("par__with_perm__", ctx), 174 Identifier::get("after_par__with_perm__", ctx))); 175 176 //===--------------------------------------------------------------------===// 177 // Linalg to loops patterns. 178 //===--------------------------------------------------------------------===// 179 patterns.add<LinalgLoweringPattern<DotOp>>( 180 ctx, 181 /*loweringType=*/LinalgLoweringType::Loops, 182 LinalgTransformationFilter(Identifier::get("REG", ctx))); 183 184 //===--------------------------------------------------------------------===// 185 // Linalg distribution patterns. 186 //===--------------------------------------------------------------------===// 187 LinalgLoopDistributionOptions distributionOptions; 188 189 //===--------------------------------------------------------------------===// 190 // Linalg to vector contraction patterns. 191 //===--------------------------------------------------------------------===// 192 patterns.add<LinalgVectorizationPattern>( 193 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 194 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 195 196 //===--------------------------------------------------------------------===// 197 // Linalg generic interchange pattern. 198 //===--------------------------------------------------------------------===// 199 patterns.add<GenericOpInterchangePattern>( 200 ctx, 201 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 202 LinalgTransformationFilter(ArrayRef<Identifier>{}, 203 Identifier::get("PERMUTED", ctx))); 204 205 //===--------------------------------------------------------------------===// 206 // Linalg subview operands promotion. 207 //===--------------------------------------------------------------------===// 208 patterns.add<LinalgPromotionPattern<MatmulOp>>( 209 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 210 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 211 Identifier::get("_views_promoted_", ctx))); 212 patterns.add<LinalgPromotionPattern<MatmulOp>>( 213 ctx, 214 LinalgPromotionOptions() 215 .setOperandsToPromote({0}) 216 .setUseFullTileBuffersByDefault(true), 217 LinalgTransformationFilter( 218 Identifier::get("_promote_first_view_", ctx), 219 Identifier::get("_first_view_promoted_", ctx))); 220 patterns.add<LinalgPromotionPattern<FillOp>>( 221 ctx, 222 LinalgPromotionOptions() 223 .setOperandsToPromote({0}) 224 .setUseFullTileBuffers({true}) 225 .setAlignment(32), 226 LinalgTransformationFilter( 227 Identifier::get("_promote_views_aligned_", ctx), 228 Identifier::get("_views_aligned_promoted_", ctx))); 229 230 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 231 232 // Drop the marker. 233 funcOp.walk([](LinalgOp op) { 234 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 235 }); 236 } 237 238 static void fillL1TilingAndMatmulToVectorPatterns( 239 FuncOp funcOp, StringRef startMarker, 240 SmallVectorImpl<RewritePatternSet> &patternsVector) { 241 MLIRContext *ctx = funcOp.getContext(); 242 patternsVector.emplace_back( 243 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 244 ctx, 245 LinalgTilingOptions() 246 .setTileSizes({8, 12, 16}) 247 .setInterchange({1, 0, 2}), 248 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 249 Identifier::get("L1", ctx)))); 250 251 patternsVector.emplace_back( 252 ctx, 253 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 254 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 255 LinalgTransformationFilter(Identifier::get("L1", ctx), 256 Identifier::get("VEC", ctx)))); 257 258 patternsVector.emplace_back( 259 ctx, std::make_unique<LinalgVectorizationPattern>( 260 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 261 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 262 patternsVector.back().add<LinalgVectorizationPattern>( 263 ctx, LinalgTransformationFilter().addFilter( 264 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 265 } 266 267 //===----------------------------------------------------------------------===// 268 // Test promotion callbacks 269 //===----------------------------------------------------------------------===// 270 271 // Allocation call back 272 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 273 ArrayRef<Value> boundingSubViewSize, 274 DataLayout &layout, 275 OperationFolder *folder) { 276 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 277 return b 278 .create<memref::AllocOp>( 279 subView.getLoc(), 280 MemRefType::get(shape, subView.getType().getElementType(), 281 /*affineMapComposition =*/{}, 3), 282 boundingSubViewSize) 283 .getResult(); 284 } 285 286 // Deallocation callback 287 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 288 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 289 return success(); 290 } 291 292 // Copy in call back 293 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 294 bool isOutput) { 295 auto floatType = src.getType().cast<MemRefType>().getElementType(); 296 if (!floatType.isa<FloatType>()) 297 return failure(); 298 if (!isOutput) 299 b.create<FillOp>( 300 src.getLoc(), dst, 301 b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0))); 302 b.create<CopyOp>(src.getLoc(), src, dst); 303 return success(); 304 } 305 306 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 307 RewritePatternSet &patterns) { 308 patterns.add<LinalgTilingPattern<MatmulOp>>( 309 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 310 LinalgTransformationFilter(Identifier::get("START", ctx), 311 Identifier::get("PROMOTE", ctx))); 312 patterns.add<LinalgPromotionPattern<MatmulOp>>( 313 ctx, 314 LinalgPromotionOptions() 315 .setOperandsToPromote({0, 2}) 316 .setUseFullTileBuffers({false, false}) 317 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 318 .setCopyInOutFns( 319 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 320 return copyCallBackFn(b, src, dst, false); 321 }, 322 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 323 return copyCallBackFn(b, src, dst, true); 324 }), 325 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 326 } 327 328 template <typename IdOp, typename NProcsOp> 329 static SmallVector<ProcInfo, 2> 330 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 331 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 332 SmallVector<ProcInfo, 2> procInfo(count); 333 const char *xyz[] = {"x", "y", "z"}; 334 Type indexType = b.getIndexType(); 335 for (unsigned i = 0; i < count; ++i) { 336 procInfo[count - 1 - i] = { 337 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 338 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 339 } 340 return procInfo; 341 } 342 343 static void fillTileAndDistributePatterns(MLIRContext *context, 344 RewritePatternSet &patterns) { 345 { 346 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 347 cyclicNprocsEqNiters.distributionMethod.resize( 348 2, DistributionMethod::CyclicNumProcsEqNumIters); 349 cyclicNprocsEqNiters.procInfo = 350 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 351 patterns.add<LinalgTilingPattern<MatmulOp>>( 352 context, 353 LinalgTilingOptions() 354 .setTileSizes({8, 8, 4}) 355 .setLoopType(LinalgTilingLoopType::ParallelLoops) 356 .setDistributionOptions(cyclicNprocsEqNiters), 357 LinalgTransformationFilter( 358 Identifier::get("distribute1", context), 359 Identifier::get("after_distribute1", context))); 360 } 361 362 { 363 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 364 cyclicNprocsGeNiters.distributionMethod.resize( 365 2, DistributionMethod::CyclicNumProcsGeNumIters); 366 cyclicNprocsGeNiters.procInfo = 367 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 368 patterns.add<LinalgTilingPattern<MatmulOp>>( 369 context, 370 LinalgTilingOptions() 371 .setTileSizes({8, 8, 4}) 372 .setLoopType(LinalgTilingLoopType::ParallelLoops) 373 .setDistributionOptions(cyclicNprocsGeNiters), 374 LinalgTransformationFilter( 375 Identifier::get("distribute2", context), 376 Identifier::get("after_distribute2", context))); 377 } 378 379 { 380 LinalgLoopDistributionOptions cyclicNprocsDefault; 381 cyclicNprocsDefault.distributionMethod.resize(2, 382 DistributionMethod::Cyclic); 383 cyclicNprocsDefault.procInfo = 384 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 385 patterns.add<LinalgTilingPattern<MatmulOp>>( 386 context, 387 LinalgTilingOptions() 388 .setTileSizes({8, 8, 4}) 389 .setLoopType(LinalgTilingLoopType::ParallelLoops) 390 .setDistributionOptions(cyclicNprocsDefault), 391 LinalgTransformationFilter( 392 Identifier::get("distribute3", context), 393 Identifier::get("after_distribute3", context))); 394 } 395 396 { 397 LinalgLoopDistributionOptions cyclicNprocsMixed1; 398 cyclicNprocsMixed1.distributionMethod = { 399 DistributionMethod::CyclicNumProcsEqNumIters, 400 DistributionMethod::CyclicNumProcsGeNumIters}; 401 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 402 patterns.add<LinalgTilingPattern<MatmulOp>>( 403 context, 404 LinalgTilingOptions() 405 .setTileSizes({8, 8, 4}) 406 .setLoopType(LinalgTilingLoopType::ParallelLoops) 407 .setDistributionOptions(cyclicNprocsMixed1), 408 LinalgTransformationFilter( 409 Identifier::get("distribute4", context), 410 Identifier::get("after_distribute4", context))); 411 } 412 413 { 414 LinalgLoopDistributionOptions cyclicNprocsMixed2; 415 cyclicNprocsMixed2.distributionMethod = { 416 DistributionMethod::CyclicNumProcsGeNumIters, 417 DistributionMethod::Cyclic}; 418 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 419 patterns.add<LinalgTilingPattern<MatmulOp>>( 420 context, 421 LinalgTilingOptions() 422 .setTileSizes({8, 8, 4}) 423 .setLoopType(LinalgTilingLoopType::ParallelLoops) 424 .setDistributionOptions(cyclicNprocsMixed2), 425 LinalgTransformationFilter( 426 Identifier::get("distribute5", context), 427 Identifier::get("after_distribute5", context))); 428 } 429 430 { 431 LinalgLoopDistributionOptions cyclicNprocsMixed3; 432 cyclicNprocsMixed3.distributionMethod = { 433 DistributionMethod::Cyclic, 434 DistributionMethod::CyclicNumProcsEqNumIters}; 435 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 436 437 patterns.add<LinalgTilingPattern<MatmulOp>>( 438 context, 439 LinalgTilingOptions() 440 .setTileSizes({8, 8, 4}) 441 .setLoopType(LinalgTilingLoopType::ParallelLoops) 442 .setDistributionOptions(cyclicNprocsMixed3), 443 LinalgTransformationFilter( 444 Identifier::get("distribute6", context), 445 Identifier::get("after_distribute6", context))); 446 } 447 448 { 449 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 450 cyclicNprocsEqNiters.distributionMethod.resize(2, 451 DistributionMethod::Cyclic); 452 cyclicNprocsEqNiters.procInfo = 453 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 454 patterns.add<LinalgTilingPattern<MatmulOp>>( 455 context, 456 LinalgTilingOptions() 457 .setTileSizes({8, 8, 4}) 458 .setLoopType(LinalgTilingLoopType::Loops) 459 .setDistributionOptions(cyclicNprocsEqNiters), 460 LinalgTransformationFilter( 461 Identifier::get("tensors_distribute1", context), 462 Identifier::get("tensors_after_distribute1", context))); 463 } 464 } 465 466 static void 467 applyMatmulToVectorPatterns(FuncOp funcOp, 468 bool testMatmulToVectorPatterns1dTiling, 469 bool testMatmulToVectorPatterns2dTiling) { 470 MLIRContext *ctx = funcOp.getContext(); 471 SmallVector<RewritePatternSet, 4> stage1Patterns; 472 if (testMatmulToVectorPatterns1dTiling) { 473 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 474 stage1Patterns); 475 } else if (testMatmulToVectorPatterns2dTiling) { 476 stage1Patterns.emplace_back( 477 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 478 ctx, 479 LinalgTilingOptions() 480 .setTileSizes({768, 264, 768}) 481 .setInterchange({1, 2, 0}), 482 LinalgTransformationFilter(Identifier::get("START", ctx), 483 Identifier::get("L2", ctx)))); 484 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 485 stage1Patterns); 486 } 487 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 488 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 489 FrozenRewritePatternSet stage2Patterns = 490 getLinalgTilingCanonicalizationPatterns(ctx); 491 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 492 std::move(stage2Patterns)); 493 } 494 495 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 496 RewritePatternSet forwardPattern(funcOp.getContext()); 497 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 498 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 499 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 500 } 501 502 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 503 RewritePatternSet patterns(funcOp.getContext()); 504 patterns.add<LinalgVectorizationPattern>( 505 funcOp.getContext(), 506 LinalgTransformationFilter() 507 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 508 patterns.add<PadTensorOpVectorizationPattern>(funcOp.getContext()); 509 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 510 } 511 512 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) { 513 RewritePatternSet foldPattern(funcOp.getContext()); 514 foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext()); 515 FrozenRewritePatternSet frozenPatterns(std::move(foldPattern)); 516 517 // Explicitly walk and apply the pattern locally to avoid more general folding 518 // on the rest of the IR. 519 funcOp.walk([&frozenPatterns](AffineMinOp minOp) { 520 (void)applyOpPatternsAndFold(minOp, frozenPatterns); 521 }); 522 } 523 524 // For now, just assume it is the zero of type. 525 // In the future, it should be the zero of type + op. 526 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 527 auto t = getElementTypeOrSelf(op.get().getType()); 528 return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); 529 } 530 531 static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) { 532 MLIRContext *context = funcOp.getContext(); 533 RewritePatternSet tilingPattern(context); 534 auto linalgTilingOptions = 535 linalg::LinalgTilingOptions() 536 .setTileSizes(tileSizes) 537 .setPaddingValueComputationFunction(getNeutralOfLinalgOp); 538 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>>( 539 context, linalgTilingOptions, 540 linalg::LinalgTransformationFilter( 541 Identifier::get("tile-and-pad", context))); 542 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 543 } 544 545 static void applyInterchangePattern(FuncOp funcOp, 546 ArrayRef<unsigned> interchangeVector) { 547 MLIRContext *context = funcOp.getContext(); 548 RewritePatternSet interchangePattern(context); 549 interchangePattern.add<GenericOpInterchangePattern>( 550 context, interchangeVector, 551 LinalgTransformationFilter(ArrayRef<Identifier>{}, 552 Identifier::get("interchange", context))); 553 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 554 } 555 556 /// Apply transformations specified as patterns. 557 void TestLinalgTransforms::runOnFunction() { 558 auto lambda = [&](void *) { 559 getFunction().walk([](LinalgOp op) { 560 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 561 }); 562 }; 563 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 564 565 if (testPromotionOptions) { 566 RewritePatternSet patterns(&getContext()); 567 fillPromotionCallBackPatterns(&getContext(), patterns); 568 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 569 return; 570 } 571 if (testTileAndDistributionOptions) { 572 RewritePatternSet patterns(&getContext()); 573 fillTileAndDistributePatterns(&getContext(), patterns); 574 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 575 return; 576 } 577 if (testPatterns) 578 return applyPatterns(getFunction()); 579 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 580 return applyMatmulToVectorPatterns(getFunction(), 581 testMatmulToVectorPatterns1dTiling, 582 testMatmulToVectorPatterns2dTiling); 583 if (testVectorTransferForwardingPatterns) 584 return applyVectorTransferForwardingPatterns(getFunction()); 585 if (testGenericToVectorPattern) 586 return applyLinalgToVectorPatterns(getFunction()); 587 if (testAffineMinSCFCanonicalizationPatterns) 588 return applyAffineMinSCFCanonicalizationPatterns(getFunction()); 589 if (testTileAndPadPattern) 590 return applyTileAndPadPattern(getFunction(), tileSizesForPadding); 591 if (testHoistPadding) { 592 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 593 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 594 }); 595 } 596 if (testInterchangePattern.hasValue()) 597 return applyInterchangePattern(getFunction(), testInterchangePattern); 598 } 599 600 namespace mlir { 601 namespace test { 602 void registerTestLinalgTransforms() { 603 PassRegistration<TestLinalgTransforms> testTransformPatternsPass( 604 "test-linalg-transform-patterns", 605 "Test Linalg transformation patterns by applying them greedily."); 606 } 607 } // namespace test 608 } // namespace mlir 609