1 //===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a configurable pass that can apply patterns liberally
10 // and be plugged in a pass pipeline.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Passes.h"
21 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
22 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/SCF/Transforms.h"
25 #include "mlir/Dialect/Tensor/IR/Tensor.h"
26 #include "mlir/Dialect/Vector/VectorTransforms.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/Pass/PassManager.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "mlir/Transforms/LoopUtils.h"
33 #include "mlir/Transforms/Passes.h"
34 #include "mlir/Transforms/Utils.h"
35 
36 using namespace mlir;
37 using namespace mlir::vector;
38 using namespace linalg;
39 
40 namespace {
41 
42 /// Configurable pass to apply pattern-based tiling and fusion.
43 struct LinalgStrategyTileAndFusePass
44     : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
45 
46   LinalgStrategyTileAndFusePass() = default;
47 
48   LinalgStrategyTileAndFusePass(StringRef opName,
49                                 LinalgTilingAndFusionOptions opt,
50                                 LinalgTransformationFilter filt)
51       : options(std::move(opt)), filter(std::move(filt)) {
52     this->anchorOpName.setValue(opName.str());
53   }
54 
55   void runOnFunction() override {
56     auto funcOp = getFunction();
57     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
58       return;
59 
60     RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
61     if (!anchorOpName.empty()) {
62       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
63           anchorOpName, funcOp.getContext(), options, filter);
64     } else {
65       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
66           funcOp.getContext(), options, filter);
67     }
68     // Search the root operation using bottom up traversal.
69     GreedyRewriteConfig config;
70     config.useTopDownTraversal = false;
71     (void)applyPatternsAndFoldGreedily(
72         funcOp, std::move(tilingAndFusionPattern), config);
73   }
74 
75   LinalgTilingAndFusionOptions options;
76   LinalgTransformationFilter filter;
77 };
78 
79 /// Configurable pass to apply pattern-based linalg tiling.
80 struct LinalgStrategyTilePass
81     : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
82 
83   LinalgStrategyTilePass() = default;
84 
85   LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
86                          LinalgTransformationFilter filt)
87       : options(std::move(opt)), filter(std::move(filt)) {
88     this->anchorOpName.setValue(opName.str());
89   }
90 
91   void runOnFunction() override {
92     auto funcOp = getFunction();
93     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
94       return;
95 
96     MLIRContext *ctx = funcOp.getContext();
97     RewritePatternSet tilingPattern(ctx);
98     if (!anchorOpName.empty())
99       tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options,
100                                              filter);
101     else
102       tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
103     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
104   }
105 
106   LinalgTilingOptions options;
107   LinalgTransformationFilter filter;
108 };
109 
110 /// Configurable pass to apply hoisting and padding.
111 struct LinalgStrategyPadPass
112     : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> {
113 
114   LinalgStrategyPadPass() = default;
115 
116   LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
117                         LinalgTransformationFilter filt)
118       : options(std::move(opt)), filter(std::move(filt)) {
119     this->anchorOpName.setValue(opName.str());
120   }
121 
122   void runOnFunction() override {
123     auto funcOp = getFunction();
124     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
125       return;
126 
127     RewritePatternSet paddingPattern(funcOp.getContext());
128     if (!anchorOpName.empty()) {
129       paddingPattern.add<LinalgPaddingPattern>(
130           anchorOpName, funcOp.getContext(), options, filter);
131     } else {
132       paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options,
133                                                filter);
134     }
135     (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern));
136   }
137 
138   LinalgPaddingOptions options;
139   LinalgTransformationFilter filter;
140 };
141 
142 /// Configurable pass to apply pattern-based linalg generalization.
143 struct LinalgStrategyGeneralizePass
144     : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
145 
146   LinalgStrategyGeneralizePass() = default;
147 
148   LinalgStrategyGeneralizePass(StringRef opName,
149                                LinalgTransformationFilter filter)
150       : filter(std::move(filter)) {
151     this->anchorOpName.setValue(opName.str());
152   }
153 
154   void runOnFunction() override {
155     auto funcOp = getFunction();
156     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
157       return;
158 
159     RewritePatternSet generalizationPattern(funcOp.getContext());
160     if (!anchorOpName.empty()) {
161       generalizationPattern.add<LinalgGeneralizationPattern>(
162           anchorOpName, funcOp.getContext(), filter);
163     } else {
164       generalizationPattern.add<LinalgGeneralizationPattern>(
165           funcOp.getContext(), filter);
166     }
167     if (failed(applyPatternsAndFoldGreedily(funcOp,
168                                             std::move(generalizationPattern))))
169       signalPassFailure();
170   }
171 
172   LinalgTransformationFilter filter;
173 };
174 
175 /// Configurable pass to apply lowering of coarser-grained named linalg ops into
176 /// finer-grained named versions.
177 struct LinalgStrategyDecomposePass
178     : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {
179 
180   LinalgStrategyDecomposePass() = default;
181 
182   LinalgStrategyDecomposePass(LinalgTransformationFilter filter)
183       : filter(std::move(filter)) {}
184 
185   void runOnFunction() override {
186     auto funcOp = getFunction();
187     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
188       return;
189     RewritePatternSet decompositionPattern(funcOp.getContext());
190     populateDecomposeConvolutionPatterns(decompositionPattern, filter);
191     if (failed(applyPatternsAndFoldGreedily(funcOp,
192                                             std::move(decompositionPattern))))
193       signalPassFailure();
194   }
195 
196   LinalgTransformationFilter filter;
197 };
198 
199 /// Configurable pass to apply pattern-based linalg generalization.
200 struct LinalgStrategyInterchangePass
201     : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
202 
203   LinalgStrategyInterchangePass() = default;
204 
205   LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
206                                 LinalgTransformationFilter filter)
207       : iteratorInterchange(iteratorInterchange.begin(),
208                             iteratorInterchange.end()),
209         filter(std::move(filter)) {}
210 
211   void runOnFunction() override {
212     auto funcOp = getFunction();
213     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
214       return;
215 
216     SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
217                                             iteratorInterchange.end());
218     RewritePatternSet interchangePattern(funcOp.getContext());
219     interchangePattern.add<GenericOpInterchangePattern>(
220         funcOp.getContext(), interchangeVector, filter);
221     if (failed(applyPatternsAndFoldGreedily(funcOp,
222                                             std::move(interchangePattern))))
223       signalPassFailure();
224   }
225 
226   SmallVector<int64_t> iteratorInterchange;
227   LinalgTransformationFilter filter;
228 };
229 
230 /// Configurable pass to apply pattern-based linalg promotion.
231 struct LinalgStrategyPromotePass
232     : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
233 
234   LinalgStrategyPromotePass() = default;
235 
236   LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
237                             LinalgTransformationFilter filt)
238       : options(std::move(opt)), filter(std::move(filt)) {
239     this->anchorOpName.setValue(opName.str());
240   }
241 
242   void runOnFunction() override {
243     auto funcOp = getFunction();
244     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
245       return;
246 
247     RewritePatternSet promotionPattern(funcOp.getContext());
248     if (!anchorOpName.empty()) {
249       promotionPattern.add<LinalgBasePromotionPattern>(
250           anchorOpName, funcOp.getContext(), options, filter);
251     } else {
252       promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
253                                                        filter, options);
254     }
255     (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
256   }
257 
258   LinalgPromotionOptions options;
259   LinalgTransformationFilter filter;
260 };
261 
262 /// Configurable pass to apply pattern-based linalg vectorization.
263 struct LinalgStrategyVectorizePass
264     : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
265 
266   LinalgStrategyVectorizePass() = default;
267 
268   LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
269                               LinalgTransformationFilter filt,
270                               bool padVectorize = false)
271       : options(opt), filter(std::move(filt)) {
272     this->anchorOpName.setValue(opName.str());
273     this->vectorizePadding.setValue(padVectorize);
274   }
275 
276   void runOnFunction() override {
277     auto funcOp = getFunction();
278     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
279       return;
280 
281     RewritePatternSet vectorizationPatterns(funcOp.getContext());
282     if (!anchorOpName.empty()) {
283       vectorizationPatterns.add<LinalgVectorizationPattern>(
284           anchorOpName, funcOp.getContext(), options, filter);
285     } else {
286       vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
287                                                             filter, options);
288     }
289     vector::populateVectorTransferPermutationMapLoweringPatterns(
290         vectorizationPatterns);
291     vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
292     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
293                               linalg::LinalgCopyVTWForwardingPattern>(
294         funcOp.getContext(), /*benefit=*/2);
295     TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns,
296                                                 funcOp.getContext());
297     TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns,
298                                                  funcOp.getContext());
299     (void)applyPatternsAndFoldGreedily(funcOp,
300                                        std::move(vectorizationPatterns));
301 
302     // Apply the pad tensor op vectorization separately to avoid running the
303     // GenericPadTensorOpVectorizationPattern too early.
304     // TODO: Improve once we have better infrastructure to control pattern
305     // application.
306     if (vectorizePadding) {
307       RewritePatternSet patterns(funcOp.getContext());
308       linalg::populatePadTensorOpVectorizationPatterns(patterns);
309       (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
310     }
311   }
312 
313   LinalgVectorizationOptions options;
314   LinalgTransformationFilter filter;
315 };
316 
317 /// Configurable pass to enable the application of other pattern-based linalg
318 /// passes.
319 struct LinalgStrategyEnablePass
320     : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
321 
322   LinalgStrategyEnablePass(LinalgEnablingOptions opt,
323                            LinalgTransformationFilter filt)
324       : options(opt), filter(std::move(filt)) {}
325 
326   void runOnFunction() override {
327     auto funcOp = getFunction();
328     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
329       return;
330 
331     MLIRContext *context = funcOp.getContext();
332     RewritePatternSet patterns =
333         linalg::getLinalgTilingCanonicalizationPatterns(context);
334     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
335     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
336       return signalPassFailure();
337 
338     if (options.licm) {
339       if (funcOp
340               ->walk([&](LoopLikeOpInterface loopLike) {
341                 if (failed(moveLoopInvariantCode(loopLike)))
342                   return WalkResult::interrupt();
343                 return WalkResult::advance();
344               })
345               .wasInterrupted())
346         return signalPassFailure();
347     }
348 
349     promoteSingleIterationLoops(funcOp);
350     if (options.hoistRedundantVectorTransfers)
351       hoistRedundantVectorTransfers(funcOp);
352 
353     if (options.hoistRedundantVectorTransfersOnTensor)
354       hoistRedundantVectorTransfersOnTensor(funcOp);
355 
356     // Run CSE to cleanup after canonicalization.
357     OpPassManager dynamicPM("builtin.func");
358     dynamicPM.addPass(createCSEPass());
359     if (failed(runPipeline(dynamicPM, funcOp)))
360       return signalPassFailure();
361   }
362 
363   LinalgEnablingOptions options;
364   LinalgTransformationFilter filter;
365 };
366 
367 /// Configurable pass to lower vector operations.
368 struct LinalgStrategyLowerVectorsPass
369     : public LinalgStrategyLowerVectorsPassBase<
370           LinalgStrategyLowerVectorsPass> {
371 
372   LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
373                                  LinalgTransformationFilter filt)
374       : options(opt), filter(std::move(filt)) {}
375 
376   void runOnFunction() override {
377     auto funcOp = getFunction();
378     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
379       return;
380 
381     MLIRContext *context = funcOp.getContext();
382     RewritePatternSet patterns(context);
383     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
384     // In a progressive lowering of vectors, this would be the 1st step.
385     if (options.contractionLowering) {
386       patterns.add<ContractionOpToOuterProductOpLowering,
387                    ContractionOpToMatmulOpLowering, ContractionOpLowering>(
388           options.vectorTransformOptions, context);
389       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
390     }
391     // In a progressive lowering of vectors, this would be the 2nd step.
392     if (options.multiReductionLowering) {
393       vector::populateVectorMultiReductionLoweringPatterns(
394           patterns,
395           options.vectorTransformOptions.vectorMultiReductionLowering);
396     }
397     // In a progressive lowering of vectors, this would be the 3rd step.
398     if (options.transferPartialRewrite) {
399       patterns.add<vector::VectorTransferFullPartialRewriter>(
400           context, options.vectorTransformOptions);
401     }
402     // In a progressive lowering of vectors, this would be the 4th step.
403     if (options.transferLowering) {
404       vector::populateVectorTransferLoweringPatterns(patterns,
405                                                      options.maxTransferRank);
406     }
407     // In a progressive lowering of vectors, this would be the 5th step.
408     if (options.transferToSCFConversion) {
409       populateVectorToSCFConversionPatterns(
410           patterns, options.vectorTransferToSCFOptions.setTargetRank(
411                         options.maxTransferRank));
412     }
413     // In a progressive lowering of vectors, this would be the 6th step.
414     if (options.shapeCastLowering) {
415       vector::populateVectorShapeCastLoweringPatterns(patterns);
416     }
417     // In a progressive lowering of vectors, this would be the 7th step.
418     if (options.transposeLowering) {
419       vector::populateVectorTransposeLoweringPatterns(
420           patterns, options.vectorTransformOptions);
421       if (options.avx2Lowering)
422         x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
423             patterns, options.avx2LoweringOptions, /*benefit=*/10);
424     }
425     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
426   }
427 
428   LinalgVectorLoweringOptions options;
429   LinalgTransformationFilter filter;
430 };
431 
432 /// Configurable pass to lower vector operations.
433 struct LinalgStrategyRemoveMarkersPass
434     : public LinalgStrategyRemoveMarkersPassBase<
435           LinalgStrategyRemoveMarkersPass> {
436 
437   void runOnFunction() override {
438     auto funcOp = getFunction();
439     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
440       return;
441     funcOp.walk([](LinalgOp op) {
442       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
443     });
444   }
445 };
446 } // namespace
447 
448 /// Create a LinalgStrategyTileAndFusePass.
449 std::unique_ptr<OperationPass<FuncOp>>
450 mlir::createLinalgStrategyTileAndFusePass(
451     StringRef opName, const LinalgTilingAndFusionOptions &options,
452     const LinalgTransformationFilter &filter) {
453   return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
454                                                          filter);
455 }
456 
457 /// Create a LinalgStrategyTilePass.
458 std::unique_ptr<OperationPass<FuncOp>>
459 mlir::createLinalgStrategyTilePass(StringRef opName,
460                                    const LinalgTilingOptions &opt,
461                                    const LinalgTransformationFilter &filter) {
462   return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
463 }
464 
465 /// Create a LinalgStrategyPadPass.
466 std::unique_ptr<OperationPass<FuncOp>>
467 mlir::createLinalgStrategyPadPass(StringRef opName,
468                                   const LinalgPaddingOptions &opt,
469                                   const LinalgTransformationFilter &filter) {
470   return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
471 }
472 
473 /// Create a LinalgStrategyPromotePass.
474 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyPromotePass(
475     StringRef opName, const LinalgPromotionOptions &opt,
476     const LinalgTransformationFilter &filter) {
477   return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
478 }
479 
480 /// Create a LinalgStrategyGeneralizePass.
481 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyGeneralizePass(
482     StringRef opName, const LinalgTransformationFilter &filter) {
483   return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
484 }
485 
486 /// Create a LinalgStrategyDecomposePass.
487 // TODO: if/when we need finer control add an `opName` parameter.
488 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyDecomposePass(
489     const LinalgTransformationFilter &filter) {
490   return std::make_unique<LinalgStrategyDecomposePass>(filter);
491 }
492 
493 /// Create a LinalgStrategyInterchangePass.
494 std::unique_ptr<OperationPass<FuncOp>>
495 mlir::createLinalgStrategyInterchangePass(
496     ArrayRef<int64_t> iteratorInterchange,
497     const LinalgTransformationFilter &filter) {
498   return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
499                                                          filter);
500 }
501 
502 /// Create a LinalgStrategyVectorizePass.
503 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass(
504     StringRef opName, LinalgVectorizationOptions opt,
505     const LinalgTransformationFilter &filter, bool padVectorize) {
506   return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
507                                                        padVectorize);
508 }
509 
510 /// Create a LinalgStrategyEnablePass.
511 std::unique_ptr<OperationPass<FuncOp>>
512 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
513                                      const LinalgTransformationFilter &filter) {
514   return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
515 }
516 
517 /// Create a LinalgStrategyLowerVectorsPass.
518 std::unique_ptr<OperationPass<FuncOp>>
519 mlir::createLinalgStrategyLowerVectorsPass(
520     LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) {
521   return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
522 }
523 
524 /// Create a LinalgStrategyRemoveMarkersPass.
525 std::unique_ptr<OperationPass<FuncOp>>
526 mlir::createLinalgStrategyRemoveMarkersPass() {
527   return std::make_unique<LinalgStrategyRemoveMarkersPass>();
528 }
529