1 //===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a configurable pass that can apply patterns liberally
10 // and be plugged in a pass pipeline.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/SCF/Transforms.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Vector/VectorTransforms.h"
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/Support/LLVM.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "mlir/Transforms/LoopUtils.h"
31 #include "mlir/Transforms/Utils.h"
32 
33 using namespace mlir;
34 using namespace mlir::vector;
35 using namespace linalg;
36 
37 namespace {
38 
39 /// Configurable pass to apply pattern-based linalg tiling.
40 struct LinalgStrategyTilePass
41     : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
42 
43   LinalgStrategyTilePass() = default;
44 
45   LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
46                          LinalgTransformationFilter filt)
47       : options(opt), filter(filt) {
48     this->anchorOpName.setValue(opName.str());
49   }
50 
51   void runOnFunction() override {
52     auto funcOp = getFunction();
53     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
54       return;
55 
56     RewritePatternSet tilingPattern(funcOp.getContext());
57     if (!anchorOpName.empty()) {
58       tilingPattern.add<LinalgGenericTilingPattern>(
59           anchorOpName, funcOp.getContext(), options, filter);
60     } else {
61       tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
62                                                     options);
63     }
64     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
65   }
66 
67   LinalgTilingOptions options;
68   LinalgTransformationFilter filter;
69 };
70 
71 /// Configurable pass to apply hoisting and padding.
72 struct LinalgStrategyPadPass
73     : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> {
74 
75   LinalgStrategyPadPass() = default;
76 
77   LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
78                         LinalgTransformationFilter filt)
79       : options(opt), filter(filt) {
80     this->anchorOpName.setValue(opName.str());
81   }
82 
83   void runOnFunction() override {
84     auto funcOp = getFunction();
85     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
86       return;
87 
88     RewritePatternSet paddingPattern(funcOp.getContext());
89     if (!anchorOpName.empty()) {
90       paddingPattern.add<LinalgPaddingPattern>(
91           anchorOpName, funcOp.getContext(), options, filter);
92     } else {
93       paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options,
94                                                filter);
95     }
96     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern))))
97       signalPassFailure();
98   }
99 
100   LinalgPaddingOptions options;
101   LinalgTransformationFilter filter;
102 };
103 
104 /// Configurable pass to apply pattern-based linalg generalization.
105 struct LinalgStrategyGeneralizePass
106     : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
107 
108   LinalgStrategyGeneralizePass() = default;
109 
110   LinalgStrategyGeneralizePass(StringRef opName,
111                                LinalgTransformationFilter filter)
112       : filter(filter) {
113     this->anchorOpName.setValue(opName.str());
114   }
115 
116   void runOnFunction() override {
117     auto funcOp = getFunction();
118     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
119       return;
120 
121     RewritePatternSet generalizationPattern(funcOp.getContext());
122     if (!anchorOpName.empty()) {
123       generalizationPattern.add<LinalgGeneralizationPattern>(
124           anchorOpName, funcOp.getContext(), filter);
125     } else {
126       generalizationPattern.add<LinalgGeneralizationPattern>(
127           funcOp.getContext(), filter);
128     }
129     if (failed(applyPatternsAndFoldGreedily(funcOp,
130                                             std::move(generalizationPattern))))
131       signalPassFailure();
132   }
133 
134   LinalgTransformationFilter filter;
135 };
136 
137 /// Configurable pass to apply lowering of coarser-grained named linalg ops into
138 /// finer-grained named versions.
139 struct LinalgStrategyDecomposePass
140     : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {
141 
142   LinalgStrategyDecomposePass() = default;
143 
144   void runOnFunction() override {
145     auto funcOp = getFunction();
146     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
147       return;
148     RewritePatternSet decompositionPattern(funcOp.getContext());
149     populateDecomposeConvolutionPatterns(decompositionPattern);
150     if (failed(applyPatternsAndFoldGreedily(funcOp,
151                                             std::move(decompositionPattern))))
152       signalPassFailure();
153   }
154 };
155 
156 /// Configurable pass to apply pattern-based linalg generalization.
157 struct LinalgStrategyInterchangePass
158     : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
159 
160   LinalgStrategyInterchangePass() = default;
161 
162   LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
163                                 LinalgTransformationFilter filter)
164       : iteratorInterchange(iteratorInterchange.begin(),
165                             iteratorInterchange.end()),
166         filter(filter) {}
167 
168   void runOnFunction() override {
169     auto funcOp = getFunction();
170     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
171       return;
172 
173     SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
174                                             iteratorInterchange.end());
175     RewritePatternSet interchangePattern(funcOp.getContext());
176     interchangePattern.add<GenericOpInterchangePattern>(
177         funcOp.getContext(), interchangeVector, filter);
178     if (failed(applyPatternsAndFoldGreedily(funcOp,
179                                             std::move(interchangePattern))))
180       signalPassFailure();
181   }
182 
183   SmallVector<int64_t> iteratorInterchange;
184   LinalgTransformationFilter filter;
185 };
186 
187 /// Configurable pass to apply pattern-based linalg promotion.
188 struct LinalgStrategyPromotePass
189     : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
190 
191   LinalgStrategyPromotePass() = default;
192 
193   LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
194                             LinalgTransformationFilter filt)
195       : options(opt), filter(filt) {
196     this->anchorOpName.setValue(opName.str());
197   }
198 
199   void runOnFunction() override {
200     auto funcOp = getFunction();
201     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
202       return;
203 
204     RewritePatternSet promotionPattern(funcOp.getContext());
205     if (!anchorOpName.empty()) {
206       promotionPattern.add<LinalgBasePromotionPattern>(
207           anchorOpName, funcOp.getContext(), options, filter);
208     } else {
209       promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
210                                                        filter, options);
211     }
212     (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
213   }
214 
215   LinalgPromotionOptions options;
216   LinalgTransformationFilter filter;
217 };
218 
219 /// Configurable pass to apply pattern-based linalg vectorization.
220 struct LinalgStrategyVectorizePass
221     : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
222 
223   LinalgStrategyVectorizePass() = default;
224 
225   LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
226                               LinalgTransformationFilter filt)
227       : options(opt), filter(filt) {
228     this->anchorOpName.setValue(opName.str());
229   }
230 
231   void runOnFunction() override {
232     auto funcOp = getFunction();
233     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
234       return;
235 
236     RewritePatternSet vectorizationPatterns(funcOp.getContext());
237     if (!anchorOpName.empty()) {
238       vectorizationPatterns.add<LinalgVectorizationPattern>(
239           anchorOpName, funcOp.getContext(), options, filter);
240     } else {
241       vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
242                                                             filter, options);
243     }
244     vector::populateVectorTransferPermutationMapLoweringPatterns(
245         vectorizationPatterns);
246     vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
247     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
248                               linalg::LinalgCopyVTWForwardingPattern>(
249         funcOp.getContext(), /*benefit=*/2);
250     (void)applyPatternsAndFoldGreedily(funcOp,
251                                        std::move(vectorizationPatterns));
252   }
253 
254   LinalgVectorizationOptions options;
255   LinalgTransformationFilter filter;
256 };
257 
258 /// Configurable pass to enable the application of other pattern-based linalg
259 /// passes.
260 struct LinalgStrategyEnablePass
261     : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
262 
263   LinalgStrategyEnablePass(LinalgEnablingOptions opt,
264                            LinalgTransformationFilter filt)
265       : options(opt), filter(filt) {}
266 
267   void runOnFunction() override {
268     auto funcOp = getFunction();
269     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
270       return;
271 
272     MLIRContext *context = funcOp.getContext();
273     RewritePatternSet patterns =
274         linalg::getLinalgTilingCanonicalizationPatterns(context);
275     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
276     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
277       return signalPassFailure();
278 
279     if (options.licm) {
280       if (funcOp
281               ->walk([&](LoopLikeOpInterface loopLike) {
282                 if (failed(moveLoopInvariantCode(loopLike)))
283                   return WalkResult::interrupt();
284                 return WalkResult::advance();
285               })
286               .wasInterrupted())
287         return signalPassFailure();
288     }
289 
290     promoteSingleIterationLoops(funcOp);
291     if (options.hoistRedundantVectorTransfers)
292       hoistRedundantVectorTransfers(funcOp);
293 
294     if (options.hoistRedundantVectorTransfersOnTensor)
295       hoistRedundantVectorTransfersOnTensor(funcOp);
296   }
297 
298   LinalgEnablingOptions options;
299   LinalgTransformationFilter filter;
300 };
301 
302 /// Configurable pass to lower vector operations.
303 struct LinalgStrategyLowerVectorsPass
304     : public LinalgStrategyLowerVectorsPassBase<
305           LinalgStrategyLowerVectorsPass> {
306 
307   LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
308                                  LinalgTransformationFilter filt)
309       : options(opt), filter(filt) {}
310 
311   void runOnFunction() override {
312     auto funcOp = getFunction();
313     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
314       return;
315 
316     MLIRContext *context = funcOp.getContext();
317     RewritePatternSet patterns(context);
318     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
319     // In a progressive lowering of vectors, this would be the 1st step.
320     if (options.contractionLowering) {
321       patterns.add<ContractionOpToOuterProductOpLowering,
322                    ContractionOpToMatmulOpLowering, ContractionOpLowering>(
323           options.vectorTransformOptions, context);
324       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
325     }
326     // In a progressive lowering of vectors, this would be the 2nd step.
327     if (options.multiReductionLowering) {
328       vector::populateVectorMultiReductionLoweringPatterns(
329           patterns,
330           options.vectorTransformOptions.vectorMultiReductionLowering);
331     }
332     // In a progressive lowering of vectors, this would be the 3rd step.
333     if (options.transferPartialRewrite) {
334       patterns.add<vector::VectorTransferFullPartialRewriter>(
335           context, options.vectorTransformOptions);
336     }
337     // In a progressive lowering of vectors, this would be the 4th step.
338     if (options.transferLowering) {
339       vector::populateVectorTransferLoweringPatterns(patterns,
340                                                      options.maxTransferRank);
341     }
342     // In a progressive lowering of vectors, this would be the 5th step.
343     if (options.transferToSCFConversion) {
344       populateVectorToSCFConversionPatterns(
345           patterns, options.vectorTransferToSCFOptions.setTargetRank(
346                         options.maxTransferRank));
347     }
348     // In a progressive lowering of vectors, this would be the 6th step.
349     if (options.shapeCastLowering) {
350       vector::populateVectorShapeCastLoweringPatterns(patterns);
351     }
352     // In a progressive lowering of vectors, this would be the 7th step.
353     if (options.transposeLowering) {
354       vector::populateVectorTransposeLoweringPatterns(
355           patterns, options.vectorTransformOptions);
356       if (options.avx2Lowering)
357         x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
358             patterns, options.avx2LoweringOptions, /*benefit=*/10);
359     }
360     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
361   }
362 
363   LinalgVectorLoweringOptions options;
364   LinalgTransformationFilter filter;
365 };
366 
367 /// Configurable pass to lower vector operations.
368 struct LinalgStrategyRemoveMarkersPass
369     : public LinalgStrategyRemoveMarkersPassBase<
370           LinalgStrategyRemoveMarkersPass> {
371 
372   void runOnFunction() override {
373     auto funcOp = getFunction();
374     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
375       return;
376     funcOp.walk([](LinalgOp op) {
377       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
378     });
379   }
380 };
381 } // namespace
382 
383 /// Create a LinalgStrategyTilePass.
384 std::unique_ptr<OperationPass<FuncOp>>
385 mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
386                                    LinalgTransformationFilter filter) {
387   return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
388 }
389 
390 /// Create a LinalgStrategyPadPass.
391 std::unique_ptr<OperationPass<FuncOp>>
392 mlir::createLinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
393                                   LinalgTransformationFilter filter) {
394   return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
395 }
396 
397 /// Create a LinalgStrategyPromotePass.
398 std::unique_ptr<OperationPass<FuncOp>>
399 mlir::createLinalgStrategyPromotePass(StringRef opName,
400                                       LinalgPromotionOptions opt,
401                                       LinalgTransformationFilter filter) {
402   return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
403 }
404 
405 /// Create a LinalgStrategyGeneralizePass.
406 std::unique_ptr<OperationPass<FuncOp>>
407 mlir::createLinalgStrategyGeneralizePass(StringRef opName,
408                                          LinalgTransformationFilter filter) {
409   return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
410 }
411 /// Create a LinalgStrategyDecomposePass.
412 // TODO: atm this is applied to all supported ops. If/when we need finer control
413 // this should be exposed with an opName + filter and a proper pattern.
414 std::unique_ptr<OperationPass<FuncOp>>
415 mlir::createLinalgStrategyDecomposePass() {
416   return std::make_unique<LinalgStrategyDecomposePass>();
417 }
418 
419 /// Create a LinalgStrategyInterchangePass.
420 std::unique_ptr<OperationPass<FuncOp>>
421 mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
422                                           LinalgTransformationFilter filter) {
423   return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
424                                                          filter);
425 }
426 
427 /// Create a LinalgStrategyVectorizePass.
428 std::unique_ptr<OperationPass<FuncOp>>
429 mlir::createLinalgStrategyVectorizePass(StringRef opName,
430                                         LinalgVectorizationOptions opt,
431                                         LinalgTransformationFilter filter) {
432   return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter);
433 }
434 
435 /// Create a LinalgStrategyEnablePass.
436 std::unique_ptr<OperationPass<FuncOp>>
437 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
438                                      LinalgTransformationFilter filter) {
439   return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
440 }
441 
442 /// Create a LinalgStrategyLowerVectorsPass.
443 std::unique_ptr<OperationPass<FuncOp>>
444 mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
445                                            LinalgTransformationFilter filter) {
446   return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
447 }
448 
449 /// Create a LinalgStrategyRemoveMarkersPass.
450 std::unique_ptr<OperationPass<FuncOp>>
451 mlir::createLinalgStrategyRemoveMarkersPass() {
452   return std::make_unique<LinalgStrategyRemoveMarkersPass>();
453 }
454