1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 #include "llvm/ADT/SetVector.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 namespace { 30 struct TestLinalgTransforms 31 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 32 TestLinalgTransforms() = default; 33 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 34 35 void getDependentDialects(DialectRegistry ®istry) const override { 36 // clang-format off 37 registry.insert<AffineDialect, 38 memref::MemRefDialect, 39 scf::SCFDialect, 40 StandardOpsDialect, 41 vector::VectorDialect, 42 gpu::GPUDialect>(); 43 // clang-format on 44 } 45 46 void runOnFunction() override; 47 48 Option<bool> testPatterns{*this, "test-patterns", 49 llvm::cl::desc("Test a mixed set of patterns"), 50 llvm::cl::init(false)}; 51 Option<bool> testMatmulToVectorPatterns1dTiling{ 52 *this, "test-matmul-to-vector-patterns-tile-1d", 53 llvm::cl::desc( 54 "Test a fused pass that applies patterns from matmul to vectors via " 55 "1-d tiling"), 56 llvm::cl::init(false)}; 57 Option<bool> testMatmulToVectorPatterns2dTiling{ 58 *this, "test-matmul-to-vector-patterns-tile-2d", 59 llvm::cl::desc( 60 "Test a fused pass that applies patterns from matmul to vectors via " 61 "2-d tiling"), 62 llvm::cl::init(false)}; 63 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 64 llvm::cl::desc("Test promotion options"), 65 llvm::cl::init(false)}; 66 Option<bool> testTileAndDistributionOptions{ 67 *this, "test-tile-and-distribute-options", 68 llvm::cl::desc("Test tile and distribute options"), 69 llvm::cl::init(false)}; 70 Option<bool> testVectorTransferForwardingPatterns{ 71 *this, "test-vector-transfer-forwarding-patterns", 72 llvm::cl::desc( 73 "Test a fused pass that forwards linalg.copy to vector.transfer"), 74 llvm::cl::init(false)}; 75 Option<bool> testGenericToVectorPattern{ 76 *this, "test-linalg-to-vector-patterns", 77 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 78 "in vector.contract form"), 79 llvm::cl::init(false)}; 80 Option<bool> testAffineMinSCFCanonicalizationPatterns{ 81 *this, "test-affine-min-scf-canonicalization-patterns", 82 llvm::cl::desc("Test affine-min + scf canonicalization patterns."), 83 llvm::cl::init(false)}; 84 Option<bool> testTileAndPadPattern{ 85 *this, "test-tile-and-pad-pattern", 86 llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)}; 87 Option<int> testHoistPadding{*this, "test-hoist-padding", 88 llvm::cl::desc("Test hoist padding"), 89 llvm::cl::init(0)}; 90 ListOption<int64_t> tileSizesForPadding{ 91 *this, "tile-sizes-for-padding", 92 llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, 93 llvm::cl::MiscFlags::CommaSeparated}; 94 ListOption<unsigned> testInterchangePattern{ 95 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 96 llvm::cl::desc("Test the interchange pattern.")}; 97 }; 98 } // end anonymous namespace 99 100 static void applyPatterns(FuncOp funcOp) { 101 MLIRContext *ctx = funcOp.getContext(); 102 RewritePatternSet patterns(ctx); 103 104 //===--------------------------------------------------------------------===// 105 // Linalg tiling patterns. 106 //===--------------------------------------------------------------------===// 107 patterns.add<LinalgTilingPattern<MatmulOp>>( 108 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 109 LinalgTransformationFilter(Identifier::get("MEM", ctx), 110 Identifier::get("L3", ctx))); 111 patterns.add<LinalgTilingPattern<MatmulOp>>( 112 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 113 LinalgTransformationFilter(Identifier::get("L3", ctx), 114 Identifier::get("L2", ctx))); 115 patterns.add<LinalgTilingPattern<MatmulOp>>( 116 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 117 LinalgTransformationFilter(Identifier::get("L2", ctx), 118 Identifier::get("L1", ctx))); 119 patterns.add<LinalgTilingPattern<MatmulOp>>( 120 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 121 LinalgTransformationFilter(Identifier::get("L1", ctx), 122 Identifier::get("REG", ctx))); 123 124 patterns.add<LinalgTilingPattern<MatvecOp>>( 125 ctx, 126 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 127 LinalgTilingLoopType::ParallelLoops), 128 LinalgTransformationFilter(ArrayRef<Identifier>{}, 129 Identifier::get("L1", ctx))); 130 131 patterns.add<LinalgTilingPattern<DotOp>>( 132 ctx, LinalgTilingOptions().setTileSizes(8000), 133 LinalgTransformationFilter( 134 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 135 Identifier::get("L3", ctx), 136 Identifier::get("L2", ctx)}, 137 Identifier::get("REG", ctx))); 138 139 //===--------------------------------------------------------------------===// 140 // Linalg tiling and permutation patterns. 141 //===--------------------------------------------------------------------===// 142 patterns.add<LinalgTilingPattern<MatmulOp>>( 143 ctx, 144 LinalgTilingOptions() 145 .setTileSizes({2000, 3000, 4000}) 146 .setInterchange({1, 2, 0}), 147 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 148 Identifier::get("L2__with_perm__", ctx))); 149 patterns.add<LinalgTilingPattern<MatmulOp>>( 150 ctx, 151 LinalgTilingOptions() 152 .setTileSizes({200, 300, 400}) 153 .setInterchange({1, 0, 2}), 154 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 155 Identifier::get("L1__with_perm__", ctx))); 156 patterns.add<LinalgTilingPattern<MatmulOp>>( 157 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 158 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 159 Identifier::get("REG__with_perm__", ctx))); 160 161 patterns.add<LinalgTilingPattern<MatvecOp>>( 162 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 163 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 164 Identifier::get("L1__with_perm__", ctx))); 165 166 patterns.add<LinalgTilingPattern<MatmulOp>>( 167 ctx, 168 LinalgTilingOptions() 169 .setTileSizes({16, 8, 4}) 170 .setInterchange({1, 2, 0}) 171 .setLoopType(LinalgTilingLoopType::ParallelLoops), 172 LinalgTransformationFilter( 173 Identifier::get("par__with_perm__", ctx), 174 Identifier::get("after_par__with_perm__", ctx))); 175 176 //===--------------------------------------------------------------------===// 177 // Linalg to loops patterns. 178 //===--------------------------------------------------------------------===// 179 patterns.add<LinalgLoweringPattern<DotOp>>( 180 ctx, 181 /*loweringType=*/LinalgLoweringType::Loops, 182 LinalgTransformationFilter(Identifier::get("REG", ctx))); 183 184 //===--------------------------------------------------------------------===// 185 // Linalg distribution patterns. 186 //===--------------------------------------------------------------------===// 187 LinalgLoopDistributionOptions distributionOptions; 188 189 //===--------------------------------------------------------------------===// 190 // Linalg to vector contraction patterns. 191 //===--------------------------------------------------------------------===// 192 patterns.add<LinalgVectorizationPattern>( 193 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 194 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 195 196 //===--------------------------------------------------------------------===// 197 // Linalg generic interchange pattern. 198 //===--------------------------------------------------------------------===// 199 patterns.add<GenericOpInterchangePattern>( 200 ctx, 201 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 202 LinalgTransformationFilter(ArrayRef<Identifier>{}, 203 Identifier::get("PERMUTED", ctx))); 204 205 //===--------------------------------------------------------------------===// 206 // Linalg subview operands promotion. 207 //===--------------------------------------------------------------------===// 208 patterns.add<LinalgPromotionPattern<MatmulOp>>( 209 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 210 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 211 Identifier::get("_views_promoted_", ctx))); 212 patterns.add<LinalgPromotionPattern<MatmulOp>>( 213 ctx, 214 LinalgPromotionOptions() 215 .setOperandsToPromote({0}) 216 .setUseFullTileBuffersByDefault(true), 217 LinalgTransformationFilter( 218 Identifier::get("_promote_first_view_", ctx), 219 Identifier::get("_first_view_promoted_", ctx))); 220 patterns.add<LinalgPromotionPattern<FillOp>>( 221 ctx, 222 LinalgPromotionOptions() 223 .setOperandsToPromote({0}) 224 .setUseFullTileBuffers({true}) 225 .setAlignment(32), 226 LinalgTransformationFilter( 227 Identifier::get("_promote_views_aligned_", ctx), 228 Identifier::get("_views_aligned_promoted_", ctx))); 229 230 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 231 232 // Drop the marker. 233 funcOp.walk([](LinalgOp op) { 234 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 235 }); 236 } 237 238 static void fillL1TilingAndMatmulToVectorPatterns( 239 FuncOp funcOp, StringRef startMarker, 240 SmallVectorImpl<RewritePatternSet> &patternsVector) { 241 MLIRContext *ctx = funcOp.getContext(); 242 patternsVector.emplace_back( 243 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 244 ctx, 245 LinalgTilingOptions() 246 .setTileSizes({8, 12, 16}) 247 .setInterchange({1, 0, 2}), 248 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 249 Identifier::get("L1", ctx)))); 250 251 patternsVector.emplace_back( 252 ctx, 253 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 254 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 255 LinalgTransformationFilter(Identifier::get("L1", ctx), 256 Identifier::get("VEC", ctx)))); 257 258 patternsVector.emplace_back( 259 ctx, std::make_unique<LinalgVectorizationPattern>( 260 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 261 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 262 patternsVector.back().add<LinalgVectorizationPattern>( 263 ctx, LinalgTransformationFilter().addFilter( 264 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 265 } 266 267 //===----------------------------------------------------------------------===// 268 // Test promotion callbacks 269 //===----------------------------------------------------------------------===// 270 271 // Allocation call back 272 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 273 ArrayRef<Value> boundingSubViewSize, 274 DataLayout &layout) { 275 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 276 return b 277 .create<memref::AllocOp>( 278 subView.getLoc(), 279 MemRefType::get(shape, subView.getType().getElementType(), 280 /*affineMapComposition =*/{}, 3), 281 boundingSubViewSize) 282 .getResult(); 283 } 284 285 // Deallocation callback 286 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 287 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 288 return success(); 289 } 290 291 // Copy in call back 292 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 293 bool isOutput) { 294 auto floatType = src.getType().cast<MemRefType>().getElementType(); 295 if (!floatType.isa<FloatType>()) 296 return failure(); 297 if (!isOutput) 298 b.create<FillOp>( 299 src.getLoc(), dst, 300 b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0))); 301 b.create<CopyOp>(src.getLoc(), src, dst); 302 return success(); 303 } 304 305 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 306 RewritePatternSet &patterns) { 307 patterns.add<LinalgTilingPattern<MatmulOp>>( 308 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 309 LinalgTransformationFilter(Identifier::get("START", ctx), 310 Identifier::get("PROMOTE", ctx))); 311 patterns.add<LinalgPromotionPattern<MatmulOp>>( 312 ctx, 313 LinalgPromotionOptions() 314 .setOperandsToPromote({0, 2}) 315 .setUseFullTileBuffers({false, false}) 316 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 317 .setCopyInOutFns( 318 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 319 return copyCallBackFn(b, src, dst, false); 320 }, 321 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 322 return copyCallBackFn(b, src, dst, true); 323 }), 324 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 325 } 326 327 template <typename IdOp, typename NProcsOp> 328 static SmallVector<ProcInfo, 2> 329 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 330 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 331 SmallVector<ProcInfo, 2> procInfo(count); 332 const char *xyz[] = {"x", "y", "z"}; 333 Type indexType = b.getIndexType(); 334 for (unsigned i = 0; i < count; ++i) { 335 procInfo[count - 1 - i] = { 336 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 337 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 338 } 339 return procInfo; 340 } 341 342 static void fillTileAndDistributePatterns(MLIRContext *context, 343 RewritePatternSet &patterns) { 344 { 345 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 346 cyclicNprocsEqNiters.distributionMethod.resize( 347 2, DistributionMethod::CyclicNumProcsEqNumIters); 348 cyclicNprocsEqNiters.procInfo = 349 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 350 patterns.add<LinalgTilingPattern<MatmulOp>>( 351 context, 352 LinalgTilingOptions() 353 .setTileSizes({8, 8, 4}) 354 .setLoopType(LinalgTilingLoopType::ParallelLoops) 355 .setDistributionOptions(cyclicNprocsEqNiters), 356 LinalgTransformationFilter( 357 Identifier::get("distribute1", context), 358 Identifier::get("after_distribute1", context))); 359 } 360 361 { 362 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 363 cyclicNprocsGeNiters.distributionMethod.resize( 364 2, DistributionMethod::CyclicNumProcsGeNumIters); 365 cyclicNprocsGeNiters.procInfo = 366 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 367 patterns.add<LinalgTilingPattern<MatmulOp>>( 368 context, 369 LinalgTilingOptions() 370 .setTileSizes({8, 8, 4}) 371 .setLoopType(LinalgTilingLoopType::ParallelLoops) 372 .setDistributionOptions(cyclicNprocsGeNiters), 373 LinalgTransformationFilter( 374 Identifier::get("distribute2", context), 375 Identifier::get("after_distribute2", context))); 376 } 377 378 { 379 LinalgLoopDistributionOptions cyclicNprocsDefault; 380 cyclicNprocsDefault.distributionMethod.resize(2, 381 DistributionMethod::Cyclic); 382 cyclicNprocsDefault.procInfo = 383 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 384 patterns.add<LinalgTilingPattern<MatmulOp>>( 385 context, 386 LinalgTilingOptions() 387 .setTileSizes({8, 8, 4}) 388 .setLoopType(LinalgTilingLoopType::ParallelLoops) 389 .setDistributionOptions(cyclicNprocsDefault), 390 LinalgTransformationFilter( 391 Identifier::get("distribute3", context), 392 Identifier::get("after_distribute3", context))); 393 } 394 395 { 396 LinalgLoopDistributionOptions cyclicNprocsMixed1; 397 cyclicNprocsMixed1.distributionMethod = { 398 DistributionMethod::CyclicNumProcsEqNumIters, 399 DistributionMethod::CyclicNumProcsGeNumIters}; 400 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 401 patterns.add<LinalgTilingPattern<MatmulOp>>( 402 context, 403 LinalgTilingOptions() 404 .setTileSizes({8, 8, 4}) 405 .setLoopType(LinalgTilingLoopType::ParallelLoops) 406 .setDistributionOptions(cyclicNprocsMixed1), 407 LinalgTransformationFilter( 408 Identifier::get("distribute4", context), 409 Identifier::get("after_distribute4", context))); 410 } 411 412 { 413 LinalgLoopDistributionOptions cyclicNprocsMixed2; 414 cyclicNprocsMixed2.distributionMethod = { 415 DistributionMethod::CyclicNumProcsGeNumIters, 416 DistributionMethod::Cyclic}; 417 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 418 patterns.add<LinalgTilingPattern<MatmulOp>>( 419 context, 420 LinalgTilingOptions() 421 .setTileSizes({8, 8, 4}) 422 .setLoopType(LinalgTilingLoopType::ParallelLoops) 423 .setDistributionOptions(cyclicNprocsMixed2), 424 LinalgTransformationFilter( 425 Identifier::get("distribute5", context), 426 Identifier::get("after_distribute5", context))); 427 } 428 429 { 430 LinalgLoopDistributionOptions cyclicNprocsMixed3; 431 cyclicNprocsMixed3.distributionMethod = { 432 DistributionMethod::Cyclic, 433 DistributionMethod::CyclicNumProcsEqNumIters}; 434 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 435 436 patterns.add<LinalgTilingPattern<MatmulOp>>( 437 context, 438 LinalgTilingOptions() 439 .setTileSizes({8, 8, 4}) 440 .setLoopType(LinalgTilingLoopType::ParallelLoops) 441 .setDistributionOptions(cyclicNprocsMixed3), 442 LinalgTransformationFilter( 443 Identifier::get("distribute6", context), 444 Identifier::get("after_distribute6", context))); 445 } 446 447 { 448 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 449 cyclicNprocsEqNiters.distributionMethod.resize(2, 450 DistributionMethod::Cyclic); 451 cyclicNprocsEqNiters.procInfo = 452 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 453 patterns.add<LinalgTilingPattern<MatmulOp>>( 454 context, 455 LinalgTilingOptions() 456 .setTileSizes({8, 8, 4}) 457 .setLoopType(LinalgTilingLoopType::Loops) 458 .setDistributionOptions(cyclicNprocsEqNiters), 459 LinalgTransformationFilter( 460 Identifier::get("tensors_distribute1", context), 461 Identifier::get("tensors_after_distribute1", context))); 462 } 463 } 464 465 static void 466 applyMatmulToVectorPatterns(FuncOp funcOp, 467 bool testMatmulToVectorPatterns1dTiling, 468 bool testMatmulToVectorPatterns2dTiling) { 469 MLIRContext *ctx = funcOp.getContext(); 470 SmallVector<RewritePatternSet, 4> stage1Patterns; 471 if (testMatmulToVectorPatterns1dTiling) { 472 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 473 stage1Patterns); 474 } else if (testMatmulToVectorPatterns2dTiling) { 475 stage1Patterns.emplace_back( 476 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 477 ctx, 478 LinalgTilingOptions() 479 .setTileSizes({768, 264, 768}) 480 .setInterchange({1, 2, 0}), 481 LinalgTransformationFilter(Identifier::get("START", ctx), 482 Identifier::get("L2", ctx)))); 483 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 484 stage1Patterns); 485 } 486 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 487 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 488 FrozenRewritePatternSet stage2Patterns = 489 getLinalgTilingCanonicalizationPatterns(ctx); 490 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 491 std::move(stage2Patterns)); 492 } 493 494 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 495 RewritePatternSet forwardPattern(funcOp.getContext()); 496 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 497 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 498 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 499 } 500 501 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 502 RewritePatternSet patterns(funcOp.getContext()); 503 patterns.add<LinalgVectorizationPattern>( 504 funcOp.getContext(), 505 LinalgTransformationFilter() 506 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 507 patterns.add<PadTensorOpVectorizationPattern>(funcOp.getContext()); 508 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 509 } 510 511 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) { 512 RewritePatternSet foldPattern(funcOp.getContext()); 513 foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext()); 514 FrozenRewritePatternSet frozenPatterns(std::move(foldPattern)); 515 516 // Explicitly walk and apply the pattern locally to avoid more general folding 517 // on the rest of the IR. 518 funcOp.walk([&frozenPatterns](AffineMinOp minOp) { 519 (void)applyOpPatternsAndFold(minOp, frozenPatterns); 520 }); 521 } 522 523 // For now, just assume it is the zero of type. 524 // In the future, it should be the zero of type + op. 525 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 526 auto t = getElementTypeOrSelf(op.get().getType()); 527 return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); 528 } 529 530 static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) { 531 MLIRContext *context = funcOp.getContext(); 532 RewritePatternSet tilingPattern(context); 533 auto linalgTilingOptions = 534 linalg::LinalgTilingOptions() 535 .setTileSizes(tileSizes) 536 .setPaddingValueComputationFunction(getNeutralOfLinalgOp); 537 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>>( 538 context, linalgTilingOptions, 539 linalg::LinalgTransformationFilter( 540 Identifier::get("tile-and-pad", context))); 541 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 542 } 543 544 static void applyInterchangePattern(FuncOp funcOp, 545 ArrayRef<unsigned> interchangeVector) { 546 MLIRContext *context = funcOp.getContext(); 547 RewritePatternSet interchangePattern(context); 548 interchangePattern.add<GenericOpInterchangePattern>( 549 context, interchangeVector, 550 LinalgTransformationFilter(ArrayRef<Identifier>{}, 551 Identifier::get("interchange", context))); 552 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 553 } 554 555 /// Apply transformations specified as patterns. 556 void TestLinalgTransforms::runOnFunction() { 557 auto lambda = [&](void *) { 558 getFunction().walk([](LinalgOp op) { 559 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 560 }); 561 }; 562 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 563 564 if (testPromotionOptions) { 565 RewritePatternSet patterns(&getContext()); 566 fillPromotionCallBackPatterns(&getContext(), patterns); 567 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 568 return; 569 } 570 if (testTileAndDistributionOptions) { 571 RewritePatternSet patterns(&getContext()); 572 fillTileAndDistributePatterns(&getContext(), patterns); 573 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 574 return; 575 } 576 if (testPatterns) 577 return applyPatterns(getFunction()); 578 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 579 return applyMatmulToVectorPatterns(getFunction(), 580 testMatmulToVectorPatterns1dTiling, 581 testMatmulToVectorPatterns2dTiling); 582 if (testVectorTransferForwardingPatterns) 583 return applyVectorTransferForwardingPatterns(getFunction()); 584 if (testGenericToVectorPattern) 585 return applyLinalgToVectorPatterns(getFunction()); 586 if (testAffineMinSCFCanonicalizationPatterns) 587 return applyAffineMinSCFCanonicalizationPatterns(getFunction()); 588 if (testTileAndPadPattern) 589 return applyTileAndPadPattern(getFunction(), tileSizesForPadding); 590 if (testHoistPadding) { 591 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 592 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 593 }); 594 } 595 if (testInterchangePattern.hasValue()) 596 return applyInterchangePattern(getFunction(), testInterchangePattern); 597 } 598 599 namespace mlir { 600 namespace test { 601 void registerTestLinalgTransforms() { 602 PassRegistration<TestLinalgTransforms> testTransformPatternsPass( 603 "test-linalg-transform-patterns", 604 "Test Linalg transformation patterns by applying them greedily."); 605 } 606 } // namespace test 607 } // namespace mlir 608