1 //===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H 10 #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H 11 12 #include "mlir/Dialect/SCF/IR/SCF.h" 13 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/Interfaces/TilingInterface.h" 16 17 #include <deque> 18 19 namespace mlir { 20 class Operation; 21 class PatternRewriter; 22 class TilingInterface; 23 } // namespace mlir 24 25 namespace mlir { 26 namespace scf { 27 28 using SCFTileSizeComputationFunction = 29 std::function<SmallVector<Value>(OpBuilder &, Operation *)>; 30 31 /// Options to use to control tiling. 32 struct SCFTilingOptions { 33 /// Computation function that returns the tile sizes for each operation. 34 /// Delayed construction of constant tile sizes should occur to interoperate 35 /// with folding. 36 SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr; 37 38 SCFTilingOptions & setTileSizeComputationFunctionSCFTilingOptions39 setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) { 40 tileSizeComputationFunction = std::move(fun); 41 return *this; 42 } 43 /// Set the `tileSizeComputationFunction` to return the values `ts`. The 44 /// values must not fold away when tiling. Otherwise, use a more robust 45 /// `tileSizeComputationFunction`. setTileSizesSCFTilingOptions46 SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) { 47 tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; }; 48 return *this; 49 } 50 /// Convenience function to set the `tileSizeComputationFunction` to a 51 /// function that computes tile sizes at the point they are needed. Allows 52 /// proper interaction with folding. 53 SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts); 54 55 /// The interchange vector to reorder the tiled loops. 56 SmallVector<unsigned> interchangeVector = {}; setInterchangeSCFTilingOptions57 SCFTilingOptions &setInterchange(ArrayRef<unsigned> interchange) { 58 interchangeVector = llvm::to_vector(interchange); 59 return *this; 60 } 61 }; 62 63 struct SCFTilingResult { 64 Operation *tiledOp; 65 SmallVector<scf::ForOp> loops; 66 }; 67 68 /// Pattern to tile an op that implements the `TilingInterface` using 69 /// `scf.for` for iterating over the tiles. 70 struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> { 71 /// Construct a generic pattern applied to all TilingInterface ops. 72 TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options, 73 PatternBenefit benefit = 1); 74 75 /// Construct a generic pattern applied to `opName`. 76 TileUsingSCFForOp(StringRef opName, MLIRContext *context, 77 SCFTilingOptions options, PatternBenefit benefit = 1); 78 79 /// `matchAndRewrite` implementation that returns the significant transformed 80 /// pieces of IR. 81 FailureOr<SCFTilingResult> 82 returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; 83 matchAndRewriteTileUsingSCFForOp84 LogicalResult matchAndRewrite(TilingInterface op, 85 PatternRewriter &rewriter) const override { 86 return returningMatchAndRewrite(op, rewriter); 87 } 88 89 private: 90 /// Options to control tiling; 91 SCFTilingOptions options; 92 }; 93 94 /// Pattern to tile and fuse a sequence of operations, by tiling the consumer 95 /// and fusing its producers. Note that this assumes that it is valid to 96 /// tile+fuse the producer into the innermost tiled loop. Its up to the caller 97 /// to ensure that the tile sizes provided make this fusion valid. 98 /// 99 /// For example, for the following sequence 100 /// 101 /// ```mlir 102 /// %0 = linalg.fill ... 103 /// %1 = linalg.matmul ... outs(%0 : ...) ... 104 /// ``` 105 /// 106 /// it is legal to fuse the fill with the matmul only if the matmul is tiled 107 /// along the parallel dimensions and not the reduction dimension, i.e. the tile 108 /// size for the reduction dimension should be 0. 109 struct SCFTileAndFuseResult { 110 SmallVector<Operation *> tiledAndFusedOps; 111 SmallVector<scf::ForOp> loops; 112 }; 113 struct TileConsumerAndFuseProducersUsingSCFForOp 114 : public OpInterfaceRewritePattern<TilingInterface> { 115 116 /// Construct a generic pattern applied to all TilingInterface ops. 117 TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, 118 SCFTilingOptions options, 119 PatternBenefit benefit = 1); 120 121 /// Construct a generic pattern applied to `opName`. 122 TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, 123 MLIRContext *context, 124 SCFTilingOptions options, 125 PatternBenefit benefit = 1); 126 127 /// `matchAndRewrite` implementation that returns the significant transformed 128 /// pieces of IR. 129 FailureOr<SCFTileAndFuseResult> 130 returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; 131 matchAndRewriteTileConsumerAndFuseProducersUsingSCFForOp132 LogicalResult matchAndRewrite(TilingInterface op, 133 PatternRewriter &rewriter) const override { 134 return returningMatchAndRewrite(op, rewriter); 135 } 136 137 private: 138 /// This pattern uses the tiling pattern. Instead of using inheritance, use 139 /// the patterns as private object that is instantiated at the same time as 140 /// this pattern. 141 TileUsingSCFForOp tilingPattern; 142 }; 143 144 } // namespace scf 145 } // namespace mlir 146 147 #endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H 148