1cf6a7c19SMahesh Ravishankar //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2cf6a7c19SMahesh Ravishankar //
3cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
5cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6cf6a7c19SMahesh Ravishankar //
7cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
8cf6a7c19SMahesh Ravishankar //
9cf6a7c19SMahesh Ravishankar // This file implements the tiling using TilingInterface.
10cf6a7c19SMahesh Ravishankar //
11cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
12cf6a7c19SMahesh Ravishankar 
138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14cf6a7c19SMahesh Ravishankar 
15cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
16cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h"
18cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/Utils/Utils.h"
19cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h"
20cf6a7c19SMahesh Ravishankar #include "mlir/IR/Matchers.h"
21cf6a7c19SMahesh Ravishankar #include "mlir/IR/PatternMatch.h"
22cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h"
23cf6a7c19SMahesh Ravishankar #include "llvm/Support/Debug.h"
24cf6a7c19SMahesh Ravishankar 
25cf6a7c19SMahesh Ravishankar #define DEBUG_TYPE "tile-using-interface"
26cf6a7c19SMahesh Ravishankar 
27cf6a7c19SMahesh Ravishankar using namespace mlir;
28cf6a7c19SMahesh Ravishankar 
29cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions &
30cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
31cf6a7c19SMahesh Ravishankar   assert(!tileSizeComputationFunction && "tile sizes already set");
32cf6a7c19SMahesh Ravishankar   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
33cf6a7c19SMahesh Ravishankar   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
34cf6a7c19SMahesh Ravishankar     OpBuilder::InsertionGuard guard(b);
35cf6a7c19SMahesh Ravishankar     b.setInsertionPointToStart(
36cf6a7c19SMahesh Ravishankar         &op->getParentOfType<func::FuncOp>().getBody().front());
37cf6a7c19SMahesh Ravishankar     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
38cf6a7c19SMahesh Ravishankar       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
39cf6a7c19SMahesh Ravishankar       return v;
40cf6a7c19SMahesh Ravishankar     }));
41cf6a7c19SMahesh Ravishankar   };
42cf6a7c19SMahesh Ravishankar   return *this;
43cf6a7c19SMahesh Ravishankar }
44cf6a7c19SMahesh Ravishankar 
452f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
462f637fe7SMahesh Ravishankar // TileUsingSCFForOp pattern implementation.
472f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
482f637fe7SMahesh Ravishankar 
49cf6a7c19SMahesh Ravishankar /// Generate an empty loop nest that represents the tiled loop nest shell.
50cf6a7c19SMahesh Ravishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
51cf6a7c19SMahesh Ravishankar /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
52cf6a7c19SMahesh Ravishankar /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
53cf6a7c19SMahesh Ravishankar /// the
54cf6a7c19SMahesh Ravishankar ///   tile processed within the inner most loop.
55cf6a7c19SMahesh Ravishankar static SmallVector<scf::ForOp>
56cf6a7c19SMahesh Ravishankar generateTileLoopNest(OpBuilder &builder, Location loc,
57cf6a7c19SMahesh Ravishankar                      ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
58cf6a7c19SMahesh Ravishankar                      SmallVector<OpFoldResult> &offsets,
59cf6a7c19SMahesh Ravishankar                      SmallVector<OpFoldResult> &sizes) {
60cf6a7c19SMahesh Ravishankar   assert(!loopRanges.empty() && "expected at least one loop range");
61cf6a7c19SMahesh Ravishankar   assert(loopRanges.size() == tileSizeVals.size() &&
62cf6a7c19SMahesh Ravishankar          "expected as many tile sizes as loop ranges");
63cf6a7c19SMahesh Ravishankar   OpBuilder::InsertionGuard guard(builder);
64cf6a7c19SMahesh Ravishankar   SmallVector<scf::ForOp> loops;
65cf6a7c19SMahesh Ravishankar   offsets.resize(loopRanges.size());
66cf6a7c19SMahesh Ravishankar   sizes.resize(loopRanges.size());
67cf6a7c19SMahesh Ravishankar 
68cf6a7c19SMahesh Ravishankar   // The tile size to use (to avoid out of bounds access) is  minimum of
69cf6a7c19SMahesh Ravishankar   // `tileSize` and `ub - iv`, where `iv` is the induction variable
70cf6a7c19SMahesh Ravishankar   // of the tiled loop.
71cf6a7c19SMahesh Ravishankar   AffineExpr s0, s1, d0;
72cf6a7c19SMahesh Ravishankar   bindDims(builder.getContext(), d0);
73cf6a7c19SMahesh Ravishankar   bindSymbols(builder.getContext(), s0, s1);
74cf6a7c19SMahesh Ravishankar   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
75cf6a7c19SMahesh Ravishankar 
76cf6a7c19SMahesh Ravishankar   for (auto loopRange : llvm::enumerate(loopRanges)) {
77cf6a7c19SMahesh Ravishankar     // No loops if tile size is zero. Set offset and size to the loop
78cf6a7c19SMahesh Ravishankar     // offset and size.
79cf6a7c19SMahesh Ravishankar     if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
80cf6a7c19SMahesh Ravishankar       offsets[loopRange.index()] = loopRange.value().offset;
81cf6a7c19SMahesh Ravishankar       sizes[loopRange.index()] = loopRange.value().size;
82cf6a7c19SMahesh Ravishankar       continue;
83cf6a7c19SMahesh Ravishankar     }
84cf6a7c19SMahesh Ravishankar 
85cf6a7c19SMahesh Ravishankar     auto loop = builder.create<scf::ForOp>(
86cf6a7c19SMahesh Ravishankar         loc, loopRange.value().offset, loopRange.value().size,
87cf6a7c19SMahesh Ravishankar         tileSizeVals[loopRange.index()], ValueRange{},
88cf6a7c19SMahesh Ravishankar         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
89cf6a7c19SMahesh Ravishankar             ValueRange /*iterArgs*/) {
90cf6a7c19SMahesh Ravishankar           Value boundedTileSize = builder.create<AffineMinOp>(
91cf6a7c19SMahesh Ravishankar               bodyLoc, minMap,
92cf6a7c19SMahesh Ravishankar               ValueRange{iv, tileSizeVals[loopRange.index()],
93cf6a7c19SMahesh Ravishankar                          loopRange.value().size});
94cf6a7c19SMahesh Ravishankar           sizes[loopRange.index()] = boundedTileSize;
95cf6a7c19SMahesh Ravishankar           builder.create<scf::YieldOp>(loc);
96cf6a7c19SMahesh Ravishankar         });
97cf6a7c19SMahesh Ravishankar     offsets[loopRange.index()] = loop.getInductionVar();
98cf6a7c19SMahesh Ravishankar     loops.push_back(loop);
99cf6a7c19SMahesh Ravishankar     builder.setInsertionPoint(loop.getBody()->getTerminator());
100cf6a7c19SMahesh Ravishankar   }
101cf6a7c19SMahesh Ravishankar   return loops;
102cf6a7c19SMahesh Ravishankar }
103cf6a7c19SMahesh Ravishankar 
104cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
105cf6a7c19SMahesh Ravishankar                                           scf::SCFTilingOptions options,
106cf6a7c19SMahesh Ravishankar                                           PatternBenefit benefit)
107cf6a7c19SMahesh Ravishankar     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
108cf6a7c19SMahesh Ravishankar       options(std::move(options)) {}
109cf6a7c19SMahesh Ravishankar 
110cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
111cf6a7c19SMahesh Ravishankar                                           MLIRContext *context,
112cf6a7c19SMahesh Ravishankar                                           scf::SCFTilingOptions options,
113cf6a7c19SMahesh Ravishankar                                           PatternBenefit benefit)
114cf6a7c19SMahesh Ravishankar     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
115cf6a7c19SMahesh Ravishankar       options(std::move(options)) {}
116cf6a7c19SMahesh Ravishankar 
117cf6a7c19SMahesh Ravishankar FailureOr<scf::SCFTilingResult>
118cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::returningMatchAndRewrite(
119cf6a7c19SMahesh Ravishankar     TilingInterface op, PatternRewriter &rewriter) const {
120cf6a7c19SMahesh Ravishankar   OpBuilder::InsertionGuard guard(rewriter);
121cf6a7c19SMahesh Ravishankar   rewriter.setInsertionPointAfter(op);
122cf6a7c19SMahesh Ravishankar 
123cf6a7c19SMahesh Ravishankar   if (!options.tileSizeComputationFunction) {
124cf6a7c19SMahesh Ravishankar     return rewriter.notifyMatchFailure(
125cf6a7c19SMahesh Ravishankar         op, "missing tile size computation function");
126cf6a7c19SMahesh Ravishankar   }
127cf6a7c19SMahesh Ravishankar 
128cf6a7c19SMahesh Ravishankar   // 1. Get the range of the loops that are represented by the operation.
129cf6a7c19SMahesh Ravishankar   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
130cf6a7c19SMahesh Ravishankar   size_t numLoops = iterationDomain.size();
131cf6a7c19SMahesh Ravishankar   if (numLoops == 0) {
132cf6a7c19SMahesh Ravishankar     return rewriter.notifyMatchFailure(
133cf6a7c19SMahesh Ravishankar         op, "unable to tile op with no iteration domain");
134cf6a7c19SMahesh Ravishankar   }
135cf6a7c19SMahesh Ravishankar 
136cf6a7c19SMahesh Ravishankar   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
137cf6a7c19SMahesh Ravishankar   // skips tiling a particular dimension. This convention is significantly
138cf6a7c19SMahesh Ravishankar   // simpler to handle instead of adjusting affine maps to account for missing
139cf6a7c19SMahesh Ravishankar   // dimensions.
140cf6a7c19SMahesh Ravishankar   SmallVector<Value, 4> tileSizeVector =
141cf6a7c19SMahesh Ravishankar       options.tileSizeComputationFunction(rewriter, op);
142cf6a7c19SMahesh Ravishankar   if (tileSizeVector.size() < iterationDomain.size()) {
143cf6a7c19SMahesh Ravishankar     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
144cf6a7c19SMahesh Ravishankar     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
145cf6a7c19SMahesh Ravishankar   }
146cf6a7c19SMahesh Ravishankar 
147cf6a7c19SMahesh Ravishankar   scf::SCFTilingResult tilingResult;
148cf6a7c19SMahesh Ravishankar   SmallVector<OpFoldResult> offsets, sizes;
149cf6a7c19SMahesh Ravishankar   {
150cf6a7c19SMahesh Ravishankar     // 3. Materialize an empty loop nest that iterates over the tiles. These
151cf6a7c19SMahesh Ravishankar     // loops for now do not return any values even if the original operation has
152cf6a7c19SMahesh Ravishankar     // results.
153cf6a7c19SMahesh Ravishankar     tilingResult.loops = generateTileLoopNest(
154cf6a7c19SMahesh Ravishankar         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
155cf6a7c19SMahesh Ravishankar 
156cf6a7c19SMahesh Ravishankar     LLVM_DEBUG({
157cf6a7c19SMahesh Ravishankar       if (!tilingResult.loops.empty()) {
158cf6a7c19SMahesh Ravishankar         llvm::errs() << "LoopNest shell :\n";
159cf6a7c19SMahesh Ravishankar         tilingResult.loops.front().dump();
160cf6a7c19SMahesh Ravishankar         llvm::errs() << "\n";
161cf6a7c19SMahesh Ravishankar       }
162cf6a7c19SMahesh Ravishankar     });
163cf6a7c19SMahesh Ravishankar 
164cf6a7c19SMahesh Ravishankar     // 4. Generate the tiled implementation within the inner most loop.
165cf6a7c19SMahesh Ravishankar     if (!tilingResult.loops.empty())
166cf6a7c19SMahesh Ravishankar       rewriter.setInsertionPoint(
167cf6a7c19SMahesh Ravishankar           tilingResult.loops.back().getBody()->getTerminator());
168cf6a7c19SMahesh Ravishankar     SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
169cf6a7c19SMahesh Ravishankar         rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
170cf6a7c19SMahesh Ravishankar     if (tiledImplementation.size() != 1) {
171cf6a7c19SMahesh Ravishankar       return rewriter.notifyMatchFailure(
172cf6a7c19SMahesh Ravishankar           op, "expected tiled implementation to return a single op");
173cf6a7c19SMahesh Ravishankar     }
174cf6a7c19SMahesh Ravishankar     tilingResult.tiledOp = tiledImplementation[0];
175cf6a7c19SMahesh Ravishankar 
176cf6a7c19SMahesh Ravishankar     LLVM_DEBUG({
177cf6a7c19SMahesh Ravishankar       if (!tilingResult.loops.empty()) {
178cf6a7c19SMahesh Ravishankar         llvm::errs() << "After tiled implementation :\n";
179cf6a7c19SMahesh Ravishankar         tilingResult.loops.front().dump();
180cf6a7c19SMahesh Ravishankar         llvm::errs() << "\n";
181cf6a7c19SMahesh Ravishankar       }
182cf6a7c19SMahesh Ravishankar     });
183cf6a7c19SMahesh Ravishankar   }
184cf6a7c19SMahesh Ravishankar 
185cf6a7c19SMahesh Ravishankar   if (op->getNumResults() == 0) {
186cf6a7c19SMahesh Ravishankar     rewriter.eraseOp(op);
187cf6a7c19SMahesh Ravishankar     return tilingResult;
188cf6a7c19SMahesh Ravishankar   }
189cf6a7c19SMahesh Ravishankar 
190cf6a7c19SMahesh Ravishankar   // 5. If the original operations has results, modify the loop nest to yield
191cf6a7c19SMahesh Ravishankar   // the replacement values.
192cf6a7c19SMahesh Ravishankar   SmallVector<Value> replacements;
193cf6a7c19SMahesh Ravishankar   if (tilingResult.loops.empty()) {
194cf6a7c19SMahesh Ravishankar     // 5a. If there were no loops, the tiled implementation results are the
195cf6a7c19SMahesh Ravishankar     // replacements.
196cf6a7c19SMahesh Ravishankar     rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
197cf6a7c19SMahesh Ravishankar     return tilingResult;
198cf6a7c19SMahesh Ravishankar   }
199cf6a7c19SMahesh Ravishankar 
200cf6a7c19SMahesh Ravishankar   // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
201cf6a7c19SMahesh Ravishankar   // replacement values using destructive updates. Use the `TilingInterface`
202cf6a7c19SMahesh Ravishankar   // to get the position of the result tiles and use that to generate the
203cf6a7c19SMahesh Ravishankar   // destructive update pattern, i.e.,
204cf6a7c19SMahesh Ravishankar   //
205cf6a7c19SMahesh Ravishankar   // ```mlir
206cf6a7c19SMahesh Ravishankar   // scf.for %iv0 = ... {
207cf6a7c19SMahesh Ravishankar   //   %0 = tiled_op
208cf6a7c19SMahesh Ravishankar   // }
209cf6a7c19SMahesh Ravishankar   // ```
210cf6a7c19SMahesh Ravishankar   //
211cf6a7c19SMahesh Ravishankar   // is transformed to
212cf6a7c19SMahesh Ravishankar   //
213cf6a7c19SMahesh Ravishankar   // ```mlir
214cf6a7c19SMahesh Ravishankar   // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
215cf6a7c19SMahesh Ravishankar   //   %0 = tiled_op
216cf6a7c19SMahesh Ravishankar   //   %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
217cf6a7c19SMahesh Ravishankar   //   scf.yield %1
218cf6a7c19SMahesh Ravishankar   // }
219cf6a7c19SMahesh Ravishankar   // ```
220cf6a7c19SMahesh Ravishankar   NewYieldValueFn yieldValueFn =
221cf6a7c19SMahesh Ravishankar       [&](OpBuilder &b, Location loc,
222cf6a7c19SMahesh Ravishankar           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
223cf6a7c19SMahesh Ravishankar     SmallVector<Value> yieldedValues;
224cf6a7c19SMahesh Ravishankar     Attribute one = b.getIndexAttr(1);
225cf6a7c19SMahesh Ravishankar     for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
226cf6a7c19SMahesh Ravishankar       SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
227cf6a7c19SMahesh Ravishankar       if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
228cf6a7c19SMahesh Ravishankar                                           resultTileOffsets,
229cf6a7c19SMahesh Ravishankar                                           resultTileSizes))) {
230cf6a7c19SMahesh Ravishankar         op.emitOpError("unable to get position of result ")
231cf6a7c19SMahesh Ravishankar             << resultNum << " of the tiled implementation";
232cf6a7c19SMahesh Ravishankar         return {};
233cf6a7c19SMahesh Ravishankar       }
234cf6a7c19SMahesh Ravishankar       SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
235cf6a7c19SMahesh Ravishankar                                                   one);
236cf6a7c19SMahesh Ravishankar       Value yieldedValue = b.create<tensor::InsertSliceOp>(
237cf6a7c19SMahesh Ravishankar           op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
238cf6a7c19SMahesh Ravishankar           newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
239cf6a7c19SMahesh Ravishankar           resultTileStrides);
240cf6a7c19SMahesh Ravishankar       yieldedValues.push_back(yieldedValue);
241cf6a7c19SMahesh Ravishankar     }
242cf6a7c19SMahesh Ravishankar     return yieldedValues;
243cf6a7c19SMahesh Ravishankar   };
244cf6a7c19SMahesh Ravishankar   SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
245cf6a7c19SMahesh Ravishankar       rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
246cf6a7c19SMahesh Ravishankar       yieldValueFn);
247*ca2933f3SAdrian Kuegel   for (const auto &loop : llvm::enumerate(tilingResult.loops)) {
248cf6a7c19SMahesh Ravishankar     rewriter.eraseOp(loop.value());
249cf6a7c19SMahesh Ravishankar     tilingResult.loops[loop.index()] = newLoops[loop.index()];
250cf6a7c19SMahesh Ravishankar   }
251cf6a7c19SMahesh Ravishankar   rewriter.replaceOp(op, tilingResult.loops.front().getResults());
252cf6a7c19SMahesh Ravishankar   return tilingResult;
253cf6a7c19SMahesh Ravishankar }
2542f637fe7SMahesh Ravishankar 
2552f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
2562f637fe7SMahesh Ravishankar // TileConsumerAndFuseProducersUsingSCFForOp pattern implementation.
2572f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
2582f637fe7SMahesh Ravishankar 
2592f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp::
2602f637fe7SMahesh Ravishankar     TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
2612f637fe7SMahesh Ravishankar                                               scf::SCFTilingOptions options,
2622f637fe7SMahesh Ravishankar                                               PatternBenefit benefit)
2632f637fe7SMahesh Ravishankar     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
2642f637fe7SMahesh Ravishankar       tilingPattern(context, std::move(options)) {}
2652f637fe7SMahesh Ravishankar 
2662f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp::
2672f637fe7SMahesh Ravishankar     TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
2682f637fe7SMahesh Ravishankar                                               MLIRContext *context,
2692f637fe7SMahesh Ravishankar                                               scf::SCFTilingOptions options,
2702f637fe7SMahesh Ravishankar                                               PatternBenefit benefit)
2712f637fe7SMahesh Ravishankar     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
2722f637fe7SMahesh Ravishankar       tilingPattern(context, std::move(options)) {}
2732f637fe7SMahesh Ravishankar 
2742f637fe7SMahesh Ravishankar /// Return the `Value` that is defined by an operation that implements
2752f637fe7SMahesh Ravishankar /// the `TilingInterface`. Looks through `iter_args` of scf.for nest
2762f637fe7SMahesh Ravishankar /// if required.
2772f637fe7SMahesh Ravishankar static Optional<OpResult> getFusableProducer(Value v) {
2782f637fe7SMahesh Ravishankar   while (auto blockArg = v.dyn_cast<BlockArgument>()) {
2792f637fe7SMahesh Ravishankar     auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
2802f637fe7SMahesh Ravishankar     if (!loopOp)
2812f637fe7SMahesh Ravishankar       return llvm::None;
2822f637fe7SMahesh Ravishankar     v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
2832f637fe7SMahesh Ravishankar   }
2842f637fe7SMahesh Ravishankar   if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp()))
2852f637fe7SMahesh Ravishankar     return llvm::None;
2862f637fe7SMahesh Ravishankar   return v.cast<OpResult>();
2872f637fe7SMahesh Ravishankar }
2882f637fe7SMahesh Ravishankar 
2892f637fe7SMahesh Ravishankar FailureOr<scf::SCFTileAndFuseResult>
2902f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
2912f637fe7SMahesh Ravishankar     TilingInterface op, PatternRewriter &rewriter) const {
2922f637fe7SMahesh Ravishankar   // This transformation is only valid for ops that return values (i.e. not
2932f637fe7SMahesh Ravishankar   // valid to use with operations that have memref operands).
2942f637fe7SMahesh Ravishankar   if (!op->getNumResults()) {
2952f637fe7SMahesh Ravishankar     return rewriter.notifyMatchFailure(
2962f637fe7SMahesh Ravishankar         op, "invalid pattern for op with no results");
2972f637fe7SMahesh Ravishankar   }
2982f637fe7SMahesh Ravishankar 
2992f637fe7SMahesh Ravishankar   // 1. First tile the consumer.
3002f637fe7SMahesh Ravishankar   SCFTileAndFuseResult tileAndFuseResult;
3012f637fe7SMahesh Ravishankar   {
3022f637fe7SMahesh Ravishankar     FailureOr<SCFTilingResult> tilingResult =
3032f637fe7SMahesh Ravishankar         tilingPattern.returningMatchAndRewrite(op, rewriter);
3042f637fe7SMahesh Ravishankar     if (failed(tilingResult)) {
3052f637fe7SMahesh Ravishankar       return failure();
3062f637fe7SMahesh Ravishankar     }
3072f637fe7SMahesh Ravishankar     tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
3082f637fe7SMahesh Ravishankar     tileAndFuseResult.loops = std::move(tilingResult->loops);
3092f637fe7SMahesh Ravishankar   }
3102f637fe7SMahesh Ravishankar 
3112f637fe7SMahesh Ravishankar   // 2. Typically, the operands of the tiled operation are slices of the
3122f637fe7SMahesh Ravishankar   //    operands of the untiled operation. These are expressed in IR using
3132f637fe7SMahesh Ravishankar   //    `tensor.extract_slice` operations with source being the operands of the
3142f637fe7SMahesh Ravishankar   //    untiled operation. Create a worklist of these `tensor.extract_slice`
3152f637fe7SMahesh Ravishankar   //    operations. If the producers of the source of the `tensor.extract_slice`
3162f637fe7SMahesh Ravishankar   //    can be tiled such that the tiled value is generated in-place, that
3172f637fe7SMahesh Ravishankar   //    effectively tiles + fuses the operations.
3182f637fe7SMahesh Ravishankar   auto addCandidateSlices = [](Operation *fusedOp,
3192f637fe7SMahesh Ravishankar                                std::deque<tensor::ExtractSliceOp> &candidates) {
3202f637fe7SMahesh Ravishankar     for (Value operand : fusedOp->getOperands())
3212f637fe7SMahesh Ravishankar       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
3222f637fe7SMahesh Ravishankar         candidates.push_back(sliceOp);
3232f637fe7SMahesh Ravishankar   };
3242f637fe7SMahesh Ravishankar 
3252f637fe7SMahesh Ravishankar   std::deque<tensor::ExtractSliceOp> candidates;
3262f637fe7SMahesh Ravishankar   addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
3272f637fe7SMahesh Ravishankar   OpBuilder::InsertionGuard g(rewriter);
3282f637fe7SMahesh Ravishankar   while (!candidates.empty()) {
3292f637fe7SMahesh Ravishankar     // 2a. Traverse the slices in BFS fashion.
3302f637fe7SMahesh Ravishankar     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
3312f637fe7SMahesh Ravishankar     candidates.pop_front();
3322f637fe7SMahesh Ravishankar 
3332f637fe7SMahesh Ravishankar     // 2b. Get the producer of the source (potentially walking through
3342f637fe7SMahesh Ravishankar     // `iter_args` of nested `scf.for`)
3352f637fe7SMahesh Ravishankar     Optional<OpResult> fusableProducer =
3362f637fe7SMahesh Ravishankar         getFusableProducer(candidateSliceOp.source());
3372f637fe7SMahesh Ravishankar     if (!fusableProducer)
3382f637fe7SMahesh Ravishankar       continue;
3392f637fe7SMahesh Ravishankar 
3402f637fe7SMahesh Ravishankar     // 2c. Generate the tiled implementation of the producer of the source
3412f637fe7SMahesh Ravishankar     rewriter.setInsertionPoint(candidateSliceOp);
3422f637fe7SMahesh Ravishankar     FailureOr<Value> fusedProducerValue =
3433b7c3a65SKazu Hirata         tensor::replaceExtractSliceWithTiledProducer(
3443b7c3a65SKazu Hirata             rewriter, candidateSliceOp, fusableProducer.getValue());
3452f637fe7SMahesh Ravishankar     if (failed(fusedProducerValue))
3462f637fe7SMahesh Ravishankar       continue;
3473b7c3a65SKazu Hirata     rewriter.replaceOp(candidateSliceOp, fusedProducerValue.getValue());
3482f637fe7SMahesh Ravishankar 
3492f637fe7SMahesh Ravishankar     // 2d. The operands of the fused producer might themselved be slices of
3502f637fe7SMahesh Ravishankar     //     values produced by operations that implement the `TilingInterface`.
3512f637fe7SMahesh Ravishankar     //     Add these operations to the worklist.
3522f637fe7SMahesh Ravishankar     Operation *fusedProducer = fusedProducerValue->getDefiningOp();
3532f637fe7SMahesh Ravishankar     tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
3542f637fe7SMahesh Ravishankar     addCandidateSlices(fusedProducer, candidates);
3552f637fe7SMahesh Ravishankar 
3562f637fe7SMahesh Ravishankar     // 2e. If the operation being fused creates a value that is used as `outs`
3572f637fe7SMahesh Ravishankar     //     in the tiled operation, the result of the unfused operation will be
3582f637fe7SMahesh Ravishankar     //     used in the `iter_args` of the tiled loop generated. When the
3592f637fe7SMahesh Ravishankar     //     operation is fused, this use in `iter_args` needs to be modified to
3602f637fe7SMahesh Ravishankar     //     use the destination of the fused operation. For example, starting
3612f637fe7SMahesh Ravishankar     //     with
3622f637fe7SMahesh Ravishankar     //
3632f637fe7SMahesh Ravishankar     //     ```mlir
3642f637fe7SMahesh Ravishankar     //     %0 = linalg.init_tensor ...
3652f637fe7SMahesh Ravishankar     //     %1 = linalg.fill ... outs(%0:...)...
3662f637fe7SMahesh Ravishankar     //     %2 = linalg.matmul ... outs(%1:...)....
3672f637fe7SMahesh Ravishankar     //     ```
3682f637fe7SMahesh Ravishankar     //
3692f637fe7SMahesh Ravishankar     //     First the `linalg.matmul` gets tiled
3702f637fe7SMahesh Ravishankar     //
3712f637fe7SMahesh Ravishankar     //     ```mlir
3722f637fe7SMahesh Ravishankar     //     %0 = linalg.init_tensor
3732f637fe7SMahesh Ravishankar     //     %1 = linalg.fill
3742f637fe7SMahesh Ravishankar     //     %2 = scf.for .... iter_args(%arg0 = %1)...
3752f637fe7SMahesh Ravishankar     //        ...
3762f637fe7SMahesh Ravishankar     //        ... = linalg.matmul ...
3772f637fe7SMahesh Ravishankar     //
3782f637fe7SMahesh Ravishankar     //     ```
3792f637fe7SMahesh Ravishankar     //
3802f637fe7SMahesh Ravishankar     //     When the `linalg.fill` gets fused, the `iter_args` needs to be
3812f637fe7SMahesh Ravishankar     //     modified
3822f637fe7SMahesh Ravishankar     //
3832f637fe7SMahesh Ravishankar     //     ```mlir
3842f637fe7SMahesh Ravishankar     //     %0 = linalg.init_tensor
3852f637fe7SMahesh Ravishankar     //     %1 = scf.for ... iter_args(%arg0 = %0)...
3862f637fe7SMahesh Ravishankar     //        ...
3872f637fe7SMahesh Ravishankar     //        %2 = linalg.fill ...
3882f637fe7SMahesh Ravishankar     //        %3 = linalg.matmul ... outs(%2: ...)...
3892f637fe7SMahesh Ravishankar     //     ```
3902f637fe7SMahesh Ravishankar     TilingInterface unfusedProducerOp =
3912f637fe7SMahesh Ravishankar         cast<TilingInterface>(fusableProducer->getOwner());
3922f637fe7SMahesh Ravishankar     scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front();
3932f637fe7SMahesh Ravishankar     SmallVector<Value> unfusedProducerOpDestValues =
3942f637fe7SMahesh Ravishankar         unfusedProducerOp.getDestinationOperands(rewriter);
3952f637fe7SMahesh Ravishankar     for (OpOperand &uses : unfusedProducerOp->getUses()) {
3962f637fe7SMahesh Ravishankar       if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
3972f637fe7SMahesh Ravishankar         unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber();
3982f637fe7SMahesh Ravishankar         unsigned operandNumber = uses.getOperandNumber();
3992f637fe7SMahesh Ravishankar         outerMostTiledLoop->setOperand(
4002f637fe7SMahesh Ravishankar             operandNumber, unfusedProducerOpDestValues[resultNumber]);
4012f637fe7SMahesh Ravishankar       }
4022f637fe7SMahesh Ravishankar     }
4032f637fe7SMahesh Ravishankar   }
4042f637fe7SMahesh Ravishankar   return tileAndFuseResult;
4052f637fe7SMahesh Ravishankar }
406