1 //===- FusionOnTensors.cpp - Implementation of linalg Fusion --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg fusion on tensors
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "PassDetail.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/Support/LLVM.h"
25
26 using namespace mlir;
27 using namespace linalg;
28
29 //===----------------------------------------------------------------------===//
30 // StructuredOp specific helpers.
31 //===----------------------------------------------------------------------===//
32
33 /// Returns the tiled slice dimensions given the tiled consumer loop dimensions.
34 /// The slice defines a hyper rectangular iteration space and fusing the
35 /// producer is always possible. However, depending on the consumer indexing
36 /// map, not all slice elements may be consumed and the tiles may overlap. In
37 /// these cases, fusion introduces redundant computation.
getTiledSliceDims(OpOperand * consumerOperand,ArrayRef<int64_t> tiledLoopDims)38 static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
39 ArrayRef<int64_t> tiledLoopDims) {
40 // Get the consumer operand indexing map.
41 LinalgOp consumerOp = consumerOperand->getOwner();
42 AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);
43
44 // Search the slice dimensions tiled by a tile loop dimension.
45 DenseSet<int64_t> tiledSliceDimIndices;
46 for (const auto &en : enumerate(indexingMap.getResults())) {
47 for (auto tiledLoopDim : tiledLoopDims) {
48 if (en.value().isFunctionOfDim(tiledLoopDim))
49 tiledSliceDimIndices.insert(en.index());
50 }
51 }
52 return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()};
53 }
54
55 /// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions
56 /// of the producer result slice returns the tiled producer loop dimensions.
57 /// Example:
58 /// ```
59 /// %res = linalg.fill(%cst, %input)
60 /// scf.for %i
61 /// scf.for %j
62 /// %slice = tensor.extract_slice %res[%i, %j]
63 /// ```
64 /// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1].
65 static SmallVector<int64_t>
getTiledProducerLoops(OpResult producerResult,ArrayRef<int64_t> tiledSliceDimIndices)66 getTiledProducerLoops(OpResult producerResult,
67 ArrayRef<int64_t> tiledSliceDimIndices) {
68 LinalgOp producerOp = producerResult.getOwner();
69
70 // Get the indexing map of the `producerOp` output operand that matches
71 // ´producerResult´.
72 AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
73 producerOp.getOutputOperand(producerResult.getResultNumber()));
74
75 // Keep only the tiled result slice dimensions of `producerIndexingMap`.
76 AffineMap tiledProducerIndexingSubMap =
77 producerIndexingMap.getSubMap(SmallVector<unsigned>(
78 tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()));
79
80 // Compute the producer loop indices mapped to the tiled result slice
81 // dimensions. As the output indexing map of structured operations are
82 // projected permutations, `tiledProducerIndexingSubMap` has to be a
83 // projected permutation as well. We can thus obtain the producer loop indices
84 // by getting the positions of the result dimensions.
85 // Example:
86 // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2].
87 assert(tiledProducerIndexingSubMap.isProjectedPermutation() &&
88 "expect slice and producer loop dimensions map one-to-one");
89 SmallVector<int64_t> tiledProducerLoopIndices;
90 llvm::transform(
91 llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
92 std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
93 return tiledProducerIndexingSubMap.getDimPosition(idx);
94 });
95
96 return tiledProducerLoopIndices;
97 }
98
99 /// Returns the producer fused in place of `sliceOp`. Tile the producer operands
100 /// along the `tiledSliceDimIndices` and clone the producer. Consider the case
101 /// of fusion of an output tensor:
102 /// ```
103 /// %1 = producer ins(...) outs(%0)
104 /// %2 = consumer ins(...) outs(%1)
105 /// ```
106 /// When consumer is tiled, %1 appears in the loop iter_args:
107 /// ```
108 /// %1 = producer ins(...) outs(%0)
109 /// %2 = scf.for ... iter_args(%1) .. (%bbarg) {
110 /// %t1 = tensor.extract_slice %bbarg[..]
111 /// %t2 = consumer ins(...) outs(%t1)
112 /// %r = tensor.insert_slice %t2, %bbarg[...]
113 /// }
114 /// ```
115 /// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0):
116 /// ```
117 /// %2 = scf.for ... iter_args(%0) .. (%bbarg) {
118 /// %t0 = tensor.extract_slice %bbarg[..]
119 /// %t1 = producer ins(...) outs(%t0)
120 /// %t2 = consumer ins(...) outs(%t1)
121 /// %r = tensor.insert_slice %t2, %bbarg[...]
122 /// }
123 /// ```
124 /// This transformation is only valid if %bbarg is exclusively used by the
125 /// output ExtractSliceOp / InsertSliceOp pair, which is checked by the
126 /// `fuseProducer` method.
127 /// TODO: instead of check and failure, insert new iter_args each time a
128 /// producer is fused into a consumer and fold away unused iter_args.
getTiledProducer(OpBuilder & b,OpResult producerResult,tensor::ExtractSliceOp sliceOp,ArrayRef<int64_t> tiledSliceDimIndices,ArrayRef<int64_t> tiledProducerLoopIndices,OpOperand * iterArg)129 static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
130 tensor::ExtractSliceOp sliceOp,
131 ArrayRef<int64_t> tiledSliceDimIndices,
132 ArrayRef<int64_t> tiledProducerLoopIndices,
133 OpOperand *iterArg) {
134 // Clone the producer after `sliceOp` since the slice may be reused to pass in
135 // the producer result.
136 OpBuilder::InsertionGuard guard(b);
137 b.setInsertionPointAfter(sliceOp);
138
139 // Get the producer.
140 LinalgOp producerOp = producerResult.getOwner();
141 Location loc = producerOp.getLoc();
142
143 // Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
144 SmallVector<Value> producerLoopBounds;
145 llvm::transform(producerOp.createLoopRanges(b, loc),
146 std::back_inserter(producerLoopBounds),
147 [](Range range) { return range.size; });
148 SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
149
150 // Tile the producer operands given the `sliceOp` ranges. Iterate the
151 // `tiledSliceDimIndices` and store the tile offset and size for the tiled
152 // slice dimension.
153 auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
154 SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr);
155 SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero);
156 SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr);
157 for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) {
158 int64_t tiledSliceDim = std::get<0>(it);
159 int64_t tiledProducerLoop = std::get<1>(it);
160 tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset;
161 tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size;
162 allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
163 }
164 erase_value(tileIvs, nullptr);
165 SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
166 tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
167 tileSizes, producerLoopBounds,
168 /**omitPartialTileCheck=*/false);
169
170 // Output fusion has to update the iteration arguments of the tile loop nest.
171 // In particular, the iteration argument of the outermost tile loop needs to
172 // be set to the producer output instead of the producer result and `clonedOp`
173 // shall use the existing `sliceOp` result instead of the tiled producer
174 // output operand.
175 if (iterArg) {
176 OpOperand *outputOperand =
177 producerOp.getOutputOperand(producerResult.getResultNumber());
178 iterArg->set(outputOperand->get());
179 tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult();
180 }
181
182 // Clone the producer using the tiled producer operands.
183 TypeRange resultTypes = ValueRange(tiledOperands)
184 .take_back(producerOp.getNumOutputs())
185 .getTypes();
186 LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
187
188 // Shift all IndexOp results by the tile offset.
189 offsetIndices(b, clonedOp, allIvs);
190
191 return clonedOp;
192 }
193
194 //===----------------------------------------------------------------------===//
195 // TileLoopNest specific helpers.
196 //===----------------------------------------------------------------------===//
197
isEmpty()198 bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); }
199
isValid()200 bool TileLoopNest::isValid() {
201 // Check if `rootOp` has been tiled at least once.
202 if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0)
203 return false;
204
205 // Check if the number of loop operations and dimensions match.
206 if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size())
207 return false;
208
209 // Check if the innermost tile loop is the parent of `tiledOp`.
210 if (rootOp->getParentOp() != tileLoopOps.back())
211 return false;
212
213 // Check if the tile loops are directly nested.
214 return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(),
215 [](Operation *op1, Operation *op2) {
216 return op1 != op2->getParentOp();
217 }) == tileLoopOps.end();
218 }
219
getTiedBBArgs(BlockArgument bbArg)220 SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
221 assert(bbArg && "expect the block argument to be non-zero");
222 SmallVector<BlockArgument> bbArgs;
223
224 // Search all tile loop block arguments from inner to outer.
225 for (auto tileLoop : reverse(tileLoopOps)) {
226 if (bbArg.getOwner()->getParentOp() != tileLoop)
227 return {};
228 bbArgs.push_back(bbArg);
229 OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg);
230 bbArg = iterArg->get().dyn_cast<BlockArgument>();
231 }
232
233 // Reverse the block arguments to order them from outer to inner.
234 return {bbArgs.rbegin(), bbArgs.rend()};
235 }
236
getTiedIterArg(BlockArgument bbArg)237 OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) {
238 // Search all block arguments and return the matching iteration argument.
239 SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
240 if (bbArgs.size() != tileLoopOps.size())
241 return nullptr;
242 return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
243 }
244
hasOtherUses(BlockArgument bbArg,tensor::ExtractSliceOp sliceOp)245 bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
246 tensor::ExtractSliceOp sliceOp) {
247 // Check the innermost block argument is either used by the ExtractSliceOp
248 // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses
249 // conservatively.
250 for (Operation *op : bbArg.getUsers()) {
251 if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op))
252 return false;
253 if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
254 if (extractSliceOp != sliceOp)
255 return false;
256 }
257 if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
258 SetVector<Operation *> backwardSlice;
259 getBackwardSlice(insertSliceOp.getSource(), &backwardSlice,
260 [](Operation *op) {
261 return isa<LinalgOp, tensor::InsertSliceOp>(op);
262 });
263 if (backwardSlice.empty() || backwardSlice.front() != sliceOp)
264 return false;
265 }
266 }
267
268 // Check the block arguments, except for the innermost one, have one use.
269 SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
270 return !all_of(bbArgs, [&](BlockArgument bbArg) {
271 return bbArg.hasOneUse() || bbArg == bbArgs.back();
272 });
273 }
274
tileRootOp(OpBuilder & b,ArrayRef<int64_t> tileSizes,ArrayRef<int64_t> tileInterchange,Optional<LinalgLoopDistributionOptions> tileDistribution)275 LogicalResult TileLoopNest::tileRootOp(
276 OpBuilder &b, ArrayRef<int64_t> tileSizes,
277 ArrayRef<int64_t> tileInterchange,
278 Optional<LinalgLoopDistributionOptions> tileDistribution) {
279 // Exit if all tile sizes are zero.
280 if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
281 return success();
282
283 // Tile the root operation.
284 LinalgTilingOptions tilingOptions;
285 tilingOptions = tilingOptions
286 .setInterchange(SmallVector<unsigned>(
287 tileInterchange.begin(), tileInterchange.end()))
288 .setTileSizes(tileSizes)
289 .setLoopType(LinalgTilingLoopType::Loops);
290 if (tileDistribution)
291 tilingOptions = tilingOptions.setDistributionOptions(*tileDistribution);
292
293 // TODO: Propagate RewriterBase everywhere.
294 IRRewriter rewriter(b);
295 FailureOr<TiledLinalgOp> tiledRootOp =
296 tileLinalgOp(rewriter, rootOp, tilingOptions);
297
298 // Exit if tiling the root operation fails.
299 if (failed(tiledRootOp))
300 return failure();
301
302 // Replace all uses of the root operation if it has been tiled before. All
303 // uses of the original untiled root operation are updated by the calling pass
304 // or pattern.
305 if (!isEmpty())
306 rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
307
308 // Transfer the stored `rootOp` loop dimensions if it has been tiled before.
309 if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) {
310 tiledRootAndFusedOpsLoops[tiledRootOp->op] =
311 tiledRootAndFusedOpsLoops[rootOp];
312 }
313
314 // Update the root operation and append the loops and tile loop dimensions.
315 rootOp = tiledRootOp->op;
316 tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
317 for (const auto &en : enumerate(tileSizes)) {
318 // Copy only the tiled loop dimensions with non-zero tile size.
319 if (en.value() == 0)
320 continue;
321 tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]);
322 }
323 assert(isValid() && "expect tile loop nest to be valid after tiling");
324 return success();
325 }
326
fuseProducer(OpBuilder & b,OpOperand * consumerOpOperand)327 FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
328 OpOperand *consumerOpOperand) {
329 // Check if the consumer has been tiled before. For example, it may not have
330 // been tiled if the outermost tile loop is a reduction loop.
331 if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0)
332 return failure();
333
334 assert(this->isValid() &&
335 "expect the tile loop nest to satisfy all invariants");
336
337 // Check the tile loop nest is non-empty.
338 if (isEmpty())
339 return failure();
340
341 // Check `consumerOpOperand` is defined by an ExtractSliceOp.
342 auto sliceOp =
343 consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
344 if (!sliceOp)
345 return failure();
346
347 // Check `sliceOp` and `consumerOp` are in the same block.
348 LinalgOp consumerOp = consumerOpOperand->getOwner();
349 if (sliceOp->getBlock() != rootOp->getBlock() ||
350 consumerOp->getBlock() != rootOp->getBlock())
351 return failure();
352
353 // Check `consumerOpOperand` is not shape-only to avoid fusion if the data is
354 // not used by the `consumerOp` computation.
355 BlockArgument bbArg = consumerOp.getTiedBlockArgument(consumerOpOperand);
356 if (bbArg.getUses().empty())
357 return failure();
358
359 // Check if the producer is a LinalgOp possibly passed by iteration argument.
360 OpOperand *iterArg = nullptr;
361 auto producerResult = sliceOp.getSource().dyn_cast<OpResult>();
362 if (auto bbArg = sliceOp.getSource().dyn_cast<BlockArgument>()) {
363 iterArg = getTiedIterArg(bbArg);
364 // Check the iteration argument may be used to pass in the producer output.
365 if (!iterArg || hasOtherUses(bbArg, sliceOp))
366 return failure();
367 producerResult = iterArg->get().dyn_cast<OpResult>();
368 }
369 if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
370 return failure();
371
372 // Compute the tiled producer slice dimensions given the tiled consumer loops.
373 SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims(
374 consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]);
375 if (tiledSliceDimIndices.empty())
376 return failure();
377
378 // Compute the tiled producer loop indices.
379 SmallVector<int64_t> tiledProducerLoopIndices =
380 getTiledProducerLoops(producerResult, tiledSliceDimIndices);
381
382 // Tile the producer operands and clone the producer in place of `sliceOp`.
383 LinalgOp clonedOp =
384 getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices,
385 tiledProducerLoopIndices, iterArg);
386 tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices;
387
388 // Cast the `clonedOp` result to gap type mismatches before canonicalization.
389 Type consumerOperandType = consumerOpOperand->get().getType();
390 Value newResult = clonedOp->getResult(producerResult.getResultNumber());
391 if (newResult.getType() != consumerOperandType) {
392 OpBuilder::InsertionGuard guard(b);
393 b.setInsertionPointAfter(clonedOp);
394 newResult = b.create<tensor::CastOp>(producerResult.getLoc(),
395 consumerOperandType, newResult);
396 }
397
398 // Replace the `sliceOp` uses except for the `clonedOp` output uses.
399 sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp);
400 return clonedOp;
401 }
402
getRootOpReplacementResults()403 ValueRange TileLoopNest::getRootOpReplacementResults() {
404 assert(!isEmpty() && "expect tile loop nest to be non-empty");
405 return tileLoopOps.front()->getOpResults();
406 }
407
getAllTiledAndFusedOps()408 SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
409 SmallVector<LinalgOp> result;
410 for (const auto &kvp : tiledRootAndFusedOpsLoops) {
411 auto linalgOp = dyn_cast<LinalgOp>(kvp.getFirst());
412 assert(linalgOp &&
413 "expect all tiled and fused operations are linalg operations");
414 result.push_back(linalgOp);
415 }
416 return result;
417 }
418
419 //===----------------------------------------------------------------------===//
420 // Tile and fuse entry-points.
421 //===----------------------------------------------------------------------===//
422
tileConsumerAndFuseProducers(OpBuilder & b,LinalgOp consumerOp,ArrayRef<int64_t> tileSizes,ArrayRef<int64_t> tileInterchange,const Optional<LinalgLoopDistributionOptions> & tileDistribution)423 FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
424 OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
425 ArrayRef<int64_t> tileInterchange,
426 const Optional<LinalgLoopDistributionOptions> &tileDistribution) {
427 assert(tileSizes.size() == tileInterchange.size() &&
428 "expect the number of tile sizes and interchange dims to match");
429 assert(isPermutation(tileInterchange) &&
430 "expect tile interchange is a permutation");
431
432 // Create an empty tile loop nest.
433 TileLoopNest tileLoopNest(consumerOp);
434
435 // Search the number of outer parallel loops to separate them from possible
436 // inner reduction dimensions.
437 SmallVector<StringAttr> iterTypes =
438 llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
439 applyPermutationToVector(iterTypes, tileInterchange);
440 auto *it = find_if(iterTypes, [&](StringAttr iterType) {
441 return !isParallelIterator(iterType);
442 });
443 int64_t split = std::distance(iterTypes.begin(), it);
444
445 // Helper to fuse the producers greedily using a queue of fusion candidates.
446 auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
447 SmallVector<OpOperand *> candidates(operands.begin(), operands.end());
448 while (!candidates.empty()) {
449 FailureOr<LinalgOp> fusedProducer =
450 tileLoopNest.fuseProducer(b, candidates.pop_back_val());
451 if (failed(fusedProducer))
452 continue;
453 candidates.append(fusedProducer->getInputAndOutputOperands());
454 }
455 };
456
457 // Perform tiling and fusion in two steps. We need to respect the loop
458 // interchange here; filter parellel dimensions based on their order *after*
459 // permutation but pass in the original configuration *before* permuation,
460 // given the tiling and interchange happen together.
461 SmallVector<int64_t> outerTileSizes(tileSizes.size(), 0);
462 SmallVector<int64_t> innerTileSizes(tileSizes.size(), 0);
463 for (int64_t i : tileInterchange.take_front(split))
464 outerTileSizes[i] = tileSizes[i];
465 for (int64_t i : tileInterchange.drop_front(split))
466 innerTileSizes[i] = tileSizes[i];
467
468 // Tile the outer parallel loops and fuse the output operands.
469 if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange,
470 tileDistribution)))
471 return failure();
472 fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
473
474 // Tile the remaining loops and fuse the input operands.
475 if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange,
476 tileDistribution)))
477 return failure();
478 fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
479
480 // Exit if the tile loop nest is empty since all tile sizes are zero.
481 if (tileLoopNest.isEmpty())
482 return failure();
483
484 return tileLoopNest;
485 }
486