1 //===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a configurable pass that can apply patterns liberally 10 // and be plugged in a pass pipeline. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/SliceAnalysis.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Affine/Utils.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Passes.h" 23 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 24 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 25 #include "mlir/Dialect/Linalg/Utils/Utils.h" 26 #include "mlir/Dialect/SCF/Transforms.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 29 #include "mlir/IR/AffineExpr.h" 30 #include "mlir/IR/AffineMap.h" 31 #include "mlir/Pass/PassManager.h" 32 #include "mlir/Support/LLVM.h" 33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 34 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" 35 #include "mlir/Transforms/Passes.h" 36 37 using namespace mlir; 38 using namespace mlir::vector; 39 using namespace linalg; 40 41 namespace { 42 43 /// Configurable pass to apply pattern-based tiling and fusion. 44 struct LinalgStrategyTileAndFusePass 45 : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> { 46 47 LinalgStrategyTileAndFusePass() = default; 48 49 LinalgStrategyTileAndFusePass(StringRef opName, 50 LinalgTilingAndFusionOptions opt, 51 LinalgTransformationFilter filt) 52 : options(std::move(opt)), filter(std::move(filt)) { 53 this->anchorOpName.setValue(opName.str()); 54 } 55 56 void runOnOperation() override { 57 auto funcOp = getOperation(); 58 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 59 return; 60 61 RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); 62 if (!anchorOpName.empty()) { 63 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 64 anchorOpName, funcOp.getContext(), options, filter); 65 } else { 66 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 67 funcOp.getContext(), options, filter); 68 } 69 // Search the root operation using bottom up traversal. 70 GreedyRewriteConfig config; 71 config.useTopDownTraversal = false; 72 (void)applyPatternsAndFoldGreedily( 73 funcOp, std::move(tilingAndFusionPattern), config); 74 } 75 76 LinalgTilingAndFusionOptions options; 77 LinalgTransformationFilter filter; 78 }; 79 80 /// Configurable pass to apply pattern-based linalg tiling. 81 struct LinalgStrategyTilePass 82 : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> { 83 84 LinalgStrategyTilePass() = default; 85 86 LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 87 LinalgTransformationFilter filt) 88 : options(std::move(opt)), filter(std::move(filt)) { 89 this->anchorOpName.setValue(opName.str()); 90 } 91 92 void runOnOperation() override { 93 auto funcOp = getOperation(); 94 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 95 return; 96 97 MLIRContext *ctx = funcOp.getContext(); 98 RewritePatternSet tilingPattern(ctx); 99 if (!anchorOpName.empty()) 100 tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options, 101 filter); 102 else 103 tilingPattern.add<LinalgTilingPattern>(ctx, options, filter); 104 if (anchorOpName == tensor::PadOp::getOperationName()) 105 populatePadTensorTilingPatterns(tilingPattern, options); 106 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 107 } 108 109 LinalgTilingOptions options; 110 LinalgTransformationFilter filter; 111 }; 112 113 /// Configurable pass to apply hoisting and padding. 114 struct LinalgStrategyPadPass 115 : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> { 116 117 LinalgStrategyPadPass() = default; 118 119 LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 120 LinalgTransformationFilter filt) 121 : options(std::move(opt)), filter(std::move(filt)) { 122 this->anchorOpName.setValue(opName.str()); 123 } 124 125 void runOnOperation() override { 126 auto funcOp = getOperation(); 127 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 128 return; 129 130 RewritePatternSet paddingPattern(funcOp.getContext()); 131 if (!anchorOpName.empty()) { 132 paddingPattern.add<LinalgPaddingPattern>( 133 anchorOpName, funcOp.getContext(), options, filter); 134 } else { 135 paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options, 136 filter); 137 } 138 (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)); 139 } 140 141 LinalgPaddingOptions options; 142 LinalgTransformationFilter filter; 143 }; 144 145 /// Configurable pass to apply pattern-based linalg generalization. 146 struct LinalgStrategyGeneralizePass 147 : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> { 148 149 LinalgStrategyGeneralizePass() = default; 150 151 LinalgStrategyGeneralizePass(StringRef opName, 152 LinalgTransformationFilter filter) 153 : filter(std::move(filter)) { 154 this->anchorOpName.setValue(opName.str()); 155 } 156 157 void runOnOperation() override { 158 auto funcOp = getOperation(); 159 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 160 return; 161 162 RewritePatternSet generalizationPattern(funcOp.getContext()); 163 if (!anchorOpName.empty()) { 164 generalizationPattern.add<LinalgGeneralizationPattern>( 165 anchorOpName, funcOp.getContext(), filter); 166 } else { 167 generalizationPattern.add<LinalgGeneralizationPattern>( 168 funcOp.getContext(), filter); 169 } 170 if (failed(applyPatternsAndFoldGreedily(funcOp, 171 std::move(generalizationPattern)))) 172 signalPassFailure(); 173 } 174 175 LinalgTransformationFilter filter; 176 }; 177 178 /// Configurable pass to apply lowering of coarser-grained named linalg ops into 179 /// finer-grained named versions. 180 struct LinalgStrategyDecomposePass 181 : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> { 182 183 LinalgStrategyDecomposePass() = default; 184 185 LinalgStrategyDecomposePass(LinalgTransformationFilter filter) 186 : filter(std::move(filter)) {} 187 188 void runOnOperation() override { 189 auto funcOp = getOperation(); 190 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 191 return; 192 RewritePatternSet decompositionPattern(funcOp.getContext()); 193 populateDecomposeConvolutionPatterns(decompositionPattern, filter); 194 if (failed(applyPatternsAndFoldGreedily(funcOp, 195 std::move(decompositionPattern)))) 196 signalPassFailure(); 197 } 198 199 LinalgTransformationFilter filter; 200 }; 201 202 /// Configurable pass to apply pattern-based linalg generalization. 203 struct LinalgStrategyInterchangePass 204 : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { 205 206 LinalgStrategyInterchangePass() = default; 207 208 LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 209 LinalgTransformationFilter filter) 210 : iteratorInterchange(iteratorInterchange.begin(), 211 iteratorInterchange.end()), 212 filter(std::move(filter)) {} 213 214 void runOnOperation() override { 215 auto funcOp = getOperation(); 216 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 217 return; 218 219 SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(), 220 iteratorInterchange.end()); 221 RewritePatternSet interchangePattern(funcOp.getContext()); 222 interchangePattern.add<GenericOpInterchangePattern>( 223 funcOp.getContext(), interchangeVector, filter); 224 if (failed(applyPatternsAndFoldGreedily(funcOp, 225 std::move(interchangePattern)))) 226 signalPassFailure(); 227 } 228 229 SmallVector<int64_t> iteratorInterchange; 230 LinalgTransformationFilter filter; 231 }; 232 233 /// Configurable pass to apply pattern-based linalg promotion. 234 struct LinalgStrategyPromotePass 235 : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> { 236 237 LinalgStrategyPromotePass() = default; 238 239 LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, 240 LinalgTransformationFilter filt) 241 : options(std::move(opt)), filter(std::move(filt)) { 242 this->anchorOpName.setValue(opName.str()); 243 } 244 245 void runOnOperation() override { 246 auto funcOp = getOperation(); 247 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 248 return; 249 250 RewritePatternSet promotionPattern(funcOp.getContext()); 251 if (!anchorOpName.empty()) { 252 promotionPattern.add<LinalgBasePromotionPattern>( 253 anchorOpName, funcOp.getContext(), options, filter); 254 } else { 255 promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(), 256 filter, options); 257 } 258 (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); 259 } 260 261 LinalgPromotionOptions options; 262 LinalgTransformationFilter filter; 263 }; 264 265 /// Configurable pass to apply pattern-based linalg peeling. 266 struct LinalgStrategyPeelPass 267 : public LinalgStrategyPeelPassBase<LinalgStrategyPeelPass> { 268 269 LinalgStrategyPeelPass() = default; 270 271 LinalgStrategyPeelPass(StringRef opName, LinalgPeelOptions opt, 272 LinalgTransformationFilter filt) 273 : options(opt), filter(std::move(filt)) { 274 this->anchorOpName.setValue(opName.str()); 275 } 276 277 void runOnOperation() override { 278 auto funcOp = getOperation(); 279 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 280 return; 281 282 RewritePatternSet peelingPatterns(funcOp.getContext()); 283 if (!anchorOpName.empty()) { 284 peelingPatterns.add<LinalgPeelingPattern>( 285 anchorOpName, funcOp.getContext(), options, filter); 286 } else { 287 peelingPatterns.add<LinalgPeelingPattern>(funcOp.getContext(), filter, 288 options); 289 } 290 if (failed( 291 applyPatternsAndFoldGreedily(funcOp, std::move(peelingPatterns)))) 292 return signalPassFailure(); 293 } 294 295 LinalgPeelOptions options; 296 LinalgTransformationFilter filter; 297 }; 298 299 /// Configurable pass to apply pattern-based linalg vectorization. 300 struct LinalgStrategyVectorizePass 301 : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> { 302 303 LinalgStrategyVectorizePass() = default; 304 305 LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, 306 LinalgTransformationFilter filt, 307 bool padVectorize = false) 308 : options(opt), filter(std::move(filt)) { 309 this->anchorOpName.setValue(opName.str()); 310 this->vectorizePadding.setValue(padVectorize); 311 } 312 313 void runOnOperation() override { 314 auto funcOp = getOperation(); 315 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 316 return; 317 318 RewritePatternSet vectorizationPatterns(funcOp.getContext()); 319 if (!anchorOpName.empty()) { 320 vectorizationPatterns.add<LinalgVectorizationPattern>( 321 anchorOpName, funcOp.getContext(), options, filter); 322 } else { 323 vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(), 324 filter, options); 325 } 326 vector::populateVectorTransferPermutationMapLoweringPatterns( 327 vectorizationPatterns); 328 vector::populateVectorReductionToContractPatterns(vectorizationPatterns); 329 vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern, 330 linalg::LinalgCopyVTWForwardingPattern>( 331 funcOp.getContext(), /*benefit=*/2); 332 TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, 333 funcOp.getContext()); 334 TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, 335 funcOp.getContext()); 336 (void)applyPatternsAndFoldGreedily(funcOp, 337 std::move(vectorizationPatterns)); 338 339 // Apply the pad tensor op vectorization separately to avoid running the 340 // GenericPadOpVectorizationPattern too early. 341 // TODO: Improve once we have better infrastructure to control pattern 342 // application. 343 if (vectorizePadding) { 344 RewritePatternSet patterns(funcOp.getContext()); 345 linalg::populatePadOpVectorizationPatterns(patterns); 346 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 347 } 348 } 349 350 LinalgVectorizationOptions options; 351 LinalgTransformationFilter filter; 352 }; 353 354 /// Configurable pass to enable the application of other pattern-based linalg 355 /// passes. 356 struct LinalgStrategyEnablePass 357 : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> { 358 359 LinalgStrategyEnablePass(LinalgEnablingOptions opt, 360 LinalgTransformationFilter filt) 361 : options(opt), filter(std::move(filt)) {} 362 363 void runOnOperation() override { 364 auto funcOp = getOperation(); 365 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 366 return; 367 368 MLIRContext *context = funcOp.getContext(); 369 RewritePatternSet patterns = 370 linalg::getLinalgTilingCanonicalizationPatterns(context); 371 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 372 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 373 return signalPassFailure(); 374 375 if (options.licm) { 376 funcOp->walk([&](LoopLikeOpInterface loopLike) { 377 moveLoopInvariantCode(loopLike); 378 }); 379 } 380 381 // Gathers all innermost loops through a post order pruned walk. 382 funcOp.walk([](Operation *op) { 383 if (auto forOp = dyn_cast<AffineForOp>(op)) 384 (void)promoteIfSingleIteration(forOp); 385 else if (auto forOp = dyn_cast<scf::ForOp>(op)) 386 (void)promoteIfSingleIteration(forOp); 387 }); 388 if (options.hoistRedundantVectorTransfers) 389 hoistRedundantVectorTransfers(funcOp); 390 391 if (options.hoistRedundantVectorTransfersOnTensor) 392 hoistRedundantVectorTransfersOnTensor(funcOp); 393 394 // Run CSE to cleanup after canonicalization. 395 OpPassManager dynamicPM("func.func"); 396 dynamicPM.addPass(createCSEPass()); 397 if (failed(runPipeline(dynamicPM, funcOp))) 398 return signalPassFailure(); 399 } 400 401 LinalgEnablingOptions options; 402 LinalgTransformationFilter filter; 403 }; 404 405 /// Configurable pass to lower vector operations. 406 struct LinalgStrategyLowerVectorsPass 407 : public LinalgStrategyLowerVectorsPassBase< 408 LinalgStrategyLowerVectorsPass> { 409 410 LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 411 LinalgTransformationFilter filt) 412 : options(opt), filter(std::move(filt)) {} 413 414 void runOnOperation() override { 415 auto funcOp = getOperation(); 416 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 417 return; 418 419 MLIRContext *context = funcOp.getContext(); 420 RewritePatternSet patterns(context); 421 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 422 // In a progressive lowering of vectors, this would be the 1st step. 423 if (options.contractionLowering) { 424 patterns.add<ContractionOpToOuterProductOpLowering, 425 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 426 options.vectorTransformOptions, context); 427 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 428 } 429 // In a progressive lowering of vectors, this would be the 2nd step. 430 if (options.multiReductionLowering) { 431 vector::populateVectorMultiReductionLoweringPatterns( 432 patterns, 433 options.vectorTransformOptions.vectorMultiReductionLowering); 434 } 435 // In a progressive lowering of vectors, this would be the 3rd step. 436 if (options.transferPartialRewrite) { 437 patterns.add<vector::VectorTransferFullPartialRewriter>( 438 context, options.vectorTransformOptions); 439 } 440 // In a progressive lowering of vectors, this would be the 4th step. 441 if (options.transferLowering) { 442 vector::populateVectorTransferLoweringPatterns(patterns, 443 options.maxTransferRank); 444 } 445 // In a progressive lowering of vectors, this would be the 5th step. 446 if (options.transferToSCFConversion) { 447 populateVectorToSCFConversionPatterns( 448 patterns, options.vectorTransferToSCFOptions.setTargetRank( 449 options.maxTransferRank)); 450 } 451 // In a progressive lowering of vectors, this would be the 6th step. 452 if (options.shapeCastLowering) { 453 vector::populateVectorShapeCastLoweringPatterns(patterns); 454 } 455 // In a progressive lowering of vectors, this would be the 7th step. 456 if (options.transposeLowering) { 457 vector::populateVectorTransposeLoweringPatterns( 458 patterns, options.vectorTransformOptions); 459 if (options.avx2Lowering) 460 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 461 patterns, options.avx2LoweringOptions, /*benefit=*/10); 462 } 463 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 464 } 465 466 LinalgVectorLoweringOptions options; 467 LinalgTransformationFilter filter; 468 }; 469 470 /// Configurable pass to lower vector operations. 471 struct LinalgStrategyRemoveMarkersPass 472 : public LinalgStrategyRemoveMarkersPassBase< 473 LinalgStrategyRemoveMarkersPass> { 474 475 void runOnOperation() override { 476 auto funcOp = getOperation(); 477 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 478 return; 479 funcOp.walk([](LinalgOp op) { 480 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 481 }); 482 } 483 }; 484 } // namespace 485 486 /// Create a LinalgStrategyTileAndFusePass. 487 std::unique_ptr<OperationPass<func::FuncOp>> 488 mlir::createLinalgStrategyTileAndFusePass( 489 StringRef opName, const LinalgTilingAndFusionOptions &options, 490 const LinalgTransformationFilter &filter) { 491 return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options, 492 filter); 493 } 494 495 /// Create a LinalgStrategyTilePass. 496 std::unique_ptr<OperationPass<func::FuncOp>> 497 mlir::createLinalgStrategyTilePass(StringRef opName, 498 const LinalgTilingOptions &opt, 499 const LinalgTransformationFilter &filter) { 500 return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter); 501 } 502 503 /// Create a LinalgStrategyPadPass. 504 std::unique_ptr<OperationPass<func::FuncOp>> 505 mlir::createLinalgStrategyPadPass(StringRef opName, 506 const LinalgPaddingOptions &opt, 507 const LinalgTransformationFilter &filter) { 508 return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter); 509 } 510 511 /// Create a LinalgStrategyPromotePass. 512 std::unique_ptr<OperationPass<func::FuncOp>> 513 mlir::createLinalgStrategyPromotePass( 514 StringRef opName, const LinalgPromotionOptions &opt, 515 const LinalgTransformationFilter &filter) { 516 return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter); 517 } 518 519 /// Create a LinalgStrategyGeneralizePass. 520 std::unique_ptr<OperationPass<func::FuncOp>> 521 mlir::createLinalgStrategyGeneralizePass( 522 StringRef opName, const LinalgTransformationFilter &filter) { 523 return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); 524 } 525 526 /// Create a LinalgStrategyDecomposePass. 527 // TODO: if/when we need finer control add an `opName` parameter. 528 std::unique_ptr<OperationPass<func::FuncOp>> 529 mlir::createLinalgStrategyDecomposePass( 530 const LinalgTransformationFilter &filter) { 531 return std::make_unique<LinalgStrategyDecomposePass>(filter); 532 } 533 534 /// Create a LinalgStrategyInterchangePass. 535 std::unique_ptr<OperationPass<func::FuncOp>> 536 mlir::createLinalgStrategyInterchangePass( 537 ArrayRef<int64_t> iteratorInterchange, 538 const LinalgTransformationFilter &filter) { 539 return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, 540 filter); 541 } 542 543 /// Create a LinalgStrategyPeelPass. 544 std::unique_ptr<OperationPass<func::FuncOp>> 545 mlir::createLinalgStrategyPeelPass(StringRef opName, LinalgPeelOptions opt, 546 const LinalgTransformationFilter &filter) { 547 return std::make_unique<LinalgStrategyPeelPass>(opName, opt, filter); 548 } 549 550 /// Create a LinalgStrategyVectorizePass. 551 std::unique_ptr<OperationPass<func::FuncOp>> 552 mlir::createLinalgStrategyVectorizePass( 553 StringRef opName, LinalgVectorizationOptions opt, 554 const LinalgTransformationFilter &filter, bool padVectorize) { 555 return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter, 556 padVectorize); 557 } 558 559 /// Create a LinalgStrategyEnablePass. 560 std::unique_ptr<OperationPass<func::FuncOp>> 561 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt, 562 const LinalgTransformationFilter &filter) { 563 return std::make_unique<LinalgStrategyEnablePass>(opt, filter); 564 } 565 566 /// Create a LinalgStrategyLowerVectorsPass. 567 std::unique_ptr<OperationPass<func::FuncOp>> 568 mlir::createLinalgStrategyLowerVectorsPass( 569 LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) { 570 return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter); 571 } 572 573 /// Create a LinalgStrategyRemoveMarkersPass. 574 std::unique_ptr<OperationPass<func::FuncOp>> 575 mlir::createLinalgStrategyRemoveMarkersPass() { 576 return std::make_unique<LinalgStrategyRemoveMarkersPass>(); 577 } 578