1 //===- FusionOnTensors.cpp - Implementation of linalg Fusion --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg fusion on tensors
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/Support/LLVM.h"
24 
25 using namespace mlir;
26 using namespace linalg;
27 
28 //===----------------------------------------------------------------------===//
29 // StructuredOp specific helpers.
30 //===----------------------------------------------------------------------===//
31 
32 /// Returns the tiled slice dimensions given the tiled consumer loop dimensions.
33 /// The slice defines a hyper rectangular iteration space and fusing the
34 /// producer is always possible. However, depending on the consumer indexing
35 /// map, not all slice elements may be consumed and the tiles may overlap. In
36 /// these cases, fusion introduces redundant computation.
37 static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
38                                               ArrayRef<int64_t> tiledLoopDims) {
39   // Get the consumer operand indexing map.
40   LinalgOp consumerOp = consumerOperand->getOwner();
41   AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);
42 
43   // Search the slice dimensions tiled by a tile loop dimension.
44   DenseSet<int64_t> tiledSliceDimIndices;
45   for (const auto &en : enumerate(indexingMap.getResults())) {
46     for (auto tiledLoopDim : tiledLoopDims) {
47       if (en.value().isFunctionOfDim(tiledLoopDim))
48         tiledSliceDimIndices.insert(en.index());
49     }
50   }
51   return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()};
52 }
53 
54 /// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions
55 /// of the producer result slice returns the tiled producer loop dimensions.
56 /// Example:
57 /// ```
58 /// %res = linalg.fill(%cst, %input)
59 /// scf.for %i
60 ///   scf.for %j
61 ///     %slice = tensor.extract_slice %res[%i, %j]
62 /// ```
63 /// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1].
64 static SmallVector<int64_t>
65 getTiledProducerLoops(OpResult producerResult,
66                       ArrayRef<int64_t> tiledSliceDimIndices) {
67   LinalgOp producerOp = producerResult.getOwner();
68 
69   // Get the indexing map of the `producerOp` output operand that matches
70   // ´producerResult´.
71   AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
72       producerOp.getOutputOperand(producerResult.getResultNumber()));
73 
74   // Keep only the tiled result slice dimensions of `producerIndexingMap`.
75   AffineMap tiledProducerIndexingSubMap =
76       producerIndexingMap.getSubMap(SmallVector<unsigned>(
77           tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()));
78 
79   // Compute the producer loop indices mapped to the tiled result slice
80   // dimensions. As the output indexing map of structured operations are
81   // projected permutations, `tiledProducerIndexingSubMap` has to be a
82   // projected permutation as well. We can thus obtain the producer loop indices
83   // by getting the positions of the result dimensions.
84   // Example:
85   // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2].
86   assert(tiledProducerIndexingSubMap.isProjectedPermutation() &&
87          "expect slice and producer loop dimensions map one-to-one");
88   SmallVector<int64_t> tiledProducerLoopIndices;
89   transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
90             std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
91               return tiledProducerIndexingSubMap.getDimPosition(idx);
92             });
93 
94   return tiledProducerLoopIndices;
95 }
96 
97 /// Returns the producer fused in place of `sliceOp`. Tile the producer operands
98 /// along the `tiledSliceDimIndices` and clone the producer. Consider the case
99 /// of fusion of an output tensor:
100 /// ```
101 /// %1 = producer ins(...) outs(%0)
102 /// %2 = consumer ins(...) outs(%1)
103 /// ```
104 /// When consumer is tiled, %1 appears in the loop iter_args:
105 /// ```
106 /// %1 = producer ins(...) outs(%0)
107 /// %2 = scf.for ... iter_args(%1) .. (%bbarg) {
108 ///   %t1 = tensor.extract_slice %bbarg[..]
109 ///   %t2 = consumer ins(...) outs(%t1)
110 ///   %r = tensor.insert_slice %t2, %bbarg[...]
111 /// }
112 /// ```
113 /// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0):
114 /// ```
115 /// %2 = scf.for ... iter_args(%0) .. (%bbarg) {
116 ///   %t0 = tensor.extract_slice %bbarg[..]
117 ///   %t1 = producer ins(...) outs(%t0)
118 ///   %t2 = consumer ins(...) outs(%t1)
119 ///   %r = tensor.insert_slice %t2, %bbarg[...]
120 /// }
121 /// ```
122 /// This transformation is only valid if %bbarg is exclusively used by the
123 /// output ExtractSliceOp / InsertSliceOp pair, which is checked by the
124 /// `fuseProducer` method.
125 /// TODO: instead of check and failure, insert new iter_args each time a
126 /// producer is fused into a consumer and fold away unused iter_args.
127 static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
128                                  tensor::ExtractSliceOp sliceOp,
129                                  ArrayRef<int64_t> tiledSliceDimIndices,
130                                  ArrayRef<int64_t> tiledProducerLoopIndices,
131                                  OpOperand *iterArg) {
132   // Clone the producer after `sliceOp` since the slice may be reused to pass in
133   // the producer result.
134   OpBuilder::InsertionGuard guard(b);
135   b.setInsertionPointAfter(sliceOp);
136 
137   // Get the producer.
138   LinalgOp producerOp = producerResult.getOwner();
139   Location loc = producerOp.getLoc();
140 
141   // Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
142   SmallVector<Value> producerLoopBounds;
143   transform(producerOp.createLoopRanges(b, loc),
144             std::back_inserter(producerLoopBounds),
145             [](Range range) { return range.size; });
146   SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
147 
148   // Tile the producer operands given the `sliceOp` ranges. Iterate the
149   // `tiledSliceDimIndices` and store the tile offset and size for the tiled
150   // slice dimension.
151   auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
152   SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr);
153   SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero);
154   SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr);
155   for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) {
156     int64_t tiledSliceDim = std::get<0>(it);
157     int64_t tiledProducerLoop = std::get<1>(it);
158     tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset;
159     tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size;
160     allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
161   }
162   erase_value(tileIvs, nullptr);
163   SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
164   tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
165                                   tileSizes, producerLoopBounds);
166 
167   // Output fusion has to update the iteration arguments of the tile loop nest.
168   // In particular, the iteration argument of the outermost tile loop needs to
169   // be set to the producer output instead of the producer result and `clonedOp`
170   // shall use the existing `sliceOp` result instead of the tiled producer
171   // output operand.
172   if (iterArg) {
173     OpOperand *outputOperand =
174         producerOp.getOutputOperand(producerResult.getResultNumber());
175     iterArg->set(outputOperand->get());
176     tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult();
177   }
178 
179   // Clone the producer using the tiled producer operands.
180   TypeRange resultTypes = ValueRange(tiledOperands)
181                               .take_back(producerOp.getNumOutputs())
182                               .getTypes();
183   LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
184 
185   // Shift all IndexOp results by the tile offset.
186   addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
187 
188   return clonedOp;
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // TileLoopNest specific helpers.
193 //===----------------------------------------------------------------------===//
194 
195 bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); }
196 
197 bool TileLoopNest::isValid() {
198   // Check if `rootOp` has been tiled at least once.
199   if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0)
200     return false;
201 
202   // Check if the number of loop operations and dimensions match.
203   if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size())
204     return false;
205 
206   // Check if the innermost tile loop is the parent of `tiledOp`.
207   if (rootOp->getParentOp() != tileLoopOps.back())
208     return false;
209 
210   // Check if the tile loops are directly nested.
211   return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(),
212                             [](Operation *op1, Operation *op2) {
213                               return op1 != op2->getParentOp();
214                             }) == tileLoopOps.end();
215 }
216 
217 SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
218   assert(bbArg && "expect the block argument to be non-zero");
219   SmallVector<BlockArgument> bbArgs;
220 
221   // Search all tile loop block arguments from inner to outer.
222   for (auto tileLoop : reverse(tileLoopOps)) {
223     if (bbArg.getOwner()->getParentOp() != tileLoop)
224       return {};
225     bbArgs.push_back(bbArg);
226     OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg);
227     bbArg = iterArg->get().dyn_cast<BlockArgument>();
228   }
229 
230   // Reverse the block arguments to order them from outer to inner.
231   return {bbArgs.rbegin(), bbArgs.rend()};
232 }
233 
234 OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) {
235   // Search all block arguments and return the matching iteration argument.
236   SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
237   if (bbArgs.size() != tileLoopOps.size())
238     return nullptr;
239   return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
240 }
241 
242 bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
243                                 tensor::ExtractSliceOp sliceOp) {
244   // Check the innermost block argument is either used by the ExtractSliceOp
245   // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses
246   // conservatively.
247   for (Operation *op : bbArg.getUsers()) {
248     if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op))
249       return false;
250     if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
251       if (extractSliceOp != sliceOp)
252         return false;
253     }
254     if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
255       SetVector<Operation *> backwardSlice;
256       getBackwardSlice(insertSliceOp.source(), &backwardSlice,
257                        [](Operation *op) {
258                          return isa<LinalgOp, tensor::InsertSliceOp>(op);
259                        });
260       if (backwardSlice.empty() || backwardSlice.front() != sliceOp)
261         return false;
262     }
263   }
264 
265   // Check the block arguments, except for the innermost one, have one use.
266   SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
267   return !all_of(bbArgs, [&](BlockArgument bbArg) {
268     return bbArg.hasOneUse() || bbArg == bbArgs.back();
269   });
270 }
271 
272 LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
273                                        ArrayRef<int64_t> tileSizes,
274                                        ArrayRef<int64_t> tileInterchange) {
275   // Exit if all tile sizes are zero.
276   if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
277     return success();
278 
279   // Tile the root operation.
280   LinalgTilingOptions tilingOptions;
281   tilingOptions = tilingOptions
282                       .setInterchange(SmallVector<unsigned>(
283                           tileInterchange.begin(), tileInterchange.end()))
284                       .setTileSizes(tileSizes)
285                       .setLoopType(LinalgTilingLoopType::Loops);
286 
287   // TODO: Propagate RewriterBase everywhere.
288   IRRewriter rewriter(b);
289   FailureOr<TiledLinalgOp> tiledRootOp =
290       tileLinalgOp(rewriter, rootOp, tilingOptions);
291 
292   // Exit if tiling the root operation fails.
293   if (failed(tiledRootOp))
294     return failure();
295 
296   // Replace all uses of the root operation if it has been tiled before. All
297   // uses of the original untiled root operation are updated by the calling pass
298   // or pattern.
299   if (!isEmpty())
300     rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
301 
302   // Transfer the stored `rootOp` loop dimensions if it has been tiled before.
303   if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) {
304     tiledRootAndFusedOpsLoops[tiledRootOp->op] =
305         tiledRootAndFusedOpsLoops[rootOp];
306   }
307 
308   // Update the root operation and append the loops and tile loop dimensions.
309   rootOp = tiledRootOp->op;
310   tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
311   for (const auto &en : enumerate(tileSizes)) {
312     // Copy only the tiled loop dimensions with non-zero tile size.
313     if (en.value() == 0)
314       continue;
315     tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]);
316   }
317   assert(isValid() && "expect tile loop nest to be valid after tiling");
318   return success();
319 }
320 
321 FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
322                                                OpOperand *consumerOpOperand) {
323   // Check if the consumer has been tiled before. For example, it may not have
324   // been tiled if the outermost tile loop is a reduction loop.
325   if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0)
326     return failure();
327 
328   assert(this->isValid() &&
329          "expect the tile loop nest to satisfy all invariants");
330 
331   // Check the tile loop nest is non-empty.
332   if (isEmpty())
333     return failure();
334 
335   // Check `consumerOpOperand` is defined by an ExtractSliceOp.
336   auto sliceOp =
337       consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
338   if (!sliceOp)
339     return failure();
340 
341   // Check `sliceOp` and `consumerOp` are in the same block.
342   LinalgOp consumerOp = consumerOpOperand->getOwner();
343   if (sliceOp->getBlock() != rootOp->getBlock() ||
344       consumerOp->getBlock() != rootOp->getBlock())
345     return failure();
346 
347   // Check if the producer is a LinalgOp possibly passed by iteration argument.
348   OpOperand *iterArg = nullptr;
349   auto producerResult = sliceOp.source().dyn_cast<OpResult>();
350   if (auto bbArg = sliceOp.source().dyn_cast<BlockArgument>()) {
351     iterArg = getTiedIterArg(bbArg);
352     // Check the iteration argument may be used to pass in the producer output.
353     if (!iterArg || hasOtherUses(bbArg, sliceOp))
354       return failure();
355     producerResult = iterArg->get().dyn_cast<OpResult>();
356   }
357   if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
358     return failure();
359 
360   // Compute the tiled producer slice dimensions given the tiled consumer loops.
361   SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims(
362       consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]);
363   if (tiledSliceDimIndices.empty())
364     return failure();
365 
366   // Compute the tiled producer loop indices.
367   SmallVector<int64_t> tiledProducerLoopIndices =
368       getTiledProducerLoops(producerResult, tiledSliceDimIndices);
369 
370   // Tile the producer operands and clone the producer in place of `sliceOp`.
371   LinalgOp clonedOp =
372       getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices,
373                        tiledProducerLoopIndices, iterArg);
374   tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices;
375 
376   // Cast the `clonedOp` result to gap type mismatches before canonicalization.
377   Type consumerOperandType = consumerOpOperand->get().getType();
378   Value newResult = clonedOp->getResult(producerResult.getResultNumber());
379   if (newResult.getType() != consumerOperandType) {
380     OpBuilder::InsertionGuard guard(b);
381     b.setInsertionPointAfter(clonedOp);
382     newResult = b.create<tensor::CastOp>(producerResult.getLoc(),
383                                          consumerOperandType, newResult);
384   }
385 
386   // Replace the `sliceOp` uses except for the `clonedOp` output uses.
387   sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp);
388   return clonedOp;
389 }
390 
391 ValueRange TileLoopNest::getRootOpReplacementResults() {
392   assert(!isEmpty() && "expect tile loop nest to be non-empty");
393   return tileLoopOps.front()->getOpResults();
394 }
395 
396 SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
397   SmallVector<LinalgOp> result;
398   for (const auto &kvp : tiledRootAndFusedOpsLoops) {
399     auto linalgOp = dyn_cast<LinalgOp>(kvp.getFirst());
400     assert(linalgOp &&
401            "expect all tiled and fused operations are linalg operations");
402     result.push_back(linalgOp);
403   }
404   return result;
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // Tile and fuse entry-points.
409 //===----------------------------------------------------------------------===//
410 
411 FailureOr<TileLoopNest>
412 mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
413                                            ArrayRef<int64_t> tileSizes,
414                                            ArrayRef<int64_t> tileInterchange) {
415   assert(tileSizes.size() == tileInterchange.size() &&
416          "expect the number of tile sizes and interchange dims to match");
417   assert(isPermutation(tileInterchange) &&
418          "expect tile interchange is a permutation");
419 
420   // Create an empty tile loop nest.
421   TileLoopNest tileLoopNest(consumerOp);
422 
423   // Search the number of outer parallel loops to separate them from possible
424   // inner reduction dimensions.
425   SmallVector<StringAttr> iterTypes =
426       llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
427   applyPermutationToVector(iterTypes, tileInterchange);
428   auto *it = find_if(iterTypes, [&](StringAttr iterType) {
429     return !isParallelIterator(iterType);
430   });
431   int64_t split = std::distance(iterTypes.begin(), it);
432 
433   // Helper to fuse the producers greedily using a queue of fusion candidates.
434   auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
435     SmallVector<OpOperand *> candidates(operands.begin(), operands.end());
436     while (!candidates.empty()) {
437       FailureOr<LinalgOp> fusedProducer =
438           tileLoopNest.fuseProducer(b, candidates.pop_back_val());
439       if (failed(fusedProducer))
440         continue;
441       candidates.append(fusedProducer->getInputAndOutputOperands());
442     }
443   };
444 
445   // Tile the outer parallel loops and fuse the output operands.
446   SmallVector<int64_t> outerTileSizes;
447   outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
448   outerTileSizes.append(tileSizes.size() - split, 0);
449   if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
450     return failure();
451   fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
452 
453   // Tile the remaining loops and fuse the input operands.
454   SmallVector<int64_t> innerTileSizes;
455   innerTileSizes.append(split, 0);
456   innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
457   if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
458     return failure();
459   fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
460 
461   return tileLoopNest;
462 }
463