1 //===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a configurable pass that can apply patterns liberally 10 // and be plugged in a pass pipeline. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/SCF/Transforms.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Dialect/Vector/VectorTransforms.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/AffineMap.h" 28 #include "mlir/Support/LLVM.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "mlir/Transforms/LoopUtils.h" 31 #include "mlir/Transforms/Utils.h" 32 33 using namespace mlir; 34 using namespace mlir::vector; 35 using namespace linalg; 36 37 namespace { 38 39 /// Configurable pass to apply pattern-based linalg tiling. 40 struct LinalgStrategyTilePass 41 : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> { 42 43 LinalgStrategyTilePass() = default; 44 45 LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 46 LinalgTransformationFilter filt) 47 : options(opt), filter(filt) { 48 this->anchorOpName.setValue(opName.str()); 49 } 50 51 void runOnFunction() override { 52 auto funcOp = getFunction(); 53 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 54 return; 55 56 RewritePatternSet tilingPattern(funcOp.getContext()); 57 if (!anchorOpName.empty()) { 58 tilingPattern.add<LinalgGenericTilingPattern>( 59 anchorOpName, funcOp.getContext(), options, filter); 60 } else { 61 tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter, 62 options); 63 } 64 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 65 } 66 67 LinalgTilingOptions options; 68 LinalgTransformationFilter filter; 69 }; 70 71 /// Configurable pass to apply hoisting and padding. 72 struct LinalgStrategyPadPass 73 : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> { 74 75 LinalgStrategyPadPass() = default; 76 77 LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 78 LinalgTransformationFilter filt) 79 : options(opt), filter(filt) { 80 this->anchorOpName.setValue(opName.str()); 81 } 82 83 void runOnFunction() override { 84 auto funcOp = getFunction(); 85 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 86 return; 87 88 RewritePatternSet paddingPattern(funcOp.getContext()); 89 if (!anchorOpName.empty()) { 90 paddingPattern.add<LinalgPaddingPattern>( 91 anchorOpName, funcOp.getContext(), options, filter); 92 } else { 93 paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options, 94 filter); 95 } 96 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)))) 97 signalPassFailure(); 98 } 99 100 LinalgPaddingOptions options; 101 LinalgTransformationFilter filter; 102 }; 103 104 /// Configurable pass to apply pattern-based linalg generalization. 105 struct LinalgStrategyGeneralizePass 106 : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> { 107 108 LinalgStrategyGeneralizePass() = default; 109 110 LinalgStrategyGeneralizePass(StringRef opName, 111 LinalgTransformationFilter filter) 112 : filter(filter) { 113 this->anchorOpName.setValue(opName.str()); 114 } 115 116 void runOnFunction() override { 117 auto funcOp = getFunction(); 118 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 119 return; 120 121 RewritePatternSet generalizationPattern(funcOp.getContext()); 122 if (!anchorOpName.empty()) { 123 generalizationPattern.add<LinalgGeneralizationPattern>( 124 anchorOpName, funcOp.getContext(), filter); 125 } else { 126 generalizationPattern.add<LinalgGeneralizationPattern>( 127 funcOp.getContext(), filter); 128 } 129 if (failed(applyPatternsAndFoldGreedily(funcOp, 130 std::move(generalizationPattern)))) 131 signalPassFailure(); 132 } 133 134 LinalgTransformationFilter filter; 135 }; 136 137 /// Configurable pass to apply lowering of coarser-grained named linalg ops into 138 /// finer-grained named versions. 139 struct LinalgStrategyDecomposePass 140 : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> { 141 142 LinalgStrategyDecomposePass() = default; 143 144 void runOnFunction() override { 145 auto funcOp = getFunction(); 146 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 147 return; 148 RewritePatternSet decompositionPattern(funcOp.getContext()); 149 populateDecomposeConvolutionPatterns(decompositionPattern); 150 if (failed(applyPatternsAndFoldGreedily(funcOp, 151 std::move(decompositionPattern)))) 152 signalPassFailure(); 153 } 154 }; 155 156 /// Configurable pass to apply pattern-based linalg generalization. 157 struct LinalgStrategyInterchangePass 158 : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { 159 160 LinalgStrategyInterchangePass() = default; 161 162 LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 163 LinalgTransformationFilter filter) 164 : iteratorInterchange(iteratorInterchange.begin(), 165 iteratorInterchange.end()), 166 filter(filter) {} 167 168 void runOnFunction() override { 169 auto funcOp = getFunction(); 170 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 171 return; 172 173 SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(), 174 iteratorInterchange.end()); 175 RewritePatternSet interchangePattern(funcOp.getContext()); 176 interchangePattern.add<GenericOpInterchangePattern>( 177 funcOp.getContext(), interchangeVector, filter); 178 if (failed(applyPatternsAndFoldGreedily(funcOp, 179 std::move(interchangePattern)))) 180 signalPassFailure(); 181 } 182 183 SmallVector<int64_t> iteratorInterchange; 184 LinalgTransformationFilter filter; 185 }; 186 187 /// Configurable pass to apply pattern-based linalg promotion. 188 struct LinalgStrategyPromotePass 189 : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> { 190 191 LinalgStrategyPromotePass() = default; 192 193 LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, 194 LinalgTransformationFilter filt) 195 : options(opt), filter(filt) { 196 this->anchorOpName.setValue(opName.str()); 197 } 198 199 void runOnFunction() override { 200 auto funcOp = getFunction(); 201 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 202 return; 203 204 RewritePatternSet promotionPattern(funcOp.getContext()); 205 if (!anchorOpName.empty()) { 206 promotionPattern.add<LinalgBasePromotionPattern>( 207 anchorOpName, funcOp.getContext(), options, filter); 208 } else { 209 promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(), 210 filter, options); 211 } 212 (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); 213 } 214 215 LinalgPromotionOptions options; 216 LinalgTransformationFilter filter; 217 }; 218 219 /// Configurable pass to apply pattern-based linalg vectorization. 220 struct LinalgStrategyVectorizePass 221 : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> { 222 223 LinalgStrategyVectorizePass() = default; 224 225 LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, 226 LinalgTransformationFilter filt) 227 : options(opt), filter(filt) { 228 this->anchorOpName.setValue(opName.str()); 229 } 230 231 void runOnFunction() override { 232 auto funcOp = getFunction(); 233 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 234 return; 235 236 RewritePatternSet vectorizationPatterns(funcOp.getContext()); 237 if (!anchorOpName.empty()) { 238 vectorizationPatterns.add<LinalgVectorizationPattern>( 239 anchorOpName, funcOp.getContext(), options, filter); 240 } else { 241 vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(), 242 filter, options); 243 } 244 vector::populateVectorTransferPermutationMapLoweringPatterns( 245 vectorizationPatterns); 246 vector::populateVectorReductionToContractPatterns(vectorizationPatterns); 247 vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern, 248 linalg::LinalgCopyVTWForwardingPattern>( 249 funcOp.getContext(), /*benefit=*/2); 250 (void)applyPatternsAndFoldGreedily(funcOp, 251 std::move(vectorizationPatterns)); 252 } 253 254 LinalgVectorizationOptions options; 255 LinalgTransformationFilter filter; 256 }; 257 258 /// Configurable pass to enable the application of other pattern-based linalg 259 /// passes. 260 struct LinalgStrategyEnablePass 261 : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> { 262 263 LinalgStrategyEnablePass(LinalgEnablingOptions opt, 264 LinalgTransformationFilter filt) 265 : options(opt), filter(filt) {} 266 267 void runOnFunction() override { 268 auto funcOp = getFunction(); 269 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 270 return; 271 272 MLIRContext *context = funcOp.getContext(); 273 RewritePatternSet patterns = 274 linalg::getLinalgTilingCanonicalizationPatterns(context); 275 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 276 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 277 return signalPassFailure(); 278 279 if (options.licm) { 280 if (funcOp 281 ->walk([&](LoopLikeOpInterface loopLike) { 282 if (failed(moveLoopInvariantCode(loopLike))) 283 return WalkResult::interrupt(); 284 return WalkResult::advance(); 285 }) 286 .wasInterrupted()) 287 return signalPassFailure(); 288 } 289 290 promoteSingleIterationLoops(funcOp); 291 if (options.hoistRedundantVectorTransfers) 292 hoistRedundantVectorTransfers(funcOp); 293 294 if (options.hoistRedundantVectorTransfersOnTensor) 295 hoistRedundantVectorTransfersOnTensor(funcOp); 296 } 297 298 LinalgEnablingOptions options; 299 LinalgTransformationFilter filter; 300 }; 301 302 /// Configurable pass to lower vector operations. 303 struct LinalgStrategyLowerVectorsPass 304 : public LinalgStrategyLowerVectorsPassBase< 305 LinalgStrategyLowerVectorsPass> { 306 307 LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 308 LinalgTransformationFilter filt) 309 : options(opt), filter(filt) {} 310 311 void runOnFunction() override { 312 auto funcOp = getFunction(); 313 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 314 return; 315 316 MLIRContext *context = funcOp.getContext(); 317 RewritePatternSet patterns(context); 318 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 319 // In a progressive lowering of vectors, this would be the 1st step. 320 if (options.contractionLowering) { 321 patterns.add<ContractionOpToOuterProductOpLowering, 322 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 323 options.vectorTransformOptions, context); 324 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 325 } 326 // In a progressive lowering of vectors, this would be the 2nd step. 327 if (options.multiReductionLowering) { 328 vector::populateVectorMultiReductionLoweringPatterns( 329 patterns, 330 options.vectorTransformOptions.vectorMultiReductionLowering); 331 } 332 // In a progressive lowering of vectors, this would be the 3rd step. 333 if (options.transferPartialRewrite) { 334 patterns.add<vector::VectorTransferFullPartialRewriter>( 335 context, options.vectorTransformOptions); 336 } 337 // In a progressive lowering of vectors, this would be the 4th step. 338 if (options.transferLowering) { 339 vector::populateVectorTransferLoweringPatterns(patterns, 340 options.maxTransferRank); 341 } 342 // In a progressive lowering of vectors, this would be the 5th step. 343 if (options.transferToSCFConversion) { 344 populateVectorToSCFConversionPatterns( 345 patterns, options.vectorTransferToSCFOptions.setTargetRank( 346 options.maxTransferRank)); 347 } 348 // In a progressive lowering of vectors, this would be the 6th step. 349 if (options.shapeCastLowering) { 350 vector::populateVectorShapeCastLoweringPatterns(patterns); 351 } 352 // In a progressive lowering of vectors, this would be the 7th step. 353 if (options.transposeLowering) { 354 vector::populateVectorTransposeLoweringPatterns( 355 patterns, options.vectorTransformOptions); 356 if (options.avx2Lowering) 357 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 358 patterns, options.avx2LoweringOptions, /*benefit=*/10); 359 } 360 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 361 } 362 363 LinalgVectorLoweringOptions options; 364 LinalgTransformationFilter filter; 365 }; 366 367 /// Configurable pass to lower vector operations. 368 struct LinalgStrategyRemoveMarkersPass 369 : public LinalgStrategyRemoveMarkersPassBase< 370 LinalgStrategyRemoveMarkersPass> { 371 372 void runOnFunction() override { 373 auto funcOp = getFunction(); 374 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 375 return; 376 funcOp.walk([](LinalgOp op) { 377 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 378 }); 379 } 380 }; 381 } // namespace 382 383 /// Create a LinalgStrategyTilePass. 384 std::unique_ptr<OperationPass<FuncOp>> 385 mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 386 LinalgTransformationFilter filter) { 387 return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter); 388 } 389 390 /// Create a LinalgStrategyPadPass. 391 std::unique_ptr<OperationPass<FuncOp>> 392 mlir::createLinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, 393 LinalgTransformationFilter filter) { 394 return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter); 395 } 396 397 /// Create a LinalgStrategyPromotePass. 398 std::unique_ptr<OperationPass<FuncOp>> 399 mlir::createLinalgStrategyPromotePass(StringRef opName, 400 LinalgPromotionOptions opt, 401 LinalgTransformationFilter filter) { 402 return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter); 403 } 404 405 /// Create a LinalgStrategyGeneralizePass. 406 std::unique_ptr<OperationPass<FuncOp>> 407 mlir::createLinalgStrategyGeneralizePass(StringRef opName, 408 LinalgTransformationFilter filter) { 409 return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); 410 } 411 /// Create a LinalgStrategyDecomposePass. 412 // TODO: atm this is applied to all supported ops. If/when we need finer control 413 // this should be exposed with an opName + filter and a proper pattern. 414 std::unique_ptr<OperationPass<FuncOp>> 415 mlir::createLinalgStrategyDecomposePass() { 416 return std::make_unique<LinalgStrategyDecomposePass>(); 417 } 418 419 /// Create a LinalgStrategyInterchangePass. 420 std::unique_ptr<OperationPass<FuncOp>> 421 mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, 422 LinalgTransformationFilter filter) { 423 return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, 424 filter); 425 } 426 427 /// Create a LinalgStrategyVectorizePass. 428 std::unique_ptr<OperationPass<FuncOp>> 429 mlir::createLinalgStrategyVectorizePass(StringRef opName, 430 LinalgVectorizationOptions opt, 431 LinalgTransformationFilter filter) { 432 return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter); 433 } 434 435 /// Create a LinalgStrategyEnablePass. 436 std::unique_ptr<OperationPass<FuncOp>> 437 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt, 438 LinalgTransformationFilter filter) { 439 return std::make_unique<LinalgStrategyEnablePass>(opt, filter); 440 } 441 442 /// Create a LinalgStrategyLowerVectorsPass. 443 std::unique_ptr<OperationPass<FuncOp>> 444 mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 445 LinalgTransformationFilter filter) { 446 return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter); 447 } 448 449 /// Create a LinalgStrategyRemoveMarkersPass. 450 std::unique_ptr<OperationPass<FuncOp>> 451 mlir::createLinalgStrategyRemoveMarkersPass() { 452 return std::make_unique<LinalgStrategyRemoveMarkersPass>(); 453 } 454