1 //===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a configurable pass that can apply patterns liberally 10 // and be plugged in a pass pipeline. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/SliceAnalysis.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/Linalg/Passes.h" 21 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 22 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/SCF/Transforms.h" 25 #include "mlir/Dialect/Tensor/IR/Tensor.h" 26 #include "mlir/Dialect/Vector/VectorTransforms.h" 27 #include "mlir/IR/AffineExpr.h" 28 #include "mlir/IR/AffineMap.h" 29 #include "mlir/Pass/PassManager.h" 30 #include "mlir/Support/LLVM.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "mlir/Transforms/LoopUtils.h" 33 #include "mlir/Transforms/Passes.h" 34 #include "mlir/Transforms/Utils.h" 35 36 using namespace mlir; 37 using namespace mlir::vector; 38 using namespace linalg; 39 40 namespace { 41 42 /// Configurable pass to apply pattern-based tiling and fusion. 43 struct LinalgStrategyTileAndFusePass 44 : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> { 45 46 LinalgStrategyTileAndFusePass() = default; 47 48 LinalgStrategyTileAndFusePass(StringRef opName, 49 LinalgTilingAndFusionOptions opt, 50 LinalgTransformationFilter filt) 51 : options(std::move(opt)), filter(std::move(filt)) { 52 this->anchorOpName.setValue(opName.str()); 53 } 54 55 void runOnFunction() override { 56 auto funcOp = getFunction(); 57 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 58 return; 59 60 RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); 61 if (!anchorOpName.empty()) { 62 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 63 anchorOpName, funcOp.getContext(), options, filter); 64 } else { 65 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 66 funcOp.getContext(), options, filter); 67 } 68 // Search the root operation using bottom up traversal. 69 GreedyRewriteConfig config; 70 config.useTopDownTraversal = false; 71 (void)applyPatternsAndFoldGreedily( 72 funcOp, std::move(tilingAndFusionPattern), config); 73 } 74 75 LinalgTilingAndFusionOptions options; 76 LinalgTransformationFilter filter; 77 }; 78 79 /// Configurable pass to apply pattern-based linalg tiling. 80 struct LinalgStrategyTilePass 81 : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> { 82 83 LinalgStrategyTilePass() = default; 84 85 LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 86 LinalgTransformationFilter filt) 87 : options(std::move(opt)), filter(std::move(filt)) { 88 this->anchorOpName.setValue(opName.str()); 89 } 90 91 void runOnFunction() override { 92 auto funcOp = getFunction(); 93 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 94 return; 95 96 MLIRContext *ctx = funcOp.getContext(); 97 RewritePatternSet tilingPattern(ctx); 98 if (!anchorOpName.empty()) 99 tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options, 100 filter); 101 else 102 tilingPattern.add<LinalgTilingPattern>(ctx, options, filter); 103 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 104 } 105 106 LinalgTilingOptions options; 107 LinalgTransformationFilter filter; 108 }; 109 110 /// Configurable pass to apply hoisting and padding. 111 struct LinalgStrategyPadPass 112 : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> { 113 114 LinalgStrategyPadPass() = default; 115 116 LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 117 LinalgTransformationFilter filt) 118 : options(std::move(opt)), filter(std::move(filt)) { 119 this->anchorOpName.setValue(opName.str()); 120 } 121 122 void runOnFunction() override { 123 auto funcOp = getFunction(); 124 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 125 return; 126 127 RewritePatternSet paddingPattern(funcOp.getContext()); 128 if (!anchorOpName.empty()) { 129 paddingPattern.add<LinalgPaddingPattern>( 130 anchorOpName, funcOp.getContext(), options, filter); 131 } else { 132 paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options, 133 filter); 134 } 135 (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)); 136 } 137 138 LinalgPaddingOptions options; 139 LinalgTransformationFilter filter; 140 }; 141 142 /// Configurable pass to apply pattern-based linalg generalization. 143 struct LinalgStrategyGeneralizePass 144 : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> { 145 146 LinalgStrategyGeneralizePass() = default; 147 148 LinalgStrategyGeneralizePass(StringRef opName, 149 LinalgTransformationFilter filter) 150 : filter(std::move(filter)) { 151 this->anchorOpName.setValue(opName.str()); 152 } 153 154 void runOnFunction() override { 155 auto funcOp = getFunction(); 156 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 157 return; 158 159 RewritePatternSet generalizationPattern(funcOp.getContext()); 160 if (!anchorOpName.empty()) { 161 generalizationPattern.add<LinalgGeneralizationPattern>( 162 anchorOpName, funcOp.getContext(), filter); 163 } else { 164 generalizationPattern.add<LinalgGeneralizationPattern>( 165 funcOp.getContext(), filter); 166 } 167 if (failed(applyPatternsAndFoldGreedily(funcOp, 168 std::move(generalizationPattern)))) 169 signalPassFailure(); 170 } 171 172 LinalgTransformationFilter filter; 173 }; 174 175 /// Configurable pass to apply lowering of coarser-grained named linalg ops into 176 /// finer-grained named versions. 177 struct LinalgStrategyDecomposePass 178 : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> { 179 180 LinalgStrategyDecomposePass() = default; 181 182 LinalgStrategyDecomposePass(LinalgTransformationFilter filter) 183 : filter(std::move(filter)) {} 184 185 void runOnFunction() override { 186 auto funcOp = getFunction(); 187 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 188 return; 189 RewritePatternSet decompositionPattern(funcOp.getContext()); 190 populateDecomposeConvolutionPatterns(decompositionPattern, filter); 191 if (failed(applyPatternsAndFoldGreedily(funcOp, 192 std::move(decompositionPattern)))) 193 signalPassFailure(); 194 } 195 196 LinalgTransformationFilter filter; 197 }; 198 199 /// Configurable pass to apply pattern-based linalg generalization. 200 struct LinalgStrategyInterchangePass 201 : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { 202 203 LinalgStrategyInterchangePass() = default; 204 205 LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 206 LinalgTransformationFilter filter) 207 : iteratorInterchange(iteratorInterchange.begin(), 208 iteratorInterchange.end()), 209 filter(std::move(filter)) {} 210 211 void runOnFunction() override { 212 auto funcOp = getFunction(); 213 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 214 return; 215 216 SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(), 217 iteratorInterchange.end()); 218 RewritePatternSet interchangePattern(funcOp.getContext()); 219 interchangePattern.add<GenericOpInterchangePattern>( 220 funcOp.getContext(), interchangeVector, filter); 221 if (failed(applyPatternsAndFoldGreedily(funcOp, 222 std::move(interchangePattern)))) 223 signalPassFailure(); 224 } 225 226 SmallVector<int64_t> iteratorInterchange; 227 LinalgTransformationFilter filter; 228 }; 229 230 /// Configurable pass to apply pattern-based linalg promotion. 231 struct LinalgStrategyPromotePass 232 : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> { 233 234 LinalgStrategyPromotePass() = default; 235 236 LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, 237 LinalgTransformationFilter filt) 238 : options(std::move(opt)), filter(std::move(filt)) { 239 this->anchorOpName.setValue(opName.str()); 240 } 241 242 void runOnFunction() override { 243 auto funcOp = getFunction(); 244 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 245 return; 246 247 RewritePatternSet promotionPattern(funcOp.getContext()); 248 if (!anchorOpName.empty()) { 249 promotionPattern.add<LinalgBasePromotionPattern>( 250 anchorOpName, funcOp.getContext(), options, filter); 251 } else { 252 promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(), 253 filter, options); 254 } 255 (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); 256 } 257 258 LinalgPromotionOptions options; 259 LinalgTransformationFilter filter; 260 }; 261 262 /// Configurable pass to apply pattern-based linalg vectorization. 263 struct LinalgStrategyVectorizePass 264 : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> { 265 266 LinalgStrategyVectorizePass() = default; 267 268 LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, 269 LinalgTransformationFilter filt, 270 bool padVectorize = false) 271 : options(opt), filter(std::move(filt)) { 272 this->anchorOpName.setValue(opName.str()); 273 this->vectorizePadding.setValue(padVectorize); 274 } 275 276 void runOnFunction() override { 277 auto funcOp = getFunction(); 278 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 279 return; 280 281 RewritePatternSet vectorizationPatterns(funcOp.getContext()); 282 if (!anchorOpName.empty()) { 283 vectorizationPatterns.add<LinalgVectorizationPattern>( 284 anchorOpName, funcOp.getContext(), options, filter); 285 } else { 286 vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(), 287 filter, options); 288 } 289 vector::populateVectorTransferPermutationMapLoweringPatterns( 290 vectorizationPatterns); 291 vector::populateVectorReductionToContractPatterns(vectorizationPatterns); 292 vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern, 293 linalg::LinalgCopyVTWForwardingPattern>( 294 funcOp.getContext(), /*benefit=*/2); 295 TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, 296 funcOp.getContext()); 297 TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, 298 funcOp.getContext()); 299 (void)applyPatternsAndFoldGreedily(funcOp, 300 std::move(vectorizationPatterns)); 301 302 // Apply the pad tensor op vectorization separately to avoid running the 303 // GenericPadTensorOpVectorizationPattern too early. 304 // TODO: Improve once we have better infrastructure to control pattern 305 // application. 306 if (vectorizePadding) { 307 RewritePatternSet patterns(funcOp.getContext()); 308 linalg::populatePadTensorOpVectorizationPatterns(patterns); 309 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 310 } 311 } 312 313 LinalgVectorizationOptions options; 314 LinalgTransformationFilter filter; 315 }; 316 317 /// Configurable pass to enable the application of other pattern-based linalg 318 /// passes. 319 struct LinalgStrategyEnablePass 320 : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> { 321 322 LinalgStrategyEnablePass(LinalgEnablingOptions opt, 323 LinalgTransformationFilter filt) 324 : options(opt), filter(std::move(filt)) {} 325 326 void runOnFunction() override { 327 auto funcOp = getFunction(); 328 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 329 return; 330 331 MLIRContext *context = funcOp.getContext(); 332 RewritePatternSet patterns = 333 linalg::getLinalgTilingCanonicalizationPatterns(context); 334 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 335 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 336 return signalPassFailure(); 337 338 if (options.licm) { 339 if (funcOp 340 ->walk([&](LoopLikeOpInterface loopLike) { 341 if (failed(moveLoopInvariantCode(loopLike))) 342 return WalkResult::interrupt(); 343 return WalkResult::advance(); 344 }) 345 .wasInterrupted()) 346 return signalPassFailure(); 347 } 348 349 promoteSingleIterationLoops(funcOp); 350 if (options.hoistRedundantVectorTransfers) 351 hoistRedundantVectorTransfers(funcOp); 352 353 if (options.hoistRedundantVectorTransfersOnTensor) 354 hoistRedundantVectorTransfersOnTensor(funcOp); 355 356 // Run CSE to cleanup after canonicalization. 357 OpPassManager dynamicPM("builtin.func"); 358 dynamicPM.addPass(createCSEPass()); 359 if (failed(runPipeline(dynamicPM, funcOp))) 360 return signalPassFailure(); 361 } 362 363 LinalgEnablingOptions options; 364 LinalgTransformationFilter filter; 365 }; 366 367 /// Configurable pass to lower vector operations. 368 struct LinalgStrategyLowerVectorsPass 369 : public LinalgStrategyLowerVectorsPassBase< 370 LinalgStrategyLowerVectorsPass> { 371 372 LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 373 LinalgTransformationFilter filt) 374 : options(opt), filter(std::move(filt)) {} 375 376 void runOnFunction() override { 377 auto funcOp = getFunction(); 378 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 379 return; 380 381 MLIRContext *context = funcOp.getContext(); 382 RewritePatternSet patterns(context); 383 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 384 // In a progressive lowering of vectors, this would be the 1st step. 385 if (options.contractionLowering) { 386 patterns.add<ContractionOpToOuterProductOpLowering, 387 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 388 options.vectorTransformOptions, context); 389 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 390 } 391 // In a progressive lowering of vectors, this would be the 2nd step. 392 if (options.multiReductionLowering) { 393 vector::populateVectorMultiReductionLoweringPatterns( 394 patterns, 395 options.vectorTransformOptions.vectorMultiReductionLowering); 396 } 397 // In a progressive lowering of vectors, this would be the 3rd step. 398 if (options.transferPartialRewrite) { 399 patterns.add<vector::VectorTransferFullPartialRewriter>( 400 context, options.vectorTransformOptions); 401 } 402 // In a progressive lowering of vectors, this would be the 4th step. 403 if (options.transferLowering) { 404 vector::populateVectorTransferLoweringPatterns(patterns, 405 options.maxTransferRank); 406 } 407 // In a progressive lowering of vectors, this would be the 5th step. 408 if (options.transferToSCFConversion) { 409 populateVectorToSCFConversionPatterns( 410 patterns, options.vectorTransferToSCFOptions.setTargetRank( 411 options.maxTransferRank)); 412 } 413 // In a progressive lowering of vectors, this would be the 6th step. 414 if (options.shapeCastLowering) { 415 vector::populateVectorShapeCastLoweringPatterns(patterns); 416 } 417 // In a progressive lowering of vectors, this would be the 7th step. 418 if (options.transposeLowering) { 419 vector::populateVectorTransposeLoweringPatterns( 420 patterns, options.vectorTransformOptions); 421 if (options.avx2Lowering) 422 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 423 patterns, options.avx2LoweringOptions, /*benefit=*/10); 424 } 425 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 426 } 427 428 LinalgVectorLoweringOptions options; 429 LinalgTransformationFilter filter; 430 }; 431 432 /// Configurable pass to lower vector operations. 433 struct LinalgStrategyRemoveMarkersPass 434 : public LinalgStrategyRemoveMarkersPassBase< 435 LinalgStrategyRemoveMarkersPass> { 436 437 void runOnFunction() override { 438 auto funcOp = getFunction(); 439 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 440 return; 441 funcOp.walk([](LinalgOp op) { 442 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 443 }); 444 } 445 }; 446 } // namespace 447 448 /// Create a LinalgStrategyTileAndFusePass. 449 std::unique_ptr<OperationPass<FuncOp>> 450 mlir::createLinalgStrategyTileAndFusePass( 451 StringRef opName, const LinalgTilingAndFusionOptions &options, 452 const LinalgTransformationFilter &filter) { 453 return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options, 454 filter); 455 } 456 457 /// Create a LinalgStrategyTilePass. 458 std::unique_ptr<OperationPass<FuncOp>> 459 mlir::createLinalgStrategyTilePass(StringRef opName, 460 const LinalgTilingOptions &opt, 461 const LinalgTransformationFilter &filter) { 462 return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter); 463 } 464 465 /// Create a LinalgStrategyPadPass. 466 std::unique_ptr<OperationPass<FuncOp>> 467 mlir::createLinalgStrategyPadPass(StringRef opName, 468 const LinalgPaddingOptions &opt, 469 const LinalgTransformationFilter &filter) { 470 return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter); 471 } 472 473 /// Create a LinalgStrategyPromotePass. 474 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyPromotePass( 475 StringRef opName, const LinalgPromotionOptions &opt, 476 const LinalgTransformationFilter &filter) { 477 return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter); 478 } 479 480 /// Create a LinalgStrategyGeneralizePass. 481 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyGeneralizePass( 482 StringRef opName, const LinalgTransformationFilter &filter) { 483 return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); 484 } 485 486 /// Create a LinalgStrategyDecomposePass. 487 // TODO: if/when we need finer control add an `opName` parameter. 488 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyDecomposePass( 489 const LinalgTransformationFilter &filter) { 490 return std::make_unique<LinalgStrategyDecomposePass>(filter); 491 } 492 493 /// Create a LinalgStrategyInterchangePass. 494 std::unique_ptr<OperationPass<FuncOp>> 495 mlir::createLinalgStrategyInterchangePass( 496 ArrayRef<int64_t> iteratorInterchange, 497 const LinalgTransformationFilter &filter) { 498 return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, 499 filter); 500 } 501 502 /// Create a LinalgStrategyVectorizePass. 503 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass( 504 StringRef opName, LinalgVectorizationOptions opt, 505 const LinalgTransformationFilter &filter, bool padVectorize) { 506 return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter, 507 padVectorize); 508 } 509 510 /// Create a LinalgStrategyEnablePass. 511 std::unique_ptr<OperationPass<FuncOp>> 512 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt, 513 const LinalgTransformationFilter &filter) { 514 return std::make_unique<LinalgStrategyEnablePass>(opt, filter); 515 } 516 517 /// Create a LinalgStrategyLowerVectorsPass. 518 std::unique_ptr<OperationPass<FuncOp>> 519 mlir::createLinalgStrategyLowerVectorsPass( 520 LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) { 521 return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter); 522 } 523 524 /// Create a LinalgStrategyRemoveMarkersPass. 525 std::unique_ptr<OperationPass<FuncOp>> 526 mlir::createLinalgStrategyRemoveMarkersPass() { 527 return std::make_unique<LinalgStrategyRemoveMarkersPass>(); 528 } 529