1 //===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a configurable pass that can apply patterns liberally
10 // and be plugged in a pass pipeline.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/SCF/Transforms.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Vector/VectorTransforms.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/Pass/PassManager.h"
28 #include "mlir/Support/LLVM.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "mlir/Transforms/LoopUtils.h"
31 #include "mlir/Transforms/Passes.h"
32 #include "mlir/Transforms/Utils.h"
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 using namespace linalg;
37 
38 namespace {
39 
40 /// Configurable pass to apply pattern-based tiling and fusion.
41 struct LinalgStrategyTileAndFusePass
42     : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
43 
44   LinalgStrategyTileAndFusePass() = default;
45 
46   LinalgStrategyTileAndFusePass(StringRef opName,
47                                 LinalgTilingAndFusionOptions opt,
48                                 LinalgTransformationFilter filt)
49       : options(opt), filter(filt) {
50     this->anchorOpName.setValue(opName.str());
51   }
52 
53   void runOnFunction() override {
54     auto funcOp = getFunction();
55     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
56       return;
57 
58     RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
59     if (!anchorOpName.empty()) {
60       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
61           anchorOpName, funcOp.getContext(), options, filter);
62     } else {
63       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
64           funcOp.getContext(), options, filter);
65     }
66     // Search the root operation using bottom up traversal.
67     GreedyRewriteConfig config;
68     config.useTopDownTraversal = false;
69     (void)applyPatternsAndFoldGreedily(
70         funcOp, std::move(tilingAndFusionPattern), config);
71   }
72 
73   LinalgTilingAndFusionOptions options;
74   LinalgTransformationFilter filter;
75 };
76 
77 /// Configurable pass to apply pattern-based linalg tiling.
78 struct LinalgStrategyTilePass
79     : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
80 
81   LinalgStrategyTilePass() = default;
82 
83   LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
84                          LinalgTransformationFilter filt)
85       : options(opt), filter(filt) {
86     this->anchorOpName.setValue(opName.str());
87   }
88 
89   void runOnFunction() override {
90     auto funcOp = getFunction();
91     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
92       return;
93 
94     RewritePatternSet tilingPattern(funcOp.getContext());
95     if (!anchorOpName.empty()) {
96       tilingPattern.add<LinalgGenericTilingPattern>(
97           anchorOpName, funcOp.getContext(), options, filter);
98     } else {
99       tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
100                                                     options);
101     }
102     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
103   }
104 
105   LinalgTilingOptions options;
106   LinalgTransformationFilter filter;
107 };
108 
109 /// Configurable pass to apply hoisting and padding.
110 struct LinalgStrategyPadPass
111     : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> {
112 
113   LinalgStrategyPadPass() = default;
114 
115   LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
116                         LinalgTransformationFilter filt)
117       : options(opt), filter(filt) {
118     this->anchorOpName.setValue(opName.str());
119   }
120 
121   void runOnFunction() override {
122     auto funcOp = getFunction();
123     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
124       return;
125 
126     RewritePatternSet paddingPattern(funcOp.getContext());
127     if (!anchorOpName.empty()) {
128       paddingPattern.add<LinalgPaddingPattern>(
129           anchorOpName, funcOp.getContext(), options, filter);
130     } else {
131       paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options,
132                                                filter);
133     }
134     (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern));
135   }
136 
137   LinalgPaddingOptions options;
138   LinalgTransformationFilter filter;
139 };
140 
141 /// Configurable pass to apply pattern-based linalg generalization.
142 struct LinalgStrategyGeneralizePass
143     : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
144 
145   LinalgStrategyGeneralizePass() = default;
146 
147   LinalgStrategyGeneralizePass(StringRef opName,
148                                LinalgTransformationFilter filter)
149       : filter(filter) {
150     this->anchorOpName.setValue(opName.str());
151   }
152 
153   void runOnFunction() override {
154     auto funcOp = getFunction();
155     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
156       return;
157 
158     RewritePatternSet generalizationPattern(funcOp.getContext());
159     if (!anchorOpName.empty()) {
160       generalizationPattern.add<LinalgGeneralizationPattern>(
161           anchorOpName, funcOp.getContext(), filter);
162     } else {
163       generalizationPattern.add<LinalgGeneralizationPattern>(
164           funcOp.getContext(), filter);
165     }
166     if (failed(applyPatternsAndFoldGreedily(funcOp,
167                                             std::move(generalizationPattern))))
168       signalPassFailure();
169   }
170 
171   LinalgTransformationFilter filter;
172 };
173 
174 /// Configurable pass to apply lowering of coarser-grained named linalg ops into
175 /// finer-grained named versions.
176 struct LinalgStrategyDecomposePass
177     : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {
178 
179   LinalgStrategyDecomposePass() = default;
180 
181   LinalgStrategyDecomposePass(LinalgTransformationFilter filter)
182       : filter(filter) {}
183 
184   void runOnFunction() override {
185     auto funcOp = getFunction();
186     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
187       return;
188     RewritePatternSet decompositionPattern(funcOp.getContext());
189     populateDecomposeConvolutionPatterns(decompositionPattern, filter);
190     if (failed(applyPatternsAndFoldGreedily(funcOp,
191                                             std::move(decompositionPattern))))
192       signalPassFailure();
193   }
194 
195   LinalgTransformationFilter filter;
196 };
197 
198 /// Configurable pass to apply pattern-based linalg generalization.
199 struct LinalgStrategyInterchangePass
200     : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
201 
202   LinalgStrategyInterchangePass() = default;
203 
204   LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
205                                 LinalgTransformationFilter filter)
206       : iteratorInterchange(iteratorInterchange.begin(),
207                             iteratorInterchange.end()),
208         filter(filter) {}
209 
210   void runOnFunction() override {
211     auto funcOp = getFunction();
212     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
213       return;
214 
215     SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
216                                             iteratorInterchange.end());
217     RewritePatternSet interchangePattern(funcOp.getContext());
218     interchangePattern.add<GenericOpInterchangePattern>(
219         funcOp.getContext(), interchangeVector, filter);
220     if (failed(applyPatternsAndFoldGreedily(funcOp,
221                                             std::move(interchangePattern))))
222       signalPassFailure();
223   }
224 
225   SmallVector<int64_t> iteratorInterchange;
226   LinalgTransformationFilter filter;
227 };
228 
229 /// Configurable pass to apply pattern-based linalg promotion.
230 struct LinalgStrategyPromotePass
231     : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
232 
233   LinalgStrategyPromotePass() = default;
234 
235   LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
236                             LinalgTransformationFilter filt)
237       : options(opt), filter(filt) {
238     this->anchorOpName.setValue(opName.str());
239   }
240 
241   void runOnFunction() override {
242     auto funcOp = getFunction();
243     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
244       return;
245 
246     RewritePatternSet promotionPattern(funcOp.getContext());
247     if (!anchorOpName.empty()) {
248       promotionPattern.add<LinalgBasePromotionPattern>(
249           anchorOpName, funcOp.getContext(), options, filter);
250     } else {
251       promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
252                                                        filter, options);
253     }
254     (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
255   }
256 
257   LinalgPromotionOptions options;
258   LinalgTransformationFilter filter;
259 };
260 
261 /// Configurable pass to apply pattern-based linalg vectorization.
262 struct LinalgStrategyVectorizePass
263     : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
264 
265   LinalgStrategyVectorizePass() = default;
266 
267   LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
268                               LinalgTransformationFilter filt,
269                               bool padVectorize = false)
270       : options(opt), filter(filt) {
271     this->anchorOpName.setValue(opName.str());
272     this->vectorizePadding.setValue(padVectorize);
273   }
274 
275   void runOnFunction() override {
276     auto funcOp = getFunction();
277     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
278       return;
279 
280     RewritePatternSet vectorizationPatterns(funcOp.getContext());
281     if (!anchorOpName.empty()) {
282       vectorizationPatterns.add<LinalgVectorizationPattern>(
283           anchorOpName, funcOp.getContext(), options, filter);
284     } else {
285       vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
286                                                             filter, options);
287     }
288     vector::populateVectorTransferPermutationMapLoweringPatterns(
289         vectorizationPatterns);
290     vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
291     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
292                               linalg::LinalgCopyVTWForwardingPattern>(
293         funcOp.getContext(), /*benefit=*/2);
294     TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns,
295                                                 funcOp.getContext());
296     TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns,
297                                                  funcOp.getContext());
298     (void)applyPatternsAndFoldGreedily(funcOp,
299                                        std::move(vectorizationPatterns));
300 
301     // Apply the pad tensor op vectorization separately to avoid running the
302     // GenericPadTensorOpVectorizationPattern too early.
303     // TODO: Improve once we have better infrastructure to control pattern
304     // application.
305     if (vectorizePadding) {
306       RewritePatternSet patterns(funcOp.getContext());
307       linalg::populatePadTensorOpVectorizationPatterns(patterns);
308       (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
309     }
310   }
311 
312   LinalgVectorizationOptions options;
313   LinalgTransformationFilter filter;
314 };
315 
316 /// Configurable pass to enable the application of other pattern-based linalg
317 /// passes.
318 struct LinalgStrategyEnablePass
319     : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
320 
321   LinalgStrategyEnablePass(LinalgEnablingOptions opt,
322                            LinalgTransformationFilter filt)
323       : options(opt), filter(filt) {}
324 
325   void runOnFunction() override {
326     auto funcOp = getFunction();
327     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
328       return;
329 
330     MLIRContext *context = funcOp.getContext();
331     RewritePatternSet patterns =
332         linalg::getLinalgTilingCanonicalizationPatterns(context);
333     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
334     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
335       return signalPassFailure();
336 
337     if (options.licm) {
338       if (funcOp
339               ->walk([&](LoopLikeOpInterface loopLike) {
340                 if (failed(moveLoopInvariantCode(loopLike)))
341                   return WalkResult::interrupt();
342                 return WalkResult::advance();
343               })
344               .wasInterrupted())
345         return signalPassFailure();
346     }
347 
348     promoteSingleIterationLoops(funcOp);
349     if (options.hoistRedundantVectorTransfers)
350       hoistRedundantVectorTransfers(funcOp);
351 
352     if (options.hoistRedundantVectorTransfersOnTensor)
353       hoistRedundantVectorTransfersOnTensor(funcOp);
354 
355     // Run CSE to cleanup after canonicalization.
356     OpPassManager dynamicPM("builtin.func");
357     dynamicPM.addPass(createCSEPass());
358     if (failed(runPipeline(dynamicPM, funcOp)))
359       return signalPassFailure();
360   }
361 
362   LinalgEnablingOptions options;
363   LinalgTransformationFilter filter;
364 };
365 
366 /// Configurable pass to lower vector operations.
367 struct LinalgStrategyLowerVectorsPass
368     : public LinalgStrategyLowerVectorsPassBase<
369           LinalgStrategyLowerVectorsPass> {
370 
371   LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
372                                  LinalgTransformationFilter filt)
373       : options(opt), filter(filt) {}
374 
375   void runOnFunction() override {
376     auto funcOp = getFunction();
377     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
378       return;
379 
380     MLIRContext *context = funcOp.getContext();
381     RewritePatternSet patterns(context);
382     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
383     // In a progressive lowering of vectors, this would be the 1st step.
384     if (options.contractionLowering) {
385       patterns.add<ContractionOpToOuterProductOpLowering,
386                    ContractionOpToMatmulOpLowering, ContractionOpLowering>(
387           options.vectorTransformOptions, context);
388       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
389     }
390     // In a progressive lowering of vectors, this would be the 2nd step.
391     if (options.multiReductionLowering) {
392       vector::populateVectorMultiReductionLoweringPatterns(
393           patterns,
394           options.vectorTransformOptions.vectorMultiReductionLowering);
395     }
396     // In a progressive lowering of vectors, this would be the 3rd step.
397     if (options.transferPartialRewrite) {
398       patterns.add<vector::VectorTransferFullPartialRewriter>(
399           context, options.vectorTransformOptions);
400     }
401     // In a progressive lowering of vectors, this would be the 4th step.
402     if (options.transferLowering) {
403       vector::populateVectorTransferLoweringPatterns(patterns,
404                                                      options.maxTransferRank);
405     }
406     // In a progressive lowering of vectors, this would be the 5th step.
407     if (options.transferToSCFConversion) {
408       populateVectorToSCFConversionPatterns(
409           patterns, options.vectorTransferToSCFOptions.setTargetRank(
410                         options.maxTransferRank));
411     }
412     // In a progressive lowering of vectors, this would be the 6th step.
413     if (options.shapeCastLowering) {
414       vector::populateVectorShapeCastLoweringPatterns(patterns);
415     }
416     // In a progressive lowering of vectors, this would be the 7th step.
417     if (options.transposeLowering) {
418       vector::populateVectorTransposeLoweringPatterns(
419           patterns, options.vectorTransformOptions);
420       if (options.avx2Lowering)
421         x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
422             patterns, options.avx2LoweringOptions, /*benefit=*/10);
423     }
424     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
425   }
426 
427   LinalgVectorLoweringOptions options;
428   LinalgTransformationFilter filter;
429 };
430 
431 /// Configurable pass to lower vector operations.
432 struct LinalgStrategyRemoveMarkersPass
433     : public LinalgStrategyRemoveMarkersPassBase<
434           LinalgStrategyRemoveMarkersPass> {
435 
436   void runOnFunction() override {
437     auto funcOp = getFunction();
438     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
439       return;
440     funcOp.walk([](LinalgOp op) {
441       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
442     });
443   }
444 };
445 } // namespace
446 
447 /// Create a LinalgStrategyTileAndFusePass.
448 std::unique_ptr<OperationPass<FuncOp>>
449 mlir::createLinalgStrategyTileAndFusePass(StringRef opName,
450                                           LinalgTilingAndFusionOptions options,
451                                           LinalgTransformationFilter filter) {
452   return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
453                                                          filter);
454 }
455 
456 /// Create a LinalgStrategyTilePass.
457 std::unique_ptr<OperationPass<FuncOp>>
458 mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
459                                    LinalgTransformationFilter filter) {
460   return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
461 }
462 
463 /// Create a LinalgStrategyPadPass.
464 std::unique_ptr<OperationPass<FuncOp>>
465 mlir::createLinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
466                                   LinalgTransformationFilter filter) {
467   return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
468 }
469 
470 /// Create a LinalgStrategyPromotePass.
471 std::unique_ptr<OperationPass<FuncOp>>
472 mlir::createLinalgStrategyPromotePass(StringRef opName,
473                                       LinalgPromotionOptions opt,
474                                       LinalgTransformationFilter filter) {
475   return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
476 }
477 
478 /// Create a LinalgStrategyGeneralizePass.
479 std::unique_ptr<OperationPass<FuncOp>>
480 mlir::createLinalgStrategyGeneralizePass(StringRef opName,
481                                          LinalgTransformationFilter filter) {
482   return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
483 }
484 
485 /// Create a LinalgStrategyDecomposePass.
486 // TODO: if/when we need finer control add an `opName` parameter.
487 std::unique_ptr<OperationPass<FuncOp>>
488 mlir::createLinalgStrategyDecomposePass(LinalgTransformationFilter filter) {
489   return std::make_unique<LinalgStrategyDecomposePass>(filter);
490 }
491 
492 /// Create a LinalgStrategyInterchangePass.
493 std::unique_ptr<OperationPass<FuncOp>>
494 mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
495                                           LinalgTransformationFilter filter) {
496   return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
497                                                          filter);
498 }
499 
500 /// Create a LinalgStrategyVectorizePass.
501 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass(
502     StringRef opName, LinalgVectorizationOptions opt,
503     LinalgTransformationFilter filter, bool padVectorize) {
504   return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
505                                                        padVectorize);
506 }
507 
508 /// Create a LinalgStrategyEnablePass.
509 std::unique_ptr<OperationPass<FuncOp>>
510 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
511                                      LinalgTransformationFilter filter) {
512   return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
513 }
514 
515 /// Create a LinalgStrategyLowerVectorsPass.
516 std::unique_ptr<OperationPass<FuncOp>>
517 mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
518                                            LinalgTransformationFilter filter) {
519   return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
520 }
521 
522 /// Create a LinalgStrategyRemoveMarkersPass.
523 std::unique_ptr<OperationPass<FuncOp>>
524 mlir::createLinalgStrategyRemoveMarkersPass() {
525   return std::make_unique<LinalgStrategyRemoveMarkersPass>();
526 }
527