1cf6a7c19SMahesh Ravishankar //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2cf6a7c19SMahesh Ravishankar //
3cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
5cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6cf6a7c19SMahesh Ravishankar //
7cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
8cf6a7c19SMahesh Ravishankar //
9cf6a7c19SMahesh Ravishankar // This file implements the tiling using TilingInterface.
10cf6a7c19SMahesh Ravishankar //
11cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
12cf6a7c19SMahesh Ravishankar
138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14cf6a7c19SMahesh Ravishankar
15cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
16cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h"
18cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/Utils/Utils.h"
19cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h"
20cf6a7c19SMahesh Ravishankar #include "mlir/IR/Matchers.h"
21cf6a7c19SMahesh Ravishankar #include "mlir/IR/PatternMatch.h"
22cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h"
23cf6a7c19SMahesh Ravishankar #include "llvm/Support/Debug.h"
24cf6a7c19SMahesh Ravishankar
25cf6a7c19SMahesh Ravishankar #define DEBUG_TYPE "tile-using-interface"
26cf6a7c19SMahesh Ravishankar
27cf6a7c19SMahesh Ravishankar using namespace mlir;
28cf6a7c19SMahesh Ravishankar
29cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions &
setTileSizes(ArrayRef<int64_t> ts)30cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
31cf6a7c19SMahesh Ravishankar assert(!tileSizeComputationFunction && "tile sizes already set");
32b8a1f00dSMahesh Ravishankar SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
33cf6a7c19SMahesh Ravishankar tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
34cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(b);
35cf6a7c19SMahesh Ravishankar b.setInsertionPointToStart(
36cf6a7c19SMahesh Ravishankar &op->getParentOfType<func::FuncOp>().getBody().front());
37cf6a7c19SMahesh Ravishankar return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
38cf6a7c19SMahesh Ravishankar Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
39cf6a7c19SMahesh Ravishankar return v;
40cf6a7c19SMahesh Ravishankar }));
41cf6a7c19SMahesh Ravishankar };
42cf6a7c19SMahesh Ravishankar return *this;
43cf6a7c19SMahesh Ravishankar }
44cf6a7c19SMahesh Ravishankar
45b8a1f00dSMahesh Ravishankar /// Helper method to adjust the interchange vector to match the iteration
46b8a1f00dSMahesh Ravishankar /// domain.
47b8a1f00dSMahesh Ravishankar static SmallVector<unsigned>
fillInterchangeVector(ArrayRef<unsigned> interchangeVector,size_t iterationDomainSize)48b8a1f00dSMahesh Ravishankar fillInterchangeVector(ArrayRef<unsigned> interchangeVector,
49b8a1f00dSMahesh Ravishankar size_t iterationDomainSize) {
50b8a1f00dSMahesh Ravishankar SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector);
51b8a1f00dSMahesh Ravishankar if (filledVector.size() < iterationDomainSize) {
52b8a1f00dSMahesh Ravishankar auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize);
53b8a1f00dSMahesh Ravishankar filledVector.append(range.begin(), range.end());
54b8a1f00dSMahesh Ravishankar }
55b8a1f00dSMahesh Ravishankar if (filledVector.size() > iterationDomainSize)
56b8a1f00dSMahesh Ravishankar filledVector.resize(iterationDomainSize);
57b8a1f00dSMahesh Ravishankar return filledVector;
58b8a1f00dSMahesh Ravishankar }
59b8a1f00dSMahesh Ravishankar
60b8a1f00dSMahesh Ravishankar /// Helper method to apply permutation to a vector
61b8a1f00dSMahesh Ravishankar template <typename T>
applyPermutationToVector(const SmallVector<T> & vector,ArrayRef<unsigned> interchange)62b8a1f00dSMahesh Ravishankar static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
63b8a1f00dSMahesh Ravishankar ArrayRef<unsigned> interchange) {
64b8a1f00dSMahesh Ravishankar assert(interchange.size() == vector.size());
65b8a1f00dSMahesh Ravishankar return llvm::to_vector(
66b8a1f00dSMahesh Ravishankar llvm::map_range(interchange, [&](unsigned val) { return vector[val]; }));
67b8a1f00dSMahesh Ravishankar }
68b8a1f00dSMahesh Ravishankar /// Helper method to apply to invert a permutation.
69b8a1f00dSMahesh Ravishankar static SmallVector<unsigned>
invertPermutationVector(ArrayRef<unsigned> interchange)70b8a1f00dSMahesh Ravishankar invertPermutationVector(ArrayRef<unsigned> interchange) {
71b8a1f00dSMahesh Ravishankar SmallVector<unsigned> inversion(interchange.size());
72b8a1f00dSMahesh Ravishankar for (auto pos : llvm::enumerate(interchange)) {
73b8a1f00dSMahesh Ravishankar inversion[pos.value()] = pos.index();
74b8a1f00dSMahesh Ravishankar }
75b8a1f00dSMahesh Ravishankar return inversion;
76b8a1f00dSMahesh Ravishankar }
77b8a1f00dSMahesh Ravishankar /// Method to check if an interchange vector is a permutation.
isPermutation(ArrayRef<unsigned> interchange)78b8a1f00dSMahesh Ravishankar static bool isPermutation(ArrayRef<unsigned> interchange) {
79b8a1f00dSMahesh Ravishankar llvm::SmallDenseSet<unsigned, 4> seenVals;
80b8a1f00dSMahesh Ravishankar for (auto val : interchange) {
81b8a1f00dSMahesh Ravishankar if (seenVals.count(val))
82b8a1f00dSMahesh Ravishankar return false;
83b8a1f00dSMahesh Ravishankar seenVals.insert(val);
84b8a1f00dSMahesh Ravishankar }
85b8a1f00dSMahesh Ravishankar return seenVals.size() == interchange.size();
86b8a1f00dSMahesh Ravishankar }
87b8a1f00dSMahesh Ravishankar
882f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
892f637fe7SMahesh Ravishankar // TileUsingSCFForOp pattern implementation.
902f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
912f637fe7SMahesh Ravishankar
92cf6a7c19SMahesh Ravishankar /// Generate an empty loop nest that represents the tiled loop nest shell.
93cf6a7c19SMahesh Ravishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
94cf6a7c19SMahesh Ravishankar /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
95cf6a7c19SMahesh Ravishankar /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
96cf6a7c19SMahesh Ravishankar /// the
97cf6a7c19SMahesh Ravishankar /// tile processed within the inner most loop.
98cf6a7c19SMahesh Ravishankar static SmallVector<scf::ForOp>
generateTileLoopNest(OpBuilder & builder,Location loc,ArrayRef<Range> loopRanges,ArrayRef<Value> tileSizeVals,SmallVector<OpFoldResult> & offsets,SmallVector<OpFoldResult> & sizes)99cf6a7c19SMahesh Ravishankar generateTileLoopNest(OpBuilder &builder, Location loc,
100cf6a7c19SMahesh Ravishankar ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
101cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> &offsets,
102cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> &sizes) {
103cf6a7c19SMahesh Ravishankar assert(!loopRanges.empty() && "expected at least one loop range");
104cf6a7c19SMahesh Ravishankar assert(loopRanges.size() == tileSizeVals.size() &&
105cf6a7c19SMahesh Ravishankar "expected as many tile sizes as loop ranges");
106cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(builder);
107cf6a7c19SMahesh Ravishankar SmallVector<scf::ForOp> loops;
108cf6a7c19SMahesh Ravishankar offsets.resize(loopRanges.size());
109cf6a7c19SMahesh Ravishankar sizes.resize(loopRanges.size());
110cf6a7c19SMahesh Ravishankar
111cf6a7c19SMahesh Ravishankar // The tile size to use (to avoid out of bounds access) is minimum of
112cf6a7c19SMahesh Ravishankar // `tileSize` and `ub - iv`, where `iv` is the induction variable
113cf6a7c19SMahesh Ravishankar // of the tiled loop.
114cf6a7c19SMahesh Ravishankar AffineExpr s0, s1, d0;
115cf6a7c19SMahesh Ravishankar bindDims(builder.getContext(), d0);
116cf6a7c19SMahesh Ravishankar bindSymbols(builder.getContext(), s0, s1);
117cf6a7c19SMahesh Ravishankar AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
118cf6a7c19SMahesh Ravishankar
119cf6a7c19SMahesh Ravishankar for (auto loopRange : llvm::enumerate(loopRanges)) {
120cf6a7c19SMahesh Ravishankar // No loops if tile size is zero. Set offset and size to the loop
121cf6a7c19SMahesh Ravishankar // offset and size.
122cf6a7c19SMahesh Ravishankar if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
123cf6a7c19SMahesh Ravishankar offsets[loopRange.index()] = loopRange.value().offset;
124cf6a7c19SMahesh Ravishankar sizes[loopRange.index()] = loopRange.value().size;
125cf6a7c19SMahesh Ravishankar continue;
126cf6a7c19SMahesh Ravishankar }
127cf6a7c19SMahesh Ravishankar
128cf6a7c19SMahesh Ravishankar auto loop = builder.create<scf::ForOp>(
129cf6a7c19SMahesh Ravishankar loc, loopRange.value().offset, loopRange.value().size,
130cf6a7c19SMahesh Ravishankar tileSizeVals[loopRange.index()], ValueRange{},
131cf6a7c19SMahesh Ravishankar [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
132cf6a7c19SMahesh Ravishankar ValueRange /*iterArgs*/) {
133cf6a7c19SMahesh Ravishankar Value boundedTileSize = builder.create<AffineMinOp>(
134cf6a7c19SMahesh Ravishankar bodyLoc, minMap,
135cf6a7c19SMahesh Ravishankar ValueRange{iv, tileSizeVals[loopRange.index()],
136cf6a7c19SMahesh Ravishankar loopRange.value().size});
137cf6a7c19SMahesh Ravishankar sizes[loopRange.index()] = boundedTileSize;
138cf6a7c19SMahesh Ravishankar builder.create<scf::YieldOp>(loc);
139cf6a7c19SMahesh Ravishankar });
140cf6a7c19SMahesh Ravishankar offsets[loopRange.index()] = loop.getInductionVar();
141cf6a7c19SMahesh Ravishankar loops.push_back(loop);
142cf6a7c19SMahesh Ravishankar builder.setInsertionPoint(loop.getBody()->getTerminator());
143cf6a7c19SMahesh Ravishankar }
144cf6a7c19SMahesh Ravishankar return loops;
145cf6a7c19SMahesh Ravishankar }
146cf6a7c19SMahesh Ravishankar
TileUsingSCFForOp(MLIRContext * context,scf::SCFTilingOptions options,PatternBenefit benefit)147cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
148cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options,
149cf6a7c19SMahesh Ravishankar PatternBenefit benefit)
150cf6a7c19SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
151cf6a7c19SMahesh Ravishankar options(std::move(options)) {}
152cf6a7c19SMahesh Ravishankar
TileUsingSCFForOp(StringRef opName,MLIRContext * context,scf::SCFTilingOptions options,PatternBenefit benefit)153cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
154cf6a7c19SMahesh Ravishankar MLIRContext *context,
155cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options,
156cf6a7c19SMahesh Ravishankar PatternBenefit benefit)
157cf6a7c19SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
158cf6a7c19SMahesh Ravishankar options(std::move(options)) {}
159cf6a7c19SMahesh Ravishankar
160cf6a7c19SMahesh Ravishankar FailureOr<scf::SCFTilingResult>
returningMatchAndRewrite(TilingInterface op,PatternRewriter & rewriter) const161cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::returningMatchAndRewrite(
162cf6a7c19SMahesh Ravishankar TilingInterface op, PatternRewriter &rewriter) const {
163cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(rewriter);
164cf6a7c19SMahesh Ravishankar rewriter.setInsertionPointAfter(op);
165cf6a7c19SMahesh Ravishankar
166cf6a7c19SMahesh Ravishankar if (!options.tileSizeComputationFunction) {
167cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure(
168cf6a7c19SMahesh Ravishankar op, "missing tile size computation function");
169cf6a7c19SMahesh Ravishankar }
170cf6a7c19SMahesh Ravishankar
171cf6a7c19SMahesh Ravishankar // 1. Get the range of the loops that are represented by the operation.
172cf6a7c19SMahesh Ravishankar SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
173cf6a7c19SMahesh Ravishankar size_t numLoops = iterationDomain.size();
174cf6a7c19SMahesh Ravishankar if (numLoops == 0) {
175cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure(
176cf6a7c19SMahesh Ravishankar op, "unable to tile op with no iteration domain");
177cf6a7c19SMahesh Ravishankar }
178cf6a7c19SMahesh Ravishankar
179cf6a7c19SMahesh Ravishankar // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
180cf6a7c19SMahesh Ravishankar // skips tiling a particular dimension. This convention is significantly
181cf6a7c19SMahesh Ravishankar // simpler to handle instead of adjusting affine maps to account for missing
182cf6a7c19SMahesh Ravishankar // dimensions.
183b8a1f00dSMahesh Ravishankar SmallVector<Value> tileSizeVector =
184cf6a7c19SMahesh Ravishankar options.tileSizeComputationFunction(rewriter, op);
185cf6a7c19SMahesh Ravishankar if (tileSizeVector.size() < iterationDomain.size()) {
186cf6a7c19SMahesh Ravishankar auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
187cf6a7c19SMahesh Ravishankar tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
188cf6a7c19SMahesh Ravishankar }
189cf6a7c19SMahesh Ravishankar
190cf6a7c19SMahesh Ravishankar scf::SCFTilingResult tilingResult;
191cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> offsets, sizes;
192cf6a7c19SMahesh Ravishankar {
193b8a1f00dSMahesh Ravishankar // If there is an interchange specified, permute the iteration domain and
194b8a1f00dSMahesh Ravishankar // the tile sizes.
195b8a1f00dSMahesh Ravishankar SmallVector<unsigned> interchangeVector;
196b8a1f00dSMahesh Ravishankar if (!options.interchangeVector.empty()) {
197b8a1f00dSMahesh Ravishankar interchangeVector = fillInterchangeVector(options.interchangeVector,
198b8a1f00dSMahesh Ravishankar iterationDomain.size());
199b8a1f00dSMahesh Ravishankar }
200b8a1f00dSMahesh Ravishankar if (!interchangeVector.empty()) {
201b8a1f00dSMahesh Ravishankar if (!isPermutation(interchangeVector)) {
202b8a1f00dSMahesh Ravishankar return rewriter.notifyMatchFailure(
203b8a1f00dSMahesh Ravishankar op, "invalid intechange vector, not a permutation of the entire "
204b8a1f00dSMahesh Ravishankar "iteration space");
205b8a1f00dSMahesh Ravishankar }
206b8a1f00dSMahesh Ravishankar
207b8a1f00dSMahesh Ravishankar iterationDomain =
208b8a1f00dSMahesh Ravishankar applyPermutationToVector(iterationDomain, interchangeVector);
209b8a1f00dSMahesh Ravishankar tileSizeVector =
210b8a1f00dSMahesh Ravishankar applyPermutationToVector(tileSizeVector, interchangeVector);
211b8a1f00dSMahesh Ravishankar }
212b8a1f00dSMahesh Ravishankar
213cf6a7c19SMahesh Ravishankar // 3. Materialize an empty loop nest that iterates over the tiles. These
214cf6a7c19SMahesh Ravishankar // loops for now do not return any values even if the original operation has
215cf6a7c19SMahesh Ravishankar // results.
216cf6a7c19SMahesh Ravishankar tilingResult.loops = generateTileLoopNest(
217cf6a7c19SMahesh Ravishankar rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
218cf6a7c19SMahesh Ravishankar
219b8a1f00dSMahesh Ravishankar if (!interchangeVector.empty()) {
220b8a1f00dSMahesh Ravishankar auto inversePermutation = invertPermutationVector(interchangeVector);
221b8a1f00dSMahesh Ravishankar offsets = applyPermutationToVector(offsets, inversePermutation);
222b8a1f00dSMahesh Ravishankar sizes = applyPermutationToVector(sizes, inversePermutation);
223b8a1f00dSMahesh Ravishankar }
224b8a1f00dSMahesh Ravishankar
225cf6a7c19SMahesh Ravishankar LLVM_DEBUG({
226cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty()) {
227cf6a7c19SMahesh Ravishankar llvm::errs() << "LoopNest shell :\n";
228cf6a7c19SMahesh Ravishankar tilingResult.loops.front().dump();
229cf6a7c19SMahesh Ravishankar llvm::errs() << "\n";
230cf6a7c19SMahesh Ravishankar }
231cf6a7c19SMahesh Ravishankar });
232cf6a7c19SMahesh Ravishankar
233cf6a7c19SMahesh Ravishankar // 4. Generate the tiled implementation within the inner most loop.
234cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty())
235cf6a7c19SMahesh Ravishankar rewriter.setInsertionPoint(
236cf6a7c19SMahesh Ravishankar tilingResult.loops.back().getBody()->getTerminator());
237cf6a7c19SMahesh Ravishankar SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
238cf6a7c19SMahesh Ravishankar rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
239cf6a7c19SMahesh Ravishankar if (tiledImplementation.size() != 1) {
240cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure(
241cf6a7c19SMahesh Ravishankar op, "expected tiled implementation to return a single op");
242cf6a7c19SMahesh Ravishankar }
243cf6a7c19SMahesh Ravishankar tilingResult.tiledOp = tiledImplementation[0];
244cf6a7c19SMahesh Ravishankar
245cf6a7c19SMahesh Ravishankar LLVM_DEBUG({
246cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty()) {
247cf6a7c19SMahesh Ravishankar llvm::errs() << "After tiled implementation :\n";
248cf6a7c19SMahesh Ravishankar tilingResult.loops.front().dump();
249cf6a7c19SMahesh Ravishankar llvm::errs() << "\n";
250cf6a7c19SMahesh Ravishankar }
251cf6a7c19SMahesh Ravishankar });
252cf6a7c19SMahesh Ravishankar }
253cf6a7c19SMahesh Ravishankar
254cf6a7c19SMahesh Ravishankar if (op->getNumResults() == 0) {
255cf6a7c19SMahesh Ravishankar rewriter.eraseOp(op);
256cf6a7c19SMahesh Ravishankar return tilingResult;
257cf6a7c19SMahesh Ravishankar }
258cf6a7c19SMahesh Ravishankar
259cf6a7c19SMahesh Ravishankar // 5. If the original operations has results, modify the loop nest to yield
260cf6a7c19SMahesh Ravishankar // the replacement values.
261cf6a7c19SMahesh Ravishankar SmallVector<Value> replacements;
262cf6a7c19SMahesh Ravishankar if (tilingResult.loops.empty()) {
263cf6a7c19SMahesh Ravishankar // 5a. If there were no loops, the tiled implementation results are the
264cf6a7c19SMahesh Ravishankar // replacements.
265cf6a7c19SMahesh Ravishankar rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
266cf6a7c19SMahesh Ravishankar return tilingResult;
267cf6a7c19SMahesh Ravishankar }
268cf6a7c19SMahesh Ravishankar
269cf6a7c19SMahesh Ravishankar // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
270cf6a7c19SMahesh Ravishankar // replacement values using destructive updates. Use the `TilingInterface`
271cf6a7c19SMahesh Ravishankar // to get the position of the result tiles and use that to generate the
272cf6a7c19SMahesh Ravishankar // destructive update pattern, i.e.,
273cf6a7c19SMahesh Ravishankar //
274cf6a7c19SMahesh Ravishankar // ```mlir
275cf6a7c19SMahesh Ravishankar // scf.for %iv0 = ... {
276cf6a7c19SMahesh Ravishankar // %0 = tiled_op
277cf6a7c19SMahesh Ravishankar // }
278cf6a7c19SMahesh Ravishankar // ```
279cf6a7c19SMahesh Ravishankar //
280cf6a7c19SMahesh Ravishankar // is transformed to
281cf6a7c19SMahesh Ravishankar //
282cf6a7c19SMahesh Ravishankar // ```mlir
283cf6a7c19SMahesh Ravishankar // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
284cf6a7c19SMahesh Ravishankar // %0 = tiled_op
285cf6a7c19SMahesh Ravishankar // %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
286cf6a7c19SMahesh Ravishankar // scf.yield %1
287cf6a7c19SMahesh Ravishankar // }
288cf6a7c19SMahesh Ravishankar // ```
289cf6a7c19SMahesh Ravishankar NewYieldValueFn yieldValueFn =
290cf6a7c19SMahesh Ravishankar [&](OpBuilder &b, Location loc,
291cf6a7c19SMahesh Ravishankar ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
292cf6a7c19SMahesh Ravishankar SmallVector<Value> yieldedValues;
293cf6a7c19SMahesh Ravishankar Attribute one = b.getIndexAttr(1);
294cf6a7c19SMahesh Ravishankar for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
295cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
296cf6a7c19SMahesh Ravishankar if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
297cf6a7c19SMahesh Ravishankar resultTileOffsets,
298cf6a7c19SMahesh Ravishankar resultTileSizes))) {
299cf6a7c19SMahesh Ravishankar op.emitOpError("unable to get position of result ")
300cf6a7c19SMahesh Ravishankar << resultNum << " of the tiled implementation";
301cf6a7c19SMahesh Ravishankar return {};
302cf6a7c19SMahesh Ravishankar }
303cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
304cf6a7c19SMahesh Ravishankar one);
305cf6a7c19SMahesh Ravishankar Value yieldedValue = b.create<tensor::InsertSliceOp>(
306cf6a7c19SMahesh Ravishankar op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
307cf6a7c19SMahesh Ravishankar newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
308cf6a7c19SMahesh Ravishankar resultTileStrides);
309cf6a7c19SMahesh Ravishankar yieldedValues.push_back(yieldedValue);
310cf6a7c19SMahesh Ravishankar }
311cf6a7c19SMahesh Ravishankar return yieldedValues;
312cf6a7c19SMahesh Ravishankar };
313cf6a7c19SMahesh Ravishankar SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
314cf6a7c19SMahesh Ravishankar rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
315cf6a7c19SMahesh Ravishankar yieldValueFn);
316ca2933f3SAdrian Kuegel for (const auto &loop : llvm::enumerate(tilingResult.loops)) {
317cf6a7c19SMahesh Ravishankar rewriter.eraseOp(loop.value());
318cf6a7c19SMahesh Ravishankar tilingResult.loops[loop.index()] = newLoops[loop.index()];
319cf6a7c19SMahesh Ravishankar }
320cf6a7c19SMahesh Ravishankar rewriter.replaceOp(op, tilingResult.loops.front().getResults());
321cf6a7c19SMahesh Ravishankar return tilingResult;
322cf6a7c19SMahesh Ravishankar }
3232f637fe7SMahesh Ravishankar
3242f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
3252f637fe7SMahesh Ravishankar // TileConsumerAndFuseProducersUsingSCFForOp pattern implementation.
3262f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===//
3272f637fe7SMahesh Ravishankar
3282f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp::
TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext * context,scf::SCFTilingOptions options,PatternBenefit benefit)3292f637fe7SMahesh Ravishankar TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
3302f637fe7SMahesh Ravishankar scf::SCFTilingOptions options,
3312f637fe7SMahesh Ravishankar PatternBenefit benefit)
3322f637fe7SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
3332f637fe7SMahesh Ravishankar tilingPattern(context, std::move(options)) {}
3342f637fe7SMahesh Ravishankar
3352f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp::
TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,MLIRContext * context,scf::SCFTilingOptions options,PatternBenefit benefit)3362f637fe7SMahesh Ravishankar TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
3372f637fe7SMahesh Ravishankar MLIRContext *context,
3382f637fe7SMahesh Ravishankar scf::SCFTilingOptions options,
3392f637fe7SMahesh Ravishankar PatternBenefit benefit)
3402f637fe7SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
3412f637fe7SMahesh Ravishankar tilingPattern(context, std::move(options)) {}
3422f637fe7SMahesh Ravishankar
3432f637fe7SMahesh Ravishankar /// Return the `Value` that is defined by an operation that implements
3442f637fe7SMahesh Ravishankar /// the `TilingInterface`. Looks through `iter_args` of scf.for nest
3452f637fe7SMahesh Ravishankar /// if required.
getFusableProducer(Value v)3462f637fe7SMahesh Ravishankar static Optional<OpResult> getFusableProducer(Value v) {
3472f637fe7SMahesh Ravishankar while (auto blockArg = v.dyn_cast<BlockArgument>()) {
3482f637fe7SMahesh Ravishankar auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
3492f637fe7SMahesh Ravishankar if (!loopOp)
3502f637fe7SMahesh Ravishankar return llvm::None;
3512f637fe7SMahesh Ravishankar v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
3522f637fe7SMahesh Ravishankar }
3532f637fe7SMahesh Ravishankar if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp()))
3542f637fe7SMahesh Ravishankar return llvm::None;
3552f637fe7SMahesh Ravishankar return v.cast<OpResult>();
3562f637fe7SMahesh Ravishankar }
3572f637fe7SMahesh Ravishankar
358*2ed7c3fdSlorenzo chelini // Replace iter args of the outer most loop with region args of the inner most
359*2ed7c3fdSlorenzo chelini // one.
replaceIterArgs(scf::ForOp outerFor,scf::ForOp innerFor,PatternRewriter & rewriter)360*2ed7c3fdSlorenzo chelini static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor,
361*2ed7c3fdSlorenzo chelini PatternRewriter &rewriter) {
362*2ed7c3fdSlorenzo chelini assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() &&
363*2ed7c3fdSlorenzo chelini "expect same number of iter args");
364*2ed7c3fdSlorenzo chelini Block *block = &(*innerFor.getRegion().begin());
365*2ed7c3fdSlorenzo chelini for (auto it :
366*2ed7c3fdSlorenzo chelini llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) {
367*2ed7c3fdSlorenzo chelini Value source = std::get<0>(it);
368*2ed7c3fdSlorenzo chelini Value target = std::get<1>(it);
369*2ed7c3fdSlorenzo chelini source.replaceUsesWithIf(target, [&](OpOperand &use) {
370*2ed7c3fdSlorenzo chelini return use.getOwner()->getBlock() == block;
371*2ed7c3fdSlorenzo chelini });
372*2ed7c3fdSlorenzo chelini }
373*2ed7c3fdSlorenzo chelini }
374*2ed7c3fdSlorenzo chelini
3752f637fe7SMahesh Ravishankar FailureOr<scf::SCFTileAndFuseResult>
returningMatchAndRewrite(TilingInterface op,PatternRewriter & rewriter) const3762f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
3772f637fe7SMahesh Ravishankar TilingInterface op, PatternRewriter &rewriter) const {
3782f637fe7SMahesh Ravishankar // This transformation is only valid for ops that return values (i.e. not
3792f637fe7SMahesh Ravishankar // valid to use with operations that have memref operands).
3802f637fe7SMahesh Ravishankar if (!op->getNumResults()) {
3812f637fe7SMahesh Ravishankar return rewriter.notifyMatchFailure(
3822f637fe7SMahesh Ravishankar op, "invalid pattern for op with no results");
3832f637fe7SMahesh Ravishankar }
3842f637fe7SMahesh Ravishankar
3852f637fe7SMahesh Ravishankar // 1. First tile the consumer.
3862f637fe7SMahesh Ravishankar SCFTileAndFuseResult tileAndFuseResult;
3872f637fe7SMahesh Ravishankar {
3882f637fe7SMahesh Ravishankar FailureOr<SCFTilingResult> tilingResult =
3892f637fe7SMahesh Ravishankar tilingPattern.returningMatchAndRewrite(op, rewriter);
3902f637fe7SMahesh Ravishankar if (failed(tilingResult)) {
3912f637fe7SMahesh Ravishankar return failure();
3922f637fe7SMahesh Ravishankar }
3932f637fe7SMahesh Ravishankar tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
3942f637fe7SMahesh Ravishankar tileAndFuseResult.loops = std::move(tilingResult->loops);
3952f637fe7SMahesh Ravishankar }
3962f637fe7SMahesh Ravishankar
3972f637fe7SMahesh Ravishankar // 2. Typically, the operands of the tiled operation are slices of the
3982f637fe7SMahesh Ravishankar // operands of the untiled operation. These are expressed in IR using
3992f637fe7SMahesh Ravishankar // `tensor.extract_slice` operations with source being the operands of the
4002f637fe7SMahesh Ravishankar // untiled operation. Create a worklist of these `tensor.extract_slice`
4012f637fe7SMahesh Ravishankar // operations. If the producers of the source of the `tensor.extract_slice`
4022f637fe7SMahesh Ravishankar // can be tiled such that the tiled value is generated in-place, that
4032f637fe7SMahesh Ravishankar // effectively tiles + fuses the operations.
4042f637fe7SMahesh Ravishankar auto addCandidateSlices = [](Operation *fusedOp,
4052f637fe7SMahesh Ravishankar std::deque<tensor::ExtractSliceOp> &candidates) {
4062f637fe7SMahesh Ravishankar for (Value operand : fusedOp->getOperands())
4072f637fe7SMahesh Ravishankar if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
4082f637fe7SMahesh Ravishankar candidates.push_back(sliceOp);
4092f637fe7SMahesh Ravishankar };
4102f637fe7SMahesh Ravishankar
4112f637fe7SMahesh Ravishankar std::deque<tensor::ExtractSliceOp> candidates;
4122f637fe7SMahesh Ravishankar addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
4132f637fe7SMahesh Ravishankar OpBuilder::InsertionGuard g(rewriter);
4142f637fe7SMahesh Ravishankar while (!candidates.empty()) {
4152f637fe7SMahesh Ravishankar // 2a. Traverse the slices in BFS fashion.
4162f637fe7SMahesh Ravishankar tensor::ExtractSliceOp candidateSliceOp = candidates.front();
4172f637fe7SMahesh Ravishankar candidates.pop_front();
4182f637fe7SMahesh Ravishankar
4192f637fe7SMahesh Ravishankar // 2b. Get the producer of the source (potentially walking through
4202f637fe7SMahesh Ravishankar // `iter_args` of nested `scf.for`)
4212f637fe7SMahesh Ravishankar Optional<OpResult> fusableProducer =
42204235d07SJacques Pienaar getFusableProducer(candidateSliceOp.getSource());
4232f637fe7SMahesh Ravishankar if (!fusableProducer)
4242f637fe7SMahesh Ravishankar continue;
4252f637fe7SMahesh Ravishankar
4262f637fe7SMahesh Ravishankar // 2c. Generate the tiled implementation of the producer of the source
4272f637fe7SMahesh Ravishankar rewriter.setInsertionPoint(candidateSliceOp);
4282f637fe7SMahesh Ravishankar FailureOr<Value> fusedProducerValue =
429c27d8152SKazu Hirata tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
430c27d8152SKazu Hirata fusableProducer.value());
4312f637fe7SMahesh Ravishankar if (failed(fusedProducerValue))
4322f637fe7SMahesh Ravishankar continue;
433c27d8152SKazu Hirata rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value());
4342f637fe7SMahesh Ravishankar
4352f637fe7SMahesh Ravishankar // 2d. The operands of the fused producer might themselved be slices of
4362f637fe7SMahesh Ravishankar // values produced by operations that implement the `TilingInterface`.
4372f637fe7SMahesh Ravishankar // Add these operations to the worklist.
4382f637fe7SMahesh Ravishankar Operation *fusedProducer = fusedProducerValue->getDefiningOp();
4392f637fe7SMahesh Ravishankar tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
4402f637fe7SMahesh Ravishankar addCandidateSlices(fusedProducer, candidates);
4412f637fe7SMahesh Ravishankar
4422f637fe7SMahesh Ravishankar // 2e. If the operation being fused creates a value that is used as `outs`
4432f637fe7SMahesh Ravishankar // in the tiled operation, the result of the unfused operation will be
4442f637fe7SMahesh Ravishankar // used in the `iter_args` of the tiled loop generated. When the
4452f637fe7SMahesh Ravishankar // operation is fused, this use in `iter_args` needs to be modified to
4462f637fe7SMahesh Ravishankar // use the destination of the fused operation. For example, starting
4472f637fe7SMahesh Ravishankar // with
4482f637fe7SMahesh Ravishankar //
4492f637fe7SMahesh Ravishankar // ```mlir
4502f637fe7SMahesh Ravishankar // %0 = linalg.init_tensor ...
4512f637fe7SMahesh Ravishankar // %1 = linalg.fill ... outs(%0:...)...
4522f637fe7SMahesh Ravishankar // %2 = linalg.matmul ... outs(%1:...)....
4532f637fe7SMahesh Ravishankar // ```
4542f637fe7SMahesh Ravishankar //
4552f637fe7SMahesh Ravishankar // First the `linalg.matmul` gets tiled
4562f637fe7SMahesh Ravishankar //
4572f637fe7SMahesh Ravishankar // ```mlir
4582f637fe7SMahesh Ravishankar // %0 = linalg.init_tensor
4592f637fe7SMahesh Ravishankar // %1 = linalg.fill
4602f637fe7SMahesh Ravishankar // %2 = scf.for .... iter_args(%arg0 = %1)...
4612f637fe7SMahesh Ravishankar // ...
4622f637fe7SMahesh Ravishankar // ... = linalg.matmul ...
4632f637fe7SMahesh Ravishankar //
4642f637fe7SMahesh Ravishankar // ```
4652f637fe7SMahesh Ravishankar //
4662f637fe7SMahesh Ravishankar // When the `linalg.fill` gets fused, the `iter_args` needs to be
4672f637fe7SMahesh Ravishankar // modified
4682f637fe7SMahesh Ravishankar //
4692f637fe7SMahesh Ravishankar // ```mlir
4702f637fe7SMahesh Ravishankar // %0 = linalg.init_tensor
4712f637fe7SMahesh Ravishankar // %1 = scf.for ... iter_args(%arg0 = %0)...
4722f637fe7SMahesh Ravishankar // ...
4732f637fe7SMahesh Ravishankar // %2 = linalg.fill ...
4742f637fe7SMahesh Ravishankar // %3 = linalg.matmul ... outs(%2: ...)...
4752f637fe7SMahesh Ravishankar // ```
4762f637fe7SMahesh Ravishankar TilingInterface unfusedProducerOp =
4772f637fe7SMahesh Ravishankar cast<TilingInterface>(fusableProducer->getOwner());
4782f637fe7SMahesh Ravishankar scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front();
4792f637fe7SMahesh Ravishankar SmallVector<Value> unfusedProducerOpDestValues =
4802f637fe7SMahesh Ravishankar unfusedProducerOp.getDestinationOperands(rewriter);
4812f637fe7SMahesh Ravishankar for (OpOperand &uses : unfusedProducerOp->getUses()) {
4822f637fe7SMahesh Ravishankar if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
4832f637fe7SMahesh Ravishankar unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber();
4842f637fe7SMahesh Ravishankar unsigned operandNumber = uses.getOperandNumber();
4852f637fe7SMahesh Ravishankar outerMostTiledLoop->setOperand(
4862f637fe7SMahesh Ravishankar operandNumber, unfusedProducerOpDestValues[resultNumber]);
4872f637fe7SMahesh Ravishankar }
4882f637fe7SMahesh Ravishankar }
4892f637fe7SMahesh Ravishankar }
490*2ed7c3fdSlorenzo chelini replaceIterArgs(tileAndFuseResult.loops.front(),
491*2ed7c3fdSlorenzo chelini tileAndFuseResult.loops.back(), rewriter);
4922f637fe7SMahesh Ravishankar return tileAndFuseResult;
4932f637fe7SMahesh Ravishankar }
494