1 //===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a configurable pass that can apply patterns liberally 10 // and be plugged in a pass pipeline. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/SliceAnalysis.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Affine/Utils.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Passes.h" 23 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 24 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 25 #include "mlir/Dialect/Linalg/Utils/Utils.h" 26 #include "mlir/Dialect/SCF/Transforms.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 29 #include "mlir/IR/AffineExpr.h" 30 #include "mlir/IR/AffineMap.h" 31 #include "mlir/Pass/PassManager.h" 32 #include "mlir/Support/LLVM.h" 33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 34 #include "mlir/Transforms/Passes.h" 35 36 using namespace mlir; 37 using namespace mlir::vector; 38 using namespace linalg; 39 40 namespace { 41 42 /// Configurable pass to apply pattern-based tiling and fusion. 43 struct LinalgStrategyTileAndFusePass 44 : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> { 45 46 LinalgStrategyTileAndFusePass() = default; 47 48 LinalgStrategyTileAndFusePass(StringRef opName, 49 LinalgTilingAndFusionOptions opt, 50 LinalgTransformationFilter filt) 51 : options(std::move(opt)), filter(std::move(filt)) { 52 this->anchorOpName.setValue(opName.str()); 53 } 54 55 void runOnOperation() override { 56 auto funcOp = getOperation(); 57 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 58 return; 59 60 RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); 61 if (!anchorOpName.empty()) { 62 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 63 anchorOpName, funcOp.getContext(), options, filter); 64 } else { 65 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 66 funcOp.getContext(), options, filter); 67 } 68 // Search the root operation using bottom up traversal. 69 GreedyRewriteConfig config; 70 config.useTopDownTraversal = false; 71 (void)applyPatternsAndFoldGreedily( 72 funcOp, std::move(tilingAndFusionPattern), config); 73 } 74 75 LinalgTilingAndFusionOptions options; 76 LinalgTransformationFilter filter; 77 }; 78 79 /// Configurable pass to apply pattern-based linalg tiling. 80 struct LinalgStrategyTilePass 81 : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> { 82 83 LinalgStrategyTilePass() = default; 84 85 LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 86 LinalgTransformationFilter filt) 87 : options(std::move(opt)), filter(std::move(filt)) { 88 this->anchorOpName.setValue(opName.str()); 89 } 90 91 void runOnOperation() override { 92 auto funcOp = getOperation(); 93 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 94 return; 95 96 MLIRContext *ctx = funcOp.getContext(); 97 RewritePatternSet tilingPattern(ctx); 98 if (!anchorOpName.empty()) 99 tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options, 100 filter); 101 else 102 tilingPattern.add<LinalgTilingPattern>(ctx, options, filter); 103 if (anchorOpName == tensor::PadOp::getOperationName()) 104 populatePadTensorTilingPatterns(tilingPattern, options); 105 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 106 } 107 108 LinalgTilingOptions options; 109 LinalgTransformationFilter filter; 110 }; 111 112 /// Configurable pass to apply hoisting and padding. 113 struct LinalgStrategyPadPass 114 : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> { 115 116 LinalgStrategyPadPass() = default; 117 118 LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 119 LinalgTransformationFilter filt) 120 : options(std::move(opt)), filter(std::move(filt)) { 121 this->anchorOpName.setValue(opName.str()); 122 } 123 124 void runOnOperation() override { 125 auto funcOp = getOperation(); 126 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 127 return; 128 129 RewritePatternSet paddingPattern(funcOp.getContext()); 130 if (!anchorOpName.empty()) { 131 paddingPattern.add<LinalgPaddingPattern>( 132 anchorOpName, funcOp.getContext(), options, filter); 133 } else { 134 paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options, 135 filter); 136 } 137 (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)); 138 } 139 140 LinalgPaddingOptions options; 141 LinalgTransformationFilter filter; 142 }; 143 144 /// Configurable pass to apply pattern-based linalg generalization. 145 struct LinalgStrategyGeneralizePass 146 : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> { 147 148 LinalgStrategyGeneralizePass() = default; 149 150 LinalgStrategyGeneralizePass(StringRef opName, 151 LinalgTransformationFilter filter) 152 : filter(std::move(filter)) { 153 this->anchorOpName.setValue(opName.str()); 154 } 155 156 void runOnOperation() override { 157 auto funcOp = getOperation(); 158 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 159 return; 160 161 RewritePatternSet generalizationPattern(funcOp.getContext()); 162 if (!anchorOpName.empty()) { 163 generalizationPattern.add<LinalgGeneralizationPattern>( 164 anchorOpName, funcOp.getContext(), filter); 165 } else { 166 generalizationPattern.add<LinalgGeneralizationPattern>( 167 funcOp.getContext(), filter); 168 } 169 if (failed(applyPatternsAndFoldGreedily(funcOp, 170 std::move(generalizationPattern)))) 171 signalPassFailure(); 172 } 173 174 LinalgTransformationFilter filter; 175 }; 176 177 /// Configurable pass to apply lowering of coarser-grained named linalg ops into 178 /// finer-grained named versions. 179 struct LinalgStrategyDecomposePass 180 : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> { 181 182 LinalgStrategyDecomposePass() = default; 183 184 LinalgStrategyDecomposePass(LinalgTransformationFilter filter) 185 : filter(std::move(filter)) {} 186 187 void runOnOperation() override { 188 auto funcOp = getOperation(); 189 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 190 return; 191 RewritePatternSet decompositionPattern(funcOp.getContext()); 192 populateDecomposeConvolutionPatterns(decompositionPattern, filter); 193 if (failed(applyPatternsAndFoldGreedily(funcOp, 194 std::move(decompositionPattern)))) 195 signalPassFailure(); 196 } 197 198 LinalgTransformationFilter filter; 199 }; 200 201 /// Configurable pass to apply pattern-based linalg generalization. 202 struct LinalgStrategyInterchangePass 203 : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { 204 205 LinalgStrategyInterchangePass() = default; 206 207 LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 208 LinalgTransformationFilter filter) 209 : iteratorInterchange(iteratorInterchange.begin(), 210 iteratorInterchange.end()), 211 filter(std::move(filter)) {} 212 213 void runOnOperation() override { 214 auto funcOp = getOperation(); 215 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 216 return; 217 218 SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(), 219 iteratorInterchange.end()); 220 RewritePatternSet interchangePattern(funcOp.getContext()); 221 interchangePattern.add<GenericOpInterchangePattern>( 222 funcOp.getContext(), interchangeVector, filter); 223 if (failed(applyPatternsAndFoldGreedily(funcOp, 224 std::move(interchangePattern)))) 225 signalPassFailure(); 226 } 227 228 SmallVector<int64_t> iteratorInterchange; 229 LinalgTransformationFilter filter; 230 }; 231 232 /// Configurable pass to apply pattern-based linalg promotion. 233 struct LinalgStrategyPromotePass 234 : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> { 235 236 LinalgStrategyPromotePass() = default; 237 238 LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, 239 LinalgTransformationFilter filt) 240 : options(std::move(opt)), filter(std::move(filt)) { 241 this->anchorOpName.setValue(opName.str()); 242 } 243 244 void runOnOperation() override { 245 auto funcOp = getOperation(); 246 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 247 return; 248 249 RewritePatternSet promotionPattern(funcOp.getContext()); 250 if (!anchorOpName.empty()) { 251 promotionPattern.add<LinalgBasePromotionPattern>( 252 anchorOpName, funcOp.getContext(), options, filter); 253 } else { 254 promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(), 255 filter, options); 256 } 257 (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); 258 } 259 260 LinalgPromotionOptions options; 261 LinalgTransformationFilter filter; 262 }; 263 264 /// Configurable pass to apply pattern-based linalg vectorization. 265 struct LinalgStrategyVectorizePass 266 : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> { 267 268 LinalgStrategyVectorizePass() = default; 269 270 LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, 271 LinalgTransformationFilter filt, 272 bool padVectorize = false) 273 : options(opt), filter(std::move(filt)) { 274 this->anchorOpName.setValue(opName.str()); 275 this->vectorizePadding.setValue(padVectorize); 276 } 277 278 void runOnOperation() override { 279 auto funcOp = getOperation(); 280 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 281 return; 282 283 RewritePatternSet vectorizationPatterns(funcOp.getContext()); 284 if (!anchorOpName.empty()) { 285 vectorizationPatterns.add<LinalgVectorizationPattern>( 286 anchorOpName, funcOp.getContext(), options, filter); 287 } else { 288 vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(), 289 filter, options); 290 } 291 vector::populateVectorTransferPermutationMapLoweringPatterns( 292 vectorizationPatterns); 293 vector::populateVectorReductionToContractPatterns(vectorizationPatterns); 294 vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern, 295 linalg::LinalgCopyVTWForwardingPattern>( 296 funcOp.getContext(), /*benefit=*/2); 297 TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, 298 funcOp.getContext()); 299 TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, 300 funcOp.getContext()); 301 (void)applyPatternsAndFoldGreedily(funcOp, 302 std::move(vectorizationPatterns)); 303 304 // Apply the pad tensor op vectorization separately to avoid running the 305 // GenericPadOpVectorizationPattern too early. 306 // TODO: Improve once we have better infrastructure to control pattern 307 // application. 308 if (vectorizePadding) { 309 RewritePatternSet patterns(funcOp.getContext()); 310 linalg::populatePadOpVectorizationPatterns(patterns); 311 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 312 } 313 } 314 315 LinalgVectorizationOptions options; 316 LinalgTransformationFilter filter; 317 }; 318 319 /// Configurable pass to enable the application of other pattern-based linalg 320 /// passes. 321 struct LinalgStrategyEnablePass 322 : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> { 323 324 LinalgStrategyEnablePass(LinalgEnablingOptions opt, 325 LinalgTransformationFilter filt) 326 : options(opt), filter(std::move(filt)) {} 327 328 void runOnOperation() override { 329 auto funcOp = getOperation(); 330 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 331 return; 332 333 MLIRContext *context = funcOp.getContext(); 334 RewritePatternSet patterns = 335 linalg::getLinalgTilingCanonicalizationPatterns(context); 336 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 337 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 338 return signalPassFailure(); 339 340 if (options.licm) { 341 if (funcOp 342 ->walk([&](LoopLikeOpInterface loopLike) { 343 if (failed(moveLoopInvariantCode(loopLike))) 344 return WalkResult::interrupt(); 345 return WalkResult::advance(); 346 }) 347 .wasInterrupted()) 348 return signalPassFailure(); 349 } 350 351 // Gathers all innermost loops through a post order pruned walk. 352 funcOp.walk([](Operation *op) { 353 if (auto forOp = dyn_cast<AffineForOp>(op)) 354 (void)promoteIfSingleIteration(forOp); 355 else if (auto forOp = dyn_cast<scf::ForOp>(op)) 356 (void)promoteIfSingleIteration(forOp); 357 }); 358 if (options.hoistRedundantVectorTransfers) 359 hoistRedundantVectorTransfers(funcOp); 360 361 if (options.hoistRedundantVectorTransfersOnTensor) 362 hoistRedundantVectorTransfersOnTensor(funcOp); 363 364 // Run CSE to cleanup after canonicalization. 365 OpPassManager dynamicPM("builtin.func"); 366 dynamicPM.addPass(createCSEPass()); 367 if (failed(runPipeline(dynamicPM, funcOp))) 368 return signalPassFailure(); 369 } 370 371 LinalgEnablingOptions options; 372 LinalgTransformationFilter filter; 373 }; 374 375 /// Configurable pass to lower vector operations. 376 struct LinalgStrategyLowerVectorsPass 377 : public LinalgStrategyLowerVectorsPassBase< 378 LinalgStrategyLowerVectorsPass> { 379 380 LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 381 LinalgTransformationFilter filt) 382 : options(opt), filter(std::move(filt)) {} 383 384 void runOnOperation() override { 385 auto funcOp = getOperation(); 386 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 387 return; 388 389 MLIRContext *context = funcOp.getContext(); 390 RewritePatternSet patterns(context); 391 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 392 // In a progressive lowering of vectors, this would be the 1st step. 393 if (options.contractionLowering) { 394 patterns.add<ContractionOpToOuterProductOpLowering, 395 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 396 options.vectorTransformOptions, context); 397 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 398 } 399 // In a progressive lowering of vectors, this would be the 2nd step. 400 if (options.multiReductionLowering) { 401 vector::populateVectorMultiReductionLoweringPatterns( 402 patterns, 403 options.vectorTransformOptions.vectorMultiReductionLowering); 404 } 405 // In a progressive lowering of vectors, this would be the 3rd step. 406 if (options.transferPartialRewrite) { 407 patterns.add<vector::VectorTransferFullPartialRewriter>( 408 context, options.vectorTransformOptions); 409 } 410 // In a progressive lowering of vectors, this would be the 4th step. 411 if (options.transferLowering) { 412 vector::populateVectorTransferLoweringPatterns(patterns, 413 options.maxTransferRank); 414 } 415 // In a progressive lowering of vectors, this would be the 5th step. 416 if (options.transferToSCFConversion) { 417 populateVectorToSCFConversionPatterns( 418 patterns, options.vectorTransferToSCFOptions.setTargetRank( 419 options.maxTransferRank)); 420 } 421 // In a progressive lowering of vectors, this would be the 6th step. 422 if (options.shapeCastLowering) { 423 vector::populateVectorShapeCastLoweringPatterns(patterns); 424 } 425 // In a progressive lowering of vectors, this would be the 7th step. 426 if (options.transposeLowering) { 427 vector::populateVectorTransposeLoweringPatterns( 428 patterns, options.vectorTransformOptions); 429 if (options.avx2Lowering) 430 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 431 patterns, options.avx2LoweringOptions, /*benefit=*/10); 432 } 433 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 434 } 435 436 LinalgVectorLoweringOptions options; 437 LinalgTransformationFilter filter; 438 }; 439 440 /// Configurable pass to lower vector operations. 441 struct LinalgStrategyRemoveMarkersPass 442 : public LinalgStrategyRemoveMarkersPassBase< 443 LinalgStrategyRemoveMarkersPass> { 444 445 void runOnOperation() override { 446 auto funcOp = getOperation(); 447 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 448 return; 449 funcOp.walk([](LinalgOp op) { 450 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 451 }); 452 } 453 }; 454 } // namespace 455 456 /// Create a LinalgStrategyTileAndFusePass. 457 std::unique_ptr<OperationPass<FuncOp>> 458 mlir::createLinalgStrategyTileAndFusePass( 459 StringRef opName, const LinalgTilingAndFusionOptions &options, 460 const LinalgTransformationFilter &filter) { 461 return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options, 462 filter); 463 } 464 465 /// Create a LinalgStrategyTilePass. 466 std::unique_ptr<OperationPass<FuncOp>> 467 mlir::createLinalgStrategyTilePass(StringRef opName, 468 const LinalgTilingOptions &opt, 469 const LinalgTransformationFilter &filter) { 470 return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter); 471 } 472 473 /// Create a LinalgStrategyPadPass. 474 std::unique_ptr<OperationPass<FuncOp>> 475 mlir::createLinalgStrategyPadPass(StringRef opName, 476 const LinalgPaddingOptions &opt, 477 const LinalgTransformationFilter &filter) { 478 return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter); 479 } 480 481 /// Create a LinalgStrategyPromotePass. 482 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyPromotePass( 483 StringRef opName, const LinalgPromotionOptions &opt, 484 const LinalgTransformationFilter &filter) { 485 return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter); 486 } 487 488 /// Create a LinalgStrategyGeneralizePass. 489 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyGeneralizePass( 490 StringRef opName, const LinalgTransformationFilter &filter) { 491 return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); 492 } 493 494 /// Create a LinalgStrategyDecomposePass. 495 // TODO: if/when we need finer control add an `opName` parameter. 496 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyDecomposePass( 497 const LinalgTransformationFilter &filter) { 498 return std::make_unique<LinalgStrategyDecomposePass>(filter); 499 } 500 501 /// Create a LinalgStrategyInterchangePass. 502 std::unique_ptr<OperationPass<FuncOp>> 503 mlir::createLinalgStrategyInterchangePass( 504 ArrayRef<int64_t> iteratorInterchange, 505 const LinalgTransformationFilter &filter) { 506 return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, 507 filter); 508 } 509 510 /// Create a LinalgStrategyVectorizePass. 511 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass( 512 StringRef opName, LinalgVectorizationOptions opt, 513 const LinalgTransformationFilter &filter, bool padVectorize) { 514 return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter, 515 padVectorize); 516 } 517 518 /// Create a LinalgStrategyEnablePass. 519 std::unique_ptr<OperationPass<FuncOp>> 520 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt, 521 const LinalgTransformationFilter &filter) { 522 return std::make_unique<LinalgStrategyEnablePass>(opt, filter); 523 } 524 525 /// Create a LinalgStrategyLowerVectorsPass. 526 std::unique_ptr<OperationPass<FuncOp>> 527 mlir::createLinalgStrategyLowerVectorsPass( 528 LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) { 529 return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter); 530 } 531 532 /// Create a LinalgStrategyRemoveMarkersPass. 533 std::unique_ptr<OperationPass<FuncOp>> 534 mlir::createLinalgStrategyRemoveMarkersPass() { 535 return std::make_unique<LinalgStrategyRemoveMarkersPass>(); 536 } 537