1 //===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a configurable pass that can apply patterns liberally 10 // and be plugged in a pass pipeline. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/SCF/Transforms.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Dialect/Vector/VectorTransforms.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/AffineMap.h" 28 #include "mlir/Pass/PassManager.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "mlir/Transforms/LoopUtils.h" 32 #include "mlir/Transforms/Passes.h" 33 #include "mlir/Transforms/Utils.h" 34 35 using namespace mlir; 36 using namespace mlir::vector; 37 using namespace linalg; 38 39 namespace { 40 41 /// Configurable pass to apply pattern-based tiling and fusion. 42 struct LinalgStrategyTileAndFusePass 43 : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> { 44 45 LinalgStrategyTileAndFusePass() = default; 46 47 LinalgStrategyTileAndFusePass(StringRef opName, 48 LinalgTilingAndFusionOptions opt, 49 LinalgTransformationFilter filt) 50 : options(opt), filter(filt) { 51 this->anchorOpName.setValue(opName.str()); 52 } 53 54 void runOnFunction() override { 55 auto funcOp = getFunction(); 56 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 57 return; 58 59 RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); 60 if (!anchorOpName.empty()) { 61 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 62 anchorOpName, funcOp.getContext(), options, filter); 63 } else { 64 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 65 funcOp.getContext(), options, filter); 66 } 67 // Search the root operation using bottom up traversal. 68 GreedyRewriteConfig config; 69 config.useTopDownTraversal = false; 70 (void)applyPatternsAndFoldGreedily( 71 funcOp, std::move(tilingAndFusionPattern), config); 72 } 73 74 LinalgTilingAndFusionOptions options; 75 LinalgTransformationFilter filter; 76 }; 77 78 /// Configurable pass to apply pattern-based linalg tiling. 79 struct LinalgStrategyTilePass 80 : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> { 81 82 LinalgStrategyTilePass() = default; 83 84 LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 85 LinalgTransformationFilter filt) 86 : options(opt), filter(filt) { 87 this->anchorOpName.setValue(opName.str()); 88 } 89 90 void runOnFunction() override { 91 auto funcOp = getFunction(); 92 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 93 return; 94 95 RewritePatternSet tilingPattern(funcOp.getContext()); 96 if (!anchorOpName.empty()) { 97 tilingPattern.add<LinalgGenericTilingPattern>( 98 anchorOpName, funcOp.getContext(), options, filter); 99 } else { 100 tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter, 101 options); 102 } 103 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 104 } 105 106 LinalgTilingOptions options; 107 LinalgTransformationFilter filter; 108 }; 109 110 /// Configurable pass to apply hoisting and padding. 111 struct LinalgStrategyPadPass 112 : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> { 113 114 LinalgStrategyPadPass() = default; 115 116 LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 117 LinalgTransformationFilter filt) 118 : options(opt), filter(filt) { 119 this->anchorOpName.setValue(opName.str()); 120 } 121 122 void runOnFunction() override { 123 auto funcOp = getFunction(); 124 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 125 return; 126 127 RewritePatternSet paddingPattern(funcOp.getContext()); 128 if (!anchorOpName.empty()) { 129 paddingPattern.add<LinalgPaddingPattern>( 130 anchorOpName, funcOp.getContext(), options, filter); 131 } else { 132 paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options, 133 filter); 134 } 135 // Traverse the operations top down to pad producers before consumers. The 136 // extract slice operation introduced after every padded operation enables 137 // padding its consumers. Padding an operation whose producers have not been 138 // padded before fails due to the missing extract slice operations. In this 139 // case, the padding pattern increments the transformation marker without 140 // padding the operation. The top down traversal is thus not only a 141 // performance optimization but also needed to pad all operations along the 142 // use-def chains. 143 GreedyRewriteConfig config; 144 config.useTopDownTraversal = true; 145 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern), 146 config))) 147 signalPassFailure(); 148 } 149 150 LinalgPaddingOptions options; 151 LinalgTransformationFilter filter; 152 }; 153 154 /// Configurable pass to apply pattern-based linalg generalization. 155 struct LinalgStrategyGeneralizePass 156 : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> { 157 158 LinalgStrategyGeneralizePass() = default; 159 160 LinalgStrategyGeneralizePass(StringRef opName, 161 LinalgTransformationFilter filter) 162 : filter(filter) { 163 this->anchorOpName.setValue(opName.str()); 164 } 165 166 void runOnFunction() override { 167 auto funcOp = getFunction(); 168 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 169 return; 170 171 RewritePatternSet generalizationPattern(funcOp.getContext()); 172 if (!anchorOpName.empty()) { 173 generalizationPattern.add<LinalgGeneralizationPattern>( 174 anchorOpName, funcOp.getContext(), filter); 175 } else { 176 generalizationPattern.add<LinalgGeneralizationPattern>( 177 funcOp.getContext(), filter); 178 } 179 if (failed(applyPatternsAndFoldGreedily(funcOp, 180 std::move(generalizationPattern)))) 181 signalPassFailure(); 182 } 183 184 LinalgTransformationFilter filter; 185 }; 186 187 /// Configurable pass to apply lowering of coarser-grained named linalg ops into 188 /// finer-grained named versions. 189 struct LinalgStrategyDecomposePass 190 : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> { 191 192 LinalgStrategyDecomposePass() = default; 193 194 LinalgStrategyDecomposePass(LinalgTransformationFilter filter) 195 : filter(filter) {} 196 197 void runOnFunction() override { 198 auto funcOp = getFunction(); 199 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 200 return; 201 RewritePatternSet decompositionPattern(funcOp.getContext()); 202 populateDecomposeConvolutionPatterns(decompositionPattern, filter); 203 if (failed(applyPatternsAndFoldGreedily(funcOp, 204 std::move(decompositionPattern)))) 205 signalPassFailure(); 206 } 207 208 LinalgTransformationFilter filter; 209 }; 210 211 /// Configurable pass to apply pattern-based linalg generalization. 212 struct LinalgStrategyInterchangePass 213 : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { 214 215 LinalgStrategyInterchangePass() = default; 216 217 LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 218 LinalgTransformationFilter filter) 219 : iteratorInterchange(iteratorInterchange.begin(), 220 iteratorInterchange.end()), 221 filter(filter) {} 222 223 void runOnFunction() override { 224 auto funcOp = getFunction(); 225 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 226 return; 227 228 SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(), 229 iteratorInterchange.end()); 230 RewritePatternSet interchangePattern(funcOp.getContext()); 231 interchangePattern.add<GenericOpInterchangePattern>( 232 funcOp.getContext(), interchangeVector, filter); 233 if (failed(applyPatternsAndFoldGreedily(funcOp, 234 std::move(interchangePattern)))) 235 signalPassFailure(); 236 } 237 238 SmallVector<int64_t> iteratorInterchange; 239 LinalgTransformationFilter filter; 240 }; 241 242 /// Configurable pass to apply pattern-based linalg promotion. 243 struct LinalgStrategyPromotePass 244 : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> { 245 246 LinalgStrategyPromotePass() = default; 247 248 LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, 249 LinalgTransformationFilter filt) 250 : options(opt), filter(filt) { 251 this->anchorOpName.setValue(opName.str()); 252 } 253 254 void runOnFunction() override { 255 auto funcOp = getFunction(); 256 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 257 return; 258 259 RewritePatternSet promotionPattern(funcOp.getContext()); 260 if (!anchorOpName.empty()) { 261 promotionPattern.add<LinalgBasePromotionPattern>( 262 anchorOpName, funcOp.getContext(), options, filter); 263 } else { 264 promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(), 265 filter, options); 266 } 267 (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); 268 } 269 270 LinalgPromotionOptions options; 271 LinalgTransformationFilter filter; 272 }; 273 274 /// Configurable pass to apply pattern-based linalg vectorization. 275 struct LinalgStrategyVectorizePass 276 : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> { 277 278 LinalgStrategyVectorizePass() = default; 279 280 LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, 281 LinalgTransformationFilter filt, 282 bool padVectorize = false) 283 : options(opt), filter(filt) { 284 this->anchorOpName.setValue(opName.str()); 285 this->vectorizePadding.setValue(padVectorize); 286 } 287 288 void runOnFunction() override { 289 auto funcOp = getFunction(); 290 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 291 return; 292 293 RewritePatternSet vectorizationPatterns(funcOp.getContext()); 294 if (!anchorOpName.empty()) { 295 vectorizationPatterns.add<LinalgVectorizationPattern>( 296 anchorOpName, funcOp.getContext(), options, filter); 297 } else { 298 vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(), 299 filter, options); 300 } 301 vector::populateVectorTransferPermutationMapLoweringPatterns( 302 vectorizationPatterns); 303 vector::populateVectorReductionToContractPatterns(vectorizationPatterns); 304 vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern, 305 linalg::LinalgCopyVTWForwardingPattern>( 306 funcOp.getContext(), /*benefit=*/2); 307 if (vectorizePadding) { 308 linalg::populatePadTensorOpVectorizationPatterns(vectorizationPatterns); 309 } 310 (void)applyPatternsAndFoldGreedily(funcOp, 311 std::move(vectorizationPatterns)); 312 } 313 314 LinalgVectorizationOptions options; 315 LinalgTransformationFilter filter; 316 }; 317 318 /// Configurable pass to enable the application of other pattern-based linalg 319 /// passes. 320 struct LinalgStrategyEnablePass 321 : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> { 322 323 LinalgStrategyEnablePass(LinalgEnablingOptions opt, 324 LinalgTransformationFilter filt) 325 : options(opt), filter(filt) {} 326 327 void runOnFunction() override { 328 auto funcOp = getFunction(); 329 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 330 return; 331 332 MLIRContext *context = funcOp.getContext(); 333 RewritePatternSet patterns = 334 linalg::getLinalgTilingCanonicalizationPatterns(context); 335 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 336 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 337 return signalPassFailure(); 338 339 if (options.licm) { 340 if (funcOp 341 ->walk([&](LoopLikeOpInterface loopLike) { 342 if (failed(moveLoopInvariantCode(loopLike))) 343 return WalkResult::interrupt(); 344 return WalkResult::advance(); 345 }) 346 .wasInterrupted()) 347 return signalPassFailure(); 348 } 349 350 promoteSingleIterationLoops(funcOp); 351 if (options.hoistRedundantVectorTransfers) 352 hoistRedundantVectorTransfers(funcOp); 353 354 if (options.hoistRedundantVectorTransfersOnTensor) 355 hoistRedundantVectorTransfersOnTensor(funcOp); 356 357 // Run CSE to cleanup after canonicalization. 358 OpPassManager dynamicPM("builtin.func"); 359 dynamicPM.addPass(createCSEPass()); 360 if (failed(runPipeline(dynamicPM, funcOp))) 361 return signalPassFailure(); 362 } 363 364 LinalgEnablingOptions options; 365 LinalgTransformationFilter filter; 366 }; 367 368 /// Configurable pass to lower vector operations. 369 struct LinalgStrategyLowerVectorsPass 370 : public LinalgStrategyLowerVectorsPassBase< 371 LinalgStrategyLowerVectorsPass> { 372 373 LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 374 LinalgTransformationFilter filt) 375 : options(opt), filter(filt) {} 376 377 void runOnFunction() override { 378 auto funcOp = getFunction(); 379 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 380 return; 381 382 MLIRContext *context = funcOp.getContext(); 383 RewritePatternSet patterns(context); 384 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 385 // In a progressive lowering of vectors, this would be the 1st step. 386 if (options.contractionLowering) { 387 patterns.add<ContractionOpToOuterProductOpLowering, 388 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 389 options.vectorTransformOptions, context); 390 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 391 } 392 // In a progressive lowering of vectors, this would be the 2nd step. 393 if (options.multiReductionLowering) { 394 vector::populateVectorMultiReductionLoweringPatterns( 395 patterns, 396 options.vectorTransformOptions.vectorMultiReductionLowering); 397 } 398 // In a progressive lowering of vectors, this would be the 3rd step. 399 if (options.transferPartialRewrite) { 400 patterns.add<vector::VectorTransferFullPartialRewriter>( 401 context, options.vectorTransformOptions); 402 } 403 // In a progressive lowering of vectors, this would be the 4th step. 404 if (options.transferLowering) { 405 vector::populateVectorTransferLoweringPatterns(patterns, 406 options.maxTransferRank); 407 } 408 // In a progressive lowering of vectors, this would be the 5th step. 409 if (options.transferToSCFConversion) { 410 populateVectorToSCFConversionPatterns( 411 patterns, options.vectorTransferToSCFOptions.setTargetRank( 412 options.maxTransferRank)); 413 } 414 // In a progressive lowering of vectors, this would be the 6th step. 415 if (options.shapeCastLowering) { 416 vector::populateVectorShapeCastLoweringPatterns(patterns); 417 } 418 // In a progressive lowering of vectors, this would be the 7th step. 419 if (options.transposeLowering) { 420 vector::populateVectorTransposeLoweringPatterns( 421 patterns, options.vectorTransformOptions); 422 if (options.avx2Lowering) 423 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 424 patterns, options.avx2LoweringOptions, /*benefit=*/10); 425 } 426 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 427 } 428 429 LinalgVectorLoweringOptions options; 430 LinalgTransformationFilter filter; 431 }; 432 433 /// Configurable pass to lower vector operations. 434 struct LinalgStrategyRemoveMarkersPass 435 : public LinalgStrategyRemoveMarkersPassBase< 436 LinalgStrategyRemoveMarkersPass> { 437 438 void runOnFunction() override { 439 auto funcOp = getFunction(); 440 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 441 return; 442 funcOp.walk([](LinalgOp op) { 443 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 444 }); 445 } 446 }; 447 } // namespace 448 449 /// Create a LinalgStrategyTileAndFusePass. 450 std::unique_ptr<OperationPass<FuncOp>> 451 mlir::createLinalgStrategyTileAndFusePass(StringRef opName, 452 LinalgTilingAndFusionOptions options, 453 LinalgTransformationFilter filter) { 454 return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options, 455 filter); 456 } 457 458 /// Create a LinalgStrategyTilePass. 459 std::unique_ptr<OperationPass<FuncOp>> 460 mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 461 LinalgTransformationFilter filter) { 462 return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter); 463 } 464 465 /// Create a LinalgStrategyPadPass. 466 std::unique_ptr<OperationPass<FuncOp>> 467 mlir::createLinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 468 LinalgTransformationFilter filter) { 469 return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter); 470 } 471 472 /// Create a LinalgStrategyPromotePass. 473 std::unique_ptr<OperationPass<FuncOp>> 474 mlir::createLinalgStrategyPromotePass(StringRef opName, 475 LinalgPromotionOptions opt, 476 LinalgTransformationFilter filter) { 477 return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter); 478 } 479 480 /// Create a LinalgStrategyGeneralizePass. 481 std::unique_ptr<OperationPass<FuncOp>> 482 mlir::createLinalgStrategyGeneralizePass(StringRef opName, 483 LinalgTransformationFilter filter) { 484 return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); 485 } 486 487 /// Create a LinalgStrategyDecomposePass. 488 // TODO: if/when we need finer control add an `opName` parameter. 489 std::unique_ptr<OperationPass<FuncOp>> 490 mlir::createLinalgStrategyDecomposePass(LinalgTransformationFilter filter) { 491 return std::make_unique<LinalgStrategyDecomposePass>(filter); 492 } 493 494 /// Create a LinalgStrategyInterchangePass. 495 std::unique_ptr<OperationPass<FuncOp>> 496 mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 497 LinalgTransformationFilter filter) { 498 return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, 499 filter); 500 } 501 502 /// Create a LinalgStrategyVectorizePass. 503 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass( 504 StringRef opName, LinalgVectorizationOptions opt, 505 LinalgTransformationFilter filter, bool padVectorize) { 506 return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter, 507 padVectorize); 508 } 509 510 /// Create a LinalgStrategyEnablePass. 511 std::unique_ptr<OperationPass<FuncOp>> 512 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt, 513 LinalgTransformationFilter filter) { 514 return std::make_unique<LinalgStrategyEnablePass>(opt, filter); 515 } 516 517 /// Create a LinalgStrategyLowerVectorsPass. 518 std::unique_ptr<OperationPass<FuncOp>> 519 mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 520 LinalgTransformationFilter filter) { 521 return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter); 522 } 523 524 /// Create a LinalgStrategyRemoveMarkersPass. 525 std::unique_ptr<OperationPass<FuncOp>> 526 mlir::createLinalgStrategyRemoveMarkersPass() { 527 return std::make_unique<LinalgStrategyRemoveMarkersPass>(); 528 } 529