1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 #include "llvm/ADT/SetVector.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 namespace { 30 struct TestLinalgTransforms 31 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 32 TestLinalgTransforms() = default; 33 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 34 35 void getDependentDialects(DialectRegistry ®istry) const override { 36 // clang-format off 37 registry.insert<AffineDialect, 38 memref::MemRefDialect, 39 scf::SCFDialect, 40 StandardOpsDialect, 41 vector::VectorDialect, 42 gpu::GPUDialect>(); 43 // clang-format on 44 } 45 46 void runOnFunction() override; 47 48 Option<bool> testPatterns{*this, "test-patterns", 49 llvm::cl::desc("Test a mixed set of patterns"), 50 llvm::cl::init(false)}; 51 Option<bool> testMatmulToVectorPatterns1dTiling{ 52 *this, "test-matmul-to-vector-patterns-tile-1d", 53 llvm::cl::desc( 54 "Test a fused pass that applies patterns from matmul to vectors via " 55 "1-d tiling"), 56 llvm::cl::init(false)}; 57 Option<bool> testMatmulToVectorPatterns2dTiling{ 58 *this, "test-matmul-to-vector-patterns-tile-2d", 59 llvm::cl::desc( 60 "Test a fused pass that applies patterns from matmul to vectors via " 61 "2-d tiling"), 62 llvm::cl::init(false)}; 63 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 64 llvm::cl::desc("Test promotion options"), 65 llvm::cl::init(false)}; 66 Option<bool> testTileAndDistributionOptions{ 67 *this, "test-tile-and-distribute-options", 68 llvm::cl::desc("Test tile and distribute options"), 69 llvm::cl::init(false)}; 70 Option<bool> testVectorTransferForwardingPatterns{ 71 *this, "test-vector-transfer-forwarding-patterns", 72 llvm::cl::desc( 73 "Test a fused pass that forwards linalg.copy to vector.transfer"), 74 llvm::cl::init(false)}; 75 Option<bool> testGenericToVectorPattern{ 76 *this, "test-linalg-to-vector-patterns", 77 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 78 "in vector.contract form"), 79 llvm::cl::init(false)}; 80 Option<bool> testAffineMinSCFCanonicalizationPatterns{ 81 *this, "test-affine-min-scf-canonicalization-patterns", 82 llvm::cl::desc("Test affine-min + scf canonicalization patterns."), 83 llvm::cl::init(false)}; 84 Option<bool> testTileAndPadPattern{ 85 *this, "test-tile-and-pad-pattern", 86 llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)}; 87 Option<int> testHoistPadding{*this, "test-hoist-padding", 88 llvm::cl::desc("Test hoist padding"), 89 llvm::cl::init(0)}; 90 Option<bool> testTransformPadTensor{ 91 *this, "test-transform-pad-tensor", 92 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 93 llvm::cl::init(false)}; 94 ListOption<int64_t> tileSizesForPadding{ 95 *this, "tile-sizes-for-padding", 96 llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, 97 llvm::cl::MiscFlags::CommaSeparated}; 98 ListOption<unsigned> testInterchangePattern{ 99 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 100 llvm::cl::desc("Test the interchange pattern.")}; 101 }; 102 } // end anonymous namespace 103 104 static void applyPatterns(FuncOp funcOp) { 105 MLIRContext *ctx = funcOp.getContext(); 106 RewritePatternSet patterns(ctx); 107 108 //===--------------------------------------------------------------------===// 109 // Linalg tiling patterns. 110 //===--------------------------------------------------------------------===// 111 patterns.add<LinalgTilingPattern<MatmulOp>>( 112 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 113 LinalgTransformationFilter(Identifier::get("MEM", ctx), 114 Identifier::get("L3", ctx))); 115 patterns.add<LinalgTilingPattern<MatmulOp>>( 116 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 117 LinalgTransformationFilter(Identifier::get("L3", ctx), 118 Identifier::get("L2", ctx))); 119 patterns.add<LinalgTilingPattern<MatmulOp>>( 120 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 121 LinalgTransformationFilter(Identifier::get("L2", ctx), 122 Identifier::get("L1", ctx))); 123 patterns.add<LinalgTilingPattern<MatmulOp>>( 124 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 125 LinalgTransformationFilter(Identifier::get("L1", ctx), 126 Identifier::get("REG", ctx))); 127 128 patterns.add<LinalgTilingPattern<MatvecOp>>( 129 ctx, 130 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 131 LinalgTilingLoopType::ParallelLoops), 132 LinalgTransformationFilter(ArrayRef<Identifier>{}, 133 Identifier::get("L1", ctx))); 134 135 patterns.add<LinalgTilingPattern<DotOp>>( 136 ctx, LinalgTilingOptions().setTileSizes(8000), 137 LinalgTransformationFilter( 138 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 139 Identifier::get("L3", ctx), 140 Identifier::get("L2", ctx)}, 141 Identifier::get("REG", ctx))); 142 143 //===--------------------------------------------------------------------===// 144 // Linalg tiling and permutation patterns. 145 //===--------------------------------------------------------------------===// 146 patterns.add<LinalgTilingPattern<MatmulOp>>( 147 ctx, 148 LinalgTilingOptions() 149 .setTileSizes({2000, 3000, 4000}) 150 .setInterchange({1, 2, 0}), 151 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 152 Identifier::get("L2__with_perm__", ctx))); 153 patterns.add<LinalgTilingPattern<MatmulOp>>( 154 ctx, 155 LinalgTilingOptions() 156 .setTileSizes({200, 300, 400}) 157 .setInterchange({1, 0, 2}), 158 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 159 Identifier::get("L1__with_perm__", ctx))); 160 patterns.add<LinalgTilingPattern<MatmulOp>>( 161 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 162 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 163 Identifier::get("REG__with_perm__", ctx))); 164 165 patterns.add<LinalgTilingPattern<MatvecOp>>( 166 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 167 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 168 Identifier::get("L1__with_perm__", ctx))); 169 170 patterns.add<LinalgTilingPattern<MatmulOp>>( 171 ctx, 172 LinalgTilingOptions() 173 .setTileSizes({16, 8, 4}) 174 .setInterchange({1, 2, 0}) 175 .setLoopType(LinalgTilingLoopType::ParallelLoops), 176 LinalgTransformationFilter( 177 Identifier::get("par__with_perm__", ctx), 178 Identifier::get("after_par__with_perm__", ctx))); 179 180 //===--------------------------------------------------------------------===// 181 // Linalg to loops patterns. 182 //===--------------------------------------------------------------------===// 183 patterns.add<LinalgLoweringPattern<DotOp>>( 184 ctx, 185 /*loweringType=*/LinalgLoweringType::Loops, 186 LinalgTransformationFilter(Identifier::get("REG", ctx))); 187 188 //===--------------------------------------------------------------------===// 189 // Linalg distribution patterns. 190 //===--------------------------------------------------------------------===// 191 LinalgLoopDistributionOptions distributionOptions; 192 193 //===--------------------------------------------------------------------===// 194 // Linalg to vector contraction patterns. 195 //===--------------------------------------------------------------------===// 196 patterns.add<LinalgVectorizationPattern>( 197 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 198 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 199 200 //===--------------------------------------------------------------------===// 201 // Linalg generic interchange pattern. 202 //===--------------------------------------------------------------------===// 203 patterns.add<GenericOpInterchangePattern>( 204 ctx, 205 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 206 LinalgTransformationFilter(ArrayRef<Identifier>{}, 207 Identifier::get("PERMUTED", ctx))); 208 209 //===--------------------------------------------------------------------===// 210 // Linalg subview operands promotion. 211 //===--------------------------------------------------------------------===// 212 patterns.add<LinalgPromotionPattern<MatmulOp>>( 213 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 214 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 215 Identifier::get("_views_promoted_", ctx))); 216 patterns.add<LinalgPromotionPattern<MatmulOp>>( 217 ctx, 218 LinalgPromotionOptions() 219 .setOperandsToPromote({0}) 220 .setUseFullTileBuffersByDefault(true), 221 LinalgTransformationFilter( 222 Identifier::get("_promote_first_view_", ctx), 223 Identifier::get("_first_view_promoted_", ctx))); 224 patterns.add<LinalgPromotionPattern<FillOp>>( 225 ctx, 226 LinalgPromotionOptions() 227 .setOperandsToPromote({0}) 228 .setUseFullTileBuffers({true}) 229 .setAlignment(32), 230 LinalgTransformationFilter( 231 Identifier::get("_promote_views_aligned_", ctx), 232 Identifier::get("_views_aligned_promoted_", ctx))); 233 234 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 235 236 // Drop the marker. 237 funcOp.walk([](LinalgOp op) { 238 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 239 }); 240 } 241 242 static void fillL1TilingAndMatmulToVectorPatterns( 243 FuncOp funcOp, StringRef startMarker, 244 SmallVectorImpl<RewritePatternSet> &patternsVector) { 245 MLIRContext *ctx = funcOp.getContext(); 246 patternsVector.emplace_back( 247 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 248 ctx, 249 LinalgTilingOptions() 250 .setTileSizes({8, 12, 16}) 251 .setInterchange({1, 0, 2}), 252 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 253 Identifier::get("L1", ctx)))); 254 255 patternsVector.emplace_back( 256 ctx, 257 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 258 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 259 LinalgTransformationFilter(Identifier::get("L1", ctx), 260 Identifier::get("VEC", ctx)))); 261 262 patternsVector.emplace_back( 263 ctx, std::make_unique<LinalgVectorizationPattern>( 264 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 265 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 266 patternsVector.back().add<LinalgVectorizationPattern>( 267 ctx, LinalgTransformationFilter().addFilter( 268 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // Test promotion callbacks 273 //===----------------------------------------------------------------------===// 274 275 // Allocation call back 276 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 277 ArrayRef<Value> boundingSubViewSize, 278 DataLayout &layout) { 279 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 280 return b 281 .create<memref::AllocOp>( 282 subView.getLoc(), 283 MemRefType::get(shape, subView.getType().getElementType(), 284 /*affineMapComposition =*/{}, 3), 285 boundingSubViewSize) 286 .getResult(); 287 } 288 289 // Deallocation callback 290 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 291 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 292 return success(); 293 } 294 295 // Copy in call back 296 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 297 bool isOutput) { 298 auto floatType = src.getType().cast<MemRefType>().getElementType(); 299 if (!floatType.isa<FloatType>()) 300 return failure(); 301 if (!isOutput) 302 b.create<FillOp>( 303 src.getLoc(), dst, 304 b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0))); 305 b.create<CopyOp>(src.getLoc(), src, dst); 306 return success(); 307 } 308 309 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 310 RewritePatternSet &patterns) { 311 patterns.add<LinalgTilingPattern<MatmulOp>>( 312 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 313 LinalgTransformationFilter(Identifier::get("START", ctx), 314 Identifier::get("PROMOTE", ctx))); 315 patterns.add<LinalgPromotionPattern<MatmulOp>>( 316 ctx, 317 LinalgPromotionOptions() 318 .setOperandsToPromote({0, 2}) 319 .setUseFullTileBuffers({false, false}) 320 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 321 .setCopyInOutFns( 322 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 323 return copyCallBackFn(b, src, dst, false); 324 }, 325 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 326 return copyCallBackFn(b, src, dst, true); 327 }), 328 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 329 } 330 331 template <typename IdOp, typename NProcsOp> 332 static SmallVector<ProcInfo, 2> 333 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 334 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 335 SmallVector<ProcInfo, 2> procInfo(count); 336 const char *xyz[] = {"x", "y", "z"}; 337 Type indexType = b.getIndexType(); 338 for (unsigned i = 0; i < count; ++i) { 339 procInfo[count - 1 - i] = { 340 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 341 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 342 } 343 return procInfo; 344 } 345 346 static void fillTileAndDistributePatterns(MLIRContext *context, 347 RewritePatternSet &patterns) { 348 { 349 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 350 cyclicNprocsEqNiters.distributionMethod.resize( 351 2, DistributionMethod::CyclicNumProcsEqNumIters); 352 cyclicNprocsEqNiters.procInfo = 353 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 354 patterns.add<LinalgTilingPattern<MatmulOp>>( 355 context, 356 LinalgTilingOptions() 357 .setTileSizes({8, 8, 4}) 358 .setLoopType(LinalgTilingLoopType::ParallelLoops) 359 .setDistributionOptions(cyclicNprocsEqNiters), 360 LinalgTransformationFilter( 361 Identifier::get("distribute1", context), 362 Identifier::get("after_distribute1", context))); 363 } 364 365 { 366 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 367 cyclicNprocsGeNiters.distributionMethod.resize( 368 2, DistributionMethod::CyclicNumProcsGeNumIters); 369 cyclicNprocsGeNiters.procInfo = 370 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 371 patterns.add<LinalgTilingPattern<MatmulOp>>( 372 context, 373 LinalgTilingOptions() 374 .setTileSizes({8, 8, 4}) 375 .setLoopType(LinalgTilingLoopType::ParallelLoops) 376 .setDistributionOptions(cyclicNprocsGeNiters), 377 LinalgTransformationFilter( 378 Identifier::get("distribute2", context), 379 Identifier::get("after_distribute2", context))); 380 } 381 382 { 383 LinalgLoopDistributionOptions cyclicNprocsDefault; 384 cyclicNprocsDefault.distributionMethod.resize(2, 385 DistributionMethod::Cyclic); 386 cyclicNprocsDefault.procInfo = 387 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 388 patterns.add<LinalgTilingPattern<MatmulOp>>( 389 context, 390 LinalgTilingOptions() 391 .setTileSizes({8, 8, 4}) 392 .setLoopType(LinalgTilingLoopType::ParallelLoops) 393 .setDistributionOptions(cyclicNprocsDefault), 394 LinalgTransformationFilter( 395 Identifier::get("distribute3", context), 396 Identifier::get("after_distribute3", context))); 397 } 398 399 { 400 LinalgLoopDistributionOptions cyclicNprocsMixed1; 401 cyclicNprocsMixed1.distributionMethod = { 402 DistributionMethod::CyclicNumProcsEqNumIters, 403 DistributionMethod::CyclicNumProcsGeNumIters}; 404 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 405 patterns.add<LinalgTilingPattern<MatmulOp>>( 406 context, 407 LinalgTilingOptions() 408 .setTileSizes({8, 8, 4}) 409 .setLoopType(LinalgTilingLoopType::ParallelLoops) 410 .setDistributionOptions(cyclicNprocsMixed1), 411 LinalgTransformationFilter( 412 Identifier::get("distribute4", context), 413 Identifier::get("after_distribute4", context))); 414 } 415 416 { 417 LinalgLoopDistributionOptions cyclicNprocsMixed2; 418 cyclicNprocsMixed2.distributionMethod = { 419 DistributionMethod::CyclicNumProcsGeNumIters, 420 DistributionMethod::Cyclic}; 421 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 422 patterns.add<LinalgTilingPattern<MatmulOp>>( 423 context, 424 LinalgTilingOptions() 425 .setTileSizes({8, 8, 4}) 426 .setLoopType(LinalgTilingLoopType::ParallelLoops) 427 .setDistributionOptions(cyclicNprocsMixed2), 428 LinalgTransformationFilter( 429 Identifier::get("distribute5", context), 430 Identifier::get("after_distribute5", context))); 431 } 432 433 { 434 LinalgLoopDistributionOptions cyclicNprocsMixed3; 435 cyclicNprocsMixed3.distributionMethod = { 436 DistributionMethod::Cyclic, 437 DistributionMethod::CyclicNumProcsEqNumIters}; 438 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 439 440 patterns.add<LinalgTilingPattern<MatmulOp>>( 441 context, 442 LinalgTilingOptions() 443 .setTileSizes({8, 8, 4}) 444 .setLoopType(LinalgTilingLoopType::ParallelLoops) 445 .setDistributionOptions(cyclicNprocsMixed3), 446 LinalgTransformationFilter( 447 Identifier::get("distribute6", context), 448 Identifier::get("after_distribute6", context))); 449 } 450 451 { 452 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 453 cyclicNprocsEqNiters.distributionMethod.resize(2, 454 DistributionMethod::Cyclic); 455 cyclicNprocsEqNiters.procInfo = 456 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 457 patterns.add<LinalgTilingPattern<MatmulOp>>( 458 context, 459 LinalgTilingOptions() 460 .setTileSizes({8, 8, 4}) 461 .setLoopType(LinalgTilingLoopType::Loops) 462 .setDistributionOptions(cyclicNprocsEqNiters), 463 LinalgTransformationFilter( 464 Identifier::get("tensors_distribute1", context), 465 Identifier::get("tensors_after_distribute1", context))); 466 } 467 } 468 469 static void 470 applyMatmulToVectorPatterns(FuncOp funcOp, 471 bool testMatmulToVectorPatterns1dTiling, 472 bool testMatmulToVectorPatterns2dTiling) { 473 MLIRContext *ctx = funcOp.getContext(); 474 SmallVector<RewritePatternSet, 4> stage1Patterns; 475 if (testMatmulToVectorPatterns1dTiling) { 476 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 477 stage1Patterns); 478 } else if (testMatmulToVectorPatterns2dTiling) { 479 stage1Patterns.emplace_back( 480 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 481 ctx, 482 LinalgTilingOptions() 483 .setTileSizes({768, 264, 768}) 484 .setInterchange({1, 2, 0}), 485 LinalgTransformationFilter(Identifier::get("START", ctx), 486 Identifier::get("L2", ctx)))); 487 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 488 stage1Patterns); 489 } 490 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 491 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 492 FrozenRewritePatternSet stage2Patterns = 493 getLinalgTilingCanonicalizationPatterns(ctx); 494 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 495 std::move(stage2Patterns)); 496 } 497 498 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 499 RewritePatternSet forwardPattern(funcOp.getContext()); 500 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 501 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 502 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 503 } 504 505 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 506 RewritePatternSet patterns(funcOp.getContext()); 507 patterns.add<LinalgVectorizationPattern>( 508 funcOp.getContext(), 509 LinalgTransformationFilter() 510 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 511 populatePadTensorOpVectorizationPatterns(patterns); 512 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 513 } 514 515 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 516 RewritePatternSet patterns(funcOp.getContext()); 517 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 518 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 519 } 520 521 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) { 522 RewritePatternSet foldPattern(funcOp.getContext()); 523 foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext()); 524 FrozenRewritePatternSet frozenPatterns(std::move(foldPattern)); 525 526 // Explicitly walk and apply the pattern locally to avoid more general folding 527 // on the rest of the IR. 528 funcOp.walk([&frozenPatterns](AffineMinOp minOp) { 529 (void)applyOpPatternsAndFold(minOp, frozenPatterns); 530 }); 531 } 532 533 // For now, just assume it is the zero of type. 534 // In the future, it should be the zero of type + op. 535 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 536 auto t = getElementTypeOrSelf(op.get().getType()); 537 return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); 538 } 539 540 static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) { 541 MLIRContext *context = funcOp.getContext(); 542 RewritePatternSet tilingPattern(context); 543 auto linalgTilingOptions = 544 linalg::LinalgTilingOptions() 545 .setTileSizes(tileSizes) 546 .setPaddingValueComputationFunction(getNeutralOfLinalgOp); 547 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>>( 548 context, linalgTilingOptions, 549 linalg::LinalgTransformationFilter( 550 Identifier::get("tile-and-pad", context))); 551 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 552 } 553 554 static void applyInterchangePattern(FuncOp funcOp, 555 ArrayRef<unsigned> interchangeVector) { 556 MLIRContext *context = funcOp.getContext(); 557 RewritePatternSet interchangePattern(context); 558 interchangePattern.add<GenericOpInterchangePattern>( 559 context, interchangeVector, 560 LinalgTransformationFilter(ArrayRef<Identifier>{}, 561 Identifier::get("interchange", context))); 562 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 563 } 564 565 /// Apply transformations specified as patterns. 566 void TestLinalgTransforms::runOnFunction() { 567 auto lambda = [&](void *) { 568 getFunction().walk([](LinalgOp op) { 569 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 570 }); 571 }; 572 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 573 574 if (testPromotionOptions) { 575 RewritePatternSet patterns(&getContext()); 576 fillPromotionCallBackPatterns(&getContext(), patterns); 577 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 578 return; 579 } 580 if (testTileAndDistributionOptions) { 581 RewritePatternSet patterns(&getContext()); 582 fillTileAndDistributePatterns(&getContext(), patterns); 583 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 584 return; 585 } 586 if (testPatterns) 587 return applyPatterns(getFunction()); 588 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 589 return applyMatmulToVectorPatterns(getFunction(), 590 testMatmulToVectorPatterns1dTiling, 591 testMatmulToVectorPatterns2dTiling); 592 if (testVectorTransferForwardingPatterns) 593 return applyVectorTransferForwardingPatterns(getFunction()); 594 if (testGenericToVectorPattern) 595 return applyLinalgToVectorPatterns(getFunction()); 596 if (testTransformPadTensor) 597 return applyPadTensorToGenericPatterns(getFunction()); 598 if (testAffineMinSCFCanonicalizationPatterns) 599 return applyAffineMinSCFCanonicalizationPatterns(getFunction()); 600 if (testTileAndPadPattern) 601 return applyTileAndPadPattern(getFunction(), tileSizesForPadding); 602 if (testHoistPadding) { 603 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 604 (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); 605 }); 606 } 607 if (testInterchangePattern.hasValue()) 608 return applyInterchangePattern(getFunction(), testInterchangePattern); 609 } 610 611 namespace mlir { 612 namespace test { 613 void registerTestLinalgTransforms() { 614 PassRegistration<TestLinalgTransforms> testTransformPatternsPass( 615 "test-linalg-transform-patterns", 616 "Test Linalg transformation patterns by applying them greedily."); 617 } 618 } // namespace test 619 } // namespace mlir 620