1cf6a7c19SMahesh Ravishankar //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2cf6a7c19SMahesh Ravishankar //
3cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
5cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6cf6a7c19SMahesh Ravishankar //
7cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
8cf6a7c19SMahesh Ravishankar //
9cf6a7c19SMahesh Ravishankar // This file implements the tiling using TilingInterface.
10cf6a7c19SMahesh Ravishankar //
11cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
12cf6a7c19SMahesh Ravishankar 
13*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14cf6a7c19SMahesh Ravishankar 
15cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
16cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h"
18cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/Utils/Utils.h"
19cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h"
20cf6a7c19SMahesh Ravishankar #include "mlir/IR/Matchers.h"
21cf6a7c19SMahesh Ravishankar #include "mlir/IR/PatternMatch.h"
22cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h"
23cf6a7c19SMahesh Ravishankar #include "llvm/Support/Debug.h"
24cf6a7c19SMahesh Ravishankar 
25cf6a7c19SMahesh Ravishankar #define DEBUG_TYPE "tile-using-interface"
26cf6a7c19SMahesh Ravishankar 
27cf6a7c19SMahesh Ravishankar using namespace mlir;
28cf6a7c19SMahesh Ravishankar 
29cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions &
30cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
31cf6a7c19SMahesh Ravishankar   assert(!tileSizeComputationFunction && "tile sizes already set");
32cf6a7c19SMahesh Ravishankar   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
33cf6a7c19SMahesh Ravishankar   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
34cf6a7c19SMahesh Ravishankar     OpBuilder::InsertionGuard guard(b);
35cf6a7c19SMahesh Ravishankar     b.setInsertionPointToStart(
36cf6a7c19SMahesh Ravishankar         &op->getParentOfType<func::FuncOp>().getBody().front());
37cf6a7c19SMahesh Ravishankar     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
38cf6a7c19SMahesh Ravishankar       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
39cf6a7c19SMahesh Ravishankar       return v;
40cf6a7c19SMahesh Ravishankar     }));
41cf6a7c19SMahesh Ravishankar   };
42cf6a7c19SMahesh Ravishankar   return *this;
43cf6a7c19SMahesh Ravishankar }
44cf6a7c19SMahesh Ravishankar 
45cf6a7c19SMahesh Ravishankar /// Generate an empty loop nest that represents the tiled loop nest shell.
46cf6a7c19SMahesh Ravishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
47cf6a7c19SMahesh Ravishankar /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
48cf6a7c19SMahesh Ravishankar /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
49cf6a7c19SMahesh Ravishankar /// the
50cf6a7c19SMahesh Ravishankar ///   tile processed within the inner most loop.
51cf6a7c19SMahesh Ravishankar static SmallVector<scf::ForOp>
52cf6a7c19SMahesh Ravishankar generateTileLoopNest(OpBuilder &builder, Location loc,
53cf6a7c19SMahesh Ravishankar                      ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
54cf6a7c19SMahesh Ravishankar                      SmallVector<OpFoldResult> &offsets,
55cf6a7c19SMahesh Ravishankar                      SmallVector<OpFoldResult> &sizes) {
56cf6a7c19SMahesh Ravishankar   assert(!loopRanges.empty() && "expected at least one loop range");
57cf6a7c19SMahesh Ravishankar   assert(loopRanges.size() == tileSizeVals.size() &&
58cf6a7c19SMahesh Ravishankar          "expected as many tile sizes as loop ranges");
59cf6a7c19SMahesh Ravishankar   OpBuilder::InsertionGuard guard(builder);
60cf6a7c19SMahesh Ravishankar   SmallVector<scf::ForOp> loops;
61cf6a7c19SMahesh Ravishankar   offsets.resize(loopRanges.size());
62cf6a7c19SMahesh Ravishankar   sizes.resize(loopRanges.size());
63cf6a7c19SMahesh Ravishankar 
64cf6a7c19SMahesh Ravishankar   // The tile size to use (to avoid out of bounds access) is  minimum of
65cf6a7c19SMahesh Ravishankar   // `tileSize` and `ub - iv`, where `iv` is the induction variable
66cf6a7c19SMahesh Ravishankar   // of the tiled loop.
67cf6a7c19SMahesh Ravishankar   AffineExpr s0, s1, d0;
68cf6a7c19SMahesh Ravishankar   bindDims(builder.getContext(), d0);
69cf6a7c19SMahesh Ravishankar   bindSymbols(builder.getContext(), s0, s1);
70cf6a7c19SMahesh Ravishankar   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
71cf6a7c19SMahesh Ravishankar 
72cf6a7c19SMahesh Ravishankar   for (auto loopRange : llvm::enumerate(loopRanges)) {
73cf6a7c19SMahesh Ravishankar     // No loops if tile size is zero. Set offset and size to the loop
74cf6a7c19SMahesh Ravishankar     // offset and size.
75cf6a7c19SMahesh Ravishankar     if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
76cf6a7c19SMahesh Ravishankar       offsets[loopRange.index()] = loopRange.value().offset;
77cf6a7c19SMahesh Ravishankar       sizes[loopRange.index()] = loopRange.value().size;
78cf6a7c19SMahesh Ravishankar       continue;
79cf6a7c19SMahesh Ravishankar     }
80cf6a7c19SMahesh Ravishankar 
81cf6a7c19SMahesh Ravishankar     auto loop = builder.create<scf::ForOp>(
82cf6a7c19SMahesh Ravishankar         loc, loopRange.value().offset, loopRange.value().size,
83cf6a7c19SMahesh Ravishankar         tileSizeVals[loopRange.index()], ValueRange{},
84cf6a7c19SMahesh Ravishankar         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
85cf6a7c19SMahesh Ravishankar             ValueRange /*iterArgs*/) {
86cf6a7c19SMahesh Ravishankar           Value boundedTileSize = builder.create<AffineMinOp>(
87cf6a7c19SMahesh Ravishankar               bodyLoc, minMap,
88cf6a7c19SMahesh Ravishankar               ValueRange{iv, tileSizeVals[loopRange.index()],
89cf6a7c19SMahesh Ravishankar                          loopRange.value().size});
90cf6a7c19SMahesh Ravishankar           sizes[loopRange.index()] = boundedTileSize;
91cf6a7c19SMahesh Ravishankar           builder.create<scf::YieldOp>(loc);
92cf6a7c19SMahesh Ravishankar         });
93cf6a7c19SMahesh Ravishankar     offsets[loopRange.index()] = loop.getInductionVar();
94cf6a7c19SMahesh Ravishankar     loops.push_back(loop);
95cf6a7c19SMahesh Ravishankar     builder.setInsertionPoint(loop.getBody()->getTerminator());
96cf6a7c19SMahesh Ravishankar   }
97cf6a7c19SMahesh Ravishankar   return loops;
98cf6a7c19SMahesh Ravishankar }
99cf6a7c19SMahesh Ravishankar 
100cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
101cf6a7c19SMahesh Ravishankar                                           scf::SCFTilingOptions options,
102cf6a7c19SMahesh Ravishankar                                           PatternBenefit benefit)
103cf6a7c19SMahesh Ravishankar     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
104cf6a7c19SMahesh Ravishankar       options(std::move(options)) {}
105cf6a7c19SMahesh Ravishankar 
106cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
107cf6a7c19SMahesh Ravishankar                                           MLIRContext *context,
108cf6a7c19SMahesh Ravishankar                                           scf::SCFTilingOptions options,
109cf6a7c19SMahesh Ravishankar                                           PatternBenefit benefit)
110cf6a7c19SMahesh Ravishankar     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
111cf6a7c19SMahesh Ravishankar       options(std::move(options)) {}
112cf6a7c19SMahesh Ravishankar 
113cf6a7c19SMahesh Ravishankar FailureOr<scf::SCFTilingResult>
114cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::returningMatchAndRewrite(
115cf6a7c19SMahesh Ravishankar     TilingInterface op, PatternRewriter &rewriter) const {
116cf6a7c19SMahesh Ravishankar   OpBuilder::InsertionGuard guard(rewriter);
117cf6a7c19SMahesh Ravishankar   rewriter.setInsertionPointAfter(op);
118cf6a7c19SMahesh Ravishankar 
119cf6a7c19SMahesh Ravishankar   if (!options.tileSizeComputationFunction) {
120cf6a7c19SMahesh Ravishankar     return rewriter.notifyMatchFailure(
121cf6a7c19SMahesh Ravishankar         op, "missing tile size computation function");
122cf6a7c19SMahesh Ravishankar   }
123cf6a7c19SMahesh Ravishankar 
124cf6a7c19SMahesh Ravishankar   // 1. Get the range of the loops that are represented by the operation.
125cf6a7c19SMahesh Ravishankar   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
126cf6a7c19SMahesh Ravishankar   size_t numLoops = iterationDomain.size();
127cf6a7c19SMahesh Ravishankar   if (numLoops == 0) {
128cf6a7c19SMahesh Ravishankar     return rewriter.notifyMatchFailure(
129cf6a7c19SMahesh Ravishankar         op, "unable to tile op with no iteration domain");
130cf6a7c19SMahesh Ravishankar   }
131cf6a7c19SMahesh Ravishankar 
132cf6a7c19SMahesh Ravishankar   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
133cf6a7c19SMahesh Ravishankar   // skips tiling a particular dimension. This convention is significantly
134cf6a7c19SMahesh Ravishankar   // simpler to handle instead of adjusting affine maps to account for missing
135cf6a7c19SMahesh Ravishankar   // dimensions.
136cf6a7c19SMahesh Ravishankar   SmallVector<Value, 4> tileSizeVector =
137cf6a7c19SMahesh Ravishankar       options.tileSizeComputationFunction(rewriter, op);
138cf6a7c19SMahesh Ravishankar   if (tileSizeVector.size() < iterationDomain.size()) {
139cf6a7c19SMahesh Ravishankar     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
140cf6a7c19SMahesh Ravishankar     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
141cf6a7c19SMahesh Ravishankar   }
142cf6a7c19SMahesh Ravishankar 
143cf6a7c19SMahesh Ravishankar   scf::SCFTilingResult tilingResult;
144cf6a7c19SMahesh Ravishankar   SmallVector<OpFoldResult> offsets, sizes;
145cf6a7c19SMahesh Ravishankar   {
146cf6a7c19SMahesh Ravishankar     // 3. Materialize an empty loop nest that iterates over the tiles. These
147cf6a7c19SMahesh Ravishankar     // loops for now do not return any values even if the original operation has
148cf6a7c19SMahesh Ravishankar     // results.
149cf6a7c19SMahesh Ravishankar     tilingResult.loops = generateTileLoopNest(
150cf6a7c19SMahesh Ravishankar         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
151cf6a7c19SMahesh Ravishankar 
152cf6a7c19SMahesh Ravishankar     LLVM_DEBUG({
153cf6a7c19SMahesh Ravishankar       if (!tilingResult.loops.empty()) {
154cf6a7c19SMahesh Ravishankar         llvm::errs() << "LoopNest shell :\n";
155cf6a7c19SMahesh Ravishankar         tilingResult.loops.front().dump();
156cf6a7c19SMahesh Ravishankar         llvm::errs() << "\n";
157cf6a7c19SMahesh Ravishankar       }
158cf6a7c19SMahesh Ravishankar     });
159cf6a7c19SMahesh Ravishankar 
160cf6a7c19SMahesh Ravishankar     // 4. Generate the tiled implementation within the inner most loop.
161cf6a7c19SMahesh Ravishankar     if (!tilingResult.loops.empty())
162cf6a7c19SMahesh Ravishankar       rewriter.setInsertionPoint(
163cf6a7c19SMahesh Ravishankar           tilingResult.loops.back().getBody()->getTerminator());
164cf6a7c19SMahesh Ravishankar     SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
165cf6a7c19SMahesh Ravishankar         rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
166cf6a7c19SMahesh Ravishankar     if (tiledImplementation.size() != 1) {
167cf6a7c19SMahesh Ravishankar       return rewriter.notifyMatchFailure(
168cf6a7c19SMahesh Ravishankar           op, "expected tiled implementation to return a single op");
169cf6a7c19SMahesh Ravishankar     }
170cf6a7c19SMahesh Ravishankar     tilingResult.tiledOp = tiledImplementation[0];
171cf6a7c19SMahesh Ravishankar 
172cf6a7c19SMahesh Ravishankar     LLVM_DEBUG({
173cf6a7c19SMahesh Ravishankar       if (!tilingResult.loops.empty()) {
174cf6a7c19SMahesh Ravishankar         llvm::errs() << "After tiled implementation :\n";
175cf6a7c19SMahesh Ravishankar         tilingResult.loops.front().dump();
176cf6a7c19SMahesh Ravishankar         llvm::errs() << "\n";
177cf6a7c19SMahesh Ravishankar       }
178cf6a7c19SMahesh Ravishankar     });
179cf6a7c19SMahesh Ravishankar   }
180cf6a7c19SMahesh Ravishankar 
181cf6a7c19SMahesh Ravishankar   if (op->getNumResults() == 0) {
182cf6a7c19SMahesh Ravishankar     rewriter.eraseOp(op);
183cf6a7c19SMahesh Ravishankar     return tilingResult;
184cf6a7c19SMahesh Ravishankar   }
185cf6a7c19SMahesh Ravishankar 
186cf6a7c19SMahesh Ravishankar   // 5. If the original operations has results, modify the loop nest to yield
187cf6a7c19SMahesh Ravishankar   // the replacement values.
188cf6a7c19SMahesh Ravishankar   SmallVector<Value> replacements;
189cf6a7c19SMahesh Ravishankar   if (tilingResult.loops.empty()) {
190cf6a7c19SMahesh Ravishankar     // 5a. If there were no loops, the tiled implementation results are the
191cf6a7c19SMahesh Ravishankar     // replacements.
192cf6a7c19SMahesh Ravishankar     rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
193cf6a7c19SMahesh Ravishankar     return tilingResult;
194cf6a7c19SMahesh Ravishankar   }
195cf6a7c19SMahesh Ravishankar 
196cf6a7c19SMahesh Ravishankar   // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
197cf6a7c19SMahesh Ravishankar   // replacement values using destructive updates. Use the `TilingInterface`
198cf6a7c19SMahesh Ravishankar   // to get the position of the result tiles and use that to generate the
199cf6a7c19SMahesh Ravishankar   // destructive update pattern, i.e.,
200cf6a7c19SMahesh Ravishankar   //
201cf6a7c19SMahesh Ravishankar   // ```mlir
202cf6a7c19SMahesh Ravishankar   // scf.for %iv0 = ... {
203cf6a7c19SMahesh Ravishankar   //   %0 = tiled_op
204cf6a7c19SMahesh Ravishankar   // }
205cf6a7c19SMahesh Ravishankar   // ```
206cf6a7c19SMahesh Ravishankar   //
207cf6a7c19SMahesh Ravishankar   // is transformed to
208cf6a7c19SMahesh Ravishankar   //
209cf6a7c19SMahesh Ravishankar   // ```mlir
210cf6a7c19SMahesh Ravishankar   // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
211cf6a7c19SMahesh Ravishankar   //   %0 = tiled_op
212cf6a7c19SMahesh Ravishankar   //   %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
213cf6a7c19SMahesh Ravishankar   //   scf.yield %1
214cf6a7c19SMahesh Ravishankar   // }
215cf6a7c19SMahesh Ravishankar   // ```
216cf6a7c19SMahesh Ravishankar   NewYieldValueFn yieldValueFn =
217cf6a7c19SMahesh Ravishankar       [&](OpBuilder &b, Location loc,
218cf6a7c19SMahesh Ravishankar           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
219cf6a7c19SMahesh Ravishankar     SmallVector<Value> yieldedValues;
220cf6a7c19SMahesh Ravishankar     Attribute one = b.getIndexAttr(1);
221cf6a7c19SMahesh Ravishankar     for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
222cf6a7c19SMahesh Ravishankar       SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
223cf6a7c19SMahesh Ravishankar       if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
224cf6a7c19SMahesh Ravishankar                                           resultTileOffsets,
225cf6a7c19SMahesh Ravishankar                                           resultTileSizes))) {
226cf6a7c19SMahesh Ravishankar         op.emitOpError("unable to get position of result ")
227cf6a7c19SMahesh Ravishankar             << resultNum << " of the tiled implementation";
228cf6a7c19SMahesh Ravishankar         return {};
229cf6a7c19SMahesh Ravishankar       }
230cf6a7c19SMahesh Ravishankar       SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
231cf6a7c19SMahesh Ravishankar                                                   one);
232cf6a7c19SMahesh Ravishankar       Value yieldedValue = b.create<tensor::InsertSliceOp>(
233cf6a7c19SMahesh Ravishankar           op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
234cf6a7c19SMahesh Ravishankar           newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
235cf6a7c19SMahesh Ravishankar           resultTileStrides);
236cf6a7c19SMahesh Ravishankar       yieldedValues.push_back(yieldedValue);
237cf6a7c19SMahesh Ravishankar     }
238cf6a7c19SMahesh Ravishankar     return yieldedValues;
239cf6a7c19SMahesh Ravishankar   };
240cf6a7c19SMahesh Ravishankar   SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
241cf6a7c19SMahesh Ravishankar       rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
242cf6a7c19SMahesh Ravishankar       yieldValueFn);
243cf6a7c19SMahesh Ravishankar   for (auto loop : llvm::enumerate(tilingResult.loops)) {
244cf6a7c19SMahesh Ravishankar     rewriter.eraseOp(loop.value());
245cf6a7c19SMahesh Ravishankar     tilingResult.loops[loop.index()] = newLoops[loop.index()];
246cf6a7c19SMahesh Ravishankar   }
247cf6a7c19SMahesh Ravishankar   rewriter.replaceOp(op, tilingResult.loops.front().getResults());
248cf6a7c19SMahesh Ravishankar   return tilingResult;
249cf6a7c19SMahesh Ravishankar }
250