1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/IR/Dominance.h"
28 #include "mlir/Support/LLVM.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "mlir/Transforms/RegionUtils.h"
31 #include "llvm/ADT/MapVector.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 
35 #include <set>
36 
37 #define DEBUG_TYPE "linalg-fusion"
38 
39 using namespace mlir;
40 using namespace mlir::edsc;
41 using namespace mlir::edsc::intrinsics;
42 using namespace mlir::linalg;
43 
44 using llvm::dbgs;
45 
46 /// Implements a simple high-level fusion pass on linalg structured operations.
47 ///
48 /// In each block, linalg ops are processed in reverse textual order.
49 /// Given a linalg op `O`, fusion occurs by:
50 ///   1. inspecting the linalg ops that write into the views read by `O`. There
51 ///      are 2 cases:
52 ///      a) buffer case: use the SSA value of the views and a simple alias
53 ///         analysis on subview ops to determine producer-consumer dependences;
54 ///      b) tensor case: use SSA use-def chains on subtensor ops;
55 ///   2. greedily fuse the linalg ops that produce the subview/subtensor.
56 ///   3. inspect the fused ops and determine whether they have other remaining
57 ///      LinalgOp uses. If not, then erase the original producing linalg op.
58 ///
59 /// More advanced use cases, analyses as well as profitability heuristics are
60 /// left for future work.
61 
62 struct ShapeDimension {
63   Value shape;
64   unsigned dimension;
65 };
66 
67 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
68 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
69 // guarantees at least one such dimension is found. If multiple candidates exist
70 // they must agree by construction (i.e. have the same size) and we just return
71 // the first one.
72 static ShapeDimension
73 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
74                           bool fromSubViewOpOnly = false) {
75   auto maps = op.indexing_maps();
76   // Iterate over the inputs and outputs in order.
77   // Extract the subranges from the linearized ranges.
78   for (auto en : llvm::enumerate(op.getShapedOperands())) {
79     // The method `getRangeFromOperandShape` requires using SubViewOp or
80     // SubTensorOps. If the value isnt defined from there continue.
81     // todo: The method should be adapted to get the values from
82     // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
83     // currently returns a `linalg.range`. The fix here is to move this op to
84     // `std` dialect and add the method to `ViewInterface`.
85     if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>(
86                                  en.value().getDefiningOp()))
87       continue;
88 
89     unsigned idx = en.index();
90     auto map = maps[idx].cast<AffineMapAttr>().getValue();
91     LLVM_DEBUG(llvm::dbgs()
92                << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
93     LLVM_DEBUG(llvm::dbgs()
94                << "getShapeDefiningLoopRange map: " << map << "\n");
95     Value shape = en.value();
96     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
97     for (auto en2 : llvm::enumerate(map.getResults())) {
98       auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
99       if (!dimExpr)
100         continue;
101       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
102         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
103                                 << loopDepth << "\n");
104         LLVM_DEBUG(llvm::dbgs()
105                    << "getShapeDefiningLoopRange shape: " << shape << "\n");
106         return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
107       }
108     }
109   }
110   llvm_unreachable("Expect to be able to extract a shape defining loop range");
111 }
112 
113 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
114 /// provides the loop range information for the fused loops. The rest are
115 /// obtained from the producer itself, since they are not tiled + fused.
116 static LinalgOp fuse(OpBuilder &builder, LinalgOp producer,
117                      const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
118   SmallVector<Value, 8> ivs, tileSizes, sizeBounds;
119   SmallVector<Range, 8> loopRanges;
120   auto zero = std_constant_index(0);
121   auto one = std_constant_index(1);
122   Location loc = producer.getLoc();
123 
124   for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
125     auto it = fusedLoopsAndRanges.find(i);
126     if (it != fusedLoopsAndRanges.end()) {
127       ivs.push_back(it->second.offset);
128       tileSizes.push_back(it->second.size);
129       sizeBounds.push_back(nullptr);
130       loopRanges.push_back(it->second);
131       LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
132                               << loopRanges.back() << "\n");
133     } else {
134       auto shapeDim = getShapeDefiningLoopRange(producer, i);
135       Value dim = memref_dim(shapeDim.shape, shapeDim.dimension);
136       tileSizes.push_back(zero);
137       sizeBounds.push_back(dim);
138       loopRanges.push_back(Range{zero, dim, one});
139       LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
140                               << loopRanges.back() << "\n");
141     }
142   }
143 
144   SmallVector<Value, 8> clonedShapes;
145   clonedShapes.reserve(producer.getNumShapedOperands());
146 
147   // Compute subranges for all tensor input/output operands.
148   auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands());
149   clonedShapes.append(makeTiledShapes(builder, loc, producer, tiledOperands,
150                                       ivs, tileSizes, sizeBounds));
151 
152   // Append the other operands.
153   auto operands = producer.getAssumedNonShapedOperands();
154   clonedShapes.append(operands.begin(), operands.end());
155 
156   // Iterate over the results in order.
157   // Extract the subtensor type from the linearized range.
158   // Since we do not enforce any canonicalizations on the fly, this is always
159   // fully dynamic at construction time.
160   SmallVector<Type, 4> resultTypes;
161   resultTypes.reserve(producer->getNumResults());
162   for (RankedTensorType t : producer.getOutputTensorTypes()) {
163     unsigned rank = t.getRank();
164     SmallVector<int64_t, 4> staticOffsetsVector(
165         rank, ShapedType::kDynamicStrideOrOffset);
166     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
167     SmallVector<int64_t, 4> staticStridesVector(
168         rank, ShapedType::kDynamicStrideOrOffset);
169     resultTypes.push_back(SubTensorOp::inferResultType(
170         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
171         staticStridesVector));
172   }
173 
174   Operation *clonedOp = producer.clone(builder, loc, resultTypes, clonedShapes);
175   // When the producer is an IndexedGenericOp, we have to transform its block
176   // IV arguments according to the tiling of the consumer, i.e. offset them by
177   // the values computed in `loopRanges`.
178   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
179     auto &block = indexedGenericOp.region().front();
180     OpBuilder::InsertionGuard g(builder);
181     builder.setInsertionPointToStart(&block);
182     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
183       Value oldIndex = block.getArgument(i);
184       // TODO: replace by an affine_apply.
185       AddIOp newIndex = builder.create<AddIOp>(indexedGenericOp.getLoc(),
186                                                oldIndex, loopRanges[i].offset);
187       oldIndex.replaceAllUsesExcept(newIndex,
188                                     SmallPtrSet<Operation *, 1>{newIndex});
189     }
190   }
191 
192   return clonedOp;
193 }
194 
195 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
196 /// expected to be defined by a subview op or a subtensor op.
197 static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
198                                       Value shapedOperand, unsigned dim) {
199   Operation *shapeProducingOp = shapedOperand.getDefiningOp();
200   if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp))
201     return subViewOp.getOrCreateRanges(b, loc)[dim];
202   if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
203     return subTensorOp.getOrCreateRanges(b, loc)[dim];
204   llvm_unreachable("SubviewOp or SubTensorOp expected");
205 }
206 
207 /// Fuses the producer of `producerIdx` into the loop immediately enclosing
208 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it
209 /// is needed just before the `consumer.
210 ///
211 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
212 /// 2 cases:
213 ///   1. Buffer case: `producerIdx` is the index of the buffer in
214 ///      `producer.getOutputBuffers()`.
215 ///   2. Tensor case: `producerIdx` is the index of the tensor in
216 ///      `producer.getResults()`.
217 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
218                      OpOperand &consumerOpOperand) {
219   LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
220   DenseMap<unsigned, Range> fusedLoopsAndRanges;
221   Value shapedOperand = consumerOpOperand.get();
222   for (auto en : llvm::enumerate(producerMap.getResults())) {
223     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
224     fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
225         b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
226   }
227   return fuse(b, producerOp, fusedLoopsAndRanges);
228 }
229 
230 // Encode structural fusion safety preconditions.
231 // Some of these will be lifted in the future with better analysis.
232 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
233                                           LinalgOp consumer) {
234   assert(producer.hasBufferSemantics() &&
235          "expected linalg op with buffer semantics");
236   assert(consumer.hasBufferSemantics() &&
237          "expected linalg op with buffer semantics");
238   if (producer.getNumOutputs() != 1) {
239     LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
240     return false;
241   }
242   // Only fuse when the producer block dominates.
243   DominanceInfo dom(producer.getOperation());
244   if (!dom.dominates(producer->getBlock(), consumer->getBlock())) {
245     LLVM_DEBUG(
246         llvm::dbgs()
247         << "\nNot structurally fusable (producer block does not dominate)");
248     return false;
249   }
250   return true;
251 }
252 
253 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
254                                              LinalgOp consumer,
255                                              Value consumedView,
256                                              LinalgOp producer) {
257   assert(producer.hasBufferSemantics() &&
258          "expected linalg op with buffer semantics");
259   assert(consumer.hasBufferSemantics() &&
260          "expected linalg op with buffer semantics");
261   // Make some simple structural checks that alleviate the need for more
262   // complex analyses.
263   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
264     LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
265                             << *producer.getOperation());
266     return false;
267   }
268   // Check for any interleaved write to consumedView.
269   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
270     LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
271                             << *producer.getOperation());
272     return false;
273   }
274   return true;
275 }
276 
277 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
278                                  LinalgOp consumer, Value consumedView,
279                                  LinalgOp producer) {
280   assert(producer.hasBufferSemantics() &&
281          "expected linalg op with buffer semantics");
282   assert(consumer.hasBufferSemantics() &&
283          "expected linalg op with buffer semantics");
284   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
285     return false;
286   // Check for any fusion-preventing dependence to any shape read/written that
287   // would violate dependences.
288   if (!graph.findCoveringDependences(producer, consumer).empty()) {
289     LLVM_DEBUG(llvm::dbgs()
290                << "\n***Not fusable due to an interleaved dependence:\t"
291                << *producer.getOperation());
292     return false;
293   }
294   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
295     // TODO: add a level of indirection to linalg.generic.
296     if (convOp.padding())
297       return false;
298   }
299   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
300     // TODO: add a level of indirection to linalg.generic.
301     if (convOp.padding())
302       return false;
303   }
304   return true;
305 }
306 
307 /// For `consumer` with buffer semantics, find the Linalg operation on buffers
308 /// that is the last writer of `consumerOpOperand`. For now the fusable
309 /// dependence is returned as an instance of the `dependenceGraph`.
310 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
311 findFusableProducer(OpOperand &consumerOpOperand,
312                     const LinalgDependenceGraph &dependenceGraph) {
313   LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
314   if (!consumerOp)
315     return {};
316 
317   // Only consider RAW and WAW atm.
318   for (auto depType : {
319            LinalgDependenceGraph::DependenceType::RAW,
320            LinalgDependenceGraph::DependenceType::WAW,
321        }) {
322     for (auto dependence : llvm::make_filter_range(
323              dependenceGraph.getDependencesInto(consumerOp, depType),
324              [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
325                Value v = elem.getIndexingValue();
326                Optional<unsigned> operandNum =
327                    elem.getIndexingOpViewOperandNum();
328                return isa<LinalgOp>(elem.getDependentOp()) &&
329                       v == consumerOpOperand.get() && operandNum &&
330                       operandNum.getValue() ==
331                           consumerOpOperand.getOperandNumber();
332              })) {
333       // Consumer consumes this view, `isStructurallyFusableProducer` also
334       // checks whether it is a strict subview of the producer view.
335       auto producer = cast<LinalgOp>(dependence.getDependentOp());
336       LLVM_DEBUG(llvm::dbgs()
337                  << "\n"
338                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
339                  << "producer: " << *dependence.getDependentOp()
340                  << " view: " << dependence.getDependentValue() << "\n");
341 
342       // If the producer and consumer have tensor semantics, the only dependence
343       // between them is through a RAW dependence and they are fusable by
344       // construction. For buffer semantics need additional checks.
345       if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() &&
346           isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
347                         producer))
348         return dependence;
349       if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) {
350         assert(dependence.dependenceType ==
351                LinalgDependenceGraph::DependenceType::RAW);
352         return dependence;
353       }
354     }
355   }
356   return {};
357 }
358 
359 Optional<FusionInfo>
360 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
361                                    const LinalgDependenceGraph &graph) {
362   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
363       findFusableProducer(consumerOpOperand, graph);
364   if (!fusableDependence)
365     return llvm::None;
366 
367   LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
368   if (!producerOp)
369     return llvm::None;
370 
371   // If producer is already in the same block as consumer, we are done.
372   if (consumerOpOperand.get().getParentBlock() ==
373       fusableDependence->getDependentValue().getParentBlock())
374     return llvm::None;
375 
376   Optional<AffineMap> producerMap =
377       fusableDependence->getDependentOpViewIndexingMap();
378   if (!producerMap)
379     return llvm::None;
380 
381   // Must be a subview or a slice to guarantee there are loops we can fuse
382   // into.
383   auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>();
384   if (!subView) {
385     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)");
386     return llvm::None;
387   }
388 
389   // Fuse `producer` just before `consumer`.
390   OpBuilder::InsertionGuard g(b);
391   b.setInsertionPoint(consumerOpOperand.getOwner());
392   ScopedContext scope(b, consumerOpOperand.getOwner()->getLoc());
393   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: "
394                           << *consumerOpOperand.getOwner() << "\n");
395 
396   auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand);
397   return FusionInfo{producerOp, fusedProducer};
398 }
399 
400 /// Walk back use-def chain through scf::For yields.
401 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
402 
403 // TODO(ravishankarm, ntv): This can be moved into the dependence graphs
404 // dependence tracking since the dependence tracking is similar to what is done
405 // w.r.t to buffers.
406 static void getProducerOfTensor(Value tensor, OpResult &opResult) {
407   if (!tensor.getType().isa<RankedTensorType>())
408     return;
409 
410   while (true) {
411     LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
412     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
413       opResult = tensor.cast<OpResult>();
414       return;
415     }
416     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
417       tensor = subTensorOp.source();
418       continue;
419     }
420     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
421       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
422         tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
423         continue;
424       }
425     }
426     return;
427   }
428 }
429 
430 Optional<FusionInfo>
431 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
432   Value inputTensor = consumerOpOperand.get();
433   OpResult producerOpResult;
434   getProducerOfTensor(inputTensor, producerOpResult);
435   if (!producerOpResult) {
436     LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
437     return {};
438   }
439   return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
440 }
441 
442 Optional<FusionInfo>
443 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
444                                    OpOperand &consumerOpOperand) {
445   auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
446   if (!producerOp)
447     return llvm::None;
448 
449   LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
450   if (!consumerOp)
451     return llvm::None;
452 
453   Value inputTensor = consumerOpOperand.get();
454 
455   // Must be a subtensor to guarantee there are loops we can fuse into.
456   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
457   if (!subTensor) {
458     LLVM_DEBUG(llvm::dbgs()
459                << "\nNot fusable, not a subtensor: " << inputTensor);
460     return {};
461   }
462 
463   // If producer is already in the same block as consumer, we are done.
464   if (consumerOpOperand.get().getParentBlock() ==
465       producerOpResult.getParentBlock())
466     return {};
467 
468   // Insert fused `producer` just before `consumer`.
469   OpBuilder::InsertionGuard g(b);
470   b.setInsertionPoint(consumerOp);
471   ScopedContext scope(b, consumerOp->getLoc());
472   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
473   LinalgOp fusedProducer =
474       fuse(b, producerOp,
475            producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()),
476            consumerOpOperand);
477 
478   // Replace use.
479   // Canonicalizations are not guaranteed to have happened before constructing
480   // `fusedProducer`. In the tensor case this can result in temporary type
481   // mismatches. Insert a `tensor.cast` op to propagate the transformation
482   // invariant that types are compatible.
483   Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
484   Type consumerType = consumerOpOperand.get().getType();
485   if (consumerType != def.getType())
486     def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
487   consumerOpOperand.set(def);
488   return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
489 }
490 
491 /// Prune all dimensions that are of reduction iterator type from `map`.
492 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
493                                            AffineMap map) {
494   llvm::SmallDenseSet<unsigned> projectedDims;
495   for (auto attr : llvm::enumerate(iteratorTypes)) {
496     if (!isParallelIterator(attr.value()))
497       projectedDims.insert(attr.index());
498   }
499   return getProjectedMap(map, projectedDims);
500 }
501 
502 /// Returns the mapping from iterations in the consumer that write to the same
503 /// location as the iterations in the producer. To do so use
504 /// - indexing map of the fused view in the consumer : consumerIndexMap
505 /// - indexing map of the fused view in the producer : producerIndexMap
506 ///     consumerLoopToProducerLoop =
507 ///       inverse(producerIndexMap).compose(consumerIndexMap)
508 static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
509     LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
510   auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp());
511   if (!producer)
512     return None;
513 
514   Optional<AffineMap> producerIndexingMap =
515       dependence.getDependentOpViewIndexingMap();
516   Optional<AffineMap> consumerIndexingMap =
517       dependence.getIndexingOpViewIndexingMap();
518   if (!producerIndexingMap || !consumerIndexingMap)
519     return None;
520 
521   AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
522       producer.iterator_types().getValue(), *producerIndexingMap);
523   if (!prunedProducerIndexingMap.isPermutation())
524     return None;
525 
526   if (consumerIndexingMap->getNumResults() !=
527       prunedProducerIndexingMap.getNumResults())
528     return None;
529 
530   LLVM_DEBUG({
531     llvm::dbgs() << "\t producerMap : ";
532     producerIndexingMap->print(llvm::dbgs());
533     llvm::dbgs() << "  pruned : ";
534     prunedProducerIndexingMap.print(llvm::dbgs());
535     llvm::dbgs() << "\n";
536     llvm::dbgs() << "\t consumerMap : ";
537     consumerIndexingMap->print(llvm::dbgs());
538     llvm::dbgs() << "\n";
539   });
540 
541   AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
542   if (!invProducerIndexMap)
543     return None;
544 
545   return invProducerIndexMap.compose(*consumerIndexingMap);
546 }
547 
548 /// Given a projected permutation `map`, returns true if the map changes the
549 /// order in which the fused loop dimension appear.
550 static bool doesTransposeAccess(AffineMap map,
551                                 const std::set<unsigned> &fusableLoops) {
552   Optional<unsigned> lastFusableLoop;
553   for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
554          return expr.cast<AffineDimExpr>().getPosition();
555        })) {
556     if (!fusableLoops.count(pos))
557       continue;
558     if (!lastFusableLoop) {
559       lastFusableLoop = pos;
560       continue;
561     }
562     if (pos <= lastFusableLoop.getValue())
563       return true;
564     lastFusableLoop = pos;
565   }
566   return false;
567 }
568 
569 /// Returns the positions of the loop in `op` that can be tiled based on the
570 /// operations that are to be fused with it. For example, in a
571 ///
572 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
573 ///
574 /// if the producer of %a needs to be fused with this op, only the `i` loop of
575 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
576 /// fused, then no loops can be tiled while fusing. The conditions used are:
577 /// 1. Only parallel loops can be used for tile + fuse. Find the number of
578 ///    common outer parallel loops between the op and its producers being fused.
579 /// 2. Of the parallel loops only some can be fused. Only those loops can be
580 ///    fused such where the fusable loops iteration space only touches one tile
581 ///    of the fused operation. This is because the producer (which is writing
582 ///    the fused subview) has update semantics.
583 ///
584 /// Since an inverse computation is needed, we need to consider the projection
585 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
586 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to
587 /// parallel loops and appear in the result of the map
588 ///
589 /// Example 1:
590 ///   linalg.fill(%c, %cst)
591 ///   linalg.matmul ins(%a, %b) outs(%c)
592 ///     Number of parallel loops : 2
593 ///     producerIndexMap = affine_map<(i, j) ->(i , j)>
594 ///     consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
595 ///     consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
596 ///     Fused dimensions : i, j
597 ///
598 /// Example 2:
599 ///   linalg.matmul ins(%a, %b) outs(%c)
600 ///   linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
601 ///                   iterator_types = ["parallel", "parallel"]}
602 ///     ins(%c) ...
603 ///
604 ///     Number of parallel loops = 2:
605 ///     producerIndexMap (projected to parallel loops) =
606 ///       affine_map<(i, j) -> (i, j)>
607 ///     consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
608 ///     Fused dimensions : i, j
609 ///
610 /// Example 3:
611 ///   linalg.copy(%s, %b)
612 ///   linalg.matmul ins(%a, %b) outs(%c)
613 ///
614 ///   Number of parallel loops = 2
615 ///   produceIndexMap : affine_map<(i, j) -> (i, j)>
616 ///   consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
617 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
618 ///   Fused dimensions : j
619 static std::set<unsigned>
620 collectFusableLoops(ArrayRef<LinalgOp> ops,
621                     const FusableOpDependencesTy &fusableDependences) {
622   assert(!ops.empty());
623   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
624     return linalgOp.iterator_types()
625         .getValue()
626         .take_while([](Attribute attr) -> bool {
627           return attr.cast<StringAttr>().getValue() ==
628                  getParallelIteratorTypeName();
629         })
630         .size();
631   };
632 
633   size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
634   for (auto op : ops.drop_back()) {
635     numOuterParallelLoops =
636         std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
637   }
638 
639   std::set<unsigned> fusableLoops;
640   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
641   fusableLoops.insert(range.begin(), range.end());
642 
643   for (auto op : reverse(ops)) {
644     for (auto dependence : fusableDependences.lookup(op)) {
645       LLVM_DEBUG({
646         llvm::dbgs() << "\t fusable :";
647         for (unsigned i : fusableLoops)
648           llvm::dbgs() << " " << i;
649         llvm::dbgs() << "\n";
650       });
651 
652       Optional<AffineMap> consumerLoopToProducerLoop =
653           getConsumerLoopToProducerLoopMap(dependence);
654       if (!consumerLoopToProducerLoop) {
655         op.emitRemark("failed to get map from consumer loop to producer loop");
656         return {};
657       }
658       // todo: This condition is only an implementation limitation. When fusing
659       // the operation, if the accesses in the producer/consumer are transposes
660       // of each other, the loop bounds for the tiled producer can be
661       // manipulated accordingly. This requires some additional bookkeeping in
662       // the implementation of tile+fuse that is deferred to later.
663       if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
664         op.emitRemark("unhandled fusion when fusion requires permutation");
665         return {};
666       }
667 
668       std::set<unsigned> candidates;
669       for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
670         unsigned position = expr.cast<AffineDimExpr>().getPosition();
671         if (fusableLoops.count(position))
672           candidates.insert(position);
673       }
674       LLVM_DEBUG({
675         llvm::dbgs() << "\t candidates :";
676         for (unsigned i : candidates)
677           llvm::dbgs() << " " << i;
678         llvm::dbgs() << "\n";
679       });
680       if (candidates.empty())
681         return {};
682       std::swap(candidates, fusableLoops);
683     }
684   }
685 
686   return fusableLoops;
687 }
688 
689 /// Find all dependences that are fusable.
690 FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
691     ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
692   FusableOpDependencesTy fusableDependences;
693   DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
694   for (LinalgOp op : reverse(ops)) {
695     for (OpOperand &opOperand : op.getShapedOpOperands()) {
696       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
697           fusableDependence = findFusableProducer(opOperand, dependenceGraph);
698       if (!fusableDependence)
699         continue;
700       LinalgOp producerOp =
701           dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
702       if (!producerOp)
703         continue;
704       // Do not fuse dependences that are to operations not in the same basic
705       // block. This avoid moving fused operations across loops that might
706       // themselves carry dependency making the fusion illegal.
707       if (producerOp->getBlock() != op->getBlock())
708         continue;
709 
710       // Make sure that the indexing map of the view used for fusion in the
711       // producer is a projected permutation.
712       Optional<AffineMap> producerMap =
713           fusableDependence->getDependentOpViewIndexingMap();
714       Optional<AffineMap> consumerMap =
715           fusableDependence->getIndexingOpViewIndexingMap();
716       assert(
717           consumerMap &&
718           "unable to find indexing map of operand/result of indexing OpView");
719       fusedProducerIndexingMap[producerOp.getOperation()].push_back(
720           *consumerMap);
721       if (!producerMap || !producerMap->isProjectedPermutation() ||
722           !consumerMap->isProjectedPermutation())
723         continue;
724 
725       fusableDependences[producerOp.getOperation()].push_back(
726           *fusableDependence);
727     }
728   }
729   // TODO: Currently fusion would not be legal if the fusable dependence is to
730   // the same producer but different indexing map in the consumer. Fix this, but
731   // in the meanwhile disallow such a fusion.
732   for (auto useIndexingMapsList : fusedProducerIndexingMap) {
733     AffineMap map1 = useIndexingMapsList.second.front();
734     for (AffineMap map2 :
735          ArrayRef<AffineMap>(useIndexingMapsList.second).drop_front()) {
736       if (map1 != map2) {
737         fusableDependences.erase(useIndexingMapsList.first);
738         break;
739       }
740     }
741   }
742   return fusableDependences;
743 }
744 
745 /// Tile the fused loops in the root operation, by setting the tile sizes for
746 /// all other loops to zero (those will be tiled later).
747 static Optional<TiledLinalgOp> tileRootOperation(
748     OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
749     const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
750   SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
751   auto zero = std_constant_index(0);
752   for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
753     if (!fusedLoops.count(i))
754       tileSizes[i] = zero;
755   LinalgTilingOptions tileFusedLoopsOptions = options;
756   tileFusedLoopsOptions.setTileSizes(tileSizes);
757   return tileLinalgOp(builder, op, tileFusedLoopsOptions);
758 }
759 
760 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
761 /// to be a tiled operation such that it is valid to fuse all operations in
762 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
763 /// `tiledOp`.
764 static SmallVector<LinalgOp, 1>
765 fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp,
766                ArrayRef<LinalgOp> fusionCandidates,
767                const FusableOpDependencesTy &fusableDependences,
768                const std::set<unsigned> &fusedLoops) {
769   OpBuilder::InsertionGuard guard(builder);
770   builder.setInsertionPoint(tiledOp);
771   DenseMap<unsigned, Range> fusedLoopsAndRanges;
772   for (unsigned loop : fusedLoops) {
773     ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
774     fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
775         builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
776   }
777 
778   SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
779   DenseMap<Operation *, LinalgOp> origOpToFusedOp;
780   origOpToFusedOp[rootOp.getOperation()] = tiledOp;
781   for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
782     LinalgOp origOp = candidate.value();
783     LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges);
784     origOpToFusedOp[origOp.getOperation()] = fusedOp;
785     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
786     // If the producer consumer operations are linalg operations on tensors, the
787     // dependence is due to value produced (as a return tensor) by the producer
788     // and used in the consumer. The returned value of the fused op needs to be
789     // made the operand of the tiled/fused consumer operation. By construction
790     // the value returned by the producer is the value used by the consumer.
791     for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) {
792       if (origOp.hasTensorSemantics() &&
793           dependence.dependenceType ==
794               LinalgDependenceGraph::DependenceType::RAW) {
795         unsigned resultIndex =
796             dependence.getDependentOpViewResultNum().getValue();
797         LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
798         if (!consumer)
799           continue;
800         Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
801         consumer.getOperation()->setOperand(
802             dependence.getIndexingOpViewOperandNum().getValue(),
803             replacementValue);
804       }
805     }
806     builder.setInsertionPoint(fusedOp);
807   }
808   return fusedOps;
809 }
810 
811 template <typename LoopType>
812 static Optional<TiledAndFusedLinalgOps>
813 tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
814                          const LinalgDependenceGraph &dependenceGraph,
815                          const LinalgTilingOptions &tilingOptions) {
816   if (ops.size() < 2)
817     return llvm::None;
818   LinalgOp rootOp = ops.back();
819   if (!llvm::all_of(
820           ops,
821           [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) &&
822       !llvm::all_of(ops, [](LinalgOp linalgOp) {
823         return linalgOp.hasTensorSemantics();
824       })) {
825     rootOp.emitError(
826         "unable to fuse operations that have tensor semantics with operations "
827         "that have buffer semantics and viceversa.");
828     return llvm::None;
829   }
830   // TODO: Support interchange with tile + fuse. This might actually help do
831   // better fusion.
832   if (!tilingOptions.interchangeVector.empty()) {
833     rootOp.emitRemark("unable to handle tile and fuse with interchange");
834     return llvm::None;
835   }
836 
837   OpBuilder::InsertionGuard guard(builder);
838   builder.setInsertionPoint(rootOp);
839   ScopedContext scope(builder, rootOp.getLoc());
840 
841   // Find all the producers.
842   FusableOpDependencesTy fusableDependences =
843       findAllFusableDependences(ops, dependenceGraph);
844   if (fusableDependences.empty())
845     return llvm::None;
846 
847   TiledAndFusedLinalgOps ret;
848   // Find the loops that can be tiled and fused.
849   ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
850 
851   // If there are no fusable dependences or there are no tile+fusable loops,
852   // just return.
853   if (ret.fusedLoopDims.empty()) {
854     return llvm::None;
855   }
856 
857   // Tile the fused loops in the last operation in the list.
858   SmallVector<Value, 4> tileSizeVector =
859       tilingOptions.tileSizeComputationFunction(builder, rootOp);
860   Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
861       builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
862   if (!tiledRootOp) {
863     rootOp.emitRemark("failed to tile the fused loops");
864     return llvm::None;
865   }
866   ret.op = tiledRootOp->op;
867   ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
868 
869   // Fuse the other operations into the fused inter-tile loops produced above.
870   ret.fusedProducers = fuseOperations(builder, rootOp, ret.op, ops.drop_back(),
871                                       fusableDependences, ret.fusedLoopDims);
872 
873   return ret;
874 }
875 
876 Optional<TiledAndFusedLinalgOps>
877 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
878                                    const LinalgDependenceGraph &dependenceGraph,
879                                    const LinalgTilingOptions &tilingOptions) {
880   switch (tilingOptions.loopType) {
881   case LinalgTilingLoopType::Loops:
882     return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
883                                                 tilingOptions);
884   case LinalgTilingLoopType::ParallelLoops:
885     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
886         builder, ops, dependenceGraph, tilingOptions);
887   default:;
888   }
889   return llvm::None;
890 }
891