1 //===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a configurable pass that can apply patterns liberally 10 // and be plugged in a pass pipeline. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/SliceAnalysis.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Affine/Utils.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Passes.h" 23 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 24 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 25 #include "mlir/Dialect/Linalg/Utils/Utils.h" 26 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 29 #include "mlir/IR/AffineExpr.h" 30 #include "mlir/IR/AffineMap.h" 31 #include "mlir/Pass/PassManager.h" 32 #include "mlir/Support/LLVM.h" 33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 34 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" 35 #include "mlir/Transforms/Passes.h" 36 37 using namespace mlir; 38 using namespace mlir::vector; 39 using namespace linalg; 40 41 namespace { 42 43 /// Configurable pass to apply pattern-based tiling and fusion. 44 struct LinalgStrategyTileAndFusePass 45 : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> { 46 47 LinalgStrategyTileAndFusePass() = default; 48 49 LinalgStrategyTileAndFusePass(StringRef opName, 50 LinalgTilingAndFusionOptions opt, 51 LinalgTransformationFilter filt) 52 : options(std::move(opt)), filter(std::move(filt)) { 53 this->anchorOpName.setValue(opName.str()); 54 } 55 56 void runOnOperation() override { 57 auto funcOp = getOperation(); 58 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 59 return; 60 61 RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); 62 if (!anchorOpName.empty()) { 63 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 64 anchorOpName, funcOp.getContext(), options, filter); 65 } else { 66 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 67 funcOp.getContext(), options, filter); 68 } 69 // Search the root operation using bottom up traversal. 70 GreedyRewriteConfig config; 71 config.useTopDownTraversal = false; 72 (void)applyPatternsAndFoldGreedily( 73 funcOp, std::move(tilingAndFusionPattern), config); 74 } 75 76 LinalgTilingAndFusionOptions options; 77 LinalgTransformationFilter filter; 78 }; 79 80 /// Configurable pass to apply pattern-based linalg tiling. 81 struct LinalgStrategyTilePass 82 : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> { 83 84 LinalgStrategyTilePass() = default; 85 86 LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 87 LinalgTransformationFilter filt) 88 : options(std::move(opt)), filter(std::move(filt)) { 89 this->anchorOpName.setValue(opName.str()); 90 } 91 92 void runOnOperation() override { 93 auto funcOp = getOperation(); 94 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 95 return; 96 97 MLIRContext *ctx = funcOp.getContext(); 98 RewritePatternSet tilingPattern(ctx); 99 if (!anchorOpName.empty()) 100 tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options, 101 filter); 102 else 103 tilingPattern.add<LinalgTilingPattern>(ctx, options, filter); 104 if (anchorOpName == tensor::PadOp::getOperationName()) 105 populatePadTensorTilingPatterns(tilingPattern, options); 106 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 107 } 108 109 LinalgTilingOptions options; 110 LinalgTransformationFilter filter; 111 }; 112 113 /// Configurable pass to apply hoisting and padding. 114 struct LinalgStrategyPadPass 115 : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> { 116 117 LinalgStrategyPadPass() = default; 118 119 LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 120 LinalgTransformationFilter filt) 121 : options(std::move(opt)), filter(std::move(filt)) { 122 this->anchorOpName.setValue(opName.str()); 123 } 124 125 void runOnOperation() override { 126 auto funcOp = getOperation(); 127 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 128 return; 129 130 RewritePatternSet paddingPattern(funcOp.getContext()); 131 if (!anchorOpName.empty()) { 132 paddingPattern.add<LinalgPaddingPattern>( 133 anchorOpName, funcOp.getContext(), options, filter); 134 } else { 135 paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options, 136 filter); 137 } 138 (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)); 139 } 140 141 LinalgPaddingOptions options; 142 LinalgTransformationFilter filter; 143 }; 144 145 /// Configurable pass to apply pattern-based linalg generalization. 146 struct LinalgStrategyGeneralizePass 147 : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> { 148 149 LinalgStrategyGeneralizePass() = default; 150 151 LinalgStrategyGeneralizePass(StringRef opName, 152 LinalgTransformationFilter filter) 153 : filter(std::move(filter)) { 154 this->anchorOpName.setValue(opName.str()); 155 } 156 157 void runOnOperation() override { 158 auto funcOp = getOperation(); 159 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 160 return; 161 162 RewritePatternSet generalizationPattern(funcOp.getContext()); 163 if (!anchorOpName.empty()) { 164 generalizationPattern.add<LinalgGeneralizationPattern>( 165 anchorOpName, funcOp.getContext(), filter); 166 } else { 167 generalizationPattern.add<LinalgGeneralizationPattern>( 168 funcOp.getContext(), filter); 169 } 170 if (failed(applyPatternsAndFoldGreedily(funcOp, 171 std::move(generalizationPattern)))) 172 signalPassFailure(); 173 } 174 175 LinalgTransformationFilter filter; 176 }; 177 178 /// Configurable pass to apply lowering of coarser-grained named linalg ops into 179 /// finer-grained named versions. 180 struct LinalgStrategyDecomposePass 181 : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> { 182 183 LinalgStrategyDecomposePass() = default; 184 185 LinalgStrategyDecomposePass(LinalgTransformationFilter filter) 186 : filter(std::move(filter)) {} 187 188 void runOnOperation() override { 189 auto funcOp = getOperation(); 190 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 191 return; 192 RewritePatternSet decompositionPattern(funcOp.getContext()); 193 populateDecomposeConvolutionPatterns(decompositionPattern, filter); 194 if (failed(applyPatternsAndFoldGreedily(funcOp, 195 std::move(decompositionPattern)))) 196 signalPassFailure(); 197 } 198 199 LinalgTransformationFilter filter; 200 }; 201 202 /// Configurable pass to apply pattern-based linalg generalization. 203 struct LinalgStrategyInterchangePass 204 : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { 205 206 LinalgStrategyInterchangePass() = default; 207 208 LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 209 LinalgTransformationFilter filter) 210 : iteratorInterchange(iteratorInterchange.begin(), 211 iteratorInterchange.end()), 212 filter(std::move(filter)) {} 213 214 void runOnOperation() override { 215 auto funcOp = getOperation(); 216 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 217 return; 218 219 SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(), 220 iteratorInterchange.end()); 221 RewritePatternSet interchangePattern(funcOp.getContext()); 222 interchangePattern.add<GenericOpInterchangePattern>( 223 funcOp.getContext(), interchangeVector, filter); 224 if (failed(applyPatternsAndFoldGreedily(funcOp, 225 std::move(interchangePattern)))) 226 signalPassFailure(); 227 } 228 229 SmallVector<int64_t> iteratorInterchange; 230 LinalgTransformationFilter filter; 231 }; 232 233 /// Configurable pass to apply pattern-based linalg peeling. 234 struct LinalgStrategyPeelPass 235 : public LinalgStrategyPeelPassBase<LinalgStrategyPeelPass> { 236 237 LinalgStrategyPeelPass() = default; 238 239 LinalgStrategyPeelPass(StringRef opName, LinalgPeelOptions opt, 240 LinalgTransformationFilter filt) 241 : options(std::move(opt)), filter(std::move(filt)) { 242 this->anchorOpName.setValue(opName.str()); 243 } 244 245 void runOnOperation() override { 246 auto funcOp = getOperation(); 247 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 248 return; 249 250 RewritePatternSet peelingPatterns(funcOp.getContext()); 251 if (!anchorOpName.empty()) { 252 peelingPatterns.add<LinalgPeelingPattern>( 253 anchorOpName, funcOp.getContext(), options, filter); 254 } else { 255 peelingPatterns.add<LinalgPeelingPattern>(funcOp.getContext(), filter, 256 options); 257 } 258 if (failed( 259 applyPatternsAndFoldGreedily(funcOp, std::move(peelingPatterns)))) 260 return signalPassFailure(); 261 } 262 263 LinalgPeelOptions options; 264 LinalgTransformationFilter filter; 265 }; 266 267 /// Configurable pass to apply pattern-based linalg vectorization. 268 struct LinalgStrategyVectorizePass 269 : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> { 270 271 LinalgStrategyVectorizePass() = default; 272 273 LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, 274 LinalgTransformationFilter filt, 275 bool padVectorize = false) 276 : options(opt), filter(std::move(filt)) { 277 this->anchorOpName.setValue(opName.str()); 278 this->vectorizePadding.setValue(padVectorize); 279 } 280 281 void runOnOperation() override { 282 auto funcOp = getOperation(); 283 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 284 return; 285 286 RewritePatternSet vectorizationPatterns(funcOp.getContext()); 287 if (!anchorOpName.empty()) { 288 vectorizationPatterns.add<LinalgVectorizationPattern>( 289 anchorOpName, funcOp.getContext(), options, filter); 290 } else { 291 vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(), 292 filter, options); 293 } 294 vector::populateVectorTransferPermutationMapLoweringPatterns( 295 vectorizationPatterns); 296 vector::populateVectorReductionToContractPatterns(vectorizationPatterns); 297 vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern, 298 linalg::LinalgCopyVTWForwardingPattern>( 299 funcOp.getContext(), /*benefit=*/2); 300 TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, 301 funcOp.getContext()); 302 TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, 303 funcOp.getContext()); 304 (void)applyPatternsAndFoldGreedily(funcOp, 305 std::move(vectorizationPatterns)); 306 307 // Apply the pad tensor op vectorization separately to avoid running the 308 // GenericPadOpVectorizationPattern too early. 309 // TODO: Improve once we have better infrastructure to control pattern 310 // application. 311 if (vectorizePadding) { 312 RewritePatternSet patterns(funcOp.getContext()); 313 linalg::populatePadOpVectorizationPatterns(patterns); 314 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 315 } 316 } 317 318 LinalgVectorizationOptions options; 319 LinalgTransformationFilter filter; 320 }; 321 322 /// Configurable pass to enable the application of other pattern-based linalg 323 /// passes. 324 struct LinalgStrategyEnablePass 325 : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> { 326 327 LinalgStrategyEnablePass(LinalgEnablingOptions opt, 328 LinalgTransformationFilter filt) 329 : options(opt), filter(std::move(filt)) {} 330 331 void runOnOperation() override { 332 auto funcOp = getOperation(); 333 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 334 return; 335 336 MLIRContext *context = funcOp.getContext(); 337 RewritePatternSet patterns = 338 linalg::getLinalgTilingCanonicalizationPatterns(context); 339 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 340 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 341 return signalPassFailure(); 342 343 if (options.licm) { 344 funcOp->walk([&](LoopLikeOpInterface loopLike) { 345 moveLoopInvariantCode(loopLike); 346 }); 347 } 348 349 // Gathers all innermost loops through a post order pruned walk. 350 funcOp.walk([](Operation *op) { 351 if (auto forOp = dyn_cast<AffineForOp>(op)) 352 (void)promoteIfSingleIteration(forOp); 353 else if (auto forOp = dyn_cast<scf::ForOp>(op)) 354 (void)promoteIfSingleIteration(forOp); 355 }); 356 if (options.hoistRedundantVectorTransfers) 357 hoistRedundantVectorTransfers(funcOp); 358 359 if (options.hoistRedundantVectorTransfersOnTensor) 360 hoistRedundantVectorTransfersOnTensor(funcOp); 361 362 // Run CSE to cleanup after canonicalization. 363 OpPassManager dynamicPM("func.func"); 364 dynamicPM.addPass(createCSEPass()); 365 if (failed(runPipeline(dynamicPM, funcOp))) 366 return signalPassFailure(); 367 } 368 369 LinalgEnablingOptions options; 370 LinalgTransformationFilter filter; 371 }; 372 373 /// Configurable pass to lower vector operations. 374 struct LinalgStrategyLowerVectorsPass 375 : public LinalgStrategyLowerVectorsPassBase< 376 LinalgStrategyLowerVectorsPass> { 377 378 LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 379 LinalgTransformationFilter filt) 380 : options(opt), filter(std::move(filt)) {} 381 382 void runOnOperation() override { 383 auto funcOp = getOperation(); 384 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 385 return; 386 387 MLIRContext *context = funcOp.getContext(); 388 RewritePatternSet patterns(context); 389 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 390 // In a progressive lowering of vectors, this would be the 1st step. 391 if (options.contractionLowering) { 392 patterns.add<ContractionOpToOuterProductOpLowering, 393 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 394 options.vectorTransformOptions, context); 395 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 396 } 397 // In a progressive lowering of vectors, this would be the 2nd step. 398 if (options.multiReductionLowering) { 399 vector::populateVectorMultiReductionLoweringPatterns( 400 patterns, 401 options.vectorTransformOptions.vectorMultiReductionLowering); 402 } 403 // In a progressive lowering of vectors, this would be the 3rd step. 404 if (options.transferPartialRewrite) { 405 patterns.add<vector::VectorTransferFullPartialRewriter>( 406 context, options.vectorTransformOptions); 407 } 408 // In a progressive lowering of vectors, this would be the 4th step. 409 if (options.transferLowering) { 410 vector::populateVectorTransferLoweringPatterns(patterns, 411 options.maxTransferRank); 412 } 413 // In a progressive lowering of vectors, this would be the 5th step. 414 if (options.transferToSCFConversion) { 415 populateVectorToSCFConversionPatterns( 416 patterns, options.vectorTransferToSCFOptions.setTargetRank( 417 options.maxTransferRank)); 418 } 419 // In a progressive lowering of vectors, this would be the 6th step. 420 if (options.shapeCastLowering) { 421 vector::populateVectorShapeCastLoweringPatterns(patterns); 422 } 423 // In a progressive lowering of vectors, this would be the 7th step. 424 if (options.transposeLowering) { 425 vector::populateVectorTransposeLoweringPatterns( 426 patterns, options.vectorTransformOptions); 427 if (options.avx2Lowering) 428 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 429 patterns, options.avx2LoweringOptions, /*benefit=*/10); 430 } 431 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 432 } 433 434 LinalgVectorLoweringOptions options; 435 LinalgTransformationFilter filter; 436 }; 437 438 /// Configurable pass to lower vector operations. 439 struct LinalgStrategyRemoveMarkersPass 440 : public LinalgStrategyRemoveMarkersPassBase< 441 LinalgStrategyRemoveMarkersPass> { 442 443 void runOnOperation() override { 444 auto funcOp = getOperation(); 445 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 446 return; 447 funcOp.walk([](LinalgOp op) { 448 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 449 }); 450 } 451 }; 452 } // namespace 453 454 /// Create a LinalgStrategyTileAndFusePass. 455 std::unique_ptr<OperationPass<func::FuncOp>> 456 mlir::createLinalgStrategyTileAndFusePass( 457 StringRef opName, const LinalgTilingAndFusionOptions &options, 458 const LinalgTransformationFilter &filter) { 459 return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options, 460 filter); 461 } 462 463 /// Create a LinalgStrategyTilePass. 464 std::unique_ptr<OperationPass<func::FuncOp>> 465 mlir::createLinalgStrategyTilePass(StringRef opName, 466 const LinalgTilingOptions &opt, 467 const LinalgTransformationFilter &filter) { 468 return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter); 469 } 470 471 /// Create a LinalgStrategyPadPass. 472 std::unique_ptr<OperationPass<func::FuncOp>> 473 mlir::createLinalgStrategyPadPass(StringRef opName, 474 const LinalgPaddingOptions &opt, 475 const LinalgTransformationFilter &filter) { 476 return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter); 477 } 478 479 /// Create a LinalgStrategyGeneralizePass. 480 std::unique_ptr<OperationPass<func::FuncOp>> 481 mlir::createLinalgStrategyGeneralizePass( 482 StringRef opName, const LinalgTransformationFilter &filter) { 483 return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); 484 } 485 486 /// Create a LinalgStrategyDecomposePass. 487 // TODO: if/when we need finer control add an `opName` parameter. 488 std::unique_ptr<OperationPass<func::FuncOp>> 489 mlir::createLinalgStrategyDecomposePass( 490 const LinalgTransformationFilter &filter) { 491 return std::make_unique<LinalgStrategyDecomposePass>(filter); 492 } 493 494 /// Create a LinalgStrategyInterchangePass. 495 std::unique_ptr<OperationPass<func::FuncOp>> 496 mlir::createLinalgStrategyInterchangePass( 497 ArrayRef<int64_t> iteratorInterchange, 498 const LinalgTransformationFilter &filter) { 499 return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, 500 filter); 501 } 502 503 /// Create a LinalgStrategyPeelPass. 504 std::unique_ptr<OperationPass<func::FuncOp>> 505 mlir::createLinalgStrategyPeelPass(StringRef opName, 506 const LinalgPeelOptions &opt, 507 const LinalgTransformationFilter &filter) { 508 return std::make_unique<LinalgStrategyPeelPass>(opName, opt, filter); 509 } 510 511 /// Create a LinalgStrategyVectorizePass. 512 std::unique_ptr<OperationPass<func::FuncOp>> 513 mlir::createLinalgStrategyVectorizePass( 514 StringRef opName, LinalgVectorizationOptions opt, 515 const LinalgTransformationFilter &filter, bool padVectorize) { 516 return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter, 517 padVectorize); 518 } 519 520 /// Create a LinalgStrategyEnablePass. 521 std::unique_ptr<OperationPass<func::FuncOp>> 522 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt, 523 const LinalgTransformationFilter &filter) { 524 return std::make_unique<LinalgStrategyEnablePass>(opt, filter); 525 } 526 527 /// Create a LinalgStrategyLowerVectorsPass. 528 std::unique_ptr<OperationPass<func::FuncOp>> 529 mlir::createLinalgStrategyLowerVectorsPass( 530 LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) { 531 return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter); 532 } 533 534 /// Create a LinalgStrategyRemoveMarkersPass. 535 std::unique_ptr<OperationPass<func::FuncOp>> 536 mlir::createLinalgStrategyRemoveMarkersPass() { 537 return std::make_unique<LinalgStrategyRemoveMarkersPass>(); 538 } 539