1 //===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a configurable pass that can apply patterns liberally 10 // and be plugged in a pass pipeline. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/SCF/Transforms.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Dialect/Vector/VectorTransforms.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/AffineMap.h" 28 #include "mlir/Support/LLVM.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "mlir/Transforms/LoopUtils.h" 31 #include "mlir/Transforms/Utils.h" 32 33 using namespace mlir; 34 using namespace mlir::vector; 35 using namespace linalg; 36 37 namespace { 38 39 /// Configurable pass to apply pattern-based tiling and fusion. 40 struct LinalgStrategyTileAndFusePass 41 : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> { 42 43 LinalgStrategyTileAndFusePass() = default; 44 45 LinalgStrategyTileAndFusePass(StringRef opName, 46 LinalgTilingAndFusionOptions opt, 47 LinalgTransformationFilter filt) 48 : options(opt), filter(filt) { 49 this->anchorOpName.setValue(opName.str()); 50 } 51 52 void runOnFunction() override { 53 auto funcOp = getFunction(); 54 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 55 return; 56 57 RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); 58 if (!anchorOpName.empty()) { 59 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 60 anchorOpName, funcOp.getContext(), options, filter); 61 } else { 62 tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>( 63 funcOp.getContext(), options, filter); 64 } 65 // Search the root operation using bottom up traversal. 66 GreedyRewriteConfig grc; 67 grc.useTopDownTraversal = false; 68 (void)applyPatternsAndFoldGreedily(funcOp, 69 std::move(tilingAndFusionPattern), grc); 70 } 71 72 LinalgTilingAndFusionOptions options; 73 LinalgTransformationFilter filter; 74 }; 75 76 /// Configurable pass to apply pattern-based linalg tiling. 77 struct LinalgStrategyTilePass 78 : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> { 79 80 LinalgStrategyTilePass() = default; 81 82 LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 83 LinalgTransformationFilter filt) 84 : options(opt), filter(filt) { 85 this->anchorOpName.setValue(opName.str()); 86 } 87 88 void runOnFunction() override { 89 auto funcOp = getFunction(); 90 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 91 return; 92 93 RewritePatternSet tilingPattern(funcOp.getContext()); 94 if (!anchorOpName.empty()) { 95 tilingPattern.add<LinalgGenericTilingPattern>( 96 anchorOpName, funcOp.getContext(), options, filter); 97 } else { 98 tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter, 99 options); 100 } 101 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 102 } 103 104 LinalgTilingOptions options; 105 LinalgTransformationFilter filter; 106 }; 107 108 /// Configurable pass to apply hoisting and padding. 109 struct LinalgStrategyPadPass 110 : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> { 111 112 LinalgStrategyPadPass() = default; 113 114 LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 115 LinalgTransformationFilter filt) 116 : options(opt), filter(filt) { 117 this->anchorOpName.setValue(opName.str()); 118 } 119 120 void runOnFunction() override { 121 auto funcOp = getFunction(); 122 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 123 return; 124 125 RewritePatternSet paddingPattern(funcOp.getContext()); 126 if (!anchorOpName.empty()) { 127 paddingPattern.add<LinalgPaddingPattern>( 128 anchorOpName, funcOp.getContext(), options, filter); 129 } else { 130 paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options, 131 filter); 132 } 133 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)))) 134 signalPassFailure(); 135 } 136 137 LinalgPaddingOptions options; 138 LinalgTransformationFilter filter; 139 }; 140 141 /// Configurable pass to apply pattern-based linalg generalization. 142 struct LinalgStrategyGeneralizePass 143 : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> { 144 145 LinalgStrategyGeneralizePass() = default; 146 147 LinalgStrategyGeneralizePass(StringRef opName, 148 LinalgTransformationFilter filter) 149 : filter(filter) { 150 this->anchorOpName.setValue(opName.str()); 151 } 152 153 void runOnFunction() override { 154 auto funcOp = getFunction(); 155 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 156 return; 157 158 RewritePatternSet generalizationPattern(funcOp.getContext()); 159 if (!anchorOpName.empty()) { 160 generalizationPattern.add<LinalgGeneralizationPattern>( 161 anchorOpName, funcOp.getContext(), filter); 162 } else { 163 generalizationPattern.add<LinalgGeneralizationPattern>( 164 funcOp.getContext(), filter); 165 } 166 if (failed(applyPatternsAndFoldGreedily(funcOp, 167 std::move(generalizationPattern)))) 168 signalPassFailure(); 169 } 170 171 LinalgTransformationFilter filter; 172 }; 173 174 /// Configurable pass to apply lowering of coarser-grained named linalg ops into 175 /// finer-grained named versions. 176 struct LinalgStrategyDecomposePass 177 : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> { 178 179 LinalgStrategyDecomposePass() = default; 180 181 void runOnFunction() override { 182 auto funcOp = getFunction(); 183 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 184 return; 185 RewritePatternSet decompositionPattern(funcOp.getContext()); 186 populateDecomposeConvolutionPatterns(decompositionPattern); 187 if (failed(applyPatternsAndFoldGreedily(funcOp, 188 std::move(decompositionPattern)))) 189 signalPassFailure(); 190 } 191 }; 192 193 /// Configurable pass to apply pattern-based linalg generalization. 194 struct LinalgStrategyInterchangePass 195 : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { 196 197 LinalgStrategyInterchangePass() = default; 198 199 LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 200 LinalgTransformationFilter filter) 201 : iteratorInterchange(iteratorInterchange.begin(), 202 iteratorInterchange.end()), 203 filter(filter) {} 204 205 void runOnFunction() override { 206 auto funcOp = getFunction(); 207 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 208 return; 209 210 SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(), 211 iteratorInterchange.end()); 212 RewritePatternSet interchangePattern(funcOp.getContext()); 213 interchangePattern.add<GenericOpInterchangePattern>( 214 funcOp.getContext(), interchangeVector, filter); 215 if (failed(applyPatternsAndFoldGreedily(funcOp, 216 std::move(interchangePattern)))) 217 signalPassFailure(); 218 } 219 220 SmallVector<int64_t> iteratorInterchange; 221 LinalgTransformationFilter filter; 222 }; 223 224 /// Configurable pass to apply pattern-based linalg promotion. 225 struct LinalgStrategyPromotePass 226 : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> { 227 228 LinalgStrategyPromotePass() = default; 229 230 LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, 231 LinalgTransformationFilter filt) 232 : options(opt), filter(filt) { 233 this->anchorOpName.setValue(opName.str()); 234 } 235 236 void runOnFunction() override { 237 auto funcOp = getFunction(); 238 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 239 return; 240 241 RewritePatternSet promotionPattern(funcOp.getContext()); 242 if (!anchorOpName.empty()) { 243 promotionPattern.add<LinalgBasePromotionPattern>( 244 anchorOpName, funcOp.getContext(), options, filter); 245 } else { 246 promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(), 247 filter, options); 248 } 249 (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); 250 } 251 252 LinalgPromotionOptions options; 253 LinalgTransformationFilter filter; 254 }; 255 256 /// Configurable pass to apply pattern-based linalg vectorization. 257 struct LinalgStrategyVectorizePass 258 : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> { 259 260 LinalgStrategyVectorizePass() = default; 261 262 LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, 263 LinalgTransformationFilter filt, 264 bool padVectorize = false) 265 : options(opt), filter(filt) { 266 this->anchorOpName.setValue(opName.str()); 267 this->vectorizePadding.setValue(padVectorize); 268 } 269 270 void runOnFunction() override { 271 auto funcOp = getFunction(); 272 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 273 return; 274 275 RewritePatternSet vectorizationPatterns(funcOp.getContext()); 276 if (!anchorOpName.empty()) { 277 vectorizationPatterns.add<LinalgVectorizationPattern>( 278 anchorOpName, funcOp.getContext(), options, filter); 279 } else { 280 vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(), 281 filter, options); 282 } 283 vector::populateVectorTransferPermutationMapLoweringPatterns( 284 vectorizationPatterns); 285 vector::populateVectorReductionToContractPatterns(vectorizationPatterns); 286 vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern, 287 linalg::LinalgCopyVTWForwardingPattern>( 288 funcOp.getContext(), /*benefit=*/2); 289 if (vectorizePadding) { 290 linalg::populatePadTensorOpVectorizationPatterns(vectorizationPatterns); 291 } 292 (void)applyPatternsAndFoldGreedily(funcOp, 293 std::move(vectorizationPatterns)); 294 } 295 296 LinalgVectorizationOptions options; 297 LinalgTransformationFilter filter; 298 }; 299 300 /// Configurable pass to enable the application of other pattern-based linalg 301 /// passes. 302 struct LinalgStrategyEnablePass 303 : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> { 304 305 LinalgStrategyEnablePass(LinalgEnablingOptions opt, 306 LinalgTransformationFilter filt) 307 : options(opt), filter(filt) {} 308 309 void runOnFunction() override { 310 auto funcOp = getFunction(); 311 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 312 return; 313 314 MLIRContext *context = funcOp.getContext(); 315 RewritePatternSet patterns = 316 linalg::getLinalgTilingCanonicalizationPatterns(context); 317 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 318 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 319 return signalPassFailure(); 320 321 if (options.licm) { 322 if (funcOp 323 ->walk([&](LoopLikeOpInterface loopLike) { 324 if (failed(moveLoopInvariantCode(loopLike))) 325 return WalkResult::interrupt(); 326 return WalkResult::advance(); 327 }) 328 .wasInterrupted()) 329 return signalPassFailure(); 330 } 331 332 promoteSingleIterationLoops(funcOp); 333 if (options.hoistRedundantVectorTransfers) 334 hoistRedundantVectorTransfers(funcOp); 335 336 if (options.hoistRedundantVectorTransfersOnTensor) 337 hoistRedundantVectorTransfersOnTensor(funcOp); 338 } 339 340 LinalgEnablingOptions options; 341 LinalgTransformationFilter filter; 342 }; 343 344 /// Configurable pass to lower vector operations. 345 struct LinalgStrategyLowerVectorsPass 346 : public LinalgStrategyLowerVectorsPassBase< 347 LinalgStrategyLowerVectorsPass> { 348 349 LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 350 LinalgTransformationFilter filt) 351 : options(opt), filter(filt) {} 352 353 void runOnFunction() override { 354 auto funcOp = getFunction(); 355 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 356 return; 357 358 MLIRContext *context = funcOp.getContext(); 359 RewritePatternSet patterns(context); 360 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 361 // In a progressive lowering of vectors, this would be the 1st step. 362 if (options.contractionLowering) { 363 patterns.add<ContractionOpToOuterProductOpLowering, 364 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 365 options.vectorTransformOptions, context); 366 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 367 } 368 // In a progressive lowering of vectors, this would be the 2nd step. 369 if (options.multiReductionLowering) { 370 vector::populateVectorMultiReductionLoweringPatterns( 371 patterns, 372 options.vectorTransformOptions.vectorMultiReductionLowering); 373 } 374 // In a progressive lowering of vectors, this would be the 3rd step. 375 if (options.transferPartialRewrite) { 376 patterns.add<vector::VectorTransferFullPartialRewriter>( 377 context, options.vectorTransformOptions); 378 } 379 // In a progressive lowering of vectors, this would be the 4th step. 380 if (options.transferLowering) { 381 vector::populateVectorTransferLoweringPatterns(patterns, 382 options.maxTransferRank); 383 } 384 // In a progressive lowering of vectors, this would be the 5th step. 385 if (options.transferToSCFConversion) { 386 populateVectorToSCFConversionPatterns( 387 patterns, options.vectorTransferToSCFOptions.setTargetRank( 388 options.maxTransferRank)); 389 } 390 // In a progressive lowering of vectors, this would be the 6th step. 391 if (options.shapeCastLowering) { 392 vector::populateVectorShapeCastLoweringPatterns(patterns); 393 } 394 // In a progressive lowering of vectors, this would be the 7th step. 395 if (options.transposeLowering) { 396 vector::populateVectorTransposeLoweringPatterns( 397 patterns, options.vectorTransformOptions); 398 if (options.avx2Lowering) 399 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 400 patterns, options.avx2LoweringOptions, /*benefit=*/10); 401 } 402 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 403 } 404 405 LinalgVectorLoweringOptions options; 406 LinalgTransformationFilter filter; 407 }; 408 409 /// Configurable pass to lower vector operations. 410 struct LinalgStrategyRemoveMarkersPass 411 : public LinalgStrategyRemoveMarkersPassBase< 412 LinalgStrategyRemoveMarkersPass> { 413 414 void runOnFunction() override { 415 auto funcOp = getFunction(); 416 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 417 return; 418 funcOp.walk([](LinalgOp op) { 419 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 420 }); 421 } 422 }; 423 } // namespace 424 425 /// Create a LinalgStrategyTileAndFusePass. 426 std::unique_ptr<OperationPass<FuncOp>> 427 mlir::createLinalgStrategyTileAndFusePass(StringRef opName, 428 LinalgTilingAndFusionOptions options, 429 LinalgTransformationFilter filter) { 430 return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options, 431 filter); 432 } 433 434 /// Create a LinalgStrategyTilePass. 435 std::unique_ptr<OperationPass<FuncOp>> 436 mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 437 LinalgTransformationFilter filter) { 438 return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter); 439 } 440 441 /// Create a LinalgStrategyPadPass. 442 std::unique_ptr<OperationPass<FuncOp>> 443 mlir::createLinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 444 LinalgTransformationFilter filter) { 445 return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter); 446 } 447 448 /// Create a LinalgStrategyPromotePass. 449 std::unique_ptr<OperationPass<FuncOp>> 450 mlir::createLinalgStrategyPromotePass(StringRef opName, 451 LinalgPromotionOptions opt, 452 LinalgTransformationFilter filter) { 453 return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter); 454 } 455 456 /// Create a LinalgStrategyGeneralizePass. 457 std::unique_ptr<OperationPass<FuncOp>> 458 mlir::createLinalgStrategyGeneralizePass(StringRef opName, 459 LinalgTransformationFilter filter) { 460 return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); 461 } 462 /// Create a LinalgStrategyDecomposePass. 463 // TODO: atm this is applied to all supported ops. If/when we need finer control 464 // this should be exposed with an opName + filter and a proper pattern. 465 std::unique_ptr<OperationPass<FuncOp>> 466 mlir::createLinalgStrategyDecomposePass() { 467 return std::make_unique<LinalgStrategyDecomposePass>(); 468 } 469 470 /// Create a LinalgStrategyInterchangePass. 471 std::unique_ptr<OperationPass<FuncOp>> 472 mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 473 LinalgTransformationFilter filter) { 474 return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, 475 filter); 476 } 477 478 /// Create a LinalgStrategyVectorizePass. 479 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass( 480 StringRef opName, LinalgVectorizationOptions opt, 481 LinalgTransformationFilter filter, bool padVectorize) { 482 return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter, 483 padVectorize); 484 } 485 486 /// Create a LinalgStrategyEnablePass. 487 std::unique_ptr<OperationPass<FuncOp>> 488 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt, 489 LinalgTransformationFilter filter) { 490 return std::make_unique<LinalgStrategyEnablePass>(opt, filter); 491 } 492 493 /// Create a LinalgStrategyLowerVectorsPass. 494 std::unique_ptr<OperationPass<FuncOp>> 495 mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 496 LinalgTransformationFilter filter) { 497 return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter); 498 } 499 500 /// Create a LinalgStrategyRemoveMarkersPass. 501 std::unique_ptr<OperationPass<FuncOp>> 502 mlir::createLinalgStrategyRemoveMarkersPass() { 503 return std::make_unique<LinalgStrategyRemoveMarkersPass>(); 504 } 505