1 //===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
10 #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
11 
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/TilingInterface.h"
16 
17 #include <deque>
18 
19 namespace mlir {
20 class Operation;
21 class PatternRewriter;
22 class TilingInterface;
23 } // namespace mlir
24 
25 namespace mlir {
26 namespace scf {
27 
28 using SCFTileSizeComputationFunction =
29     std::function<SmallVector<Value>(OpBuilder &, Operation *)>;
30 
31 /// Options to use to control tiling.
32 struct SCFTilingOptions {
33   /// Computation function that returns the tile sizes for each operation.
34   /// Delayed construction of constant tile sizes should occur to interoperate
35   /// with folding.
36   SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
37 
38   SCFTilingOptions &
setTileSizeComputationFunctionSCFTilingOptions39   setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
40     tileSizeComputationFunction = std::move(fun);
41     return *this;
42   }
43   /// Set the `tileSizeComputationFunction` to return the values `ts`. The
44   /// values must not fold away when tiling. Otherwise, use a more robust
45   /// `tileSizeComputationFunction`.
setTileSizesSCFTilingOptions46   SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
47     tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
48     return *this;
49   }
50   /// Convenience function to set the `tileSizeComputationFunction` to a
51   /// function that computes tile sizes at the point they are needed. Allows
52   /// proper interaction with folding.
53   SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
54 
55   /// The interchange vector to reorder the tiled loops.
56   SmallVector<unsigned> interchangeVector = {};
setInterchangeSCFTilingOptions57   SCFTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
58     interchangeVector = llvm::to_vector(interchange);
59     return *this;
60   }
61 };
62 
63 struct SCFTilingResult {
64   Operation *tiledOp;
65   SmallVector<scf::ForOp> loops;
66 };
67 
68 /// Pattern to tile an op that implements the `TilingInterface` using
69 /// `scf.for` for iterating over the tiles.
70 struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
71   /// Construct a generic pattern applied to all TilingInterface ops.
72   TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options,
73                     PatternBenefit benefit = 1);
74 
75   /// Construct a generic pattern applied to `opName`.
76   TileUsingSCFForOp(StringRef opName, MLIRContext *context,
77                     SCFTilingOptions options, PatternBenefit benefit = 1);
78 
79   /// `matchAndRewrite` implementation that returns the significant transformed
80   /// pieces of IR.
81   FailureOr<SCFTilingResult>
82   returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
83 
matchAndRewriteTileUsingSCFForOp84   LogicalResult matchAndRewrite(TilingInterface op,
85                                 PatternRewriter &rewriter) const override {
86     return returningMatchAndRewrite(op, rewriter);
87   }
88 
89 private:
90   /// Options to control tiling;
91   SCFTilingOptions options;
92 };
93 
94 /// Pattern to tile and fuse a sequence of operations, by tiling the consumer
95 /// and fusing its producers. Note that this assumes that it is valid to
96 /// tile+fuse the producer into the innermost tiled loop. Its up to the caller
97 /// to ensure that the tile sizes provided make this fusion valid.
98 ///
99 /// For example, for the following sequence
100 ///
101 /// ```mlir
102 /// %0 = linalg.fill ...
103 /// %1 = linalg.matmul ... outs(%0 : ...) ...
104 /// ```
105 ///
106 /// it is legal to fuse the fill with the matmul only if the matmul is tiled
107 /// along the parallel dimensions and not the reduction dimension, i.e. the tile
108 /// size for the reduction dimension should be 0.
109 struct SCFTileAndFuseResult {
110   SmallVector<Operation *> tiledAndFusedOps;
111   SmallVector<scf::ForOp> loops;
112 };
113 struct TileConsumerAndFuseProducersUsingSCFForOp
114     : public OpInterfaceRewritePattern<TilingInterface> {
115 
116   /// Construct a generic pattern applied to all TilingInterface ops.
117   TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
118                                             SCFTilingOptions options,
119                                             PatternBenefit benefit = 1);
120 
121   /// Construct a generic pattern applied to `opName`.
122   TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
123                                             MLIRContext *context,
124                                             SCFTilingOptions options,
125                                             PatternBenefit benefit = 1);
126 
127   /// `matchAndRewrite` implementation that returns the significant transformed
128   /// pieces of IR.
129   FailureOr<SCFTileAndFuseResult>
130   returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
131 
matchAndRewriteTileConsumerAndFuseProducersUsingSCFForOp132   LogicalResult matchAndRewrite(TilingInterface op,
133                                 PatternRewriter &rewriter) const override {
134     return returningMatchAndRewrite(op, rewriter);
135   }
136 
137 private:
138   /// This pattern uses the tiling pattern. Instead of using inheritance, use
139   /// the patterns as private object that is instantiated at the same time as
140   /// this pattern.
141   TileUsingSCFForOp tilingPattern;
142 };
143 
144 } // namespace scf
145 } // namespace mlir
146 
147 #endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
148