1 //===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a configurable pass that can apply patterns liberally
10 // and be plugged in a pass pipeline.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Affine/Utils.h"
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"
22 #include "mlir/Dialect/Linalg/Passes.h"
23 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
24 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
25 #include "mlir/Dialect/Linalg/Utils/Utils.h"
26 #include "mlir/Dialect/SCF/Transforms.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/AffineMap.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Support/LLVM.h"
33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34 #include "mlir/Transforms/Passes.h"
35 
36 using namespace mlir;
37 using namespace mlir::vector;
38 using namespace linalg;
39 
40 namespace {
41 
42 /// Configurable pass to apply pattern-based tiling and fusion.
43 struct LinalgStrategyTileAndFusePass
44     : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
45 
46   LinalgStrategyTileAndFusePass() = default;
47 
48   LinalgStrategyTileAndFusePass(StringRef opName,
49                                 LinalgTilingAndFusionOptions opt,
50                                 LinalgTransformationFilter filt)
51       : options(std::move(opt)), filter(std::move(filt)) {
52     this->anchorOpName.setValue(opName.str());
53   }
54 
55   void runOnOperation() override {
56     auto funcOp = getOperation();
57     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
58       return;
59 
60     RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
61     if (!anchorOpName.empty()) {
62       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
63           anchorOpName, funcOp.getContext(), options, filter);
64     } else {
65       tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
66           funcOp.getContext(), options, filter);
67     }
68     // Search the root operation using bottom up traversal.
69     GreedyRewriteConfig config;
70     config.useTopDownTraversal = false;
71     (void)applyPatternsAndFoldGreedily(
72         funcOp, std::move(tilingAndFusionPattern), config);
73   }
74 
75   LinalgTilingAndFusionOptions options;
76   LinalgTransformationFilter filter;
77 };
78 
79 /// Configurable pass to apply pattern-based linalg tiling.
80 struct LinalgStrategyTilePass
81     : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
82 
83   LinalgStrategyTilePass() = default;
84 
85   LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
86                          LinalgTransformationFilter filt)
87       : options(std::move(opt)), filter(std::move(filt)) {
88     this->anchorOpName.setValue(opName.str());
89   }
90 
91   void runOnOperation() override {
92     auto funcOp = getOperation();
93     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
94       return;
95 
96     MLIRContext *ctx = funcOp.getContext();
97     RewritePatternSet tilingPattern(ctx);
98     if (!anchorOpName.empty())
99       tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options,
100                                              filter);
101     else
102       tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
103     if (anchorOpName == tensor::PadOp::getOperationName())
104       populatePadTensorTilingPatterns(tilingPattern, options);
105     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
106   }
107 
108   LinalgTilingOptions options;
109   LinalgTransformationFilter filter;
110 };
111 
112 /// Configurable pass to apply hoisting and padding.
113 struct LinalgStrategyPadPass
114     : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> {
115 
116   LinalgStrategyPadPass() = default;
117 
118   LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
119                         LinalgTransformationFilter filt)
120       : options(std::move(opt)), filter(std::move(filt)) {
121     this->anchorOpName.setValue(opName.str());
122   }
123 
124   void runOnOperation() override {
125     auto funcOp = getOperation();
126     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
127       return;
128 
129     RewritePatternSet paddingPattern(funcOp.getContext());
130     if (!anchorOpName.empty()) {
131       paddingPattern.add<LinalgPaddingPattern>(
132           anchorOpName, funcOp.getContext(), options, filter);
133     } else {
134       paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options,
135                                                filter);
136     }
137     (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern));
138   }
139 
140   LinalgPaddingOptions options;
141   LinalgTransformationFilter filter;
142 };
143 
144 /// Configurable pass to apply pattern-based linalg generalization.
145 struct LinalgStrategyGeneralizePass
146     : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
147 
148   LinalgStrategyGeneralizePass() = default;
149 
150   LinalgStrategyGeneralizePass(StringRef opName,
151                                LinalgTransformationFilter filter)
152       : filter(std::move(filter)) {
153     this->anchorOpName.setValue(opName.str());
154   }
155 
156   void runOnOperation() override {
157     auto funcOp = getOperation();
158     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
159       return;
160 
161     RewritePatternSet generalizationPattern(funcOp.getContext());
162     if (!anchorOpName.empty()) {
163       generalizationPattern.add<LinalgGeneralizationPattern>(
164           anchorOpName, funcOp.getContext(), filter);
165     } else {
166       generalizationPattern.add<LinalgGeneralizationPattern>(
167           funcOp.getContext(), filter);
168     }
169     if (failed(applyPatternsAndFoldGreedily(funcOp,
170                                             std::move(generalizationPattern))))
171       signalPassFailure();
172   }
173 
174   LinalgTransformationFilter filter;
175 };
176 
177 /// Configurable pass to apply lowering of coarser-grained named linalg ops into
178 /// finer-grained named versions.
179 struct LinalgStrategyDecomposePass
180     : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {
181 
182   LinalgStrategyDecomposePass() = default;
183 
184   LinalgStrategyDecomposePass(LinalgTransformationFilter filter)
185       : filter(std::move(filter)) {}
186 
187   void runOnOperation() override {
188     auto funcOp = getOperation();
189     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
190       return;
191     RewritePatternSet decompositionPattern(funcOp.getContext());
192     populateDecomposeConvolutionPatterns(decompositionPattern, filter);
193     if (failed(applyPatternsAndFoldGreedily(funcOp,
194                                             std::move(decompositionPattern))))
195       signalPassFailure();
196   }
197 
198   LinalgTransformationFilter filter;
199 };
200 
201 /// Configurable pass to apply pattern-based linalg generalization.
202 struct LinalgStrategyInterchangePass
203     : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
204 
205   LinalgStrategyInterchangePass() = default;
206 
207   LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
208                                 LinalgTransformationFilter filter)
209       : iteratorInterchange(iteratorInterchange.begin(),
210                             iteratorInterchange.end()),
211         filter(std::move(filter)) {}
212 
213   void runOnOperation() override {
214     auto funcOp = getOperation();
215     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
216       return;
217 
218     SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
219                                             iteratorInterchange.end());
220     RewritePatternSet interchangePattern(funcOp.getContext());
221     interchangePattern.add<GenericOpInterchangePattern>(
222         funcOp.getContext(), interchangeVector, filter);
223     if (failed(applyPatternsAndFoldGreedily(funcOp,
224                                             std::move(interchangePattern))))
225       signalPassFailure();
226   }
227 
228   SmallVector<int64_t> iteratorInterchange;
229   LinalgTransformationFilter filter;
230 };
231 
232 /// Configurable pass to apply pattern-based linalg promotion.
233 struct LinalgStrategyPromotePass
234     : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
235 
236   LinalgStrategyPromotePass() = default;
237 
238   LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
239                             LinalgTransformationFilter filt)
240       : options(std::move(opt)), filter(std::move(filt)) {
241     this->anchorOpName.setValue(opName.str());
242   }
243 
244   void runOnOperation() override {
245     auto funcOp = getOperation();
246     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
247       return;
248 
249     RewritePatternSet promotionPattern(funcOp.getContext());
250     if (!anchorOpName.empty()) {
251       promotionPattern.add<LinalgBasePromotionPattern>(
252           anchorOpName, funcOp.getContext(), options, filter);
253     } else {
254       promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
255                                                        filter, options);
256     }
257     (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
258   }
259 
260   LinalgPromotionOptions options;
261   LinalgTransformationFilter filter;
262 };
263 
264 /// Configurable pass to apply pattern-based linalg vectorization.
265 struct LinalgStrategyVectorizePass
266     : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
267 
268   LinalgStrategyVectorizePass() = default;
269 
270   LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
271                               LinalgTransformationFilter filt,
272                               bool padVectorize = false)
273       : options(opt), filter(std::move(filt)) {
274     this->anchorOpName.setValue(opName.str());
275     this->vectorizePadding.setValue(padVectorize);
276   }
277 
278   void runOnOperation() override {
279     auto funcOp = getOperation();
280     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
281       return;
282 
283     RewritePatternSet vectorizationPatterns(funcOp.getContext());
284     if (!anchorOpName.empty()) {
285       vectorizationPatterns.add<LinalgVectorizationPattern>(
286           anchorOpName, funcOp.getContext(), options, filter);
287     } else {
288       vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
289                                                             filter, options);
290     }
291     vector::populateVectorTransferPermutationMapLoweringPatterns(
292         vectorizationPatterns);
293     vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
294     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
295                               linalg::LinalgCopyVTWForwardingPattern>(
296         funcOp.getContext(), /*benefit=*/2);
297     TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns,
298                                                 funcOp.getContext());
299     TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns,
300                                                  funcOp.getContext());
301     (void)applyPatternsAndFoldGreedily(funcOp,
302                                        std::move(vectorizationPatterns));
303 
304     // Apply the pad tensor op vectorization separately to avoid running the
305     // GenericPadOpVectorizationPattern too early.
306     // TODO: Improve once we have better infrastructure to control pattern
307     // application.
308     if (vectorizePadding) {
309       RewritePatternSet patterns(funcOp.getContext());
310       linalg::populatePadOpVectorizationPatterns(patterns);
311       (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
312     }
313   }
314 
315   LinalgVectorizationOptions options;
316   LinalgTransformationFilter filter;
317 };
318 
319 /// Configurable pass to enable the application of other pattern-based linalg
320 /// passes.
321 struct LinalgStrategyEnablePass
322     : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
323 
324   LinalgStrategyEnablePass(LinalgEnablingOptions opt,
325                            LinalgTransformationFilter filt)
326       : options(opt), filter(std::move(filt)) {}
327 
328   void runOnOperation() override {
329     auto funcOp = getOperation();
330     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
331       return;
332 
333     MLIRContext *context = funcOp.getContext();
334     RewritePatternSet patterns =
335         linalg::getLinalgTilingCanonicalizationPatterns(context);
336     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
337     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
338       return signalPassFailure();
339 
340     if (options.licm) {
341       if (funcOp
342               ->walk([&](LoopLikeOpInterface loopLike) {
343                 if (failed(moveLoopInvariantCode(loopLike)))
344                   return WalkResult::interrupt();
345                 return WalkResult::advance();
346               })
347               .wasInterrupted())
348         return signalPassFailure();
349     }
350 
351     // Gathers all innermost loops through a post order pruned walk.
352     funcOp.walk([](Operation *op) {
353       if (auto forOp = dyn_cast<AffineForOp>(op))
354         (void)promoteIfSingleIteration(forOp);
355       else if (auto forOp = dyn_cast<scf::ForOp>(op))
356         (void)promoteIfSingleIteration(forOp);
357     });
358     if (options.hoistRedundantVectorTransfers)
359       hoistRedundantVectorTransfers(funcOp);
360 
361     if (options.hoistRedundantVectorTransfersOnTensor)
362       hoistRedundantVectorTransfersOnTensor(funcOp);
363 
364     // Run CSE to cleanup after canonicalization.
365     OpPassManager dynamicPM("builtin.func");
366     dynamicPM.addPass(createCSEPass());
367     if (failed(runPipeline(dynamicPM, funcOp)))
368       return signalPassFailure();
369   }
370 
371   LinalgEnablingOptions options;
372   LinalgTransformationFilter filter;
373 };
374 
375 /// Configurable pass to lower vector operations.
376 struct LinalgStrategyLowerVectorsPass
377     : public LinalgStrategyLowerVectorsPassBase<
378           LinalgStrategyLowerVectorsPass> {
379 
380   LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
381                                  LinalgTransformationFilter filt)
382       : options(opt), filter(std::move(filt)) {}
383 
384   void runOnOperation() override {
385     auto funcOp = getOperation();
386     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
387       return;
388 
389     MLIRContext *context = funcOp.getContext();
390     RewritePatternSet patterns(context);
391     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
392     // In a progressive lowering of vectors, this would be the 1st step.
393     if (options.contractionLowering) {
394       patterns.add<ContractionOpToOuterProductOpLowering,
395                    ContractionOpToMatmulOpLowering, ContractionOpLowering>(
396           options.vectorTransformOptions, context);
397       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
398     }
399     // In a progressive lowering of vectors, this would be the 2nd step.
400     if (options.multiReductionLowering) {
401       vector::populateVectorMultiReductionLoweringPatterns(
402           patterns,
403           options.vectorTransformOptions.vectorMultiReductionLowering);
404     }
405     // In a progressive lowering of vectors, this would be the 3rd step.
406     if (options.transferPartialRewrite) {
407       patterns.add<vector::VectorTransferFullPartialRewriter>(
408           context, options.vectorTransformOptions);
409     }
410     // In a progressive lowering of vectors, this would be the 4th step.
411     if (options.transferLowering) {
412       vector::populateVectorTransferLoweringPatterns(patterns,
413                                                      options.maxTransferRank);
414     }
415     // In a progressive lowering of vectors, this would be the 5th step.
416     if (options.transferToSCFConversion) {
417       populateVectorToSCFConversionPatterns(
418           patterns, options.vectorTransferToSCFOptions.setTargetRank(
419                         options.maxTransferRank));
420     }
421     // In a progressive lowering of vectors, this would be the 6th step.
422     if (options.shapeCastLowering) {
423       vector::populateVectorShapeCastLoweringPatterns(patterns);
424     }
425     // In a progressive lowering of vectors, this would be the 7th step.
426     if (options.transposeLowering) {
427       vector::populateVectorTransposeLoweringPatterns(
428           patterns, options.vectorTransformOptions);
429       if (options.avx2Lowering)
430         x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
431             patterns, options.avx2LoweringOptions, /*benefit=*/10);
432     }
433     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
434   }
435 
436   LinalgVectorLoweringOptions options;
437   LinalgTransformationFilter filter;
438 };
439 
440 /// Configurable pass to lower vector operations.
441 struct LinalgStrategyRemoveMarkersPass
442     : public LinalgStrategyRemoveMarkersPassBase<
443           LinalgStrategyRemoveMarkersPass> {
444 
445   void runOnOperation() override {
446     auto funcOp = getOperation();
447     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
448       return;
449     funcOp.walk([](LinalgOp op) {
450       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
451     });
452   }
453 };
454 } // namespace
455 
456 /// Create a LinalgStrategyTileAndFusePass.
457 std::unique_ptr<OperationPass<FuncOp>>
458 mlir::createLinalgStrategyTileAndFusePass(
459     StringRef opName, const LinalgTilingAndFusionOptions &options,
460     const LinalgTransformationFilter &filter) {
461   return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
462                                                          filter);
463 }
464 
465 /// Create a LinalgStrategyTilePass.
466 std::unique_ptr<OperationPass<FuncOp>>
467 mlir::createLinalgStrategyTilePass(StringRef opName,
468                                    const LinalgTilingOptions &opt,
469                                    const LinalgTransformationFilter &filter) {
470   return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
471 }
472 
473 /// Create a LinalgStrategyPadPass.
474 std::unique_ptr<OperationPass<FuncOp>>
475 mlir::createLinalgStrategyPadPass(StringRef opName,
476                                   const LinalgPaddingOptions &opt,
477                                   const LinalgTransformationFilter &filter) {
478   return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
479 }
480 
481 /// Create a LinalgStrategyPromotePass.
482 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyPromotePass(
483     StringRef opName, const LinalgPromotionOptions &opt,
484     const LinalgTransformationFilter &filter) {
485   return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
486 }
487 
488 /// Create a LinalgStrategyGeneralizePass.
489 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyGeneralizePass(
490     StringRef opName, const LinalgTransformationFilter &filter) {
491   return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
492 }
493 
494 /// Create a LinalgStrategyDecomposePass.
495 // TODO: if/when we need finer control add an `opName` parameter.
496 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyDecomposePass(
497     const LinalgTransformationFilter &filter) {
498   return std::make_unique<LinalgStrategyDecomposePass>(filter);
499 }
500 
501 /// Create a LinalgStrategyInterchangePass.
502 std::unique_ptr<OperationPass<FuncOp>>
503 mlir::createLinalgStrategyInterchangePass(
504     ArrayRef<int64_t> iteratorInterchange,
505     const LinalgTransformationFilter &filter) {
506   return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
507                                                          filter);
508 }
509 
510 /// Create a LinalgStrategyVectorizePass.
511 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass(
512     StringRef opName, LinalgVectorizationOptions opt,
513     const LinalgTransformationFilter &filter, bool padVectorize) {
514   return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
515                                                        padVectorize);
516 }
517 
518 /// Create a LinalgStrategyEnablePass.
519 std::unique_ptr<OperationPass<FuncOp>>
520 mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
521                                      const LinalgTransformationFilter &filter) {
522   return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
523 }
524 
525 /// Create a LinalgStrategyLowerVectorsPass.
526 std::unique_ptr<OperationPass<FuncOp>>
527 mlir::createLinalgStrategyLowerVectorsPass(
528     LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) {
529   return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
530 }
531 
532 /// Create a LinalgStrategyRemoveMarkersPass.
533 std::unique_ptr<OperationPass<FuncOp>>
534 mlir::createLinalgStrategyRemoveMarkersPass() {
535   return std::make_unique<LinalgStrategyRemoveMarkersPass>();
536 }
537