1 //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements view-based alias and dependence analyses.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/IR/BuiltinOps.h"
18 
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 
22 #define DEBUG_TYPE "linalg-dependence-analysis"
23 
24 using namespace mlir;
25 using namespace mlir::linalg;
26 
27 using llvm::dbgs;
28 
find(Value v)29 Value Aliases::find(Value v) {
30   if (v.isa<BlockArgument>())
31     return v;
32 
33   auto it = aliases.find(v);
34   if (it != aliases.end()) {
35     assert(it->getSecond().getType().isa<BaseMemRefType>() &&
36            "Memref expected");
37     return it->getSecond();
38   }
39 
40   while (true) {
41     if (v.isa<BlockArgument>())
42       return v;
43 
44     Operation *defOp = v.getDefiningOp();
45     if (!defOp)
46       return v;
47 
48     // Treat RegionBranchOpInterfaces like an allocate and don't try to follow
49     // the aliasing further.
50     if (isa<RegionBranchOpInterface>(defOp))
51       return v;
52     if (isa<bufferization::ToMemrefOp>(defOp))
53       return v;
54 
55     if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {
56       // Collect all memory effects on `v`.
57       SmallVector<MemoryEffects::EffectInstance, 1> effects;
58       memEffect.getEffectsOnValue(v, effects);
59 
60       // If we have the 'Allocate' memory effect on `v`, then `v` should be the
61       // original buffer.
62       if (llvm::any_of(
63               effects, [](const MemoryEffects::EffectInstance &instance) {
64                 return isa<MemoryEffects::Allocate>(instance.getEffect());
65               }))
66         return v;
67     }
68 
69     if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) {
70       auto it =
71           aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource())));
72       return it.first->second;
73     }
74 
75     llvm::errs() << "View alias analysis reduces to: " << v << "\n";
76     llvm_unreachable("unsupported view alias case");
77   }
78 }
79 
getDependenceTypeStr(DependenceType depType)80 StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
81   switch (depType) {
82   case LinalgDependenceGraph::DependenceType::RAW:
83     return "RAW";
84   case LinalgDependenceGraph::DependenceType::RAR:
85     return "RAR";
86   case LinalgDependenceGraph::DependenceType::WAR:
87     return "WAR";
88   case LinalgDependenceGraph::DependenceType::WAW:
89     return "WAW";
90   default:
91     break;
92   }
93   llvm_unreachable("Unexpected DependenceType");
94 }
95 
96 LinalgDependenceGraph
buildDependenceGraph(Aliases & aliases,func::FuncOp f)97 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, func::FuncOp f) {
98   SmallVector<LinalgOp, 8> linalgOps;
99   f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
100   return LinalgDependenceGraph(aliases, linalgOps);
101 }
102 
LinalgDependenceGraph(Aliases & aliases,ArrayRef<LinalgOp> ops)103 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
104                                              ArrayRef<LinalgOp> ops)
105     : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
106   for (const auto &en : llvm::enumerate(linalgOps)) {
107     linalgOpPositions.insert(
108         std::make_pair(en.value().getOperation(), en.index()));
109   }
110   for (unsigned i = 0, e = ops.size(); i < e; ++i) {
111     for (unsigned j = i + 1; j < e; ++j) {
112       addDependencesBetween(ops[i], ops[j]);
113     }
114   }
115 }
116 
addDependenceElem(DependenceType dt,LinalgDependenceGraphElem::OpView indexingOpView,LinalgDependenceGraphElem::OpView dependentOpView)117 void LinalgDependenceGraph::addDependenceElem(
118     DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView,
119     LinalgDependenceGraphElem::OpView dependentOpView) {
120   LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
121                     << LinalgDependenceGraphElem::getValue(indexingOpView)
122                     << " @) -> \n\t\t("
123                     << LinalgDependenceGraphElem::getValue(dependentOpView)
124                     << " @)");
125   dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)]
126       .push_back(
127           LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
128   dependencesIntoGraphs[dt]
129                        [LinalgDependenceGraphElem::getOwner(dependentOpView)]
130                            .push_back(LinalgDependenceGraphElem{
131                                indexingOpView, dependentOpView, dt});
132 }
133 
134 LinalgDependenceGraph::dependence_range
getDependencesFrom(LinalgOp src,LinalgDependenceGraph::DependenceType dt) const135 LinalgDependenceGraph::getDependencesFrom(
136     LinalgOp src, LinalgDependenceGraph::DependenceType dt) const {
137   return getDependencesFrom(src.getOperation(), dt);
138 }
139 
140 LinalgDependenceGraph::dependence_range
getDependencesFrom(Operation * src,LinalgDependenceGraph::DependenceType dt) const141 LinalgDependenceGraph::getDependencesFrom(
142     Operation *src, LinalgDependenceGraph::DependenceType dt) const {
143   auto iter = dependencesFromGraphs[dt].find(src);
144   if (iter == dependencesFromGraphs[dt].end())
145     return llvm::make_range(nullptr, nullptr);
146   return llvm::make_range(iter->second.begin(), iter->second.end());
147 }
148 
149 LinalgDependenceGraph::dependence_range
getDependencesInto(LinalgOp dst,LinalgDependenceGraph::DependenceType dt) const150 LinalgDependenceGraph::getDependencesInto(
151     LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const {
152   return getDependencesInto(dst.getOperation(), dt);
153 }
154 
155 LinalgDependenceGraph::dependence_range
getDependencesInto(Operation * dst,LinalgDependenceGraph::DependenceType dt) const156 LinalgDependenceGraph::getDependencesInto(
157     Operation *dst, LinalgDependenceGraph::DependenceType dt) const {
158   auto iter = dependencesIntoGraphs[dt].find(dst);
159   if (iter == dependencesIntoGraphs[dt].end())
160     return llvm::make_range(nullptr, nullptr);
161   return llvm::make_range(iter->second.begin(), iter->second.end());
162 }
163 
addDependencesBetween(LinalgOp src,LinalgOp dst)164 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
165   LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation()
166                     << " and " << *dst.getOperation() << "\n");
167   if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
168     for (OpOperand *dstOpOperand : dst.getInputOperands()) {
169       // Check if the operand is defined by the src.
170       auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
171       if (definingOp && definingOp == src)
172         addDependenceElem(DependenceType::RAW, dstOpOperand->get(),
173                           dstOpOperand);
174     }
175     for (OpOperand *dstOpOperand : dst.getOutputOperands()) {
176       // Check if the operand is defined by the src.
177       auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
178       if (definingOp && definingOp == src) {
179         if (dst.isInitTensor(dstOpOperand)) {
180           addDependenceElem(DependenceType::RAW, dstOpOperand->get(),
181                             dstOpOperand);
182         }
183         addDependenceElem(DependenceType::WAW, dstOpOperand->get(),
184                           dstOpOperand);
185       }
186     }
187     return;
188   }
189   assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
190          "unhandled dependence tracking for mixed buffer/tensor operations");
191   for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W
192     // RAW graph
193     for (OpOperand *dstOpOperand : dst.getInputBufferOperands())   // R
194       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias
195         addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand);
196     // WAW graph
197     for (OpOperand *dstOpOperand : dst.getOutputBufferOperands())  // W
198       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias
199         addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand);
200   }
201   for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R
202     // RAR graph
203     for (OpOperand *dstOpOperand : dst.getInputBufferOperands())   // R
204       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias
205         addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand);
206     // WAR graph
207     for (OpOperand *dstOpOperand : dst.getOutputBufferOperands())  // W
208       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias
209         addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand);
210   }
211 }
212 
213 SmallVector<Operation *, 8>
findCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp) const214 LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
215                                                LinalgOp dstLinalgOp) const {
216   return findOperationsWithCoveringDependences(
217       srcLinalgOp, dstLinalgOp, nullptr,
218       {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
219 }
220 
findCoveringWrites(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const221 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
222     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
223   return findOperationsWithCoveringDependences(
224       srcLinalgOp, dstLinalgOp, view,
225       {DependenceType::WAW, DependenceType::WAR});
226 }
227 
findCoveringReads(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const228 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
229     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
230   return findOperationsWithCoveringDependences(
231       srcLinalgOp, dstLinalgOp, view,
232       {DependenceType::RAR, DependenceType::RAW});
233 }
234 
235 SmallVector<Operation *, 8>
findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view,ArrayRef<DependenceType> types) const236 LinalgDependenceGraph::findOperationsWithCoveringDependences(
237     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view,
238     ArrayRef<DependenceType> types) const {
239   auto *src = srcLinalgOp.getOperation();
240   auto *dst = dstLinalgOp.getOperation();
241   auto srcPos = linalgOpPositions.lookup(src);
242   auto dstPos = linalgOpPositions.lookup(dst);
243   assert(srcPos < dstPos && "expected dst after src in IR traversal order");
244 
245   SmallVector<Operation *, 8> res;
246   // Consider an intermediate interleaved `interim` op, look for any dependence
247   // to an aliasing view on a src -> op -> dst path.
248   // TODO: we are not considering paths yet, just interleaved positions.
249   for (auto dt : types) {
250     for (auto dependence : getDependencesFrom(src, dt)) {
251       auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp());
252       // Skip if not interleaved.
253       if (interimPos >= dstPos || interimPos <= srcPos)
254         continue;
255       Value consumerView = dependence.getIndexingValue();
256       if (view && !aliases.alias(view, consumerView))
257         continue;
258       auto *op = dependence.getDependentOp();
259       LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
260                         << getDependenceTypeStr(dt) << ": " << *src << " -> "
261                         << *op << " on " << consumerView);
262       res.push_back(op);
263     }
264   }
265   return res;
266 }
267 
hasDependenceFrom(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const268 bool LinalgDependenceGraph::hasDependenceFrom(
269     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
270     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
271   for (auto dep : depTypes)
272     for (auto dependence : getDependencesInto(dstLinalgOp, dep))
273       if (dependence.getDependentOp() == srcLinalgOp)
274         return true;
275   return false;
276 }
277 
hasDependentOperationsFrom(LinalgOp linalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const278 bool LinalgDependenceGraph::hasDependentOperationsFrom(
279     LinalgOp linalgOp,
280     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
281   for (auto dep : depTypes) {
282     if (!getDependencesFrom(linalgOp, dep).empty())
283       return true;
284   }
285   return false;
286 }
287 
hasDependentOperationsInto(LinalgOp linalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const288 bool LinalgDependenceGraph::hasDependentOperationsInto(
289     LinalgOp linalgOp,
290     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
291   for (auto dep : depTypes) {
292     if (!getDependencesInto(linalgOp, dep).empty())
293       return true;
294   }
295   return false;
296 }
297 
hasDependentOperations(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const298 bool LinalgDependenceGraph::hasDependentOperations(
299     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
300   return hasDependentOperationsInto(linalgOp, depTypes) ||
301          hasDependentOperationsFrom(linalgOp, depTypes);
302 }
303 
304 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperationsInto(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const305 LinalgDependenceGraph::getDependentOperationsInto(
306     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
307   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
308       dependentOperations;
309   for (auto dependenceType : depTypes) {
310     auto dependencies = getDependencesInto(linalgOp, dependenceType);
311     dependentOperations.append(dependencies.begin(), dependencies.end());
312   }
313   return dependentOperations;
314 }
315 
316 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperationsFrom(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const317 LinalgDependenceGraph::getDependentOperationsFrom(
318     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
319   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
320       dependentOperations;
321   for (auto dependenceType : depTypes) {
322     auto dependencies = getDependencesFrom(linalgOp, dependenceType);
323     dependentOperations.append(dependencies.begin(), dependencies.end());
324   }
325   return dependentOperations;
326 }
327 
328 /// Returns all dependent operations (into and from) given `operation`.
329 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperations(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const330 LinalgDependenceGraph::getDependentOperations(
331     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
332   SmallVector<LinalgDependenceGraphElem, 2> dependentOperations =
333       getDependentOperationsInto(linalgOp, depTypes);
334   SmallVector<LinalgDependenceGraphElem, 2> t =
335       getDependentOperationsFrom(linalgOp, depTypes);
336   dependentOperations.append(t.begin(), t.end());
337   return dependentOperations;
338 }
339 
print(raw_ostream & os) const340 void LinalgDependenceGraph::print(raw_ostream &os) const {
341   for (auto dt : {
342            LinalgDependenceGraph::DependenceType::RAW,
343            LinalgDependenceGraph::DependenceType::WAW,
344        }) {
345     const auto &fromGraph = dependencesFromGraphs[dt];
346     for (const auto &it : fromGraph) {
347       os << "[LinalgDependenceGraph] DT " << dt << " from: " << *it.first
348          << ":\n";
349       for (const auto &dep : it.second) {
350         os << "\tDT " << dt << " " << *dep.getDependentOp() << ":\n";
351       }
352     }
353   }
354 }
355 
dump() const356 void LinalgDependenceGraph::dump() const { print(llvm::errs()); }
357