1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/SCF/Utils/Utils.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Interfaces/TilingInterface.h"
23 #include "llvm/Support/Debug.h"
24
25 #define DEBUG_TYPE "tile-using-interface"
26
27 using namespace mlir;
28
29 scf::SCFTilingOptions &
setTileSizes(ArrayRef<int64_t> ts)30 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
31 assert(!tileSizeComputationFunction && "tile sizes already set");
32 SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
33 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
34 OpBuilder::InsertionGuard guard(b);
35 b.setInsertionPointToStart(
36 &op->getParentOfType<func::FuncOp>().getBody().front());
37 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
38 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
39 return v;
40 }));
41 };
42 return *this;
43 }
44
45 /// Helper method to adjust the interchange vector to match the iteration
46 /// domain.
47 static SmallVector<unsigned>
fillInterchangeVector(ArrayRef<unsigned> interchangeVector,size_t iterationDomainSize)48 fillInterchangeVector(ArrayRef<unsigned> interchangeVector,
49 size_t iterationDomainSize) {
50 SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector);
51 if (filledVector.size() < iterationDomainSize) {
52 auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize);
53 filledVector.append(range.begin(), range.end());
54 }
55 if (filledVector.size() > iterationDomainSize)
56 filledVector.resize(iterationDomainSize);
57 return filledVector;
58 }
59
60 /// Helper method to apply permutation to a vector
61 template <typename T>
applyPermutationToVector(const SmallVector<T> & vector,ArrayRef<unsigned> interchange)62 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
63 ArrayRef<unsigned> interchange) {
64 assert(interchange.size() == vector.size());
65 return llvm::to_vector(
66 llvm::map_range(interchange, [&](unsigned val) { return vector[val]; }));
67 }
68 /// Helper method to apply to invert a permutation.
69 static SmallVector<unsigned>
invertPermutationVector(ArrayRef<unsigned> interchange)70 invertPermutationVector(ArrayRef<unsigned> interchange) {
71 SmallVector<unsigned> inversion(interchange.size());
72 for (auto pos : llvm::enumerate(interchange)) {
73 inversion[pos.value()] = pos.index();
74 }
75 return inversion;
76 }
77 /// Method to check if an interchange vector is a permutation.
isPermutation(ArrayRef<unsigned> interchange)78 static bool isPermutation(ArrayRef<unsigned> interchange) {
79 llvm::SmallDenseSet<unsigned, 4> seenVals;
80 for (auto val : interchange) {
81 if (seenVals.count(val))
82 return false;
83 seenVals.insert(val);
84 }
85 return seenVals.size() == interchange.size();
86 }
87
88 //===----------------------------------------------------------------------===//
89 // TileUsingSCFForOp pattern implementation.
90 //===----------------------------------------------------------------------===//
91
92 /// Generate an empty loop nest that represents the tiled loop nest shell.
93 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
94 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
95 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
96 /// the
97 /// tile processed within the inner most loop.
98 static SmallVector<scf::ForOp>
generateTileLoopNest(OpBuilder & builder,Location loc,ArrayRef<Range> loopRanges,ArrayRef<Value> tileSizeVals,SmallVector<OpFoldResult> & offsets,SmallVector<OpFoldResult> & sizes)99 generateTileLoopNest(OpBuilder &builder, Location loc,
100 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
101 SmallVector<OpFoldResult> &offsets,
102 SmallVector<OpFoldResult> &sizes) {
103 assert(!loopRanges.empty() && "expected at least one loop range");
104 assert(loopRanges.size() == tileSizeVals.size() &&
105 "expected as many tile sizes as loop ranges");
106 OpBuilder::InsertionGuard guard(builder);
107 SmallVector<scf::ForOp> loops;
108 offsets.resize(loopRanges.size());
109 sizes.resize(loopRanges.size());
110
111 // The tile size to use (to avoid out of bounds access) is minimum of
112 // `tileSize` and `ub - iv`, where `iv` is the induction variable
113 // of the tiled loop.
114 AffineExpr s0, s1, d0;
115 bindDims(builder.getContext(), d0);
116 bindSymbols(builder.getContext(), s0, s1);
117 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
118
119 for (auto loopRange : llvm::enumerate(loopRanges)) {
120 // No loops if tile size is zero. Set offset and size to the loop
121 // offset and size.
122 if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
123 offsets[loopRange.index()] = loopRange.value().offset;
124 sizes[loopRange.index()] = loopRange.value().size;
125 continue;
126 }
127
128 auto loop = builder.create<scf::ForOp>(
129 loc, loopRange.value().offset, loopRange.value().size,
130 tileSizeVals[loopRange.index()], ValueRange{},
131 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
132 ValueRange /*iterArgs*/) {
133 Value boundedTileSize = builder.create<AffineMinOp>(
134 bodyLoc, minMap,
135 ValueRange{iv, tileSizeVals[loopRange.index()],
136 loopRange.value().size});
137 sizes[loopRange.index()] = boundedTileSize;
138 builder.create<scf::YieldOp>(loc);
139 });
140 offsets[loopRange.index()] = loop.getInductionVar();
141 loops.push_back(loop);
142 builder.setInsertionPoint(loop.getBody()->getTerminator());
143 }
144 return loops;
145 }
146
TileUsingSCFForOp(MLIRContext * context,scf::SCFTilingOptions options,PatternBenefit benefit)147 scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
148 scf::SCFTilingOptions options,
149 PatternBenefit benefit)
150 : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
151 options(std::move(options)) {}
152
TileUsingSCFForOp(StringRef opName,MLIRContext * context,scf::SCFTilingOptions options,PatternBenefit benefit)153 scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
154 MLIRContext *context,
155 scf::SCFTilingOptions options,
156 PatternBenefit benefit)
157 : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
158 options(std::move(options)) {}
159
160 FailureOr<scf::SCFTilingResult>
returningMatchAndRewrite(TilingInterface op,PatternRewriter & rewriter) const161 scf::TileUsingSCFForOp::returningMatchAndRewrite(
162 TilingInterface op, PatternRewriter &rewriter) const {
163 OpBuilder::InsertionGuard guard(rewriter);
164 rewriter.setInsertionPointAfter(op);
165
166 if (!options.tileSizeComputationFunction) {
167 return rewriter.notifyMatchFailure(
168 op, "missing tile size computation function");
169 }
170
171 // 1. Get the range of the loops that are represented by the operation.
172 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
173 size_t numLoops = iterationDomain.size();
174 if (numLoops == 0) {
175 return rewriter.notifyMatchFailure(
176 op, "unable to tile op with no iteration domain");
177 }
178
179 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
180 // skips tiling a particular dimension. This convention is significantly
181 // simpler to handle instead of adjusting affine maps to account for missing
182 // dimensions.
183 SmallVector<Value> tileSizeVector =
184 options.tileSizeComputationFunction(rewriter, op);
185 if (tileSizeVector.size() < iterationDomain.size()) {
186 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
187 tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
188 }
189
190 scf::SCFTilingResult tilingResult;
191 SmallVector<OpFoldResult> offsets, sizes;
192 {
193 // If there is an interchange specified, permute the iteration domain and
194 // the tile sizes.
195 SmallVector<unsigned> interchangeVector;
196 if (!options.interchangeVector.empty()) {
197 interchangeVector = fillInterchangeVector(options.interchangeVector,
198 iterationDomain.size());
199 }
200 if (!interchangeVector.empty()) {
201 if (!isPermutation(interchangeVector)) {
202 return rewriter.notifyMatchFailure(
203 op, "invalid intechange vector, not a permutation of the entire "
204 "iteration space");
205 }
206
207 iterationDomain =
208 applyPermutationToVector(iterationDomain, interchangeVector);
209 tileSizeVector =
210 applyPermutationToVector(tileSizeVector, interchangeVector);
211 }
212
213 // 3. Materialize an empty loop nest that iterates over the tiles. These
214 // loops for now do not return any values even if the original operation has
215 // results.
216 tilingResult.loops = generateTileLoopNest(
217 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
218
219 if (!interchangeVector.empty()) {
220 auto inversePermutation = invertPermutationVector(interchangeVector);
221 offsets = applyPermutationToVector(offsets, inversePermutation);
222 sizes = applyPermutationToVector(sizes, inversePermutation);
223 }
224
225 LLVM_DEBUG({
226 if (!tilingResult.loops.empty()) {
227 llvm::errs() << "LoopNest shell :\n";
228 tilingResult.loops.front().dump();
229 llvm::errs() << "\n";
230 }
231 });
232
233 // 4. Generate the tiled implementation within the inner most loop.
234 if (!tilingResult.loops.empty())
235 rewriter.setInsertionPoint(
236 tilingResult.loops.back().getBody()->getTerminator());
237 SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
238 rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
239 if (tiledImplementation.size() != 1) {
240 return rewriter.notifyMatchFailure(
241 op, "expected tiled implementation to return a single op");
242 }
243 tilingResult.tiledOp = tiledImplementation[0];
244
245 LLVM_DEBUG({
246 if (!tilingResult.loops.empty()) {
247 llvm::errs() << "After tiled implementation :\n";
248 tilingResult.loops.front().dump();
249 llvm::errs() << "\n";
250 }
251 });
252 }
253
254 if (op->getNumResults() == 0) {
255 rewriter.eraseOp(op);
256 return tilingResult;
257 }
258
259 // 5. If the original operations has results, modify the loop nest to yield
260 // the replacement values.
261 SmallVector<Value> replacements;
262 if (tilingResult.loops.empty()) {
263 // 5a. If there were no loops, the tiled implementation results are the
264 // replacements.
265 rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
266 return tilingResult;
267 }
268
269 // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
270 // replacement values using destructive updates. Use the `TilingInterface`
271 // to get the position of the result tiles and use that to generate the
272 // destructive update pattern, i.e.,
273 //
274 // ```mlir
275 // scf.for %iv0 = ... {
276 // %0 = tiled_op
277 // }
278 // ```
279 //
280 // is transformed to
281 //
282 // ```mlir
283 // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
284 // %0 = tiled_op
285 // %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
286 // scf.yield %1
287 // }
288 // ```
289 NewYieldValueFn yieldValueFn =
290 [&](OpBuilder &b, Location loc,
291 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
292 SmallVector<Value> yieldedValues;
293 Attribute one = b.getIndexAttr(1);
294 for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
295 SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
296 if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
297 resultTileOffsets,
298 resultTileSizes))) {
299 op.emitOpError("unable to get position of result ")
300 << resultNum << " of the tiled implementation";
301 return {};
302 }
303 SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
304 one);
305 Value yieldedValue = b.create<tensor::InsertSliceOp>(
306 op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
307 newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
308 resultTileStrides);
309 yieldedValues.push_back(yieldedValue);
310 }
311 return yieldedValues;
312 };
313 SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
314 rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
315 yieldValueFn);
316 for (const auto &loop : llvm::enumerate(tilingResult.loops)) {
317 rewriter.eraseOp(loop.value());
318 tilingResult.loops[loop.index()] = newLoops[loop.index()];
319 }
320 rewriter.replaceOp(op, tilingResult.loops.front().getResults());
321 return tilingResult;
322 }
323
324 //===----------------------------------------------------------------------===//
325 // TileConsumerAndFuseProducersUsingSCFForOp pattern implementation.
326 //===----------------------------------------------------------------------===//
327
328 scf::TileConsumerAndFuseProducersUsingSCFForOp::
TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext * context,scf::SCFTilingOptions options,PatternBenefit benefit)329 TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
330 scf::SCFTilingOptions options,
331 PatternBenefit benefit)
332 : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
333 tilingPattern(context, std::move(options)) {}
334
335 scf::TileConsumerAndFuseProducersUsingSCFForOp::
TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,MLIRContext * context,scf::SCFTilingOptions options,PatternBenefit benefit)336 TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
337 MLIRContext *context,
338 scf::SCFTilingOptions options,
339 PatternBenefit benefit)
340 : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
341 tilingPattern(context, std::move(options)) {}
342
343 /// Return the `Value` that is defined by an operation that implements
344 /// the `TilingInterface`. Looks through `iter_args` of scf.for nest
345 /// if required.
getFusableProducer(Value v)346 static Optional<OpResult> getFusableProducer(Value v) {
347 while (auto blockArg = v.dyn_cast<BlockArgument>()) {
348 auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
349 if (!loopOp)
350 return llvm::None;
351 v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
352 }
353 if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp()))
354 return llvm::None;
355 return v.cast<OpResult>();
356 }
357
358 // Replace iter args of the outer most loop with region args of the inner most
359 // one.
replaceIterArgs(scf::ForOp outerFor,scf::ForOp innerFor,PatternRewriter & rewriter)360 static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor,
361 PatternRewriter &rewriter) {
362 assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() &&
363 "expect same number of iter args");
364 Block *block = &(*innerFor.getRegion().begin());
365 for (auto it :
366 llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) {
367 Value source = std::get<0>(it);
368 Value target = std::get<1>(it);
369 source.replaceUsesWithIf(target, [&](OpOperand &use) {
370 return use.getOwner()->getBlock() == block;
371 });
372 }
373 }
374
375 FailureOr<scf::SCFTileAndFuseResult>
returningMatchAndRewrite(TilingInterface op,PatternRewriter & rewriter) const376 scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
377 TilingInterface op, PatternRewriter &rewriter) const {
378 // This transformation is only valid for ops that return values (i.e. not
379 // valid to use with operations that have memref operands).
380 if (!op->getNumResults()) {
381 return rewriter.notifyMatchFailure(
382 op, "invalid pattern for op with no results");
383 }
384
385 // 1. First tile the consumer.
386 SCFTileAndFuseResult tileAndFuseResult;
387 {
388 FailureOr<SCFTilingResult> tilingResult =
389 tilingPattern.returningMatchAndRewrite(op, rewriter);
390 if (failed(tilingResult)) {
391 return failure();
392 }
393 tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
394 tileAndFuseResult.loops = std::move(tilingResult->loops);
395 }
396
397 // 2. Typically, the operands of the tiled operation are slices of the
398 // operands of the untiled operation. These are expressed in IR using
399 // `tensor.extract_slice` operations with source being the operands of the
400 // untiled operation. Create a worklist of these `tensor.extract_slice`
401 // operations. If the producers of the source of the `tensor.extract_slice`
402 // can be tiled such that the tiled value is generated in-place, that
403 // effectively tiles + fuses the operations.
404 auto addCandidateSlices = [](Operation *fusedOp,
405 std::deque<tensor::ExtractSliceOp> &candidates) {
406 for (Value operand : fusedOp->getOperands())
407 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
408 candidates.push_back(sliceOp);
409 };
410
411 std::deque<tensor::ExtractSliceOp> candidates;
412 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
413 OpBuilder::InsertionGuard g(rewriter);
414 while (!candidates.empty()) {
415 // 2a. Traverse the slices in BFS fashion.
416 tensor::ExtractSliceOp candidateSliceOp = candidates.front();
417 candidates.pop_front();
418
419 // 2b. Get the producer of the source (potentially walking through
420 // `iter_args` of nested `scf.for`)
421 Optional<OpResult> fusableProducer =
422 getFusableProducer(candidateSliceOp.getSource());
423 if (!fusableProducer)
424 continue;
425
426 // 2c. Generate the tiled implementation of the producer of the source
427 rewriter.setInsertionPoint(candidateSliceOp);
428 FailureOr<Value> fusedProducerValue =
429 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
430 fusableProducer.value());
431 if (failed(fusedProducerValue))
432 continue;
433 rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value());
434
435 // 2d. The operands of the fused producer might themselved be slices of
436 // values produced by operations that implement the `TilingInterface`.
437 // Add these operations to the worklist.
438 Operation *fusedProducer = fusedProducerValue->getDefiningOp();
439 tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
440 addCandidateSlices(fusedProducer, candidates);
441
442 // 2e. If the operation being fused creates a value that is used as `outs`
443 // in the tiled operation, the result of the unfused operation will be
444 // used in the `iter_args` of the tiled loop generated. When the
445 // operation is fused, this use in `iter_args` needs to be modified to
446 // use the destination of the fused operation. For example, starting
447 // with
448 //
449 // ```mlir
450 // %0 = linalg.init_tensor ...
451 // %1 = linalg.fill ... outs(%0:...)...
452 // %2 = linalg.matmul ... outs(%1:...)....
453 // ```
454 //
455 // First the `linalg.matmul` gets tiled
456 //
457 // ```mlir
458 // %0 = linalg.init_tensor
459 // %1 = linalg.fill
460 // %2 = scf.for .... iter_args(%arg0 = %1)...
461 // ...
462 // ... = linalg.matmul ...
463 //
464 // ```
465 //
466 // When the `linalg.fill` gets fused, the `iter_args` needs to be
467 // modified
468 //
469 // ```mlir
470 // %0 = linalg.init_tensor
471 // %1 = scf.for ... iter_args(%arg0 = %0)...
472 // ...
473 // %2 = linalg.fill ...
474 // %3 = linalg.matmul ... outs(%2: ...)...
475 // ```
476 TilingInterface unfusedProducerOp =
477 cast<TilingInterface>(fusableProducer->getOwner());
478 scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front();
479 SmallVector<Value> unfusedProducerOpDestValues =
480 unfusedProducerOp.getDestinationOperands(rewriter);
481 for (OpOperand &uses : unfusedProducerOp->getUses()) {
482 if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
483 unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber();
484 unsigned operandNumber = uses.getOperandNumber();
485 outerMostTiledLoop->setOperand(
486 operandNumber, unfusedProducerOpDestValues[resultNumber]);
487 }
488 }
489 }
490 replaceIterArgs(tileAndFuseResult.loops.front(),
491 tileAndFuseResult.loops.back(), rewriter);
492 return tileAndFuseResult;
493 }
494