1 //===- CodegenStrategy.h - Linalg programmable codegen strategy -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
10 #define MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
11 
12 #include <utility>
13 
14 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
15 #include "mlir/Dialect/Linalg/Passes.h"
16 #include "mlir/Pass/PassManager.h"
17 
18 namespace mlir {
19 
20 namespace linalg {
21 
22 /// Abstract Transformation class applied in a sequence that also handles state
23 /// through markers.
24 struct Transformation {
TransformationTransformation25   explicit Transformation(LinalgTransformationFilter::FilterFunction f)
26       : filter(std::move(f)) {}
27   virtual ~Transformation() = default;
28   virtual void addToPassPipeline(OpPassManager &pm,
29                                  LinalgTransformationFilter m) const = 0;
30   LinalgTransformationFilter::FilterFunction filter = nullptr;
31 };
32 
33 /// Represent one application of LinalgStrategyTileAndFusePass.
34 struct TileAndFuse : public Transformation {
35   TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options,
36               LinalgTransformationFilter::FilterFunction f = nullptr)
TransformationTileAndFuse37       : Transformation(std::move(f)), opName(name),
38         options(std::move(options)) {}
39 
addToPassPipelineTileAndFuse40   void addToPassPipeline(OpPassManager &pm,
41                          LinalgTransformationFilter m) const override {
42     pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m));
43   }
44 
45 private:
46   std::string opName;
47   linalg::LinalgTilingAndFusionOptions options;
48 };
49 
50 /// Represent one application of LinalgStrategyTilePass.
51 struct Tile : public Transformation {
52   Tile(StringRef name, linalg::LinalgTilingOptions options,
53        LinalgTransformationFilter::FilterFunction f = nullptr)
TransformationTile54       : Transformation(std::move(f)), opName(name),
55         options(std::move(options)) {}
56 
addToPassPipelineTile57   void addToPassPipeline(OpPassManager &pm,
58                          LinalgTransformationFilter m) const override {
59     pm.addPass(createLinalgStrategyTilePass(opName, options, m));
60   }
61 
62 private:
63   std::string opName;
64   linalg::LinalgTilingOptions options;
65 };
66 
67 /// Represent one application of LinalgStrategyPadPass.
68 struct Pad : public Transformation {
69   Pad(StringRef name, linalg::LinalgPaddingOptions options,
70       LinalgTransformationFilter::FilterFunction f = nullptr)
TransformationPad71       : Transformation(std::move(f)), opName(name),
72         options(std::move(options)) {}
73 
addToPassPipelinePad74   void addToPassPipeline(OpPassManager &pm,
75                          LinalgTransformationFilter m) const override {
76     pm.addPass(createLinalgStrategyPadPass(opName, options, m));
77   }
78 
79 private:
80   std::string opName;
81   linalg::LinalgPaddingOptions options;
82 };
83 
84 /// Represent one application of createLinalgStrategyGeneralizePass.
85 struct Generalize : public Transformation {
86   explicit Generalize(StringRef name,
87                       LinalgTransformationFilter::FilterFunction f = nullptr)
TransformationGeneralize88       : Transformation(std::move(f)), opName(name) {}
89 
addToPassPipelineGeneralize90   void addToPassPipeline(OpPassManager &pm,
91                          LinalgTransformationFilter m) const override {
92     pm.addPass(createLinalgStrategyGeneralizePass(opName, m));
93   }
94 
95 private:
96   std::string opName;
97 };
98 
99 /// Represent one application of createLinalgStrategyInterchangePass.
100 struct Interchange : public Transformation {
101   explicit Interchange(ArrayRef<int64_t> iteratorInterchange,
102                        LinalgTransformationFilter::FilterFunction f = nullptr)
TransformationInterchange103       : Transformation(std::move(f)),
104         iteratorInterchange(iteratorInterchange.begin(),
105                             iteratorInterchange.end()) {}
106 
addToPassPipelineInterchange107   void addToPassPipeline(OpPassManager &pm,
108                          LinalgTransformationFilter m) const override {
109     pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m));
110   }
111 
112 private:
113   SmallVector<int64_t> iteratorInterchange;
114 };
115 
116 /// Represent one application of createLinalgStrategyDecomposePass.
117 struct Decompose : public Transformation {
118   explicit Decompose(LinalgTransformationFilter::FilterFunction f = nullptr)
TransformationDecompose119       : Transformation(std::move(f)) {}
120 
addToPassPipelineDecompose121   void addToPassPipeline(OpPassManager &pm,
122                          LinalgTransformationFilter m) const override {
123     pm.addPass(createLinalgStrategyDecomposePass(m));
124   }
125 };
126 
127 /// Represent one application of createLinalgStrategyPeelPass.
128 struct Peel : public Transformation {
129   explicit Peel(linalg::LinalgPeelOptions options,
130                 LinalgTransformationFilter::FilterFunction f = nullptr)
TransformationPeel131       : Transformation(std::move(f)), opName(), options(options) {}
132 
133   Peel(StringRef name, linalg::LinalgPeelOptions options,
134        LinalgTransformationFilter::FilterFunction f = nullptr)
TransformationPeel135       : Transformation(std::move(f)), opName(name), options(options) {}
136 
addToPassPipelinePeel137   void addToPassPipeline(OpPassManager &pm,
138                          LinalgTransformationFilter m) const override {
139     pm.addPass(createLinalgStrategyPeelPass(opName, options, m));
140   }
141 
142 private:
143   std::string opName;
144   linalg::LinalgPeelOptions options;
145 };
146 
147 /// Represent one application of createLinalgStrategyVectorizePass.
148 struct Vectorize : public Transformation {
149   explicit Vectorize(linalg::LinalgVectorizationOptions options,
150                      LinalgTransformationFilter::FilterFunction f = nullptr,
151                      bool padVectorize = false)
TransformationVectorize152       : Transformation(std::move(f)), opName(), options(options),
153         vectorizePadding(padVectorize) {}
154 
155   Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
156             LinalgTransformationFilter::FilterFunction f = nullptr,
157             bool padVectorize = false)
TransformationVectorize158       : Transformation(std::move(f)), opName(name), options(options),
159         vectorizePadding(padVectorize) {}
160 
addToPassPipelineVectorize161   void addToPassPipeline(OpPassManager &pm,
162                          LinalgTransformationFilter m) const override {
163     pm.addPass(createLinalgStrategyVectorizePass(opName, options, m,
164                                                  vectorizePadding));
165   }
166 
167 private:
168   std::string opName;
169   linalg::LinalgVectorizationOptions options;
170   bool vectorizePadding;
171 };
172 
173 /// Represent one application of createLinalgStrategyLowerVectorsPass.
174 struct VectorLowering : public Transformation {
175   explicit VectorLowering(
176       linalg::LinalgVectorLoweringOptions options,
177       LinalgTransformationFilter::FilterFunction f = nullptr)
TransformationVectorLowering178       : Transformation(std::move(f)), options(options) {}
179 
addToPassPipelineVectorLowering180   void addToPassPipeline(OpPassManager &pm,
181                          LinalgTransformationFilter m) const override {
182     pm.addPass(createLinalgStrategyLowerVectorsPass(options, m));
183   }
184 
185 private:
186   linalg::LinalgVectorLoweringOptions options;
187 };
188 
189 /// Codegen strategy controls how a Linalg op is progressively lowered.
190 struct CodegenStrategy {
191   /// Append a pattern to tile the Op `opName` and fuse its producers with
192   /// tiling and fusion `options`.
193   CodegenStrategy &
194   tileAndFuse(StringRef opName, const LinalgTilingAndFusionOptions &options,
195               const LinalgTransformationFilter::FilterFunction &f = nullptr) {
196     transformationSequence.emplace_back(
197         std::make_unique<TileAndFuse>(opName, options, f));
198     return *this;
199   }
200   /// Conditionally append a pattern to tile the Op `opName` and fuse its
201   /// producers with tiling and fusion `options`.
202   CodegenStrategy &
203   tileAndFuseIf(bool b, StringRef opName, LinalgTilingAndFusionOptions options,
204                 LinalgTransformationFilter::FilterFunction f = nullptr) {
205     return b ? tileAndFuse(opName, std::move(options), std::move(f)) : *this;
206   }
207   /// Append a pattern to add a level of tiling for Op `opName` with tiling
208   /// `options`.
209   CodegenStrategy &
210   tile(StringRef opName, const linalg::LinalgTilingOptions &options,
211        const LinalgTransformationFilter::FilterFunction &f = nullptr) {
212     transformationSequence.emplace_back(
213         std::make_unique<Tile>(opName, options, f));
214     return *this;
215   }
216   /// Conditionally append a pattern to add a level of tiling for
217   /// `LinalgOpType` with tiling `options`.
218   CodegenStrategy &
219   tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
220          LinalgTransformationFilter::FilterFunction f = nullptr) {
221     return b ? tile(opName, std::move(options), std::move(f)) : *this;
222   }
223   /// Append a pattern to pad and hoist the operands of Op `opName` with padding
224   /// `options`.
225   CodegenStrategy &
226   pad(StringRef opName, const linalg::LinalgPaddingOptions &options,
227       const LinalgTransformationFilter::FilterFunction &f = nullptr) {
228     transformationSequence.emplace_back(
229         std::make_unique<Pad>(opName, options, f));
230     return *this;
231   }
232   /// Conditionally append a pattern to pad and hoist the operands of Op
233   /// `opName` with padding `options`.
234   CodegenStrategy &
235   padIf(bool b, StringRef opName, linalg::LinalgPaddingOptions options,
236         LinalgTransformationFilter::FilterFunction f = nullptr) {
237     return b ? pad(opName, std::move(options), std::move(f)) : *this;
238   }
239   /// Append a pattern to generalize named operations.
240   CodegenStrategy &
241   generalize(StringRef opName,
242              const LinalgTransformationFilter::FilterFunction &f = nullptr) {
243     transformationSequence.emplace_back(
244         std::make_unique<Generalize>(opName, f));
245     return *this;
246   }
247   /// Conditionally append a pattern to generalize named operations.
248   CodegenStrategy &
249   generalizeIf(bool b, StringRef opName,
250                LinalgTransformationFilter::FilterFunction f = nullptr) {
251     return b ? generalize(opName, std::move(f)) : *this;
252   }
253   /// Append a pattern to interchange iterators.
254   CodegenStrategy &
255   interchange(ArrayRef<int64_t> iteratorInterchange,
256               const LinalgTransformationFilter::FilterFunction &f = nullptr) {
257     transformationSequence.emplace_back(
258         std::make_unique<Interchange>(iteratorInterchange, f));
259     return *this;
260   }
261   /// Conditionally append a pattern to interchange iterators.
262   CodegenStrategy &
263   interchangeIf(bool b, ArrayRef<int64_t> iteratorInterchange,
264                 LinalgTransformationFilter::FilterFunction f = nullptr) {
265     return b ? interchange(iteratorInterchange, std::move(f)) : *this;
266   }
267   /// Append patterns to decompose convolutions.
268   CodegenStrategy &
269   decompose(const LinalgTransformationFilter::FilterFunction &f = nullptr) {
270     transformationSequence.emplace_back(std::make_unique<Decompose>(f));
271     return *this;
272   }
273   /// Conditionally append patterns to decompose convolutions.
274   CodegenStrategy &
275   decomposeIf(bool b, LinalgTransformationFilter::FilterFunction f = nullptr) {
276     return b ? decompose(std::move(f)) : *this;
277   }
278   /// Append a pattern to peel 'LinalgOpType'.
279   CodegenStrategy &
280   peel(StringRef opName, const LinalgPeelOptions &options,
281        const LinalgTransformationFilter::FilterFunction &f = nullptr) {
282     transformationSequence.emplace_back(
283         std::make_unique<Peel>(opName, options, f));
284     return *this;
285   }
286   /// Conditionally append a pattern to peel 'LinalgOpType'.
287   CodegenStrategy &
288   peelIf(bool b, StringRef opName, const LinalgPeelOptions &options,
289          LinalgTransformationFilter::FilterFunction f = nullptr) {
290     return b ? peel(opName, options, std::move(f)) : *this;
291   }
292   /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
293   CodegenStrategy &
294   vectorize(StringRef opName,
295             const LinalgTransformationFilter::FilterFunction &f = nullptr,
296             bool vectorizePadding = false) {
297     transformationSequence.emplace_back(std::make_unique<Vectorize>(
298         opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding));
299     return *this;
300   }
301   /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
302   /// operation.
303   CodegenStrategy &
304   vectorizeIf(bool b, StringRef opName,
305               LinalgTransformationFilter::FilterFunction f = nullptr,
306               bool vectorizePadding = false) {
307     return b ? vectorize(opName, std::move(f), vectorizePadding) : *this;
308   }
309   /// Append a pattern to lower all vector operations.
vectorLoweringCodegenStrategy310   CodegenStrategy &vectorLowering(LinalgVectorLoweringOptions options) {
311     transformationSequence.emplace_back(
312         std::make_unique<VectorLowering>(options));
313     return *this;
314   }
315   /// Configure the post staged-patterns global enabling passes options.
316   CodegenStrategy &
setVectorTransferToSCFOptionsCodegenStrategy317   setVectorTransferToSCFOptions(LinalgEnablingOptions options) {
318     linalgEnablingOptions = options;
319     return *this;
320   }
321 
322   /// Apply the transformation patterns in sequence with cleanup
323   /// transformations interleaved.
324   void configurePassPipeline(OpPassManager &pm, MLIRContext *context,
325                              bool addEnablePass = true) const;
326 
327 private:
328   LogicalResult postPatternTransforms(Operation *func) const;
329 
330   LinalgEnablingOptions linalgEnablingOptions;
331   SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
332 };
333 
334 } // namespace linalg
335 } // namespace mlir
336 
337 #endif // MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
338