1 //===- FusionOnTensors.cpp - Implementation of linalg Fusion --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg fusion on tensors
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/Support/LLVM.h"
25 
26 using namespace mlir;
27 using namespace linalg;
28 
29 //===----------------------------------------------------------------------===//
30 // StructuredOp specific helpers.
31 //===----------------------------------------------------------------------===//
32 
33 /// Returns the tiled slice dimensions given the tiled consumer loop dimensions.
34 /// The slice defines a hyper rectangular iteration space and fusing the
35 /// producer is always possible. However, depending on the consumer indexing
36 /// map, not all slice elements may be consumed and the tiles may overlap. In
37 /// these cases, fusion introduces redundant computation.
38 static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
39                                               ArrayRef<int64_t> tiledLoopDims) {
40   // Get the consumer operand indexing map.
41   LinalgOp consumerOp = consumerOperand->getOwner();
42   AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);
43 
44   // Search the slice dimensions tiled by a tile loop dimension.
45   DenseSet<int64_t> tiledSliceDimIndices;
46   for (const auto &en : enumerate(indexingMap.getResults())) {
47     for (auto tiledLoopDim : tiledLoopDims) {
48       if (en.value().isFunctionOfDim(tiledLoopDim))
49         tiledSliceDimIndices.insert(en.index());
50     }
51   }
52   return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()};
53 }
54 
55 /// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions
56 /// of the producer result slice returns the tiled producer loop dimensions.
57 /// Example:
58 /// ```
59 /// %res = linalg.fill(%cst, %input)
60 /// scf.for %i
61 ///   scf.for %j
62 ///     %slice = tensor.extract_slice %res[%i, %j]
63 /// ```
64 /// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1].
65 static SmallVector<int64_t>
66 getTiledProducerLoops(OpResult producerResult,
67                       ArrayRef<int64_t> tiledSliceDimIndices) {
68   LinalgOp producerOp = producerResult.getOwner();
69 
70   // Get the indexing map of the `producerOp` output operand that matches
71   // ´producerResult´.
72   AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
73       producerOp.getOutputOperand(producerResult.getResultNumber()));
74 
75   // Keep only the tiled result slice dimensions of `producerIndexingMap`.
76   AffineMap tiledProducerIndexingSubMap =
77       producerIndexingMap.getSubMap(SmallVector<unsigned>(
78           tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()));
79 
80   // Compute the producer loop indices mapped to the tiled result slice
81   // dimensions. As the output indexing map of structured operations are
82   // projected permutations, `tiledProducerIndexingSubMap` has to be a
83   // projected permutation as well. We can thus obtain the producer loop indices
84   // by getting the positions of the result dimensions.
85   // Example:
86   // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2].
87   assert(tiledProducerIndexingSubMap.isProjectedPermutation() &&
88          "expect slice and producer loop dimensions map one-to-one");
89   SmallVector<int64_t> tiledProducerLoopIndices;
90   transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
91             std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
92               return tiledProducerIndexingSubMap.getDimPosition(idx);
93             });
94 
95   return tiledProducerLoopIndices;
96 }
97 
98 /// Returns the producer fused in place of `sliceOp`. Tile the producer operands
99 /// along the `tiledSliceDimIndices` and clone the producer. Consider the case
100 /// of fusion of an output tensor:
101 /// ```
102 /// %1 = producer ins(...) outs(%0)
103 /// %2 = consumer ins(...) outs(%1)
104 /// ```
105 /// When consumer is tiled, %1 appears in the loop iter_args:
106 /// ```
107 /// %1 = producer ins(...) outs(%0)
108 /// %2 = scf.for ... iter_args(%1) .. (%bbarg) {
109 ///   %t1 = tensor.extract_slice %bbarg[..]
110 ///   %t2 = consumer ins(...) outs(%t1)
111 ///   %r = tensor.insert_slice %t2, %bbarg[...]
112 /// }
113 /// ```
114 /// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0):
115 /// ```
116 /// %2 = scf.for ... iter_args(%0) .. (%bbarg) {
117 ///   %t0 = tensor.extract_slice %bbarg[..]
118 ///   %t1 = producer ins(...) outs(%t0)
119 ///   %t2 = consumer ins(...) outs(%t1)
120 ///   %r = tensor.insert_slice %t2, %bbarg[...]
121 /// }
122 /// ```
123 /// This transformation is only valid if %bbarg is exclusively used by the
124 /// output ExtractSliceOp / InsertSliceOp pair, which is checked by the
125 /// `fuseProducer` method.
126 /// TODO: instead of check and failure, insert new iter_args each time a
127 /// producer is fused into a consumer and fold away unused iter_args.
128 static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
129                                  tensor::ExtractSliceOp sliceOp,
130                                  ArrayRef<int64_t> tiledSliceDimIndices,
131                                  ArrayRef<int64_t> tiledProducerLoopIndices,
132                                  OpOperand *iterArg) {
133   // Clone the producer after `sliceOp` since the slice may be reused to pass in
134   // the producer result.
135   OpBuilder::InsertionGuard guard(b);
136   b.setInsertionPointAfter(sliceOp);
137 
138   // Get the producer.
139   LinalgOp producerOp = producerResult.getOwner();
140   Location loc = producerOp.getLoc();
141 
142   // Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
143   SmallVector<Value> producerLoopBounds;
144   transform(producerOp.createLoopRanges(b, loc),
145             std::back_inserter(producerLoopBounds),
146             [](Range range) { return range.size; });
147   SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
148 
149   // Tile the producer operands given the `sliceOp` ranges. Iterate the
150   // `tiledSliceDimIndices` and store the tile offset and size for the tiled
151   // slice dimension.
152   auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
153   SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr);
154   SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero);
155   SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr);
156   for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) {
157     int64_t tiledSliceDim = std::get<0>(it);
158     int64_t tiledProducerLoop = std::get<1>(it);
159     tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset;
160     tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size;
161     allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
162   }
163   erase_value(tileIvs, nullptr);
164   SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
165   tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
166                                   tileSizes, producerLoopBounds);
167 
168   // Output fusion has to update the iteration arguments of the tile loop nest.
169   // In particular, the iteration argument of the outermost tile loop needs to
170   // be set to the producer output instead of the producer result and `clonedOp`
171   // shall use the existing `sliceOp` result instead of the tiled producer
172   // output operand.
173   if (iterArg) {
174     OpOperand *outputOperand =
175         producerOp.getOutputOperand(producerResult.getResultNumber());
176     iterArg->set(outputOperand->get());
177     tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult();
178   }
179 
180   // Clone the producer using the tiled producer operands.
181   TypeRange resultTypes = ValueRange(tiledOperands)
182                               .take_back(producerOp.getNumOutputs())
183                               .getTypes();
184   LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
185 
186   // Shift all IndexOp results by the tile offset.
187   addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
188 
189   return clonedOp;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // TileLoopNest specific helpers.
194 //===----------------------------------------------------------------------===//
195 
196 bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); }
197 
198 bool TileLoopNest::isValid() {
199   // Check if `rootOp` has been tiled at least once.
200   if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0)
201     return false;
202 
203   // Check if the number of loop operations and dimensions match.
204   if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size())
205     return false;
206 
207   // Check if the innermost tile loop is the parent of `tiledOp`.
208   if (rootOp->getParentOp() != tileLoopOps.back())
209     return false;
210 
211   // Check if the tile loops are directly nested.
212   return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(),
213                             [](Operation *op1, Operation *op2) {
214                               return op1 != op2->getParentOp();
215                             }) == tileLoopOps.end();
216 }
217 
218 SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
219   assert(bbArg && "expect the block argument to be non-zero");
220   SmallVector<BlockArgument> bbArgs;
221 
222   // Search all tile loop block arguments from inner to outer.
223   for (auto tileLoop : reverse(tileLoopOps)) {
224     if (bbArg.getOwner()->getParentOp() != tileLoop)
225       return {};
226     bbArgs.push_back(bbArg);
227     OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg);
228     bbArg = iterArg->get().dyn_cast<BlockArgument>();
229   }
230 
231   // Reverse the block arguments to order them from outer to inner.
232   return {bbArgs.rbegin(), bbArgs.rend()};
233 }
234 
235 OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) {
236   // Search all block arguments and return the matching iteration argument.
237   SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
238   if (bbArgs.size() != tileLoopOps.size())
239     return nullptr;
240   return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
241 }
242 
243 bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
244                                 tensor::ExtractSliceOp sliceOp) {
245   // Check the innermost block argument is either used by the ExtractSliceOp
246   // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses
247   // conservatively.
248   for (Operation *op : bbArg.getUsers()) {
249     if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op))
250       return false;
251     if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
252       if (extractSliceOp != sliceOp)
253         return false;
254     }
255     if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
256       SetVector<Operation *> backwardSlice;
257       getBackwardSlice(insertSliceOp.source(), &backwardSlice,
258                        [](Operation *op) {
259                          return isa<LinalgOp, tensor::InsertSliceOp>(op);
260                        });
261       if (backwardSlice.empty() || backwardSlice.front() != sliceOp)
262         return false;
263     }
264   }
265 
266   // Check the block arguments, except for the innermost one, have one use.
267   SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
268   return !all_of(bbArgs, [&](BlockArgument bbArg) {
269     return bbArg.hasOneUse() || bbArg == bbArgs.back();
270   });
271 }
272 
273 LogicalResult TileLoopNest::tileRootOp(
274     OpBuilder &b, ArrayRef<int64_t> tileSizes,
275     ArrayRef<int64_t> tileInterchange,
276     Optional<LinalgLoopDistributionOptions> tileDistribution) {
277   // Exit if all tile sizes are zero.
278   if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
279     return success();
280 
281   // Tile the root operation.
282   LinalgTilingOptions tilingOptions;
283   tilingOptions = tilingOptions
284                       .setInterchange(SmallVector<unsigned>(
285                           tileInterchange.begin(), tileInterchange.end()))
286                       .setTileSizes(tileSizes)
287                       .setLoopType(LinalgTilingLoopType::Loops);
288   if (tileDistribution)
289     tilingOptions =
290         tilingOptions.setDistributionOptions(tileDistribution.getValue());
291 
292   // TODO: Propagate RewriterBase everywhere.
293   IRRewriter rewriter(b);
294   FailureOr<TiledLinalgOp> tiledRootOp =
295       tileLinalgOp(rewriter, rootOp, tilingOptions);
296 
297   // Exit if tiling the root operation fails.
298   if (failed(tiledRootOp))
299     return failure();
300 
301   // Replace all uses of the root operation if it has been tiled before. All
302   // uses of the original untiled root operation are updated by the calling pass
303   // or pattern.
304   if (!isEmpty())
305     rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
306 
307   // Transfer the stored `rootOp` loop dimensions if it has been tiled before.
308   if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) {
309     tiledRootAndFusedOpsLoops[tiledRootOp->op] =
310         tiledRootAndFusedOpsLoops[rootOp];
311   }
312 
313   // Update the root operation and append the loops and tile loop dimensions.
314   rootOp = tiledRootOp->op;
315   tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
316   for (const auto &en : enumerate(tileSizes)) {
317     // Copy only the tiled loop dimensions with non-zero tile size.
318     if (en.value() == 0)
319       continue;
320     tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]);
321   }
322   assert(isValid() && "expect tile loop nest to be valid after tiling");
323   return success();
324 }
325 
326 FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
327                                                OpOperand *consumerOpOperand) {
328   // Check if the consumer has been tiled before. For example, it may not have
329   // been tiled if the outermost tile loop is a reduction loop.
330   if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0)
331     return failure();
332 
333   assert(this->isValid() &&
334          "expect the tile loop nest to satisfy all invariants");
335 
336   // Check the tile loop nest is non-empty.
337   if (isEmpty())
338     return failure();
339 
340   // Check `consumerOpOperand` is defined by an ExtractSliceOp.
341   auto sliceOp =
342       consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
343   if (!sliceOp)
344     return failure();
345 
346   // Check `sliceOp` and `consumerOp` are in the same block.
347   LinalgOp consumerOp = consumerOpOperand->getOwner();
348   if (sliceOp->getBlock() != rootOp->getBlock() ||
349       consumerOp->getBlock() != rootOp->getBlock())
350     return failure();
351 
352   // Check if the producer is a LinalgOp possibly passed by iteration argument.
353   OpOperand *iterArg = nullptr;
354   auto producerResult = sliceOp.source().dyn_cast<OpResult>();
355   if (auto bbArg = sliceOp.source().dyn_cast<BlockArgument>()) {
356     iterArg = getTiedIterArg(bbArg);
357     // Check the iteration argument may be used to pass in the producer output.
358     if (!iterArg || hasOtherUses(bbArg, sliceOp))
359       return failure();
360     producerResult = iterArg->get().dyn_cast<OpResult>();
361   }
362   if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
363     return failure();
364 
365   // Compute the tiled producer slice dimensions given the tiled consumer loops.
366   SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims(
367       consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]);
368   if (tiledSliceDimIndices.empty())
369     return failure();
370 
371   // Compute the tiled producer loop indices.
372   SmallVector<int64_t> tiledProducerLoopIndices =
373       getTiledProducerLoops(producerResult, tiledSliceDimIndices);
374 
375   // Tile the producer operands and clone the producer in place of `sliceOp`.
376   LinalgOp clonedOp =
377       getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices,
378                        tiledProducerLoopIndices, iterArg);
379   tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices;
380 
381   // Cast the `clonedOp` result to gap type mismatches before canonicalization.
382   Type consumerOperandType = consumerOpOperand->get().getType();
383   Value newResult = clonedOp->getResult(producerResult.getResultNumber());
384   if (newResult.getType() != consumerOperandType) {
385     OpBuilder::InsertionGuard guard(b);
386     b.setInsertionPointAfter(clonedOp);
387     newResult = b.create<tensor::CastOp>(producerResult.getLoc(),
388                                          consumerOperandType, newResult);
389   }
390 
391   // Replace the `sliceOp` uses except for the `clonedOp` output uses.
392   sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp);
393   return clonedOp;
394 }
395 
396 ValueRange TileLoopNest::getRootOpReplacementResults() {
397   assert(!isEmpty() && "expect tile loop nest to be non-empty");
398   return tileLoopOps.front()->getOpResults();
399 }
400 
401 SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
402   SmallVector<LinalgOp> result;
403   for (const auto &kvp : tiledRootAndFusedOpsLoops) {
404     auto linalgOp = dyn_cast<LinalgOp>(kvp.getFirst());
405     assert(linalgOp &&
406            "expect all tiled and fused operations are linalg operations");
407     result.push_back(linalgOp);
408   }
409   return result;
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // Tile and fuse entry-points.
414 //===----------------------------------------------------------------------===//
415 
416 FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
417     OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
418     ArrayRef<int64_t> tileInterchange,
419     const Optional<LinalgLoopDistributionOptions> &tileDistribution) {
420   assert(tileSizes.size() == tileInterchange.size() &&
421          "expect the number of tile sizes and interchange dims to match");
422   assert(isPermutation(tileInterchange) &&
423          "expect tile interchange is a permutation");
424 
425   // Create an empty tile loop nest.
426   TileLoopNest tileLoopNest(consumerOp);
427 
428   // Search the number of outer parallel loops to separate them from possible
429   // inner reduction dimensions.
430   SmallVector<StringAttr> iterTypes =
431       llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
432   applyPermutationToVector(iterTypes, tileInterchange);
433   auto *it = find_if(iterTypes, [&](StringAttr iterType) {
434     return !isParallelIterator(iterType);
435   });
436   int64_t split = std::distance(iterTypes.begin(), it);
437 
438   // Helper to fuse the producers greedily using a queue of fusion candidates.
439   auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
440     SmallVector<OpOperand *> candidates(operands.begin(), operands.end());
441     while (!candidates.empty()) {
442       FailureOr<LinalgOp> fusedProducer =
443           tileLoopNest.fuseProducer(b, candidates.pop_back_val());
444       if (failed(fusedProducer))
445         continue;
446       candidates.append(fusedProducer->getInputAndOutputOperands());
447     }
448   };
449 
450   // Tile the outer parallel loops and fuse the output operands.
451   SmallVector<int64_t> outerTileSizes;
452   outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
453   outerTileSizes.append(tileSizes.size() - split, 0);
454   if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange,
455                                      tileDistribution)))
456     return failure();
457   fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
458 
459   // Tile the remaining loops and fuse the input operands.
460   SmallVector<int64_t> innerTileSizes;
461   innerTileSizes.append(split, 0);
462   innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
463   if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange,
464                                      tileDistribution)))
465     return failure();
466   fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
467 
468   // Exit if the tile loop nest is empty since all tile sizes are zero.
469   if (tileLoopNest.isEmpty())
470     return failure();
471 
472   return tileLoopNest;
473 }
474