1 //===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a configurable pass that can apply patterns liberally
10 // and be plugged in a pass pipeline.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Affine/Utils.h"
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"
22 #include "mlir/Dialect/Linalg/Passes.h"
23 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
24 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
25 #include "mlir/Dialect/Linalg/Utils/Utils.h"
26 #include "mlir/Dialect/SCF/Transforms.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/AffineMap.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Support/LLVM.h"
33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
35 #include "mlir/Transforms/Passes.h"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 using namespace linalg;
40 
41 namespace {
42 
43 /// Configurable pass to apply pattern-based tiling and fusion.
44 struct LinalgStrategyTileAndFusePass
45     : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
46 
47   LinalgStrategyTileAndFusePass() = default;
48 
49   LinalgStrategyTileAndFusePass(StringRef opName,
50                                 LinalgTilingAndFusionOptions opt,
51                                 LinalgTransformationFilter filt)
52       : options(std::move(opt)), filter(std::move(filt)) {
53     this->anchorOpName.setValue(opName.str());
54   }
55 
56   void runOnOperation() override {
57     auto funcOp = getOperation();
58     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
59       return;
60 
61     RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
62     if (!anchorOpName.empty()) {
63       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
64           anchorOpName, funcOp.getContext(), options, filter);
65     } else {
66       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
67           funcOp.getContext(), options, filter);
68     }
69     // Search the root operation using bottom up traversal.
70     GreedyRewriteConfig config;
71     config.useTopDownTraversal = false;
72     (void)applyPatternsAndFoldGreedily(
73         funcOp, std::move(tilingAndFusionPattern), config);
74   }
75 
76   LinalgTilingAndFusionOptions options;
77   LinalgTransformationFilter filter;
78 };
79 
80 /// Configurable pass to apply pattern-based linalg tiling.
81 struct LinalgStrategyTilePass
82     : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
83 
84   LinalgStrategyTilePass() = default;
85 
86   LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
87                          LinalgTransformationFilter filt)
88       : options(std::move(opt)), filter(std::move(filt)) {
89     this->anchorOpName.setValue(opName.str());
90   }
91 
92   void runOnOperation() override {
93     auto funcOp = getOperation();
94     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
95       return;
96 
97     MLIRContext *ctx = funcOp.getContext();
98     RewritePatternSet tilingPattern(ctx);
99     if (!anchorOpName.empty())
100       tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options,
101                                              filter);
102     else
103       tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
104     if (anchorOpName == tensor::PadOp::getOperationName())
105       populatePadTensorTilingPatterns(tilingPattern, options);
106     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
107   }
108 
109   LinalgTilingOptions options;
110   LinalgTransformationFilter filter;
111 };
112 
113 /// Configurable pass to apply hoisting and padding.
114 struct LinalgStrategyPadPass
115     : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> {
116 
117   LinalgStrategyPadPass() = default;
118 
119   LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
120                         LinalgTransformationFilter filt)
121       : options(std::move(opt)), filter(std::move(filt)) {
122     this->anchorOpName.setValue(opName.str());
123   }
124 
125   void runOnOperation() override {
126     auto funcOp = getOperation();
127     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
128       return;
129 
130     RewritePatternSet paddingPattern(funcOp.getContext());
131     if (!anchorOpName.empty()) {
132       paddingPattern.add<LinalgPaddingPattern>(
133           anchorOpName, funcOp.getContext(), options, filter);
134     } else {
135       paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options,
136                                                filter);
137     }
138     (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern));
139   }
140 
141   LinalgPaddingOptions options;
142   LinalgTransformationFilter filter;
143 };
144 
145 /// Configurable pass to apply pattern-based linalg generalization.
146 struct LinalgStrategyGeneralizePass
147     : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
148 
149   LinalgStrategyGeneralizePass() = default;
150 
151   LinalgStrategyGeneralizePass(StringRef opName,
152                                LinalgTransformationFilter filter)
153       : filter(std::move(filter)) {
154     this->anchorOpName.setValue(opName.str());
155   }
156 
157   void runOnOperation() override {
158     auto funcOp = getOperation();
159     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
160       return;
161 
162     RewritePatternSet generalizationPattern(funcOp.getContext());
163     if (!anchorOpName.empty()) {
164       generalizationPattern.add<LinalgGeneralizationPattern>(
165           anchorOpName, funcOp.getContext(), filter);
166     } else {
167       generalizationPattern.add<LinalgGeneralizationPattern>(
168           funcOp.getContext(), filter);
169     }
170     if (failed(applyPatternsAndFoldGreedily(funcOp,
171                                             std::move(generalizationPattern))))
172       signalPassFailure();
173   }
174 
175   LinalgTransformationFilter filter;
176 };
177 
178 /// Configurable pass to apply lowering of coarser-grained named linalg ops into
179 /// finer-grained named versions.
180 struct LinalgStrategyDecomposePass
181     : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {
182 
183   LinalgStrategyDecomposePass() = default;
184 
185   LinalgStrategyDecomposePass(LinalgTransformationFilter filter)
186       : filter(std::move(filter)) {}
187 
188   void runOnOperation() override {
189     auto funcOp = getOperation();
190     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
191       return;
192     RewritePatternSet decompositionPattern(funcOp.getContext());
193     populateDecomposeConvolutionPatterns(decompositionPattern, filter);
194     if (failed(applyPatternsAndFoldGreedily(funcOp,
195                                             std::move(decompositionPattern))))
196       signalPassFailure();
197   }
198 
199   LinalgTransformationFilter filter;
200 };
201 
202 /// Configurable pass to apply pattern-based linalg generalization.
203 struct LinalgStrategyInterchangePass
204     : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
205 
206   LinalgStrategyInterchangePass() = default;
207 
208   LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
209                                 LinalgTransformationFilter filter)
210       : iteratorInterchange(iteratorInterchange.begin(),
211                             iteratorInterchange.end()),
212         filter(std::move(filter)) {}
213 
214   void runOnOperation() override {
215     auto funcOp = getOperation();
216     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
217       return;
218 
219     SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
220                                             iteratorInterchange.end());
221     RewritePatternSet interchangePattern(funcOp.getContext());
222     interchangePattern.add<GenericOpInterchangePattern>(
223         funcOp.getContext(), interchangeVector, filter);
224     if (failed(applyPatternsAndFoldGreedily(funcOp,
225                                             std::move(interchangePattern))))
226       signalPassFailure();
227   }
228 
229   SmallVector<int64_t> iteratorInterchange;
230   LinalgTransformationFilter filter;
231 };
232 
233 /// Configurable pass to apply pattern-based linalg promotion.
234 struct LinalgStrategyPromotePass
235     : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
236 
237   LinalgStrategyPromotePass() = default;
238 
239   LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
240                             LinalgTransformationFilter filt)
241       : options(std::move(opt)), filter(std::move(filt)) {
242     this->anchorOpName.setValue(opName.str());
243   }
244 
245   void runOnOperation() override {
246     auto funcOp = getOperation();
247     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
248       return;
249 
250     RewritePatternSet promotionPattern(funcOp.getContext());
251     if (!anchorOpName.empty()) {
252       promotionPattern.add<LinalgBasePromotionPattern>(
253           anchorOpName, funcOp.getContext(), options, filter);
254     } else {
255       promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
256                                                        filter, options);
257     }
258     (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
259   }
260 
261   LinalgPromotionOptions options;
262   LinalgTransformationFilter filter;
263 };
264 
265 /// Configurable pass to apply pattern-based linalg vectorization.
266 struct LinalgStrategyVectorizePass
267     : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
268 
269   LinalgStrategyVectorizePass() = default;
270 
271   LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
272                               LinalgTransformationFilter filt,
273                               bool padVectorize = false)
274       : options(opt), filter(std::move(filt)) {
275     this->anchorOpName.setValue(opName.str());
276     this->vectorizePadding.setValue(padVectorize);
277   }
278 
279   void runOnOperation() override {
280     auto funcOp = getOperation();
281     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
282       return;
283 
284     RewritePatternSet vectorizationPatterns(funcOp.getContext());
285     if (!anchorOpName.empty()) {
286       vectorizationPatterns.add<LinalgVectorizationPattern>(
287           anchorOpName, funcOp.getContext(), options, filter);
288     } else {
289       vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
290                                                             filter, options);
291     }
292     vector::populateVectorTransferPermutationMapLoweringPatterns(
293         vectorizationPatterns);
294     vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
295     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
296                               linalg::LinalgCopyVTWForwardingPattern>(
297         funcOp.getContext(), /*benefit=*/2);
298     TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns,
299                                                 funcOp.getContext());
300     TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns,
301                                                  funcOp.getContext());
302     (void)applyPatternsAndFoldGreedily(funcOp,
303                                        std::move(vectorizationPatterns));
304 
305     // Apply the pad tensor op vectorization separately to avoid running the
306     // GenericPadOpVectorizationPattern too early.
307     // TODO: Improve once we have better infrastructure to control pattern
308     // application.
309     if (vectorizePadding) {
310       RewritePatternSet patterns(funcOp.getContext());
311       linalg::populatePadOpVectorizationPatterns(patterns);
312       (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
313     }
314   }
315 
316   LinalgVectorizationOptions options;
317   LinalgTransformationFilter filter;
318 };
319 
320 /// Configurable pass to enable the application of other pattern-based linalg
321 /// passes.
322 struct LinalgStrategyEnablePass
323     : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
324 
325   LinalgStrategyEnablePass(LinalgEnablingOptions opt,
326                            LinalgTransformationFilter filt)
327       : options(opt), filter(std::move(filt)) {}
328 
329   void runOnOperation() override {
330     auto funcOp = getOperation();
331     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
332       return;
333 
334     MLIRContext *context = funcOp.getContext();
335     RewritePatternSet patterns =
336         linalg::getLinalgTilingCanonicalizationPatterns(context);
337     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
338     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
339       return signalPassFailure();
340 
341     if (options.licm) {
342       funcOp->walk([&](LoopLikeOpInterface loopLike) {
343         moveLoopInvariantCode(loopLike);
344       });
345     }
346 
347     // Gathers all innermost loops through a post order pruned walk.
348     funcOp.walk([](Operation *op) {
349       if (auto forOp = dyn_cast<AffineForOp>(op))
350         (void)promoteIfSingleIteration(forOp);
351       else if (auto forOp = dyn_cast<scf::ForOp>(op))
352         (void)promoteIfSingleIteration(forOp);
353     });
354     if (options.hoistRedundantVectorTransfers)
355       hoistRedundantVectorTransfers(funcOp);
356 
357     if (options.hoistRedundantVectorTransfersOnTensor)
358       hoistRedundantVectorTransfersOnTensor(funcOp);
359 
360     // Run CSE to cleanup after canonicalization.
361     OpPassManager dynamicPM("func.func");
362     dynamicPM.addPass(createCSEPass());
363     if (failed(runPipeline(dynamicPM, funcOp)))
364       return signalPassFailure();
365   }
366 
367   LinalgEnablingOptions options;
368   LinalgTransformationFilter filter;
369 };
370 
371 /// Configurable pass to lower vector operations.
372 struct LinalgStrategyLowerVectorsPass
373     : public LinalgStrategyLowerVectorsPassBase<
374           LinalgStrategyLowerVectorsPass> {
375 
376   LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
377                                  LinalgTransformationFilter filt)
378       : options(opt), filter(std::move(filt)) {}
379 
380   void runOnOperation() override {
381     auto funcOp = getOperation();
382     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
383       return;
384 
385     MLIRContext *context = funcOp.getContext();
386     RewritePatternSet patterns(context);
387     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
388     // In a progressive lowering of vectors, this would be the 1st step.
389     if (options.contractionLowering) {
390       patterns.add<ContractionOpToOuterProductOpLowering,
391                    ContractionOpToMatmulOpLowering, ContractionOpLowering>(
392           options.vectorTransformOptions, context);
393       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
394     }
395     // In a progressive lowering of vectors, this would be the 2nd step.
396     if (options.multiReductionLowering) {
397       vector::populateVectorMultiReductionLoweringPatterns(
398           patterns,
399           options.vectorTransformOptions.vectorMultiReductionLowering);
400     }
401     // In a progressive lowering of vectors, this would be the 3rd step.
402     if (options.transferPartialRewrite) {
403       patterns.add<vector::VectorTransferFullPartialRewriter>(
404           context, options.vectorTransformOptions);
405     }
406     // In a progressive lowering of vectors, this would be the 4th step.
407     if (options.transferLowering) {
408       vector::populateVectorTransferLoweringPatterns(patterns,
409                                                      options.maxTransferRank);
410     }
411     // In a progressive lowering of vectors, this would be the 5th step.
412     if (options.transferToSCFConversion) {
413       populateVectorToSCFConversionPatterns(
414           patterns, options.vectorTransferToSCFOptions.setTargetRank(
415                         options.maxTransferRank));
416     }
417     // In a progressive lowering of vectors, this would be the 6th step.
418     if (options.shapeCastLowering) {
419       vector::populateVectorShapeCastLoweringPatterns(patterns);
420     }
421     // In a progressive lowering of vectors, this would be the 7th step.
422     if (options.transposeLowering) {
423       vector::populateVectorTransposeLoweringPatterns(
424           patterns, options.vectorTransformOptions);
425       if (options.avx2Lowering)
426         x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
427             patterns, options.avx2LoweringOptions, /*benefit=*/10);
428     }
429     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
430   }
431 
432   LinalgVectorLoweringOptions options;
433   LinalgTransformationFilter filter;
434 };
435 
436 /// Configurable pass to lower vector operations.
437 struct LinalgStrategyRemoveMarkersPass
438     : public LinalgStrategyRemoveMarkersPassBase<
439           LinalgStrategyRemoveMarkersPass> {
440 
441   void runOnOperation() override {
442     auto funcOp = getOperation();
443     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
444       return;
445     funcOp.walk([](LinalgOp op) {
446       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
447     });
448   }
449 };
450 } // namespace
451 
452 /// Create a LinalgStrategyTileAndFusePass.
453 std::unique_ptr<OperationPass<func::FuncOp>>
454 mlir::createLinalgStrategyTileAndFusePass(
455     StringRef opName, const LinalgTilingAndFusionOptions &options,
456     const LinalgTransformationFilter &filter) {
457   return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
458                                                          filter);
459 }
460 
461 /// Create a LinalgStrategyTilePass.
462 std::unique_ptr<OperationPass<func::FuncOp>>
463 mlir::createLinalgStrategyTilePass(StringRef opName,
464                                    const LinalgTilingOptions &opt,
465                                    const LinalgTransformationFilter &filter) {
466   return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
467 }
468 
469 /// Create a LinalgStrategyPadPass.
470 std::unique_ptr<OperationPass<func::FuncOp>>
471 mlir::createLinalgStrategyPadPass(StringRef opName,
472                                   const LinalgPaddingOptions &opt,
473                                   const LinalgTransformationFilter &filter) {
474   return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
475 }
476 
477 /// Create a LinalgStrategyPromotePass.
478 std::unique_ptr<OperationPass<func::FuncOp>>
479 mlir::createLinalgStrategyPromotePass(
480     StringRef opName, const LinalgPromotionOptions &opt,
481     const LinalgTransformationFilter &filter) {
482   return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
483 }
484 
485 /// Create a LinalgStrategyGeneralizePass.
486 std::unique_ptr<OperationPass<func::FuncOp>>
487 mlir::createLinalgStrategyGeneralizePass(
488     StringRef opName, const LinalgTransformationFilter &filter) {
489   return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
490 }
491 
492 /// Create a LinalgStrategyDecomposePass.
493 // TODO: if/when we need finer control add an `opName` parameter.
494 std::unique_ptr<OperationPass<func::FuncOp>>
495 mlir::createLinalgStrategyDecomposePass(
496     const LinalgTransformationFilter &filter) {
497   return std::make_unique<LinalgStrategyDecomposePass>(filter);
498 }
499 
500 /// Create a LinalgStrategyInterchangePass.
501 std::unique_ptr<OperationPass<func::FuncOp>>
502 mlir::createLinalgStrategyInterchangePass(
503     ArrayRef<int64_t> iteratorInterchange,
504     const LinalgTransformationFilter &filter) {
505   return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
506                                                          filter);
507 }
508 
509 /// Create a LinalgStrategyVectorizePass.
510 std::unique_ptr<OperationPass<func::FuncOp>>
511 mlir::createLinalgStrategyVectorizePass(
512     StringRef opName, LinalgVectorizationOptions opt,
513     const LinalgTransformationFilter &filter, bool padVectorize) {
514   return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
515                                                        padVectorize);
516 }
517 
518 /// Create a LinalgStrategyEnablePass.
519 std::unique_ptr<OperationPass<func::FuncOp>>
520 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
521                                      const LinalgTransformationFilter &filter) {
522   return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
523 }
524 
525 /// Create a LinalgStrategyLowerVectorsPass.
526 std::unique_ptr<OperationPass<func::FuncOp>>
527 mlir::createLinalgStrategyLowerVectorsPass(
528     LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) {
529   return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
530 }
531 
532 /// Create a LinalgStrategyRemoveMarkersPass.
533 std::unique_ptr<OperationPass<func::FuncOp>>
534 mlir::createLinalgStrategyRemoveMarkersPass() {
535   return std::make_unique<LinalgStrategyRemoveMarkersPass>();
536 }
537