1*cf6a7c19SMahesh Ravishankar //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2*cf6a7c19SMahesh Ravishankar //
3*cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
5*cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*cf6a7c19SMahesh Ravishankar //
7*cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
8*cf6a7c19SMahesh Ravishankar //
9*cf6a7c19SMahesh Ravishankar // This file implements the tiling using TilingInterface.
10*cf6a7c19SMahesh Ravishankar //
11*cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
12*cf6a7c19SMahesh Ravishankar 
13*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/TileUsingInterface.h"
14*cf6a7c19SMahesh Ravishankar 
15*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
16*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h"
18*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/Utils/Utils.h"
19*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h"
20*cf6a7c19SMahesh Ravishankar #include "mlir/IR/Matchers.h"
21*cf6a7c19SMahesh Ravishankar #include "mlir/IR/PatternMatch.h"
22*cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h"
23*cf6a7c19SMahesh Ravishankar #include "llvm/Support/Debug.h"
24*cf6a7c19SMahesh Ravishankar 
25*cf6a7c19SMahesh Ravishankar #define DEBUG_TYPE "tile-using-interface"
26*cf6a7c19SMahesh Ravishankar 
27*cf6a7c19SMahesh Ravishankar using namespace mlir;
28*cf6a7c19SMahesh Ravishankar 
29*cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions &
30*cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
31*cf6a7c19SMahesh Ravishankar   assert(!tileSizeComputationFunction && "tile sizes already set");
32*cf6a7c19SMahesh Ravishankar   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
33*cf6a7c19SMahesh Ravishankar   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
34*cf6a7c19SMahesh Ravishankar     OpBuilder::InsertionGuard guard(b);
35*cf6a7c19SMahesh Ravishankar     b.setInsertionPointToStart(
36*cf6a7c19SMahesh Ravishankar         &op->getParentOfType<func::FuncOp>().getBody().front());
37*cf6a7c19SMahesh Ravishankar     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
38*cf6a7c19SMahesh Ravishankar       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
39*cf6a7c19SMahesh Ravishankar       return v;
40*cf6a7c19SMahesh Ravishankar     }));
41*cf6a7c19SMahesh Ravishankar   };
42*cf6a7c19SMahesh Ravishankar   return *this;
43*cf6a7c19SMahesh Ravishankar }
44*cf6a7c19SMahesh Ravishankar 
45*cf6a7c19SMahesh Ravishankar /// Generate an empty loop nest that represents the tiled loop nest shell.
46*cf6a7c19SMahesh Ravishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
47*cf6a7c19SMahesh Ravishankar /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
48*cf6a7c19SMahesh Ravishankar /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
49*cf6a7c19SMahesh Ravishankar /// the
50*cf6a7c19SMahesh Ravishankar ///   tile processed within the inner most loop.
51*cf6a7c19SMahesh Ravishankar static SmallVector<scf::ForOp>
52*cf6a7c19SMahesh Ravishankar generateTileLoopNest(OpBuilder &builder, Location loc,
53*cf6a7c19SMahesh Ravishankar                      ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
54*cf6a7c19SMahesh Ravishankar                      SmallVector<OpFoldResult> &offsets,
55*cf6a7c19SMahesh Ravishankar                      SmallVector<OpFoldResult> &sizes) {
56*cf6a7c19SMahesh Ravishankar   assert(!loopRanges.empty() && "expected at least one loop range");
57*cf6a7c19SMahesh Ravishankar   assert(loopRanges.size() == tileSizeVals.size() &&
58*cf6a7c19SMahesh Ravishankar          "expected as many tile sizes as loop ranges");
59*cf6a7c19SMahesh Ravishankar   OpBuilder::InsertionGuard guard(builder);
60*cf6a7c19SMahesh Ravishankar   SmallVector<scf::ForOp> loops;
61*cf6a7c19SMahesh Ravishankar   offsets.resize(loopRanges.size());
62*cf6a7c19SMahesh Ravishankar   sizes.resize(loopRanges.size());
63*cf6a7c19SMahesh Ravishankar 
64*cf6a7c19SMahesh Ravishankar   // The tile size to use (to avoid out of bounds access) is  minimum of
65*cf6a7c19SMahesh Ravishankar   // `tileSize` and `ub - iv`, where `iv` is the induction variable
66*cf6a7c19SMahesh Ravishankar   // of the tiled loop.
67*cf6a7c19SMahesh Ravishankar   AffineExpr s0, s1, d0;
68*cf6a7c19SMahesh Ravishankar   bindDims(builder.getContext(), d0);
69*cf6a7c19SMahesh Ravishankar   bindSymbols(builder.getContext(), s0, s1);
70*cf6a7c19SMahesh Ravishankar   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
71*cf6a7c19SMahesh Ravishankar 
72*cf6a7c19SMahesh Ravishankar   for (auto loopRange : llvm::enumerate(loopRanges)) {
73*cf6a7c19SMahesh Ravishankar     // No loops if tile size is zero. Set offset and size to the loop
74*cf6a7c19SMahesh Ravishankar     // offset and size.
75*cf6a7c19SMahesh Ravishankar     if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
76*cf6a7c19SMahesh Ravishankar       offsets[loopRange.index()] = loopRange.value().offset;
77*cf6a7c19SMahesh Ravishankar       sizes[loopRange.index()] = loopRange.value().size;
78*cf6a7c19SMahesh Ravishankar       continue;
79*cf6a7c19SMahesh Ravishankar     }
80*cf6a7c19SMahesh Ravishankar 
81*cf6a7c19SMahesh Ravishankar     auto loop = builder.create<scf::ForOp>(
82*cf6a7c19SMahesh Ravishankar         loc, loopRange.value().offset, loopRange.value().size,
83*cf6a7c19SMahesh Ravishankar         tileSizeVals[loopRange.index()], ValueRange{},
84*cf6a7c19SMahesh Ravishankar         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
85*cf6a7c19SMahesh Ravishankar             ValueRange /*iterArgs*/) {
86*cf6a7c19SMahesh Ravishankar           Value boundedTileSize = builder.create<AffineMinOp>(
87*cf6a7c19SMahesh Ravishankar               bodyLoc, minMap,
88*cf6a7c19SMahesh Ravishankar               ValueRange{iv, tileSizeVals[loopRange.index()],
89*cf6a7c19SMahesh Ravishankar                          loopRange.value().size});
90*cf6a7c19SMahesh Ravishankar           sizes[loopRange.index()] = boundedTileSize;
91*cf6a7c19SMahesh Ravishankar           builder.create<scf::YieldOp>(loc);
92*cf6a7c19SMahesh Ravishankar         });
93*cf6a7c19SMahesh Ravishankar     offsets[loopRange.index()] = loop.getInductionVar();
94*cf6a7c19SMahesh Ravishankar     loops.push_back(loop);
95*cf6a7c19SMahesh Ravishankar     builder.setInsertionPoint(loop.getBody()->getTerminator());
96*cf6a7c19SMahesh Ravishankar   }
97*cf6a7c19SMahesh Ravishankar   return loops;
98*cf6a7c19SMahesh Ravishankar }
99*cf6a7c19SMahesh Ravishankar 
100*cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
101*cf6a7c19SMahesh Ravishankar                                           scf::SCFTilingOptions options,
102*cf6a7c19SMahesh Ravishankar                                           PatternBenefit benefit)
103*cf6a7c19SMahesh Ravishankar     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
104*cf6a7c19SMahesh Ravishankar       options(std::move(options)) {}
105*cf6a7c19SMahesh Ravishankar 
106*cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
107*cf6a7c19SMahesh Ravishankar                                           MLIRContext *context,
108*cf6a7c19SMahesh Ravishankar                                           scf::SCFTilingOptions options,
109*cf6a7c19SMahesh Ravishankar                                           PatternBenefit benefit)
110*cf6a7c19SMahesh Ravishankar     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
111*cf6a7c19SMahesh Ravishankar       options(std::move(options)) {}
112*cf6a7c19SMahesh Ravishankar 
113*cf6a7c19SMahesh Ravishankar FailureOr<scf::SCFTilingResult>
114*cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::returningMatchAndRewrite(
115*cf6a7c19SMahesh Ravishankar     TilingInterface op, PatternRewriter &rewriter) const {
116*cf6a7c19SMahesh Ravishankar   OpBuilder::InsertionGuard guard(rewriter);
117*cf6a7c19SMahesh Ravishankar   rewriter.setInsertionPointAfter(op);
118*cf6a7c19SMahesh Ravishankar 
119*cf6a7c19SMahesh Ravishankar   if (!options.tileSizeComputationFunction) {
120*cf6a7c19SMahesh Ravishankar     return rewriter.notifyMatchFailure(
121*cf6a7c19SMahesh Ravishankar         op, "missing tile size computation function");
122*cf6a7c19SMahesh Ravishankar   }
123*cf6a7c19SMahesh Ravishankar 
124*cf6a7c19SMahesh Ravishankar   // 1. Get the range of the loops that are represented by the operation.
125*cf6a7c19SMahesh Ravishankar   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
126*cf6a7c19SMahesh Ravishankar   size_t numLoops = iterationDomain.size();
127*cf6a7c19SMahesh Ravishankar   if (numLoops == 0) {
128*cf6a7c19SMahesh Ravishankar     return rewriter.notifyMatchFailure(
129*cf6a7c19SMahesh Ravishankar         op, "unable to tile op with no iteration domain");
130*cf6a7c19SMahesh Ravishankar   }
131*cf6a7c19SMahesh Ravishankar 
132*cf6a7c19SMahesh Ravishankar   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
133*cf6a7c19SMahesh Ravishankar   // skips tiling a particular dimension. This convention is significantly
134*cf6a7c19SMahesh Ravishankar   // simpler to handle instead of adjusting affine maps to account for missing
135*cf6a7c19SMahesh Ravishankar   // dimensions.
136*cf6a7c19SMahesh Ravishankar   SmallVector<Value, 4> tileSizeVector =
137*cf6a7c19SMahesh Ravishankar       options.tileSizeComputationFunction(rewriter, op);
138*cf6a7c19SMahesh Ravishankar   if (tileSizeVector.size() < iterationDomain.size()) {
139*cf6a7c19SMahesh Ravishankar     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
140*cf6a7c19SMahesh Ravishankar     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
141*cf6a7c19SMahesh Ravishankar   }
142*cf6a7c19SMahesh Ravishankar 
143*cf6a7c19SMahesh Ravishankar   scf::SCFTilingResult tilingResult;
144*cf6a7c19SMahesh Ravishankar   SmallVector<OpFoldResult> offsets, sizes;
145*cf6a7c19SMahesh Ravishankar   {
146*cf6a7c19SMahesh Ravishankar     // 3. Materialize an empty loop nest that iterates over the tiles. These
147*cf6a7c19SMahesh Ravishankar     // loops for now do not return any values even if the original operation has
148*cf6a7c19SMahesh Ravishankar     // results.
149*cf6a7c19SMahesh Ravishankar     tilingResult.loops = generateTileLoopNest(
150*cf6a7c19SMahesh Ravishankar         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
151*cf6a7c19SMahesh Ravishankar 
152*cf6a7c19SMahesh Ravishankar     LLVM_DEBUG({
153*cf6a7c19SMahesh Ravishankar       if (!tilingResult.loops.empty()) {
154*cf6a7c19SMahesh Ravishankar         llvm::errs() << "LoopNest shell :\n";
155*cf6a7c19SMahesh Ravishankar         tilingResult.loops.front().dump();
156*cf6a7c19SMahesh Ravishankar         llvm::errs() << "\n";
157*cf6a7c19SMahesh Ravishankar       }
158*cf6a7c19SMahesh Ravishankar     });
159*cf6a7c19SMahesh Ravishankar 
160*cf6a7c19SMahesh Ravishankar     // 4. Generate the tiled implementation within the inner most loop.
161*cf6a7c19SMahesh Ravishankar     if (!tilingResult.loops.empty())
162*cf6a7c19SMahesh Ravishankar       rewriter.setInsertionPoint(
163*cf6a7c19SMahesh Ravishankar           tilingResult.loops.back().getBody()->getTerminator());
164*cf6a7c19SMahesh Ravishankar     SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
165*cf6a7c19SMahesh Ravishankar         rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
166*cf6a7c19SMahesh Ravishankar     if (tiledImplementation.size() != 1) {
167*cf6a7c19SMahesh Ravishankar       return rewriter.notifyMatchFailure(
168*cf6a7c19SMahesh Ravishankar           op, "expected tiled implementation to return a single op");
169*cf6a7c19SMahesh Ravishankar     }
170*cf6a7c19SMahesh Ravishankar     tilingResult.tiledOp = tiledImplementation[0];
171*cf6a7c19SMahesh Ravishankar 
172*cf6a7c19SMahesh Ravishankar     LLVM_DEBUG({
173*cf6a7c19SMahesh Ravishankar       if (!tilingResult.loops.empty()) {
174*cf6a7c19SMahesh Ravishankar         llvm::errs() << "After tiled implementation :\n";
175*cf6a7c19SMahesh Ravishankar         tilingResult.loops.front().dump();
176*cf6a7c19SMahesh Ravishankar         llvm::errs() << "\n";
177*cf6a7c19SMahesh Ravishankar       }
178*cf6a7c19SMahesh Ravishankar     });
179*cf6a7c19SMahesh Ravishankar   }
180*cf6a7c19SMahesh Ravishankar 
181*cf6a7c19SMahesh Ravishankar   if (op->getNumResults() == 0) {
182*cf6a7c19SMahesh Ravishankar     rewriter.eraseOp(op);
183*cf6a7c19SMahesh Ravishankar     return tilingResult;
184*cf6a7c19SMahesh Ravishankar   }
185*cf6a7c19SMahesh Ravishankar 
186*cf6a7c19SMahesh Ravishankar   // 5. If the original operations has results, modify the loop nest to yield
187*cf6a7c19SMahesh Ravishankar   // the replacement values.
188*cf6a7c19SMahesh Ravishankar   SmallVector<Value> replacements;
189*cf6a7c19SMahesh Ravishankar   if (tilingResult.loops.empty()) {
190*cf6a7c19SMahesh Ravishankar     // 5a. If there were no loops, the tiled implementation results are the
191*cf6a7c19SMahesh Ravishankar     // replacements.
192*cf6a7c19SMahesh Ravishankar     rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
193*cf6a7c19SMahesh Ravishankar     return tilingResult;
194*cf6a7c19SMahesh Ravishankar   }
195*cf6a7c19SMahesh Ravishankar 
196*cf6a7c19SMahesh Ravishankar   // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
197*cf6a7c19SMahesh Ravishankar   // replacement values using destructive updates. Use the `TilingInterface`
198*cf6a7c19SMahesh Ravishankar   // to get the position of the result tiles and use that to generate the
199*cf6a7c19SMahesh Ravishankar   // destructive update pattern, i.e.,
200*cf6a7c19SMahesh Ravishankar   //
201*cf6a7c19SMahesh Ravishankar   // ```mlir
202*cf6a7c19SMahesh Ravishankar   // scf.for %iv0 = ... {
203*cf6a7c19SMahesh Ravishankar   //   %0 = tiled_op
204*cf6a7c19SMahesh Ravishankar   // }
205*cf6a7c19SMahesh Ravishankar   // ```
206*cf6a7c19SMahesh Ravishankar   //
207*cf6a7c19SMahesh Ravishankar   // is transformed to
208*cf6a7c19SMahesh Ravishankar   //
209*cf6a7c19SMahesh Ravishankar   // ```mlir
210*cf6a7c19SMahesh Ravishankar   // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
211*cf6a7c19SMahesh Ravishankar   //   %0 = tiled_op
212*cf6a7c19SMahesh Ravishankar   //   %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
213*cf6a7c19SMahesh Ravishankar   //   scf.yield %1
214*cf6a7c19SMahesh Ravishankar   // }
215*cf6a7c19SMahesh Ravishankar   // ```
216*cf6a7c19SMahesh Ravishankar   NewYieldValueFn yieldValueFn =
217*cf6a7c19SMahesh Ravishankar       [&](OpBuilder &b, Location loc,
218*cf6a7c19SMahesh Ravishankar           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
219*cf6a7c19SMahesh Ravishankar     SmallVector<Value> yieldedValues;
220*cf6a7c19SMahesh Ravishankar     Attribute one = b.getIndexAttr(1);
221*cf6a7c19SMahesh Ravishankar     for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
222*cf6a7c19SMahesh Ravishankar       SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
223*cf6a7c19SMahesh Ravishankar       if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
224*cf6a7c19SMahesh Ravishankar                                           resultTileOffsets,
225*cf6a7c19SMahesh Ravishankar                                           resultTileSizes))) {
226*cf6a7c19SMahesh Ravishankar         op.emitOpError("unable to get position of result ")
227*cf6a7c19SMahesh Ravishankar             << resultNum << " of the tiled implementation";
228*cf6a7c19SMahesh Ravishankar         return {};
229*cf6a7c19SMahesh Ravishankar       }
230*cf6a7c19SMahesh Ravishankar       SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
231*cf6a7c19SMahesh Ravishankar                                                   one);
232*cf6a7c19SMahesh Ravishankar       Value yieldedValue = b.create<tensor::InsertSliceOp>(
233*cf6a7c19SMahesh Ravishankar           op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
234*cf6a7c19SMahesh Ravishankar           newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
235*cf6a7c19SMahesh Ravishankar           resultTileStrides);
236*cf6a7c19SMahesh Ravishankar       yieldedValues.push_back(yieldedValue);
237*cf6a7c19SMahesh Ravishankar     }
238*cf6a7c19SMahesh Ravishankar     return yieldedValues;
239*cf6a7c19SMahesh Ravishankar   };
240*cf6a7c19SMahesh Ravishankar   SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
241*cf6a7c19SMahesh Ravishankar       rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
242*cf6a7c19SMahesh Ravishankar       yieldValueFn);
243*cf6a7c19SMahesh Ravishankar   for (auto loop : llvm::enumerate(tilingResult.loops)) {
244*cf6a7c19SMahesh Ravishankar     rewriter.eraseOp(loop.value());
245*cf6a7c19SMahesh Ravishankar     tilingResult.loops[loop.index()] = newLoops[loop.index()];
246*cf6a7c19SMahesh Ravishankar   }
247*cf6a7c19SMahesh Ravishankar   rewriter.replaceOp(op, tilingResult.loops.front().getResults());
248*cf6a7c19SMahesh Ravishankar   return tilingResult;
249*cf6a7c19SMahesh Ravishankar }
250