1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Dominance.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "mlir/Transforms/RegionUtils.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 
33 #include <set>
34 
35 #define DEBUG_TYPE "linalg-fusion"
36 
37 using namespace mlir;
38 using namespace mlir::edsc;
39 using namespace mlir::edsc::intrinsics;
40 using namespace mlir::linalg;
41 
42 using llvm::dbgs;
43 
44 /// Implements a simple high-level fusion pass on linalg structured operations.
45 ///
46 /// In each block, linalg ops are processed in reverse textual order.
47 /// Given a linalg op `O`, fusion occurs by:
48 ///   1. inspecting the linalg ops that write into the views read by `O`. There
49 ///      are 2 cases:
50 ///      a) buffer case: use the SSA value of the views and a simple alias
51 ///         analysis on subview ops to determine producer-consumer dependences;
52 ///      b) tensor case: use SSA use-def chains on subtensor ops;
53 ///   2. greedily fuse the linalg ops that produce the subview/subtensor.
54 ///   3. inspect the fused ops and determine whether they have other remaining
55 ///      LinalgOp uses. If not, then erase the original producing linalg op.
56 ///
57 /// More advanced use cases, analyses as well as profitability heuristics are
58 /// left for future work.
59 
60 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
61 // by `permutationMap`.
62 static void inferShapeComponents(AffineMap permutationMap,
63                                  ArrayRef<Range> loopRanges,
64                                  SmallVectorImpl<OpFoldResult> &offsets,
65                                  SmallVectorImpl<OpFoldResult> &sizes,
66                                  SmallVectorImpl<OpFoldResult> &strides) {
67   assert(permutationMap.isProjectedPermutation() &&
68          "expected some subset of a permutation map");
69   SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
70   unsigned idx = 0;
71   for (AffineExpr e : permutationMap.getResults()) {
72     // loopToOperandRangesMaps are permutations-only, just swap indices.
73     unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
74     shapeRanges[idx++] = loopRanges[loopPos];
75   }
76   // Construct a new subshape for the tile.
77   unsigned rank = shapeRanges.size();
78   offsets.reserve(rank);
79   sizes.reserve(rank);
80   strides.reserve(rank);
81   for (auto r : shapeRanges) {
82     offsets.push_back(r.offset);
83     sizes.push_back(r.size);
84     strides.push_back(r.stride);
85   }
86 }
87 
88 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
89 // a subset of the original loop ranges of `op`.
90 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
91 // to the `loopRanges` in order to obtain view ranges.
92 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
93                                     ArrayRef<Range> loopRanges) {
94   SmallVector<Value, 8> clonedShapes;
95   clonedShapes.reserve(op.getNumShapedOperands());
96 
97   // Iterate over the shape operands in order.
98   // Extract the subranges from the linearized ranges.
99   for (auto en : llvm::enumerate(op.getShapedOperands())) {
100     unsigned shapedOperandIdx = en.index();
101     AffineMap map = op.getIndexingMap(shapedOperandIdx);
102     LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
103                             << " with indexingMap: " << map << "\n");
104     SmallVector<OpFoldResult, 4> offsets, sizes, strides;
105     inferShapeComponents(map, loopRanges, offsets, sizes, strides);
106     Value shape = en.value();
107     Value sub = shape.getType().isa<MemRefType>()
108                     ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
109                           .getResult()
110                     : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
111                           .getResult();
112     clonedShapes.push_back(sub);
113   }
114   // Append the other operands.
115   auto operands = op.getAssumedNonShapedOperands();
116   clonedShapes.append(operands.begin(), operands.end());
117 
118   // Iterate over the results in order.
119   // Extract the subtensor type from the linearized range.
120   // Since we do not enforce any canonicalizations on the fly, this is always
121   // fully dynamic at construction time.
122   SmallVector<Type, 4> resultTypes;
123   resultTypes.reserve(op->getNumResults());
124   for (RankedTensorType t : op.getOutputTensorTypes()) {
125     unsigned rank = t.getRank();
126     SmallVector<int64_t, 4> staticOffsetsVector(
127         rank, ShapedType::kDynamicStrideOrOffset);
128     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
129     SmallVector<int64_t, 4> staticStridesVector(
130         rank, ShapedType::kDynamicStrideOrOffset);
131     resultTypes.push_back(SubTensorOp::inferResultType(
132         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
133         staticStridesVector));
134   }
135 
136   Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
137   // When the producer is an IndexedGenericOp, we have to transform its block
138   // IV arguments according to the tiling of the consumer, i.e. offset them by
139   // the values computed in `loopRanges`.
140   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
141     auto &block = indexedGenericOp.region().front();
142     OpBuilder::InsertionGuard g(b);
143     b.setInsertionPointToStart(&block);
144     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
145       Value oldIndex = block.getArgument(i);
146       // TODO: replace by an affine_apply.
147       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
148                                          loopRanges[i].offset);
149       oldIndex.replaceAllUsesExcept(newIndex,
150                                     SmallPtrSet<Operation *, 1>{newIndex});
151     }
152   }
153 
154   return clonedOp;
155 }
156 
157 struct ShapeDimension {
158   Value shape;
159   unsigned dimension;
160 };
161 
162 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
163 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
164 // guarantees at least one such dimension is found. If multiple candidates exist
165 // they must agree by construction (i.e. have the same size) and we just return
166 // the first one.
167 static ShapeDimension
168 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
169                           bool fromSubViewOpOnly = false) {
170   auto maps = op.indexing_maps();
171   // Iterate over the inputs and outputs in order.
172   // Extract the subranges from the linearized ranges.
173   for (auto en : llvm::enumerate(op.getShapedOperands())) {
174     // The method `getRangeFromOperandShape` requires using SubViewOp or
175     // SubTensorOps. If the value isnt defined from there continue.
176     // todo: The method should be adapted to get the values from
177     // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
178     // currently returns a `linalg.range`. The fix here is to move this op to
179     // `std` dialect and add the method to `ViewInterface`.
180     if (fromSubViewOpOnly &&
181         !isa_and_nonnull<SubViewOp, SubTensorOp>(en.value().getDefiningOp()))
182       continue;
183 
184     unsigned idx = en.index();
185     auto map = maps[idx].cast<AffineMapAttr>().getValue();
186     LLVM_DEBUG(llvm::dbgs()
187                << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
188     LLVM_DEBUG(llvm::dbgs()
189                << "getShapeDefiningLoopRange map: " << map << "\n");
190     Value shape = en.value();
191     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
192     for (auto en2 : llvm::enumerate(map.getResults())) {
193       auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
194       if (!dimExpr)
195         continue;
196       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
197         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
198                                 << loopDepth << "\n");
199         LLVM_DEBUG(llvm::dbgs()
200                    << "getShapeDefiningLoopRange shape: " << shape << "\n");
201         return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
202       }
203     }
204   }
205   llvm_unreachable("Expect to be able to extract a shape defining loop range");
206 }
207 
208 /// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
209 /// provides the loop range information for the fused loops. The rest are
210 /// obtained from the producer itself, since they are not tiled + fused.
211 static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
212                      const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
213 
214   unsigned nPar = producer.getNumParallelLoops();
215   unsigned nRed = producer.getNumReductionLoops();
216   unsigned nWin = producer.getNumWindowLoops();
217   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
218   for (auto fusedLoops : fusedLoopsAndRanges)
219     loopRanges[fusedLoops.first] = fusedLoops.second;
220 
221   // Iterate over all dimensions. For the dimensions not identified by the
222   // producer map for `producerIdx`, we need to explicitly compute the shape
223   // that defines the loop ranges using the `producer`.
224   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
225     if (loopRanges[i].offset)
226       LLVM_DEBUG(llvm::dbgs()
227                  << "existing LoopRange: " << loopRanges[i] << "\n");
228     else {
229       auto shapeDim = getShapeDefiningLoopRange(producer, i);
230       loopRanges[i] = Range{std_constant_index(0),
231                             std_dim(shapeDim.shape, shapeDim.dimension),
232                             std_constant_index(1)};
233       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
234     }
235   }
236 
237   return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
238 }
239 
240 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
241 /// expected to be defined by a subview op or a subtensor op.
242 static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
243                                       Value shapedOperand, unsigned dim) {
244   Operation *shapeProducingOp = shapedOperand.getDefiningOp();
245   if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
246     return subViewOp.getOrCreateRanges(b, loc)[dim];
247   if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
248     return subTensorOp.getOrCreateRanges(b, loc)[dim];
249   llvm_unreachable("SubviewOp or SubTensorOp expected");
250 }
251 
252 /// Fuses the producer of `producerIdx` into the loop immediately enclosing
253 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it
254 /// is needed just before the `consumer.
255 ///
256 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
257 /// 2 cases:
258 ///   1. Buffer case: `producerIdx` is the index of the buffer in
259 ///      `producer.getOutputBuffers()`.
260 ///   2. Tensor case: `producerIdx` is the index of the tensor in
261 ///      `producer.getResults()`.
262 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
263                      OpOperand &consumerOpOperand) {
264   LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
265   DenseMap<unsigned, Range> fusedLoopsAndRanges;
266   Value shapedOperand = consumerOpOperand.get();
267   for (auto en : llvm::enumerate(producerMap.getResults())) {
268     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
269     fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
270         b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
271   }
272   return fuse(b, producerOp, fusedLoopsAndRanges);
273 }
274 
275 // Encode structural fusion safety preconditions.
276 // Some of these will be lifted in the future with better analysis.
277 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
278                                           LinalgOp consumer) {
279   assert(producer.hasBufferSemantics() &&
280          "expected linalg op with buffer semantics");
281   assert(consumer.hasBufferSemantics() &&
282          "expected linalg op with buffer semantics");
283   if (producer.getNumOutputs() != 1) {
284     LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
285     return false;
286   }
287   // Only fuse when the producer block dominates.
288   DominanceInfo dom(producer.getOperation());
289   if (!dom.dominates(producer->getBlock(), consumer->getBlock())) {
290     LLVM_DEBUG(
291         llvm::dbgs()
292         << "\nNot structurally fusable (producer block does not dominate)");
293     return false;
294   }
295   return true;
296 }
297 
298 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
299                                              LinalgOp consumer,
300                                              Value consumedView,
301                                              LinalgOp producer) {
302   assert(producer.hasBufferSemantics() &&
303          "expected linalg op with buffer semantics");
304   assert(consumer.hasBufferSemantics() &&
305          "expected linalg op with buffer semantics");
306   // Make some simple structural checks that alleviate the need for more
307   // complex analyses.
308   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
309     LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
310                             << *producer.getOperation());
311     return false;
312   }
313   // Check for any interleaved write to consumedView.
314   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
315     LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
316                             << *producer.getOperation());
317     return false;
318   }
319   return true;
320 }
321 
322 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
323                                  LinalgOp consumer, Value consumedView,
324                                  LinalgOp producer) {
325   assert(producer.hasBufferSemantics() &&
326          "expected linalg op with buffer semantics");
327   assert(consumer.hasBufferSemantics() &&
328          "expected linalg op with buffer semantics");
329   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
330     return false;
331   // Check for any fusion-preventing dependence to any shape read/written that
332   // would violate dependences.
333   if (!graph.findCoveringDependences(producer, consumer).empty()) {
334     LLVM_DEBUG(llvm::dbgs()
335                << "\n***Not fusable due to an interleaved dependence:\t"
336                << *producer.getOperation());
337     return false;
338   }
339   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
340     // TODO: add a level of indirection to linalg.generic.
341     if (convOp.padding())
342       return false;
343   }
344   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
345     // TODO: add a level of indirection to linalg.generic.
346     if (convOp.padding())
347       return false;
348   }
349   return true;
350 }
351 
352 /// For `consumer` with buffer semantics, find the Linalg operation on buffers
353 /// that is the last writer of `consumerOpOperand`. For now the fusable
354 /// dependence is returned as an instance of the `dependenceGraph`.
355 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
356 findFusableProducer(OpOperand &consumerOpOperand,
357                     const LinalgDependenceGraph &dependenceGraph) {
358   LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
359   if (!consumerOp)
360     return {};
361 
362   // Only consider RAW and WAW atm.
363   for (auto depType : {
364            LinalgDependenceGraph::DependenceType::RAW,
365            LinalgDependenceGraph::DependenceType::WAW,
366        }) {
367     for (auto dependence : llvm::make_filter_range(
368              dependenceGraph.getDependencesInto(consumerOp, depType),
369              [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
370                Value v = elem.getIndexingValue();
371                Optional<unsigned> operandNum =
372                    elem.getIndexingOpViewOperandNum();
373                return isa<LinalgOp>(elem.getDependentOp()) &&
374                       v == consumerOpOperand.get() && operandNum &&
375                       operandNum.getValue() ==
376                           consumerOpOperand.getOperandNumber();
377              })) {
378       // Consumer consumes this view, `isStructurallyFusableProducer` also
379       // checks whether it is a strict subview of the producer view.
380       auto producer = cast<LinalgOp>(dependence.getDependentOp());
381       LLVM_DEBUG(llvm::dbgs()
382                  << "\n"
383                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
384                  << "producer: " << *dependence.getDependentOp()
385                  << " view: " << dependence.getDependentValue() << "\n");
386 
387       // If the producer and consumer have tensor semantics, the only dependence
388       // between them is through a RAW dependence and they are fusable by
389       // construction. For buffer semantics need additional checks.
390       if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() &&
391           isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
392                         producer))
393         return dependence;
394       if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) {
395         assert(dependence.dependenceType ==
396                LinalgDependenceGraph::DependenceType::RAW);
397         return dependence;
398       }
399     }
400   }
401   return {};
402 }
403 
404 Optional<FusionInfo>
405 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
406                                    const LinalgDependenceGraph &graph) {
407   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
408       findFusableProducer(consumerOpOperand, graph);
409   if (!fusableDependence)
410     return llvm::None;
411 
412   LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
413   if (!producerOp)
414     return llvm::None;
415 
416   // If producer is already in the same block as consumer, we are done.
417   if (consumerOpOperand.get().getParentBlock() ==
418       fusableDependence->getDependentValue().getParentBlock())
419     return llvm::None;
420 
421   Optional<AffineMap> producerMap =
422       fusableDependence->getDependentOpViewIndexingMap();
423   if (!producerMap)
424     return llvm::None;
425 
426   // Must be a subview or a slice to guarantee there are loops we can fuse
427   // into.
428   auto subView = consumerOpOperand.get().getDefiningOp<SubViewOp>();
429   if (!subView) {
430     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)");
431     return llvm::None;
432   }
433 
434   // Fuse `producer` just before `consumer`.
435   OpBuilder::InsertionGuard g(b);
436   b.setInsertionPoint(consumerOpOperand.getOwner());
437   ScopedContext scope(b, consumerOpOperand.getOwner()->getLoc());
438   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: "
439                           << *consumerOpOperand.getOwner() << "\n");
440 
441   auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand);
442   return FusionInfo{producerOp, fusedProducer};
443 }
444 
445 /// Walk back use-def chain through scf::For yields.
446 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
447 
448 // TODO(ravishankarm, ntv): This can be moved into the dependence graphs
449 // dependence tracking since the dependence tracking is similar to what is done
450 // w.r.t to buffers.
451 static void getProducerOfTensor(Value tensor, OpResult &opResult) {
452   if (!tensor.getType().isa<RankedTensorType>())
453     return;
454 
455   while (true) {
456     LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
457     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
458       opResult = tensor.cast<OpResult>();
459       return;
460     }
461     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
462       tensor = subTensorOp.source();
463       continue;
464     }
465     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
466       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
467         tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
468         continue;
469       }
470     }
471     return;
472   }
473 }
474 
475 Optional<FusionInfo>
476 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
477   Value inputTensor = consumerOpOperand.get();
478   OpResult producerOpResult;
479   getProducerOfTensor(inputTensor, producerOpResult);
480   if (!producerOpResult) {
481     LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
482     return {};
483   }
484   return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
485 }
486 
487 Optional<FusionInfo>
488 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
489                                    OpOperand &consumerOpOperand) {
490   auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
491   if (!producerOp)
492     return llvm::None;
493 
494   LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
495   if (!consumerOp)
496     return llvm::None;
497 
498   Value inputTensor = consumerOpOperand.get();
499 
500   // Must be a subtensor to guarantee there are loops we can fuse into.
501   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
502   if (!subTensor) {
503     LLVM_DEBUG(llvm::dbgs()
504                << "\nNot fusable, not a subtensor: " << inputTensor);
505     return {};
506   }
507 
508   // If producer is already in the same block as consumer, we are done.
509   if (consumerOpOperand.get().getParentBlock() ==
510       producerOpResult.getParentBlock())
511     return {};
512 
513   // Insert fused `producer` just before `consumer`.
514   OpBuilder::InsertionGuard g(b);
515   b.setInsertionPoint(consumerOp);
516   ScopedContext scope(b, consumerOp->getLoc());
517   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
518   LinalgOp fusedProducer =
519       fuse(b, producerOp,
520            producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()),
521            consumerOpOperand);
522 
523   // Replace use.
524   // Canonicalizations are not guaranteed to have happened before constructing
525   // `fusedProducer`. In the tensor case this can result in temporary type
526   // mismatches. Insert a `tensor.cast` op to propagate the transformation
527   // invariant that types are compatible.
528   Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
529   Type consumerType = consumerOpOperand.get().getType();
530   if (consumerType != def.getType())
531     def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
532   consumerOpOperand.set(def);
533   return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
534 }
535 
536 /// Prune all dimensions that are of reduction iterator type from `map`.
537 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
538                                            AffineMap map) {
539   llvm::SmallDenseSet<unsigned> projectedDims;
540   for (auto attr : llvm::enumerate(iteratorTypes)) {
541     if (!isParallelIterator(attr.value()))
542       projectedDims.insert(attr.index());
543   }
544   return getProjectedMap(map, projectedDims);
545 }
546 
547 /// Returns the mapping from iterations in the consumer that write to the same
548 /// location as the iterations in the producer. To do so use
549 /// - indexing map of the fused view in the consumer : consumerIndexMap
550 /// - indexing map of the fused view in the producer : producerIndexMap
551 ///     consumerLoopToProducerLoop =
552 ///       inverse(producerIndexMap).compose(consumerIndexMap)
553 static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
554     LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
555   auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp());
556   if (!producer)
557     return None;
558 
559   Optional<AffineMap> producerIndexingMap =
560       dependence.getDependentOpViewIndexingMap();
561   Optional<AffineMap> consumerIndexingMap =
562       dependence.getIndexingOpViewIndexingMap();
563   if (!producerIndexingMap || !consumerIndexingMap)
564     return None;
565 
566   AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
567       producer.iterator_types().getValue(), *producerIndexingMap);
568   if (!prunedProducerIndexingMap.isPermutation())
569     return None;
570 
571   if (consumerIndexingMap->getNumResults() !=
572       prunedProducerIndexingMap.getNumResults())
573     return None;
574 
575   LLVM_DEBUG({
576     llvm::dbgs() << "\t producerMap : ";
577     producerIndexingMap->print(llvm::dbgs());
578     llvm::dbgs() << "  pruned : ";
579     prunedProducerIndexingMap.print(llvm::dbgs());
580     llvm::dbgs() << "\n";
581     llvm::dbgs() << "\t consumerMap : ";
582     consumerIndexingMap->print(llvm::dbgs());
583     llvm::dbgs() << "\n";
584   });
585 
586   AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
587   if (!invProducerIndexMap)
588     return None;
589 
590   return invProducerIndexMap.compose(*consumerIndexingMap);
591 }
592 
593 /// Given a projected permutation `map`, returns true if the map changes the
594 /// order in which the fused loop dimension appear.
595 static bool doesTransposeAccess(AffineMap map,
596                                 const std::set<unsigned> &fusableLoops) {
597   Optional<unsigned> lastFusableLoop;
598   for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
599          return expr.cast<AffineDimExpr>().getPosition();
600        })) {
601     if (!fusableLoops.count(pos))
602       continue;
603     if (!lastFusableLoop) {
604       lastFusableLoop = pos;
605       continue;
606     }
607     if (pos <= lastFusableLoop.getValue())
608       return true;
609     lastFusableLoop = pos;
610   }
611   return false;
612 }
613 
614 /// Returns the positions of the loop in `op` that can be tiled based on the
615 /// operations that are to be fused with it. For example, in a
616 ///
617 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
618 ///
619 /// if the producer of %a needs to be fused with this op, only the `i` loop of
620 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
621 /// fused, then no loops can be tiled while fusing. The conditions used are:
622 /// 1. Only parallel loops can be used for tile + fuse. Find the number of
623 ///    common outer parallel loops between the op and its producers being fused.
624 /// 2. Of the parallel loops only some can be fused. Only those loops can be
625 ///    fused such where the fusable loops iteration space only touches one tile
626 ///    of the fused operation. This is because the producer (which is writing
627 ///    the fused subview) has update semantics.
628 ///
629 /// Since an inverse computation is needed, we need to consider the projection
630 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
631 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to
632 /// parallel loops and appear in the result of the map
633 ///
634 /// Example 1:
635 ///   linalg.fill(%c, %cst)
636 ///   linalg.matmul ins(%a, %b) outs(%c)
637 ///     Number of parallel loops : 2
638 ///     producerIndexMap = affine_map<(i, j) ->(i , j)>
639 ///     consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
640 ///     consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
641 ///     Fused dimensions : i, j
642 ///
643 /// Example 2:
644 ///   linalg.matmul ins(%a, %b) outs(%c)
645 ///   linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
646 ///                   iterator_types = ["parallel", "parallel"]}
647 ///     ins(%c) ...
648 ///
649 ///     Number of parallel loops = 2:
650 ///     producerIndexMap (projected to parallel loops) =
651 ///       affine_map<(i, j) -> (i, j)>
652 ///     consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
653 ///     Fused dimensions : i, j
654 ///
655 /// Example 3:
656 ///   linalg.copy(%s, %b)
657 ///   linalg.matmul ins(%a, %b) outs(%c)
658 ///
659 ///   Number of parallel loops = 2
660 ///   produceIndexMap : affine_map<(i, j) -> (i, j)>
661 ///   consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
662 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
663 ///   Fused dimensions : j
664 static std::set<unsigned>
665 collectFusableLoops(ArrayRef<LinalgOp> ops,
666                     const FusableOpDependencesTy &fusableDependences) {
667   assert(!ops.empty());
668   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
669     return linalgOp.iterator_types()
670         .getValue()
671         .take_while([](Attribute attr) -> bool {
672           return attr.cast<StringAttr>().getValue() ==
673                  getParallelIteratorTypeName();
674         })
675         .size();
676   };
677 
678   size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
679   for (auto op : ops.drop_back()) {
680     numOuterParallelLoops =
681         std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
682   }
683 
684   std::set<unsigned> fusableLoops;
685   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
686   fusableLoops.insert(range.begin(), range.end());
687 
688   for (auto op : reverse(ops)) {
689     for (auto dependence : fusableDependences.lookup(op)) {
690       LLVM_DEBUG({
691         llvm::dbgs() << "\t fusable :";
692         for (unsigned i : fusableLoops)
693           llvm::dbgs() << " " << i;
694         llvm::dbgs() << "\n";
695       });
696 
697       Optional<AffineMap> consumerLoopToProducerLoop =
698           getConsumerLoopToProducerLoopMap(dependence);
699       if (!consumerLoopToProducerLoop) {
700         op.emitRemark("failed to get map from consumer loop to producer loop");
701         return {};
702       }
703       // todo: This condition is only an implementation limitation. When fusing
704       // the operation, if the accesses in the producer/consumer are transposes
705       // of each other, the loop bounds for the tiled producer can be
706       // manipulated accordingly. This requires some additional bookkeeping in
707       // the implementation of tile+fuse that is deferred to later.
708       if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
709         op.emitRemark("unhandled fusion when fusion requires permutation");
710         return {};
711       }
712 
713       std::set<unsigned> candidates;
714       for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
715         unsigned position = expr.cast<AffineDimExpr>().getPosition();
716         if (fusableLoops.count(position))
717           candidates.insert(position);
718       }
719       LLVM_DEBUG({
720         llvm::dbgs() << "\t candidates :";
721         for (unsigned i : candidates)
722           llvm::dbgs() << " " << i;
723         llvm::dbgs() << "\n";
724       });
725       if (candidates.empty())
726         return {};
727       std::swap(candidates, fusableLoops);
728     }
729   }
730 
731   return fusableLoops;
732 }
733 
734 /// Find all dependences that are fusable.
735 FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
736     ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
737   FusableOpDependencesTy fusableDependences;
738   DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
739   for (LinalgOp op : reverse(ops)) {
740     for (OpOperand &opOperand : op.getShapedOpOperands()) {
741       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
742           fusableDependence = findFusableProducer(opOperand, dependenceGraph);
743       if (!fusableDependence)
744         continue;
745       LinalgOp producerOp =
746           dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
747       if (!producerOp)
748         continue;
749       // Do not fuse dependences that are to operations not in the same basic
750       // block. This avoid moving fused operations across loops that might
751       // themselves carry dependency making the fusion illegal.
752       if (producerOp->getBlock() != op->getBlock())
753         continue;
754 
755       // Make sure that the indexing map of the view used for fusion in the
756       // producer is a projected permutation.
757       Optional<AffineMap> producerMap =
758           fusableDependence->getDependentOpViewIndexingMap();
759       Optional<AffineMap> consumerMap =
760           fusableDependence->getIndexingOpViewIndexingMap();
761       assert(
762           consumerMap &&
763           "unable to find indexing map of operand/result of indexing OpView");
764       fusedProducerIndexingMap[producerOp.getOperation()].push_back(
765           *consumerMap);
766       if (!producerMap || !producerMap->isProjectedPermutation() ||
767           !consumerMap->isProjectedPermutation())
768         continue;
769 
770       fusableDependences[producerOp.getOperation()].push_back(
771           *fusableDependence);
772     }
773   }
774   // TODO: Currently fusion would not be legal if the fusable dependence is to
775   // the same producer but different indexing map in the consumer. Fix this, but
776   // in the meanwhile disallow such a fusion.
777   for (auto useIndexingMapsList : fusedProducerIndexingMap) {
778     AffineMap map1 = useIndexingMapsList.second.front();
779     for (AffineMap map2 :
780          ArrayRef<AffineMap>(useIndexingMapsList.second).drop_front()) {
781       if (map1 != map2) {
782         fusableDependences.erase(useIndexingMapsList.first);
783         break;
784       }
785     }
786   }
787   return fusableDependences;
788 }
789 
790 /// Tile the fused loops in the root operation, by setting the tile sizes for
791 /// all other loops to zero (those will be tiled later).
792 static Optional<TiledLinalgOp> tileRootOperation(
793     OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
794     const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
795   SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
796   auto zero = std_constant_index(0);
797   for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
798     if (!fusedLoops.count(i))
799       tileSizes[i] = zero;
800   LinalgTilingOptions tileFusedLoopsOptions = options;
801   tileFusedLoopsOptions.setTileSizes(tileSizes);
802   return tileLinalgOp(builder, op, tileFusedLoopsOptions);
803 }
804 
805 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
806 /// to be a tiled operation such that it is valid to fuse all operations in
807 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
808 /// `tiledOp`.
809 static SmallVector<LinalgOp, 1>
810 fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp,
811                ArrayRef<LinalgOp> fusionCandidates,
812                const FusableOpDependencesTy &fusableDependences,
813                const std::set<unsigned> &fusedLoops) {
814   OpBuilder::InsertionGuard guard(builder);
815   builder.setInsertionPoint(tiledOp);
816   DenseMap<unsigned, Range> fusedLoopsAndRanges;
817   for (unsigned loop : fusedLoops) {
818     ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
819     fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
820         builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
821   }
822 
823   SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
824   DenseMap<Operation *, LinalgOp> origOpToFusedOp;
825   origOpToFusedOp[rootOp.getOperation()] = tiledOp;
826   for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
827     LinalgOp origOp = candidate.value();
828     LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges);
829     origOpToFusedOp[origOp.getOperation()] = fusedOp;
830     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
831     // If the producer consumer operations are linalg operations on tensors, the
832     // dependence is due to value produced (as a return tensor) by the producer
833     // and used in the consumer. The returned value of the fused op needs to be
834     // made the operand of the tiled/fused consumer operation. By construction
835     // the value returned by the producer is the value used by the consumer.
836     for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) {
837       if (origOp.hasTensorSemantics() &&
838           dependence.dependenceType ==
839               LinalgDependenceGraph::DependenceType::RAW) {
840         unsigned resultIndex =
841             dependence.getDependentOpViewResultNum().getValue();
842         LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
843         if (!consumer)
844           continue;
845         Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
846         consumer.getOperation()->setOperand(
847             dependence.getIndexingOpViewOperandNum().getValue(),
848             replacementValue);
849       }
850     }
851     builder.setInsertionPoint(fusedOp);
852   }
853   return fusedOps;
854 }
855 
856 template <typename LoopType>
857 static Optional<TiledAndFusedLinalgOps>
858 tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
859                          const LinalgDependenceGraph &dependenceGraph,
860                          const LinalgTilingOptions &tilingOptions) {
861   if (ops.size() < 2)
862     return llvm::None;
863   LinalgOp rootOp = ops.back();
864   if (!llvm::all_of(
865           ops,
866           [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) &&
867       !llvm::all_of(ops, [](LinalgOp linalgOp) {
868         return linalgOp.hasTensorSemantics();
869       })) {
870     rootOp.emitError(
871         "unable to fuse operations that have tensor semantics with operations "
872         "that have buffer semantics and viceversa.");
873     return llvm::None;
874   }
875   // TODO: Support interchange with tile + fuse. This might actually help do
876   // better fusion.
877   if (!tilingOptions.interchangeVector.empty()) {
878     rootOp.emitRemark("unable to handle tile and fuse with interchange");
879     return llvm::None;
880   }
881 
882   OpBuilder::InsertionGuard guard(builder);
883   builder.setInsertionPoint(rootOp);
884   ScopedContext scope(builder, rootOp.getLoc());
885 
886   // Find all the producers.
887   FusableOpDependencesTy fusableDependences =
888       findAllFusableDependences(ops, dependenceGraph);
889   if (fusableDependences.empty())
890     return llvm::None;
891 
892   TiledAndFusedLinalgOps ret;
893   // Find the loops that can be tiled and fused.
894   ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
895 
896   // If there are no fusable dependences or there are no tile+fusable loops,
897   // just return.
898   if (ret.fusedLoopDims.empty()) {
899     return llvm::None;
900   }
901 
902   // Tile the fused loops in the last operation in the list.
903   SmallVector<Value, 4> tileSizeVector =
904       tilingOptions.tileSizeComputationFunction(builder, rootOp);
905   Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
906       builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
907   if (!tiledRootOp) {
908     rootOp.emitRemark("failed to tile the fused loops");
909     return llvm::None;
910   }
911   ret.op = tiledRootOp->op;
912   ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
913 
914   // Fuse the other operations into the fused inter-tile loops produced above.
915   ret.fusedProducers = fuseOperations(builder, rootOp, ret.op, ops.drop_back(),
916                                       fusableDependences, ret.fusedLoopDims);
917 
918   return ret;
919 }
920 
921 Optional<TiledAndFusedLinalgOps>
922 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
923                                    const LinalgDependenceGraph &dependenceGraph,
924                                    const LinalgTilingOptions &tilingOptions) {
925   switch (tilingOptions.loopType) {
926   case LinalgTilingLoopType::Loops:
927     return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
928                                                 tilingOptions);
929   case LinalgTilingLoopType::ParallelLoops:
930     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
931         builder, ops, dependenceGraph, tilingOptions);
932   default:;
933   }
934   return llvm::None;
935 }
936