1 //===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a configurable pass that can apply patterns liberally 10 // and be plugged in a pass pipeline. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/SliceAnalysis.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/Linalg/Passes.h" 21 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 22 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/SCF/Transforms.h" 25 #include "mlir/Dialect/Tensor/IR/Tensor.h" 26 #include "mlir/Dialect/Vector/VectorTransforms.h" 27 #include "mlir/IR/AffineExpr.h" 28 #include "mlir/IR/AffineMap.h" 29 #include "mlir/Pass/PassManager.h" 30 #include "mlir/Support/LLVM.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "mlir/Transforms/LoopUtils.h" 33 #include "mlir/Transforms/Passes.h" 34 #include "mlir/Transforms/Utils.h" 35 36 using namespace mlir; 37 using namespace mlir::vector; 38 using namespace linalg; 39 40 namespace { 41 42 /// Configurable pass to apply pattern-based tiling and fusion. 43 struct LinalgStrategyTileAndFusePass 44 : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> { 45 46 LinalgStrategyTileAndFusePass() = default; 47 48 LinalgStrategyTileAndFusePass(StringRef opName, 49 LinalgTilingAndFusionOptions opt, 50 LinalgTransformationFilter filt) 51 : options(std::move(opt)), filter(std::move(filt)) { 52 this->anchorOpName.setValue(opName.str()); 53 } 54 55 void runOnOperation() override { 56 auto funcOp = getOperation(); 57 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 58 return; 59 60 RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); 61 if (!anchorOpName.empty()) { 62 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 63 anchorOpName, funcOp.getContext(), options, filter); 64 } else { 65 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 66 funcOp.getContext(), options, filter); 67 } 68 // Search the root operation using bottom up traversal. 69 GreedyRewriteConfig config; 70 config.useTopDownTraversal = false; 71 (void)applyPatternsAndFoldGreedily( 72 funcOp, std::move(tilingAndFusionPattern), config); 73 } 74 75 LinalgTilingAndFusionOptions options; 76 LinalgTransformationFilter filter; 77 }; 78 79 /// Configurable pass to apply pattern-based linalg tiling. 80 struct LinalgStrategyTilePass 81 : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> { 82 83 LinalgStrategyTilePass() = default; 84 85 LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 86 LinalgTransformationFilter filt) 87 : options(std::move(opt)), filter(std::move(filt)) { 88 this->anchorOpName.setValue(opName.str()); 89 } 90 91 void runOnOperation() override { 92 auto funcOp = getOperation(); 93 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 94 return; 95 96 MLIRContext *ctx = funcOp.getContext(); 97 RewritePatternSet tilingPattern(ctx); 98 if (!anchorOpName.empty()) 99 tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options, 100 filter); 101 else 102 tilingPattern.add<LinalgTilingPattern>(ctx, options, filter); 103 if (anchorOpName == linalg::PadTensorOp::getOperationName()) 104 populatePadTensorTilingPatterns(tilingPattern, options); 105 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 106 } 107 108 LinalgTilingOptions options; 109 LinalgTransformationFilter filter; 110 }; 111 112 /// Configurable pass to apply hoisting and padding. 113 struct LinalgStrategyPadPass 114 : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> { 115 116 LinalgStrategyPadPass() = default; 117 118 LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 119 LinalgTransformationFilter filt) 120 : options(std::move(opt)), filter(std::move(filt)) { 121 this->anchorOpName.setValue(opName.str()); 122 } 123 124 void runOnOperation() override { 125 auto funcOp = getOperation(); 126 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 127 return; 128 129 RewritePatternSet paddingPattern(funcOp.getContext()); 130 if (!anchorOpName.empty()) { 131 paddingPattern.add<LinalgPaddingPattern>( 132 anchorOpName, funcOp.getContext(), options, filter); 133 } else { 134 paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options, 135 filter); 136 } 137 (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)); 138 } 139 140 LinalgPaddingOptions options; 141 LinalgTransformationFilter filter; 142 }; 143 144 /// Configurable pass to apply pattern-based linalg generalization. 145 struct LinalgStrategyGeneralizePass 146 : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> { 147 148 LinalgStrategyGeneralizePass() = default; 149 150 LinalgStrategyGeneralizePass(StringRef opName, 151 LinalgTransformationFilter filter) 152 : filter(std::move(filter)) { 153 this->anchorOpName.setValue(opName.str()); 154 } 155 156 void runOnOperation() override { 157 auto funcOp = getOperation(); 158 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 159 return; 160 161 RewritePatternSet generalizationPattern(funcOp.getContext()); 162 if (!anchorOpName.empty()) { 163 generalizationPattern.add<LinalgGeneralizationPattern>( 164 anchorOpName, funcOp.getContext(), filter); 165 } else { 166 generalizationPattern.add<LinalgGeneralizationPattern>( 167 funcOp.getContext(), filter); 168 } 169 if (failed(applyPatternsAndFoldGreedily(funcOp, 170 std::move(generalizationPattern)))) 171 signalPassFailure(); 172 } 173 174 LinalgTransformationFilter filter; 175 }; 176 177 /// Configurable pass to apply lowering of coarser-grained named linalg ops into 178 /// finer-grained named versions. 179 struct LinalgStrategyDecomposePass 180 : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> { 181 182 LinalgStrategyDecomposePass() = default; 183 184 LinalgStrategyDecomposePass(LinalgTransformationFilter filter) 185 : filter(std::move(filter)) {} 186 187 void runOnOperation() override { 188 auto funcOp = getOperation(); 189 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 190 return; 191 RewritePatternSet decompositionPattern(funcOp.getContext()); 192 populateDecomposeConvolutionPatterns(decompositionPattern, filter); 193 if (failed(applyPatternsAndFoldGreedily(funcOp, 194 std::move(decompositionPattern)))) 195 signalPassFailure(); 196 } 197 198 LinalgTransformationFilter filter; 199 }; 200 201 /// Configurable pass to apply pattern-based linalg generalization. 202 struct LinalgStrategyInterchangePass 203 : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { 204 205 LinalgStrategyInterchangePass() = default; 206 207 LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 208 LinalgTransformationFilter filter) 209 : iteratorInterchange(iteratorInterchange.begin(), 210 iteratorInterchange.end()), 211 filter(std::move(filter)) {} 212 213 void runOnOperation() override { 214 auto funcOp = getOperation(); 215 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 216 return; 217 218 SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(), 219 iteratorInterchange.end()); 220 RewritePatternSet interchangePattern(funcOp.getContext()); 221 interchangePattern.add<GenericOpInterchangePattern>( 222 funcOp.getContext(), interchangeVector, filter); 223 if (failed(applyPatternsAndFoldGreedily(funcOp, 224 std::move(interchangePattern)))) 225 signalPassFailure(); 226 } 227 228 SmallVector<int64_t> iteratorInterchange; 229 LinalgTransformationFilter filter; 230 }; 231 232 /// Configurable pass to apply pattern-based linalg promotion. 233 struct LinalgStrategyPromotePass 234 : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> { 235 236 LinalgStrategyPromotePass() = default; 237 238 LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, 239 LinalgTransformationFilter filt) 240 : options(std::move(opt)), filter(std::move(filt)) { 241 this->anchorOpName.setValue(opName.str()); 242 } 243 244 void runOnOperation() override { 245 auto funcOp = getOperation(); 246 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 247 return; 248 249 RewritePatternSet promotionPattern(funcOp.getContext()); 250 if (!anchorOpName.empty()) { 251 promotionPattern.add<LinalgBasePromotionPattern>( 252 anchorOpName, funcOp.getContext(), options, filter); 253 } else { 254 promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(), 255 filter, options); 256 } 257 (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); 258 } 259 260 LinalgPromotionOptions options; 261 LinalgTransformationFilter filter; 262 }; 263 264 /// Configurable pass to apply pattern-based linalg vectorization. 265 struct LinalgStrategyVectorizePass 266 : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> { 267 268 LinalgStrategyVectorizePass() = default; 269 270 LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, 271 LinalgTransformationFilter filt, 272 bool padVectorize = false) 273 : options(opt), filter(std::move(filt)) { 274 this->anchorOpName.setValue(opName.str()); 275 this->vectorizePadding.setValue(padVectorize); 276 } 277 278 void runOnOperation() override { 279 auto funcOp = getOperation(); 280 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 281 return; 282 283 RewritePatternSet vectorizationPatterns(funcOp.getContext()); 284 if (!anchorOpName.empty()) { 285 vectorizationPatterns.add<LinalgVectorizationPattern>( 286 anchorOpName, funcOp.getContext(), options, filter); 287 } else { 288 vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(), 289 filter, options); 290 } 291 vector::populateVectorTransferPermutationMapLoweringPatterns( 292 vectorizationPatterns); 293 vector::populateVectorReductionToContractPatterns(vectorizationPatterns); 294 vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern, 295 linalg::LinalgCopyVTWForwardingPattern>( 296 funcOp.getContext(), /*benefit=*/2); 297 TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, 298 funcOp.getContext()); 299 TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, 300 funcOp.getContext()); 301 (void)applyPatternsAndFoldGreedily(funcOp, 302 std::move(vectorizationPatterns)); 303 304 // Apply the pad tensor op vectorization separately to avoid running the 305 // GenericPadTensorOpVectorizationPattern too early. 306 // TODO: Improve once we have better infrastructure to control pattern 307 // application. 308 if (vectorizePadding) { 309 RewritePatternSet patterns(funcOp.getContext()); 310 linalg::populatePadTensorOpVectorizationPatterns(patterns); 311 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 312 } 313 } 314 315 LinalgVectorizationOptions options; 316 LinalgTransformationFilter filter; 317 }; 318 319 /// Configurable pass to enable the application of other pattern-based linalg 320 /// passes. 321 struct LinalgStrategyEnablePass 322 : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> { 323 324 LinalgStrategyEnablePass(LinalgEnablingOptions opt, 325 LinalgTransformationFilter filt) 326 : options(opt), filter(std::move(filt)) {} 327 328 void runOnOperation() override { 329 auto funcOp = getOperation(); 330 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 331 return; 332 333 MLIRContext *context = funcOp.getContext(); 334 RewritePatternSet patterns = 335 linalg::getLinalgTilingCanonicalizationPatterns(context); 336 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 337 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 338 return signalPassFailure(); 339 340 if (options.licm) { 341 if (funcOp 342 ->walk([&](LoopLikeOpInterface loopLike) { 343 if (failed(moveLoopInvariantCode(loopLike))) 344 return WalkResult::interrupt(); 345 return WalkResult::advance(); 346 }) 347 .wasInterrupted()) 348 return signalPassFailure(); 349 } 350 351 promoteSingleIterationLoops(funcOp); 352 if (options.hoistRedundantVectorTransfers) 353 hoistRedundantVectorTransfers(funcOp); 354 355 if (options.hoistRedundantVectorTransfersOnTensor) 356 hoistRedundantVectorTransfersOnTensor(funcOp); 357 358 // Run CSE to cleanup after canonicalization. 359 OpPassManager dynamicPM("builtin.func"); 360 dynamicPM.addPass(createCSEPass()); 361 if (failed(runPipeline(dynamicPM, funcOp))) 362 return signalPassFailure(); 363 } 364 365 LinalgEnablingOptions options; 366 LinalgTransformationFilter filter; 367 }; 368 369 /// Configurable pass to lower vector operations. 370 struct LinalgStrategyLowerVectorsPass 371 : public LinalgStrategyLowerVectorsPassBase< 372 LinalgStrategyLowerVectorsPass> { 373 374 LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 375 LinalgTransformationFilter filt) 376 : options(opt), filter(std::move(filt)) {} 377 378 void runOnOperation() override { 379 auto funcOp = getOperation(); 380 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 381 return; 382 383 MLIRContext *context = funcOp.getContext(); 384 RewritePatternSet patterns(context); 385 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 386 // In a progressive lowering of vectors, this would be the 1st step. 387 if (options.contractionLowering) { 388 patterns.add<ContractionOpToOuterProductOpLowering, 389 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 390 options.vectorTransformOptions, context); 391 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 392 } 393 // In a progressive lowering of vectors, this would be the 2nd step. 394 if (options.multiReductionLowering) { 395 vector::populateVectorMultiReductionLoweringPatterns( 396 patterns, 397 options.vectorTransformOptions.vectorMultiReductionLowering); 398 } 399 // In a progressive lowering of vectors, this would be the 3rd step. 400 if (options.transferPartialRewrite) { 401 patterns.add<vector::VectorTransferFullPartialRewriter>( 402 context, options.vectorTransformOptions); 403 } 404 // In a progressive lowering of vectors, this would be the 4th step. 405 if (options.transferLowering) { 406 vector::populateVectorTransferLoweringPatterns(patterns, 407 options.maxTransferRank); 408 } 409 // In a progressive lowering of vectors, this would be the 5th step. 410 if (options.transferToSCFConversion) { 411 populateVectorToSCFConversionPatterns( 412 patterns, options.vectorTransferToSCFOptions.setTargetRank( 413 options.maxTransferRank)); 414 } 415 // In a progressive lowering of vectors, this would be the 6th step. 416 if (options.shapeCastLowering) { 417 vector::populateVectorShapeCastLoweringPatterns(patterns); 418 } 419 // In a progressive lowering of vectors, this would be the 7th step. 420 if (options.transposeLowering) { 421 vector::populateVectorTransposeLoweringPatterns( 422 patterns, options.vectorTransformOptions); 423 if (options.avx2Lowering) 424 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 425 patterns, options.avx2LoweringOptions, /*benefit=*/10); 426 } 427 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 428 } 429 430 LinalgVectorLoweringOptions options; 431 LinalgTransformationFilter filter; 432 }; 433 434 /// Configurable pass to lower vector operations. 435 struct LinalgStrategyRemoveMarkersPass 436 : public LinalgStrategyRemoveMarkersPassBase< 437 LinalgStrategyRemoveMarkersPass> { 438 439 void runOnOperation() override { 440 auto funcOp = getOperation(); 441 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 442 return; 443 funcOp.walk([](LinalgOp op) { 444 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 445 }); 446 } 447 }; 448 } // namespace 449 450 /// Create a LinalgStrategyTileAndFusePass. 451 std::unique_ptr<OperationPass<FuncOp>> 452 mlir::createLinalgStrategyTileAndFusePass( 453 StringRef opName, const LinalgTilingAndFusionOptions &options, 454 const LinalgTransformationFilter &filter) { 455 return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options, 456 filter); 457 } 458 459 /// Create a LinalgStrategyTilePass. 460 std::unique_ptr<OperationPass<FuncOp>> 461 mlir::createLinalgStrategyTilePass(StringRef opName, 462 const LinalgTilingOptions &opt, 463 const LinalgTransformationFilter &filter) { 464 return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter); 465 } 466 467 /// Create a LinalgStrategyPadPass. 468 std::unique_ptr<OperationPass<FuncOp>> 469 mlir::createLinalgStrategyPadPass(StringRef opName, 470 const LinalgPaddingOptions &opt, 471 const LinalgTransformationFilter &filter) { 472 return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter); 473 } 474 475 /// Create a LinalgStrategyPromotePass. 476 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyPromotePass( 477 StringRef opName, const LinalgPromotionOptions &opt, 478 const LinalgTransformationFilter &filter) { 479 return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter); 480 } 481 482 /// Create a LinalgStrategyGeneralizePass. 483 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyGeneralizePass( 484 StringRef opName, const LinalgTransformationFilter &filter) { 485 return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); 486 } 487 488 /// Create a LinalgStrategyDecomposePass. 489 // TODO: if/when we need finer control add an `opName` parameter. 490 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyDecomposePass( 491 const LinalgTransformationFilter &filter) { 492 return std::make_unique<LinalgStrategyDecomposePass>(filter); 493 } 494 495 /// Create a LinalgStrategyInterchangePass. 496 std::unique_ptr<OperationPass<FuncOp>> 497 mlir::createLinalgStrategyInterchangePass( 498 ArrayRef<int64_t> iteratorInterchange, 499 const LinalgTransformationFilter &filter) { 500 return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, 501 filter); 502 } 503 504 /// Create a LinalgStrategyVectorizePass. 505 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass( 506 StringRef opName, LinalgVectorizationOptions opt, 507 const LinalgTransformationFilter &filter, bool padVectorize) { 508 return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter, 509 padVectorize); 510 } 511 512 /// Create a LinalgStrategyEnablePass. 513 std::unique_ptr<OperationPass<FuncOp>> 514 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt, 515 const LinalgTransformationFilter &filter) { 516 return std::make_unique<LinalgStrategyEnablePass>(opt, filter); 517 } 518 519 /// Create a LinalgStrategyLowerVectorsPass. 520 std::unique_ptr<OperationPass<FuncOp>> 521 mlir::createLinalgStrategyLowerVectorsPass( 522 LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) { 523 return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter); 524 } 525 526 /// Create a LinalgStrategyRemoveMarkersPass. 527 std::unique_ptr<OperationPass<FuncOp>> 528 mlir::createLinalgStrategyRemoveMarkersPass() { 529 return std::make_unique<LinalgStrategyRemoveMarkersPass>(); 530 } 531