1 //===- CodegenStrategy.h - Linalg programmable codegen strategy -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_ 10 #define MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_ 11 12 #include <utility> 13 14 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 15 #include "mlir/Dialect/Linalg/Passes.h" 16 #include "mlir/Pass/PassManager.h" 17 18 namespace mlir { 19 20 namespace linalg { 21 22 /// Abstract Transformation class applied in a sequence that also handles state 23 /// through markers. 24 struct Transformation { TransformationTransformation25 explicit Transformation(LinalgTransformationFilter::FilterFunction f) 26 : filter(std::move(f)) {} 27 virtual ~Transformation() = default; 28 virtual void addToPassPipeline(OpPassManager &pm, 29 LinalgTransformationFilter m) const = 0; 30 LinalgTransformationFilter::FilterFunction filter = nullptr; 31 }; 32 33 /// Represent one application of LinalgStrategyTileAndFusePass. 34 struct TileAndFuse : public Transformation { 35 TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options, 36 LinalgTransformationFilter::FilterFunction f = nullptr) TransformationTileAndFuse37 : Transformation(std::move(f)), opName(name), 38 options(std::move(options)) {} 39 addToPassPipelineTileAndFuse40 void addToPassPipeline(OpPassManager &pm, 41 LinalgTransformationFilter m) const override { 42 pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m)); 43 } 44 45 private: 46 std::string opName; 47 linalg::LinalgTilingAndFusionOptions options; 48 }; 49 50 /// Represent one application of LinalgStrategyTilePass. 51 struct Tile : public Transformation { 52 Tile(StringRef name, linalg::LinalgTilingOptions options, 53 LinalgTransformationFilter::FilterFunction f = nullptr) TransformationTile54 : Transformation(std::move(f)), opName(name), 55 options(std::move(options)) {} 56 addToPassPipelineTile57 void addToPassPipeline(OpPassManager &pm, 58 LinalgTransformationFilter m) const override { 59 pm.addPass(createLinalgStrategyTilePass(opName, options, m)); 60 } 61 62 private: 63 std::string opName; 64 linalg::LinalgTilingOptions options; 65 }; 66 67 /// Represent one application of LinalgStrategyPadPass. 68 struct Pad : public Transformation { 69 Pad(StringRef name, linalg::LinalgPaddingOptions options, 70 LinalgTransformationFilter::FilterFunction f = nullptr) TransformationPad71 : Transformation(std::move(f)), opName(name), 72 options(std::move(options)) {} 73 addToPassPipelinePad74 void addToPassPipeline(OpPassManager &pm, 75 LinalgTransformationFilter m) const override { 76 pm.addPass(createLinalgStrategyPadPass(opName, options, m)); 77 } 78 79 private: 80 std::string opName; 81 linalg::LinalgPaddingOptions options; 82 }; 83 84 /// Represent one application of createLinalgStrategyGeneralizePass. 85 struct Generalize : public Transformation { 86 explicit Generalize(StringRef name, 87 LinalgTransformationFilter::FilterFunction f = nullptr) TransformationGeneralize88 : Transformation(std::move(f)), opName(name) {} 89 addToPassPipelineGeneralize90 void addToPassPipeline(OpPassManager &pm, 91 LinalgTransformationFilter m) const override { 92 pm.addPass(createLinalgStrategyGeneralizePass(opName, m)); 93 } 94 95 private: 96 std::string opName; 97 }; 98 99 /// Represent one application of createLinalgStrategyInterchangePass. 100 struct Interchange : public Transformation { 101 explicit Interchange(ArrayRef<int64_t> iteratorInterchange, 102 LinalgTransformationFilter::FilterFunction f = nullptr) TransformationInterchange103 : Transformation(std::move(f)), 104 iteratorInterchange(iteratorInterchange.begin(), 105 iteratorInterchange.end()) {} 106 addToPassPipelineInterchange107 void addToPassPipeline(OpPassManager &pm, 108 LinalgTransformationFilter m) const override { 109 pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m)); 110 } 111 112 private: 113 SmallVector<int64_t> iteratorInterchange; 114 }; 115 116 /// Represent one application of createLinalgStrategyDecomposePass. 117 struct Decompose : public Transformation { 118 explicit Decompose(LinalgTransformationFilter::FilterFunction f = nullptr) TransformationDecompose119 : Transformation(std::move(f)) {} 120 addToPassPipelineDecompose121 void addToPassPipeline(OpPassManager &pm, 122 LinalgTransformationFilter m) const override { 123 pm.addPass(createLinalgStrategyDecomposePass(m)); 124 } 125 }; 126 127 /// Represent one application of createLinalgStrategyPeelPass. 128 struct Peel : public Transformation { 129 explicit Peel(linalg::LinalgPeelOptions options, 130 LinalgTransformationFilter::FilterFunction f = nullptr) TransformationPeel131 : Transformation(std::move(f)), opName(), options(options) {} 132 133 Peel(StringRef name, linalg::LinalgPeelOptions options, 134 LinalgTransformationFilter::FilterFunction f = nullptr) TransformationPeel135 : Transformation(std::move(f)), opName(name), options(options) {} 136 addToPassPipelinePeel137 void addToPassPipeline(OpPassManager &pm, 138 LinalgTransformationFilter m) const override { 139 pm.addPass(createLinalgStrategyPeelPass(opName, options, m)); 140 } 141 142 private: 143 std::string opName; 144 linalg::LinalgPeelOptions options; 145 }; 146 147 /// Represent one application of createLinalgStrategyVectorizePass. 148 struct Vectorize : public Transformation { 149 explicit Vectorize(linalg::LinalgVectorizationOptions options, 150 LinalgTransformationFilter::FilterFunction f = nullptr, 151 bool padVectorize = false) TransformationVectorize152 : Transformation(std::move(f)), opName(), options(options), 153 vectorizePadding(padVectorize) {} 154 155 Vectorize(StringRef name, linalg::LinalgVectorizationOptions options, 156 LinalgTransformationFilter::FilterFunction f = nullptr, 157 bool padVectorize = false) TransformationVectorize158 : Transformation(std::move(f)), opName(name), options(options), 159 vectorizePadding(padVectorize) {} 160 addToPassPipelineVectorize161 void addToPassPipeline(OpPassManager &pm, 162 LinalgTransformationFilter m) const override { 163 pm.addPass(createLinalgStrategyVectorizePass(opName, options, m, 164 vectorizePadding)); 165 } 166 167 private: 168 std::string opName; 169 linalg::LinalgVectorizationOptions options; 170 bool vectorizePadding; 171 }; 172 173 /// Represent one application of createLinalgStrategyLowerVectorsPass. 174 struct VectorLowering : public Transformation { 175 explicit VectorLowering( 176 linalg::LinalgVectorLoweringOptions options, 177 LinalgTransformationFilter::FilterFunction f = nullptr) TransformationVectorLowering178 : Transformation(std::move(f)), options(options) {} 179 addToPassPipelineVectorLowering180 void addToPassPipeline(OpPassManager &pm, 181 LinalgTransformationFilter m) const override { 182 pm.addPass(createLinalgStrategyLowerVectorsPass(options, m)); 183 } 184 185 private: 186 linalg::LinalgVectorLoweringOptions options; 187 }; 188 189 /// Codegen strategy controls how a Linalg op is progressively lowered. 190 struct CodegenStrategy { 191 /// Append a pattern to tile the Op `opName` and fuse its producers with 192 /// tiling and fusion `options`. 193 CodegenStrategy & 194 tileAndFuse(StringRef opName, const LinalgTilingAndFusionOptions &options, 195 const LinalgTransformationFilter::FilterFunction &f = nullptr) { 196 transformationSequence.emplace_back( 197 std::make_unique<TileAndFuse>(opName, options, f)); 198 return *this; 199 } 200 /// Conditionally append a pattern to tile the Op `opName` and fuse its 201 /// producers with tiling and fusion `options`. 202 CodegenStrategy & 203 tileAndFuseIf(bool b, StringRef opName, LinalgTilingAndFusionOptions options, 204 LinalgTransformationFilter::FilterFunction f = nullptr) { 205 return b ? tileAndFuse(opName, std::move(options), std::move(f)) : *this; 206 } 207 /// Append a pattern to add a level of tiling for Op `opName` with tiling 208 /// `options`. 209 CodegenStrategy & 210 tile(StringRef opName, const linalg::LinalgTilingOptions &options, 211 const LinalgTransformationFilter::FilterFunction &f = nullptr) { 212 transformationSequence.emplace_back( 213 std::make_unique<Tile>(opName, options, f)); 214 return *this; 215 } 216 /// Conditionally append a pattern to add a level of tiling for 217 /// `LinalgOpType` with tiling `options`. 218 CodegenStrategy & 219 tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options, 220 LinalgTransformationFilter::FilterFunction f = nullptr) { 221 return b ? tile(opName, std::move(options), std::move(f)) : *this; 222 } 223 /// Append a pattern to pad and hoist the operands of Op `opName` with padding 224 /// `options`. 225 CodegenStrategy & 226 pad(StringRef opName, const linalg::LinalgPaddingOptions &options, 227 const LinalgTransformationFilter::FilterFunction &f = nullptr) { 228 transformationSequence.emplace_back( 229 std::make_unique<Pad>(opName, options, f)); 230 return *this; 231 } 232 /// Conditionally append a pattern to pad and hoist the operands of Op 233 /// `opName` with padding `options`. 234 CodegenStrategy & 235 padIf(bool b, StringRef opName, linalg::LinalgPaddingOptions options, 236 LinalgTransformationFilter::FilterFunction f = nullptr) { 237 return b ? pad(opName, std::move(options), std::move(f)) : *this; 238 } 239 /// Append a pattern to generalize named operations. 240 CodegenStrategy & 241 generalize(StringRef opName, 242 const LinalgTransformationFilter::FilterFunction &f = nullptr) { 243 transformationSequence.emplace_back( 244 std::make_unique<Generalize>(opName, f)); 245 return *this; 246 } 247 /// Conditionally append a pattern to generalize named operations. 248 CodegenStrategy & 249 generalizeIf(bool b, StringRef opName, 250 LinalgTransformationFilter::FilterFunction f = nullptr) { 251 return b ? generalize(opName, std::move(f)) : *this; 252 } 253 /// Append a pattern to interchange iterators. 254 CodegenStrategy & 255 interchange(ArrayRef<int64_t> iteratorInterchange, 256 const LinalgTransformationFilter::FilterFunction &f = nullptr) { 257 transformationSequence.emplace_back( 258 std::make_unique<Interchange>(iteratorInterchange, f)); 259 return *this; 260 } 261 /// Conditionally append a pattern to interchange iterators. 262 CodegenStrategy & 263 interchangeIf(bool b, ArrayRef<int64_t> iteratorInterchange, 264 LinalgTransformationFilter::FilterFunction f = nullptr) { 265 return b ? interchange(iteratorInterchange, std::move(f)) : *this; 266 } 267 /// Append patterns to decompose convolutions. 268 CodegenStrategy & 269 decompose(const LinalgTransformationFilter::FilterFunction &f = nullptr) { 270 transformationSequence.emplace_back(std::make_unique<Decompose>(f)); 271 return *this; 272 } 273 /// Conditionally append patterns to decompose convolutions. 274 CodegenStrategy & 275 decomposeIf(bool b, LinalgTransformationFilter::FilterFunction f = nullptr) { 276 return b ? decompose(std::move(f)) : *this; 277 } 278 /// Append a pattern to peel 'LinalgOpType'. 279 CodegenStrategy & 280 peel(StringRef opName, const LinalgPeelOptions &options, 281 const LinalgTransformationFilter::FilterFunction &f = nullptr) { 282 transformationSequence.emplace_back( 283 std::make_unique<Peel>(opName, options, f)); 284 return *this; 285 } 286 /// Conditionally append a pattern to peel 'LinalgOpType'. 287 CodegenStrategy & 288 peelIf(bool b, StringRef opName, const LinalgPeelOptions &options, 289 LinalgTransformationFilter::FilterFunction f = nullptr) { 290 return b ? peel(opName, options, std::move(f)) : *this; 291 } 292 /// Append a pattern to rewrite `LinalgOpType` as a vector operation. 293 CodegenStrategy & 294 vectorize(StringRef opName, 295 const LinalgTransformationFilter::FilterFunction &f = nullptr, 296 bool vectorizePadding = false) { 297 transformationSequence.emplace_back(std::make_unique<Vectorize>( 298 opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding)); 299 return *this; 300 } 301 /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector 302 /// operation. 303 CodegenStrategy & 304 vectorizeIf(bool b, StringRef opName, 305 LinalgTransformationFilter::FilterFunction f = nullptr, 306 bool vectorizePadding = false) { 307 return b ? vectorize(opName, std::move(f), vectorizePadding) : *this; 308 } 309 /// Append a pattern to lower all vector operations. vectorLoweringCodegenStrategy310 CodegenStrategy &vectorLowering(LinalgVectorLoweringOptions options) { 311 transformationSequence.emplace_back( 312 std::make_unique<VectorLowering>(options)); 313 return *this; 314 } 315 /// Configure the post staged-patterns global enabling passes options. 316 CodegenStrategy & setVectorTransferToSCFOptionsCodegenStrategy317 setVectorTransferToSCFOptions(LinalgEnablingOptions options) { 318 linalgEnablingOptions = options; 319 return *this; 320 } 321 322 /// Apply the transformation patterns in sequence with cleanup 323 /// transformations interleaved. 324 void configurePassPipeline(OpPassManager &pm, MLIRContext *context, 325 bool addEnablePass = true) const; 326 327 private: 328 LogicalResult postPatternTransforms(Operation *func) const; 329 330 LinalgEnablingOptions linalgEnablingOptions; 331 SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence; 332 }; 333 334 } // namespace linalg 335 } // namespace mlir 336 337 #endif // MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_ 338