1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/Dominance.h"
14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/EDSC/Helpers.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Support/STLExtras.h"
27 #include "mlir/Transforms/FoldUtils.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 
32 #define DEBUG_TYPE "linalg-fusion"
33 
34 using namespace mlir;
35 using namespace mlir::edsc;
36 using namespace mlir::edsc::intrinsics;
37 using namespace mlir::linalg;
38 using namespace mlir::linalg::intrinsics;
39 
40 using llvm::dbgs;
41 
42 /// Implements a simple high-level fusion pass of linalg library operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 ///   1. inspecting the linalg ops that write into the views read by `O`. This
47 ///      uses the SSA value of the views and a simple subview/slice analysis to
48 ///      determine producer-consumer dependences;
49 ///   2. greedily fuse the linalg ops that produce subview
50 ///   3. inspect the fused ops and determine whether they have other remaining
51 ///      LinalgOp uses. If not, then erase the original producing linalg op.
52 ///
53 /// More advanced use cases, analyses as well as profitability heuristics are
54 /// left for future work.
55 
56 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
57 static llvm::cl::list<unsigned> clTileSizes(
58     "linalg-fusion-tile-sizes",
59     llvm::cl::desc(
60         "Tile sizes by which to tile linalg operations during linalg fusion"),
61     llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
62     llvm::cl::cat(clOptionsCategory));
63 
64 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
65 // a subset of the original loop ranges of `op`.
66 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
67 // to the `loopRanges` in order to obtain view ranges.
68 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
69                                     ArrayRef<SubViewOp::Range> loopRanges) {
70   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
71   auto maps = loopToOperandRangesMaps(op);
72   SmallVector<Value, 8> clonedViews;
73   clonedViews.reserve(op.getNumInputsAndOutputs());
74   // Iterate over the inputs and outputs in order.
75   // Extract the subranges from the linearized ranges.
76   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
77   for (auto en : llvm::enumerate(ios)) {
78     unsigned idx = en.index();
79     auto map = maps[idx];
80     LLVM_DEBUG(dbgs() << "map: " << map << "\n");
81     Value view = en.value();
82     SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
83     for (auto en2 : llvm::enumerate(map.getResults())) {
84       unsigned d = en2.index();
85       // loopToOperandRangesMaps are permutations-only.
86       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
87       viewRanges[d] = loopRanges[loopPos];
88       LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
89                         << "\t"
90                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
91     }
92     // Construct a new subview for the tile.
93     unsigned rank = viewRanges.size();
94     SmallVector<Value, 4> offsets, sizes, strides;
95     offsets.reserve(rank);
96     sizes.reserve(rank);
97     strides.reserve(rank);
98     for (auto r : viewRanges) {
99       offsets.push_back(r.offset);
100       sizes.push_back(r.size);
101       strides.push_back(r.stride);
102     }
103     clonedViews.push_back(
104         b.create<SubViewOp>(loc, view, offsets, sizes, strides));
105   }
106   auto operands = getAssumedNonViewOperands(op);
107   clonedViews.append(operands.begin(), operands.end());
108   return op.clone(b, loc, clonedViews);
109 }
110 
111 struct ViewDimension {
112   Value view;
113   unsigned dimension;
114 };
115 
116 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies
117 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
118 // guarantees at least one such dimension is found. If multiple candidates exist
119 // they must agree by construction (i.e. have the same size) and we just return
120 // the first one.
121 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
122   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
123   auto maps = loopToOperandRangesMaps(op);
124   // Iterate over the inputs and outputs in order.
125   // Extract the subranges from the linearized ranges.
126   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
127   for (auto en : llvm::enumerate(ios)) {
128     unsigned idx = en.index();
129     auto map = maps[idx];
130     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
131     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
132     Value view = en.value();
133     SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
134     for (auto en2 : llvm::enumerate(map.getResults())) {
135       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
136         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
137                           << "\n");
138         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
139         return ViewDimension{view, static_cast<unsigned>(en2.index())};
140       }
141     }
142   }
143   llvm_unreachable("Expect to be able to extract a view defining loop range");
144 }
145 
146 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
147                      unsigned consumerIdx, unsigned producerIdx,
148                      OperationFolder *folder) {
149   assert(producer.hasBufferSemantics() &&
150          "expected linalg op with buffer semantics");
151   assert(consumer.hasBufferSemantics() &&
152          "expected linalg op with buffer semantics");
153   auto subView = dyn_cast_or_null<SubViewOp>(
154       consumer.getInput(consumerIdx).getDefiningOp());
155   auto slice =
156       dyn_cast_or_null<SliceOp>(consumer.getInput(consumerIdx).getDefiningOp());
157   assert(subView || slice);
158   (void)subView;
159   (void)slice;
160 
161   // loopToOperandRangesMaps are permutations-only by construction:
162   //   we can always identify a data dimension with a (at least one) loop
163   //   dimension.
164   AffineMap producerMap =
165       loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx];
166   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
167                     << ", producer map: " << producerMap << "\n");
168 
169   unsigned nPar = producer.getNumParallelLoops();
170   unsigned nRed = producer.getNumReductionLoops();
171   unsigned nWin = producer.getNumWindowLoops();
172   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
173 
174   // Iterate over dimensions identified by the producer map for `producerIdx`.
175   // This defines a subset of the loop ranges that we need to complete later.
176   for (auto en : llvm::enumerate(producerMap.getResults())) {
177     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
178     loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
179   }
180 
181   OpBuilder b(consumer.getOperation());
182   auto loc = consumer.getLoc();
183   // Iterate over all dimensions. For the dimensions not identified by the
184   // producer map for `producerIdx`, we need to explicitly compute the view that
185   // defines the loop ranges using the `producer`.
186   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
187     if (loopRanges[i].offset)
188       LLVM_DEBUG(llvm::dbgs()
189                  << "existing LoopRange: " << loopRanges[i] << "\n");
190     else {
191       auto viewDim = getViewDefiningLoopRange(producer, i);
192       loopRanges[i] = SubViewOp::Range{constant_index(folder, 0),
193                                        dim(viewDim.view, viewDim.dimension),
194                                        constant_index(folder, 1)};
195       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
196     }
197   }
198 
199   return cloneWithLoopRanges(b, loc, producer, loopRanges);
200 }
201 
202 // Encode structural fusion safety preconditions.
203 // Some of these will be lifted in the future with better analysis.
204 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
205                                           LinalgOp consumer) {
206   assert(producer.hasBufferSemantics() &&
207          "expected linalg op with buffer semantics");
208   assert(consumer.hasBufferSemantics() &&
209          "expected linalg op with buffer semantics");
210   if (producer.getNumOutputs() != 1) {
211     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
212     return false;
213   }
214   // Only fuse when the producer block dominates.
215   DominanceInfo dom(producer.getOperation());
216   if (!dom.dominates(producer.getOperation()->getBlock(),
217                      consumer.getOperation()->getBlock())) {
218     LLVM_DEBUG(
219         dbgs()
220         << "\nNot structurally fusable (producer block does not dominate)");
221     return false;
222   }
223   return true;
224 }
225 
226 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
227                                              LinalgOp consumer,
228                                              Value consumedView,
229                                              LinalgOp producer) {
230   assert(producer.hasBufferSemantics() &&
231          "expected linalg op with buffer semantics");
232   assert(consumer.hasBufferSemantics() &&
233          "expected linalg op with buffer semantics");
234   // Make some simple structural checks that alleviate the need for more
235   // complex analyses.
236   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
237     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
238                       << *producer.getOperation());
239     return false;
240   }
241   // Check for any interleaved write to consumedView.
242   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
243     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
244                       << *producer.getOperation());
245     return false;
246   }
247   return true;
248 }
249 
250 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
251                                  LinalgOp consumer, Value consumedView,
252                                  LinalgOp producer) {
253   assert(producer.hasBufferSemantics() &&
254          "expected linalg op with buffer semantics");
255   assert(consumer.hasBufferSemantics() &&
256          "expected linalg op with buffer semantics");
257   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
258     return false;
259   // Check for any fusion-preventing dependence to any view read/written that
260   // would violate dependences.
261   if (!graph.findCoveringDependences(producer, consumer).empty()) {
262     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
263                       << *producer.getOperation());
264     return false;
265   }
266   return true;
267 }
268 
269 // Only consider RAW atm.
270 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
271     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
272     const LinalgDependenceGraph &graph, OperationFolder *folder) {
273   assert(consumer.hasBufferSemantics() &&
274          "expected linalg op with buffer semantics");
275   LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
276                     << *consumer.getOperation());
277   for (auto dependence : graph.getDependencesInto(
278            consumer, LinalgDependenceGraph::DependenceType::RAW)) {
279     LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
280                       << *dependence.dependentOpView.op << "\n");
281     auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
282 
283     // Check that the dependence is indeed on the input `consumerIdx` view.
284     auto consumedView = dependence.indexingView;
285     if (consumer.getInput(consumerIdx) != consumedView)
286       continue;
287 
288     // Consumer consumes this view, `isStructurallyFusableProducer` also checks
289     // whether it is a strict subview of the producer view.
290     auto producedView = dependence.dependentOpView.view;
291     auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
292     // `consumerIdx` and `producerIdx` exist by construction.
293     LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
294                       << " view: " << producedView
295                       << " output index: " << producerIdx);
296 
297     // Must be a subview or a slice to guarantee there are loops we can fuse
298     // into.
299     auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp());
300     auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp());
301     if (!subView && !slice) {
302       LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
303       continue;
304     }
305 
306     // Simple fusability checks.
307     if (!isFusableInto(graph, consumer, consumedView, producer))
308       continue;
309 
310     // Fuse `producer` just before `consumer`.
311     OpBuilder::InsertionGuard g(b);
312     b.setInsertionPoint(consumer.getOperation());
313     ScopedContext scope(b, consumer.getLoc());
314     LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
315     auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
316                               producerIdx, folder);
317 
318     return FusionInfo{producer, fusedProducer};
319   }
320   return llvm::None;
321 }
322 
323 static void fuseLinalgOpsGreedily(FuncOp f) {
324   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
325 
326   OpBuilder b(f);
327   OperationFolder folder(f.getContext());
328   DenseSet<Operation *> eraseSet;
329 
330   // Save original Linalg ops, we only want to make a pass over those.
331   SmallVector<Operation *, 8> linalgOps;
332   f.walk([&](LinalgOp op) {
333     if (op.hasBufferSemantics())
334       linalgOps.push_back(op);
335   });
336 
337   Aliases aliases;
338   LinalgDependenceGraph G(aliases, linalgOps);
339   for (auto *op : llvm::reverse(linalgOps)) {
340     for (unsigned consumerIdx = 0, e = LinalgOp(op).getNumInputs();
341          consumerIdx < e; ++consumerIdx) {
342       if (auto fusionInfo = fuseProducerOf(b, op, consumerIdx, G, &folder))
343         eraseSet.insert(fusionInfo->originalProducer.getOperation());
344     }
345   }
346 
347   // The `fuseProducerOf` function performs structural checks and in particular
348   // that no covering read or write exist between the consumer and the producer.
349   // As a consequence, the only fusions that may occur preserve subsequent
350   // dependences and are guaranteed by construction to produce the whole view.
351   // We may thus erase the producer once it is fused.
352   for (auto *e : eraseSet)
353     e->erase();
354   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
355 }
356 
357 namespace {
358 struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
359   void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
360 };
361 } // namespace
362 
363 std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgFusionPass() {
364   return std::make_unique<LinalgFusionPass>();
365 }
366 
367 static PassRegistration<LinalgFusionPass>
368     pass("linalg-fusion", "Fuse operations in the linalg dialect");
369