1 //===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a configurable pass that can apply patterns liberally 10 // and be plugged in a pass pipeline. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/SCF/Transforms.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Dialect/Vector/VectorOps.h" 26 #include "mlir/Dialect/Vector/VectorTransforms.h" 27 #include "mlir/IR/AffineExpr.h" 28 #include "mlir/IR/AffineMap.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "mlir/Transforms/LoopUtils.h" 32 #include "mlir/Transforms/Utils.h" 33 34 using namespace mlir; 35 using namespace linalg; 36 37 namespace { 38 39 /// Configurable pass to apply pattern-based linalg tiling. 40 struct LinalgStrategyTilePass 41 : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> { 42 43 LinalgStrategyTilePass() = default; 44 45 LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 46 LinalgTransformationFilter filt) 47 : options(opt), filter(filt) { 48 this->anchorOpName.setValue(opName.str()); 49 } 50 51 void runOnFunction() override { 52 auto funcOp = getFunction(); 53 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 54 return; 55 56 RewritePatternSet tilingPattern(funcOp.getContext()); 57 if (!anchorOpName.empty()) { 58 tilingPattern.add<LinalgGenericTilingPattern>( 59 anchorOpName, funcOp.getContext(), options, filter); 60 } else { 61 tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter, 62 options); 63 } 64 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 65 } 66 67 LinalgTilingOptions options; 68 LinalgTransformationFilter filter; 69 }; 70 71 /// Configurable pass to apply pattern-based linalg promotion. 72 struct LinalgStrategyPromotePass 73 : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> { 74 75 LinalgStrategyPromotePass() = default; 76 77 LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, 78 LinalgTransformationFilter filt) 79 : options(opt), filter(filt) { 80 this->anchorOpName.setValue(opName.str()); 81 } 82 83 void runOnFunction() override { 84 auto funcOp = getFunction(); 85 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 86 return; 87 88 RewritePatternSet promotionPattern(funcOp.getContext()); 89 if (!anchorOpName.empty()) { 90 promotionPattern.add<LinalgBasePromotionPattern>( 91 anchorOpName, funcOp.getContext(), options, filter); 92 } else { 93 promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(), 94 filter, options); 95 } 96 (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); 97 } 98 99 LinalgPromotionOptions options; 100 LinalgTransformationFilter filter; 101 }; 102 103 /// Configurable pass to apply pattern-based linalg vectorization. 104 struct LinalgStrategyVectorizePass 105 : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> { 106 107 LinalgStrategyVectorizePass() = default; 108 109 LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, 110 LinalgTransformationFilter filt) 111 : options(opt), filter(filt) { 112 this->anchorOpName.setValue(opName.str()); 113 } 114 115 void runOnFunction() override { 116 auto funcOp = getFunction(); 117 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 118 return; 119 120 RewritePatternSet vectorizationPatterns(funcOp.getContext()); 121 if (!anchorOpName.empty()) { 122 vectorizationPatterns.add<LinalgVectorizationPattern>( 123 anchorOpName, funcOp.getContext(), options, filter); 124 } else { 125 vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(), 126 filter, options); 127 } 128 vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern, 129 linalg::LinalgCopyVTWForwardingPattern>( 130 funcOp.getContext(), /*benefit=*/2); 131 (void)applyPatternsAndFoldGreedily(funcOp, 132 std::move(vectorizationPatterns)); 133 } 134 135 LinalgVectorizationOptions options; 136 LinalgTransformationFilter filter; 137 }; 138 139 /// Configurable pass to enable the application of other pattern-based linalg 140 /// passes. 141 struct LinalgStrategyEnablePass 142 : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> { 143 144 LinalgStrategyEnablePass(LinalgEnablingOptions opt, 145 LinalgTransformationFilter filt) 146 : options(opt), filter(filt) {} 147 148 void runOnFunction() override { 149 auto funcOp = getFunction(); 150 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 151 return; 152 153 MLIRContext *context = funcOp.getContext(); 154 RewritePatternSet patterns = 155 linalg::getLinalgTilingCanonicalizationPatterns(context); 156 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 157 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 158 return signalPassFailure(); 159 160 if (options.enableLICM) { 161 if (funcOp 162 ->walk([&](LoopLikeOpInterface loopLike) { 163 if (failed(moveLoopInvariantCode(loopLike))) 164 return WalkResult::interrupt(); 165 return WalkResult::advance(); 166 }) 167 .wasInterrupted()) 168 return signalPassFailure(); 169 } 170 171 promoteSingleIterationLoops(funcOp); 172 if (options.enableHoistRedundantVectorTransfers) 173 hoistRedundantVectorTransfers(funcOp); 174 175 if (options.enableHoistRedundantVectorTransfersOnTensor) 176 hoistRedundantVectorTransfersOnTensor(funcOp); 177 } 178 179 LinalgEnablingOptions options; 180 LinalgTransformationFilter filter; 181 }; 182 183 /// Configurable pass to lower vector operations. 184 struct LinalgStrategyLowerVectorsPass 185 : public LinalgStrategyLowerVectorsPassBase< 186 LinalgStrategyLowerVectorsPass> { 187 188 LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 189 LinalgTransformationFilter filt) 190 : options(opt), filter(filt) {} 191 192 void runOnFunction() override { 193 auto funcOp = getFunction(); 194 if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) 195 return; 196 197 MLIRContext *context = funcOp.getContext(); 198 RewritePatternSet patterns(context); 199 if (options.enableVectorTransferPartialRewrite) { 200 patterns.add<vector::VectorTransferFullPartialRewriter>( 201 context, options.vectorTransformOptions); 202 } 203 if (options.enableVectorContractLowering) { 204 patterns.add<ContractionOpToOuterProductOpLowering, 205 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 206 options.vectorTransformOptions, context); 207 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 208 } 209 if (options.enableVectorToSCFConversion) { 210 populateVectorToSCFConversionPatterns(patterns, 211 options.vectorTransferToSCFOptions); 212 } 213 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 214 } 215 216 LinalgVectorLoweringOptions options; 217 LinalgTransformationFilter filter; 218 }; 219 } // namespace 220 221 /// Create a LinalgStrategyTilePass. 222 std::unique_ptr<OperationPass<FuncOp>> 223 mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, 224 LinalgTransformationFilter filter) { 225 return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter); 226 } 227 228 /// Create a LinalgStrategyPromotePass. 229 std::unique_ptr<OperationPass<FuncOp>> 230 mlir::createLinalgStrategyPromotePass(StringRef opName, 231 LinalgPromotionOptions opt, 232 LinalgTransformationFilter filter) { 233 return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter); 234 } 235 236 /// Create a LinalgStrategyVectorizePass. 237 std::unique_ptr<OperationPass<FuncOp>> 238 mlir::createLinalgStrategyVectorizePass(StringRef opName, 239 LinalgVectorizationOptions opt, 240 LinalgTransformationFilter filter) { 241 return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter); 242 } 243 244 /// Create a LinalgStrategyEnablePass. 245 std::unique_ptr<OperationPass<FuncOp>> 246 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt, 247 LinalgTransformationFilter filter) { 248 return std::make_unique<LinalgStrategyEnablePass>(opt, filter); 249 } 250 251 /// Create a LinalgStrategyLowerVectorsPass. 252 std::unique_ptr<OperationPass<FuncOp>> 253 mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, 254 LinalgTransformationFilter filter) { 255 return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter); 256 } 257