1 //===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a configurable pass that can apply patterns liberally
10 // and be plugged in a pass pipeline.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Passes.h"
21 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
22 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/SCF/Transforms.h"
25 #include "mlir/Dialect/Tensor/IR/Tensor.h"
26 #include "mlir/Dialect/Vector/VectorTransforms.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/Pass/PassManager.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "mlir/Transforms/LoopUtils.h"
33 #include "mlir/Transforms/Passes.h"
34 #include "mlir/Transforms/Utils.h"
35 
36 using namespace mlir;
37 using namespace mlir::vector;
38 using namespace linalg;
39 
40 namespace {
41 
42 /// Configurable pass to apply pattern-based tiling and fusion.
43 struct LinalgStrategyTileAndFusePass
44     : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
45 
46   LinalgStrategyTileAndFusePass() = default;
47 
48   LinalgStrategyTileAndFusePass(StringRef opName,
49                                 LinalgTilingAndFusionOptions opt,
50                                 LinalgTransformationFilter filt)
51       : options(std::move(opt)), filter(std::move(filt)) {
52     this->anchorOpName.setValue(opName.str());
53   }
54 
55   void runOnOperation() override {
56     auto funcOp = getOperation();
57     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
58       return;
59 
60     RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
61     if (!anchorOpName.empty()) {
62       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
63           anchorOpName, funcOp.getContext(), options, filter);
64     } else {
65       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
66           funcOp.getContext(), options, filter);
67     }
68     // Search the root operation using bottom up traversal.
69     GreedyRewriteConfig config;
70     config.useTopDownTraversal = false;
71     (void)applyPatternsAndFoldGreedily(
72         funcOp, std::move(tilingAndFusionPattern), config);
73   }
74 
75   LinalgTilingAndFusionOptions options;
76   LinalgTransformationFilter filter;
77 };
78 
79 /// Configurable pass to apply pattern-based linalg tiling.
80 struct LinalgStrategyTilePass
81     : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
82 
83   LinalgStrategyTilePass() = default;
84 
85   LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
86                          LinalgTransformationFilter filt)
87       : options(std::move(opt)), filter(std::move(filt)) {
88     this->anchorOpName.setValue(opName.str());
89   }
90 
91   void runOnOperation() override {
92     auto funcOp = getOperation();
93     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
94       return;
95 
96     MLIRContext *ctx = funcOp.getContext();
97     RewritePatternSet tilingPattern(ctx);
98     if (!anchorOpName.empty())
99       tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options,
100                                              filter);
101     else
102       tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
103     if (anchorOpName == linalg::PadTensorOp::getOperationName())
104       populatePadTensorTilingPatterns(tilingPattern, options);
105     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
106   }
107 
108   LinalgTilingOptions options;
109   LinalgTransformationFilter filter;
110 };
111 
112 /// Configurable pass to apply hoisting and padding.
113 struct LinalgStrategyPadPass
114     : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> {
115 
116   LinalgStrategyPadPass() = default;
117 
118   LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
119                         LinalgTransformationFilter filt)
120       : options(std::move(opt)), filter(std::move(filt)) {
121     this->anchorOpName.setValue(opName.str());
122   }
123 
124   void runOnOperation() override {
125     auto funcOp = getOperation();
126     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
127       return;
128 
129     RewritePatternSet paddingPattern(funcOp.getContext());
130     if (!anchorOpName.empty()) {
131       paddingPattern.add<LinalgPaddingPattern>(
132           anchorOpName, funcOp.getContext(), options, filter);
133     } else {
134       paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options,
135                                                filter);
136     }
137     (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern));
138   }
139 
140   LinalgPaddingOptions options;
141   LinalgTransformationFilter filter;
142 };
143 
144 /// Configurable pass to apply pattern-based linalg generalization.
145 struct LinalgStrategyGeneralizePass
146     : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
147 
148   LinalgStrategyGeneralizePass() = default;
149 
150   LinalgStrategyGeneralizePass(StringRef opName,
151                                LinalgTransformationFilter filter)
152       : filter(std::move(filter)) {
153     this->anchorOpName.setValue(opName.str());
154   }
155 
156   void runOnOperation() override {
157     auto funcOp = getOperation();
158     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
159       return;
160 
161     RewritePatternSet generalizationPattern(funcOp.getContext());
162     if (!anchorOpName.empty()) {
163       generalizationPattern.add<LinalgGeneralizationPattern>(
164           anchorOpName, funcOp.getContext(), filter);
165     } else {
166       generalizationPattern.add<LinalgGeneralizationPattern>(
167           funcOp.getContext(), filter);
168     }
169     if (failed(applyPatternsAndFoldGreedily(funcOp,
170                                             std::move(generalizationPattern))))
171       signalPassFailure();
172   }
173 
174   LinalgTransformationFilter filter;
175 };
176 
177 /// Configurable pass to apply lowering of coarser-grained named linalg ops into
178 /// finer-grained named versions.
179 struct LinalgStrategyDecomposePass
180     : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {
181 
182   LinalgStrategyDecomposePass() = default;
183 
184   LinalgStrategyDecomposePass(LinalgTransformationFilter filter)
185       : filter(std::move(filter)) {}
186 
187   void runOnOperation() override {
188     auto funcOp = getOperation();
189     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
190       return;
191     RewritePatternSet decompositionPattern(funcOp.getContext());
192     populateDecomposeConvolutionPatterns(decompositionPattern, filter);
193     if (failed(applyPatternsAndFoldGreedily(funcOp,
194                                             std::move(decompositionPattern))))
195       signalPassFailure();
196   }
197 
198   LinalgTransformationFilter filter;
199 };
200 
201 /// Configurable pass to apply pattern-based linalg generalization.
202 struct LinalgStrategyInterchangePass
203     : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
204 
205   LinalgStrategyInterchangePass() = default;
206 
207   LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
208                                 LinalgTransformationFilter filter)
209       : iteratorInterchange(iteratorInterchange.begin(),
210                             iteratorInterchange.end()),
211         filter(std::move(filter)) {}
212 
213   void runOnOperation() override {
214     auto funcOp = getOperation();
215     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
216       return;
217 
218     SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
219                                             iteratorInterchange.end());
220     RewritePatternSet interchangePattern(funcOp.getContext());
221     interchangePattern.add<GenericOpInterchangePattern>(
222         funcOp.getContext(), interchangeVector, filter);
223     if (failed(applyPatternsAndFoldGreedily(funcOp,
224                                             std::move(interchangePattern))))
225       signalPassFailure();
226   }
227 
228   SmallVector<int64_t> iteratorInterchange;
229   LinalgTransformationFilter filter;
230 };
231 
232 /// Configurable pass to apply pattern-based linalg promotion.
233 struct LinalgStrategyPromotePass
234     : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
235 
236   LinalgStrategyPromotePass() = default;
237 
238   LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
239                             LinalgTransformationFilter filt)
240       : options(std::move(opt)), filter(std::move(filt)) {
241     this->anchorOpName.setValue(opName.str());
242   }
243 
244   void runOnOperation() override {
245     auto funcOp = getOperation();
246     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
247       return;
248 
249     RewritePatternSet promotionPattern(funcOp.getContext());
250     if (!anchorOpName.empty()) {
251       promotionPattern.add<LinalgBasePromotionPattern>(
252           anchorOpName, funcOp.getContext(), options, filter);
253     } else {
254       promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
255                                                        filter, options);
256     }
257     (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
258   }
259 
260   LinalgPromotionOptions options;
261   LinalgTransformationFilter filter;
262 };
263 
264 /// Configurable pass to apply pattern-based linalg vectorization.
265 struct LinalgStrategyVectorizePass
266     : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
267 
268   LinalgStrategyVectorizePass() = default;
269 
270   LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
271                               LinalgTransformationFilter filt,
272                               bool padVectorize = false)
273       : options(opt), filter(std::move(filt)) {
274     this->anchorOpName.setValue(opName.str());
275     this->vectorizePadding.setValue(padVectorize);
276   }
277 
278   void runOnOperation() override {
279     auto funcOp = getOperation();
280     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
281       return;
282 
283     RewritePatternSet vectorizationPatterns(funcOp.getContext());
284     if (!anchorOpName.empty()) {
285       vectorizationPatterns.add<LinalgVectorizationPattern>(
286           anchorOpName, funcOp.getContext(), options, filter);
287     } else {
288       vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
289                                                             filter, options);
290     }
291     vector::populateVectorTransferPermutationMapLoweringPatterns(
292         vectorizationPatterns);
293     vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
294     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
295                               linalg::LinalgCopyVTWForwardingPattern>(
296         funcOp.getContext(), /*benefit=*/2);
297     TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns,
298                                                 funcOp.getContext());
299     TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns,
300                                                  funcOp.getContext());
301     (void)applyPatternsAndFoldGreedily(funcOp,
302                                        std::move(vectorizationPatterns));
303 
304     // Apply the pad tensor op vectorization separately to avoid running the
305     // GenericPadTensorOpVectorizationPattern too early.
306     // TODO: Improve once we have better infrastructure to control pattern
307     // application.
308     if (vectorizePadding) {
309       RewritePatternSet patterns(funcOp.getContext());
310       linalg::populatePadTensorOpVectorizationPatterns(patterns);
311       (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
312     }
313   }
314 
315   LinalgVectorizationOptions options;
316   LinalgTransformationFilter filter;
317 };
318 
319 /// Configurable pass to enable the application of other pattern-based linalg
320 /// passes.
321 struct LinalgStrategyEnablePass
322     : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
323 
324   LinalgStrategyEnablePass(LinalgEnablingOptions opt,
325                            LinalgTransformationFilter filt)
326       : options(opt), filter(std::move(filt)) {}
327 
328   void runOnOperation() override {
329     auto funcOp = getOperation();
330     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
331       return;
332 
333     MLIRContext *context = funcOp.getContext();
334     RewritePatternSet patterns =
335         linalg::getLinalgTilingCanonicalizationPatterns(context);
336     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
337     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
338       return signalPassFailure();
339 
340     if (options.licm) {
341       if (funcOp
342               ->walk([&](LoopLikeOpInterface loopLike) {
343                 if (failed(moveLoopInvariantCode(loopLike)))
344                   return WalkResult::interrupt();
345                 return WalkResult::advance();
346               })
347               .wasInterrupted())
348         return signalPassFailure();
349     }
350 
351     promoteSingleIterationLoops(funcOp);
352     if (options.hoistRedundantVectorTransfers)
353       hoistRedundantVectorTransfers(funcOp);
354 
355     if (options.hoistRedundantVectorTransfersOnTensor)
356       hoistRedundantVectorTransfersOnTensor(funcOp);
357 
358     // Run CSE to cleanup after canonicalization.
359     OpPassManager dynamicPM("builtin.func");
360     dynamicPM.addPass(createCSEPass());
361     if (failed(runPipeline(dynamicPM, funcOp)))
362       return signalPassFailure();
363   }
364 
365   LinalgEnablingOptions options;
366   LinalgTransformationFilter filter;
367 };
368 
369 /// Configurable pass to lower vector operations.
370 struct LinalgStrategyLowerVectorsPass
371     : public LinalgStrategyLowerVectorsPassBase<
372           LinalgStrategyLowerVectorsPass> {
373 
374   LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
375                                  LinalgTransformationFilter filt)
376       : options(opt), filter(std::move(filt)) {}
377 
378   void runOnOperation() override {
379     auto funcOp = getOperation();
380     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
381       return;
382 
383     MLIRContext *context = funcOp.getContext();
384     RewritePatternSet patterns(context);
385     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
386     // In a progressive lowering of vectors, this would be the 1st step.
387     if (options.contractionLowering) {
388       patterns.add<ContractionOpToOuterProductOpLowering,
389                    ContractionOpToMatmulOpLowering, ContractionOpLowering>(
390           options.vectorTransformOptions, context);
391       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
392     }
393     // In a progressive lowering of vectors, this would be the 2nd step.
394     if (options.multiReductionLowering) {
395       vector::populateVectorMultiReductionLoweringPatterns(
396           patterns,
397           options.vectorTransformOptions.vectorMultiReductionLowering);
398     }
399     // In a progressive lowering of vectors, this would be the 3rd step.
400     if (options.transferPartialRewrite) {
401       patterns.add<vector::VectorTransferFullPartialRewriter>(
402           context, options.vectorTransformOptions);
403     }
404     // In a progressive lowering of vectors, this would be the 4th step.
405     if (options.transferLowering) {
406       vector::populateVectorTransferLoweringPatterns(patterns,
407                                                      options.maxTransferRank);
408     }
409     // In a progressive lowering of vectors, this would be the 5th step.
410     if (options.transferToSCFConversion) {
411       populateVectorToSCFConversionPatterns(
412           patterns, options.vectorTransferToSCFOptions.setTargetRank(
413                         options.maxTransferRank));
414     }
415     // In a progressive lowering of vectors, this would be the 6th step.
416     if (options.shapeCastLowering) {
417       vector::populateVectorShapeCastLoweringPatterns(patterns);
418     }
419     // In a progressive lowering of vectors, this would be the 7th step.
420     if (options.transposeLowering) {
421       vector::populateVectorTransposeLoweringPatterns(
422           patterns, options.vectorTransformOptions);
423       if (options.avx2Lowering)
424         x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
425             patterns, options.avx2LoweringOptions, /*benefit=*/10);
426     }
427     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
428   }
429 
430   LinalgVectorLoweringOptions options;
431   LinalgTransformationFilter filter;
432 };
433 
434 /// Configurable pass to lower vector operations.
435 struct LinalgStrategyRemoveMarkersPass
436     : public LinalgStrategyRemoveMarkersPassBase<
437           LinalgStrategyRemoveMarkersPass> {
438 
439   void runOnOperation() override {
440     auto funcOp = getOperation();
441     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
442       return;
443     funcOp.walk([](LinalgOp op) {
444       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
445     });
446   }
447 };
448 } // namespace
449 
450 /// Create a LinalgStrategyTileAndFusePass.
451 std::unique_ptr<OperationPass<FuncOp>>
452 mlir::createLinalgStrategyTileAndFusePass(
453     StringRef opName, const LinalgTilingAndFusionOptions &options,
454     const LinalgTransformationFilter &filter) {
455   return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
456                                                          filter);
457 }
458 
459 /// Create a LinalgStrategyTilePass.
460 std::unique_ptr<OperationPass<FuncOp>>
461 mlir::createLinalgStrategyTilePass(StringRef opName,
462                                    const LinalgTilingOptions &opt,
463                                    const LinalgTransformationFilter &filter) {
464   return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
465 }
466 
467 /// Create a LinalgStrategyPadPass.
468 std::unique_ptr<OperationPass<FuncOp>>
469 mlir::createLinalgStrategyPadPass(StringRef opName,
470                                   const LinalgPaddingOptions &opt,
471                                   const LinalgTransformationFilter &filter) {
472   return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
473 }
474 
475 /// Create a LinalgStrategyPromotePass.
476 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyPromotePass(
477     StringRef opName, const LinalgPromotionOptions &opt,
478     const LinalgTransformationFilter &filter) {
479   return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
480 }
481 
482 /// Create a LinalgStrategyGeneralizePass.
483 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyGeneralizePass(
484     StringRef opName, const LinalgTransformationFilter &filter) {
485   return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
486 }
487 
488 /// Create a LinalgStrategyDecomposePass.
489 // TODO: if/when we need finer control add an `opName` parameter.
490 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyDecomposePass(
491     const LinalgTransformationFilter &filter) {
492   return std::make_unique<LinalgStrategyDecomposePass>(filter);
493 }
494 
495 /// Create a LinalgStrategyInterchangePass.
496 std::unique_ptr<OperationPass<FuncOp>>
497 mlir::createLinalgStrategyInterchangePass(
498     ArrayRef<int64_t> iteratorInterchange,
499     const LinalgTransformationFilter &filter) {
500   return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
501                                                          filter);
502 }
503 
504 /// Create a LinalgStrategyVectorizePass.
505 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass(
506     StringRef opName, LinalgVectorizationOptions opt,
507     const LinalgTransformationFilter &filter, bool padVectorize) {
508   return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
509                                                        padVectorize);
510 }
511 
512 /// Create a LinalgStrategyEnablePass.
513 std::unique_ptr<OperationPass<FuncOp>>
514 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
515                                      const LinalgTransformationFilter &filter) {
516   return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
517 }
518 
519 /// Create a LinalgStrategyLowerVectorsPass.
520 std::unique_ptr<OperationPass<FuncOp>>
521 mlir::createLinalgStrategyLowerVectorsPass(
522     LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) {
523   return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
524 }
525 
526 /// Create a LinalgStrategyRemoveMarkersPass.
527 std::unique_ptr<OperationPass<FuncOp>>
528 mlir::createLinalgStrategyRemoveMarkersPass() {
529   return std::make_unique<LinalgStrategyRemoveMarkersPass>();
530 }
531