1 //===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a configurable pass that can apply patterns liberally
10 // and be plugged in a pass pipeline.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/SCF/Transforms.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Vector/VectorTransforms.h"
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/Pass/PassManager.h"
29 #include "mlir/Support/LLVM.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "mlir/Transforms/LoopUtils.h"
32 #include "mlir/Transforms/Passes.h"
33 #include "mlir/Transforms/Utils.h"
34 
35 using namespace mlir;
36 using namespace mlir::vector;
37 using namespace linalg;
38 
39 namespace {
40 
41 /// Configurable pass to apply pattern-based tiling and fusion.
42 struct LinalgStrategyTileAndFusePass
43     : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
44 
45   LinalgStrategyTileAndFusePass() = default;
46 
47   LinalgStrategyTileAndFusePass(StringRef opName,
48                                 LinalgTilingAndFusionOptions opt,
49                                 LinalgTransformationFilter filt)
50       : options(opt), filter(filt) {
51     this->anchorOpName.setValue(opName.str());
52   }
53 
54   void runOnFunction() override {
55     auto funcOp = getFunction();
56     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
57       return;
58 
59     RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
60     if (!anchorOpName.empty()) {
61       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
62           anchorOpName, funcOp.getContext(), options, filter);
63     } else {
64       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
65           funcOp.getContext(), options, filter);
66     }
67     // Search the root operation using bottom up traversal.
68     GreedyRewriteConfig config;
69     config.useTopDownTraversal = false;
70     (void)applyPatternsAndFoldGreedily(
71         funcOp, std::move(tilingAndFusionPattern), config);
72   }
73 
74   LinalgTilingAndFusionOptions options;
75   LinalgTransformationFilter filter;
76 };
77 
78 /// Configurable pass to apply pattern-based linalg tiling.
79 struct LinalgStrategyTilePass
80     : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
81 
82   LinalgStrategyTilePass() = default;
83 
84   LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
85                          LinalgTransformationFilter filt)
86       : options(opt), filter(filt) {
87     this->anchorOpName.setValue(opName.str());
88   }
89 
90   void runOnFunction() override {
91     auto funcOp = getFunction();
92     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
93       return;
94 
95     RewritePatternSet tilingPattern(funcOp.getContext());
96     if (!anchorOpName.empty()) {
97       tilingPattern.add<LinalgGenericTilingPattern>(
98           anchorOpName, funcOp.getContext(), options, filter);
99     } else {
100       tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
101                                                     options);
102     }
103     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
104   }
105 
106   LinalgTilingOptions options;
107   LinalgTransformationFilter filter;
108 };
109 
110 /// Configurable pass to apply hoisting and padding.
111 struct LinalgStrategyPadPass
112     : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> {
113 
114   LinalgStrategyPadPass() = default;
115 
116   LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
117                         LinalgTransformationFilter filt)
118       : options(opt), filter(filt) {
119     this->anchorOpName.setValue(opName.str());
120   }
121 
122   void runOnFunction() override {
123     auto funcOp = getFunction();
124     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
125       return;
126 
127     RewritePatternSet paddingPattern(funcOp.getContext());
128     if (!anchorOpName.empty()) {
129       paddingPattern.add<LinalgPaddingPattern>(
130           anchorOpName, funcOp.getContext(), options, filter);
131     } else {
132       paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options,
133                                                filter);
134     }
135     // Traverse the operations top down to pad producers before consumers. The
136     // extract slice operation introduced after every padded operation enables
137     // padding its consumers. Padding an operation whose producers have not been
138     // padded before fails due to the missing extract slice operations. In this
139     // case, the padding pattern increments the transformation marker without
140     // padding the operation. The top down traversal is thus not only a
141     // performance optimization but also needed to pad all operations along the
142     // use-def chains.
143     GreedyRewriteConfig config;
144     config.useTopDownTraversal = true;
145     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern),
146                                             config)))
147       signalPassFailure();
148   }
149 
150   LinalgPaddingOptions options;
151   LinalgTransformationFilter filter;
152 };
153 
154 /// Configurable pass to apply pattern-based linalg generalization.
155 struct LinalgStrategyGeneralizePass
156     : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
157 
158   LinalgStrategyGeneralizePass() = default;
159 
160   LinalgStrategyGeneralizePass(StringRef opName,
161                                LinalgTransformationFilter filter)
162       : filter(filter) {
163     this->anchorOpName.setValue(opName.str());
164   }
165 
166   void runOnFunction() override {
167     auto funcOp = getFunction();
168     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
169       return;
170 
171     RewritePatternSet generalizationPattern(funcOp.getContext());
172     if (!anchorOpName.empty()) {
173       generalizationPattern.add<LinalgGeneralizationPattern>(
174           anchorOpName, funcOp.getContext(), filter);
175     } else {
176       generalizationPattern.add<LinalgGeneralizationPattern>(
177           funcOp.getContext(), filter);
178     }
179     if (failed(applyPatternsAndFoldGreedily(funcOp,
180                                             std::move(generalizationPattern))))
181       signalPassFailure();
182   }
183 
184   LinalgTransformationFilter filter;
185 };
186 
187 /// Configurable pass to apply lowering of coarser-grained named linalg ops into
188 /// finer-grained named versions.
189 struct LinalgStrategyDecomposePass
190     : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {
191 
192   LinalgStrategyDecomposePass() = default;
193 
194   LinalgStrategyDecomposePass(LinalgTransformationFilter filter)
195       : filter(filter) {}
196 
197   void runOnFunction() override {
198     auto funcOp = getFunction();
199     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
200       return;
201     RewritePatternSet decompositionPattern(funcOp.getContext());
202     populateDecomposeConvolutionPatterns(decompositionPattern, filter);
203     if (failed(applyPatternsAndFoldGreedily(funcOp,
204                                             std::move(decompositionPattern))))
205       signalPassFailure();
206   }
207 
208   LinalgTransformationFilter filter;
209 };
210 
211 /// Configurable pass to apply pattern-based linalg generalization.
212 struct LinalgStrategyInterchangePass
213     : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
214 
215   LinalgStrategyInterchangePass() = default;
216 
217   LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
218                                 LinalgTransformationFilter filter)
219       : iteratorInterchange(iteratorInterchange.begin(),
220                             iteratorInterchange.end()),
221         filter(filter) {}
222 
223   void runOnFunction() override {
224     auto funcOp = getFunction();
225     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
226       return;
227 
228     SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
229                                             iteratorInterchange.end());
230     RewritePatternSet interchangePattern(funcOp.getContext());
231     interchangePattern.add<GenericOpInterchangePattern>(
232         funcOp.getContext(), interchangeVector, filter);
233     if (failed(applyPatternsAndFoldGreedily(funcOp,
234                                             std::move(interchangePattern))))
235       signalPassFailure();
236   }
237 
238   SmallVector<int64_t> iteratorInterchange;
239   LinalgTransformationFilter filter;
240 };
241 
242 /// Configurable pass to apply pattern-based linalg promotion.
243 struct LinalgStrategyPromotePass
244     : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
245 
246   LinalgStrategyPromotePass() = default;
247 
248   LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
249                             LinalgTransformationFilter filt)
250       : options(opt), filter(filt) {
251     this->anchorOpName.setValue(opName.str());
252   }
253 
254   void runOnFunction() override {
255     auto funcOp = getFunction();
256     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
257       return;
258 
259     RewritePatternSet promotionPattern(funcOp.getContext());
260     if (!anchorOpName.empty()) {
261       promotionPattern.add<LinalgBasePromotionPattern>(
262           anchorOpName, funcOp.getContext(), options, filter);
263     } else {
264       promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
265                                                        filter, options);
266     }
267     (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
268   }
269 
270   LinalgPromotionOptions options;
271   LinalgTransformationFilter filter;
272 };
273 
274 /// Configurable pass to apply pattern-based linalg vectorization.
275 struct LinalgStrategyVectorizePass
276     : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
277 
278   LinalgStrategyVectorizePass() = default;
279 
280   LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
281                               LinalgTransformationFilter filt,
282                               bool padVectorize = false)
283       : options(opt), filter(filt) {
284     this->anchorOpName.setValue(opName.str());
285     this->vectorizePadding.setValue(padVectorize);
286   }
287 
288   void runOnFunction() override {
289     auto funcOp = getFunction();
290     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
291       return;
292 
293     RewritePatternSet vectorizationPatterns(funcOp.getContext());
294     if (!anchorOpName.empty()) {
295       vectorizationPatterns.add<LinalgVectorizationPattern>(
296           anchorOpName, funcOp.getContext(), options, filter);
297     } else {
298       vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
299                                                             filter, options);
300     }
301     vector::populateVectorTransferPermutationMapLoweringPatterns(
302         vectorizationPatterns);
303     vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
304     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
305                               linalg::LinalgCopyVTWForwardingPattern>(
306         funcOp.getContext(), /*benefit=*/2);
307     if (vectorizePadding) {
308       linalg::populatePadTensorOpVectorizationPatterns(vectorizationPatterns);
309     }
310     (void)applyPatternsAndFoldGreedily(funcOp,
311                                        std::move(vectorizationPatterns));
312   }
313 
314   LinalgVectorizationOptions options;
315   LinalgTransformationFilter filter;
316 };
317 
318 /// Configurable pass to enable the application of other pattern-based linalg
319 /// passes.
320 struct LinalgStrategyEnablePass
321     : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
322 
323   LinalgStrategyEnablePass(LinalgEnablingOptions opt,
324                            LinalgTransformationFilter filt)
325       : options(opt), filter(filt) {}
326 
327   void runOnFunction() override {
328     auto funcOp = getFunction();
329     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
330       return;
331 
332     MLIRContext *context = funcOp.getContext();
333     RewritePatternSet patterns =
334         linalg::getLinalgTilingCanonicalizationPatterns(context);
335     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
336     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
337       return signalPassFailure();
338 
339     if (options.licm) {
340       if (funcOp
341               ->walk([&](LoopLikeOpInterface loopLike) {
342                 if (failed(moveLoopInvariantCode(loopLike)))
343                   return WalkResult::interrupt();
344                 return WalkResult::advance();
345               })
346               .wasInterrupted())
347         return signalPassFailure();
348     }
349 
350     promoteSingleIterationLoops(funcOp);
351     if (options.hoistRedundantVectorTransfers)
352       hoistRedundantVectorTransfers(funcOp);
353 
354     if (options.hoistRedundantVectorTransfersOnTensor)
355       hoistRedundantVectorTransfersOnTensor(funcOp);
356 
357     // Run CSE to cleanup after canonicalization.
358     OpPassManager dynamicPM("builtin.func");
359     dynamicPM.addPass(createCSEPass());
360     if (failed(runPipeline(dynamicPM, funcOp)))
361       return signalPassFailure();
362   }
363 
364   LinalgEnablingOptions options;
365   LinalgTransformationFilter filter;
366 };
367 
368 /// Configurable pass to lower vector operations.
369 struct LinalgStrategyLowerVectorsPass
370     : public LinalgStrategyLowerVectorsPassBase<
371           LinalgStrategyLowerVectorsPass> {
372 
373   LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
374                                  LinalgTransformationFilter filt)
375       : options(opt), filter(filt) {}
376 
377   void runOnFunction() override {
378     auto funcOp = getFunction();
379     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
380       return;
381 
382     MLIRContext *context = funcOp.getContext();
383     RewritePatternSet patterns(context);
384     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
385     // In a progressive lowering of vectors, this would be the 1st step.
386     if (options.contractionLowering) {
387       patterns.add<ContractionOpToOuterProductOpLowering,
388                    ContractionOpToMatmulOpLowering, ContractionOpLowering>(
389           options.vectorTransformOptions, context);
390       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
391     }
392     // In a progressive lowering of vectors, this would be the 2nd step.
393     if (options.multiReductionLowering) {
394       vector::populateVectorMultiReductionLoweringPatterns(
395           patterns,
396           options.vectorTransformOptions.vectorMultiReductionLowering);
397     }
398     // In a progressive lowering of vectors, this would be the 3rd step.
399     if (options.transferPartialRewrite) {
400       patterns.add<vector::VectorTransferFullPartialRewriter>(
401           context, options.vectorTransformOptions);
402     }
403     // In a progressive lowering of vectors, this would be the 4th step.
404     if (options.transferLowering) {
405       vector::populateVectorTransferLoweringPatterns(patterns,
406                                                      options.maxTransferRank);
407     }
408     // In a progressive lowering of vectors, this would be the 5th step.
409     if (options.transferToSCFConversion) {
410       populateVectorToSCFConversionPatterns(
411           patterns, options.vectorTransferToSCFOptions.setTargetRank(
412                         options.maxTransferRank));
413     }
414     // In a progressive lowering of vectors, this would be the 6th step.
415     if (options.shapeCastLowering) {
416       vector::populateVectorShapeCastLoweringPatterns(patterns);
417     }
418     // In a progressive lowering of vectors, this would be the 7th step.
419     if (options.transposeLowering) {
420       vector::populateVectorTransposeLoweringPatterns(
421           patterns, options.vectorTransformOptions);
422       if (options.avx2Lowering)
423         x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
424             patterns, options.avx2LoweringOptions, /*benefit=*/10);
425     }
426     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
427   }
428 
429   LinalgVectorLoweringOptions options;
430   LinalgTransformationFilter filter;
431 };
432 
433 /// Configurable pass to lower vector operations.
434 struct LinalgStrategyRemoveMarkersPass
435     : public LinalgStrategyRemoveMarkersPassBase<
436           LinalgStrategyRemoveMarkersPass> {
437 
438   void runOnFunction() override {
439     auto funcOp = getFunction();
440     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
441       return;
442     funcOp.walk([](LinalgOp op) {
443       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
444     });
445   }
446 };
447 } // namespace
448 
449 /// Create a LinalgStrategyTileAndFusePass.
450 std::unique_ptr<OperationPass<FuncOp>>
451 mlir::createLinalgStrategyTileAndFusePass(StringRef opName,
452                                           LinalgTilingAndFusionOptions options,
453                                           LinalgTransformationFilter filter) {
454   return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
455                                                          filter);
456 }
457 
458 /// Create a LinalgStrategyTilePass.
459 std::unique_ptr<OperationPass<FuncOp>>
460 mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
461                                    LinalgTransformationFilter filter) {
462   return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
463 }
464 
465 /// Create a LinalgStrategyPadPass.
466 std::unique_ptr<OperationPass<FuncOp>>
467 mlir::createLinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
468                                   LinalgTransformationFilter filter) {
469   return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
470 }
471 
472 /// Create a LinalgStrategyPromotePass.
473 std::unique_ptr<OperationPass<FuncOp>>
474 mlir::createLinalgStrategyPromotePass(StringRef opName,
475                                       LinalgPromotionOptions opt,
476                                       LinalgTransformationFilter filter) {
477   return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
478 }
479 
480 /// Create a LinalgStrategyGeneralizePass.
481 std::unique_ptr<OperationPass<FuncOp>>
482 mlir::createLinalgStrategyGeneralizePass(StringRef opName,
483                                          LinalgTransformationFilter filter) {
484   return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
485 }
486 
487 /// Create a LinalgStrategyDecomposePass.
488 // TODO: if/when we need finer control add an `opName` parameter.
489 std::unique_ptr<OperationPass<FuncOp>>
490 mlir::createLinalgStrategyDecomposePass(LinalgTransformationFilter filter) {
491   return std::make_unique<LinalgStrategyDecomposePass>(filter);
492 }
493 
494 /// Create a LinalgStrategyInterchangePass.
495 std::unique_ptr<OperationPass<FuncOp>>
496 mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
497                                           LinalgTransformationFilter filter) {
498   return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
499                                                          filter);
500 }
501 
502 /// Create a LinalgStrategyVectorizePass.
503 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass(
504     StringRef opName, LinalgVectorizationOptions opt,
505     LinalgTransformationFilter filter, bool padVectorize) {
506   return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
507                                                        padVectorize);
508 }
509 
510 /// Create a LinalgStrategyEnablePass.
511 std::unique_ptr<OperationPass<FuncOp>>
512 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
513                                      LinalgTransformationFilter filter) {
514   return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
515 }
516 
517 /// Create a LinalgStrategyLowerVectorsPass.
518 std::unique_ptr<OperationPass<FuncOp>>
519 mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
520                                            LinalgTransformationFilter filter) {
521   return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
522 }
523 
524 /// Create a LinalgStrategyRemoveMarkersPass.
525 std::unique_ptr<OperationPass<FuncOp>>
526 mlir::createLinalgStrategyRemoveMarkersPass() {
527   return std::make_unique<LinalgStrategyRemoveMarkersPass>();
528 }
529