1 //===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a configurable pass that can apply patterns liberally 10 // and be plugged in a pass pipeline. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/SliceAnalysis.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Affine/Utils.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Passes.h" 23 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 24 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 25 #include "mlir/Dialect/Linalg/Utils/Utils.h" 26 #include "mlir/Dialect/SCF/Transforms.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 29 #include "mlir/IR/AffineExpr.h" 30 #include "mlir/IR/AffineMap.h" 31 #include "mlir/Pass/PassManager.h" 32 #include "mlir/Support/LLVM.h" 33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 34 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" 35 #include "mlir/Transforms/Passes.h" 36 37 using namespace mlir; 38 using namespace mlir::vector; 39 using namespace linalg; 40 41 namespace { 42 43 /// Configurable pass to apply pattern-based tiling and fusion. 44 struct LinalgStrategyTileAndFusePass 45 : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> { 46 47 LinalgStrategyTileAndFusePass() = default; 48 49 LinalgStrategyTileAndFusePass(StringRef opName, 50 LinalgTilingAndFusionOptions opt, 51 LinalgTransformationFilter filt) 52 : options(std::move(opt)), filter(std::move(filt)) { 53 this->anchorOpName.setValue(opName.str()); 54 } 55 56 void runOnOperation() override { 57 auto funcOp = getOperation(); 58 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 59 return; 60 61 RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); 62 if (!anchorOpName.empty()) { 63 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 64 anchorOpName, funcOp.getContext(), options, filter); 65 } else { 66 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 67 funcOp.getContext(), options, filter); 68 } 69 // Search the root operation using bottom up traversal. 70 GreedyRewriteConfig config; 71 config.useTopDownTraversal = false; 72 (void)applyPatternsAndFoldGreedily( 73 funcOp, std::move(tilingAndFusionPattern), config); 74 } 75 76 LinalgTilingAndFusionOptions options; 77 LinalgTransformationFilter filter; 78 }; 79 80 /// Configurable pass to apply pattern-based linalg tiling. 81 struct LinalgStrategyTilePass 82 : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> { 83 84 LinalgStrategyTilePass() = default; 85 86 LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 87 LinalgTransformationFilter filt) 88 : options(std::move(opt)), filter(std::move(filt)) { 89 this->anchorOpName.setValue(opName.str()); 90 } 91 92 void runOnOperation() override { 93 auto funcOp = getOperation(); 94 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 95 return; 96 97 MLIRContext *ctx = funcOp.getContext(); 98 RewritePatternSet tilingPattern(ctx); 99 if (!anchorOpName.empty()) 100 tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options, 101 filter); 102 else 103 tilingPattern.add<LinalgTilingPattern>(ctx, options, filter); 104 if (anchorOpName == tensor::PadOp::getOperationName()) 105 populatePadTensorTilingPatterns(tilingPattern, options); 106 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 107 } 108 109 LinalgTilingOptions options; 110 LinalgTransformationFilter filter; 111 }; 112 113 /// Configurable pass to apply hoisting and padding. 114 struct LinalgStrategyPadPass 115 : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> { 116 117 LinalgStrategyPadPass() = default; 118 119 LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 120 LinalgTransformationFilter filt) 121 : options(std::move(opt)), filter(std::move(filt)) { 122 this->anchorOpName.setValue(opName.str()); 123 } 124 125 void runOnOperation() override { 126 auto funcOp = getOperation(); 127 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 128 return; 129 130 RewritePatternSet paddingPattern(funcOp.getContext()); 131 if (!anchorOpName.empty()) { 132 paddingPattern.add<LinalgPaddingPattern>( 133 anchorOpName, funcOp.getContext(), options, filter); 134 } else { 135 paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options, 136 filter); 137 } 138 (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)); 139 } 140 141 LinalgPaddingOptions options; 142 LinalgTransformationFilter filter; 143 }; 144 145 /// Configurable pass to apply pattern-based linalg generalization. 146 struct LinalgStrategyGeneralizePass 147 : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> { 148 149 LinalgStrategyGeneralizePass() = default; 150 151 LinalgStrategyGeneralizePass(StringRef opName, 152 LinalgTransformationFilter filter) 153 : filter(std::move(filter)) { 154 this->anchorOpName.setValue(opName.str()); 155 } 156 157 void runOnOperation() override { 158 auto funcOp = getOperation(); 159 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 160 return; 161 162 RewritePatternSet generalizationPattern(funcOp.getContext()); 163 if (!anchorOpName.empty()) { 164 generalizationPattern.add<LinalgGeneralizationPattern>( 165 anchorOpName, funcOp.getContext(), filter); 166 } else { 167 generalizationPattern.add<LinalgGeneralizationPattern>( 168 funcOp.getContext(), filter); 169 } 170 if (failed(applyPatternsAndFoldGreedily(funcOp, 171 std::move(generalizationPattern)))) 172 signalPassFailure(); 173 } 174 175 LinalgTransformationFilter filter; 176 }; 177 178 /// Configurable pass to apply lowering of coarser-grained named linalg ops into 179 /// finer-grained named versions. 180 struct LinalgStrategyDecomposePass 181 : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> { 182 183 LinalgStrategyDecomposePass() = default; 184 185 LinalgStrategyDecomposePass(LinalgTransformationFilter filter) 186 : filter(std::move(filter)) {} 187 188 void runOnOperation() override { 189 auto funcOp = getOperation(); 190 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 191 return; 192 RewritePatternSet decompositionPattern(funcOp.getContext()); 193 populateDecomposeConvolutionPatterns(decompositionPattern, filter); 194 if (failed(applyPatternsAndFoldGreedily(funcOp, 195 std::move(decompositionPattern)))) 196 signalPassFailure(); 197 } 198 199 LinalgTransformationFilter filter; 200 }; 201 202 /// Configurable pass to apply pattern-based linalg generalization. 203 struct LinalgStrategyInterchangePass 204 : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { 205 206 LinalgStrategyInterchangePass() = default; 207 208 LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 209 LinalgTransformationFilter filter) 210 : iteratorInterchange(iteratorInterchange.begin(), 211 iteratorInterchange.end()), 212 filter(std::move(filter)) {} 213 214 void runOnOperation() override { 215 auto funcOp = getOperation(); 216 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 217 return; 218 219 SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(), 220 iteratorInterchange.end()); 221 RewritePatternSet interchangePattern(funcOp.getContext()); 222 interchangePattern.add<GenericOpInterchangePattern>( 223 funcOp.getContext(), interchangeVector, filter); 224 if (failed(applyPatternsAndFoldGreedily(funcOp, 225 std::move(interchangePattern)))) 226 signalPassFailure(); 227 } 228 229 SmallVector<int64_t> iteratorInterchange; 230 LinalgTransformationFilter filter; 231 }; 232 233 /// Configurable pass to apply pattern-based linalg promotion. 234 struct LinalgStrategyPromotePass 235 : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> { 236 237 LinalgStrategyPromotePass() = default; 238 239 LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, 240 LinalgTransformationFilter filt) 241 : options(std::move(opt)), filter(std::move(filt)) { 242 this->anchorOpName.setValue(opName.str()); 243 } 244 245 void runOnOperation() override { 246 auto funcOp = getOperation(); 247 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 248 return; 249 250 RewritePatternSet promotionPattern(funcOp.getContext()); 251 if (!anchorOpName.empty()) { 252 promotionPattern.add<LinalgBasePromotionPattern>( 253 anchorOpName, funcOp.getContext(), options, filter); 254 } else { 255 promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(), 256 filter, options); 257 } 258 (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); 259 } 260 261 LinalgPromotionOptions options; 262 LinalgTransformationFilter filter; 263 }; 264 265 /// Configurable pass to apply pattern-based linalg vectorization. 266 struct LinalgStrategyVectorizePass 267 : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> { 268 269 LinalgStrategyVectorizePass() = default; 270 271 LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, 272 LinalgTransformationFilter filt, 273 bool padVectorize = false) 274 : options(opt), filter(std::move(filt)) { 275 this->anchorOpName.setValue(opName.str()); 276 this->vectorizePadding.setValue(padVectorize); 277 } 278 279 void runOnOperation() override { 280 auto funcOp = getOperation(); 281 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 282 return; 283 284 RewritePatternSet vectorizationPatterns(funcOp.getContext()); 285 if (!anchorOpName.empty()) { 286 vectorizationPatterns.add<LinalgVectorizationPattern>( 287 anchorOpName, funcOp.getContext(), options, filter); 288 } else { 289 vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(), 290 filter, options); 291 } 292 vector::populateVectorTransferPermutationMapLoweringPatterns( 293 vectorizationPatterns); 294 vector::populateVectorReductionToContractPatterns(vectorizationPatterns); 295 vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern, 296 linalg::LinalgCopyVTWForwardingPattern>( 297 funcOp.getContext(), /*benefit=*/2); 298 TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, 299 funcOp.getContext()); 300 TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, 301 funcOp.getContext()); 302 (void)applyPatternsAndFoldGreedily(funcOp, 303 std::move(vectorizationPatterns)); 304 305 // Apply the pad tensor op vectorization separately to avoid running the 306 // GenericPadOpVectorizationPattern too early. 307 // TODO: Improve once we have better infrastructure to control pattern 308 // application. 309 if (vectorizePadding) { 310 RewritePatternSet patterns(funcOp.getContext()); 311 linalg::populatePadOpVectorizationPatterns(patterns); 312 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 313 } 314 } 315 316 LinalgVectorizationOptions options; 317 LinalgTransformationFilter filter; 318 }; 319 320 /// Configurable pass to enable the application of other pattern-based linalg 321 /// passes. 322 struct LinalgStrategyEnablePass 323 : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> { 324 325 LinalgStrategyEnablePass(LinalgEnablingOptions opt, 326 LinalgTransformationFilter filt) 327 : options(opt), filter(std::move(filt)) {} 328 329 void runOnOperation() override { 330 auto funcOp = getOperation(); 331 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 332 return; 333 334 MLIRContext *context = funcOp.getContext(); 335 RewritePatternSet patterns = 336 linalg::getLinalgTilingCanonicalizationPatterns(context); 337 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 338 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 339 return signalPassFailure(); 340 341 if (options.licm) { 342 funcOp->walk([&](LoopLikeOpInterface loopLike) { 343 moveLoopInvariantCode(loopLike); 344 }); 345 } 346 347 // Gathers all innermost loops through a post order pruned walk. 348 funcOp.walk([](Operation *op) { 349 if (auto forOp = dyn_cast<AffineForOp>(op)) 350 (void)promoteIfSingleIteration(forOp); 351 else if (auto forOp = dyn_cast<scf::ForOp>(op)) 352 (void)promoteIfSingleIteration(forOp); 353 }); 354 if (options.hoistRedundantVectorTransfers) 355 hoistRedundantVectorTransfers(funcOp); 356 357 if (options.hoistRedundantVectorTransfersOnTensor) 358 hoistRedundantVectorTransfersOnTensor(funcOp); 359 360 // Run CSE to cleanup after canonicalization. 361 OpPassManager dynamicPM("func.func"); 362 dynamicPM.addPass(createCSEPass()); 363 if (failed(runPipeline(dynamicPM, funcOp))) 364 return signalPassFailure(); 365 } 366 367 LinalgEnablingOptions options; 368 LinalgTransformationFilter filter; 369 }; 370 371 /// Configurable pass to lower vector operations. 372 struct LinalgStrategyLowerVectorsPass 373 : public LinalgStrategyLowerVectorsPassBase< 374 LinalgStrategyLowerVectorsPass> { 375 376 LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 377 LinalgTransformationFilter filt) 378 : options(opt), filter(std::move(filt)) {} 379 380 void runOnOperation() override { 381 auto funcOp = getOperation(); 382 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 383 return; 384 385 MLIRContext *context = funcOp.getContext(); 386 RewritePatternSet patterns(context); 387 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 388 // In a progressive lowering of vectors, this would be the 1st step. 389 if (options.contractionLowering) { 390 patterns.add<ContractionOpToOuterProductOpLowering, 391 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 392 options.vectorTransformOptions, context); 393 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 394 } 395 // In a progressive lowering of vectors, this would be the 2nd step. 396 if (options.multiReductionLowering) { 397 vector::populateVectorMultiReductionLoweringPatterns( 398 patterns, 399 options.vectorTransformOptions.vectorMultiReductionLowering); 400 } 401 // In a progressive lowering of vectors, this would be the 3rd step. 402 if (options.transferPartialRewrite) { 403 patterns.add<vector::VectorTransferFullPartialRewriter>( 404 context, options.vectorTransformOptions); 405 } 406 // In a progressive lowering of vectors, this would be the 4th step. 407 if (options.transferLowering) { 408 vector::populateVectorTransferLoweringPatterns(patterns, 409 options.maxTransferRank); 410 } 411 // In a progressive lowering of vectors, this would be the 5th step. 412 if (options.transferToSCFConversion) { 413 populateVectorToSCFConversionPatterns( 414 patterns, options.vectorTransferToSCFOptions.setTargetRank( 415 options.maxTransferRank)); 416 } 417 // In a progressive lowering of vectors, this would be the 6th step. 418 if (options.shapeCastLowering) { 419 vector::populateVectorShapeCastLoweringPatterns(patterns); 420 } 421 // In a progressive lowering of vectors, this would be the 7th step. 422 if (options.transposeLowering) { 423 vector::populateVectorTransposeLoweringPatterns( 424 patterns, options.vectorTransformOptions); 425 if (options.avx2Lowering) 426 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 427 patterns, options.avx2LoweringOptions, /*benefit=*/10); 428 } 429 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 430 } 431 432 LinalgVectorLoweringOptions options; 433 LinalgTransformationFilter filter; 434 }; 435 436 /// Configurable pass to lower vector operations. 437 struct LinalgStrategyRemoveMarkersPass 438 : public LinalgStrategyRemoveMarkersPassBase< 439 LinalgStrategyRemoveMarkersPass> { 440 441 void runOnOperation() override { 442 auto funcOp = getOperation(); 443 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 444 return; 445 funcOp.walk([](LinalgOp op) { 446 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 447 }); 448 } 449 }; 450 } // namespace 451 452 /// Create a LinalgStrategyTileAndFusePass. 453 std::unique_ptr<OperationPass<func::FuncOp>> 454 mlir::createLinalgStrategyTileAndFusePass( 455 StringRef opName, const LinalgTilingAndFusionOptions &options, 456 const LinalgTransformationFilter &filter) { 457 return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options, 458 filter); 459 } 460 461 /// Create a LinalgStrategyTilePass. 462 std::unique_ptr<OperationPass<func::FuncOp>> 463 mlir::createLinalgStrategyTilePass(StringRef opName, 464 const LinalgTilingOptions &opt, 465 const LinalgTransformationFilter &filter) { 466 return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter); 467 } 468 469 /// Create a LinalgStrategyPadPass. 470 std::unique_ptr<OperationPass<func::FuncOp>> 471 mlir::createLinalgStrategyPadPass(StringRef opName, 472 const LinalgPaddingOptions &opt, 473 const LinalgTransformationFilter &filter) { 474 return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter); 475 } 476 477 /// Create a LinalgStrategyPromotePass. 478 std::unique_ptr<OperationPass<func::FuncOp>> 479 mlir::createLinalgStrategyPromotePass( 480 StringRef opName, const LinalgPromotionOptions &opt, 481 const LinalgTransformationFilter &filter) { 482 return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter); 483 } 484 485 /// Create a LinalgStrategyGeneralizePass. 486 std::unique_ptr<OperationPass<func::FuncOp>> 487 mlir::createLinalgStrategyGeneralizePass( 488 StringRef opName, const LinalgTransformationFilter &filter) { 489 return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); 490 } 491 492 /// Create a LinalgStrategyDecomposePass. 493 // TODO: if/when we need finer control add an `opName` parameter. 494 std::unique_ptr<OperationPass<func::FuncOp>> 495 mlir::createLinalgStrategyDecomposePass( 496 const LinalgTransformationFilter &filter) { 497 return std::make_unique<LinalgStrategyDecomposePass>(filter); 498 } 499 500 /// Create a LinalgStrategyInterchangePass. 501 std::unique_ptr<OperationPass<func::FuncOp>> 502 mlir::createLinalgStrategyInterchangePass( 503 ArrayRef<int64_t> iteratorInterchange, 504 const LinalgTransformationFilter &filter) { 505 return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, 506 filter); 507 } 508 509 /// Create a LinalgStrategyVectorizePass. 510 std::unique_ptr<OperationPass<func::FuncOp>> 511 mlir::createLinalgStrategyVectorizePass( 512 StringRef opName, LinalgVectorizationOptions opt, 513 const LinalgTransformationFilter &filter, bool padVectorize) { 514 return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter, 515 padVectorize); 516 } 517 518 /// Create a LinalgStrategyEnablePass. 519 std::unique_ptr<OperationPass<func::FuncOp>> 520 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt, 521 const LinalgTransformationFilter &filter) { 522 return std::make_unique<LinalgStrategyEnablePass>(opt, filter); 523 } 524 525 /// Create a LinalgStrategyLowerVectorsPass. 526 std::unique_ptr<OperationPass<func::FuncOp>> 527 mlir::createLinalgStrategyLowerVectorsPass( 528 LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) { 529 return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter); 530 } 531 532 /// Create a LinalgStrategyRemoveMarkersPass. 533 std::unique_ptr<OperationPass<func::FuncOp>> 534 mlir::createLinalgStrategyRemoveMarkersPass() { 535 return std::make_unique<LinalgStrategyRemoveMarkersPass>(); 536 } 537