1 //===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a configurable pass that can apply patterns liberally 10 // and be plugged in a pass pipeline. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/SCF/Transforms.h" 23 #include "mlir/Dialect/Tensor/IR/Tensor.h" 24 #include "mlir/Dialect/Vector/VectorTransforms.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/Pass/PassManager.h" 28 #include "mlir/Support/LLVM.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "mlir/Transforms/LoopUtils.h" 31 #include "mlir/Transforms/Passes.h" 32 #include "mlir/Transforms/Utils.h" 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 using namespace linalg; 37 38 namespace { 39 40 /// Configurable pass to apply pattern-based tiling and fusion. 41 struct LinalgStrategyTileAndFusePass 42 : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> { 43 44 LinalgStrategyTileAndFusePass() = default; 45 46 LinalgStrategyTileAndFusePass(StringRef opName, 47 LinalgTilingAndFusionOptions opt, 48 LinalgTransformationFilter filt) 49 : options(opt), filter(filt) { 50 this->anchorOpName.setValue(opName.str()); 51 } 52 53 void runOnFunction() override { 54 auto funcOp = getFunction(); 55 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 56 return; 57 58 RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); 59 if (!anchorOpName.empty()) { 60 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 61 anchorOpName, funcOp.getContext(), options, filter); 62 } else { 63 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 64 funcOp.getContext(), options, filter); 65 } 66 // Search the root operation using bottom up traversal. 67 GreedyRewriteConfig config; 68 config.useTopDownTraversal = false; 69 (void)applyPatternsAndFoldGreedily( 70 funcOp, std::move(tilingAndFusionPattern), config); 71 } 72 73 LinalgTilingAndFusionOptions options; 74 LinalgTransformationFilter filter; 75 }; 76 77 /// Configurable pass to apply pattern-based linalg tiling. 78 struct LinalgStrategyTilePass 79 : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> { 80 81 LinalgStrategyTilePass() = default; 82 83 LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 84 LinalgTransformationFilter filt) 85 : options(opt), filter(filt) { 86 this->anchorOpName.setValue(opName.str()); 87 } 88 89 void runOnFunction() override { 90 auto funcOp = getFunction(); 91 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 92 return; 93 94 RewritePatternSet tilingPattern(funcOp.getContext()); 95 if (!anchorOpName.empty()) { 96 tilingPattern.add<LinalgGenericTilingPattern>( 97 anchorOpName, funcOp.getContext(), options, filter); 98 } else { 99 tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter, 100 options); 101 } 102 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 103 } 104 105 LinalgTilingOptions options; 106 LinalgTransformationFilter filter; 107 }; 108 109 /// Configurable pass to apply hoisting and padding. 110 struct LinalgStrategyPadPass 111 : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> { 112 113 LinalgStrategyPadPass() = default; 114 115 LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 116 LinalgTransformationFilter filt) 117 : options(opt), filter(filt) { 118 this->anchorOpName.setValue(opName.str()); 119 } 120 121 void runOnFunction() override { 122 auto funcOp = getFunction(); 123 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 124 return; 125 126 RewritePatternSet paddingPattern(funcOp.getContext()); 127 if (!anchorOpName.empty()) { 128 paddingPattern.add<LinalgPaddingPattern>( 129 anchorOpName, funcOp.getContext(), options, filter); 130 } else { 131 paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options, 132 filter); 133 } 134 (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)); 135 } 136 137 LinalgPaddingOptions options; 138 LinalgTransformationFilter filter; 139 }; 140 141 /// Configurable pass to apply pattern-based linalg generalization. 142 struct LinalgStrategyGeneralizePass 143 : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> { 144 145 LinalgStrategyGeneralizePass() = default; 146 147 LinalgStrategyGeneralizePass(StringRef opName, 148 LinalgTransformationFilter filter) 149 : filter(filter) { 150 this->anchorOpName.setValue(opName.str()); 151 } 152 153 void runOnFunction() override { 154 auto funcOp = getFunction(); 155 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 156 return; 157 158 RewritePatternSet generalizationPattern(funcOp.getContext()); 159 if (!anchorOpName.empty()) { 160 generalizationPattern.add<LinalgGeneralizationPattern>( 161 anchorOpName, funcOp.getContext(), filter); 162 } else { 163 generalizationPattern.add<LinalgGeneralizationPattern>( 164 funcOp.getContext(), filter); 165 } 166 if (failed(applyPatternsAndFoldGreedily(funcOp, 167 std::move(generalizationPattern)))) 168 signalPassFailure(); 169 } 170 171 LinalgTransformationFilter filter; 172 }; 173 174 /// Configurable pass to apply lowering of coarser-grained named linalg ops into 175 /// finer-grained named versions. 176 struct LinalgStrategyDecomposePass 177 : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> { 178 179 LinalgStrategyDecomposePass() = default; 180 181 LinalgStrategyDecomposePass(LinalgTransformationFilter filter) 182 : filter(filter) {} 183 184 void runOnFunction() override { 185 auto funcOp = getFunction(); 186 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 187 return; 188 RewritePatternSet decompositionPattern(funcOp.getContext()); 189 populateDecomposeConvolutionPatterns(decompositionPattern, filter); 190 if (failed(applyPatternsAndFoldGreedily(funcOp, 191 std::move(decompositionPattern)))) 192 signalPassFailure(); 193 } 194 195 LinalgTransformationFilter filter; 196 }; 197 198 /// Configurable pass to apply pattern-based linalg generalization. 199 struct LinalgStrategyInterchangePass 200 : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { 201 202 LinalgStrategyInterchangePass() = default; 203 204 LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 205 LinalgTransformationFilter filter) 206 : iteratorInterchange(iteratorInterchange.begin(), 207 iteratorInterchange.end()), 208 filter(filter) {} 209 210 void runOnFunction() override { 211 auto funcOp = getFunction(); 212 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 213 return; 214 215 SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(), 216 iteratorInterchange.end()); 217 RewritePatternSet interchangePattern(funcOp.getContext()); 218 interchangePattern.add<GenericOpInterchangePattern>( 219 funcOp.getContext(), interchangeVector, filter); 220 if (failed(applyPatternsAndFoldGreedily(funcOp, 221 std::move(interchangePattern)))) 222 signalPassFailure(); 223 } 224 225 SmallVector<int64_t> iteratorInterchange; 226 LinalgTransformationFilter filter; 227 }; 228 229 /// Configurable pass to apply pattern-based linalg promotion. 230 struct LinalgStrategyPromotePass 231 : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> { 232 233 LinalgStrategyPromotePass() = default; 234 235 LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, 236 LinalgTransformationFilter filt) 237 : options(opt), filter(filt) { 238 this->anchorOpName.setValue(opName.str()); 239 } 240 241 void runOnFunction() override { 242 auto funcOp = getFunction(); 243 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 244 return; 245 246 RewritePatternSet promotionPattern(funcOp.getContext()); 247 if (!anchorOpName.empty()) { 248 promotionPattern.add<LinalgBasePromotionPattern>( 249 anchorOpName, funcOp.getContext(), options, filter); 250 } else { 251 promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(), 252 filter, options); 253 } 254 (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); 255 } 256 257 LinalgPromotionOptions options; 258 LinalgTransformationFilter filter; 259 }; 260 261 /// Configurable pass to apply pattern-based linalg vectorization. 262 struct LinalgStrategyVectorizePass 263 : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> { 264 265 LinalgStrategyVectorizePass() = default; 266 267 LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, 268 LinalgTransformationFilter filt, 269 bool padVectorize = false) 270 : options(opt), filter(filt) { 271 this->anchorOpName.setValue(opName.str()); 272 this->vectorizePadding.setValue(padVectorize); 273 } 274 275 void runOnFunction() override { 276 auto funcOp = getFunction(); 277 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 278 return; 279 280 RewritePatternSet vectorizationPatterns(funcOp.getContext()); 281 if (!anchorOpName.empty()) { 282 vectorizationPatterns.add<LinalgVectorizationPattern>( 283 anchorOpName, funcOp.getContext(), options, filter); 284 } else { 285 vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(), 286 filter, options); 287 } 288 vector::populateVectorTransferPermutationMapLoweringPatterns( 289 vectorizationPatterns); 290 vector::populateVectorReductionToContractPatterns(vectorizationPatterns); 291 vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern, 292 linalg::LinalgCopyVTWForwardingPattern>( 293 funcOp.getContext(), /*benefit=*/2); 294 TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, 295 funcOp.getContext()); 296 TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, 297 funcOp.getContext()); 298 (void)applyPatternsAndFoldGreedily(funcOp, 299 std::move(vectorizationPatterns)); 300 301 // Apply the pad tensor op vectorization separately to avoid running the 302 // GenericPadTensorOpVectorizationPattern too early. 303 // TODO: Improve once we have better infrastructure to control pattern 304 // application. 305 if (vectorizePadding) { 306 RewritePatternSet patterns(funcOp.getContext()); 307 linalg::populatePadTensorOpVectorizationPatterns(patterns); 308 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 309 } 310 } 311 312 LinalgVectorizationOptions options; 313 LinalgTransformationFilter filter; 314 }; 315 316 /// Configurable pass to enable the application of other pattern-based linalg 317 /// passes. 318 struct LinalgStrategyEnablePass 319 : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> { 320 321 LinalgStrategyEnablePass(LinalgEnablingOptions opt, 322 LinalgTransformationFilter filt) 323 : options(opt), filter(filt) {} 324 325 void runOnFunction() override { 326 auto funcOp = getFunction(); 327 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 328 return; 329 330 MLIRContext *context = funcOp.getContext(); 331 RewritePatternSet patterns = 332 linalg::getLinalgTilingCanonicalizationPatterns(context); 333 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 334 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 335 return signalPassFailure(); 336 337 if (options.licm) { 338 if (funcOp 339 ->walk([&](LoopLikeOpInterface loopLike) { 340 if (failed(moveLoopInvariantCode(loopLike))) 341 return WalkResult::interrupt(); 342 return WalkResult::advance(); 343 }) 344 .wasInterrupted()) 345 return signalPassFailure(); 346 } 347 348 promoteSingleIterationLoops(funcOp); 349 if (options.hoistRedundantVectorTransfers) 350 hoistRedundantVectorTransfers(funcOp); 351 352 if (options.hoistRedundantVectorTransfersOnTensor) 353 hoistRedundantVectorTransfersOnTensor(funcOp); 354 355 // Run CSE to cleanup after canonicalization. 356 OpPassManager dynamicPM("builtin.func"); 357 dynamicPM.addPass(createCSEPass()); 358 if (failed(runPipeline(dynamicPM, funcOp))) 359 return signalPassFailure(); 360 } 361 362 LinalgEnablingOptions options; 363 LinalgTransformationFilter filter; 364 }; 365 366 /// Configurable pass to lower vector operations. 367 struct LinalgStrategyLowerVectorsPass 368 : public LinalgStrategyLowerVectorsPassBase< 369 LinalgStrategyLowerVectorsPass> { 370 371 LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 372 LinalgTransformationFilter filt) 373 : options(opt), filter(filt) {} 374 375 void runOnFunction() override { 376 auto funcOp = getFunction(); 377 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 378 return; 379 380 MLIRContext *context = funcOp.getContext(); 381 RewritePatternSet patterns(context); 382 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 383 // In a progressive lowering of vectors, this would be the 1st step. 384 if (options.contractionLowering) { 385 patterns.add<ContractionOpToOuterProductOpLowering, 386 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 387 options.vectorTransformOptions, context); 388 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 389 } 390 // In a progressive lowering of vectors, this would be the 2nd step. 391 if (options.multiReductionLowering) { 392 vector::populateVectorMultiReductionLoweringPatterns( 393 patterns, 394 options.vectorTransformOptions.vectorMultiReductionLowering); 395 } 396 // In a progressive lowering of vectors, this would be the 3rd step. 397 if (options.transferPartialRewrite) { 398 patterns.add<vector::VectorTransferFullPartialRewriter>( 399 context, options.vectorTransformOptions); 400 } 401 // In a progressive lowering of vectors, this would be the 4th step. 402 if (options.transferLowering) { 403 vector::populateVectorTransferLoweringPatterns(patterns, 404 options.maxTransferRank); 405 } 406 // In a progressive lowering of vectors, this would be the 5th step. 407 if (options.transferToSCFConversion) { 408 populateVectorToSCFConversionPatterns( 409 patterns, options.vectorTransferToSCFOptions.setTargetRank( 410 options.maxTransferRank)); 411 } 412 // In a progressive lowering of vectors, this would be the 6th step. 413 if (options.shapeCastLowering) { 414 vector::populateVectorShapeCastLoweringPatterns(patterns); 415 } 416 // In a progressive lowering of vectors, this would be the 7th step. 417 if (options.transposeLowering) { 418 vector::populateVectorTransposeLoweringPatterns( 419 patterns, options.vectorTransformOptions); 420 if (options.avx2Lowering) 421 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 422 patterns, options.avx2LoweringOptions, /*benefit=*/10); 423 } 424 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 425 } 426 427 LinalgVectorLoweringOptions options; 428 LinalgTransformationFilter filter; 429 }; 430 431 /// Configurable pass to lower vector operations. 432 struct LinalgStrategyRemoveMarkersPass 433 : public LinalgStrategyRemoveMarkersPassBase< 434 LinalgStrategyRemoveMarkersPass> { 435 436 void runOnFunction() override { 437 auto funcOp = getFunction(); 438 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 439 return; 440 funcOp.walk([](LinalgOp op) { 441 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 442 }); 443 } 444 }; 445 } // namespace 446 447 /// Create a LinalgStrategyTileAndFusePass. 448 std::unique_ptr<OperationPass<FuncOp>> 449 mlir::createLinalgStrategyTileAndFusePass(StringRef opName, 450 LinalgTilingAndFusionOptions options, 451 LinalgTransformationFilter filter) { 452 return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options, 453 filter); 454 } 455 456 /// Create a LinalgStrategyTilePass. 457 std::unique_ptr<OperationPass<FuncOp>> 458 mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 459 LinalgTransformationFilter filter) { 460 return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter); 461 } 462 463 /// Create a LinalgStrategyPadPass. 464 std::unique_ptr<OperationPass<FuncOp>> 465 mlir::createLinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 466 LinalgTransformationFilter filter) { 467 return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter); 468 } 469 470 /// Create a LinalgStrategyPromotePass. 471 std::unique_ptr<OperationPass<FuncOp>> 472 mlir::createLinalgStrategyPromotePass(StringRef opName, 473 LinalgPromotionOptions opt, 474 LinalgTransformationFilter filter) { 475 return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter); 476 } 477 478 /// Create a LinalgStrategyGeneralizePass. 479 std::unique_ptr<OperationPass<FuncOp>> 480 mlir::createLinalgStrategyGeneralizePass(StringRef opName, 481 LinalgTransformationFilter filter) { 482 return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); 483 } 484 485 /// Create a LinalgStrategyDecomposePass. 486 // TODO: if/when we need finer control add an `opName` parameter. 487 std::unique_ptr<OperationPass<FuncOp>> 488 mlir::createLinalgStrategyDecomposePass(LinalgTransformationFilter filter) { 489 return std::make_unique<LinalgStrategyDecomposePass>(filter); 490 } 491 492 /// Create a LinalgStrategyInterchangePass. 493 std::unique_ptr<OperationPass<FuncOp>> 494 mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 495 LinalgTransformationFilter filter) { 496 return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, 497 filter); 498 } 499 500 /// Create a LinalgStrategyVectorizePass. 501 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass( 502 StringRef opName, LinalgVectorizationOptions opt, 503 LinalgTransformationFilter filter, bool padVectorize) { 504 return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter, 505 padVectorize); 506 } 507 508 /// Create a LinalgStrategyEnablePass. 509 std::unique_ptr<OperationPass<FuncOp>> 510 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt, 511 LinalgTransformationFilter filter) { 512 return std::make_unique<LinalgStrategyEnablePass>(opt, filter); 513 } 514 515 /// Create a LinalgStrategyLowerVectorsPass. 516 std::unique_ptr<OperationPass<FuncOp>> 517 mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 518 LinalgTransformationFilter filter) { 519 return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter); 520 } 521 522 /// Create a LinalgStrategyRemoveMarkersPass. 523 std::unique_ptr<OperationPass<FuncOp>> 524 mlir::createLinalgStrategyRemoveMarkersPass() { 525 return std::make_unique<LinalgStrategyRemoveMarkersPass>(); 526 } 527