1 //===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a configurable pass that can apply patterns liberally
10 // and be plugged in a pass pipeline.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/SCF/Transforms.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Vector/VectorOps.h"
26 #include "mlir/Dialect/Vector/VectorTransforms.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/Support/LLVM.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "mlir/Transforms/LoopUtils.h"
32 #include "mlir/Transforms/Utils.h"
33 
34 using namespace mlir;
35 using namespace linalg;
36 
37 namespace {
38 
39 /// Configurable pass to apply pattern-based linalg tiling.
40 struct LinalgStrategyTilePass
41     : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
42 
43   LinalgStrategyTilePass() = default;
44 
45   LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
46                          LinalgTransformationFilter filt)
47       : options(opt), filter(filt) {
48     this->anchorOpName.setValue(opName.str());
49   }
50 
51   void runOnFunction() override {
52     auto funcOp = getFunction();
53     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
54       return;
55 
56     RewritePatternSet tilingPattern(funcOp.getContext());
57     if (!anchorOpName.empty()) {
58       tilingPattern.add<LinalgGenericTilingPattern>(
59           anchorOpName, funcOp.getContext(), options, filter);
60     } else {
61       tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
62                                                     options);
63     }
64     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
65   }
66 
67   LinalgTilingOptions options;
68   LinalgTransformationFilter filter;
69 };
70 
71 /// Configurable pass to apply pattern-based linalg promotion.
72 struct LinalgStrategyPromotePass
73     : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
74 
75   LinalgStrategyPromotePass() = default;
76 
77   LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
78                             LinalgTransformationFilter filt)
79       : options(opt), filter(filt) {
80     this->anchorOpName.setValue(opName.str());
81   }
82 
83   void runOnFunction() override {
84     auto funcOp = getFunction();
85     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
86       return;
87 
88     RewritePatternSet promotionPattern(funcOp.getContext());
89     if (!anchorOpName.empty()) {
90       promotionPattern.add<LinalgBasePromotionPattern>(
91           anchorOpName, funcOp.getContext(), options, filter);
92     } else {
93       promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
94                                                        filter, options);
95     }
96     (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
97   }
98 
99   LinalgPromotionOptions options;
100   LinalgTransformationFilter filter;
101 };
102 
103 /// Configurable pass to apply pattern-based linalg vectorization.
104 struct LinalgStrategyVectorizePass
105     : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
106 
107   LinalgStrategyVectorizePass() = default;
108 
109   LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
110                               LinalgTransformationFilter filt)
111       : options(opt), filter(filt) {
112     this->anchorOpName.setValue(opName.str());
113   }
114 
115   void runOnFunction() override {
116     auto funcOp = getFunction();
117     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
118       return;
119 
120     RewritePatternSet vectorizationPatterns(funcOp.getContext());
121     if (!anchorOpName.empty()) {
122       vectorizationPatterns.add<LinalgVectorizationPattern>(
123           anchorOpName, funcOp.getContext(), options, filter);
124     } else {
125       vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
126                                                             filter, options);
127     }
128     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
129                               linalg::LinalgCopyVTWForwardingPattern>(
130         funcOp.getContext(), /*benefit=*/2);
131     (void)applyPatternsAndFoldGreedily(funcOp,
132                                        std::move(vectorizationPatterns));
133   }
134 
135   LinalgVectorizationOptions options;
136   LinalgTransformationFilter filter;
137 };
138 
139 /// Configurable pass to enable the application of other pattern-based linalg
140 /// passes.
141 struct LinalgStrategyEnablePass
142     : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
143 
144   LinalgStrategyEnablePass(LinalgEnablingOptions opt,
145                            LinalgTransformationFilter filt)
146       : options(opt), filter(filt) {}
147 
148   void runOnFunction() override {
149     auto funcOp = getFunction();
150     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
151       return;
152 
153     MLIRContext *context = funcOp.getContext();
154     RewritePatternSet patterns =
155         linalg::getLinalgTilingCanonicalizationPatterns(context);
156     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
157     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
158       return signalPassFailure();
159 
160     if (options.enableLICM) {
161       if (funcOp
162               ->walk([&](LoopLikeOpInterface loopLike) {
163                 if (failed(moveLoopInvariantCode(loopLike)))
164                   return WalkResult::interrupt();
165                 return WalkResult::advance();
166               })
167               .wasInterrupted())
168         return signalPassFailure();
169     }
170 
171     promoteSingleIterationLoops(funcOp);
172     if (options.enableHoistRedundantVectorTransfers)
173       hoistRedundantVectorTransfers(funcOp);
174 
175     if (options.enableHoistRedundantVectorTransfersOnTensor)
176       hoistRedundantVectorTransfersOnTensor(funcOp);
177   }
178 
179   LinalgEnablingOptions options;
180   LinalgTransformationFilter filter;
181 };
182 
183 /// Configurable pass to lower vector operations.
184 struct LinalgStrategyLowerVectorsPass
185     : public LinalgStrategyLowerVectorsPassBase<
186           LinalgStrategyLowerVectorsPass> {
187 
188   LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
189                                  LinalgTransformationFilter filt)
190       : options(opt), filter(filt) {}
191 
192   void runOnFunction() override {
193     auto funcOp = getFunction();
194     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
195       return;
196 
197     MLIRContext *context = funcOp.getContext();
198     RewritePatternSet patterns(context);
199     if (options.enableVectorTransferPartialRewrite) {
200       patterns.add<vector::VectorTransferFullPartialRewriter>(
201           context, options.vectorTransformOptions);
202     }
203     if (options.enableVectorContractLowering) {
204       patterns.add<ContractionOpToOuterProductOpLowering,
205                    ContractionOpToMatmulOpLowering, ContractionOpLowering>(
206           options.vectorTransformOptions, context);
207       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
208     }
209     if (options.enableVectorToSCFConversion) {
210       populateVectorToSCFConversionPatterns(patterns,
211                                             options.vectorTransferToSCFOptions);
212     }
213     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
214   }
215 
216   LinalgVectorLoweringOptions options;
217   LinalgTransformationFilter filter;
218 };
219 } // namespace
220 
221 /// Create a LinalgStrategyTilePass.
222 std::unique_ptr<OperationPass<FuncOp>>
223 mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
224                                    LinalgTransformationFilter filter) {
225   return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
226 }
227 
228 /// Create a LinalgStrategyPromotePass.
229 std::unique_ptr<OperationPass<FuncOp>>
230 mlir::createLinalgStrategyPromotePass(StringRef opName,
231                                       LinalgPromotionOptions opt,
232                                       LinalgTransformationFilter filter) {
233   return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
234 }
235 
236 /// Create a LinalgStrategyVectorizePass.
237 std::unique_ptr<OperationPass<FuncOp>>
238 mlir::createLinalgStrategyVectorizePass(StringRef opName,
239                                         LinalgVectorizationOptions opt,
240                                         LinalgTransformationFilter filter) {
241   return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter);
242 }
243 
244 /// Create a LinalgStrategyEnablePass.
245 std::unique_ptr<OperationPass<FuncOp>>
246 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
247                                      LinalgTransformationFilter filter) {
248   return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
249 }
250 
251 /// Create a LinalgStrategyLowerVectorsPass.
252 std::unique_ptr<OperationPass<FuncOp>>
253 mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
254                                            LinalgTransformationFilter filter) {
255   return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
256 }
257