1 //===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a configurable pass that can apply patterns liberally
10 // and be plugged in a pass pipeline.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/SCF/Transforms.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Vector/VectorTransforms.h"
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/Support/LLVM.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "mlir/Transforms/LoopUtils.h"
31 #include "mlir/Transforms/Utils.h"
32 
33 using namespace mlir;
34 using namespace mlir::vector;
35 using namespace linalg;
36 
37 namespace {
38 
39 /// Configurable pass to apply pattern-based tiling and fusion.
40 struct LinalgStrategyTileAndFusePass
41     : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
42 
43   LinalgStrategyTileAndFusePass() = default;
44 
45   LinalgStrategyTileAndFusePass(StringRef opName,
46                                 LinalgTilingAndFusionOptions opt,
47                                 LinalgTransformationFilter filt)
48       : options(opt), filter(filt) {
49     this->anchorOpName.setValue(opName.str());
50   }
51 
52   void runOnFunction() override {
53     auto funcOp = getFunction();
54     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
55       return;
56 
57     RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
58     if (!anchorOpName.empty()) {
59       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
60           anchorOpName, funcOp.getContext(), options, filter);
61     } else {
62       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
63           funcOp.getContext(), options, filter);
64     }
65     // Search the root operation using bottom up traversal.
66     GreedyRewriteConfig grc;
67     grc.useTopDownTraversal = false;
68     (void)applyPatternsAndFoldGreedily(funcOp,
69                                        std::move(tilingAndFusionPattern), grc);
70   }
71 
72   LinalgTilingAndFusionOptions options;
73   LinalgTransformationFilter filter;
74 };
75 
76 /// Configurable pass to apply pattern-based linalg tiling.
77 struct LinalgStrategyTilePass
78     : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
79 
80   LinalgStrategyTilePass() = default;
81 
82   LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
83                          LinalgTransformationFilter filt)
84       : options(opt), filter(filt) {
85     this->anchorOpName.setValue(opName.str());
86   }
87 
88   void runOnFunction() override {
89     auto funcOp = getFunction();
90     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
91       return;
92 
93     RewritePatternSet tilingPattern(funcOp.getContext());
94     if (!anchorOpName.empty()) {
95       tilingPattern.add<LinalgGenericTilingPattern>(
96           anchorOpName, funcOp.getContext(), options, filter);
97     } else {
98       tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
99                                                     options);
100     }
101     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
102   }
103 
104   LinalgTilingOptions options;
105   LinalgTransformationFilter filter;
106 };
107 
108 /// Configurable pass to apply hoisting and padding.
109 struct LinalgStrategyPadPass
110     : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> {
111 
112   LinalgStrategyPadPass() = default;
113 
114   LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
115                         LinalgTransformationFilter filt)
116       : options(opt), filter(filt) {
117     this->anchorOpName.setValue(opName.str());
118   }
119 
120   void runOnFunction() override {
121     auto funcOp = getFunction();
122     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
123       return;
124 
125     RewritePatternSet paddingPattern(funcOp.getContext());
126     if (!anchorOpName.empty()) {
127       paddingPattern.add<LinalgPaddingPattern>(
128           anchorOpName, funcOp.getContext(), options, filter);
129     } else {
130       paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options,
131                                                filter);
132     }
133     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern))))
134       signalPassFailure();
135   }
136 
137   LinalgPaddingOptions options;
138   LinalgTransformationFilter filter;
139 };
140 
141 /// Configurable pass to apply pattern-based linalg generalization.
142 struct LinalgStrategyGeneralizePass
143     : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
144 
145   LinalgStrategyGeneralizePass() = default;
146 
147   LinalgStrategyGeneralizePass(StringRef opName,
148                                LinalgTransformationFilter filter)
149       : filter(filter) {
150     this->anchorOpName.setValue(opName.str());
151   }
152 
153   void runOnFunction() override {
154     auto funcOp = getFunction();
155     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
156       return;
157 
158     RewritePatternSet generalizationPattern(funcOp.getContext());
159     if (!anchorOpName.empty()) {
160       generalizationPattern.add<LinalgGeneralizationPattern>(
161           anchorOpName, funcOp.getContext(), filter);
162     } else {
163       generalizationPattern.add<LinalgGeneralizationPattern>(
164           funcOp.getContext(), filter);
165     }
166     if (failed(applyPatternsAndFoldGreedily(funcOp,
167                                             std::move(generalizationPattern))))
168       signalPassFailure();
169   }
170 
171   LinalgTransformationFilter filter;
172 };
173 
174 /// Configurable pass to apply lowering of coarser-grained named linalg ops into
175 /// finer-grained named versions.
176 struct LinalgStrategyDecomposePass
177     : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {
178 
179   LinalgStrategyDecomposePass() = default;
180 
181   void runOnFunction() override {
182     auto funcOp = getFunction();
183     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
184       return;
185     RewritePatternSet decompositionPattern(funcOp.getContext());
186     populateDecomposeConvolutionPatterns(decompositionPattern);
187     if (failed(applyPatternsAndFoldGreedily(funcOp,
188                                             std::move(decompositionPattern))))
189       signalPassFailure();
190   }
191 };
192 
193 /// Configurable pass to apply pattern-based linalg generalization.
194 struct LinalgStrategyInterchangePass
195     : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
196 
197   LinalgStrategyInterchangePass() = default;
198 
199   LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
200                                 LinalgTransformationFilter filter)
201       : iteratorInterchange(iteratorInterchange.begin(),
202                             iteratorInterchange.end()),
203         filter(filter) {}
204 
205   void runOnFunction() override {
206     auto funcOp = getFunction();
207     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
208       return;
209 
210     SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
211                                             iteratorInterchange.end());
212     RewritePatternSet interchangePattern(funcOp.getContext());
213     interchangePattern.add<GenericOpInterchangePattern>(
214         funcOp.getContext(), interchangeVector, filter);
215     if (failed(applyPatternsAndFoldGreedily(funcOp,
216                                             std::move(interchangePattern))))
217       signalPassFailure();
218   }
219 
220   SmallVector<int64_t> iteratorInterchange;
221   LinalgTransformationFilter filter;
222 };
223 
224 /// Configurable pass to apply pattern-based linalg promotion.
225 struct LinalgStrategyPromotePass
226     : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
227 
228   LinalgStrategyPromotePass() = default;
229 
230   LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
231                             LinalgTransformationFilter filt)
232       : options(opt), filter(filt) {
233     this->anchorOpName.setValue(opName.str());
234   }
235 
236   void runOnFunction() override {
237     auto funcOp = getFunction();
238     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
239       return;
240 
241     RewritePatternSet promotionPattern(funcOp.getContext());
242     if (!anchorOpName.empty()) {
243       promotionPattern.add<LinalgBasePromotionPattern>(
244           anchorOpName, funcOp.getContext(), options, filter);
245     } else {
246       promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
247                                                        filter, options);
248     }
249     (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
250   }
251 
252   LinalgPromotionOptions options;
253   LinalgTransformationFilter filter;
254 };
255 
256 /// Configurable pass to apply pattern-based linalg vectorization.
257 struct LinalgStrategyVectorizePass
258     : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
259 
260   LinalgStrategyVectorizePass() = default;
261 
262   LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
263                               LinalgTransformationFilter filt,
264                               bool padVectorize = false)
265       : options(opt), filter(filt) {
266     this->anchorOpName.setValue(opName.str());
267     this->vectorizePadding.setValue(padVectorize);
268   }
269 
270   void runOnFunction() override {
271     auto funcOp = getFunction();
272     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
273       return;
274 
275     RewritePatternSet vectorizationPatterns(funcOp.getContext());
276     if (!anchorOpName.empty()) {
277       vectorizationPatterns.add<LinalgVectorizationPattern>(
278           anchorOpName, funcOp.getContext(), options, filter);
279     } else {
280       vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
281                                                             filter, options);
282     }
283     vector::populateVectorTransferPermutationMapLoweringPatterns(
284         vectorizationPatterns);
285     vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
286     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
287                               linalg::LinalgCopyVTWForwardingPattern>(
288         funcOp.getContext(), /*benefit=*/2);
289     if (vectorizePadding) {
290       linalg::populatePadTensorOpVectorizationPatterns(vectorizationPatterns);
291     }
292     (void)applyPatternsAndFoldGreedily(funcOp,
293                                        std::move(vectorizationPatterns));
294   }
295 
296   LinalgVectorizationOptions options;
297   LinalgTransformationFilter filter;
298 };
299 
300 /// Configurable pass to enable the application of other pattern-based linalg
301 /// passes.
302 struct LinalgStrategyEnablePass
303     : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
304 
305   LinalgStrategyEnablePass(LinalgEnablingOptions opt,
306                            LinalgTransformationFilter filt)
307       : options(opt), filter(filt) {}
308 
309   void runOnFunction() override {
310     auto funcOp = getFunction();
311     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
312       return;
313 
314     MLIRContext *context = funcOp.getContext();
315     RewritePatternSet patterns =
316         linalg::getLinalgTilingCanonicalizationPatterns(context);
317     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
318     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
319       return signalPassFailure();
320 
321     if (options.licm) {
322       if (funcOp
323               ->walk([&](LoopLikeOpInterface loopLike) {
324                 if (failed(moveLoopInvariantCode(loopLike)))
325                   return WalkResult::interrupt();
326                 return WalkResult::advance();
327               })
328               .wasInterrupted())
329         return signalPassFailure();
330     }
331 
332     promoteSingleIterationLoops(funcOp);
333     if (options.hoistRedundantVectorTransfers)
334       hoistRedundantVectorTransfers(funcOp);
335 
336     if (options.hoistRedundantVectorTransfersOnTensor)
337       hoistRedundantVectorTransfersOnTensor(funcOp);
338   }
339 
340   LinalgEnablingOptions options;
341   LinalgTransformationFilter filter;
342 };
343 
344 /// Configurable pass to lower vector operations.
345 struct LinalgStrategyLowerVectorsPass
346     : public LinalgStrategyLowerVectorsPassBase<
347           LinalgStrategyLowerVectorsPass> {
348 
349   LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
350                                  LinalgTransformationFilter filt)
351       : options(opt), filter(filt) {}
352 
353   void runOnFunction() override {
354     auto funcOp = getFunction();
355     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
356       return;
357 
358     MLIRContext *context = funcOp.getContext();
359     RewritePatternSet patterns(context);
360     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
361     // In a progressive lowering of vectors, this would be the 1st step.
362     if (options.contractionLowering) {
363       patterns.add<ContractionOpToOuterProductOpLowering,
364                    ContractionOpToMatmulOpLowering, ContractionOpLowering>(
365           options.vectorTransformOptions, context);
366       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
367     }
368     // In a progressive lowering of vectors, this would be the 2nd step.
369     if (options.multiReductionLowering) {
370       vector::populateVectorMultiReductionLoweringPatterns(
371           patterns,
372           options.vectorTransformOptions.vectorMultiReductionLowering);
373     }
374     // In a progressive lowering of vectors, this would be the 3rd step.
375     if (options.transferPartialRewrite) {
376       patterns.add<vector::VectorTransferFullPartialRewriter>(
377           context, options.vectorTransformOptions);
378     }
379     // In a progressive lowering of vectors, this would be the 4th step.
380     if (options.transferLowering) {
381       vector::populateVectorTransferLoweringPatterns(patterns,
382                                                      options.maxTransferRank);
383     }
384     // In a progressive lowering of vectors, this would be the 5th step.
385     if (options.transferToSCFConversion) {
386       populateVectorToSCFConversionPatterns(
387           patterns, options.vectorTransferToSCFOptions.setTargetRank(
388                         options.maxTransferRank));
389     }
390     // In a progressive lowering of vectors, this would be the 6th step.
391     if (options.shapeCastLowering) {
392       vector::populateVectorShapeCastLoweringPatterns(patterns);
393     }
394     // In a progressive lowering of vectors, this would be the 7th step.
395     if (options.transposeLowering) {
396       vector::populateVectorTransposeLoweringPatterns(
397           patterns, options.vectorTransformOptions);
398       if (options.avx2Lowering)
399         x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
400             patterns, options.avx2LoweringOptions, /*benefit=*/10);
401     }
402     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
403   }
404 
405   LinalgVectorLoweringOptions options;
406   LinalgTransformationFilter filter;
407 };
408 
409 /// Configurable pass to lower vector operations.
410 struct LinalgStrategyRemoveMarkersPass
411     : public LinalgStrategyRemoveMarkersPassBase<
412           LinalgStrategyRemoveMarkersPass> {
413 
414   void runOnFunction() override {
415     auto funcOp = getFunction();
416     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
417       return;
418     funcOp.walk([](LinalgOp op) {
419       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
420     });
421   }
422 };
423 } // namespace
424 
425 /// Create a LinalgStrategyTileAndFusePass.
426 std::unique_ptr<OperationPass<FuncOp>>
427 mlir::createLinalgStrategyTileAndFusePass(StringRef opName,
428                                           LinalgTilingAndFusionOptions options,
429                                           LinalgTransformationFilter filter) {
430   return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
431                                                          filter);
432 }
433 
434 /// Create a LinalgStrategyTilePass.
435 std::unique_ptr<OperationPass<FuncOp>>
436 mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
437                                    LinalgTransformationFilter filter) {
438   return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
439 }
440 
441 /// Create a LinalgStrategyPadPass.
442 std::unique_ptr<OperationPass<FuncOp>>
443 mlir::createLinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
444                                   LinalgTransformationFilter filter) {
445   return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
446 }
447 
448 /// Create a LinalgStrategyPromotePass.
449 std::unique_ptr<OperationPass<FuncOp>>
450 mlir::createLinalgStrategyPromotePass(StringRef opName,
451                                       LinalgPromotionOptions opt,
452                                       LinalgTransformationFilter filter) {
453   return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
454 }
455 
456 /// Create a LinalgStrategyGeneralizePass.
457 std::unique_ptr<OperationPass<FuncOp>>
458 mlir::createLinalgStrategyGeneralizePass(StringRef opName,
459                                          LinalgTransformationFilter filter) {
460   return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
461 }
462 /// Create a LinalgStrategyDecomposePass.
463 // TODO: atm this is applied to all supported ops. If/when we need finer control
464 // this should be exposed with an opName + filter and a proper pattern.
465 std::unique_ptr<OperationPass<FuncOp>>
466 mlir::createLinalgStrategyDecomposePass() {
467   return std::make_unique<LinalgStrategyDecomposePass>();
468 }
469 
470 /// Create a LinalgStrategyInterchangePass.
471 std::unique_ptr<OperationPass<FuncOp>>
472 mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
473                                           LinalgTransformationFilter filter) {
474   return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
475                                                          filter);
476 }
477 
478 /// Create a LinalgStrategyVectorizePass.
479 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass(
480     StringRef opName, LinalgVectorizationOptions opt,
481     LinalgTransformationFilter filter, bool padVectorize) {
482   return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
483                                                        padVectorize);
484 }
485 
486 /// Create a LinalgStrategyEnablePass.
487 std::unique_ptr<OperationPass<FuncOp>>
488 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
489                                      LinalgTransformationFilter filter) {
490   return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
491 }
492 
493 /// Create a LinalgStrategyLowerVectorsPass.
494 std::unique_ptr<OperationPass<FuncOp>>
495 mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
496                                            LinalgTransformationFilter filter) {
497   return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
498 }
499 
500 /// Create a LinalgStrategyRemoveMarkersPass.
501 std::unique_ptr<OperationPass<FuncOp>>
502 mlir::createLinalgStrategyRemoveMarkersPass() {
503   return std::make_unique<LinalgStrategyRemoveMarkersPass>();
504 }
505